Compare commits

..

1 Commits

Author SHA1 Message Date
Sergey Borisov
0ebe2c0ebc Add mask to l2l, MaskEdge and ColorCorrect nodes 2023-07-24 14:25:54 +03:00
337 changed files with 9740 additions and 16281 deletions

View File

@@ -20,13 +20,13 @@ def calc_images_mean_L1(image1_path, image2_path):
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("image1_path")
parser.add_argument("image2_path")
parser.add_argument('image1_path')
parser.add_argument('image2_path')
args = parser.parse_args()
return args
if __name__ == "__main__":
if __name__ == '__main__':
args = parse_args()
mean_L1 = calc_images_mean_L1(args.image1_path, args.image2_path)
print(mean_L1)

View File

@@ -1,2 +1 @@
b3dccfaeb636599c02effc377cdd8a87d658256c
218b6d0546b990fc449c876fb99f44b50c4daa35

View File

@@ -1,11 +1,11 @@
name: Close inactive issues
on:
schedule:
- cron: "00 4 * * *"
- cron: "00 6 * * *"
env:
DAYS_BEFORE_ISSUE_STALE: 30
DAYS_BEFORE_ISSUE_CLOSE: 14
DAYS_BEFORE_ISSUE_STALE: 14
DAYS_BEFORE_ISSUE_CLOSE: 28
jobs:
close-issues:
@@ -14,7 +14,7 @@ jobs:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v8
- uses: actions/stale@v5
with:
days-before-issue-stale: ${{ env.DAYS_BEFORE_ISSUE_STALE }}
days-before-issue-close: ${{ env.DAYS_BEFORE_ISSUE_CLOSE }}
@@ -23,6 +23,5 @@ jobs:
close-issue-message: "Due to inactivity, this issue was automatically closed. If you are still experiencing the issue, please recreate the issue."
days-before-pr-stale: -1
days-before-pr-close: -1
exempt-issue-labels: "Active Issue"
repo-token: ${{ secrets.GITHUB_TOKEN }}
operations-per-run: 500

View File

@@ -2,6 +2,8 @@ name: Lint frontend
on:
pull_request:
paths:
- 'invokeai/frontend/web/**'
types:
- 'ready_for_review'
- 'opened'
@@ -9,6 +11,8 @@ on:
push:
branches:
- 'main'
paths:
- 'invokeai/frontend/web/**'
merge_group:
workflow_dispatch:

View File

@@ -1,27 +0,0 @@
name: Black # TODO: add isort and flake8 later
on:
pull_request: {}
push:
branches: master
tags: "*"
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install dependencies with pip
run: |
pip install --upgrade pip wheel
pip install .[test]
# - run: isort --check-only .
- run: black --check .
# - run: flake8

1
.gitignore vendored
View File

@@ -38,6 +38,7 @@ develop-eggs/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/

View File

@@ -1,10 +0,0 @@
# See https://pre-commit.com/ for usage and config
repos:
- repo: local
hooks:
- id: black
name: black
stages: [commit]
language: system
entry: black
types: [python]

View File

@@ -1,290 +0,0 @@
Copyright (c) 2023 Stability AI
CreativeML Open RAIL++-M License dated July 26, 2023
Section I: PREAMBLE
Multimodal generative models are being widely adopted and used, and
have the potential to transform the way artists, among other
individuals, conceive and benefit from AI or ML technologies as a tool
for content creation.
Notwithstanding the current and potential benefits that these
artifacts can bring to society at large, there are also concerns about
potential misuses of them, either due to their technical limitations
or ethical considerations.
In short, this license strives for both the open and responsible
downstream use of the accompanying model. When it comes to the open
character, we took inspiration from open source permissive licenses
regarding the grant of IP rights. Referring to the downstream
responsible use, we added use-based restrictions not permitting the
use of the model in very specific scenarios, in order for the licensor
to be able to enforce the license in case potential misuses of the
Model may occur. At the same time, we strive to promote open and
responsible research on generative models for art and content
generation.
Even though downstream derivative versions of the model could be
released under different licensing terms, the latter will always have
to include - at minimum - the same use-based restrictions as the ones
in the original license (this license). We believe in the intersection
between open and responsible AI development; thus, this agreement aims
to strike a balance between both in order to enable responsible
open-science in the field of AI.
This CreativeML Open RAIL++-M License governs the use of the model
(and its derivatives) and is informed by the model card associated
with the model.
NOW THEREFORE, You and Licensor agree as follows:
Definitions
"License" means the terms and conditions for use, reproduction, and
Distribution as defined in this document.
"Data" means a collection of information and/or content extracted from
the dataset used with the Model, including to train, pretrain, or
otherwise evaluate the Model. The Data is not licensed under this
License.
"Output" means the results of operating a Model as embodied in
informational content resulting therefrom.
"Model" means any accompanying machine-learning based assemblies
(including checkpoints), consisting of learnt weights, parameters
(including optimizer states), corresponding to the model architecture
as embodied in the Complementary Material, that have been trained or
tuned, in whole or in part on the Data, using the Complementary
Material.
"Derivatives of the Model" means all modifications to the Model, works
based on the Model, or any other model which is created or initialized
by transfer of patterns of the weights, parameters, activations or
output of the Model, to the other model, in order to cause the other
model to perform similarly to the Model, including - but not limited
to - distillation methods entailing the use of intermediate data
representations or methods based on the generation of synthetic data
by the Model for training the other model.
"Complementary Material" means the accompanying source code and
scripts used to define, run, load, benchmark or evaluate the Model,
and used to prepare data for training or evaluation, if any. This
includes any accompanying documentation, tutorials, examples, etc, if
any.
"Distribution" means any transmission, reproduction, publication or
other sharing of the Model or Derivatives of the Model to a third
party, including providing the Model as a hosted service made
available by electronic or other remote means - e.g. API-based or web
access.
"Licensor" means the copyright owner or entity authorized by the
copyright owner that is granting the License, including the persons or
entities that may have rights in the Model and/or distributing the
Model.
"You" (or "Your") means an individual or Legal Entity exercising
permissions granted by this License and/or making use of the Model for
whichever purpose and in any field of use, including usage of the
Model in an end-use application - e.g. chatbot, translator, image
generator.
"Third Parties" means individuals or legal entities that are not under
common control with Licensor or You.
"Contribution" means any work of authorship, including the original
version of the Model and any modifications or additions to that Model
or Derivatives of the Model thereof, that is intentionally submitted
to Licensor for inclusion in the Model by the copyright owner or by an
individual or Legal Entity authorized to submit on behalf of the
copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent to
the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control
systems, and issue tracking systems that are managed by, or on behalf
of, the Licensor for the purpose of discussing and improving the
Model, but excluding communication that is conspicuously marked or
otherwise designated in writing by the copyright owner as "Not a
Contribution."
"Contributor" means Licensor and any individual or Legal Entity on
behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Model.
Section II: INTELLECTUAL PROPERTY RIGHTS
Both copyright and patent grants apply to the Model, Derivatives of
the Model and Complementary Material. The Model and Derivatives of the
Model are subject to additional terms as described in
Section III.
Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare, publicly display, publicly
perform, sublicense, and distribute the Complementary Material, the
Model, and Derivatives of the Model.
Grant of Patent License. Subject to the terms and conditions of this
License and where and as applicable, each Contributor hereby grants to
You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable (except as stated in this paragraph) patent license to
make, have made, use, offer to sell, sell, import, and otherwise
transfer the Model and the Complementary Material, where such license
applies only to those patent claims licensable by such Contributor
that are necessarily infringed by their Contribution(s) alone or by
combination of their Contribution(s) with the Model to which such
Contribution(s) was submitted. If You institute patent litigation
against any entity (including a cross-claim or counterclaim in a
lawsuit) alleging that the Model and/or Complementary Material or a
Contribution incorporated within the Model and/or Complementary
Material constitutes direct or contributory patent infringement, then
any patent licenses granted to You under this License for the Model
and/or Work shall terminate as of the date such litigation is asserted
or filed.
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
Distribution and Redistribution. You may host for Third Party remote
access purposes (e.g. software-as-a-service), reproduce and distribute
copies of the Model or Derivatives of the Model thereof in any medium,
with or without modifications, provided that You meet the following
conditions: Use-based restrictions as referenced in paragraph 5 MUST
be included as an enforceable provision by You in any type of legal
agreement (e.g. a license) governing the use and/or distribution of
the Model or Derivatives of the Model, and You shall give notice to
subsequent users You Distribute to, that the Model or Derivatives of
the Model are subject to paragraph 5. This provision does not apply to
the use of Complementary Material. You must give any Third Party
recipients of the Model or Derivatives of the Model a copy of this
License; You must cause any modified files to carry prominent notices
stating that You changed the files; You must retain all copyright,
patent, trademark, and attribution notices excluding those notices
that do not pertain to any part of the Model, Derivatives of the
Model. You may add Your own copyright statement to Your modifications
and may provide additional or different license terms and conditions -
respecting paragraph 4.a. - for use, reproduction, or Distribution of
Your modifications, or for any such Derivatives of the Model as a
whole, provided Your use, reproduction, and Distribution of the Model
otherwise complies with the conditions stated in this License.
Use-based restrictions. The restrictions set forth in Attachment A are
considered Use-based restrictions. Therefore You cannot use the Model
and the Derivatives of the Model for the specified restricted
uses. You may use the Model subject to this License, including only
for lawful purposes and in accordance with the License. Use may
include creating any content with, finetuning, updating, running,
training, evaluating and/or reparametrizing the Model. You shall
require all of Your users who use the Model or a Derivative of the
Model to comply with the terms of this paragraph (paragraph 5).
The Output You Generate. Except as set forth herein, Licensor claims
no rights in the Output You generate using the Model. You are
accountable for the Output you generate and its subsequent uses. No
use of the output can contravene any provision as stated in the
License.
Section IV: OTHER PROVISIONS
Updates and Runtime Restrictions. To the maximum extent permitted by
law, Licensor reserves the right to restrict (remotely or otherwise)
usage of the Model in violation of this License.
Trademarks and related. Nothing in this License permits You to make
use of Licensors trademarks, trade names, logos or to otherwise
suggest endorsement or misrepresent the relationship between the
parties; and any rights not expressly granted herein are reserved by
the Licensors.
Disclaimer of Warranty. Unless required by applicable law or agreed to
in writing, Licensor provides the Model and the Complementary Material
(and each Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Model, Derivatives of
the Model, and the Complementary Material and assume any risks
associated with Your exercise of permissions under this License.
Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise, unless
required by applicable law (such as deliberate and grossly negligent
acts) or agreed to in writing, shall any Contributor be liable to You
for damages, including any direct, indirect, special, incidental, or
consequential damages of any character arising as a result of this
License or out of the use or inability to use the Model and the
Complementary Material (including but not limited to damages for loss
of goodwill, work stoppage, computer failure or malfunction, or any
and all other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
Accepting Warranty or Additional Liability. While redistributing the
Model, Derivatives of the Model and the Complementary Material
thereof, You may choose to offer, and charge a fee for, acceptance of
support, warranty, indemnity, or other liability obligations and/or
rights consistent with this License. However, in accepting such
obligations, You may act only on Your own behalf and on Your sole
responsibility, not on behalf of any other Contributor, and only if
You agree to indemnify, defend, and hold each Contributor harmless for
any liability incurred by, or claims asserted against, such
Contributor by reason of your accepting any such warranty or
additional liability.
If any provision of this License is held to be invalid, illegal or
unenforceable, the remaining provisions shall be unaffected thereby
and remain valid as if such provision had not been set forth herein.
END OF TERMS AND CONDITIONS
Attachment A
Use Restrictions
You agree not to use the Model or Derivatives of the Model:
* In any way that violates any applicable national, federal, state,
local or international law or regulation;
* For the purpose of exploiting, harming or attempting to exploit or
harm minors in any way;
* To generate or disseminate verifiably false information and/or
content with the purpose of harming others;
* To generate or disseminate personal identifiable information that
can be used to harm an individual;
* To defame, disparage or otherwise harass others;
* For fully automated decision making that adversely impacts an
individuals legal rights or otherwise creates or modifies a
binding, enforceable obligation;
* For any use intended to or which has the effect of discriminating
against or harming individuals or groups based on online or offline
social behavior or known or predicted personal or personality
characteristics;
* To exploit any of the vulnerabilities of a specific group of persons
based on their age, social, physical or mental characteristics, in
order to materially distort the behavior of a person pertaining to
that group in a manner that causes or is likely to cause that person
or another person physical or psychological harm;
* For any use intended to or which has the effect of discriminating
against individuals or groups based on legally protected
characteristics or categories;
* To provide medical advice and medical results interpretation;
* To generate or disseminate information for the purpose to be used
for administration of justice, law enforcement, immigration or
asylum processes, such as predicting an individual will commit
fraud/crime commitment (e.g. by text profiling, drawing causal
relationships between assertions made in documents, indiscriminate
and arbitrarily-targeted use).

View File

@@ -123,7 +123,7 @@ and go to http://localhost:9090.
### Command-Line Installation (for developers and users familiar with Terminals)
You must have Python 3.9 through 3.11 installed on your machine. Earlier or
You must have Python 3.9 or 3.10 installed on your machine. Earlier or
later versions are not supported.
Node.js also needs to be installed along with yarn (can be installed with
the command `npm install -g yarn` if needed)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 490 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 335 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 217 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 244 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 948 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 292 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 420 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 179 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 216 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 439 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 563 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 353 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 131 KiB

View File

@@ -16,7 +16,7 @@ If you don't feel ready to make a code contribution yet, no problem! You can als
There are two paths to making a development contribution:
1. Choosing an open issue to address. Open issues can be found in the [Issues](https://github.com/invoke-ai/InvokeAI/issues?q=is%3Aissue+is%3Aopen) section of the InvokeAI repository. These are tagged by the issue type (bug, enhancement, etc.) along with the “good first issues” tag denoting if they are suitable for first time contributors.
1. Additional items can be found on our [roadmap](https://github.com/orgs/invoke-ai/projects/7). The roadmap is organized in terms of priority, and contains features of varying size and complexity. If there is an inflight item youd like to help with, reach out to the contributor assigned to the item to see how you can help.
1. Additional items can be found on our roadmap <******************************link to roadmap>******************************. The roadmap is organized in terms of priority, and contains features of varying size and complexity. If there is an inflight item youd like to help with, reach out to the contributor assigned to the item to see how you can help.
2. Opening a new issue or feature to add. **Please make sure you have searched through existing issues before creating new ones.**
*Regardless of what you choose, please post in the [#dev-chat](https://discord.com/channels/1020123559063990373/1049495067846524939) channel of the Discord before you start development in order to confirm that the issue or feature is aligned with the current direction of the project. We value our contributors time and effort and want to ensure that no ones time is being misspent.*

View File

@@ -65,6 +65,7 @@ InvokeAI:
esrgan: true
internet_available: true
log_tokenization: false
nsfw_checker: false
patchmatch: true
restore: true
...
@@ -135,16 +136,19 @@ command-line options by giving the `--help` argument:
```
(.venv) > invokeai-web --help
usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials] [--allow_methods [ALLOW_METHODS ...]]
[--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan] [--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
[--patchmatch | --no-patchmatch] [--restore | --no-restore]
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_loaded_models MAX_LOADED_MODELS] [--max_cache_size MAX_CACHE_SIZE]
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--gpu_mem_reserved GPU_MEM_RESERVED] [--precision {auto,float16,float32,autocast}]
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled] [--tiled_decode | --no-tiled_decode] [--root ROOT]
[--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR] [--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH]
[--models_dir MODELS_DIR] [--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]] [--log_format {plain,color,syslog,legacy}]
[--log_level {debug,info,warning,error,critical}] [--version | --no-version]
usage: InvokeAI [-h] [--host HOST] [--port PORT] [--allow_origins [ALLOW_ORIGINS ...]] [--allow_credentials | --no-allow_credentials]
[--allow_methods [ALLOW_METHODS ...]] [--allow_headers [ALLOW_HEADERS ...]] [--esrgan | --no-esrgan]
[--internet_available | --no-internet_available] [--log_tokenization | --no-log_tokenization]
[--nsfw_checker | --no-nsfw_checker] [--patchmatch | --no-patchmatch] [--restore | --no-restore]
[--always_use_cpu | --no-always_use_cpu] [--free_gpu_mem | --no-free_gpu_mem] [--max_cache_size MAX_CACHE_SIZE]
[--max_vram_cache_size MAX_VRAM_CACHE_SIZE] [--precision {auto,float16,float32,autocast}]
[--sequential_guidance | --no-sequential_guidance] [--xformers_enabled | --no-xformers_enabled]
[--tiled_decode | --no-tiled_decode] [--root ROOT] [--autoimport_dir AUTOIMPORT_DIR] [--lora_dir LORA_DIR]
[--embedding_dir EMBEDDING_DIR] [--controlnet_dir CONTROLNET_DIR] [--conf_path CONF_PATH] [--models_dir MODELS_DIR]
[--legacy_conf_dir LEGACY_CONF_DIR] [--db_dir DB_DIR] [--outdir OUTDIR] [--from_file FROM_FILE]
[--use_memory_db | --no-use_memory_db] [--model MODEL] [--log_handlers [LOG_HANDLERS ...]]
[--log_format {plain,color,syslog,legacy}] [--log_level {debug,info,warning,error,critical}]
...
```
## The Configuration Settings
@@ -174,6 +178,7 @@ These configuration settings allow you to enable and disable various InvokeAI fe
| `esrgan` | `true` | Activate the ESRGAN upscaling options|
| `internet_available` | `true` | When a resource is not available locally, try to fetch it via the internet |
| `log_tokenization` | `false` | Before each text2image generation, print a color-coded representation of the prompt to the console; this can help understand why a prompt is not working as expected |
| `nsfw_checker` | `true` | Activate the NSFW checker to blur out risque images |
| `patchmatch` | `true` | Activate the "patchmatch" algorithm for improved inpainting |
| `restore` | `true` | Activate the facial restoration features (DEPRECATED; restoration features will be removed in 3.0.0) |

View File

@@ -61,13 +61,11 @@ A noise scheduler (eg. DPM++ 2M Karras) schedules the subtraction of noise from
| ImageInverseLerp | Inverse linear interpolation of all pixels of an image |
| ImageLerp | Linear interpolation of all pixels of an image |
| ImageMultiply | Multiplies two images together using `PIL.ImageChops.Multiply()` |
| ImageNSFWBlurInvocation | Detects and blurs images that may contain sexually explicit content |
| ImagePaste | Pastes an image into another image |
| ImageProcessor | Base class for invocations that reprocess images for ControlNet |
| ImageResize | Resizes an image to specific dimensions |
| ImageScale | Scales an image by a factor |
| ImageToLatents | Scales latents by a given factor |
| ImageWatermarkInvocation | Adds an invisible watermark to images |
| InfillColor | Infills transparent areas of an image with a solid color |
| InfillPatchMatch | Infills transparent areas of an image using the PatchMatch algorithm |
| InfillTile | Infills transparent areas of an image with tiles of the image |
@@ -118,49 +116,49 @@ There are several node grouping concepts that can be examined with a narrow focu
As described, an initial noise tensor is necessary for the latent diffusion process. As a result, all non-image *ToLatents nodes require a noise node input.
![groupsnoise](../assets/nodes/groupsnoise.png)
<img width="654" alt="groupsnoise" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/2e8d297e-ad55-4d27-bc93-c119dad2a2c5">
### Conditioning
As described, conditioning is necessary for the latent diffusion process, whether empty or not. As a result, all non-image *ToLatents nodes require positive and negative conditioning inputs. Conditioning is reliant on a CLIP tokenizer provided by the Model Loader node.
![groupsconditioning](../assets/nodes/groupsconditioning.png)
<img width="1024" alt="groupsconditioning" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/f8f7ad8a-8d9c-418e-b5ad-1437b774b27e">
### Image Space & VAE
The ImageToLatents node doesn't require a noise node input, but requires a VAE input to convert the image from image space into latent space. In reverse, the LatentsToImage node requires a VAE input to convert from latent space back into image space.
![groupsimgvae](../assets/nodes/groupsimgvae.png)
<img width="637" alt="groupsimgvae" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/dd99969c-e0a8-4f78-9b17-3ffe179cef9a">
### Defined & Random Seeds
It is common to want to use both the same seed (for continuity) and random seeds (for variance). To define a seed, simply enter it into the 'Seed' field on a noise node. Conversely, the RandomInt node generates a random integer between 'Low' and 'High', and can be used as input to the 'Seed' edge point on a noise node to randomize your seed.
![groupsrandseed](../assets/nodes/groupsrandseed.png)
<img width="922" alt="groupsrandseed" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/af55bc20-60f6-438e-aba5-3ec871443710">
### Control
Control means to guide the diffusion process to adhere to a defined input or structure. Control can be provided as input to non-image *ToLatents nodes from ControlNet nodes. ControlNet nodes usually require an image processor which converts an input image for use with ControlNet.
![groupscontrol](../assets/nodes/groupscontrol.png)
<img width="805" alt="groupscontrol" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/cc9c5de7-23a7-46c8-bbad-1f3609d999a6">
### LoRA
The Lora Loader node lets you load a LoRA (say that ten times fast) and pass it as output to both the Prompt (Compel) and non-image *ToLatents nodes. A model's CLIP tokenizer is passed through the LoRA into Prompt (Compel), where it affects conditioning. A model's U-Net is also passed through the LoRA into a non-image *ToLatents node, where it affects noise prediction.
![groupslora](../assets/nodes/groupslora.png)
<img width="993" alt="groupslora" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/630962b0-d914-4505-b3ea-ccae9b0269da">
### Scaling
Use the ImageScale, ScaleLatents, and Upscale nodes to upscale images and/or latent images. The chosen method differs across contexts. However, be aware that latents are already noisy and compressed at their original resolution; scaling an image could produce more detailed results.
![groupsallscale](../assets/nodes/groupsallscale.png)
<img width="644" alt="groupsallscale" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/99314f05-dd9f-4b6d-b378-31de55346a13">
### Iteration + Multiple Images as Input
Iteration is a common concept in any processing, and means to repeat a process with given input. In nodes, you're able to use the Iterate node to iterate through collections usually gathered by the Collect node. The Iterate node has many potential uses, from processing a collection of images one after another, to varying seeds across multiple image generations and more. This screenshot demonstrates how to collect several images and pass them out one at a time.
![groupsiterate](../assets/nodes/groupsiterate.png)
<img width="788" alt="groupsiterate" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/4af5ca27-82c9-4018-8c5b-024d3ee0a121">
### Multiple Image Generation + Random Seeds
@@ -168,7 +166,7 @@ Multiple image generation in the node editor is done using the RandomRange node.
To control seeds across generations takes some care. The first row in the screenshot will generate multiple images with different seeds, but using the same RandomRange parameters across invocations will result in the same group of random seeds being used across the images, producing repeatable results. In the second row, adding the RandomInt node as input to RandomRange's 'Seed' edge point will ensure that seeds are varied across all images across invocations, producing varied results.
![groupsmultigenseeding](../assets/nodes/groupsmultigenseeding.png)
<img width="1027" alt="groupsmultigenseeding" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/518d1b2b-fed1-416b-a052-ab06552521b3">
## Examples
@@ -176,7 +174,7 @@ With our knowledge of node grouping and the diffusion process, lets break dow
### Basic text-to-image Node Graph
![nodest2i](../assets/nodes/nodest2i.png)
<img width="875" alt="nodest2i" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/17c67720-c376-4db8-94f0-5e00381a61ee">
- Model Loader: A necessity to generating images (as weve read above). We choose our model from the dropdown. It outputs a U-Net, CLIP tokenizer, and VAE.
- Prompt (Compel): Another necessity. Two prompt nodes are created. One will output positive conditioning (what you want, dog), one will output negative (what you dont want, cat). They both input the CLIP tokenizer that the Model Loader node outputs.
@@ -186,7 +184,7 @@ With our knowledge of node grouping and the diffusion process, lets break dow
### Basic image-to-image Node Graph
![nodesi2i](../assets/nodes/nodesi2i.png)
<img width="998" alt="nodesi2i" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/3f2c95d5-cee7-4415-9b79-b46ee60a92fe">
- Model Loader: Choose a model from the dropdown.
- Prompt (Compel): Two prompt nodes. One positive (dog), one negative (dog). Same CLIP inputs from the Model Loader node as before.
@@ -197,7 +195,7 @@ With our knowledge of node grouping and the diffusion process, lets break dow
### Basic ControlNet Node Graph
![nodescontrol](../assets/nodes/nodescontrol.png)
<img width="703" alt="nodescontrol" src="https://github.com/ymgenesis/InvokeAI/assets/25252829/b02ded86-ceb4-44a2-9910-e19ad184d471">
- Model Loader
- Prompt (Compel)

View File

@@ -1,40 +1,12 @@
---
title: Watermarking, NSFW Image Checking
title: The NSFW Checker
---
# :material-image-off: Invisible Watermark and the NSFW Checker
## Watermarking
InvokeAI does not apply watermarking to images by default. However,
many computer scientists working in the field of generative AI worry
that a flood of computer-generated imagery will contaminate the image
data sets needed to train future generations of generative models.
InvokeAI offers an optional watermarking mode that writes a small bit
of text, **InvokeAI**, into each image that it generates using an
"invisible" watermarking library that spreads the information
throughout the image in a way that is not perceptible to the human
eye. If you are planning to share your generated images on
internet-accessible services, we encourage you to activate the
invisible watermark mode in order to help preserve the digital image
environment.
The downside of watermarking is that it increases the size of the
image moderately, and has been reported by some individuals to degrade
image quality. Your mileage may vary.
To read the watermark in an image, activate the InvokeAI virtual
environment (called the "developer's console" in the launcher) and run
the command:
```
invisible-watermark -a decode -t bytes -m dwtDct -l 64 /path/to/image.png
```
# :material-image-off: NSFW Checker
## The NSFW ("Safety") Checker
Stable Diffusion 1.5-based image generation models will produce sexual
The Stable Diffusion image generation models will produce sexual
imagery if deliberately prompted, and will occasionally produce such
images when this is not intended. Such images are colloquially known
as "Not Safe for Work" (NSFW). This behavior is due to the nature of
@@ -46,17 +18,35 @@ jurisdictions it may be illegal to publicly distribute such imagery,
including mounting a publicly-available server that provides
unfiltered images to the public. Furthermore, the [Stable Diffusion
weights
License](https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE-SD1+SD2.txt),
and the [Stable Diffusion XL
License][https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE-SDXL.txt]
both forbid the models from being used to "exploit any of the
License](https://github.com/invoke-ai/InvokeAI/blob/main/LICENSE-ModelWeights.txt)
forbids the model from being used to "exploit any of the
vulnerabilities of a specific group of persons."
For these reasons Stable Diffusion offers a "safety checker," a
machine learning model trained to recognize potentially disturbing
imagery. When a potentially NSFW image is detected, the checker will
blur the image and paste a warning icon on top. The checker can be
turned on and off in the Web interface under Settings.
turned on and off on the command line using `--nsfw_checker` and
`--no-nsfw_checker`.
At installation time, InvokeAI will ask whether the checker should be
activated by default (neither argument given on the command line). The
response is stored in the InvokeAI initialization file
(`invokeai.yaml` in the InvokeAI root directory). You can change the
default at any time by opening this file in a text editor and
changing the line `nsfw_checker:` from true to false or vice-versa:
```
...
Features:
esrgan: true
internet_available: true
log_tokenization: false
nsfw_checker: true
patchmatch: true
restore: true
```
## Caveats
@@ -94,3 +84,10 @@ are encouraged to turn **off** intermediate image rendering when you
are using the checker. Future versions of InvokeAI will apply
additional blurring to intermediate images when the checker is active.
### Watermarking
InvokeAI does not apply any sort of watermark to images it
generates. However, it does write metadata into the PNG data area,
including the prompt used to generate the image and relevant parameter
settings. These fields can be examined using the `sd-metadata.py`
script that comes with the InvokeAI package.

View File

@@ -16,24 +16,21 @@ Output Example:
---
## **Invisible Watermark**
## **Seamless Tiling**
In keeping with the principles for responsible AI generation, and to
help AI researchers avoid synthetic images contaminating their
training sets, InvokeAI adds an invisible watermark to each of the
final images it generates. The watermark consists of the text
"InvokeAI" and can be viewed using the
[invisible-watermarks](https://github.com/ShieldMnt/invisible-watermark)
tool.
The seamless tiling mode causes generated images to seamlessly tile
with itself creating repetitive wallpaper-like patterns. To use it,
activate the Seamless Tiling option in the Web GUI and then select
whether to tile on the X (horizontal) and/or Y (vertical) axes. Tiling
will then be active for the next set of generations.
Watermarking is controlled using the `invisible-watermark` setting in
`invokeai.yaml`. To turn it off, add the following line under the `Features`
category.
A nice prompt to test seamless tiling with is:
```
invisible_watermark: false
pond garden with lotus by claude monet"
```
---
## **Weighted Prompts**
@@ -42,10 +39,34 @@ priority to them, by adding `:<percent>` to the end of the section you wish to u
example consider this prompt:
```bash
(tabby cat):0.25 (white duck):0.75 hybrid
tabby cat:0.25 white duck:0.75 hybrid
```
This will tell the sampler to invest 25% of its effort on the tabby cat aspect of the image and 75%
on the white duck aspect (surprisingly, this example actually works). The prompt weights can use any
combination of integers and floating point numbers, and they do not need to add up to 1.
## **Thresholding and Perlin Noise Initialization Options**
Under the Noise section of the Web UI, you will find two options named
Perlin Noise and Noise Threshold. [Perlin
noise](https://en.wikipedia.org/wiki/Perlin_noise) is a type of
structured noise used to simulate terrain and other natural
textures. The slider controls the percentage of perlin noise that will
be mixed into the image at the beginning of generation. Adding a little
perlin noise to a generation will alter the image substantially.
The noise threshold limits the range of the latent values during
sampling and helps combat the oversharpening seem with higher CFG
scale values.
For better intuition into what these options do in practice:
![here is a graphic demonstrating them both](../assets/truncation_comparison.jpg)
In generating this graphic, perlin noise at initialization was
programmatically varied going across on the diagram by values 0.0,
0.1, 0.2, 0.4, 0.5, 0.6, 0.8, 0.9, 1.0; and the threshold was varied
going down from 0, 1, 2, 3, 4, 5, 10, 20, 100. The other options are
fixed using the prompt "a portrait of a beautiful young lady" a CFG of
20, 100 steps, and a seed of 1950357039.

View File

@@ -4,9 +4,6 @@ title: Overview
Here you can find the documentation for InvokeAI's various features.
## The [Getting Started Guide](../help/gettingStartedWithAI)
A getting started guide for those new to AI image generation.
## The Basics
### * The [Web User Interface](WEB.md)
Guide to the Web interface. Also see the [WebUI Hotkeys Reference Guide](WEBUIHOTKEYS.md)
@@ -49,7 +46,7 @@ Personalize models by adding your own style or subjects.
## Other Features
### * [The NSFW Checker](WATERMARK+NSFW.md)
### * [The NSFW Checker](NSFW.md)
Prevent InvokeAI from displaying unwanted racy images.
### * [Controlling Logging](LOGGING.md)

View File

@@ -1,95 +0,0 @@
# Getting Started with AI Image Generation
New to image generation with AI? Youre in the right place!
This is a high level walkthrough of some of the concepts and terms youll see as you start using InvokeAI. Please note, this is not an exhaustive guide and may be out of date due to the rapidly changing nature of the space.
## Using InvokeAI
### **Prompt Crafting**
- Prompts are the basis of using InvokeAI, providing the models directions on what to generate. As a general rule of thumb, the more detailed your prompt is, the better your result will be.
*To get started, heres an easy template to use for structuring your prompts:*
- Subject, Style, Quality, Aesthetic
- **Subject:** What your image will be about. E.g. “a futuristic city with trains”, “penguins floating on icebergs”, “friends sharing beers”
- **Style:** The style or medium in which your image will be in. E.g. “photograph”, “pencil sketch”, “oil paints”, or “pop art”, “cubism”, “abstract”
- **Quality:** A particular aspect or trait that you would like to see emphasized in your image. E.g. "award-winning", "featured in {relevant set of high quality works}", "professionally acclaimed". Many people often use "masterpiece".
- **Aesthetics:** The visual impact and design of the artwork. This can be colors, mood, lighting, setting, etc.
- There are two prompt boxes: *Positive Prompt* & *Negative Prompt*.
- A **Positive** Prompt includes words you want the model to reference when creating an image.
- Negative Prompt is for anything you want the model to eliminate when creating an image. It doesnt always interpret things exactly the way you would, but helps control the generation process. Always try to include a few terms - you can typically use lower quality image terms like “blurry” or “distorted” with good success.
- Some examples prompts you can try on your own:
- A detailed oil painting of a tranquil forest at sunset with vibrant+ colors and soft, golden light filtering through the trees
- friends sharing beers in a busy city, realistic colored pencil sketch, twilight, masterpiece, bright, lively
### Generation Workflows
- Invoke offers a number of different workflows for interacting with models to produce images. Each is extremely powerful on its own, but together provide you an unparalleled way of producing high quality creative outputs that align with your vision.
- **Text to Image:** The text to image tab focuses on the key workflow of using a prompt to generate a new image. It includes other features that help control the generation process as well.
- **Image to Image:** With image to image, you provide an image as a reference (called the “initial image”), which provides more guidance around color and structure to the AI as it generates a new image. This is provided alongside the same features as Text to Image.
- **Unified Canvas:** The Unified Canvas is an advanced AI-first image editing tool that is easy to use, but hard to master. Drag an image onto the canvas from your gallery in order to regenerate certain elements, edit content or colors (known as inpainting), or extend the image with an exceptional degree of consistency and clarity (called outpainting).
### Improving Image Quality
- Fine tuning your prompt - the more specific you are, the closer the image will turn out to what is in your head! Adding more details in the Positive Prompt or Negative Prompt can help add / remove pieces of your image to improve it - You can also use advanced techniques like upweighting and downweighting to control the influence of certain words. [Learn more here](https://invoke-ai.github.io/InvokeAI/features/PROMPTS/#prompt-syntax-features).
- **Tip: If youre seeing poor results, try adding the things you dont like about the image to your negative prompt may help. E.g. distorted, low quality, unrealistic, etc.**
- Explore different models - Other models can produce different results due to the data theyve been trained on. Each model has specific language and settings it works best with; a models documentation is your friend here. Play around with some and see what works best for you!
- Increasing Steps - The number of steps used controls how much time the model is given to produce an image, and depends on the “Scheduler” used. The schedule controls how each step is processed by the model. More steps tends to mean better results, but will take longer - We recommend at least 30 steps for most
- Tweak and Iterate - Remember, its best to change one thing at a time so you know what is working and what isn't. Sometimes you just need to try a new image, and other times using a new prompt might be the ticket. For testing, consider turning off the “random” Seed - Using the same seed with the same settings will produce the same image, which makes it the perfect way to learn exactly what your changes are doing.
- Explore Advanced Settings - InvokeAI has a full suite of tools available to allow you complete control over your image creation process - Check out our [docs if you want to learn more](https://invoke-ai.github.io/InvokeAI/features/).
## Terms & Concepts
If you're interested in learning more, check out [this presentation](https://docs.google.com/presentation/d/1IO78i8oEXFTZ5peuHHYkVF-Y3e2M6iM5tCnc-YBfcCM/edit?usp=sharing) from one of our maintainers (@lstein).
### Stable Diffusion
Stable Diffusion is deep learning, text-to-image model that is the foundation of the capabilities found in InvokeAI. Since the release of Stable Diffusion, there have been many subsequent models created based on Stable Diffusion that are designed to generate specific types of images.
### Prompts
Prompts provide the models directions on what to generate. As a general rule of thumb, the more detailed your prompt is, the better your result will be.
### Models
Models are the magic that power InvokeAI. These files represent the output of training a machine on understanding massive amounts of images - providing them with the capability to generate new images using just a text description of what youd like to see. (Like Stable Diffusion!)
Invoke offers a simple way to download several different models upon installation, but many more can be discovered online, including at ****. Each model can produce a unique style of output, based on the images it was trained on - Try out different models to see which best fits your creative vision!
- *Models that contain “inpainting” in the name are designed for use with the inpainting feature of the Unified Canvas*
### Scheduler
Schedulers guide the process of removing noise (de-noising) from data. They determine:
1. The number of steps to take to remove the noise.
2. Whether the steps are random (stochastic) or predictable (deterministic).
3. The specific method (algorithm) used for de-noising.
Experimenting with different schedulers is recommended as each will produce different outputs!
### Steps
The number of de-noising steps each generation through.
Schedulers can be intricate and there's often a balance to strike between how quickly they can de-noise data and how well they can do it. It's typically advised to experiment with different schedulers to see which one gives the best results. There has been a lot written on the internet about different schedulers, as well as exploring what the right level of "steps" are for each. You can save generation time by reducing the number of steps used, but you'll want to make sure that you are satisfied with the quality of images produced!
### Low-Rank Adaptations / LoRAs
Low-Rank Adaptations (LoRAs) are like a smaller, more focused version of models, intended to focus on training a better understanding of how a specific character, style, or concept looks.
### Textual Inversion Embeddings
Textual Inversion Embeddings, like LoRAs, assist with more easily prompting for certain characters, styles, or concepts. However, embeddings are trained to update the relationship between a specific word (known as the “trigger”) and the intended output.
### ControlNet
ControlNets are neural network models that are able to extract key features from an existing image and use these features to guide the output of the image generation model.
### VAE
Variational auto-encoder (VAE) is a encode/decode model that translates the "latents" image produced during the image generation procees to the large pixel images that we see.

View File

@@ -11,33 +11,6 @@ title: Home
```
-->
<!-- CSS styling -->
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/@fortawesome/fontawesome-free@6.2.1/css/fontawesome.min.css">
<style>
.button {
width: 300px;
height: 50px;
background-color: #448AFF;
color: #fff;
font-size: 16px;
border: none;
cursor: pointer;
border-radius: 0.2rem;
}
.button-container {
display: grid;
grid-template-columns: repeat(3, 300px);
gap: 20px;
}
.button:hover {
background-color: #526CFE;
}
</style>
<div align="center" markdown>
@@ -97,23 +70,61 @@ image-to-image generator. It provides a streamlined process with various new
features and options to aid the image generation process. It runs on Windows,
Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
**Quick links**: [<a href="https://discord.gg/ZmtBAhwWhy">Discord Server</a>]
[<a href="https://github.com/invoke-ai/InvokeAI/">Code and Downloads</a>] [<a
href="https://github.com/invoke-ai/InvokeAI/issues">Bug Reports</a>] [<a
href="https://github.com/invoke-ai/InvokeAI/discussions">Discussion, Ideas &
Q&A</a>]
<div align="center"><img src="assets/invoke-web-server-1.png" width=640></div>
!!! Note
!!! note
This project is rapidly evolving. Please use the [Issues tab](https://github.com/invoke-ai/InvokeAI/issues) to report bugs and make feature requests. Be sure to use the provided templates as it will help aid response time.
This fork is rapidly evolving. Please use the [Issues tab](https://github.com/invoke-ai/InvokeAI/issues) to report bugs and make feature requests. Be sure to use the provided templates. They will help aid diagnose issues faster.
## :octicons-link-24: Quick Links
## :octicons-package-dependencies-24: Installation
<div class="button-container">
<a href="installation/INSTALLATION"> <button class="button">Installation</button> </a>
<a href="features/"> <button class="button">Features</button> </a>
<a href="help/gettingStartedWithAI/"> <button class="button">Getting Started</button> </a>
<a href="contributing/CONTRIBUTING/"> <button class="button">Contributing</button> </a>
<a href="https://github.com/invoke-ai/InvokeAI/"> <button class="button">Code and Downloads</button> </a>
<a href="https://github.com/invoke-ai/InvokeAI/issues"> <button class="button">Bug Reports </button> </a>
<a href="https://discord.gg/ZmtBAhwWhy"> <button class="button"> Join the Discord Server!</button> </a>
</div>
This fork is supported across Linux, Windows and Macintosh. Linux users can use
either an Nvidia-based card (with CUDA support) or an AMD card (using the ROCm
driver).
### [Installation Getting Started Guide](installation)
#### **[Automated Installer](installation/010_INSTALL_AUTOMATED.md)**
✅ This is the recommended installation method for first-time users.
#### [Manual Installation](installation/020_INSTALL_MANUAL.md)
This method is recommended for experienced users and developers
#### [Docker Installation](installation/040_INSTALL_DOCKER.md)
This method is recommended for those familiar with running Docker containers
### Other Installation Guides
- [PyPatchMatch](installation/060_INSTALL_PATCHMATCH.md)
- [XFormers](installation/070_INSTALL_XFORMERS.md)
- [CUDA and ROCm Drivers](installation/030_INSTALL_CUDA_AND_ROCM.md)
- [Installing New Models](installation/050_INSTALLING_MODELS.md)
## :fontawesome-solid-computer: Hardware Requirements
### :octicons-cpu-24: System
You wil need one of the following:
- :simple-nvidia: An NVIDIA-based graphics card with 4 GB or more VRAM memory.
- :simple-amd: An AMD-based graphics card with 4 GB or more VRAM memory (Linux
only)
- :fontawesome-brands-apple: An Apple computer with an M1 chip.
We do **not recommend** the following video cards due to issues with their
running in half-precision mode and having insufficient VRAM to render 512x512
images in full-precision mode:
- NVIDIA 10xx series cards such as the 1080ti
- GTX 1650 series cards
- GTX 1660 series cards
### :fontawesome-solid-memory: Memory and Disk
- At least 12 GB Main Memory RAM.
- At least 18 GB of free disk space for the machine learning model, Python, and
all its dependencies.
## :octicons-gift-24: InvokeAI Features
@@ -137,7 +148,7 @@ Mac and Linux machines, and runs on GPU cards with as little as 4 GB of RAM.
- [Model Merging](features/MODEL_MERGING.md)
- [ControlNet Models](features/CONTROLNET.md)
- [Style/Subject Concepts and Embeddings](features/CONCEPTS.md)
- [Watermarking and the Not Safe for Work (NSFW) Checker](features/WATERMARK+NSFW.md)
- [Not Safe for Work (NSFW) Checker](features/NSFW.md)
<!-- seperator -->
### Prompt Engineering
- [Prompt Syntax](features/PROMPTS.md)
@@ -219,7 +230,7 @@ encouraged to do so.
## :octicons-person-24: Contributors
This software is a combined effort of various people from across the world.
This fork is a combined effort of various people from across the world.
[Check out the list of all these amazing people](other/CONTRIBUTORS.md). We
thank them for their time, hard work and effort.

View File

@@ -40,8 +40,10 @@ experimental versions later.
this, open up a command-line window ("Terminal" on Linux and
Macintosh, "Command" or "Powershell" on Windows) and type `python
--version`. If Python is installed, it will print out the version
number. If it is version `3.9.*`, `3.10.*` or `3.11.*` you meet
requirements.
number. If it is version `3.9.*` or `3.10.*`, you meet
requirements. We do not recommend using Python 3.11 or higher,
as not all the libraries that InvokeAI depends on work properly
with this version.
!!! warning "What to do if you have an unsupported version"
@@ -213,6 +215,17 @@ experimental versions later.
Generally the defaults are fine, and you can come back to this screen at
any time to tweak your system. Here are the options you can adjust:
- ***Output directory for images***
This is the path to a directory in which InvokeAI will store all its
generated images.
- ***NSFW checker***
If checked, InvokeAI will test images for potential sexual content
and blur them out if found. Note that the NSFW checker consumes
an additional 0.6 GB of VRAM on top of the 2-3 GB of VRAM used
by most image models. If you have a low VRAM GPU (4-6 GB), you
can reduce out of memory errors by disabling the checker.
- ***HuggingFace Access Token***
InvokeAI has the ability to download embedded styles and subjects
from the HuggingFace Concept Library on-demand. However, some of
@@ -244,30 +257,20 @@ experimental versions later.
and graphics cards. The "autocast" option is deprecated and
shouldn't be used unless you are asked to by a member of the team.
- **Size of the RAM cache used for fast model switching***
- ***Number of models to cache in CPU memory***
This allows you to keep models in memory and switch rapidly among
them rather than having them load from disk each time. This slider
controls how many models to keep loaded at once. A typical SD-1 or SD-2 model
uses 2-3 GB of memory. A typical SDXL model uses 6-7 GB. Providing more
RAM will allow more models to be co-resident.
controls how many models to keep loaded at once. Each
model will use 2-4 GB of RAM, so use this cautiously
- ***Output directory for images***
This is the path to a directory in which InvokeAI will store all its
generated images.
- ***Autoimport Folder***
This is the directory in which you can place models you have
downloaded and wish to load into InvokeAI. You can place a variety
of models in this directory, including diffusers folders, .ckpt files,
.safetensors files, as well as LoRAs, ControlNet and Textual Inversion
files (both folder and file versions). To help organize this folder,
you can create several levels of subfolders and drop your models into
whichever ones you want.
- ***Autoimport FolderLICENSE***
- ***Directory containing embedding/textual inversion files***
This is the directory in which you can place custom embedding
files (.pt or .bin). During startup, this directory will be
scanned and InvokeAI will print out the text terms that
are available to trigger the embeddings.
At the bottom of the screen you will see a checkbox for accepting
the CreativeML Responsible AI Licenses. You need to accept the license
the CreativeML Responsible AI License. You need to accept the license
in order to download Stable Diffusion models from the next screen.
_You can come back to the startup options form_ as many times as you like.
@@ -372,71 +375,8 @@ experimental versions later.
Once InvokeAI is installed, do not move or remove this directory."
<a name="troubleshooting"></a>
## Troubleshooting
### _OSErrors on Windows while installing dependencies_
During a zip file installation or an online update, installation stops
with an error like this:
![broken-dependency-screenshot](../assets/troubleshooting/broken-dependency.png){:width="800px"}
This seems to happen particularly often with the `pydantic` and
`numpy` packages. The most reliable solution requires several manual
steps to complete installation.
Open up a Powershell window and navigate to the `invokeai` directory
created by the installer. Then give the following series of commands:
```cmd
rm .\.venv -r -force
python -mvenv .venv
.\.venv\Scripts\activate
pip install invokeai
invokeai-configure --yes --root .
```
If you see anything marked as an error during this process please stop
and seek help on the Discord [installation support
channel](https://discord.com/channels/1020123559063990373/1041391462190956654). A
few warning messages are OK.
If you are updating from a previous version, this should restore your
system to a working state. If you are installing from scratch, there
is one additional command to give:
```cmd
wget -O invoke.bat https://raw.githubusercontent.com/invoke-ai/InvokeAI/main/installer/templates/invoke.bat.in
```
This will create the `invoke.bat` script needed to launch InvokeAI and
its related programs.
### _Stable Diffusion XL Generation Fails after Trying to Load unet_
InvokeAI is working in other respects, but when trying to generate
images with Stable Diffusion XL you get a "Server Error". The text log
in the launch window contains this log line above several more lines of
error messages:
```INFO --> Loading model:D:\LONG\PATH\TO\MODEL, type sdxl:main:unet```
This failure mode occurs when there is a network glitch during
downloading the very large SDXL model.
To address this, first go to the Web Model Manager and delete the
Stable-Diffusion-XL-base-1.X model. Then navigate to HuggingFace and
manually download the .safetensors version of the model. The 1.0
version is located at
https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main
and the file is named `sd_xl_base_1.0.safetensors`.
Save this file to disk and then reenter the Model Manager. Navigate to
Import Models->Add Model, then type (or drag-and-drop) the path to the
.safetensors file. Press "Add Model".
### _Package dependency conflicts_
If you have previously installed InvokeAI or another Stable Diffusion

View File

@@ -32,7 +32,7 @@ gaming):
* **Python**
version 3.9 through 3.11
version 3.9 or 3.10 (3.11 is not recommended).
* **CUDA Tools**
@@ -65,7 +65,7 @@ gaming):
To install InvokeAI with virtual environments and the PIP package
manager, please follow these steps:
1. Please make sure you are using Python 3.9 through 3.11. The rest of the install
1. Please make sure you are using Python 3.9 or 3.10. The rest of the install
procedure depends on this and will not work with other versions:
```bash

View File

@@ -1,4 +1,6 @@
# Overview
---
title: Overview
---
We offer several ways to install InvokeAI, each one suited to your
experience and preferences. We suggest that everyone start by
@@ -13,56 +15,6 @@ See the [troubleshooting
section](010_INSTALL_AUTOMATED.md#troubleshooting) of the automated
install guide for frequently-encountered installation issues.
This fork is supported across Linux, Windows and Macintosh. Linux users can use
either an Nvidia-based card (with CUDA support) or an AMD card (using the ROCm
driver).
### [Installation Getting Started Guide](installation)
#### **[Automated Installer](010_INSTALL_AUTOMATED.md)**
✅ This is the recommended installation method for first-time users.
#### [Manual Installation](020_INSTALL_MANUAL.md)
This method is recommended for experienced users and developers
#### [Docker Installation](040_INSTALL_DOCKER.md)
This method is recommended for those familiar with running Docker containers
### Other Installation Guides
- [PyPatchMatch](installation/060_INSTALL_PATCHMATCH.md)
- [XFormers](installation/070_INSTALL_XFORMERS.md)
- [CUDA and ROCm Drivers](installation/030_INSTALL_CUDA_AND_ROCM.md)
- [Installing New Models](installation/050_INSTALLING_MODELS.md)
## :fontawesome-solid-computer: Hardware Requirements
### :octicons-cpu-24: System
You wil need one of the following:
- :simple-nvidia: An NVIDIA-based graphics card with 4 GB or more VRAM memory.
- :simple-amd: An AMD-based graphics card with 4 GB or more VRAM memory (Linux
only)
- :fontawesome-brands-apple: An Apple computer with an M1 chip.
** SDXL 1.0 Requirements*
To use SDXL, user must have one of the following:
- :simple-nvidia: An NVIDIA-based graphics card with 8 GB or more VRAM memory.
- :simple-amd: An AMD-based graphics card with 16 GB or more VRAM memory (Linux
only)
- :fontawesome-brands-apple: An Apple computer with an M1 chip.
### :fontawesome-solid-memory: Memory and Disk
- At least 12 GB Main Memory RAM.
- At least 18 GB of free disk space for the machine learning model, Python, and
all its dependencies.
We do **not recommend** the following video cards due to issues with their
running in half-precision mode and having insufficient VRAM to render 512x512
images in full-precision mode:
- NVIDIA 10xx series cards such as the 1080ti
- GTX 1650 series cards
- GTX 1660 series cards
## Installation options
1. [Automated Installer](010_INSTALL_AUTOMATED.md)

View File

@@ -14,28 +14,23 @@ The nodes linked below have been developed and contributed by members of the Inv
## List of Nodes
### FaceTools
### Face Mask
**Description:** FaceTools is a collection of nodes created to manipulate faces as you would in Unified Canvas. It includes FaceMask, FaceOff, and FacePlace. FaceMask autodetects a face in the image using MediaPipe and creates a mask from it. FaceOff similarly detects a face, then takes the face off of the image by adding a square bounding box around it and cropping/scaling it. FacePlace puts the bounded face image from FaceOff back onto the original image. Using these nodes with other inpainting node(s), you can put new faces on existing things, put new things around existing faces, and work closer with a face as a bounded image. Additionally, you can supply X and Y offset values to scale/change the shape of the mask for finer control on FaceMask and FaceOff. See GitHub repository below for usage examples.
**Description:** This node autodetects a face in the image using MediaPipe and masks it by making it transparent. Via outpainting you can swap faces with other faces, or invert the mask and swap things around the face with other things. Additionally, you can supply X and Y offset values to scale/change the shape of the mask for finer control. The node also outputs an all-white mask in the same dimensions as the input image. This is needed by the inpaint node (and unified canvas) for outpainting.
**Node Link:** https://github.com/ymgenesis/FaceTools/
**Node Link:** https://github.com/ymgenesis/InvokeAI/blob/facemaskmediapipe/invokeai/app/invocations/facemask.py
**FaceMask Output Examples**
**Example Node Graph:** https://www.mediafire.com/file/gohn5sb1bfp8use/21-July_2023-FaceMask.json/file
![5cc8abce-53b0-487a-b891-3bf94dcc8960](https://github.com/invoke-ai/InvokeAI/assets/25252829/43f36d24-1429-4ab1-bd06-a4bedfe0955e)
![b920b710-1882-49a0-8d02-82dff2cca907](https://github.com/invoke-ai/InvokeAI/assets/25252829/7660c1ed-bf7d-4d0a-947f-1fc1679557ba)
![71a91805-fda5-481c-b380-264665703133](https://github.com/invoke-ai/InvokeAI/assets/25252829/f8f6a2ee-2b68-4482-87da-b90221d5c3e2)
**Output Examples**
<hr>
### Ideal Size
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
**Node Link:** https://github.com/JPPhoto/ideal-size-node
![2e3168cb-af6a-475d-bfac-c7b7fd67b4c2](https://github.com/ymgenesis/InvokeAI/assets/25252829/a5ad7d44-2ada-4b3c-a56e-a21f8244a1ac)
![2_annotated](https://github.com/ymgenesis/InvokeAI/assets/25252829/53416c8a-a23b-4d76-bb6d-3cfd776e0096)
![2fe2150c-fd08-4e26-8c36-f0610bf441bb](https://github.com/ymgenesis/InvokeAI/assets/25252829/b0f7ecfe-f093-4147-a904-b9f131b41dc9)
![831b6b98-4f0f-4360-93c8-69a9c1338cbe](https://github.com/ymgenesis/InvokeAI/assets/25252829/fc7b0622-e361-4155-8a76-082894d084f0)
--------------------------------
### Example Node Template
### Super Cool Node Template
**Description:** This node allows you to do super cool things with InvokeAI.
@@ -45,9 +40,13 @@ The nodes linked below have been developed and contributed by members of the Inv
**Output Examples**
![Example Image](https://invoke-ai.github.io/InvokeAI/assets/invoke_ai_banner.png){: style="height:115px;width:240px"}
![Invoke AI](https://invoke-ai.github.io/InvokeAI/assets/invoke_ai_banner.png)
### Ideal Size
**Description:** This node calculates an ideal image size for a first pass of a multi-pass upscaling. The aim is to avoid duplication that results from choosing a size larger than the model is capable of.
**Node Link:** https://github.com/JPPhoto/ideal-size-node
## Help
If you run into any issues with a node, please post in the [InvokeAI Discord](https://discord.gg/ZmtBAhwWhy).

25
flake.lock generated
View File

@@ -1,25 +0,0 @@
{
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1690630721,
"narHash": "sha256-Y04onHyBQT4Erfr2fc82dbJTfXGYrf4V0ysLUYnPOP8=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "d2b52322f35597c62abf56de91b0236746b2a03d",
"type": "github"
},
"original": {
"id": "nixpkgs",
"type": "indirect"
}
},
"root": {
"inputs": {
"nixpkgs": "nixpkgs"
}
}
},
"root": "root",
"version": 7
}

View File

@@ -1,81 +0,0 @@
# Important note: this flake does not attempt to create a fully isolated, 'pure'
# Python environment for InvokeAI. Instead, it depends on local invocations of
# virtualenv/pip to install the required (binary) packages, most importantly the
# prebuilt binary pytorch packages with CUDA support.
# ML Python packages with CUDA support, like pytorch, are notoriously expensive
# to compile so it's purposefuly not what this flake does.
{
description = "An (impure) flake to develop on InvokeAI.";
outputs = { self, nixpkgs }:
let
system = "x86_64-linux";
pkgs = import nixpkgs {
inherit system;
config.allowUnfree = true;
};
python = pkgs.python310;
mkShell = { dir, install }:
let
setupScript = pkgs.writeScript "setup-invokai" ''
# This must be sourced using 'source', not executed.
${python}/bin/python -m venv ${dir}
${dir}/bin/python -m pip install ${install}
# ${dir}/bin/python -c 'import torch; assert(torch.cuda.is_available())'
source ${dir}/bin/activate
'';
in
pkgs.mkShell rec {
buildInputs = with pkgs; [
# Backend: graphics, CUDA.
cudaPackages.cudnn
cudaPackages.cuda_nvrtc
cudatoolkit
freeglut
glib
gperf
procps
libGL
libGLU
linuxPackages.nvidia_x11
python
stdenv.cc
stdenv.cc.cc.lib
xorg.libX11
xorg.libXext
xorg.libXi
xorg.libXmu
xorg.libXrandr
xorg.libXv
zlib
# Pre-commit hooks.
black
# Frontend.
yarn
nodejs
];
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath buildInputs;
CUDA_PATH = pkgs.cudatoolkit;
EXTRA_LDFLAGS = "-L${pkgs.linuxPackages.nvidia_x11}/lib";
shellHook = ''
if [[ -f "${dir}/bin/activate" ]]; then
source "${dir}/bin/activate"
echo "Using Python: $(which python)"
else
echo "Use 'source ${setupScript}' to set up the environment."
fi
'';
};
in
{
devShells.${system} = rec {
develop = mkShell { dir = "venv"; install = "-e '.[xformers]' --extra-index-url https://download.pytorch.org/whl/cu118"; };
default = develop;
};
};
}

View File

@@ -9,20 +9,16 @@ cd $scriptdir
function version { echo "$@" | awk -F. '{ printf("%d%03d%03d%03d\n", $1,$2,$3,$4); }'; }
MINIMUM_PYTHON_VERSION=3.9.0
MAXIMUM_PYTHON_VERSION=3.11.100
MAXIMUM_PYTHON_VERSION=3.11.0
PYTHON=""
for candidate in python3.11 python3.10 python3.9 python3 python ; do
for candidate in python3.10 python3.9 python3 python ; do
if ppath=`which $candidate`; then
# when using `pyenv`, the executable for an inactive Python version will exist but will not be operational
# we check that this found executable can actually run
if [ $($candidate --version &>/dev/null; echo ${PIPESTATUS}) -gt 0 ]; then continue; fi
python_version=$($ppath -V | awk '{ print $2 }')
if [ $(version $python_version) -ge $(version "$MINIMUM_PYTHON_VERSION") ]; then
if [ $(version $python_version) -le $(version "$MAXIMUM_PYTHON_VERSION") ]; then
PYTHON=$ppath
break
fi
if [ $(version $python_version) -lt $(version "$MAXIMUM_PYTHON_VERSION") ]; then
PYTHON=$ppath
break
fi
fi
fi
done

View File

@@ -13,7 +13,7 @@ from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Union
SUPPORTED_PYTHON = ">=3.9.0,<=3.11.100"
SUPPORTED_PYTHON = ">=3.9.0,<3.11"
INSTALLER_REQS = ["rich", "semver", "requests", "plumbum", "prompt-toolkit"]
BOOTSTRAP_VENV_PREFIX = "invokeai-installer-tmp"
@@ -141,16 +141,15 @@ class Installer:
# upgrade pip in Python 3.9 environments
if int(platform.python_version_tuple()[1]) == 9:
from plumbum import FG, local
pip = local[get_pip_from_venv(venv_dir)]
pip["install", "--upgrade", "pip"] & FG
pip[ "install", "--upgrade", "pip"] & FG
return venv_dir
def install(
self, root: str = "~/invokeai", version: str = "latest", yes_to_all=False, find_links: Path = None
) -> None:
def install(self, root: str = "~/invokeai-3", version: str = "latest", yes_to_all=False, find_links: Path = None) -> None:
"""
Install the InvokeAI application into the given runtime path
@@ -168,8 +167,7 @@ class Installer:
messages.welcome()
default_path = os.environ.get("INVOKEAI_ROOT") or Path(root).expanduser().resolve()
self.dest = default_path if yes_to_all else messages.dest_path(root)
self.dest = Path(root).expanduser().resolve() if yes_to_all else messages.dest_path(root)
# create the venv for the app
self.venv = self.app_venv()
@@ -177,7 +175,7 @@ class Installer:
self.instance = InvokeAiInstance(runtime=self.dest, venv=self.venv, version=version)
# install dependencies and the InvokeAI application
(extra_index_url, optional_modules) = get_torch_source() if not yes_to_all else (None, None)
(extra_index_url,optional_modules) = get_torch_source() if not yes_to_all else (None,None)
self.instance.install(
extra_index_url,
optional_modules,
@@ -190,7 +188,6 @@ class Installer:
# run through the configuration flow
self.instance.configure()
class InvokeAiInstance:
"""
Manages an installed instance of InvokeAI, comprising a virtual environment and a runtime directory.
@@ -199,6 +196,7 @@ class InvokeAiInstance:
"""
def __init__(self, runtime: Path, venv: Path, version: str) -> None:
self.runtime = runtime
self.venv = venv
self.pip = get_pip_from_venv(venv)
@@ -249,9 +247,6 @@ class InvokeAiInstance:
pip[
"install",
"--require-virtualenv",
"numpy~=1.24.0", # choose versions that won't be uninstalled during phase 2
"urllib3~=1.26.0",
"requests~=2.28.0",
"torch~=2.0.0",
"torchmetrics==0.11.4",
"torchvision>=0.14.1",
@@ -317,7 +312,7 @@ class InvokeAiInstance:
"install",
"--require-virtualenv",
"--use-pep517",
str(src) + (optional_modules if optional_modules else ""),
str(src)+(optional_modules if optional_modules else ''),
"--find-links" if find_links is not None else None,
find_links,
"--extra-index-url" if extra_index_url is not None else None,
@@ -334,15 +329,15 @@ class InvokeAiInstance:
# set sys.argv to a consistent state
new_argv = [sys.argv[0]]
for i in range(1, len(sys.argv)):
for i in range(1,len(sys.argv)):
el = sys.argv[i]
if el in ["-r", "--root"]:
if el in ['-r','--root']:
new_argv.append(el)
new_argv.append(sys.argv[i + 1])
elif el in ["-y", "--yes", "--yes-to-all"]:
new_argv.append(sys.argv[i+1])
elif el in ['-y','--yes','--yes-to-all']:
new_argv.append(el)
sys.argv = new_argv
import requests # to catch download exceptions
from messages import introduction
@@ -358,16 +353,16 @@ class InvokeAiInstance:
invokeai_configure()
succeeded = True
except requests.exceptions.ConnectionError as e:
print(f"\nA network error was encountered during configuration and download: {str(e)}")
print(f'\nA network error was encountered during configuration and download: {str(e)}')
except OSError as e:
print(f"\nAn OS error was encountered during configuration and download: {str(e)}")
print(f'\nAn OS error was encountered during configuration and download: {str(e)}')
except Exception as e:
print(f"\nA problem was encountered during the configuration and download steps: {str(e)}")
print(f'\nA problem was encountered during the configuration and download steps: {str(e)}')
finally:
if not succeeded:
print('To try again, find the "invokeai" directory, run the script "invoke.sh" or "invoke.bat"')
print("and choose option 7 to fix a broken install, optionally followed by option 5 to install models.")
print("Alternatively you can relaunch the installer.")
print('and choose option 7 to fix a broken install, optionally followed by option 5 to install models.')
print('Alternatively you can relaunch the installer.')
def install_user_scripts(self):
"""
@@ -376,11 +371,11 @@ class InvokeAiInstance:
ext = "bat" if OS == "Windows" else "sh"
# scripts = ['invoke', 'update']
scripts = ["invoke"]
#scripts = ['invoke', 'update']
scripts = ['invoke']
for script in scripts:
src = Path(__file__).parent / ".." / "templates" / f"{script}.{ext}.in"
src = Path(__file__).parent / '..' / "templates" / f"{script}.{ext}.in"
dest = self.runtime / f"{script}.{ext}"
shutil.copy(src, dest)
os.chmod(dest, 0o0755)
@@ -425,7 +420,11 @@ def set_sys_path(venv_path: Path) -> None:
# filter out any paths in sys.path that may be system- or user-wide
# but leave the temporary bootstrap virtualenv as it contains packages we
# temporarily need at install time
sys.path = list(filter(lambda p: not p.endswith("-packages") or p.find(BOOTSTRAP_VENV_PREFIX) != -1, sys.path))
sys.path = list(filter(
lambda p: not p.endswith("-packages")
or p.find(BOOTSTRAP_VENV_PREFIX) != -1,
sys.path
))
# determine site-packages/lib directory location for the venv
lib = "Lib" if OS == "Windows" else f"lib/python{sys.version_info.major}.{sys.version_info.minor}"
@@ -434,7 +433,7 @@ def set_sys_path(venv_path: Path) -> None:
sys.path.append(str(Path(venv_path, lib, "site-packages").expanduser().resolve()))
def get_torch_source() -> (Union[str, None], str):
def get_torch_source() -> (Union[str, None],str):
"""
Determine the extra index URL for pip to use for torch installation.
This depends on the OS and the graphics accelerator in use.
@@ -455,19 +454,16 @@ def get_torch_source() -> (Union[str, None], str):
device = graphical_accelerator()
url = None
optional_modules = "[onnx]"
optional_modules = None
if OS == "Linux":
if device == "rocm":
url = "https://download.pytorch.org/whl/rocm5.4.2"
elif device == "cpu":
url = "https://download.pytorch.org/whl/cpu"
if device == "cuda":
url = "https://download.pytorch.org/whl/cu117"
optional_modules = "[xformers,onnx-cuda]"
if device == "cuda_and_dml":
url = "https://download.pytorch.org/whl/cu117"
optional_modules = "[xformers,onnx-directml]"
if device == 'cuda':
url = 'https://download.pytorch.org/whl/cu117'
optional_modules = '[xformers]'
# in all other cases, Torch wheels should be coming from PyPi as of Torch 1.13

View File

@@ -3,7 +3,6 @@ InvokeAI Installer
"""
import argparse
import os
from pathlib import Path
from installer import Installer
@@ -16,7 +15,7 @@ if __name__ == "__main__":
dest="root",
type=str,
help="Destination path for installation",
default=os.environ.get("INVOKEAI_ROOT") or "~/invokeai",
default="~/invokeai",
)
parser.add_argument(
"-y",
@@ -42,7 +41,7 @@ if __name__ == "__main__":
type=Path,
default=None,
)
args = parser.parse_args()
inst = Installer()

View File

@@ -36,15 +36,13 @@ else:
def welcome():
@group()
def text():
if (platform_specific := _platform_specific_help()) != "":
yield platform_specific
yield ""
yield Text.from_markup(
"Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.",
justify="center",
)
yield Text.from_markup("Some of the installation steps take a long time to run. Please be patient. If the script appears to hang for more than 10 minutes, please interrupt with [i]Control-C[/] and retry.", justify="center")
console.rule()
print(
@@ -60,7 +58,6 @@ def welcome():
)
console.line()
def confirm_install(dest: Path) -> bool:
if dest.exists():
print(f":exclamation: Directory {dest} already exists :exclamation:")
@@ -95,6 +92,7 @@ def dest_path(dest=None) -> Path:
dest_confirmed = confirm_install(dest)
while not dest_confirmed:
# if the given destination already exists, the starting point for browsing is its parent directory.
# the user may have made a typo, or otherwise wants to place the root dir next to an existing one.
# if the destination dir does NOT exist, then the user must have changed their mind about the selection.
@@ -167,10 +165,6 @@ def graphical_accelerator():
"an [gold1 b]NVIDIA[/] GPU (using CUDA™)",
"cuda",
)
nvidia_with_dml = (
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
"cuda_and_dml",
)
amd = (
"an [gold1 b]AMD[/] GPU (using ROCm™)",
"rocm",
@@ -185,7 +179,7 @@ def graphical_accelerator():
)
if OS == "Windows":
options = [nvidia, nvidia_with_dml, cpu]
options = [nvidia, cpu]
if OS == "Linux":
options = [nvidia, amd, cpu]
elif OS == "Darwin":
@@ -306,20 +300,15 @@ def introduction() -> None:
)
console.line(2)
def _platform_specific_help() -> str:
def _platform_specific_help()->str:
if OS == "Darwin":
text = Text.from_markup(
"""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/]."""
)
text = Text.from_markup("""[b wheat1]macOS Users![/]\n\nPlease be sure you have the [b wheat1]Xcode command-line tools[/] installed before continuing.\nIf not, cancel with [i]Control-C[/] and follow the Xcode install instructions at [deep_sky_blue1]https://www.freecodecamp.org/news/install-xcode-command-line-tools/[/].""")
elif OS == "Windows":
text = Text.from_markup(
"""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
text = Text.from_markup("""[b wheat1]Windows Users![/]\n\nBefore you start, please do the following:
1. Double-click on the file [b wheat1]WinLongPathsEnabled.reg[/] in order to
enable long path support on your system.
2. Make sure you have the [b wheat1]Visual C++ core libraries[/] installed. If not, install from
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]"""
)
[deep_sky_blue1]https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170[/]""")
else:
text = ""
return text

View File

@@ -41,7 +41,7 @@ IF /I "%choice%" == "1" (
python .venv\Scripts\invokeai-configure.exe --skip-sd-weight --skip-support-models
) ELSE IF /I "%choice%" == "7" (
echo Running invokeai-configure...
python .venv\Scripts\invokeai-configure.exe --yes --skip-sd-weight
python .venv\Scripts\invokeai-configure.exe --yes --default_only
) ELSE IF /I "%choice%" == "8" (
echo Developer Console
echo Python command is:

View File

@@ -82,7 +82,7 @@ do_choice() {
7)
clear
printf "Re-run the configure script to fix a broken install or to complete a major upgrade\n"
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only --skip-sd-weights
invokeai-configure --root ${INVOKEAI_ROOT} --yes --default_only
;;
8)
clear

View File

@@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Optional
from logging import Logger
import os
from invokeai.app.services.board_image_record_storage import (
@@ -55,7 +54,7 @@ logger = InvokeAILogger.getLogger()
class ApiDependencies:
"""Contains and initializes all dependencies for the API"""
invoker: Optional[Invoker] = None
invoker: Invoker = None
@staticmethod
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
@@ -79,7 +78,9 @@ class ApiDependencies:
image_record_storage = SqliteImageRecordStorage(db_location)
image_file_storage = DiskImageFileStorage(f"{output_folder}/images")
names = SimpleNameService()
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
latents = ForwardCacheLatentsStorage(
DiskLatentsStorage(f"{output_folder}/latents")
)
board_record_storage = SqliteBoardRecordStorage(db_location)
board_image_record_storage = SqliteBoardImageRecordStorage(db_location)
@@ -124,7 +125,9 @@ class ApiDependencies:
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
configuration=config,

View File

@@ -1,21 +1,14 @@
import typing
from enum import Enum
from fastapi import Body
from fastapi.routing import APIRouter
from pathlib import Path
from pydantic import BaseModel, Field
from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.version import __version__
from ..dependencies import ApiDependencies
from invokeai.backend.util.logging import logging
class LogLevel(int, Enum):
NotSet = logging.NOTSET
Debug = logging.DEBUG
@@ -23,13 +16,7 @@ class LogLevel(int, Enum):
Warning = logging.WARNING
Error = logging.ERROR
Critical = logging.CRITICAL
class Upscaler(BaseModel):
upscaling_method: str = Field(description="Name of upscaling method")
upscaling_models: list[str] = Field(description="List of upscaling models for this method")
app_router = APIRouter(prefix="/v1/app", tags=["app"])
@@ -43,62 +30,43 @@ class AppConfig(BaseModel):
"""App Config Response"""
infill_methods: list[str] = Field(description="List of available infill methods")
upscaling_methods: list[Upscaler] = Field(description="List of upscaling methods")
nsfw_methods: list[str] = Field(description="List of NSFW checking methods")
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
@app_router.get("/version", operation_id="app_version", status_code=200, response_model=AppVersion)
@app_router.get(
"/version", operation_id="app_version", status_code=200, response_model=AppVersion
)
async def get_version() -> AppVersion:
return AppVersion(version=__version__)
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
@app_router.get(
"/config", operation_id="get_config", status_code=200, response_model=AppConfig
)
async def get_config() -> AppConfig:
infill_methods = ["tile"]
infill_methods = ['tile']
if PatchMatch.patchmatch_available():
infill_methods.append("patchmatch")
upscaling_models = []
for model in typing.get_args(ESRGAN_MODELS):
upscaling_models.append(str(Path(model).stem))
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)
nsfw_methods = []
if SafetyChecker.safety_checker_available():
nsfw_methods.append("nsfw_checker")
watermarking_methods = []
if InvisibleWatermark.invisible_watermark_available():
watermarking_methods.append("invisible_watermark")
return AppConfig(
infill_methods=infill_methods,
upscaling_methods=[upscaler],
nsfw_methods=nsfw_methods,
watermarking_methods=watermarking_methods,
)
infill_methods.append('patchmatch')
return AppConfig(infill_methods=infill_methods)
@app_router.get(
"/logging",
operation_id="get_log_level",
responses={200: {"description": "The operation was successful"}},
response_model=LogLevel,
responses={200: {"description" : "The operation was successful"}},
response_model = LogLevel,
)
async def get_log_level() -> LogLevel:
async def get_log_level(
) -> LogLevel:
"""Returns the log level"""
return LogLevel(ApiDependencies.invoker.services.logger.level)
@app_router.post(
"/logging",
operation_id="set_log_level",
responses={200: {"description": "The operation was successful"}},
response_model=LogLevel,
responses={200: {"description" : "The operation was successful"}},
response_model = LogLevel,
)
async def set_log_level(
level: LogLevel = Body(description="New log verbosity level"),
level: LogLevel = Body(description="New log verbosity level"),
) -> LogLevel:
"""Sets the log verbosity level"""
ApiDependencies.invoker.services.logger.setLevel(level)

View File

@@ -52,3 +52,4 @@ async def remove_board_image(
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to update board")

View File

@@ -18,7 +18,9 @@ class DeleteBoardResult(BaseModel):
deleted_board_images: list[str] = Field(
description="The image names of the board-images relationships that were deleted."
)
deleted_images: list[str] = Field(description="The names of the images that were deleted.")
deleted_images: list[str] = Field(
description="The names of the images that were deleted."
)
@boards_router.post(
@@ -71,16 +73,22 @@ async def update_board(
) -> BoardDTO:
"""Updates a board"""
try:
result = ApiDependencies.invoker.services.boards.update(board_id=board_id, changes=changes)
result = ApiDependencies.invoker.services.boards.update(
board_id=board_id, changes=changes
)
return result
except Exception as e:
raise HTTPException(status_code=500, detail="Failed to update board")
@boards_router.delete("/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult)
@boards_router.delete(
"/{board_id}", operation_id="delete_board", response_model=DeleteBoardResult
)
async def delete_board(
board_id: str = Path(description="The id of board to delete"),
include_images: Optional[bool] = Query(description="Permanently delete all images on the board", default=False),
include_images: Optional[bool] = Query(
description="Permanently delete all images on the board", default=False
),
) -> DeleteBoardResult:
"""Deletes a board"""
try:
@@ -88,7 +96,9 @@ async def delete_board(
deleted_images = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id=board_id
)
ApiDependencies.invoker.services.images.delete_images_on_board(board_id=board_id)
ApiDependencies.invoker.services.images.delete_images_on_board(
board_id=board_id
)
ApiDependencies.invoker.services.boards.delete(board_id=board_id)
return DeleteBoardResult(
board_id=board_id,
@@ -117,7 +127,9 @@ async def delete_board(
async def list_boards(
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
offset: Optional[int] = Query(default=None, description="The page offset"),
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
limit: Optional[int] = Query(
default=None, description="The number of boards per page"
),
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards"""
if all:

View File

@@ -40,9 +40,15 @@ async def upload_image(
response: Response,
image_category: ImageCategory = Query(description="The category of the image"),
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
board_id: Optional[str] = Query(
default=None, description="The board to add this image to, if any"
),
session_id: Optional[str] = Query(
default=None, description="The session ID associated with this upload, if any"
),
crop_visible: Optional[bool] = Query(
default=False, description="Whether to crop the image"
),
) -> ImageDTO:
"""Uploads an image"""
if not file.content_type.startswith("image"):
@@ -109,7 +115,9 @@ async def clear_intermediates() -> int:
)
async def update_image(
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"),
image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image"
),
) -> ImageDTO:
"""Updates an image"""
@@ -204,11 +212,15 @@ async def get_image_thumbnail(
"""Gets a thumbnail image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
path = ApiDependencies.invoker.services.images.get_path(
image_name, thumbnail=True
)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
response = FileResponse(path, media_type="image/webp", content_disposition_type="inline")
response = FileResponse(
path, media_type="image/webp", content_disposition_type="inline"
)
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
return response
except Exception as e:
@@ -227,7 +239,9 @@ async def get_image_urls(
try:
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
thumbnail_url = ApiDependencies.invoker.services.images.get_url(image_name, thumbnail=True)
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_name, thumbnail=True
)
return ImageUrlsDTO(
image_name=image_name,
image_url=image_url,
@@ -243,9 +257,15 @@ async def get_image_urls(
response_model=OffsetPaginatedResults[ImageDTO],
)
async def list_image_dtos(
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
image_origin: Optional[ResourceOrigin] = Query(
default=None, description="The origin of images to list."
),
categories: Optional[list[ImageCategory]] = Query(
default=None, description="The categories of image to include."
),
is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images."
),
board_id: Optional[str] = Query(
default=None,
description="The board id to filter by. Use 'none' to find images without a board.",

View File

@@ -28,52 +28,49 @@ ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@models_router.get(
"/",
operation_id="list_models",
responses={200: {"model": ModelsList}},
responses={200: {"model": ModelsList }},
)
async def list_models(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList:
"""Gets a list of models"""
if base_models and len(base_models) > 0:
if base_models and len(base_models)>0:
models_raw = list()
for base_model in base_models:
models_raw.extend(ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type))
else:
models_raw = ApiDependencies.invoker.services.model_manager.list_models(None, model_type)
models = parse_obj_as(ModelsList, {"models": models_raw})
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
@models_router.patch(
"/{base_model}/{model_type}/{model_name}",
operation_id="update_model",
responses={
200: {"description": "The model was updated successfully"},
400: {"description": "Bad request"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to the new name"},
},
status_code=200,
response_model=UpdateModelResponse,
responses={200: {"description" : "The model was updated successfully"},
400: {"description" : "Bad request"},
404: {"description" : "The model could not be found"},
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code = 200,
response_model = UpdateModelResponse,
)
async def update_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
logger = ApiDependencies.invoker.services.logger
try:
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
@@ -84,13 +81,13 @@ async def update_model(
# rename operation requested
if info.model_name != model_name or info.base_model != base_model:
ApiDependencies.invoker.services.model_manager.rename_model(
base_model=base_model,
model_type=model_type,
model_name=model_name,
new_name=info.model_name,
new_base=info.base_model,
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = info.model_name,
new_base = info.base_model,
)
logger.info(f"Successfully renamed {base_model.value}/{model_name}=>{info.base_model}/{info.model_name}")
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
# update information to support an update of attributes
model_name = info.model_name
base_model = info.base_model
@@ -99,15 +96,16 @@ async def update_model(
base_model=base_model,
model_type=model_type,
)
if new_info.get("path") != previous_info.get(
"path"
): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get("path")
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get('path')
ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict()
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_attributes=info.dict()
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
@@ -125,48 +123,49 @@ async def update_model(
return model_response
@models_router.post(
"/import",
operation_id="import_model",
responses={
201: {"description": "The model imported successfully"},
404: {"description": "The model could not be found"},
415: {"description": "Unrecognized file/folder format"},
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
415: {"description" : "Unrecognized file/folder format"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse,
response_model=ImportModelResponse
)
async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
),
location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse:
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
items_to_import = {location}
prediction_types = {x.value: x for x in SchedulerPredictionType}
prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger
try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
)
info = installed_models.get(location)
if not info:
logger.error("Import failed")
raise HTTPException(status_code=415)
logger.info(f"Successfully imported {location}, got {info}")
logger.info(f'Successfully imported {location}, got {info}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name, base_model=info.base_model, model_type=info.model_type
model_name=info.name,
base_model=info.base_model,
model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@@ -176,34 +175,38 @@ async def import_model(
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/add",
operation_id="add_model",
responses={
201: {"description": "The model added successfully"},
404: {"description": "The model could not be found"},
424: {"description": "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
responses= {
201: {"description" : "The model added successfully"},
404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse,
response_model=ImportModelResponse
)
async def add_model(
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> ImportModelResponse:
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.add_model(
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
info.model_name,
info.base_model,
info.model_type,
model_attributes = info.dict()
)
logger.info(f"Successfully added {info.model_name}")
logger.info(f'Successfully added {info.model_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
model_name=info.model_name,
base_model=info.base_model,
model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except ModelNotFoundException as e:
@@ -213,66 +216,66 @@ async def add_model(
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete(
"/{base_model}/{model_type}/{model_name}",
operation_id="del_model",
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
status_code=204,
response_model=None,
responses={
204: { "description": "Model deleted successfully" },
404: { "description": "Model not found" }
},
status_code = 204,
response_model = None,
)
async def delete_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
) -> Response:
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.del_model(
model_name, base_model=base_model, model_type=model_type
)
ApiDependencies.invoker.services.model_manager.del_model(model_name,
base_model = base_model,
model_type = model_type
)
logger.info(f"Deleted model: {model_name}")
return Response(status_code=204)
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@models_router.put(
"/convert/{base_model}/{model_type}/{model_name}",
operation_id="convert_model",
responses={
200: {"description": "Model converted successfully"},
400: {"description": "Bad request"},
404: {"description": "Model not found"},
200: { "description": "Model converted successfully" },
400: {"description" : "Bad request" },
404: { "description": "Model not found" },
},
status_code=200,
response_model=ConvertModelResponse,
status_code = 200,
response_model = ConvertModelResponse,
)
async def convert_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
convert_dest_directory: Optional[str] = Query(
default=None, description="Save the converted model to the designated directory"
),
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
convert_dest_directory: Optional[str] = Query(default=None, description="Save the converted model to the designated directory"),
) -> ConvertModelResponse:
"""Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none."""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Converting model: {model_name}")
dest = pathlib.Path(convert_dest_directory) if convert_dest_directory else None
ApiDependencies.invoker.services.model_manager.convert_model(
model_name,
base_model=base_model,
model_type=model_type,
convert_dest_directory=dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name, base_model=base_model, model_type=model_type
)
ApiDependencies.invoker.services.model_manager.convert_model(model_name,
base_model = base_model,
model_type = model_type,
convert_dest_directory = dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(model_name,
base_model = base_model,
model_type = model_type)
response = parse_obj_as(ConvertModelResponse, model_raw)
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
@@ -280,101 +283,91 @@ async def convert_model(
raise HTTPException(status_code=400, detail=str(e))
return response
@models_router.get(
"/search",
operation_id="search_for_models",
responses={
200: {"description": "Directory searched successfully"},
404: {"description": "Invalid directory path"},
200: { "description": "Directory searched successfully" },
404: { "description": "Invalid directory path" },
},
status_code=200,
response_model=List[pathlib.Path],
status_code = 200,
response_model = List[pathlib.Path]
)
async def search_for_models(
search_path: pathlib.Path = Query(description="Directory path to search for models"),
) -> List[pathlib.Path]:
search_path: pathlib.Path = Query(description="Directory path to search for models")
)->List[pathlib.Path]:
if not search_path.is_dir():
raise HTTPException(
status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory"
)
return ApiDependencies.invoker.services.model_manager.search_for_models(search_path)
raise HTTPException(status_code=404, detail=f"The search path '{search_path}' does not exist or is not directory")
return ApiDependencies.invoker.services.model_manager.search_for_models([search_path])
@models_router.get(
"/ckpt_confs",
operation_id="list_ckpt_configs",
responses={
200: {"description": "paths retrieved successfully"},
200: { "description" : "paths retrieved successfully" },
},
status_code=200,
response_model=List[pathlib.Path],
status_code = 200,
response_model = List[pathlib.Path]
)
async def list_ckpt_configs() -> List[pathlib.Path]:
async def list_ckpt_configs(
)->List[pathlib.Path]:
"""Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT."""
return ApiDependencies.invoker.services.model_manager.list_checkpoint_configs()
@models_router.post(
"/sync",
operation_id="sync_to_config",
responses={
201: {"description": "synchronization successful"},
201: { "description": "synchronization successful" },
},
status_code=201,
response_model=bool,
status_code = 201,
response_model = bool
)
async def sync_to_config() -> bool:
async def sync_to_config(
)->bool:
"""Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in-memory data structures with disk data structures."""
ApiDependencies.invoker.services.model_manager.sync_to_config()
return True
@models_router.put(
"/merge/{base_model}",
operation_id="merge_models",
responses={
200: {"description": "Model converted successfully"},
400: {"description": "Incompatible models"},
404: {"description": "One or more models not found"},
200: { "description": "Model converted successfully" },
400: { "description": "Incompatible models" },
404: { "description": "One or more models not found" },
},
status_code=200,
response_model=MergeModelResponse,
status_code = 200,
response_model = MergeModelResponse,
)
async def merge_models(
base_model: BaseModelType = Path(description="Base model"),
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description="Name of destination model"),
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(
description="Force merging of models created with different versions of diffusers", default=False
),
merge_dest_directory: Optional[str] = Body(
description="Save the merged model to the designated directory (with 'merged_model_name' appended)",
default=None,
),
base_model: BaseModelType = Path(description="Base model"),
model_names: List[str] = Body(description="model name", min_items=2, max_items=3),
merged_model_name: Optional[str] = Body(description="Name of destination model"),
alpha: Optional[float] = Body(description="Alpha weighting strength to apply to 2d and 3d models", default=0.5),
interp: Optional[MergeInterpolationMethod] = Body(description="Interpolation method"),
force: Optional[bool] = Body(description="Force merging of models created with different versions of diffusers", default=False),
merge_dest_directory: Optional[str] = Body(description="Save the merged model to the designated directory (with 'merged_model_name' appended)", default=None)
) -> MergeModelResponse:
"""Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Merging models: {model_names} into {merge_dest_directory or '<MODELS>'}/{merged_model_name}")
dest = pathlib.Path(merge_dest_directory) if merge_dest_directory else None
result = ApiDependencies.invoker.services.model_manager.merge_models(
model_names,
base_model,
merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory=dest,
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
result.name,
base_model=base_model,
model_type=ModelType.Main,
)
result = ApiDependencies.invoker.services.model_manager.merge_models(model_names,
base_model,
merged_model_name=merged_model_name or "+".join(model_names),
alpha=alpha,
interp=interp,
force=force,
merge_dest_directory = dest
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(result.name,
base_model = base_model,
model_type = ModelType.Main,
)
response = parse_obj_as(ConvertModelResponse, model_raw)
except ModelNotFoundException:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")

View File

@@ -30,7 +30,9 @@ session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
},
)
async def create_session(
graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with")
graph: Optional[Graph] = Body(
default=None, description="The graph to initialize the session with"
)
) -> GraphExecutionState:
"""Creates a new session, optionally initializing it with an invocation graph"""
session = ApiDependencies.invoker.create_execution_state(graph)
@@ -49,9 +51,13 @@ async def list_sessions(
) -> PaginatedResults[GraphExecutionState]:
"""Gets a list of sessions, optionally searching"""
if query == "":
result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
result = ApiDependencies.invoker.services.graph_execution_manager.list(
page, per_page
)
else:
result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
result = ApiDependencies.invoker.services.graph_execution_manager.search(
query, page, per_page
)
return result
@@ -85,9 +91,9 @@ async def get_session(
)
async def add_node(
session_id: str = Path(description="The id of the session"),
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
description="The node to add"
),
node: Annotated[
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
] = Body(description="The node to add"),
) -> str:
"""Adds a node to the graph"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
@@ -118,9 +124,9 @@ async def add_node(
async def update_node(
session_id: str = Path(description="The id of the session"),
node_path: str = Path(description="The path to the node in the graph"),
node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
description="The new node"
),
node: Annotated[
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
] = Body(description="The new node"),
) -> GraphExecutionState:
"""Updates a node in the graph and removes all linked edges"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
@@ -224,7 +230,7 @@ async def delete_edge(
try:
edge = Edge(
source=EdgeConnection(node_id=from_node_id, field=from_field),
destination=EdgeConnection(node_id=to_node_id, field=to_field),
destination=EdgeConnection(node_id=to_node_id, field=to_field)
)
session.delete_edge(edge)
ApiDependencies.invoker.services.graph_execution_manager.set(
@@ -249,7 +255,9 @@ async def delete_edge(
)
async def invoke_session(
session_id: str = Path(description="The id of the session to invoke"),
all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
all: bool = Query(
default=False, description="Whether or not to invoke all remaining invocations"
),
) -> Response:
"""Invokes a session"""
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
@@ -266,7 +274,9 @@ async def invoke_session(
@session_router.delete(
"/{session_id}/invoke",
operation_id="cancel_session_invoke",
responses={202: {"description": "The invocation is canceled"}},
responses={
202: {"description": "The invocation is canceled"}
},
)
async def cancel_session_invoke(
session_id: str = Path(description="The id of the session to cancel"),

View File

@@ -16,7 +16,9 @@ class SocketIO:
self.__sio.on("subscribe", handler=self._handle_sub)
self.__sio.on("unsubscribe", handler=self._handle_unsub)
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
local_handler.register(
event_name=EventServiceBase.session_event, _func=self._handle_session_event
)
async def _handle_session_event(self, event: Event):
await self.__sio.emit(

View File

@@ -3,7 +3,6 @@ import asyncio
import sys
from inspect import signature
import logging
import uvicorn
import socket
@@ -17,10 +16,9 @@ from fastapi_events.middleware import EventHandlerASGIMiddleware
from pathlib import Path
from pydantic.schema import schema
# This should come early so that modules can log their initialization properly
#This should come early so that modules can log their initialization properly
from .services.config import InvokeAIAppConfig
from ..backend.util.logging import InvokeAILogger
app_config = InvokeAIAppConfig.get_config()
app_config.parse_args()
logger = InvokeAILogger.getLogger(config=app_config)
@@ -29,7 +27,7 @@ from invokeai.version.invokeai_version import __version__
# we call this early so that the message appears before
# other invokeai initialization messages
if app_config.version:
print(f"InvokeAI version {__version__}")
print(f'InvokeAI version {__version__}')
sys.exit(0)
import invokeai.frontend.web as web_dir
@@ -39,18 +37,17 @@ from .api.dependencies import ApiDependencies
from .api.routers import sessions, models, images, boards, board_images, app_info
from .api.sockets import SocketIO
from .invocations.baseinvocation import BaseInvocation
import torch
import invokeai.backend.util.hotfixes
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes
# fix for windows mimetypes registry entries being borked
# see https://github.com/invoke-ai/InvokeAI/discussions/3684#discussioncomment-6391352
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type('application/javascript', '.js')
mimetypes.add_type('text/css', '.css')
# Create the app
# TODO: create this all in a method so configuration/etc. can be passed in?
@@ -60,13 +57,14 @@ app = FastAPI(title="Invoke AI", docs_url=None, redoc_url=None)
event_handler_id: int = id(app)
app.add_middleware(
EventHandlerASGIMiddleware,
handlers=[local_handler], # TODO: consider doing this in services to support different configurations
handlers=[
local_handler
], # TODO: consider doing this in services to support different configurations
middleware_id=event_handler_id,
)
socket_io = SocketIO(app)
# Add startup event to load dependencies
@app.on_event("startup")
async def startup_event():
@@ -78,7 +76,9 @@ async def startup_event():
allow_headers=app_config.allow_headers,
)
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
ApiDependencies.initialize(
config=app_config, event_handler_id=event_handler_id, logger=logger
)
# Shut down threads
@@ -103,8 +103,7 @@ app.include_router(boards.boards_router, prefix="/api")
app.include_router(board_images.board_images_router, prefix="/api")
app.include_router(app_info.app_router, prefix="/api")
app.include_router(app_info.app_router, prefix='/api')
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
@@ -145,7 +144,6 @@ def custom_openapi():
invoker_schema["output"] = outputs_ref
from invokeai.backend.model_management.models import get_model_config_enums
for model_config_format_enum in set(get_model_config_enums()):
name = model_config_format_enum.__qualname__
@@ -168,8 +166,7 @@ def custom_openapi():
app.openapi = custom_openapi
# Override API doc favicons
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], "static/dream_web")), name="static")
app.mount("/static", StaticFiles(directory=Path(web_dir.__path__[0], 'static/dream_web')), name="static")
@app.get("/docs", include_in_schema=False)
def overridden_swagger():
@@ -190,8 +187,11 @@ def overridden_redoc():
# Must mount *after* the other routes else it borks em
app.mount("/", StaticFiles(directory=Path(web_dir.__path__[0], "dist"), html=True), name="ui")
app.mount("/",
StaticFiles(directory=Path(web_dir.__path__[0],"dist"),
html=True
), name="ui"
)
def invoke_api():
def find_port(port: int):
@@ -204,34 +204,15 @@ def invoke_api():
else:
return port
from invokeai.backend.install.check_root import check_invokeai_root
check_invokeai_root(app_config) # note, may exit with an exception if root not set up
port = find_port(app_config.port)
if port != app_config.port:
logger.warn(f"Port {app_config.port} in use, using port {port}")
# Start our own event loop for eventing usage
loop = asyncio.new_event_loop()
config = uvicorn.Config(
app=app,
host=app_config.host,
port=port,
loop=loop,
log_level=app_config.log_level,
)
config = uvicorn.Config(app=app, host=app_config.host, port=port, loop=loop)
# Use access_log to turn off logging
server = uvicorn.Server(config)
# replace uvicorn's loggers with InvokeAI's for consistent appearance
for logname in ["uvicorn.access", "uvicorn"]:
l = logging.getLogger(logname)
l.handlers.clear()
for ch in logger.handlers:
l.addHandler(ch)
loop.run_until_complete(server.serve())
if __name__ == "__main__":
invoke_api()

View File

@@ -14,14 +14,8 @@ from ..services.graph import GraphExecutionState, LibraryGraph, Edge
from ..services.invoker import Invoker
def add_field_argument(command_parser, name: str, field, default_override=None):
default = (
default_override
if default_override is not None
else field.default
if field.default_factory is None
else field.default_factory()
)
def add_field_argument(command_parser, name: str, field, default_override = None):
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
if get_origin(field.type_) == Literal:
allowed_values = get_args(field.type_)
allowed_types = set()
@@ -53,8 +47,8 @@ def add_parsers(
commands: list[type],
command_field: str = "type",
exclude_fields: list[str] = ["id", "type"],
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None,
):
add_arguments: Union[Callable[[argparse.ArgumentParser], None],None] = None
):
"""Adds parsers for each command to the subparsers"""
# Create subparsers for each command
@@ -67,7 +61,7 @@ def add_parsers(
add_arguments(command_parser)
# Convert all fields to arguments
fields = command.__fields__ # type: ignore
fields = command.__fields__ # type: ignore
for name, field in fields.items():
if name in exclude_fields:
continue
@@ -76,11 +70,13 @@ def add_parsers(
def add_graph_parsers(
subparsers, graphs: list[LibraryGraph], add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
subparsers,
graphs: list[LibraryGraph],
add_arguments: Union[Callable[[argparse.ArgumentParser], None], None] = None
):
for graph in graphs:
command_parser = subparsers.add_parser(graph.name, help=graph.description)
if add_arguments is not None:
add_arguments(command_parser)
@@ -132,7 +128,6 @@ class CliContext:
class ExitCli(Exception):
"""Exception to exit the CLI"""
pass
@@ -160,7 +155,7 @@ class BaseCommand(ABC, BaseModel):
@classmethod
def get_commands_map(cls):
# Get the type strings out of the literals and into a dictionary
return dict(map(lambda t: (get_args(get_type_hints(t)["type"])[0], t), BaseCommand.get_all_subclasses()))
return dict(map(lambda t: (get_args(get_type_hints(t)['type'])[0], t),BaseCommand.get_all_subclasses()))
@abstractmethod
def run(self, context: CliContext) -> None:
@@ -170,8 +165,7 @@ class BaseCommand(ABC, BaseModel):
class ExitCommand(BaseCommand):
"""Exits the CLI"""
type: Literal["exit"] = "exit"
type: Literal['exit'] = 'exit'
def run(self, context: CliContext) -> None:
raise ExitCli()
@@ -179,8 +173,7 @@ class ExitCommand(BaseCommand):
class HelpCommand(BaseCommand):
"""Shows help"""
type: Literal["help"] = "help"
type: Literal['help'] = 'help'
def run(self, context: CliContext) -> None:
context.parser.print_help()
@@ -190,7 +183,11 @@ def get_graph_execution_history(
graph_execution_state: GraphExecutionState,
) -> Iterable[str]:
"""Gets the history of fully-executed invocations for a graph execution"""
return (n for n in reversed(graph_execution_state.executed_history) if n in graph_execution_state.graph.nodes)
return (
n
for n in reversed(graph_execution_state.executed_history)
if n in graph_execution_state.graph.nodes
)
def get_invocation_command(invocation) -> str:
@@ -221,8 +218,7 @@ def get_invocation_command(invocation) -> str:
class HistoryCommand(BaseCommand):
"""Shows the invocation history"""
type: Literal["history"] = "history"
type: Literal['history'] = 'history'
# Inputs
# fmt: off
@@ -239,8 +235,7 @@ class HistoryCommand(BaseCommand):
class SetDefaultCommand(BaseCommand):
"""Sets a default value for a field"""
type: Literal["default"] = "default"
type: Literal['default'] = 'default'
# Inputs
# fmt: off
@@ -258,8 +253,7 @@ class SetDefaultCommand(BaseCommand):
class DrawGraphCommand(BaseCommand):
"""Debugs a graph"""
type: Literal["draw_graph"] = "draw_graph"
type: Literal['draw_graph'] = 'draw_graph'
def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
@@ -277,8 +271,7 @@ class DrawGraphCommand(BaseCommand):
class DrawExecutionGraphCommand(BaseCommand):
"""Debugs an execution graph"""
type: Literal["draw_xgraph"] = "draw_xgraph"
type: Literal['draw_xgraph'] = 'draw_xgraph'
def run(self, context: CliContext) -> None:
session: GraphExecutionState = context.invoker.services.graph_execution_manager.get(context.session.id)
@@ -293,7 +286,6 @@ class DrawExecutionGraphCommand(BaseCommand):
plt.axis("off")
plt.show()
class SortedHelpFormatter(argparse.HelpFormatter):
def _iter_indented_subactions(self, action):
try:

View File

@@ -19,8 +19,8 @@ from ..services.invocation_services import InvocationServices
# singleton object, class variable
completer = None
class Completer(object):
def __init__(self, model_manager: ModelManager):
self.commands = self.get_commands()
self.matches = None
@@ -43,7 +43,7 @@ class Completer(object):
except IndexError:
pass
options = options or list(self.parse_commands().keys())
if not text: # first time
self.matches = options
else:
@@ -56,17 +56,17 @@ class Completer(object):
return match
@classmethod
def get_commands(self) -> List[object]:
def get_commands(self)->List[object]:
"""
Return a list of all the client commands and invocations.
"""
return BaseCommand.get_commands() + BaseInvocation.get_invocations()
def get_current_command(self, buffer: str) -> tuple[str, str]:
def get_current_command(self, buffer: str)->tuple[str, str]:
"""
Parse the readline buffer to find the most recent command and its switch.
"""
if len(buffer) == 0:
if len(buffer)==0:
return None, None
tokens = shlex.split(buffer)
command = None
@@ -78,11 +78,11 @@ class Completer(object):
else:
switch = t
# don't try to autocomplete switches that are already complete
if switch and buffer.endswith(" "):
switch = None
return command or "", switch or ""
if switch and buffer.endswith(' '):
switch=None
return command or '', switch or ''
def parse_commands(self) -> Dict[str, List[str]]:
def parse_commands(self)->Dict[str, List[str]]:
"""
Return a dict in which the keys are the command name
and the values are the parameters the command takes.
@@ -90,11 +90,11 @@ class Completer(object):
result = dict()
for command in self.commands:
hints = get_type_hints(command)
name = get_args(hints["type"])[0]
result.update({name: hints})
name = get_args(hints['type'])[0]
result.update({name:hints})
return result
def get_command_options(self, command: str, switch: str) -> List[str]:
def get_command_options(self, command: str, switch: str)->List[str]:
"""
Return all the parameters that can be passed to the command as
command-line switches. Returns None if the command is unrecognized.
@@ -102,46 +102,42 @@ class Completer(object):
parsed_commands = self.parse_commands()
if command not in parsed_commands:
return None
# handle switches in the format "-foo=bar"
argument = None
if switch and "=" in switch:
switch, argument = switch.split("=")
parameter = switch.strip("-")
if switch and '=' in switch:
switch, argument = switch.split('=')
parameter = switch.strip('-')
if parameter in parsed_commands[command]:
if argument is None:
return self.get_parameter_options(parameter, parsed_commands[command][parameter])
else:
return [
f"--{parameter}={x}"
for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])
]
return [f"--{parameter}={x}" for x in self.get_parameter_options(parameter, parsed_commands[command][parameter])]
else:
return [f"--{x}" for x in parsed_commands[command].keys()]
def get_parameter_options(self, parameter: str, typehint) -> List[str]:
def get_parameter_options(self, parameter: str, typehint)->List[str]:
"""
Given a parameter type (such as Literal), offers autocompletions.
"""
if get_origin(typehint) == Literal:
return get_args(typehint)
if parameter == "model":
if parameter == 'model':
return self.manager.model_names()
def _pre_input_hook(self):
if self.linebuffer:
readline.insert_text(self.linebuffer)
readline.redisplay()
self.linebuffer = None
def set_autocompleter(services: InvocationServices) -> Completer:
global completer
if completer:
return completer
completer = Completer(services.model_manager)
readline.set_completer(completer.complete)
@@ -166,6 +162,8 @@ def set_autocompleter(services: InvocationServices) -> Completer:
pass
except OSError: # file likely corrupted
newname = f"{histfile}.old"
logger.error(f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}")
logger.error(
f"Your history file {histfile} couldn't be loaded and may be corrupted. Renaming it to {newname}"
)
histfile.replace(Path(newname))
atexit.register(readline.write_history_file, histfile)

View File

@@ -13,7 +13,6 @@ from pydantic.fields import Field
# This should come early so that the logger can pick up its configuration options
from .services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
config = InvokeAIAppConfig.get_config()
config.parse_args()
logger = InvokeAILogger().getLogger(config=config)
@@ -21,7 +20,7 @@ from invokeai.version.invokeai_version import __version__
# we call this early so that the message appears before other invokeai initialization messages
if config.version:
print(f"InvokeAI version {__version__}")
print(f'InvokeAI version {__version__}')
sys.exit(0)
from invokeai.app.services.board_image_record_storage import (
@@ -37,21 +36,18 @@ from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from .services.default_graphs import default_text_to_image_graph_id, create_system_graphs
from .services.default_graphs import (default_text_to_image_graph_id,
create_system_graphs)
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .cli.commands import BaseCommand, CliContext, ExitCli, SortedHelpFormatter, add_graph_parsers, add_parsers
from .cli.commands import (BaseCommand, CliContext, ExitCli,
SortedHelpFormatter, add_graph_parsers, add_parsers)
from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.graph import (
Edge,
EdgeConnection,
GraphExecutionState,
GraphInvocation,
LibraryGraph,
are_connection_types_compatible,
)
from .services.graph import (Edge, EdgeConnection, GraphExecutionState,
GraphInvocation, LibraryGraph,
are_connection_types_compatible)
from .services.image_file_storage import DiskImageFileStorage
from .services.invocation_queue import MemoryInvocationQueue
from .services.invocation_services import InvocationServices
@@ -62,7 +58,6 @@ from .services.sqlite import SqliteItemStorage
import torch
import invokeai.backend.util.hotfixes
if torch.backends.mps.is_available():
import invokeai.backend.util.mps_fixes
@@ -74,7 +69,6 @@ class CliCommand(BaseModel):
class InvalidArgs(Exception):
pass
def add_invocation_args(command_parser):
# Add linking capability
command_parser.add_argument(
@@ -119,7 +113,7 @@ def get_command_parser(services: InvocationServices) -> argparse.ArgumentParser:
return parser
class NodeField:
class NodeField():
alias: str
node_path: str
field: str
@@ -132,20 +126,15 @@ class NodeField:
self.field_type = field_type
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str, NodeField]:
return {k: NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
def fields_from_type_hints(hints: dict[str, type], node_path: str) -> dict[str,NodeField]:
return {k:NodeField(alias=k, node_path=node_path, field=k, field_type=v) for k, v in hints.items()}
def get_node_input_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
"""Gets the node field for the specified field alias"""
exposed_input = next(e for e in graph.exposed_inputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_input.node_path))
return NodeField(
alias=exposed_input.alias,
node_path=f"{node_id}.{exposed_input.node_path}",
field=exposed_input.field,
field_type=get_type_hints(node_type)[exposed_input.field],
)
return NodeField(alias=exposed_input.alias, node_path=f'{node_id}.{exposed_input.node_path}', field=exposed_input.field, field_type=get_type_hints(node_type)[exposed_input.field])
def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -> NodeField:
@@ -153,12 +142,7 @@ def get_node_output_field(graph: LibraryGraph, field_alias: str, node_id: str) -
exposed_output = next(e for e in graph.exposed_outputs if e.alias == field_alias)
node_type = type(graph.graph.get_node(exposed_output.node_path))
node_output_type = node_type.get_output_type()
return NodeField(
alias=exposed_output.alias,
node_path=f"{node_id}.{exposed_output.node_path}",
field=exposed_output.field,
field_type=get_type_hints(node_output_type)[exposed_output.field],
)
return NodeField(alias=exposed_output.alias, node_path=f'{node_id}.{exposed_output.node_path}', field=exposed_output.field, field_type=get_type_hints(node_output_type)[exposed_output.field])
def get_node_inputs(invocation: BaseInvocation, context: CliContext) -> dict[str, NodeField]:
@@ -181,7 +165,9 @@ def get_node_outputs(invocation: BaseInvocation, context: CliContext) -> dict[st
return {e.alias: get_node_output_field(graph, e.alias, invocation.id) for e in graph.exposed_outputs}
def generate_matching_edges(a: BaseInvocation, b: BaseInvocation, context: CliContext) -> list[Edge]:
def generate_matching_edges(
a: BaseInvocation, b: BaseInvocation, context: CliContext
) -> list[Edge]:
"""Generates all possible edges between two invocations"""
afields = get_node_outputs(a, context)
bfields = get_node_inputs(b, context)
@@ -193,14 +179,12 @@ def generate_matching_edges(a: BaseInvocation, b: BaseInvocation, context: CliCo
matching_fields = matching_fields.difference(invalid_fields)
# Validate types
matching_fields = [
f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)
]
matching_fields = [f for f in matching_fields if are_connection_types_compatible(afields[f].field_type, bfields[f].field_type)]
edges = [
Edge(
source=EdgeConnection(node_id=afields[alias].node_path, field=afields[alias].field),
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field),
destination=EdgeConnection(node_id=bfields[alias].node_path, field=bfields[alias].field)
)
for alias in matching_fields
]
@@ -209,7 +193,6 @@ def generate_matching_edges(a: BaseInvocation, b: BaseInvocation, context: CliCo
class SessionError(Exception):
"""Raised when a session error has occurred"""
pass
@@ -226,23 +209,22 @@ def invoke_all(context: CliContext):
context.invoker.services.logger.error(
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
)
raise SessionError()
def invoke_cli():
logger.info(f"InvokeAI version {__version__}")
logger.info(f'InvokeAI version {__version__}')
# get the optional list of invocations to execute on the command line
parser = config.get_parser()
parser.add_argument("commands", nargs="*")
parser.add_argument('commands',nargs='*')
invocation_commands = parser.parse_args().commands
# get the optional file to read commands from.
# Simplest is to use it for STDIN
if infile := config.from_file:
sys.stdin = open(infile, "r")
model_manager = ModelManagerService(config, logger)
sys.stdin = open(infile,"r")
model_manager = ModelManagerService(config,logger)
events = EventServiceBase()
output_folder = config.output_path
@@ -252,13 +234,13 @@ def invoke_cli():
db_location = ":memory:"
else:
db_location = config.db_path
db_location.parent.mkdir(parents=True, exist_ok=True)
db_location.parent.mkdir(parents=True,exist_ok=True)
logger.info(f'InvokeAI database location is "{db_location}"')
graph_execution_manager = SqliteItemStorage[GraphExecutionState](
filename=db_location, table_name="graph_executions"
)
filename=db_location, table_name="graph_executions"
)
urls = LocalUrlService()
image_record_storage = SqliteImageRecordStorage(db_location)
@@ -299,21 +281,24 @@ def invoke_cli():
graph_execution_manager=graph_execution_manager,
)
)
services = InvocationServices(
model_manager=model_manager,
events=events,
latents=ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")),
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f'{output_folder}/latents')),
images=images,
boards=boards,
board_images=board_images,
queue=MemoryInvocationQueue(),
graph_library=SqliteItemStorage[LibraryGraph](filename=db_location, table_name="graphs"),
graph_library=SqliteItemStorage[LibraryGraph](
filename=db_location, table_name="graphs"
),
graph_execution_manager=graph_execution_manager,
processor=DefaultInvocationProcessor(),
logger=logger,
configuration=config,
)
system_graphs = create_system_graphs(services.graph_library)
system_graph_names = set([g.name for g in system_graphs])
@@ -323,7 +308,7 @@ def invoke_cli():
session: GraphExecutionState = invoker.create_execution_state()
parser = get_command_parser(services)
re_negid = re.compile("^-[0-9]+$")
re_negid = re.compile('^-[0-9]+$')
# Uncomment to print out previous sessions at startup
# print(services.session_manager.list())
@@ -333,7 +318,7 @@ def invoke_cli():
command_line_args_exist = len(invocation_commands) > 0
done = False
while not done:
try:
if command_line_args_exist:
@@ -347,7 +332,7 @@ def invoke_cli():
try:
# Refresh the state of the session
# history = list(get_graph_execution_history(context.session))
#history = list(get_graph_execution_history(context.session))
history = list(reversed(context.nodes_added))
# Split the command for piping
@@ -368,17 +353,17 @@ def invoke_cli():
args[field_name] = field_default
# Parse invocation
command: CliCommand = None # type:ignore
command: CliCommand = None # type:ignore
system_graph: Optional[LibraryGraph] = None
if args["type"] in system_graph_names:
system_graph = next(filter(lambda g: g.name == args["type"], system_graphs))
if args['type'] in system_graph_names:
system_graph = next(filter(lambda g: g.name == args['type'], system_graphs))
invocation = GraphInvocation(graph=system_graph.graph, id=str(current_id))
for exposed_input in system_graph.exposed_inputs:
if exposed_input.alias in args:
node = invocation.graph.get_node(exposed_input.node_path)
field = exposed_input.field
setattr(node, field, args[exposed_input.alias])
command = CliCommand(command=invocation)
command = CliCommand(command = invocation)
context.graph_nodes[invocation.id] = system_graph.id
else:
args["id"] = current_id
@@ -400,13 +385,17 @@ def invoke_cli():
# Pipe previous command output (if there was a previous command)
edges: list[Edge] = list()
if len(history) > 0 or current_id != start_id:
from_id = history[0] if current_id == start_id else str(current_id - 1)
from_id = (
history[0] if current_id == start_id else str(current_id - 1)
)
from_node = (
next(filter(lambda n: n[0].id == from_id, new_invocations))[0]
if current_id != start_id
else context.session.graph.get_node(from_id)
)
matching_edges = generate_matching_edges(from_node, command.command, context)
matching_edges = generate_matching_edges(
from_node, command.command, context
)
edges.extend(matching_edges)
# Parse provided links
@@ -417,18 +406,16 @@ def invoke_cli():
node_id = str(current_id + int(node_id))
link_node = context.session.graph.get_node(node_id)
matching_edges = generate_matching_edges(link_node, command.command, context)
matching_edges = generate_matching_edges(
link_node, command.command, context
)
matching_destinations = [e.destination for e in matching_edges]
edges = [e for e in edges if e.destination not in matching_destinations]
edges.extend(matching_edges)
if "link" in args and args["link"]:
for link in args["link"]:
edges = [
e
for e in edges
if e.destination.node_id != command.command.id or e.destination.field != link[2]
]
edges = [e for e in edges if e.destination.node_id != command.command.id or e.destination.field != link[2]]
node_id = link[0]
if re_negid.match(node_id):
@@ -441,7 +428,7 @@ def invoke_cli():
edges.append(
Edge(
source=EdgeConnection(node_id=node_output.node_path, field=node_output.field),
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field),
destination=EdgeConnection(node_id=node_input.node_path, field=node_input.field)
)
)

View File

@@ -4,5 +4,9 @@ __all__ = []
dirname = os.path.dirname(os.path.abspath(__file__))
for f in os.listdir(dirname):
if f != "__init__.py" and os.path.isfile("%s/%s" % (dirname, f)) and f[-3:] == ".py":
if (
f != "__init__.py"
and os.path.isfile("%s/%s" % (dirname, f))
and f[-3:] == ".py"
):
__all__.append(f[:-3])

View File

@@ -4,7 +4,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints
from typing import (TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args,
get_type_hints)
from pydantic import BaseConfig, BaseModel, Field

View File

@@ -8,7 +8,8 @@ from pydantic import Field, validator
from invokeai.app.models.image import ImageField
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext, UIConfig)
class IntCollectionOutput(BaseInvocationOutput):
@@ -26,7 +27,8 @@ class FloatCollectionOutput(BaseInvocationOutput):
type: Literal["float_collection"] = "float_collection"
# Outputs
collection: list[float] = Field(default=[], description="The float collection")
collection: list[float] = Field(
default=[], description="The float collection")
class ImageCollectionOutput(BaseInvocationOutput):
@@ -35,7 +37,8 @@ class ImageCollectionOutput(BaseInvocationOutput):
type: Literal["image_collection"] = "image_collection"
# Outputs
collection: list[ImageField] = Field(default=[], description="The output images")
collection: list[ImageField] = Field(
default=[], description="The output images")
class Config:
schema_extra = {"required": ["type", "collection"]}
@@ -53,7 +56,10 @@ class RangeInvocation(BaseInvocation):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Range", "tags": ["range", "integer", "collection"]},
"ui": {
"title": "Range",
"tags": ["range", "integer", "collection"]
},
}
@validator("stop")
@@ -63,7 +69,9 @@ class RangeInvocation(BaseInvocation):
return v
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step)))
return IntCollectionOutput(
collection=list(range(self.start, self.stop, self.step))
)
class RangeOfSizeInvocation(BaseInvocation):
@@ -78,11 +86,18 @@ class RangeOfSizeInvocation(BaseInvocation):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]},
"ui": {
"title": "Sized Range",
"tags": ["range", "integer", "size", "collection"]
},
}
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step)))
return IntCollectionOutput(
collection=list(
range(
self.start, self.start + self.size,
self.step)))
class RandomRangeInvocation(BaseInvocation):
@@ -92,7 +107,9 @@ class RandomRangeInvocation(BaseInvocation):
# Inputs
low: int = Field(default=0, description="The inclusive low value")
high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value")
high: int = Field(
default=np.iinfo(np.int32).max, description="The exclusive high value"
)
size: int = Field(default=1, description="The number of values to generate")
seed: int = Field(
ge=0,
@@ -103,12 +120,19 @@ class RandomRangeInvocation(BaseInvocation):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]},
"ui": {
"title": "Random Range",
"tags": ["range", "integer", "random", "collection"]
},
}
def invoke(self, context: InvocationContext) -> IntCollectionOutput:
rng = np.random.default_rng(self.seed)
return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size)))
return IntCollectionOutput(
collection=list(
rng.integers(
low=self.low, high=self.high,
size=self.size)))
class ImageCollectionInvocation(BaseInvocation):

View File

@@ -1,73 +1,66 @@
from typing import Literal, Optional, Union, List, Annotated
from pydantic import BaseModel, Field
import re
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
from .model import ClipField
from ...backend.util.devices import torch_dtype
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from compel.prompt_parser import (Blend, Conjunction,
CrossAttentionControlSubstitute,
FlattenedPrompt, Fragment)
from ...backend.util.devices import torch_dtype
from ...backend.model_management import ModelType
from ...backend.model_management.models import ModelNotFoundException
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .model import ClipField
from dataclasses import dataclass
class ConditioningField(BaseModel):
conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data")
conditioning_name: Optional[str] = Field(
default=None, description="The name of conditioning data")
class Config:
schema_extra = {"required": ["conditioning_name"]}
@dataclass
class BasicConditioningInfo:
# type: Literal["basic_conditioning"] = "basic_conditioning"
#type: Literal["basic_conditioning"] = "basic_conditioning"
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
# type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
#type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]
ConditioningInfoType = Annotated[
Union[BasicConditioningInfo, SDXLConditioningInfo],
Field(discriminator="type")
]
@dataclass
class ConditioningFieldData:
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
# unconditioned: Optional[torch.Tensor]
#unconditioned: Optional[torch.Tensor]
# class ConditioningAlgo(str, Enum):
#class ConditioningAlgo(str, Enum):
# Compose = "compose"
# ComposeEx = "compose_ex"
# PerpNeg = "perp_neg"
class CompelOutput(BaseInvocationOutput):
"""Compel parser output"""
# fmt: off
#fmt: off
type: Literal["compel_output"] = "compel_output"
conditioning: ConditioningField = Field(default=None, description="Conditioning")
# fmt: on
#fmt: on
class CompelInvocation(BaseInvocation):
@@ -81,28 +74,33 @@ class CompelInvocation(BaseInvocation):
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
"ui": {
"title": "Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(),
context=context,
**self.clip.tokenizer.dict(), context=context,
)
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(),
context=context,
**self.clip.text_encoder.dict(), context=context,
)
def _lora_loader():
for lora in self.clip.loras:
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
@@ -118,18 +116,15 @@ class CompelInvocation(BaseInvocation):
)
except ModelNotFoundException:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),\
text_encoder_info as text_encoder:
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, self.clip.skipped_layers
), text_encoder_info as text_encoder:
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@@ -144,12 +139,14 @@ class CompelInvocation(BaseInvocation):
if context.services.configuration.log_tokenization:
log_tokenization_for_prompt_object(prompt, tokenizer)
c, options = compel.build_conditioning_tensor_for_prompt_object(prompt)
c, options = compel.build_conditioning_tensor_for_prompt_object(
prompt)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
tokens_count_including_eos_bos=get_max_token_count(
tokenizer, conjunction),
cross_attention_control_args=options.get(
"cross_attention_control", None),)
c = c.detach().to("cpu")
@@ -171,26 +168,24 @@ class CompelInvocation(BaseInvocation):
),
)
class SDXLPromptInvocationBase:
def run_clip_raw(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
context=context,
)
text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(),
context=context,
)
def _lora_loader():
for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
@@ -201,23 +196,19 @@ class SDXLPromptInvocationBase:
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model
)
except ModelNotFoundException:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
text_encoder_info as text_encoder:
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, clip_field.skipped_layers
), text_encoder_info as text_encoder:
text_inputs = tokenizer(
prompt,
padding="max_length",
@@ -250,21 +241,20 @@ class SDXLPromptInvocationBase:
def run_clip_compel(self, context, clip_field, prompt, get_pooled):
tokenizer_info = context.services.model_manager.get_model(
**clip_field.tokenizer.dict(),
context=context,
)
text_encoder_info = context.services.model_manager.get_model(
**clip_field.text_encoder.dict(),
context=context,
)
def _lora_loader():
for lora in clip_field.loras:
lora_info = context.services.model_manager.get_model(**lora.dict(exclude={"weight"}), context=context)
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}))
yield (lora_info.context.model, lora.weight)
del lora_info
return
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
#loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
@@ -275,30 +265,26 @@ class SDXLPromptInvocationBase:
model_name=name,
base_model=clip_field.text_encoder.base_model,
model_type=ModelType.TextualInversion,
context=context,
).context.model
)
except ModelNotFoundException:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),\
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (tokenizer, ti_manager),\
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),\
text_encoder_info as text_encoder:
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, clip_field.skipped_layers
), text_encoder_info as text_encoder:
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=torch_dtype,
truncate_long_prompts=True, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=True,
)
@@ -332,7 +318,6 @@ class SDXLPromptInvocationBase:
return c, c_pooled, ec
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
@@ -352,7 +337,13 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
"ui": {
"title": "SDXL Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
@@ -367,7 +358,9 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
crop_coords = (self.crop_top, self.crop_left)
target_size = (self.target_height, self.target_width)
add_time_ids = torch.tensor([original_size + crop_coords + target_size])
add_time_ids = torch.tensor([
original_size + crop_coords + target_size
])
conditioning_data = ConditioningFieldData(
conditionings=[
@@ -389,13 +382,12 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
),
)
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt"
style: str = Field(default="", description="Style prompt") # TODO: ?
style: str = Field(default="", description="Style prompt") # TODO: ?
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
@@ -409,7 +401,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
"ui": {
"title": "SDXL Refiner Prompt (Compel)",
"tags": ["prompt", "compel"],
"type_hints": {"model": "model"},
"type_hints": {
"model": "model"
}
},
}
@@ -420,7 +414,9 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
add_time_ids = torch.tensor([
original_size + crop_coords + (self.aesthetic_score,)
])
conditioning_data = ConditioningFieldData(
conditionings=[
@@ -428,7 +424,7 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
embeds=c2,
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec2, # or None
extra_conditioning=ec2, # or None
)
]
)
@@ -442,7 +438,6 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
),
)
class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Pass unmodified prompt to conditioning without compel processing."""
@@ -462,7 +457,13 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "SDXL Prompt (Raw)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}},
"ui": {
"title": "SDXL Prompt (Raw)",
"tags": ["prompt", "compel"],
"type_hints": {
"model": "model"
}
},
}
@torch.no_grad()
@@ -477,7 +478,9 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
crop_coords = (self.crop_top, self.crop_left)
target_size = (self.target_height, self.target_width)
add_time_ids = torch.tensor([original_size + crop_coords + target_size])
add_time_ids = torch.tensor([
original_size + crop_coords + target_size
])
conditioning_data = ConditioningFieldData(
conditionings=[
@@ -499,13 +502,12 @@ class SDXLRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
),
)
class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"""Parse prompt using compel package to conditioning."""
type: Literal["sdxl_refiner_raw_prompt"] = "sdxl_refiner_raw_prompt"
style: str = Field(default="", description="Style prompt") # TODO: ?
style: str = Field(default="", description="Style prompt") # TODO: ?
original_width: int = Field(1024, description="")
original_height: int = Field(1024, description="")
crop_top: int = Field(0, description="")
@@ -519,7 +521,9 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
"ui": {
"title": "SDXL Refiner Prompt (Raw)",
"tags": ["prompt", "compel"],
"type_hints": {"model": "model"},
"type_hints": {
"model": "model"
}
},
}
@@ -530,7 +534,9 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
original_size = (self.original_height, self.original_width)
crop_coords = (self.crop_top, self.crop_left)
add_time_ids = torch.tensor([original_size + crop_coords + (self.aesthetic_score,)])
add_time_ids = torch.tensor([
original_size + crop_coords + (self.aesthetic_score,)
])
conditioning_data = ConditioningFieldData(
conditionings=[
@@ -538,7 +544,7 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
embeds=c2,
pooled_embeds=c2_pooled,
add_time_ids=add_time_ids,
extra_conditioning=ec2, # or None
extra_conditioning=ec2, # or None
)
]
)
@@ -555,14 +561,11 @@ class SDXLRefinerRawPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
class ClipSkipInvocationOutput(BaseInvocationOutput):
"""Clip skip node output"""
type: Literal["clip_skip_output"] = "clip_skip_output"
clip: ClipField = Field(None, description="Clip with skipped layers")
class ClipSkipInvocation(BaseInvocation):
"""Skip layers in clip text_encoder model."""
type: Literal["clip_skip"] = "clip_skip"
clip: ClipField = Field(None, description="Clip to use")
@@ -570,7 +573,10 @@ class ClipSkipInvocation(BaseInvocation):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]},
"ui": {
"title": "CLIP Skip",
"tags": ["clip", "skip"]
},
}
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
@@ -581,26 +587,46 @@ class ClipSkipInvocation(BaseInvocation):
def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction], truncate_if_too_long=False
) -> int:
tokenizer, prompt: Union[FlattenedPrompt, Blend, Conjunction],
truncate_if_too_long=False) -> int:
if type(prompt) is Blend:
blend: Blend = prompt
return max([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in blend.prompts])
return max(
[
get_max_token_count(tokenizer, p, truncate_if_too_long)
for p in blend.prompts
]
)
elif type(prompt) is Conjunction:
conjunction: Conjunction = prompt
return sum([get_max_token_count(tokenizer, p, truncate_if_too_long) for p in conjunction.prompts])
return sum(
[
get_max_token_count(tokenizer, p, truncate_if_too_long)
for p in conjunction.prompts
]
)
else:
return len(get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long))
return len(
get_tokens_for_prompt_object(
tokenizer, prompt, truncate_if_too_long))
def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> List[str]:
def get_tokens_for_prompt_object(
tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True
) -> List[str]:
if type(parsed_prompt) is Blend:
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
raise ValueError(
"Blend is not supported here - you need to get tokens for each of its .children"
)
text_fragments = [
x.text
if type(x) is Fragment
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
else (
" ".join([f.text for f in x.original])
if type(x) is CrossAttentionControlSubstitute
else str(x)
)
for x in parsed_prompt.children
]
text = " ".join(text_fragments)
@@ -611,17 +637,25 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
return tokens
def log_tokenization_for_conjunction(c: Conjunction, tokenizer, display_label_prefix=None):
def log_tokenization_for_conjunction(
c: Conjunction, tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or ""
for i, p in enumerate(c.prompts):
if len(c.prompts) > 1:
this_display_label_prefix = f"{display_label_prefix}(conjunction part {i + 1}, weight={c.weights[i]})"
else:
this_display_label_prefix = display_label_prefix
log_tokenization_for_prompt_object(p, tokenizer, display_label_prefix=this_display_label_prefix)
log_tokenization_for_prompt_object(
p,
tokenizer,
display_label_prefix=this_display_label_prefix
)
def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None):
def log_tokenization_for_prompt_object(
p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None
):
display_label_prefix = display_label_prefix or ""
if type(p) is Blend:
blend: Blend = p
@@ -658,10 +692,13 @@ def log_tokenization_for_prompt_object(p: Union[Blend, FlattenedPrompt], tokeniz
)
else:
text = " ".join([x.text for x in flattened_prompt.children])
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
log_tokenization_for_text(
text, tokenizer, display_label=display_label_prefix
)
def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
def log_tokenization_for_text(
text, tokenizer, display_label=None, truncate_if_too_long=False):
"""shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '

View File

@@ -6,30 +6,21 @@ from typing import Dict, List, Literal, Optional, Union
import cv2
import numpy as np
from controlnet_aux import (
CannyDetector,
ContentShuffleDetector,
HEDdetector,
LeresDetector,
LineartAnimeDetector,
LineartDetector,
MediapipeFaceDetector,
MidasDetector,
MLSDdetector,
NormalBaeDetector,
OpenposeDetector,
PidiNetDetector,
SamDetector,
ZoeDetector,
)
from controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector,
LeresDetector, LineartAnimeDetector,
LineartDetector, MediapipeFaceDetector,
MidasDetector, MLSDdetector, NormalBaeDetector,
OpenposeDetector, PidiNetDetector, SamDetector,
ZoeDetector)
from controlnet_aux.util import HWC3, ade_palette
from PIL import Image
from pydantic import BaseModel, Field, validator
from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from ..models.image import ImageOutput, PILInvocationConfig
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .image import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [
###########################################
@@ -43,6 +34,7 @@ CONTROLNET_DEFAULT_MODELS = [
"lllyasviel/sd-controlnet-scribble",
"lllyasviel/sd-controlnet-normal",
"lllyasviel/sd-controlnet-mlsd",
#############################################
# lllyasviel sd v1.5, ControlNet v1.1 models
#############################################
@@ -64,6 +56,7 @@ CONTROLNET_DEFAULT_MODELS = [
"lllyasviel/control_v11e_sd15_shuffle",
"lllyasviel/control_v11e_sd15_ip2p",
"lllyasviel/control_v11f1e_sd15_tile",
#################################################
# thibaud sd v2.1 models (ControlNet v1.0? or v1.1?
##################################################
@@ -78,6 +71,7 @@ CONTROLNET_DEFAULT_MODELS = [
"thibaud/controlnet-sd21-lineart-diffusers",
"thibaud/controlnet-sd21-normalbae-diffusers",
"thibaud/controlnet-sd21-ade20k-diffusers",
##############################################
# ControlNetMediaPipeface, ControlNet v1.1
##############################################
@@ -89,17 +83,10 @@ CONTROLNET_DEFAULT_MODELS = [
]
CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)]
CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])]
CONTROLNET_RESIZE_VALUES = Literal[
tuple(
[
"just_resize",
"crop_resize",
"fill_resize",
"just_resize_simple",
]
)
]
CONTROLNET_MODE_VALUES = Literal[tuple(
["balanced", "more_prompt", "more_control", "unbalanced"])]
CONTROLNET_RESIZE_VALUES = Literal[tuple(
["just_resize", "crop_resize", "fill_resize", "just_resize_simple",])]
class ControlNetModelField(BaseModel):
@@ -111,17 +98,21 @@ class ControlNetModelField(BaseModel):
class ControlField(BaseModel):
image: ImageField = Field(default=None, description="The control image")
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
control_model: Optional[ControlNetModelField] = Field(
default=None, description="The ControlNet model to use")
# control_weight: Optional[float] = Field(default=1, description="weight given to controlnet")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
control_weight: Union[float, List[float]] = Field(
default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
default=0, ge=0, le=1,
description="When the ControlNet is first applied (% of total steps)")
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
default=1, ge=0, le=1,
description="When the ControlNet is last applied (% of total steps)")
control_mode: CONTROLNET_MODE_VALUES = Field(
default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(
default="just_resize", description="The resize mode to use")
@validator("control_weight")
def validate_control_weight(cls, v):
@@ -129,10 +120,11 @@ class ControlField(BaseModel):
if isinstance(v, list):
for i in v:
if i < -1 or i > 2:
raise ValueError("Control weights must be within -1 to 2 range")
raise ValueError(
'Control weights must be within -1 to 2 range')
else:
if v < -1 or v > 2:
raise ValueError("Control weights must be within -1 to 2 range")
raise ValueError('Control weights must be within -1 to 2 range')
return v
class Config:
@@ -144,13 +136,12 @@ class ControlField(BaseModel):
"control_model": "controlnet_model",
# "control_weight": "number",
}
},
}
}
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# fmt: off
type: Literal["control_output"] = "control_output"
control: ControlField = Field(default=None, description="The control info")
@@ -159,7 +150,6 @@ class ControlOutput(BaseInvocationOutput):
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
# fmt: off
type: Literal["controlnet"] = "controlnet"
# Inputs
@@ -186,7 +176,7 @@ class ControlNetInvocation(BaseInvocation):
# "cfg_scale": "float",
"cfg_scale": "number",
"control_weight": "float",
},
}
},
}
@@ -215,7 +205,10 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Image Processor", "tags": ["image", "processor"]},
"ui": {
"title": "Image Processor",
"tags": ["image", "processor"]
},
}
def run_processor(self, image):
@@ -240,7 +233,7 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
image_category=ImageCategory.CONTROL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.is_intermediate,
is_intermediate=self.is_intermediate
)
"""Builds an ImageOutput and its ImageField"""
@@ -255,9 +248,9 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig):
)
class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class CannyImageProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Canny edge detection for ControlNet"""
# fmt: off
type: Literal["canny_image_processor"] = "canny_image_processor"
# Input
@@ -267,18 +260,22 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]},
"ui": {
"title": "Canny Processor",
"tags": ["controlnet", "canny", "image", "processor"]
},
}
def run_processor(self, image):
canny_processor = CannyDetector()
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
processed_image = canny_processor(
image, self.low_threshold, self.high_threshold)
return processed_image
class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class HedImageProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Applies HED edge detection to image"""
# fmt: off
type: Literal["hed_image_processor"] = "hed_image_processor"
# Inputs
@@ -291,25 +288,27 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig)
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]},
"ui": {
"title": "Softedge(HED) Processor",
"tags": ["controlnet", "softedge", "hed", "image", "processor"]
},
}
def run_processor(self, image):
hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = hed_processor(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
# safe not supported in controlnet_aux v0.0.3
# safe=self.safe,
scribble=self.scribble,
)
processed_image = hed_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
# safe not supported in controlnet_aux v0.0.3
# safe=self.safe,
scribble=self.scribble,
)
return processed_image
class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class LineartImageProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Applies line art processing to image"""
# fmt: off
type: Literal["lineart_image_processor"] = "lineart_image_processor"
# Inputs
@@ -320,20 +319,24 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]},
"ui": {
"title": "Lineart Processor",
"tags": ["controlnet", "lineart", "image", "processor"]
},
}
def run_processor(self, image):
lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
lineart_processor = LineartDetector.from_pretrained(
"lllyasviel/Annotators")
processed_image = lineart_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
)
image, detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution, coarse=self.coarse)
return processed_image
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class LineartAnimeImageProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Applies line art anime processing to image"""
# fmt: off
type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor"
# Inputs
@@ -345,23 +348,23 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati
schema_extra = {
"ui": {
"title": "Lineart Anime Processor",
"tags": ["controlnet", "lineart", "anime", "image", "processor"],
"tags": ["controlnet", "lineart", "anime", "image", "processor"]
},
}
def run_processor(self, image):
processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = processor(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
processor = LineartAnimeDetector.from_pretrained(
"lllyasviel/Annotators")
processed_image = processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
return processed_image
class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class OpenposeImageProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Applies Openpose processing to image"""
# fmt: off
type: Literal["openpose_image_processor"] = "openpose_image_processor"
# Inputs
@@ -372,23 +375,25 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]},
"ui": {
"title": "Openpose Processor",
"tags": ["controlnet", "openpose", "image", "processor"]
},
}
def run_processor(self, image):
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
openpose_processor = OpenposeDetector.from_pretrained(
"lllyasviel/Annotators")
processed_image = openpose_processor(
image,
detect_resolution=self.detect_resolution,
image, detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
hand_and_face=self.hand_and_face,
)
hand_and_face=self.hand_and_face,)
return processed_image
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class MidasDepthImageProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Applies Midas depth processing to image"""
# fmt: off
type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor"
# Inputs
@@ -400,24 +405,26 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]},
"ui": {
"title": "Midas (Depth) Processor",
"tags": ["controlnet", "midas", "depth", "image", "processor"]
},
}
def run_processor(self, image):
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
processed_image = midas_processor(
image,
a=np.pi * self.a_mult,
bg_th=self.bg_th,
# dept_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal=self.depth_and_normal,
)
processed_image = midas_processor(image,
a=np.pi * self.a_mult,
bg_th=self.bg_th,
# dept_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal=self.depth_and_normal,
)
return processed_image
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class NormalbaeImageProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Applies NormalBae processing to image"""
# fmt: off
type: Literal["normalbae_image_processor"] = "normalbae_image_processor"
# Inputs
@@ -427,20 +434,24 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]},
"ui": {
"title": "Normal BAE Processor",
"tags": ["controlnet", "normal", "bae", "image", "processor"]
},
}
def run_processor(self, image):
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
normalbae_processor = NormalBaeDetector.from_pretrained(
"lllyasviel/Annotators")
processed_image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
)
image, detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution)
return processed_image
class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class MlsdImageProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Applies MLSD processing to image"""
# fmt: off
type: Literal["mlsd_image_processor"] = "mlsd_image_processor"
# Inputs
@@ -452,24 +463,24 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]},
"ui": {
"title": "MLSD Processor",
"tags": ["controlnet", "mlsd", "image", "processor"]
},
}
def run_processor(self, image):
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
processed_image = mlsd_processor(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
thr_v=self.thr_v,
thr_d=self.thr_d,
)
image, detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution, thr_v=self.thr_v,
thr_d=self.thr_d)
return processed_image
class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class PidiImageProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Applies PIDI processing to image"""
# fmt: off
type: Literal["pidi_image_processor"] = "pidi_image_processor"
# Inputs
@@ -481,24 +492,25 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]},
"ui": {
"title": "PIDI Processor",
"tags": ["controlnet", "pidi", "image", "processor"]
},
}
def run_processor(self, image):
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
pidi_processor = PidiNetDetector.from_pretrained(
"lllyasviel/Annotators")
processed_image = pidi_processor(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
safe=self.safe,
scribble=self.scribble,
)
image, detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution, safe=self.safe,
scribble=self.scribble)
return processed_image
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class ContentShuffleImageProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Applies content shuffle processing to image"""
# fmt: off
type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor"
# Inputs
@@ -513,45 +525,48 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca
schema_extra = {
"ui": {
"title": "Content Shuffle Processor",
"tags": ["controlnet", "contentshuffle", "image", "processor"],
"tags": ["controlnet", "contentshuffle", "image", "processor"]
},
}
def run_processor(self, image):
content_shuffle_processor = ContentShuffleDetector()
processed_image = content_shuffle_processor(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
h=self.h,
w=self.w,
f=self.f,
)
processed_image = content_shuffle_processor(image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
h=self.h,
w=self.w,
f=self.f
)
return processed_image
# should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class ZoeDepthImageProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Applies Zoe depth processing to image"""
# fmt: off
type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor"
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]},
"ui": {
"title": "Zoe (Depth) Processor",
"tags": ["controlnet", "zoe", "depth", "image", "processor"]
},
}
def run_processor(self, image):
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
zoe_depth_processor = ZoeDetector.from_pretrained(
"lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
return processed_image
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class MediapipeFaceProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Applies mediapipe face processing to image"""
# fmt: off
type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor"
# Inputs
@@ -561,22 +576,26 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]},
"ui": {
"title": "Mediapipe Processor",
"tags": ["controlnet", "mediapipe", "image", "processor"]
},
}
def run_processor(self, image):
# MediaPipeFaceDetector throws an error if image has alpha channel
# so convert to RGB if needed
if image.mode == "RGBA":
image = image.convert("RGB")
if image.mode == 'RGBA':
image = image.convert('RGB')
mediapipe_face_processor = MediapipeFaceDetector()
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
processed_image = mediapipe_face_processor(
image, max_faces=self.max_faces, min_confidence=self.min_confidence)
return processed_image
class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class LeresImageProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Applies leres processing to image"""
# fmt: off
type: Literal["leres_image_processor"] = "leres_image_processor"
# Inputs
@@ -589,23 +608,24 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]},
"ui": {
"title": "Leres (Depth) Processor",
"tags": ["controlnet", "leres", "depth", "image", "processor"]
},
}
def run_processor(self, image):
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(
image,
thr_a=self.thr_a,
thr_b=self.thr_b,
boost=self.boost,
image, thr_a=self.thr_a, thr_b=self.thr_b, boost=self.boost,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
image_resolution=self.image_resolution)
return processed_image
class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class TileResamplerProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
# fmt: off
type: Literal["tile_image_processor"] = "tile_image_processor"
# Inputs
@@ -617,17 +637,16 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
schema_extra = {
"ui": {
"title": "Tile Resample Processor",
"tags": ["controlnet", "tile", "resample", "image", "processor"],
"tags": ["controlnet", "tile", "resample", "image", "processor"]
},
}
# tile_resample copied from sd-webui-controlnet/scripts/processor.py
def tile_resample(
self,
np_img: np.ndarray,
res=512, # never used?
down_sampling_rate=1.0,
):
def tile_resample(self,
np_img: np.ndarray,
res=512, # never used?
down_sampling_rate=1.0,
):
np_img = HWC3(np_img)
if down_sampling_rate < 1.1:
return np_img
@@ -639,41 +658,36 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo
def run_processor(self, img):
np_img = np.array(img, dtype=np.uint8)
processed_np_image = self.tile_resample(
np_img,
# res=self.tile_size,
down_sampling_rate=self.down_sampling_rate,
)
processed_np_image = self.tile_resample(np_img,
# res=self.tile_size,
down_sampling_rate=self.down_sampling_rate
)
processed_image = Image.fromarray(processed_np_image)
return processed_image
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig):
class SegmentAnythingProcessorInvocation(
ImageProcessorInvocation, PILInvocationConfig):
"""Applies segment anything processing to image"""
# fmt: off
type: Literal["segment_anything_processor"] = "segment_anything_processor"
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Segment Anything Processor",
"tags": ["controlnet", "segment", "anything", "sam", "image", "processor"],
},
}
schema_extra = {"ui": {"title": "Segment Anything Processor", "tags": [
"controlnet", "segment", "anything", "sam", "image", "processor"]}, }
def run_processor(self, image):
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints"
)
"ybelkada/segment-anything", subfolder="checkpoints")
np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(np_img)
return processed_image
class SamDetectorReproducibleColors(SamDetector):
# overriding SamDetector.show_anns() method to use reproducible colors for segmentation image
# base class show_anns() method randomizes colors,
# which seems to also lead to non-reproducible image generation
@@ -681,15 +695,19 @@ class SamDetectorReproducibleColors(SamDetector):
def show_anns(self, anns: List[Dict]):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x["area"]), reverse=True)
h, w = anns[0]["segmentation"].shape
final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
h, w = anns[0]['segmentation'].shape
final_img = Image.fromarray(
np.zeros((h, w, 3), dtype=np.uint8), mode="RGB")
palette = ade_palette()
for i, ann in enumerate(sorted_anns):
m = ann["segmentation"]
m = ann['segmentation']
img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8)
# doing modulo just in case number of annotated regions exceeds number of colors in palette
ann_color = palette[i % len(palette)]
img[:, :] = ann_color
final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m * 255)))
final_img.paste(
Image.fromarray(img, mode="RGB"),
(0, 0),
Image.fromarray(np.uint8(m * 255)))
return np.array(final_img, dtype=np.uint8)

View File

@@ -37,7 +37,10 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]},
"ui": {
"title": "OpenCV Inpaint",
"tags": ["opencv", "inpaint"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@@ -6,7 +6,8 @@ from typing import Literal, Optional, get_args
import torch
from pydantic import Field
from invokeai.app.models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from invokeai.app.models.image import (ColorField, ImageCategory, ImageField,
ResourceOrigin)
from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods
@@ -24,12 +25,13 @@ from contextlib import contextmanager, ExitStack, ContextDecorator
SAMPLER_NAME_VALUES = Literal[tuple(InvokeAIGenerator.schedulers())]
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
DEFAULT_INFILL_METHOD = (
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
)
from .latent import get_scheduler
class OldModelContext(ContextDecorator):
model: StableDiffusionGeneratorPipeline
@@ -42,7 +44,6 @@ class OldModelContext(ContextDecorator):
def __exit__(self, *exc):
return False
class OldModelInfo:
name: str
hash: str
@@ -63,34 +64,20 @@ class InpaintInvocation(BaseInvocation):
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
seed: int = Field(
ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed
)
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
width: int = Field(
default=512,
multiple_of=8,
gt=0,
description="The width of the resulting image",
)
height: int = Field(
default=512,
multiple_of=8,
gt=0,
description="The height of the resulting image",
)
cfg_scale: float = Field(
default=7.5,
ge=1,
description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt",
)
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use")
seed: int = Field(ge=0, le=SEED_MAX, description="The seed to use (omit for random)", default_factory=get_random_seed)
steps: int = Field(default=30, gt=0, description="The number of steps to use to generate the image")
width: int = Field(default=512, multiple_of=8, gt=0, description="The width of the resulting image", )
height: int = Field(default=512, multiple_of=8, gt=0, description="The height of the resulting image", )
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet model")
vae: VaeField = Field(default=None, description="Vae model")
# Inputs
image: Optional[ImageField] = Field(description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the original image")
strength: float = Field(
default=0.75, gt=0, le=1, description="The strength of the original image"
)
fit: bool = Field(
default=True,
description="Whether or not the result should be fit to the aspect ratio of the input image",
@@ -99,10 +86,18 @@ class InpaintInvocation(BaseInvocation):
# Inputs
mask: Optional[ImageField] = Field(description="The mask")
seam_size: int = Field(default=96, ge=1, description="The seam inpaint size (px)")
seam_blur: int = Field(default=16, ge=0, description="The seam inpaint blur radius (px)")
seam_strength: float = Field(default=0.75, gt=0, le=1, description="The seam inpaint strength")
seam_steps: int = Field(default=30, ge=1, description="The number of steps to use for seam inpaint")
tile_size: int = Field(default=32, ge=1, description="The tile infill method size (px)")
seam_blur: int = Field(
default=16, ge=0, description="The seam inpaint blur radius (px)"
)
seam_strength: float = Field(
default=0.75, gt=0, le=1, description="The seam inpaint strength"
)
seam_steps: int = Field(
default=30, ge=1, description="The number of steps to use for seam inpaint"
)
tile_size: int = Field(
default=32, ge=1, description="The tile infill method size (px)"
)
infill_method: INFILL_METHODS = Field(
default=DEFAULT_INFILL_METHOD,
description="The method used to infill empty regions (px)",
@@ -133,7 +128,10 @@ class InpaintInvocation(BaseInvocation):
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["stable-diffusion", "image"], "title": "Inpaint"},
"ui": {
"tags": ["stable-diffusion", "image"],
"title": "Inpaint"
},
}
def dispatch_progress(
@@ -164,23 +162,18 @@ class InpaintInvocation(BaseInvocation):
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
**lora.dict(exclude={"weight"}), context=context,)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context,)
vae_info = context.services.model_manager.get_model(**self.vae.vae.dict(), context=context,)
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
context=context,
)
with vae_info as vae,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
with vae_info as vae, ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()), unet_info as unet:
device = context.services.model_manager.mgr.cache.execution_device
dtype = context.services.model_manager.mgr.cache.precision
@@ -204,11 +197,21 @@ class InpaintInvocation(BaseInvocation):
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = None if self.image is None else context.services.images.get_pil_image(self.image.image_name)
mask = None if self.mask is None else context.services.images.get_pil_image(self.mask.image_name)
image = (
None
if self.image is None
else context.services.images.get_pil_image(self.image.image_name)
)
mask = (
None
if self.mask is None
else context.services.images.get_pil_image(self.mask.image_name)
)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
scheduler = get_scheduler(

View File

@@ -2,27 +2,63 @@
from typing import Literal, Optional
import cv2
import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import Field
from pathlib import Path
from pydantic import BaseModel, Field
from typing import Union
from invokeai.app.invocations.metadata import CoreMetadata
from ..models.image import (
ImageCategory,
ImageField,
ResourceOrigin,
PILInvocationConfig,
ImageOutput,
MaskOutput,
)
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
InvocationConfig,
)
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}
class LoadImageInvocation(BaseInvocation):
@@ -39,7 +75,10 @@ class LoadImageInvocation(BaseInvocation):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Load Image", "tags": ["image", "load"]},
"ui": {
"title": "Load Image",
"tags": ["image", "load"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
@@ -58,11 +97,16 @@ class ShowImageInvocation(BaseInvocation):
type: Literal["show_image"] = "show_image"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to show")
image: Optional[ImageField] = Field(
default=None, description="The image to show"
)
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Show Image", "tags": ["image", "show"]},
"ui": {
"title": "Show Image",
"tags": ["image", "show"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
@@ -95,13 +139,18 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Crop Image", "tags": ["image", "crop"]},
"ui": {
"title": "Crop Image",
"tags": ["image", "crop"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
image_crop = Image.new(mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0))
image_crop = Image.new(
mode="RGBA", size=(self.width, self.height), color=(0, 0, 0, 0)
)
image_crop.paste(image, (-self.x, -self.y))
image_dto = context.services.images.create(
@@ -136,15 +185,19 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Paste Image", "tags": ["image", "paste"]},
"ui": {
"title": "Paste Image",
"tags": ["image", "paste"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image(self.base_image.image_name)
image = context.services.images.get_pil_image(self.image.image_name)
mask = (
None if self.mask is None else ImageOps.invert(context.services.images.get_pil_image(self.mask.image_name))
)
mask = None
if self.mask is not None:
mask = context.services.images.get_pil_image(self.mask.image_name)
mask = ImageOps.invert(mask.convert("L"))
# TODO: probably shouldn't invert mask here... should user be required to do it?
min_x = min(0, self.x)
@@ -152,7 +205,9 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
max_x = max(base_image.width, image.width + self.x)
max_y = max(base_image.height, image.height + self.y)
new_image = Image.new(mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0))
new_image = Image.new(
mode="RGBA", size=(max_x - min_x, max_y - min_y), color=(0, 0, 0, 0)
)
new_image.paste(base_image, (abs(min_x), abs(min_y)))
new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask)
@@ -185,7 +240,10 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]},
"ui": {
"title": "Mask From Alpha",
"tags": ["image", "mask", "alpha"]
},
}
def invoke(self, context: InvocationContext) -> MaskOutput:
@@ -224,7 +282,10 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Multiply Images", "tags": ["image", "multiply"]},
"ui": {
"title": "Multiply Images",
"tags": ["image", "multiply"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
@@ -265,7 +326,10 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Image Channel", "tags": ["image", "channel"]},
"ui": {
"title": "Image Channel",
"tags": ["image", "channel"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
@@ -305,7 +369,10 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Convert Image", "tags": ["image", "convert"]},
"ui": {
"title": "Convert Image",
"tags": ["image", "convert"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
@@ -343,14 +410,19 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Blur Image", "tags": ["image", "blur"]},
"ui": {
"title": "Blur Image",
"tags": ["image", "blur"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
blur = (
ImageFilter.GaussianBlur(self.radius) if self.blur_type == "gaussian" else ImageFilter.BoxBlur(self.radius)
ImageFilter.GaussianBlur(self.radius)
if self.blur_type == "gaussian"
else ImageFilter.BoxBlur(self.radius)
)
blur_image = image.filter(blur)
@@ -405,7 +477,10 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Resize Image", "tags": ["image", "resize"]},
"ui": {
"title": "Resize Image",
"tags": ["image", "resize"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
@@ -448,7 +523,10 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Scale Image", "tags": ["image", "scale"]},
"ui": {
"title": "Scale Image",
"tags": ["image", "scale"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
@@ -493,7 +571,10 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]},
"ui": {
"title": "Image Linear Interpolation",
"tags": ["image", "linear", "interpolation", "lerp"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
@@ -536,7 +617,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
schema_extra = {
"ui": {
"title": "Image Inverse Linear Interpolation",
"tags": ["image", "linear", "interpolation", "inverse"],
"tags": ["image", "linear", "interpolation", "inverse"]
},
}
@@ -544,7 +625,12 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image = context.services.images.get_pil_image(self.image.image_name)
image_arr = numpy.asarray(image, dtype=numpy.float32)
image_arr = numpy.minimum(numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1) * 255
image_arr = (
numpy.minimum(
numpy.maximum(image_arr - self.min, 0) / float(self.max - self.min), 1
)
* 255
)
ilerp_image = Image.fromarray(numpy.uint8(image_arr))
@@ -563,86 +649,162 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
height=image_dto.height,
)
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
"""Add blur to NSFW-flagged images"""
class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
"""Applies an edge mask to an image"""
# fmt: off
type: Literal["img_nsfw"] = "img_nsfw"
type: Literal["mask_edge"] = "mask_edge"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
image: Optional[ImageField] = Field(default=None, description="The image to apply the mask to")
edge_size: int = Field(description="The size of the edge")
edge_blur: int = Field(description="The amount of blur on the edge")
low_threshold: int = Field(description="First threshold for the hysteresis procedure in Canny edge detection")
high_threshold: int = Field(description="Second threshold for the hysteresis procedure in Canny edge detection")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]},
}
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = context.services.images.get_pil_image(self.image.image_name)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
npimg = numpy.asarray(mask, dtype=numpy.uint8)
npgradient = numpy.uint8(
255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0))
)
npedge = cv2.Canny(npimg, threshold1=self.low_threshold, threshold2=self.high_threshold)
npmask = npgradient + npedge
npmask = cv2.dilate(
npmask, numpy.ones((3, 3), numpy.uint8), iterations=int(self.edge_size / 2)
)
logger = context.services.logger
logger.debug("Running NSFW checker")
if SafetyChecker.has_nsfw_concept(image):
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = self._get_caution_img()
blurry_image.paste(caution, (0, 0), caution)
image = blurry_image
new_mask = Image.fromarray(npmask)
if self.edge_blur > 0:
new_mask = new_mask.filter(ImageFilter.BoxBlur(self.edge_blur))
new_mask = ImageOps.invert(new_mask)
image_dto = context.services.images.create(
image=image,
image=new_mask,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.MASK,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return MaskOutput(
mask=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["color_correct"] = "color_correct"
init: Optional[ImageField] = Field(default=None, description="Initial image")
result: Optional[ImageField] = Field(default=None, description="Resulted image")
mask: Optional[ImageField] = Field(default=None, description="Mask image")
mask_blur_radius: float = Field(default=8, description="Mask blur radius")
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_init_mask = None
if self.mask is not None:
pil_init_mask = context.services.images.get_pil_image(
self.mask.image_name
).convert("L")
init_image = context.services.images.get_pil_image(
self.init.image_name
)
result = context.services.images.get_pil_image(
self.result.image_name
).convert("RGBA")
#if init_image is None or init_mask is None:
# return result
# Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
#pil_init_mask = (
# init_mask.getchannel("A")
# if init_mask.mode == "RGBA"
# else init_mask.convert("L")
#)
pil_init_image = init_image.convert(
"RGBA"
) # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching.
init_rgb_pixels = numpy.asarray(init_image.convert("RGB"), dtype=numpy.uint8)
init_a_pixels = numpy.asarray(pil_init_image.getchannel("A"), dtype=numpy.uint8)
init_mask_pixels = numpy.asarray(pil_init_mask, dtype=numpy.uint8)
# Get numpy version of result
np_image = numpy.asarray(result.convert("RGB"), dtype=numpy.uint8)
# Mask and calculate mean and standard deviation
mask_pixels = init_a_pixels * init_mask_pixels > 0
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
np_image_masked = np_image[mask_pixels, :]
if np_init_rgb_pixels_masked.size > 0:
init_means = np_init_rgb_pixels_masked.mean(axis=0)
init_std = np_init_rgb_pixels_masked.std(axis=0)
gen_means = np_image_masked.mean(axis=0)
gen_std = np_image_masked.std(axis=0)
# Color correct
np_matched_result = np_image.copy()
np_matched_result[:, :, :] = (
(
(
(
np_matched_result[:, :, :].astype(numpy.float32)
- gen_means[None, None, :]
)
/ gen_std[None, None, :]
)
* init_std[None, None, :]
+ init_means[None, None, :]
)
.clip(0, 255)
.astype(numpy.uint8)
)
matched_result = Image.fromarray(np_matched_result, mode="RGB")
else:
matched_result = Image.fromarray(np_image, mode="RGB")
# Blur the mask out (into init image) by specified amount
if self.mask_blur_radius > 0:
nm = numpy.asarray(pil_init_mask, dtype=numpy.uint8)
nmd = cv2.erode(
nm,
kernel=numpy.ones((3, 3), dtype=numpy.uint8),
iterations=int(self.mask_blur_radius / 2),
)
pmd = Image.fromarray(nmd, mode="L")
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(self.mask_blur_radius))
else:
blurred_init_mask = pil_init_mask
multiplied_blurred_init_mask = ImageChops.multiply(
blurred_init_mask, result.split()[-1]
)
# Paste original on color-corrected generation (using blurred mask)
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
image_dto = context.services.images.create(
image=matched_result,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
def _get_caution_img(self) -> Image:
import invokeai.app.assets.images as image_assets
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
return caution.resize((caution.width // 2, caution.height // 2))
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
"""Add an invisible watermark to an image"""
# fmt: off
type: Literal["img_watermark"] = "img_watermark"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check")
text: str = Field(default='InvokeAI', description="Watermark text")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
new_image = InvisibleWatermark.add_watermark(image, self.text)
image_dto = context.services.images.create(
image=new_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
)
return ImageOutput(

View File

@@ -30,7 +30,9 @@ def infill_methods() -> list[str]:
INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
DEFAULT_INFILL_METHOD = (
"patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"
)
def infill_patchmatch(im: Image.Image) -> Image.Image:
@@ -42,7 +44,9 @@ def infill_patchmatch(im: Image.Image) -> Image.Image:
return im
# Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
im_patched_np = PatchMatch.inpaint(
im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3
)
im_patched = Image.fromarray(im_patched_np, mode="RGB")
return im_patched
@@ -64,7 +68,9 @@ def get_tile_images(image: np.ndarray, width=8, height=8):
)
def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
def tile_fill_missing(
im: Image.Image, tile_size: int = 16, seed: Optional[int] = None
) -> Image.Image:
# Only fill if there's an alpha layer
if im.mode != "RGBA":
return im
@@ -97,7 +103,9 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
# Find all invalid tiles and replace with a random valid tile
replace_count = (tiles_mask == False).sum()
rng = np.random.default_rng(seed=seed)
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]
tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[
rng.choice(filtered_tiles.shape[0], replace_count), :, :, :
]
# Convert back to an image
tiles_all = tiles_all.reshape(tshape)
@@ -118,7 +126,9 @@ class InfillColorInvocation(BaseInvocation):
"""Infills transparent areas of an image with a solid color"""
type: Literal["infill_rgba"] = "infill_rgba"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
image: Optional[ImageField] = Field(
default=None, description="The image to infill"
)
color: ColorField = Field(
default=ColorField(r=127, g=127, b=127, a=255),
description="The color to use to infill",
@@ -126,7 +136,10 @@ class InfillColorInvocation(BaseInvocation):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]},
"ui": {
"title": "Color Infill",
"tags": ["image", "inpaint", "color", "infill"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
@@ -158,7 +171,9 @@ class InfillTileInvocation(BaseInvocation):
type: Literal["infill_tile"] = "infill_tile"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
image: Optional[ImageField] = Field(
default=None, description="The image to infill"
)
tile_size: int = Field(default=32, ge=1, description="The tile size (px)")
seed: int = Field(
ge=0,
@@ -169,13 +184,18 @@ class InfillTileInvocation(BaseInvocation):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]},
"ui": {
"title": "Tile Infill",
"tags": ["image", "inpaint", "tile", "infill"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
infilled = tile_fill_missing(
image.copy(), seed=self.seed, tile_size=self.tile_size
)
infilled.paste(image, (0, 0), image.split()[-1])
image_dto = context.services.images.create(
@@ -199,11 +219,16 @@ class InfillPatchMatchInvocation(BaseInvocation):
type: Literal["infill_patchmatch"] = "infill_patchmatch"
image: Optional[ImageField] = Field(default=None, description="The image to infill")
image: Optional[ImageField] = Field(
default=None, description="The image to infill"
)
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]},
"ui": {
"title": "Patch Match Infill",
"tags": ["image", "inpaint", "patchmatch", "infill"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:

View File

@@ -3,6 +3,8 @@
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import torchvision.transforms as T
from torchvision.transforms.functional import resize as tv_resize
import einops
import torch
from diffusers import ControlNetModel
@@ -12,22 +14,20 @@ from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.model_management.models.base import ModelType
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,
ControlNetData,
StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
ConditioningData, ControlNetData, StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import \
PostprocessingSettings
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.model_management import ModelPatcher
from ...backend.util.devices import choose_torch_device, torch_dtype, choose_precision
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
@@ -48,7 +48,8 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
class LatentsField(BaseModel):
"""A latents field used for passing latents between invocations"""
latents_name: Optional[str] = Field(default=None, description="The name of the latents")
latents_name: Optional[str] = Field(
default=None, description="The name of the latents")
class Config:
schema_extra = {"required": ["latents_name"]}
@@ -56,15 +57,14 @@ class LatentsField(BaseModel):
class LatentsOutput(BaseInvocationOutput):
"""Base class for invocations that output latents"""
# fmt: off
#fmt: off
type: Literal["latents_output"] = "latents_output"
# Inputs
latents: LatentsField = Field(default=None, description="The output latents")
width: int = Field(description="The width of the latents in pixels")
height: int = Field(description="The height of the latents in pixels")
# fmt: on
#fmt: on
def build_latents_output(latents_name: str, latents: torch.Tensor):
@@ -75,7 +75,9 @@ def build_latents_output(latents_name: str, latents: torch.Tensor):
)
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
SAMPLER_NAME_VALUES = Literal[
tuple(list(SCHEDULER_MAP.keys()))
]
def get_scheduler(
@@ -83,10 +85,11 @@ def get_scheduler(
scheduler_info: ModelInfo,
scheduler_name: str,
) -> Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(
scheduler_name, SCHEDULER_MAP['ddim']
)
orig_scheduler_info = context.services.model_manager.get_model(
**scheduler_info.dict(),
context=context,
**scheduler_info.dict(), context=context,
)
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@@ -101,7 +104,7 @@ def get_scheduler(
scheduler = scheduler_class.from_config(scheduler_config)
# hack copied over from generate.py
if not hasattr(scheduler, "uses_inpainting_model"):
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
return scheduler
@@ -122,8 +125,8 @@ class TextToLatentsInvocation(BaseInvocation):
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
@@ -132,10 +135,10 @@ class TextToLatentsInvocation(BaseInvocation):
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation
@@ -148,8 +151,8 @@ class TextToLatentsInvocation(BaseInvocation):
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number",
},
"cfg_scale": "number"
}
},
}
@@ -189,14 +192,16 @@ class TextToLatentsInvocation(BaseInvocation):
threshold=0.0, # threshold,
warmup=0.2, # warmup,
h_symmetry_time_pct=None, # h_symmetry_time_pct,
v_symmetry_time_pct=None, # v_symmetry_time_pct,
v_symmetry_time_pct=None # v_symmetry_time_pct,
),
)
conditioning_data = conditioning_data.add_scheduler_args_if_applicable(
scheduler,
# for ddim scheduler
eta=0.0, # ddim_eta
# for ancestral and sde schedulers
generator=torch.Generator(device=unet.device).manual_seed(0),
)
@@ -244,6 +249,7 @@ class TextToLatentsInvocation(BaseInvocation):
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
@@ -257,7 +263,7 @@ class TextToLatentsInvocation(BaseInvocation):
control_list = control_input
else:
control_list = None
if control_list is None:
if (control_list is None):
control_data = None
# from above handling, any control that is not None should now be of type list[ControlField]
else:
@@ -277,7 +283,9 @@ class TextToLatentsInvocation(BaseInvocation):
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
input_image = context.services.images.get_pil_image(
control_image_field.image_name
)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
@@ -312,71 +320,69 @@ class TextToLatentsInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings():
noise = context.services.latents.get(self.noise.latents_name)
noise = context.services.latents.get(self.noise.latents_name)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}), context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(), context=context,
)
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
control_data = self.prep_control_data(
model=pipeline,
context=context,
control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(unet.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
class LatentsToLatentsInvocation(TextToLatentsInvocation):
@@ -385,8 +391,14 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
type: Literal["l2l"] = "l2l"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to use as a base image")
strength: float = Field(default=0.7, ge=0, le=1, description="The strength of the latents to use")
latents: Optional[LatentsField] = Field(
description="The latents to use as a base image")
strength: float = Field(
default=0.7, ge=0, le=1,
description="The strength of the latents to use")
mask: Optional[ImageField] = Field(
None, description="Mask",
)
# Schema customisation
class Config(InvocationConfig):
@@ -398,89 +410,113 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
"model": "model",
"control": "control",
"cfg_scale": "number",
},
}
},
}
def prep_mask_tensor(self, context, lantents):
if self.mask is None:
return None
mask_image = context.services.images.get_pil_image(self.mask.image_name)
if mask_image.mode != "L":
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
mask_image = mask_image.convert("L")
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_tensor = tv_resize(
mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR
)
return mask_tensor
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
with SilenceWarnings(): # this quenches NSFW nag from diffusers
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name)
mask = self.prep_mask_tensor(context, latent)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
def step_callback(state: PipelineIntermediateState):
self.dispatch_progress(context, source_node_id, state)
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}),
context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
def _lora_loader():
for lora in self.unet.loras:
lora_info = context.services.model_manager.get_model(
**lora.dict(exclude={"weight"}), context=context,
)
yield (lora_info.context.model, lora.weight)
del lora_info
return
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(), context=context,
)
with ExitStack() as exit_stack,\
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype)
mask = mask.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), unet_info as unet:
noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)
control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
control_data = self.prep_control_data(
model=pipeline,
context=context,
control_input=self.control,
latents_shape=noise.shape,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# TODO: Verify the noise is the right size
initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
latent, device=unet.device, dtype=latent.dtype
)
# TODO: Verify the noise is the right size
initial_latents = (
latent if self.strength < 1.0 else torch.zeros_like(latent, device=unet.device, dtype=latent.dtype)
)
timesteps, _ = pipeline.get_img2img_timesteps(
self.steps,
self.strength,
device=unet.device,
)
timesteps, _ = pipeline.get_img2img_timesteps(
self.steps,
self.strength,
device=unet.device,
)
def _apply_mask_on_step(step_output, timestep, conditioning_data):
noised_init = scheduler.add_noise(initial_latents, noise, timestep.unsqueeze(0))
step_output.prev_sample = step_output.prev_sample * (1 - mask) + noised_init * mask
return step_output
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)
additional_guidance = []
if mask is not None:
additional_guidance.append(_apply_mask_on_step)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents,
timesteps=timesteps,
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback,
additional_guidance=additional_guidance,
)
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, result_latents)
return build_latents_output(latents_name=name, latents=result_latents)
@@ -491,13 +527,14 @@ class LatentsToImageInvocation(BaseInvocation):
type: Literal["l2i"] = "l2i"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
latents: Optional[LatentsField] = Field(
description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
metadata: Optional[CoreMetadata] = Field(
default=None, description="Optional core metadata to be written to the image"
)
tiled: bool = Field(
default=False,
description="Decode latents by overlaping tiles(less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# Schema customisation
class Config(InvocationConfig):
@@ -513,8 +550,7 @@ class LatentsToImageInvocation(BaseInvocation):
latents = context.services.latents.get(self.latents.latents_name)
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
context=context,
**self.vae.vae.dict(), context=context,
)
with vae_info as vae:
@@ -581,7 +617,8 @@ class LatentsToImageInvocation(BaseInvocation):
)
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear",
"bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
class ResizeLatentsInvocation(BaseInvocation):
@@ -590,30 +627,36 @@ class ResizeLatentsInvocation(BaseInvocation):
type: Literal["lresize"] = "lresize"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to resize")
width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)")
height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
latents: Optional[LatentsField] = Field(
description="The latents to resize")
width: Union[int, None] = Field(default=512,
ge=64, multiple_of=8, description="The width to resize to (px)")
height: Union[int, None] = Field(default=512,
ge=64, multiple_of=8, description="The height to resize to (px)")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
)
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Resize Latents", "tags": ["latents", "resize"]},
"ui": {
"title": "Resize Latents",
"tags": ["latents", "resize"]
},
}
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO:
device = choose_torch_device()
device=choose_torch_device()
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
size=(self.height // 8, self.width // 8),
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
latents.to(device), size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
@@ -632,30 +675,35 @@ class ScaleLatentsInvocation(BaseInvocation):
type: Literal["lscale"] = "lscale"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode")
latents: Optional[LatentsField] = Field(
description="The latents to scale")
scale_factor: float = Field(
gt=0, description="The factor by which to scale the latents")
mode: LATENTS_INTERPOLATION_MODE = Field(
default="bilinear", description="The interpolation mode")
antialias: bool = Field(
default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)"
)
default=False,
description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Scale Latents", "tags": ["latents", "scale"]},
"ui": {
"title": "Scale Latents",
"tags": ["latents", "scale"]
},
}
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)
# TODO:
device = choose_torch_device()
device=choose_torch_device()
# resizing
resized_latents = torch.nn.functional.interpolate(
latents.to(device),
scale_factor=self.scale_factor,
mode=self.mode,
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
latents.to(device), scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,
)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
@@ -676,13 +724,19 @@ class ImageToLatentsInvocation(BaseInvocation):
# Inputs
image: Optional[ImageField] = Field(description="The image to encode")
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision")
tiled: bool = Field(
default=False,
description="Encode latents by overlaping tiles(less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Image To Latents", "tags": ["latents", "image"]},
"ui": {
"title": "Image To Latents",
"tags": ["latents", "image"]
},
}
@torch.no_grad()
@@ -692,10 +746,9 @@ class ImageToLatentsInvocation(BaseInvocation):
# )
image = context.services.images.get_pil_image(self.image.image_name)
# vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
#vae_info = context.services.model_manager.get_model(**self.vae.vae.dict())
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
context=context,
**self.vae.vae.dict(), context=context,
)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
@@ -722,12 +775,12 @@ class ImageToLatentsInvocation(BaseInvocation):
vae.post_quant_conv.to(orig_dtype)
vae.decoder.conv_in.to(orig_dtype)
vae.decoder.mid_block.to(orig_dtype)
# else:
#else:
# latents = latents.float()
else:
vae.to(dtype=torch.float16)
# latents = latents.half()
#latents = latents.half()
if self.tiled:
vae.enable_tiling()
@@ -738,7 +791,9 @@ class ImageToLatentsInvocation(BaseInvocation):
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
latents = image_tensor_dist.sample().to(dtype=vae.dtype) # FIXME: uses torch.randn. make reproducible!
latents = image_tensor_dist.sample().to(
dtype=vae.dtype
) # FIXME: uses torch.randn. make reproducible!
latents = vae.config.scaling_factor * latents
latents = latents.to(dtype=orig_dtype)

View File

@@ -54,7 +54,10 @@ class AddInvocation(BaseInvocation, MathInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Add", "tags": ["math", "add"]},
"ui": {
"title": "Add",
"tags": ["math", "add"]
},
}
def invoke(self, context: InvocationContext) -> IntOutput:
@@ -72,7 +75,10 @@ class SubtractInvocation(BaseInvocation, MathInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Subtract", "tags": ["math", "subtract"]},
"ui": {
"title": "Subtract",
"tags": ["math", "subtract"]
},
}
def invoke(self, context: InvocationContext) -> IntOutput:
@@ -90,7 +96,10 @@ class MultiplyInvocation(BaseInvocation, MathInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Multiply", "tags": ["math", "multiply"]},
"ui": {
"title": "Multiply",
"tags": ["math", "multiply"]
},
}
def invoke(self, context: InvocationContext) -> IntOutput:
@@ -108,7 +117,10 @@ class DivideInvocation(BaseInvocation, MathInvocationConfig):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Divide", "tags": ["math", "divide"]},
"ui": {
"title": "Divide",
"tags": ["math", "divide"]
},
}
def invoke(self, context: InvocationContext) -> IntOutput:
@@ -128,7 +140,10 @@ class RandomIntInvocation(BaseInvocation):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]},
"ui": {
"title": "Random Integer",
"tags": ["math", "random", "integer"]
},
}
def invoke(self, context: InvocationContext) -> IntOutput:

View File

@@ -2,19 +2,16 @@ from typing import Literal, Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationConfig,
InvocationContext,
)
from invokeai.app.invocations.baseinvocation import (BaseInvocation,
BaseInvocationOutput, InvocationConfig,
InvocationContext)
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.model import (LoRAModelField, MainModelField,
VAEModelField)
class LoRAMetadataField(BaseModel):
"""LoRA metadata for an image generated in InvokeAI."""
lora: LoRAModelField = Field(description="The LoRA model")
weight: float = Field(description="The weight of the LoRA model")
@@ -22,9 +19,7 @@ class LoRAMetadataField(BaseModel):
class CoreMetadata(BaseModel):
"""Core generation metadata for an image generated in InvokeAI."""
generation_mode: str = Field(
description="The generation mode that output this image",
)
generation_mode: str = Field(description="The generation mode that output this image",)
positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter")
@@ -34,40 +29,21 @@ class CoreMetadata(BaseModel):
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(
description="The number of skipped CLIP layers",
)
clip_skip: int = Field(description="The number of skipped CLIP layers",)
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
vae: Union[VAEModelField, None] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
# Latents-to-Latents
strength: Union[float, None] = Field(
default=None,
description="The strength used for latents-to-latents",
)
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
# SDXL
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
# SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
refiner_cfg_scale: Union[float, None] = Field(
init_image: Union[str, None] = Field(
default=None, description="The name of the initial image"
)
vae: Union[VAEModelField, None] = Field(
default=None,
description="The classifier-free guidance scale parameter used for the refiner",
description="The VAE used for decoding, if the main model's default was not used",
)
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
refiner_aesthetic_store: Union[float, None] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
class ImageMetadata(BaseModel):
@@ -77,7 +53,9 @@ class ImageMetadata(BaseModel):
default=None,
description="The image's core metadata, if it was created in the Linear or Canvas UI",
)
graph: Optional[dict] = Field(default=None, description="The graph that created the image")
graph: Optional[dict] = Field(
default=None, description="The graph that created the image"
)
class MetadataAccumulatorOutput(BaseInvocationOutput):
@@ -93,9 +71,7 @@ class MetadataAccumulatorInvocation(BaseInvocation):
type: Literal["metadata_accumulator"] = "metadata_accumulator"
generation_mode: str = Field(
description="The generation mode that output this image",
)
generation_mode: str = Field(description="The generation mode that output this image",)
positive_prompt: str = Field(description="The positive prompt parameter")
negative_prompt: str = Field(description="The negative prompt parameter")
width: int = Field(description="The width parameter")
@@ -105,48 +81,52 @@ class MetadataAccumulatorInvocation(BaseInvocation):
cfg_scale: float = Field(description="The classifier-free guidance scale parameter")
steps: int = Field(description="The number of steps used for inference")
scheduler: str = Field(description="The scheduler used for inference")
clip_skip: int = Field(
description="The number of skipped CLIP layers",
)
clip_skip: int = Field(description="The number of skipped CLIP layers",)
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
controlnets: list[ControlField]= Field(description="The ControlNets used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
strength: Union[float, None] = Field(
default=None,
description="The strength used for latents-to-latents",
)
init_image: Union[str, None] = Field(default=None, description="The name of the initial image")
init_image: Union[str, None] = Field(
default=None, description="The name of the initial image"
)
vae: Union[VAEModelField, None] = Field(
default=None,
description="The VAE used for decoding, if the main model's default was not used",
)
# SDXL
positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter")
negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter")
# SDXL Refiner
refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used")
refiner_cfg_scale: Union[float, None] = Field(
default=None,
description="The classifier-free guidance scale parameter used for the refiner",
)
refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner")
refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner")
refiner_aesthetic_store: Union[float, None] = Field(
default=None, description="The aesthetic score used for the refiner"
)
refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising")
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Metadata Accumulator",
"tags": ["image", "metadata", "generation"],
"tags": ["image", "metadata", "generation"]
},
}
def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput:
"""Collects and outputs a CoreMetadata object"""
return MetadataAccumulatorOutput(metadata=CoreMetadata(**self.dict()))
return MetadataAccumulatorOutput(
metadata=CoreMetadata(
generation_mode=self.generation_mode,
positive_prompt=self.positive_prompt,
negative_prompt=self.negative_prompt,
width=self.width,
height=self.height,
seed=self.seed,
rand_device=self.rand_device,
cfg_scale=self.cfg_scale,
steps=self.steps,
scheduler=self.scheduler,
model=self.model,
strength=self.strength,
init_image=self.init_image,
vae=self.vae,
controlnets=self.controlnets,
loras=self.loras,
clip_skip=self.clip_skip,
)
)

View File

@@ -4,14 +4,17 @@ from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from ...backend.model_management import BaseModelType, ModelType, SubModelType
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Info to load submodel")
submodel: Optional[SubModelType] = Field(default=None, description="Info to load submodel")
submodel: Optional[SubModelType] = Field(
default=None, description="Info to load submodel"
)
class LoraInfo(ModelInfo):
@@ -30,7 +33,6 @@ class ClipField(BaseModel):
skipped_layers: int = Field(description="Number of skipped layers in text_encoder")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
@@ -47,13 +49,11 @@ class ModelLoaderOutput(BaseInvocationOutput):
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
class MainModelField(BaseModel):
"""Main model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Model Type")
class LoRAModelField(BaseModel):
@@ -62,7 +62,6 @@ class LoRAModelField(BaseModel):
model_name: str = Field(description="Name of the LoRA model")
base_model: BaseModelType = Field(description="Base model")
class MainModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
@@ -181,7 +180,7 @@ class MainModelLoaderInvocation(BaseInvocation):
),
)
class LoraLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
@@ -198,7 +197,9 @@ class LoraLoaderInvocation(BaseInvocation):
type: Literal["lora_loader"] = "lora_loader"
lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name")
lora: Union[LoRAModelField, None] = Field(
default=None, description="Lora model name"
)
weight: float = Field(default=0.75, description="With what weight to apply lora")
unet: Optional[UNetField] = Field(description="UNet model for applying lora")
@@ -227,10 +228,14 @@ class LoraLoaderInvocation(BaseInvocation):
):
raise Exception(f"Unkown lora name: {lora_name}!")
if self.unet is not None and any(lora.model_name == lora_name for lora in self.unet.loras):
if self.unet is not None and any(
lora.model_name == lora_name for lora in self.unet.loras
):
raise Exception(f'Lora "{lora_name}" already applied to unet')
if self.clip is not None and any(lora.model_name == lora_name for lora in self.clip.loras):
if self.clip is not None and any(
lora.model_name == lora_name for lora in self.clip.loras
):
raise Exception(f'Lora "{lora_name}" already applied to clip')
output = LoraLoaderOutput()

View File

@@ -119,8 +119,8 @@ class NoiseInvocation(BaseInvocation):
@validator("seed", pre=True)
def modulo_seed(cls, v):
"""Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % (SEED_MAX + 1)
"""Returns the seed modulo SEED_MAX to ensure it is within the valid range."""
return v % SEED_MAX
def invoke(self, context: InvocationContext) -> NoiseOutput:
noise = get_noise(

View File

@@ -1,578 +0,0 @@
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import re
import inspect
from pydantic import BaseModel, Field, validator
import torch
import numpy as np
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management import ONNXModelPatcher
from ...backend.util import choose_torch_device
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.backend import BaseModelType, ModelType, SubModelType
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from ...backend.stable_diffusion import PipelineIntermediateState
from tqdm import tqdm
from .model import ClipField
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
from .compel import CompelOutput
ORT_TO_NP_TYPE = {
"tensor(bool)": np.bool_,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
"tensor(int16)": np.int16,
"tensor(uint16)": np.uint16,
"tensor(int32)": np.int32,
"tensor(uint32)": np.uint32,
"tensor(int64)": np.int64,
"tensor(uint64)": np.uint64,
"tensor(float16)": np.float16,
"tensor(float)": np.float32,
"tensor(double)": np.float64,
}
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
class ONNXPromptInvocation(BaseInvocation):
type: Literal["prompt_onnx"] = "prompt_onnx"
prompt: str = Field(default="", description="Prompt")
clip: ClipField = Field(None, description="Clip to use")
def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(),
)
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
loras = [
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
for lora in self.clip.loras
]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
# stack.enter_context(
# context.services.model_manager.get_model(
# model_name=name,
# base_model=self.clip.text_encoder.base_model,
# model_type=ModelType.TextualInversion,
# )
# )
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
).context.model
)
except Exception:
# print(e)
# import traceback
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
if loras or ti_list:
text_encoder.release_session()
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti(
orig_tokenizer, text_encoder, ti_list
) as (tokenizer, ti_manager):
text_encoder.create_session()
# copy from
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
text_inputs = tokenizer(
self.prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
"""
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
"""
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (prompt_embeds, None))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
# Text to image
class ONNXTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["t2l_onnx"] = "t2l_onnx"
# Inputs
# fmt: off
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents")
unet: UNetField = Field(default=None, description="UNet submodel")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
"control": "control",
# "cfg_scale": "float",
"cfg_scale": "number",
},
},
}
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
def invoke(self, context: InvocationContext) -> LatentsOutput:
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
if isinstance(c, torch.Tensor):
c = c.cpu().numpy()
if isinstance(uc, torch.Tensor):
uc = uc.cpu().numpy()
device = torch.device(choose_torch_device())
prompt_embeds = np.concatenate([uc, c])
latents = context.services.latents.get(self.noise.latents_name)
if isinstance(latents, torch.Tensor):
latents = latents.cpu().numpy()
# TODO: better execution device handling
latents = latents.astype(ORT_TO_NP_TYPE[self.precision])
# get the initial random noise unless the user supplied it
do_classifier_free_guidance = True
# latents_dtype = prompt_embeds.dtype
# latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
# if latents.shape != latents_shape:
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
def torch2numpy(latent: torch.Tensor):
return latent.cpu().numpy()
def numpy2torch(latent, device):
return torch.from_numpy(latent).to(device)
def dispatch_progress(
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
) -> None:
stable_diffusion_step_callback(
context=context,
intermediate_state=intermediate_state,
node=self.dict(),
source_node_id=source_node_id,
)
scheduler.set_timesteps(self.steps)
latents = latents * np.float64(scheduler.init_noise_sigma)
extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
eta=0.0,
)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet, ExitStack() as stack:
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
for lora in self.unet.loras
]
if loras:
unet.release_session()
with ONNXModelPatcher.apply_lora_unet(unet, loras):
# TODO:
_, _, h, w = latents.shape
unet.create_session(h, w)
timestep_dtype = next(
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
for i in tqdm(range(len(scheduler.timesteps))):
t = scheduler.timesteps[i]
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
latent_model_input = latent_model_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
noise_pred = noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = scheduler.step(
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
)
latents = torch2numpy(scheduler_output.prev_sample)
state = PipelineIntermediateState(
run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample
)
dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state)
# call the callback, if provided
# if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
# Latent to image
class ONNXLatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i_onnx"] = "l2i_onnx"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel")
metadata: Optional[CoreMetadata] = Field(
default=None, description="Optional core metadata to be written to the image"
)
# tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
if self.vae.vae.submodel != SubModelType.VaeDecoder:
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
)
# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
with vae_info as vae:
vae.create_session()
# copied from
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])])
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
image = VaeImageProcessor.numpy_to_pil(image)[0]
torch.cuda.empty_cache()
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class ONNXModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
# fmt: off
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
model_name: str = Field(default="", description="Model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name?
}
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
model_name = "stable-diffusion-v1-5"
base_model = BaseModelType.StableDiffusion1
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.ONNX,
):
raise Exception(f"Unkown model name: {model_name}!")
return ONNXModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae_decoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.VaeDecoder,
),
),
vae_encoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.VaeEncoder,
),
),
)
class OnnxModelField(BaseModel):
"""Onnx model field"""
model_name: str = Field(description="Name of the model")
base_model: BaseModelType = Field(description="Base model")
model_type: ModelType = Field(description="Model Type")
class OnnxModelLoaderInvocation(BaseInvocation):
"""Loads a main model, outputting its submodels."""
type: Literal["onnx_model_loader"] = "onnx_model_loader"
model: OnnxModelField = Field(description="The model to load")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"title": "Onnx Model Loader",
"tags": ["model", "loader"],
"type_hints": {"model": "model"},
},
}
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
base_model = self.model.base_model
model_name = self.model.model_name
model_type = ModelType.ONNX
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
return ONNXModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder,
),
loras=[],
skipped_layers=0,
),
vae_decoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.VaeDecoder,
),
),
vae_encoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.VaeEncoder,
),
),
)

View File

@@ -12,37 +12,16 @@ import matplotlib.pyplot as plt
from easing_functions import (
LinearInOut,
QuadEaseInOut,
QuadEaseIn,
QuadEaseOut,
CubicEaseInOut,
CubicEaseIn,
CubicEaseOut,
QuarticEaseInOut,
QuarticEaseIn,
QuarticEaseOut,
QuinticEaseInOut,
QuinticEaseIn,
QuinticEaseOut,
SineEaseInOut,
SineEaseIn,
SineEaseOut,
CircularEaseIn,
CircularEaseInOut,
CircularEaseOut,
ExponentialEaseInOut,
ExponentialEaseIn,
ExponentialEaseOut,
ElasticEaseIn,
ElasticEaseInOut,
ElasticEaseOut,
BackEaseIn,
BackEaseInOut,
BackEaseOut,
BounceEaseIn,
BounceEaseInOut,
BounceEaseOut,
)
QuadEaseInOut, QuadEaseIn, QuadEaseOut,
CubicEaseInOut, CubicEaseIn, CubicEaseOut,
QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
SineEaseInOut, SineEaseIn, SineEaseOut,
CircularEaseIn, CircularEaseInOut, CircularEaseOut,
ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
BackEaseIn, BackEaseInOut, BackEaseOut,
BounceEaseIn, BounceEaseInOut, BounceEaseOut)
from .baseinvocation import (
BaseInvocation,
@@ -66,12 +45,17 @@ class FloatLinearRangeInvocation(BaseInvocation):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]},
"ui": {
"title": "Linear Range (Float)",
"tags": ["math", "float", "linear", "range"]
},
}
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
param_list = list(np.linspace(self.start, self.stop, self.steps))
return FloatCollectionOutput(collection=param_list)
return FloatCollectionOutput(
collection=param_list
)
EASING_FUNCTIONS_MAP = {
@@ -108,7 +92,9 @@ EASING_FUNCTIONS_MAP = {
"BounceInOut": BounceEaseInOut,
}
EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))]
EASING_FUNCTION_KEYS: Any = Literal[
tuple(list(EASING_FUNCTIONS_MAP.keys()))
]
# actually I think for now could just use CollectionOutput (which is list[Any]
@@ -137,9 +123,13 @@ class StepParamEasingInvocation(BaseInvocation):
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]},
"ui": {
"title": "Param Easing By Step",
"tags": ["param", "step", "easing"]
},
}
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
log_diagnostics = False
# convert from start_step_percent to nearest step <= (steps * start_step_percent)
@@ -180,13 +170,12 @@ class StepParamEasingInvocation(BaseInvocation):
# and create reverse copy of list[1:end-1]
# but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
if log_diagnostics:
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
easing_function = easing_class(
start=self.start_value, end=self.end_value, duration=base_easing_duration - 1
)
base_easing_duration = int(np.ceil(num_easing_steps/2.0))
if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration))
even_num_steps = (num_easing_steps % 2 == 0) # even number of steps
easing_function = easing_class(start=self.start_value,
end=self.end_value,
duration=base_easing_duration - 1)
base_easing_vals = list()
for step_index in range(base_easing_duration):
easing_val = easing_function.ease(step_index)
@@ -225,7 +214,9 @@ class StepParamEasingInvocation(BaseInvocation):
#
else: # no mirroring (default)
easing_function = easing_class(start=self.start_value, end=self.end_value, duration=num_easing_steps - 1)
easing_function = easing_class(start=self.start_value,
end=self.end_value,
duration=num_easing_steps - 1)
for step_index in range(num_easing_steps):
step_val = easing_function.ease(step_index)
easing_list.append(step_val)
@@ -249,11 +240,13 @@ class StepParamEasingInvocation(BaseInvocation):
ax = plt.gca()
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.savefig(buf, format='png')
buf.seek(0)
im = PIL.Image.open(buf)
im.show()
buf.close()
# output array of size steps, each entry list[i] is param value for step i
return FloatCollectionOutput(collection=param_list)
return FloatCollectionOutput(
collection=param_list
)

View File

@@ -4,80 +4,67 @@ from typing import Literal
from pydantic import Field
from invokeai.app.invocations.prompt import PromptOutput
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .math import FloatOutput, IntOutput
# Pass-through parameter nodes - used by subgraphs
class ParamIntInvocation(BaseInvocation):
"""An integer parameter"""
# fmt: off
#fmt: off
type: Literal["param_int"] = "param_int"
a: int = Field(default=0, description="The integer value")
# fmt: on
#fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "integer"], "title": "Integer Parameter"},
}
schema_extra = {
"ui": {
"tags": ["param", "integer"],
"title": "Integer Parameter"
},
}
def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a)
class ParamFloatInvocation(BaseInvocation):
"""A float parameter"""
# fmt: off
#fmt: off
type: Literal["param_float"] = "param_float"
param: float = Field(default=0.0, description="The float value")
# fmt: on
#fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "float"], "title": "Float Parameter"},
}
schema_extra = {
"ui": {
"tags": ["param", "float"],
"title": "Float Parameter"
},
}
def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(param=self.param)
class StringOutput(BaseInvocationOutput):
"""A string output"""
type: Literal["string_output"] = "string_output"
text: str = Field(default=None, description="The output string")
class ParamStringInvocation(BaseInvocation):
"""A string parameter"""
type: Literal["param_string"] = "param_string"
text: str = Field(default="", description="The string value")
type: Literal['param_string'] = 'param_string'
text: str = Field(default='', description='The string value')
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "string"], "title": "String Parameter"},
}
schema_extra = {
"ui": {
"tags": ["param", "string"],
"title": "String Parameter"
},
}
def invoke(self, context: InvocationContext) -> StringOutput:
return StringOutput(text=self.text)
class ParamPromptInvocation(BaseInvocation):
"""A prompt input parameter"""
type: Literal["param_prompt"] = "param_prompt"
prompt: str = Field(default="", description="The prompt value")
class Config(InvocationConfig):
schema_extra = {
"ui": {"tags": ["param", "prompt"], "title": "Prompt"},
}
def invoke(self, context: InvocationContext) -> PromptOutput:
return PromptOutput(prompt=self.prompt)

View File

@@ -7,21 +7,19 @@ from pydantic import Field, validator
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator
class PromptOutput(BaseInvocationOutput):
"""Base class for invocations that output a prompt"""
# fmt: off
#fmt: off
type: Literal["prompt"] = "prompt"
prompt: str = Field(default=None, description="The output prompt")
# fmt: on
#fmt: on
class Config:
schema_extra = {
"required": [
"type",
"prompt",
'required': [
'type',
'prompt',
]
}
@@ -46,11 +44,16 @@ class DynamicPromptInvocation(BaseInvocation):
type: Literal["dynamic_prompt"] = "dynamic_prompt"
prompt: str = Field(description="The prompt to parse with dynamicprompts")
max_prompts: int = Field(default=1, description="The number of prompts to generate")
combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator")
combinatorial: bool = Field(
default=False, description="Whether to use the combinatorial generator"
)
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]},
"ui": {
"title": "Dynamic Prompt",
"tags": ["prompt", "dynamic"]
},
}
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
@@ -62,11 +65,10 @@ class DynamicPromptInvocation(BaseInvocation):
prompts = generator.generate(self.prompt, num_images=self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))
class PromptsFromFileInvocation(BaseInvocation):
"""Loads prompts from a text file"""
'''Loads prompts from a text file'''
# fmt: off
type: Literal['prompt_from_file'] = 'prompt_from_file'
@@ -76,11 +78,14 @@ class PromptsFromFileInvocation(BaseInvocation):
post_prompt: Optional[str] = Field(description="String to append to each prompt")
start_line: int = Field(default=1, ge=1, description="Line in the file to start start from")
max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)")
# fmt: on
#fmt: on
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Prompts From File", "tags": ["prompt", "file"]},
"ui": {
"title": "Prompts From File",
"tags": ["prompt", "file"]
},
}
@validator("file_path")
@@ -98,13 +103,11 @@ class PromptsFromFileInvocation(BaseInvocation):
with open(file_path) as f:
for i, line in enumerate(f):
if i >= start_line and i < end_line:
prompts.append((pre_prompt or "") + line.strip() + (post_prompt or ""))
prompts.append((pre_prompt or '') + line.strip() + (post_prompt or ''))
if i >= end_line:
break
return prompts
def invoke(self, context: InvocationContext) -> PromptCollectionOutput:
prompts = self.promptsFromFile(
self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts
)
prompts = self.promptsFromFile(self.file_path, self.pre_prompt, self.post_prompt, self.start_line, self.max_prompts)
return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts))

View File

@@ -7,13 +7,13 @@ from pydantic import Field, validator
from ...backend.model_management import ModelType, SubModelType
from invokeai.app.util.step_callback import stable_diffusion_xl_step_callback
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
from .compel import ConditioningField
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output
class SDXLModelLoaderOutput(BaseInvocationOutput):
"""SDXL base model loader output"""
@@ -26,19 +26,16 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output"""
# fmt: off
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
# fmt: on
#fmt: on
class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels."""
@@ -128,10 +125,8 @@ class SDXLModelLoaderInvocation(BaseInvocation):
),
)
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels."""
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
model: MainModelField = Field(description="The model to load")
@@ -143,7 +138,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"ui": {
"title": "SDXL Refiner Model Loader",
"tags": ["model", "loader", "sdxl_refiner"],
"type_hints": {"model": "refiner_model"},
"type_hints": {"model": "model"},
},
}
@@ -201,8 +196,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
),
),
)
# Text to image
class SDXLTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
@@ -219,9 +213,9 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel")
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
@@ -230,10 +224,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation
@@ -243,10 +237,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
"title": "SDXL Text To Latents",
"tags": ["latents"],
"type_hints": {
"model": "model",
# "cfg_scale": "float",
"cfg_scale": "number",
},
"model": "model",
# "cfg_scale": "float",
"cfg_scale": "number"
}
},
}
@@ -271,7 +265,9 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.noise.latents_name)
@@ -292,15 +288,18 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
)
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps)
timesteps = scheduler.timesteps
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict(), context=context)
latents = latents * scheduler.init_noise_sigma
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()
)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
scheduler.set_timesteps(num_inference_steps, device=unet.device)
timesteps = scheduler.timesteps
latents = latents.to(device=unet.device, dtype=unet.dtype) * scheduler.init_noise_sigma
extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
@@ -351,10 +350,10 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
# del noise_pred_uncond
# del noise_pred_text
#del noise_pred_uncond
#del noise_pred_text
# if do_classifier_free_guidance and guidance_rescale > 0.0:
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
@@ -365,7 +364,7 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
# if callback is not None and i % callback_steps == 0:
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
else:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
@@ -379,13 +378,13 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latents, t)
# import gc
# gc.collect()
# torch.cuda.empty_cache()
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# predict the noise residual
@@ -412,41 +411,42 @@ class SDXLTextToLatentsInvocation(BaseInvocation):
# perform guidance
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
# del noise_pred_text
# del noise_pred_uncond
# import gc
# gc.collect()
# torch.cuda.empty_cache()
#del noise_pred_text
#del noise_pred_uncond
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# if do_classifier_free_guidance and guidance_rescale > 0.0:
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# del noise_pred
# import gc
# gc.collect()
# torch.cuda.empty_cache()
#del noise_pred
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
# if callback is not None and i % callback_steps == 0:
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
#################
latents = latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)
class SDXLLatentsToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
@@ -463,12 +463,12 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
unet: UNetField = Field(default=None, description="UNet submodel")
latents: Optional[LatentsField] = Field(description="Initial latents")
denoising_start: float = Field(default=0.0, ge=0, le=1, description="")
denoising_end: float = Field(default=1.0, ge=0, le=1, description="")
denoising_start: float = Field(default=0.0, ge=0, lt=1, description="")
denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
# control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
@@ -477,10 +477,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation
@@ -490,10 +490,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
"title": "SDXL Latents to Latents",
"tags": ["latents"],
"type_hints": {
"model": "model",
# "cfg_scale": "float",
"cfg_scale": "number",
},
"model": "model",
# "cfg_scale": "float",
"cfg_scale": "number"
}
},
}
@@ -518,7 +518,9 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
latents = context.services.latents.get(self.latents.latents_name)
@@ -538,27 +540,26 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
scheduler_name=self.scheduler,
)
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict(),
context=context,
)
# apply denoising_start
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps)
t_start = int(round(self.denoising_start * num_inference_steps))
timesteps = scheduler.timesteps[t_start * scheduler.order:]
num_inference_steps = num_inference_steps - t_start
# apply noise(if provided)
if self.noise is not None:
noise = context.services.latents.get(self.noise.latents_name)
latents = scheduler.add_noise(latents, noise, timesteps[:1])
del noise
unet_info = context.services.model_manager.get_model(
**self.unet.unet.dict()
)
do_classifier_free_guidance = True
cross_attention_kwargs = None
with unet_info as unet:
# apply denoising_start
num_inference_steps = self.steps
scheduler.set_timesteps(num_inference_steps, device=unet.device)
t_start = int(round(self.denoising_start * num_inference_steps))
timesteps = scheduler.timesteps[t_start * scheduler.order :]
num_inference_steps = num_inference_steps - t_start
# apply noise(if provided)
if self.noise is not None and timesteps.shape[0] > 0:
noise = context.services.latents.get(self.noise.latents_name)
latents = scheduler.add_noise(latents, noise, timesteps[:1])
del noise
# apply scheduler extra args
extra_step_kwargs = dict()
@@ -610,10 +611,10 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
# del noise_pred_uncond
# del noise_pred_text
#del noise_pred_uncond
#del noise_pred_text
# if do_classifier_free_guidance and guidance_rescale > 0.0:
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
@@ -624,7 +625,7 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
# if callback is not None and i % callback_steps == 0:
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
else:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
@@ -638,13 +639,13 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
# latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(latents, t)
# import gc
# gc.collect()
# torch.cuda.empty_cache()
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# predict the noise residual
@@ -671,36 +672,38 @@ class SDXLLatentsToLatentsInvocation(BaseInvocation):
# perform guidance
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
# del noise_pred_text
# del noise_pred_uncond
# import gc
# gc.collect()
# torch.cuda.empty_cache()
#del noise_pred_text
#del noise_pred_uncond
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# if do_classifier_free_guidance and guidance_rescale > 0.0:
#if do_classifier_free_guidance and guidance_rescale > 0.0:
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
# noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# del noise_pred
# import gc
# gc.collect()
# torch.cuda.empty_cache()
#del noise_pred
#import gc
#gc.collect()
#torch.cuda.empty_cache()
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
self.dispatch_progress(context, source_node_id, latents, i, num_inference_steps)
# if callback is not None and i % callback_steps == 0:
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
#################
latents = latents.to("cpu")
torch.cuda.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}"
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)

View File

@@ -29,11 +29,16 @@ class ESRGANInvocation(BaseInvocation):
type: Literal["esrgan"] = "esrgan"
image: Union[ImageField, None] = Field(default=None, description="The input image")
model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
model_name: ESRGAN_MODELS = Field(
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
)
class Config(InvocationConfig):
schema_extra = {
"ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]},
"ui": {
"title": "Upscale (RealESRGAN)",
"tags": ["image", "upscale", "realesrgan"]
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
@@ -103,7 +108,9 @@ class ESRGANInvocation(BaseInvocation):
upscaled_image, img_mode = upsampler.enhance(cv_image)
# back to PIL
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
pil_image = Image.fromarray(
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
).convert("RGBA")
image_dto = context.services.images.create(
image=pil_image,

View File

@@ -1,4 +1,3 @@
class CanceledException(Exception):
"""Execution canceled by user."""
pass

View File

@@ -1,83 +1,8 @@
from enum import Enum
from typing import Optional, Tuple, Literal
from typing import Optional, Tuple
from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum
from ..invocations.baseinvocation import (
BaseInvocationOutput,
InvocationConfig,
)
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}
class ResourceOrigin(str, Enum, metaclass=MetaEnum):
@@ -136,3 +61,30 @@ class InvalidImageCategoryException(ValueError):
def __init__(self, message="Invalid image category."):
super().__init__(message)
class ImageField(BaseModel):
"""An image field used for passing image objects between invocations"""
image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config:
schema_extra = {"required": ["image_name"]}
class ColorField(BaseModel):
r: int = Field(ge=0, le=255, description="The red component")
g: int = Field(ge=0, le=255, description="The green component")
b: int = Field(ge=0, le=255, description="The blue component")
a: int = Field(ge=0, le=255, description="The alpha component")
def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

View File

@@ -207,7 +207,9 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
raise e
finally:
self._lock.release()
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
return OffsetPaginatedResults(
items=images, offset=offset, limit=limit, total=count
)
def get_all_board_image_names_for_board(self, board_id: str) -> list[str]:
try:

View File

@@ -102,7 +102,9 @@ class BoardImagesService(BoardImagesServiceABC):
self,
board_id: str,
) -> list[str]:
return self._services.board_image_records.get_all_board_image_names_for_board(board_id)
return self._services.board_image_records.get_all_board_image_names_for_board(
board_id
)
def get_board_for_image(
self,
@@ -112,7 +114,9 @@ class BoardImagesService(BoardImagesServiceABC):
return board_id
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
def board_record_to_dto(
board_record: BoardRecord, cover_image_name: Optional[str], image_count: int
) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
**board_record.dict(exclude={"cover_image_name"}),

View File

@@ -15,7 +15,9 @@ from pydantic import BaseModel, Field, Extra
class BoardChanges(BaseModel, extra=Extra.forbid):
board_name: Optional[str] = Field(description="The board's new name.")
cover_image_name: Optional[str] = Field(description="The name of the board's new cover image.")
cover_image_name: Optional[str] = Field(
description="The name of the board's new cover image."
)
class BoardRecordNotFoundException(Exception):
@@ -290,7 +292,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
count = cast(int, self._cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
return OffsetPaginatedResults[BoardRecord](
items=boards, offset=offset, limit=limit, total=count
)
except sqlite3.Error as e:
self._conn.rollback()

View File

@@ -108,12 +108,16 @@ class BoardService(BoardServiceABC):
def get_dto(self, board_id: str) -> BoardDTO:
board_record = self._services.board_records.get(board_id)
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
cover_image = self._services.image_records.get_most_recent_image_for_board(
board_record.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
image_count = self._services.board_image_records.get_image_count_for_board(
board_id
)
return board_record_to_dto(board_record, cover_image_name, image_count)
def update(
@@ -122,44 +126,60 @@ class BoardService(BoardServiceABC):
changes: BoardChanges,
) -> BoardDTO:
board_record = self._services.board_records.update(board_id, changes)
cover_image = self._services.image_records.get_most_recent_image_for_board(board_record.board_id)
cover_image = self._services.image_records.get_most_recent_image_for_board(
board_record.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(board_id)
image_count = self._services.board_image_records.get_image_count_for_board(
board_id
)
return board_record_to_dto(board_record, cover_image_name, image_count)
def delete(self, board_id: str) -> None:
self._services.board_records.delete(board_id)
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
def get_many(
self, offset: int = 0, limit: int = 10
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self._services.board_records.get_many(offset, limit)
board_dtos = []
for r in board_records.items:
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
cover_image = self._services.image_records.get_most_recent_image_for_board(
r.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
return OffsetPaginatedResults[BoardDTO](
items=board_dtos, offset=offset, limit=limit, total=len(board_dtos)
)
def get_all(self) -> list[BoardDTO]:
board_records = self._services.board_records.get_all()
board_dtos = []
for r in board_records:
cover_image = self._services.image_records.get_most_recent_image_for_board(r.board_id)
cover_image = self._services.image_records.get_most_recent_image_for_board(
r.board_id
)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self._services.board_image_records.get_image_count_for_board(r.board_id)
image_count = self._services.board_image_records.get_image_count_for_board(
r.board_id
)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
return board_dtos
return board_dtos

View File

@@ -1,6 +1,6 @@
# Copyright (c) 2023 Lincoln Stein (https://github.com/lstein) and the InvokeAI Development Team
"""Invokeai configuration system.
'''Invokeai configuration system.
Arguments and fields are taken from the pydantic definition of the
model. Defaults can be set by creating a yaml configuration file that
@@ -28,6 +28,7 @@ InvokeAI:
always_use_cpu: false
free_gpu_mem: false
Features:
nsfw_checker: true
restore: true
esrgan: true
patchmatch: true
@@ -91,18 +92,18 @@ Typical usage at the top level file:
from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its cache size
# get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig.get_config()
conf.parse_args()
print(conf.max_cache_size)
print(conf.nsfw_checker)
Typical usage in a backend module:
from invokeai.app.services.config import InvokeAIAppConfig
# get global configuration and print its cache size value
# get global configuration and print its nsfw_checker value
conf = InvokeAIAppConfig.get_config()
print(conf.max_cache_size)
print(conf.nsfw_checker)
Computed properties:
@@ -158,7 +159,7 @@ two configs are kept in separate sections of the config file:
outdir: outputs
...
"""
'''
from __future__ import annotations
import argparse
import pydoc
@@ -170,67 +171,64 @@ from pathlib import Path
from pydantic import BaseSettings, Field, parse_obj_as
from typing import ClassVar, Dict, List, Set, Literal, Union, get_origin, get_type_hints, get_args
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
INIT_FILE = Path('invokeai.yaml')
MODEL_CORE = Path('models/core')
DB_FILE = Path('invokeai.db')
LEGACY_INIT_FILE = Path('invokeai.init')
class InvokeAISettings(BaseSettings):
"""
'''
Runtime configuration settings in which default values are
read from an omegaconf .yaml file.
"""
'''
initconf : ClassVar[DictConfig] = None
argparse_groups : ClassVar[Dict] = {}
initconf: ClassVar[DictConfig] = None
argparse_groups: ClassVar[Dict] = {}
def parse_args(self, argv: list = sys.argv[1:]):
def parse_args(self, argv: list=sys.argv[1:]):
parser = self.get_parser()
opt = parser.parse_args(argv)
for name in self.__fields__:
if name not in self._excluded():
setattr(self, name, getattr(opt, name))
setattr(self, name, getattr(opt,name))
def to_yaml(self) -> str:
def to_yaml(self)->str:
"""
Return a YAML string representing our settings. This can be used
as the contents of `invokeai.yaml` to restore settings later.
"""
cls = self.__class__
type = get_args(get_type_hints(cls)["type"])[0]
field_dict = dict({type: dict()})
for name, field in self.__fields__.items():
type = get_args(get_type_hints(cls)['type'])[0]
field_dict = dict({type:dict()})
for name,field in self.__fields__.items():
if name in cls._excluded_from_yaml():
continue
category = field.field_info.extra.get("category") or "Uncategorized"
value = getattr(self, name)
value = getattr(self,name)
if category not in field_dict[type]:
field_dict[type][category] = dict()
# keep paths as strings to make it easier to read
field_dict[type][category][name] = str(value) if isinstance(value, Path) else value
field_dict[type][category][name] = str(value) if isinstance(value,Path) else value
conf = OmegaConf.create(field_dict)
return OmegaConf.to_yaml(conf)
@classmethod
def add_parser_arguments(cls, parser):
if "type" in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
if 'type' in get_type_hints(cls):
settings_stanza = get_args(get_type_hints(cls)['type'])[0]
else:
settings_stanza = "Uncategorized"
env_prefix = cls.Config.env_prefix if hasattr(cls.Config, "env_prefix") else settings_stanza.upper()
env_prefix = cls.Config.env_prefix if hasattr(cls.Config,'env_prefix') else settings_stanza.upper()
initconf = (
cls.initconf.get(settings_stanza)
if cls.initconf and settings_stanza in cls.initconf
else OmegaConf.create()
)
initconf = cls.initconf.get(settings_stanza) \
if cls.initconf and settings_stanza in cls.initconf \
else OmegaConf.create()
# create an upcase version of the environment in
# order to achieve case-insensitive environment
# variables (the way Windows does)
upcase_environ = dict()
for key, value in os.environ.items():
for key,value in os.environ.items():
upcase_environ[key.upper()] = value
fields = cls.__fields__
@@ -240,8 +238,8 @@ class InvokeAISettings(BaseSettings):
if name not in cls._excluded():
current_default = field.default
category = field.field_info.extra.get("category", "Uncategorized")
env_name = env_prefix + "_" + name
category = field.field_info.extra.get("category","Uncategorized")
env_name = env_prefix + '_' + name
if category in initconf and name in initconf.get(category):
field.default = initconf.get(category).get(name)
if env_name.upper() in upcase_environ:
@@ -251,15 +249,15 @@ class InvokeAISettings(BaseSettings):
field.default = current_default
@classmethod
def cmd_name(self, command_field: str = "type") -> str:
def cmd_name(self, command_field: str='type')->str:
hints = get_type_hints(self)
if command_field in hints:
return get_args(hints[command_field])[0]
else:
return "Uncategorized"
return 'Uncategorized'
@classmethod
def get_parser(cls) -> ArgumentParser:
def get_parser(cls)->ArgumentParser:
parser = PagingArgumentParser(
prog=cls.cmd_name(),
description=cls.__doc__,
@@ -272,42 +270,24 @@ class InvokeAISettings(BaseSettings):
parser.add_parser(cls.cmd_name(), help=cls.__doc__)
@classmethod
def _excluded(self) -> List[str]:
def _excluded(self)->List[str]:
# internal fields that shouldn't be exposed as command line options
return ["type", "initconf", "cached_root"]
return ['type','initconf']
@classmethod
def _excluded_from_yaml(self) -> List[str]:
def _excluded_from_yaml(self)->List[str]:
# combination of deprecated parameters and internal ones that shouldn't be exposed as invokeai.yaml options
return [
"type",
"initconf",
"gpu_mem_reserved",
"max_loaded_models",
"version",
"from_file",
"model",
"restore",
"root",
"nsfw_checker",
"cached_root",
]
return ['type','initconf', 'gpu_mem_reserved', 'max_loaded_models', 'version', 'from_file', 'model', 'restore', 'root']
class Config:
env_file_encoding = "utf-8"
env_file_encoding = 'utf-8'
arbitrary_types_allowed = True
case_sensitive = True
@classmethod
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
def add_field_argument(cls, command_parser, name: str, field, default_override = None):
field_type = get_type_hints(cls).get(name)
default = (
default_override
if default_override is not None
else field.default
if field.default_factory is None
else field.default_factory()
)
default = default_override if default_override is not None else field.default if field.default_factory is None else field.default_factory()
if category := field.field_info.extra.get("category"):
if category not in cls.argparse_groups:
cls.argparse_groups[category] = command_parser.add_argument_group(category)
@@ -336,10 +316,10 @@ class InvokeAISettings(BaseSettings):
argparse_group.add_argument(
f"--{name}",
dest=name,
nargs="*",
nargs='*',
type=field.type_,
default=default,
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
help=field.field_info.description,
)
else:
@@ -348,35 +328,31 @@ class InvokeAISettings(BaseSettings):
dest=name,
type=field.type_,
default=default,
action=argparse.BooleanOptionalAction if field.type_ == bool else "store",
action=argparse.BooleanOptionalAction if field.type_==bool else 'store',
help=field.field_info.description,
)
def _find_root() -> Path:
def _find_root()->Path:
venv = Path(os.environ.get("VIRTUAL_ENV") or ".")
if os.environ.get("INVOKEAI_ROOT"):
root = Path(os.environ.get("INVOKEAI_ROOT")).resolve()
elif any([(venv.parent / x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE]]):
elif any([(venv.parent/x).exists() for x in [INIT_FILE, LEGACY_INIT_FILE, MODEL_CORE]]):
root = (venv.parent).resolve()
else:
root = Path("~/invokeai").expanduser().resolve()
return root
class InvokeAIAppConfig(InvokeAISettings):
"""
Generate images using Stable Diffusion. Use "invokeai" to launch
the command-line client (recommended for experts only), or
"invokeai-web" to launch the web server. Global options
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
setting environment variables INVOKEAI_<setting>.
"""
'''
Generate images using Stable Diffusion. Use "invokeai" to launch
the command-line client (recommended for experts only), or
"invokeai-web" to launch the web server. Global options
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
setting environment variables INVOKEAI_<setting>.
'''
singleton_config: ClassVar[InvokeAIAppConfig] = None
singleton_init: ClassVar[Dict] = None
# fmt: off
#fmt: off
type: Literal["InvokeAI"] = "InvokeAI"
host : str = Field(default="127.0.0.1", description="IP address to bind to", category='Web Server')
port : int = Field(default=9090, description="Port to bind to", category='Web Server')
@@ -388,6 +364,7 @@ class InvokeAIAppConfig(InvokeAISettings):
esrgan : bool = Field(default=True, description="Enable/disable upscaling code", category='Features')
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')
@@ -397,7 +374,6 @@ class InvokeAIAppConfig(InvokeAISettings):
max_cache_size : float = Field(default=6.0, gt=0, description="Maximum memory amount used by model cache for rapid switching", category='Memory/Performance')
max_vram_cache_size : float = Field(default=2.75, ge=0, description="Amount of VRAM reserved for model storage", category='Memory/Performance')
gpu_mem_reserved : float = Field(default=2.75, ge=0, description="DEPRECATED: use max_vram_cache_size. Amount of VRAM reserved for model storage", category='DEPRECATED')
nsfw_checker : bool = Field(default=True, description="DEPRECATED: use Web settings to enable/disable", category='DEPRECATED')
precision : Literal[tuple(['auto','float16','float32','autocast'])] = Field(default='auto',description='Floating point precision', category='Memory/Performance')
sequential_guidance : bool = Field(default=False, description="Whether to calculate guidance in serial instead of in parallel, lowering memory requirements", category='Memory/Performance')
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", category='Memory/Performance')
@@ -424,17 +400,16 @@ class InvokeAIAppConfig(InvokeAISettings):
log_level : Literal[tuple(["debug","info","warning","error","critical"])] = Field(default="info", description="Emit logging messages at this level or higher", category="Logging")
version : bool = Field(default=False, description="Show InvokeAI version and exit", category="Other")
cached_root : Path = Field(default=None, description="internal use only", category="DEPRECATED")
# fmt: on
#fmt: on
def parse_args(self, argv: List[str] = None, conf: DictConfig = None, clobber=False):
"""
def parse_args(self, argv: List[str]=None, conf: DictConfig = None, clobber=False):
'''
Update settings with contents of init file, environment, and
command-line settings.
:param conf: alternate Omegaconf dictionary object
:param argv: aternate sys.argv list
:param clobber: ovewrite any initialization parameters passed during initialization
"""
'''
# Set the runtime root directory. We parse command-line switches here
# in order to pick up the --root_dir option.
super().parse_args(argv)
@@ -451,144 +426,125 @@ class InvokeAIAppConfig(InvokeAISettings):
if self.singleton_init and not clobber:
hints = get_type_hints(self.__class__)
for k in self.singleton_init:
setattr(self, k, parse_obj_as(hints[k], self.singleton_init[k]))
setattr(self,k,parse_obj_as(hints[k],self.singleton_init[k]))
@classmethod
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
"""
def get_config(cls,**kwargs)->InvokeAIAppConfig:
'''
This returns a singleton InvokeAIAppConfig configuration object.
"""
if (
cls.singleton_config is None
or type(cls.singleton_config) != cls
or (kwargs and cls.singleton_init != kwargs)
):
'''
if cls.singleton_config is None \
or type(cls.singleton_config)!=cls \
or (kwargs and cls.singleton_init != kwargs):
cls.singleton_config = cls(**kwargs)
cls.singleton_init = kwargs
return cls.singleton_config
@property
def root_path(self) -> Path:
"""
def root_path(self)->Path:
'''
Path to the runtime root directory
"""
# we cache value of root to protect against it being '.' and the cwd changing
if self.cached_root:
root = self.cached_root
elif self.root:
root = Path(self.root).expanduser().absolute()
'''
if self.root:
return Path(self.root).expanduser().absolute()
else:
root = self.find_root()
self.cached_root = root
return self.cached_root
return self.find_root()
@property
def root_dir(self) -> Path:
"""
def root_dir(self)->Path:
'''
Alias for above.
"""
'''
return self.root_path
def _resolve(self, partial_path: Path) -> Path:
def _resolve(self,partial_path:Path)->Path:
return (self.root_path / partial_path).resolve()
@property
def init_file_path(self) -> Path:
"""
def init_file_path(self)->Path:
'''
Path to invokeai.yaml
"""
'''
return self._resolve(INIT_FILE)
@property
def output_path(self) -> Path:
"""
def output_path(self)->Path:
'''
Path to defaults outputs directory.
"""
'''
return self._resolve(self.outdir)
@property
def db_path(self) -> Path:
"""
def db_path(self)->Path:
'''
Path to the invokeai.db file.
"""
'''
return self._resolve(self.db_dir) / DB_FILE
@property
def model_conf_path(self) -> Path:
"""
def model_conf_path(self)->Path:
'''
Path to models configuration file.
"""
'''
return self._resolve(self.conf_path)
@property
def legacy_conf_path(self) -> Path:
"""
def legacy_conf_path(self)->Path:
'''
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
"""
'''
return self._resolve(self.legacy_conf_dir)
@property
def models_path(self) -> Path:
"""
def models_path(self)->Path:
'''
Path to the models directory
"""
'''
return self._resolve(self.models_dir)
@property
def autoconvert_path(self) -> Path:
"""
def autoconvert_path(self)->Path:
'''
Path to the directory containing models to be imported automatically at startup.
"""
'''
return self._resolve(self.autoconvert_dir) if self.autoconvert_dir else None
# the following methods support legacy calls leftover from the Globals era
@property
def full_precision(self) -> bool:
def full_precision(self)->bool:
"""Return true if precision set to float32"""
return self.precision == "float32"
return self.precision=='float32'
@property
def disable_xformers(self) -> bool:
def disable_xformers(self)->bool:
"""Return true if xformers_enabled is false"""
return not self.xformers_enabled
@property
def try_patchmatch(self) -> bool:
def try_patchmatch(self)->bool:
"""Return true if patchmatch true"""
return self.patchmatch
@property
def nsfw_checker(self) -> bool:
"""NSFW node is always active and disabled from Web UIe"""
return True
@property
def invisible_watermark(self) -> bool:
"""invisible watermark node is always active and disabled from Web UIe"""
return True
@staticmethod
def find_root() -> Path:
"""
def find_root()->Path:
'''
Choose the runtime root directory when not specified on command line or
init file.
"""
'''
return _find_root()
class PagingArgumentParser(argparse.ArgumentParser):
"""
'''
A custom ArgumentParser that uses pydoc to page its output.
It also supports reading defaults from an init file.
"""
'''
def print_help(self, file=None):
text = self.format_help()
pydoc.pager(text)
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
"""
def get_invokeai_config(**kwargs)->InvokeAIAppConfig:
'''
Legacy function which returns InvokeAIAppConfig.get_config()
"""
'''
return InvokeAIAppConfig.get_config(**kwargs)

View File

@@ -1,5 +1,4 @@
from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation
from ..invocations.image import ImageNSFWBlurInvocation
from ..invocations.noise import NoiseInvocation
from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation
@@ -7,80 +6,55 @@ from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Gr
from .item_storage import ItemStorageABC
default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
default_text_to_image_graph_id = '539b2af5-2b4d-4d8c-8071-e54a3255fc74'
def create_text_to_image() -> LibraryGraph:
return LibraryGraph(
id=default_text_to_image_graph_id,
name="t2i",
description="Converts text to an image",
name='t2i',
description='Converts text to an image',
graph=Graph(
nodes={
"width": ParamIntInvocation(id="width", a=512),
"height": ParamIntInvocation(id="height", a=512),
"seed": ParamIntInvocation(id="seed", a=-1),
"3": NoiseInvocation(id="3"),
"4": CompelInvocation(id="4"),
"5": CompelInvocation(id="5"),
"6": TextToLatentsInvocation(id="6"),
"7": LatentsToImageInvocation(id="7"),
"8": ImageNSFWBlurInvocation(id="8"),
'width': ParamIntInvocation(id='width', a=512),
'height': ParamIntInvocation(id='height', a=512),
'seed': ParamIntInvocation(id='seed', a=-1),
'3': NoiseInvocation(id='3'),
'4': CompelInvocation(id='4'),
'5': CompelInvocation(id='5'),
'6': TextToLatentsInvocation(id='6'),
'7': LatentsToImageInvocation(id='7'),
},
edges=[
Edge(
source=EdgeConnection(node_id="width", field="a"),
destination=EdgeConnection(node_id="3", field="width"),
),
Edge(
source=EdgeConnection(node_id="height", field="a"),
destination=EdgeConnection(node_id="3", field="height"),
),
Edge(
source=EdgeConnection(node_id="seed", field="a"),
destination=EdgeConnection(node_id="3", field="seed"),
),
Edge(
source=EdgeConnection(node_id="3", field="noise"),
destination=EdgeConnection(node_id="6", field="noise"),
),
Edge(
source=EdgeConnection(node_id="6", field="latents"),
destination=EdgeConnection(node_id="7", field="latents"),
),
Edge(
source=EdgeConnection(node_id="4", field="conditioning"),
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
),
Edge(
source=EdgeConnection(node_id="5", field="conditioning"),
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
),
Edge(
source=EdgeConnection(node_id="7", field="image"),
destination=EdgeConnection(node_id="8", field="image"),
),
],
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
Edge(source=EdgeConnection(node_id='height', field='a'), destination=EdgeConnection(node_id='3', field='height')),
Edge(source=EdgeConnection(node_id='seed', field='a'), destination=EdgeConnection(node_id='3', field='seed')),
Edge(source=EdgeConnection(node_id='3', field='noise'), destination=EdgeConnection(node_id='6', field='noise')),
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
]
),
exposed_inputs=[
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
ExposedNodeInput(node_path="width", field="a", alias="width"),
ExposedNodeInput(node_path="height", field="a", alias="height"),
ExposedNodeInput(node_path="seed", field="a", alias="seed"),
ExposedNodeInput(node_path='4', field='prompt', alias='positive_prompt'),
ExposedNodeInput(node_path='5', field='prompt', alias='negative_prompt'),
ExposedNodeInput(node_path='width', field='a', alias='width'),
ExposedNodeInput(node_path='height', field='a', alias='height'),
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
],
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
)
exposed_outputs=[
ExposedNodeOutput(node_path='7', field='image', alias='image')
])
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
graphs: list[LibraryGraph] = list()
# text_to_image = graph_library.get(default_text_to_image_graph_id)
# # TODO: Check if the graph is the same as the default one, and if not, update it
# #if text_to_image is None:
text_to_image = create_text_to_image()

View File

@@ -3,13 +3,7 @@
from typing import Any, Optional
from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp
from invokeai.app.services.model_manager_service import (
BaseModelType,
ModelType,
SubModelType,
ModelInfo,
)
from invokeai.app.services.model_manager_service import BaseModelType, ModelType, SubModelType, ModelInfo
class EventServiceBase:
session_event: str = "session_event"
@@ -73,7 +67,6 @@ class EventServiceBase:
graph_execution_state_id: str,
node: dict,
source_node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when an invocation has completed"""
@@ -83,12 +76,13 @@ class EventServiceBase:
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
error_type=error_type,
error=error,
),
)
def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None:
def emit_invocation_started(
self, graph_execution_state_id: str, node: dict, source_node_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
event_name="invocation_started",
@@ -108,13 +102,13 @@ class EventServiceBase:
),
)
def emit_model_load_started(
self,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
def emit_model_load_started (
self,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
) -> None:
"""Emitted when a model is requested"""
self.__emit_session_event(
@@ -129,13 +123,13 @@ class EventServiceBase:
)
def emit_model_load_completed(
self,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_info: ModelInfo,
self,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel: SubModelType,
model_info: ModelInfo,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_session_event(
@@ -151,37 +145,3 @@ class EventServiceBase:
precision=str(model_info.precision),
),
)
def emit_session_retrieval_error(
self,
graph_execution_state_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when session retrieval fails"""
self.__emit_session_event(
event_name="session_retrieval_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
error_type=error_type,
error=error,
),
)
def emit_invocation_retrieval_error(
self,
graph_execution_state_id: str,
node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when invocation retrieval fails"""
self.__emit_session_event(
event_name="invocation_retrieval_error",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
node_id=node_id,
error_type=error_type,
error=error,
),
)

View File

@@ -28,7 +28,6 @@ from ..invocations.baseinvocation import (
# in 3.10 this would be "from types import NoneType"
NoneType = type(None)
class EdgeConnection(BaseModel):
node_id: str = Field(description="The id of the node for this edge connection")
field: str = Field(description="The field for this connection")
@@ -62,7 +61,6 @@ def get_input_field(node: BaseInvocation, field: str) -> Any:
node_input_field = node_inputs.get(field) or None
return node_input_field
def is_union_subtype(t1, t2):
t1_args = get_args(t1)
t2_args = get_args(t2)
@@ -73,7 +71,6 @@ def is_union_subtype(t1, t2):
# t1 is a Union, check that all of its types are in t2_args
return all(arg in t2_args for arg in t1_args)
def is_list_or_contains_list(t):
t_args = get_args(t)
@@ -157,17 +154,15 @@ class GraphInvocationOutput(BaseInvocationOutput):
class Config:
schema_extra = {
"required": [
"type",
"image",
'required': [
'type',
'image',
]
}
# TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation):
"""Execute a graph"""
type: Literal["graph"] = "graph"
# TODO: figure out how to create a default here
@@ -187,21 +182,23 @@ class IterateInvocationOutput(BaseInvocationOutput):
class Config:
schema_extra = {
"required": [
"type",
"item",
'required': [
'type',
'item',
]
}
# TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation):
"""Iterates over a list of items"""
type: Literal["iterate"] = "iterate"
collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list)
index: int = Field(description="The index, will be provided on executed iterators", default=0)
collection: list[Any] = Field(
description="The list of items to iterate over", default_factory=list
)
index: int = Field(
description="The index, will be provided on executed iterators", default=0
)
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
"""Produces the outputs as values"""
@@ -215,13 +212,12 @@ class CollectInvocationOutput(BaseInvocationOutput):
class Config:
schema_extra = {
"required": [
"type",
"collection",
'required': [
'type',
'collection',
]
}
class CollectInvocation(BaseInvocation):
"""Collects values into a collection"""
@@ -273,7 +269,9 @@ class Graph(BaseModel):
if node_path in self.nodes:
return (self, node_path)
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
node_id = (
node_path if "." not in node_path else node_path[: node_path.index(".")]
)
if node_id not in self.nodes:
raise NodeNotFoundError(f"Node {node_path} not found in graph")
@@ -335,7 +333,9 @@ class Graph(BaseModel):
return False
# Validate all edges reference nodes in the graph
node_ids = set([e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges])
node_ids = set(
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
)
if not all((self.has_node(node_id) for node_id in node_ids)):
return False
@@ -361,14 +361,22 @@ class Graph(BaseModel):
# Validate all iterators
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
if not all(
(self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))
(
self._is_iterator_connection_valid(n.id)
for n in self.nodes.values()
if isinstance(n, IterateInvocation)
)
):
return False
# Validate all collectors
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
if not all(
(self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))
(
self._is_collector_connection_valid(n.id)
for n in self.nodes.values()
if isinstance(n, CollectInvocation)
)
):
return False
@@ -387,51 +395,48 @@ class Graph(BaseModel):
# Validate that an edge to this node+field doesn't already exist
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
raise InvalidEdgeError(
f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists"
)
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
# Validate that no cycles would be created
g = self.nx_graph_flat()
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
raise InvalidEdgeError(
f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}"
)
raise InvalidEdgeError(f'Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}')
# Validate that the field types are compatible
if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field):
raise InvalidEdgeError(
f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
if not are_connections_compatible(
from_node, edge.source.field, to_node, edge.destination.field
):
raise InvalidEdgeError(f'Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source):
raise InvalidEdgeError(
f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
if not self._is_iterator_connection_valid(
edge.destination.node_id, new_input=edge.source
):
raise InvalidEdgeError(f'Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if iterator input type matches output type (if this edge results in both being set)
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination):
raise InvalidEdgeError(
f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
if not self._is_iterator_connection_valid(
edge.source.node_id, new_output=edge.destination
):
raise InvalidEdgeError(f'Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if collector input type matches output type (if this edge results in both being set)
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source):
raise InvalidEdgeError(
f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
if not self._is_collector_connection_valid(
edge.destination.node_id, new_input=edge.source
):
raise InvalidEdgeError(f'Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
# Validate if collector output type matches input type (if this edge results in both being set)
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination):
raise InvalidEdgeError(
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
)
if not self._is_collector_connection_valid(
edge.source.node_id, new_output=edge.destination
):
raise InvalidEdgeError(f'Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}')
def has_node(self, node_path: str) -> bool:
"""Determines whether or not a node exists in the graph."""
@@ -460,13 +465,17 @@ class Graph(BaseModel):
# Ensure the node type matches the new node
if type(node) != type(new_node):
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
raise TypeError(
f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}"
)
# Ensure the new id is either the same or is not in the graph
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
new_path = self._get_node_path(new_node.id, prefix=prefix)
if new_node.id != node.id and self.has_node(new_path):
raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph")
raise NodeAlreadyInGraphError(
"Node with id {new_node.id} already exists in graph"
)
# Set the new node in the graph
graph.nodes[new_node.id] = new_node
@@ -488,7 +497,9 @@ class Graph(BaseModel):
graph.add_edge(
Edge(
source=edge.source,
destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field),
destination=EdgeConnection(
node_id=new_graph_node_path, field=edge.destination.field
)
)
)
@@ -501,12 +512,16 @@ class Graph(BaseModel):
)
graph.add_edge(
Edge(
source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field),
destination=edge.destination,
source=EdgeConnection(
node_id=new_graph_node_path, field=edge.source.field
),
destination=edge.destination
)
)
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]:
def _get_input_edges(
self, node_path: str, field: Optional[str] = None
) -> list[Edge]:
"""Gets all input edges for a node"""
edges = self._get_input_edges_and_graphs(node_path)
@@ -523,7 +538,7 @@ class Graph(BaseModel):
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
),
)
)
for _, prefix, e in filtered_edges
]
@@ -535,20 +550,32 @@ class Graph(BaseModel):
edges = list()
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
edges.extend(
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
)
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
node_id = (
node_path if "." not in node_path else node_path[: node_path.index(".")]
)
node = self.nodes[node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
graph_path = (
node.id
if prefix is None or prefix == ""
else self._get_node_path(node.id, prefix=prefix)
)
graph_edges = graph._get_input_edges_and_graphs(
node_path[(len(node_id) + 1) :], prefix=graph_path
)
edges.extend(graph_edges)
return edges
def _get_output_edges(self, node_path: str, field: str) -> list[Edge]:
def _get_output_edges(
self, node_path: str, field: str
) -> list[Edge]:
"""Gets all output edges for a node"""
edges = self._get_output_edges_and_graphs(node_path)
@@ -565,7 +592,7 @@ class Graph(BaseModel):
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
),
)
)
for _, prefix, e in filtered_edges
]
@@ -577,15 +604,25 @@ class Graph(BaseModel):
edges = list()
# Return any input edges that appear in this graph
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
edges.extend(
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
)
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
node_id = (
node_path if "." not in node_path else node_path[: node_path.index(".")]
)
node = self.nodes[node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
graph_path = (
node.id
if prefix is None or prefix == ""
else self._get_node_path(node.id, prefix=prefix)
)
graph_edges = graph._get_output_edges_and_graphs(
node_path[(len(node_id) + 1) :], prefix=graph_path
)
edges.extend(graph_edges)
return edges
@@ -609,8 +646,12 @@ class Graph(BaseModel):
return False
# Get input and output fields (the fields linked to the iterator's input/output)
input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field)
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
input_field = get_output_field(
self.get_node(inputs[0].node_id), inputs[0].field
)
output_fields = list(
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
)
# Input type must be a list
if get_origin(input_field) != list:
@@ -618,7 +659,12 @@ class Graph(BaseModel):
# Validate that all outputs match the input type
input_field_item_type = get_args(input_field)[0]
if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)):
if not all(
(
are_connection_types_compatible(input_field_item_type, f)
for f in output_fields
)
):
return False
return True
@@ -638,21 +684,35 @@ class Graph(BaseModel):
outputs.append(new_output)
# Get input and output fields (the fields linked to the iterator's input/output)
input_fields = list([get_output_field(self.get_node(e.node_id), e.field) for e in inputs])
output_fields = list([get_input_field(self.get_node(e.node_id), e.field) for e in outputs])
input_fields = list(
[get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
)
output_fields = list(
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
)
# Validate that all inputs are derived from or match a single type
input_field_types = set(
[
t
for input_field in input_fields
for t in ([input_field] if get_origin(input_field) == None else get_args(input_field))
for t in (
[input_field]
if get_origin(input_field) == None
else get_args(input_field)
)
if t != NoneType
]
) # Get unique types
type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types)
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
type_tree.add_edges_from(
[
e
for e in itertools.permutations(input_field_types, 2)
if issubclass(e[1], e[0])
]
)
type_degrees = type_tree.in_degree(type_tree.nodes)
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
return False # There is more than one root type
@@ -669,7 +729,9 @@ class Graph(BaseModel):
return False
# Verify that all outputs match the input type (are a base class or the same class)
if not all((issubclass(input_root_type, get_args(f)[0]) for f in output_fields)):
if not all(
(issubclass(input_root_type, get_args(f)[0]) for f in output_fields)
):
return False
return True
@@ -689,7 +751,9 @@ class Graph(BaseModel):
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
def nx_graph_flat(
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
) -> nx.DiGraph:
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
g = nx_graph or nx.DiGraph()
@@ -698,18 +762,26 @@ class Graph(BaseModel):
[
self._get_node_path(n.id, prefix)
for n in self.nodes.values()
if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)
if not isinstance(n, GraphInvocation)
and not isinstance(n, IterateInvocation)
]
)
# Expand graph nodes
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
for sgn in (
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
):
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
# TODO: figure out if iteration nodes need to be expanded
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
g.add_edges_from(
[
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
for e in unique_edges
]
)
return g
@@ -728,19 +800,23 @@ class GraphExecutionState(BaseModel):
)
# Nodes that have been executed
executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set)
executed: set[str] = Field(
description="The set of node ids that have been executed", default_factory=set
)
executed_history: list[str] = Field(
description="The list of node ids that have been executed, in order of execution",
default_factory=list,
)
# The results of executed nodes
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
description="The results of node executions", default_factory=dict
)
results: dict[
str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]
] = Field(description="The results of node executions", default_factory=dict)
# Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
errors: dict[str, str] = Field(
description="Errors raised when executing nodes", default_factory=dict
)
# Map of prepared/executed nodes to their original nodes
prepared_source_mapping: dict[str, str] = Field(
@@ -756,16 +832,16 @@ class GraphExecutionState(BaseModel):
class Config:
schema_extra = {
"required": [
"id",
"graph",
"execution_graph",
"executed",
"executed_history",
"results",
"errors",
"prepared_source_mapping",
"source_prepared_mapping",
'required': [
'id',
'graph',
'execution_graph',
'executed',
'executed_history',
'results',
'errors',
'prepared_source_mapping',
'source_prepared_mapping',
]
}
@@ -823,7 +899,9 @@ class GraphExecutionState(BaseModel):
"""Returns true if the graph has any errors"""
return len(self.errors) > 0
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
def _create_execution_node(
self, node_path: str, iteration_node_map: list[tuple[str, str]]
) -> list[str]:
"""Prepares an iteration node and connects all edges, returning the new node id"""
node = self.graph.get_node(node_path)
@@ -833,12 +911,20 @@ class GraphExecutionState(BaseModel):
# If this is an iterator node, we must create a copy for each iteration
if isinstance(node, IterateInvocation):
# Get input collection edge (should error if there are no inputs)
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection")))
input_collection_prepared_node_id = next(
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
input_collection_edge = next(
iter(self.graph._get_input_edges(node_path, "collection"))
)
input_collection_prepared_node_id = next(
n[1]
for n in iteration_node_map
if n[0] == input_collection_edge.source.node_id
)
input_collection_prepared_node_output = self.results[
input_collection_prepared_node_id
]
input_collection = getattr(
input_collection_prepared_node_output, input_collection_edge.source.field
)
input_collection_prepared_node_output = self.results[input_collection_prepared_node_id]
input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field)
self_iteration_count = len(input_collection)
new_nodes = list()
@@ -853,7 +939,9 @@ class GraphExecutionState(BaseModel):
# For collect nodes, this may contain multiple inputs to the same field
new_edges = list()
for edge in input_edges:
for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id):
for input_node_id in (
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
):
new_edge = Edge(
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
destination=EdgeConnection(node_id="", field=edge.destination.field),
@@ -894,7 +982,11 @@ class GraphExecutionState(BaseModel):
def _iterator_graph(self) -> nx.DiGraph:
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
g = self.graph.nx_graph_flat()
collectors = (n for n in self.graph.nodes if isinstance(self.graph.get_node(n), CollectInvocation))
collectors = (
n
for n in self.graph.nodes
if isinstance(self.graph.get_node(n), CollectInvocation)
)
for c in collectors:
g.remove_edges_from(list(g.in_edges(c)))
return g
@@ -902,7 +994,11 @@ class GraphExecutionState(BaseModel):
def _get_node_iterators(self, node_id: str) -> list[str]:
"""Gets iterators for a node"""
g = self._iterator_graph()
iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)]
iterators = [
n
for n in nx.ancestors(g, node_id)
if isinstance(self.graph.get_node(n), IterateInvocation)
]
return iterators
def _prepare(self) -> Optional[str]:
@@ -949,18 +1045,29 @@ class GraphExecutionState(BaseModel):
if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation
all_iteration_mappings = list(
itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))
itertools.chain(
*(
((s, p) for p in self.source_prepared_mapping[s])
for s in next_node_parents
)
)
)
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
create_results = self._create_execution_node(
next_node_id, all_iteration_mappings
)
if create_results is not None:
new_node_ids.extend(create_results)
else: # Iterators or normal nodes
# Get all iterator combinations for this node
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
iterator_nodes = self._get_node_iterators(next_node_id)
iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes]
iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared))
iterator_nodes_prepared = [
list(self.source_prepared_mapping[n]) for n in iterator_nodes
]
iterator_node_prepared_combinations = list(
itertools.product(*iterator_nodes_prepared)
)
# Select the correct prepared parents for each iteration
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
@@ -989,16 +1096,31 @@ class GraphExecutionState(BaseModel):
return next(iter(prepared_nodes))
# Check if the requested node is an iterator
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
prepared_iterator = next(
(n for n in prepared_nodes if n in prepared_iterator_nodes), None
)
if prepared_iterator is not None:
return prepared_iterator
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
iterator_source_node_mapping = [
(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes
]
parent_iterators = [
itn
for itn in iterator_source_node_mapping
if nx.has_path(graph, itn[1], source_node_path)
]
return next(
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
(
n
for n in prepared_nodes
if all(
nx.has_path(execution_graph, pit[0], n)
for pit in parent_iterators
)
),
None,
)
@@ -1008,13 +1130,13 @@ class GraphExecutionState(BaseModel):
# Depth-first search with pre-order traversal is a depth-first topological sort
sorted_nodes = nx.dfs_preorder_nodes(g)
next_node = next(
(
n
for n in sorted_nodes
if n not in self.executed # the node must not already be executed...
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
if n not in self.executed # the node must not already be executed...
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
),
None,
)
@@ -1099,18 +1221,15 @@ class ExposedNodeOutput(BaseModel):
field: str = Field(description="The field name of the output")
alias: str = Field(description="The alias of the output")
class LibraryGraph(BaseModel):
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
graph: Graph = Field(description="The graph")
name: str = Field(description="The name of the graph")
description: str = Field(description="The description of the graph")
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
exposed_outputs: list[ExposedNodeOutput] = Field(
description="The outputs exposed by this graph", default_factory=list
)
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
@validator("exposed_inputs", "exposed_outputs")
@validator('exposed_inputs', 'exposed_outputs')
def validate_exposed_aliases(cls, v):
if len(v) != len(set(i.alias for i in v)):
raise ValueError("Duplicate exposed alias")
@@ -1118,27 +1237,23 @@ class LibraryGraph(BaseModel):
@root_validator
def validate_exposed_nodes(cls, values):
graph = values["graph"]
graph = values['graph']
# Validate exposed inputs
for exposed_input in values["exposed_inputs"]:
for exposed_input in values['exposed_inputs']:
if not graph.has_node(exposed_input.node_path):
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
node = graph.get_node(exposed_input.node_path)
if get_input_field(node, exposed_input.field) is None:
raise ValueError(
f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}"
)
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
# Validate exposed outputs
for exposed_output in values["exposed_outputs"]:
for exposed_output in values['exposed_outputs']:
if not graph.has_node(exposed_output.node_path):
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
node = graph.get_node(exposed_output.node_path)
if get_output_field(node, exposed_output.field) is None:
raise ValueError(
f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}"
)
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
return values

View File

@@ -85,7 +85,9 @@ class DiskImageFileStorage(ImageFileStorageBase):
self.__cache_ids = Queue()
self.__max_cache_size = 10 # TODO: get this from config
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__output_folder: Path = (
output_folder if isinstance(output_folder, Path) else Path(output_folder)
)
self.__thumbnails_folder = self.__output_folder / "thumbnails"
# Validate required output folders at launch
@@ -118,7 +120,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image_path = self.get_path(image_name)
pnginfo = PngImagePlugin.PngInfo()
if metadata is not None:
pnginfo.add_text("invokeai_metadata", json.dumps(metadata))
if graph is not None:
@@ -181,7 +183,9 @@ class DiskImageFileStorage(ImageFileStorageBase):
def __set_cache(self, image_name: Path, image: PILImageType):
if not image_name in self.__cache:
self.__cache[image_name] = image
self.__cache_ids.put(image_name) # TODO: this should refresh position for LRU cache
self.__cache_ids.put(
image_name
) # TODO: this should refresh position for LRU cache
if len(self.__cache) > self.__max_cache_size:
cache_id = self.__cache_ids.get()
if cache_id in self.__cache:

View File

@@ -426,7 +426,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally:
self._lock.release()
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
return OffsetPaginatedResults(
items=images, offset=offset, limit=limit, total=count
)
def delete(self, image_name: str) -> None:
try:
@@ -464,6 +466,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally:
self._lock.release()
def delete_intermediates(self) -> list[str]:
try:
self._lock.acquire()
@@ -502,7 +505,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
is_intermediate: bool = False,
) -> datetime:
try:
metadata_json = None if metadata is None else json.dumps(metadata)
metadata_json = (
None if metadata is None else json.dumps(metadata)
)
self._lock.acquire()
self._cursor.execute(
"""--sql

View File

@@ -216,9 +216,16 @@ class ImageService(ImageServiceABC):
metadata=metadata,
session_id=session_id,
)
if board_id is not None:
self._services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self._services.image_files.save(image_name=image_name, image=image, metadata=metadata, graph=graph)
self._services.board_image_records.add_image_to_board(
board_id=board_id, image_name=image_name
)
self._services.image_files.save(
image_name=image_name, image=image, metadata=metadata, graph=graph
)
image_dto = self.get_dto(image_name)
return image_dto
@@ -229,7 +236,7 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Failed to save image file")
raise
except Exception as e:
self._services.logger.error(f"Problem saving image record and file: {str(e)}")
self._services.logger.error("Problem saving image record and file")
raise e
def update(
@@ -293,7 +300,9 @@ class ImageService(ImageServiceABC):
if not image_record.session_id:
return ImageMetadata()
session_raw = self._services.graph_execution_manager.get_raw(image_record.session_id)
session_raw = self._services.graph_execution_manager.get_raw(
image_record.session_id
)
graph = None
if session_raw:
@@ -358,7 +367,9 @@ class ImageService(ImageServiceABC):
r,
self._services.urls.get_image_url(r.image_name),
self._services.urls.get_image_url(r.image_name, True),
self._services.board_image_records.get_board_for_image(r.image_name),
self._services.board_image_records.get_board_for_image(
r.image_name
),
),
results.items,
)
@@ -390,7 +401,11 @@ class ImageService(ImageServiceABC):
def delete_images_on_board(self, board_id: str):
try:
image_names = self._services.board_image_records.get_all_board_image_names_for_board(board_id)
image_names = (
self._services.board_image_records.get_all_board_image_names_for_board(
board_id
)
)
for image_name in image_names:
self._services.image_files.delete(image_name)
self._services.image_records.delete_many(image_names)

View File

@@ -7,7 +7,6 @@ from queue import Queue
from pydantic import BaseModel, Field
from typing import Optional
class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
invocation_id: str = Field(description="The ID of the node being invoked")
@@ -46,11 +45,9 @@ class MemoryInvocationQueue(InvocationQueueABC):
def get(self) -> InvocationQueueItem:
item = self.__queue.get()
while (
isinstance(item, InvocationQueueItem)
and item.graph_execution_state_id in self.__cancellations
and self.__cancellations[item.graph_execution_state_id] > item.timestamp
):
while isinstance(item, InvocationQueueItem) \
and item.graph_execution_state_id in self.__cancellations \
and self.__cancellations[item.graph_execution_state_id] > item.timestamp:
item = self.__queue.get()
# Clear old items

View File

@@ -7,7 +7,6 @@ from .graph import Graph, GraphExecutionState
from .invocation_queue import InvocationQueueItem
from .invocation_services import InvocationServices
class Invoker:
"""The invoker, used to execute invocations"""
@@ -17,7 +16,9 @@ class Invoker:
self.services = services
self._start()
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
def invoke(
self, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""

View File

@@ -9,15 +9,13 @@ T = TypeVar("T", bound=BaseModel)
class PaginatedResults(GenericModel, Generic[T]):
"""Paginated results"""
# fmt: off
#fmt: off
items: list[T] = Field(description="Items")
page: int = Field(description="Current Page")
pages: int = Field(description="Total number of pages")
per_page: int = Field(description="Number of items per page")
total: int = Field(description="Total number of items in result")
# fmt: on
#fmt: on
class ItemStorageABC(ABC, Generic[T]):
_on_changed_callbacks: list[Callable[[T], None]]
@@ -50,7 +48,9 @@ class ItemStorageABC(ABC, Generic[T]):
pass
@abstractmethod
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]:
pass
def on_changed(self, on_changed: Callable[[T], None]) -> None:

View File

@@ -7,7 +7,6 @@ from typing import Dict, Union, Optional
import torch
class LatentsStorageBase(ABC):
"""Responsible for storing and retrieving latents."""
@@ -26,7 +25,7 @@ class LatentsStorageBase(ABC):
class ForwardCacheLatentsStorage(LatentsStorageBase):
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
__cache: Dict[str, torch.Tensor]
__cache_ids: Queue
__max_cache_size: int
@@ -88,6 +87,8 @@ class DiskLatentsStorage(LatentsStorageBase):
def delete(self, name: str) -> None:
latent_path = self.get_path(name)
latent_path.unlink()
def get_path(self, name: str) -> Path:
return self.__output_folder / name

View File

@@ -103,7 +103,7 @@ class ModelManagerServiceBase(ABC):
}
"""
pass
@abstractmethod
def list_model(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
"""
@@ -125,7 +125,7 @@ class ModelManagerServiceBase(ABC):
base_model: BaseModelType,
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
clobber: bool = False
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with an
@@ -148,12 +148,12 @@ class ModelManagerServiceBase(ABC):
Update the named model with a dictionary of attributes. Will fail with a
ModelNotFoundException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
pass
@abstractmethod
def del_model(
self,
@@ -169,20 +169,21 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str,
):
def rename_model(self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str,
):
"""
Rename the indicated model.
"""
pass
@abstractmethod
def list_checkpoint_configs(self) -> List[Path]:
def list_checkpoint_configs(
self
)->List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
@@ -193,7 +194,7 @@ class ModelManagerServiceBase(ABC):
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main, ModelType.Vae],
model_type: Union[ModelType.Main,ModelType.Vae],
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
@@ -210,12 +211,11 @@ class ModelManagerServiceBase(ABC):
pass
@abstractmethod
def heuristic_import(
self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
def heuristic_import(self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
)->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
@@ -230,23 +230,19 @@ class ModelManagerServiceBase(ABC):
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
"""
'''
pass
@abstractmethod
def merge_models(
self,
model_names: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model names to merge"
),
base_model: Union[BaseModelType, str] = Field(
default=None, description="Base model shared by all models to be merged"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None,
self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = None
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
@@ -254,27 +250,27 @@ class ModelManagerServiceBase(ABC):
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
pass
@abstractmethod
def search_for_models(self, directory: Path) -> List[Path]:
def search_for_models(self, directory: Path)->List[Path]:
"""
Return list of all models found in the designated directory.
"""
pass
@abstractmethod
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
pass
@abstractmethod
def commit(self, conf_file: Optional[Path] = None) -> None:
"""
@@ -284,11 +280,9 @@ class ModelManagerServiceBase(ABC):
"""
pass
# simple implementation
class ModelManagerService(ModelManagerServiceBase):
"""Responsible for managing models on disk and in memory"""
def __init__(
self,
config: InvokeAIAppConfig,
@@ -304,17 +298,17 @@ class ModelManagerService(ModelManagerServiceBase):
config_file = config.model_conf_path
else:
config_file = config.root_dir / "configs/models.yaml"
logger.debug(f"Config file={config_file}")
logger.debug(f'Config file={config_file}')
device = torch.device(choose_torch_device())
device_name = torch.cuda.get_device_name() if device == torch.device("cuda") else ""
logger.info(f"GPU device = {device} {device_name}")
device_name = torch.cuda.get_device_name() if device==torch.device('cuda') else ''
logger.info(f'GPU device = {device} {device_name}')
precision = config.precision
if precision == "auto":
precision = choose_precision(device)
dtype = torch.float32 if precision == "float32" else torch.float16
dtype = torch.float32 if precision == 'float32' else torch.float16
# this is transitional backward compatibility
# support for the deprecated `max_loaded_models`
@@ -322,7 +316,9 @@ class ModelManagerService(ModelManagerServiceBase):
# cache size is set to 2.5 GB times
# the number of max_loaded_models. Otherwise
# use new `max_cache_size` config setting
max_cache_size = config.max_cache_size if hasattr(config, "max_cache_size") else config.max_loaded_models * 2.5
max_cache_size = config.max_cache_size \
if hasattr(config,'max_cache_size') \
else config.max_loaded_models * 2.5
logger.debug(f"Maximum RAM cache size: {max_cache_size} GiB")
@@ -336,7 +332,7 @@ class ModelManagerService(ModelManagerServiceBase):
sequential_offload=sequential_offload,
logger=logger,
)
logger.info("Model manager service initialized")
logger.info('Model manager service initialized')
def get_model(
self,
@@ -375,7 +371,7 @@ class ModelManagerService(ModelManagerServiceBase):
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info,
model_info=model_info
)
return model_info
@@ -409,7 +405,9 @@ class ModelManagerService(ModelManagerServiceBase):
return self.mgr.model_names()
def list_models(
self, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None
self,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None
) -> list[dict]:
"""
Return a list of models.
@@ -420,7 +418,9 @@ class ModelManagerService(ModelManagerServiceBase):
"""
Return information about the model using the same format as list_models()
"""
return self.mgr.list_model(model_name=model_name, base_model=base_model, model_type=model_type)
return self.mgr.list_model(model_name=model_name,
base_model=base_model,
model_type=model_type)
def add_model(
self,
@@ -429,7 +429,7 @@ class ModelManagerService(ModelManagerServiceBase):
model_type: ModelType,
model_attributes: dict,
clobber: bool = False,
) -> None:
)->None:
"""
Update the named model with a dictionary of attributes. Will fail with an
assertion error if the name already exists. Pass clobber=True to overwrite.
@@ -437,7 +437,7 @@ class ModelManagerService(ModelManagerServiceBase):
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f"add/update model {model_name}")
self.logger.debug(f'add/update model {model_name}')
return self.mgr.add_model(model_name, base_model, model_type, model_attributes, clobber)
def update_model(
@@ -450,15 +450,15 @@ class ModelManagerService(ModelManagerServiceBase):
"""
Update the named model with a dictionary of attributes. Will fail with a
ModelNotFoundException exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f"update model {model_name}")
self.logger.debug(f'update model {model_name}')
if not self.model_exists(model_name, base_model, model_type):
raise ModelNotFoundException(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model(
self,
model_name: str,
@@ -470,7 +470,7 @@ class ModelManagerService(ModelManagerServiceBase):
then the underlying weight file or diffusers directory will be deleted
as well.
"""
self.logger.debug(f"delete model {model_name}")
self.logger.debug(f'delete model {model_name}')
self.mgr.del_model(model_name, base_model, model_type)
self.mgr.commit()
@@ -478,10 +478,8 @@ class ModelManagerService(ModelManagerServiceBase):
self,
model_name: str,
base_model: BaseModelType,
model_type: Union[ModelType.Main, ModelType.Vae],
convert_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
model_type: Union[ModelType.Main,ModelType.Vae],
convert_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
) -> AddModelResult:
"""
Convert a checkpoint file into a diffusers folder, deleting the cached
@@ -496,10 +494,10 @@ class ModelManagerService(ModelManagerServiceBase):
also raise a ValueError in the event that there is a similarly-named diffusers
directory already in place.
"""
self.logger.debug(f"convert model {model_name}")
self.logger.debug(f'convert model {model_name}')
return self.mgr.convert_model(model_name, base_model, model_type, convert_dest_directory)
def commit(self, conf_file: Optional[Path] = None):
def commit(self, conf_file: Optional[Path]=None):
"""
Write current configuration out to the indicated file.
If no conf_file is provided, then replaces the
@@ -526,7 +524,7 @@ class ModelManagerService(ModelManagerServiceBase):
base_model=base_model,
model_type=model_type,
submodel=submodel,
model_info=model_info,
model_info=model_info
)
else:
context.services.events.emit_model_load_started(
@@ -537,16 +535,16 @@ class ModelManagerService(ModelManagerServiceBase):
submodel=submodel,
)
@property
def logger(self):
return self.mgr.logger
def heuristic_import(
self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
) -> dict[str, AddModelResult]:
"""Import a list of paths, repo_ids or URLs. Returns the set of
def heuristic_import(self,
items_to_import: set[str],
prediction_type_helper: Optional[Callable[[Path],SchedulerPredictionType]]=None,
)->dict[str, AddModelResult]:
'''Import a list of paths, repo_ids or URLs. Returns the set of
successfully imported items.
:param items_to_import: Set of strings corresponding to models to be imported.
:param prediction_type_helper: A callback that receives the Path of a Stable Diffusion 2 checkpoint model and returns a SchedulerPredictionType.
@@ -561,24 +559,18 @@ class ModelManagerService(ModelManagerServiceBase):
The result is a set of successfully installed models. Each element
of the set is a dict corresponding to the newly-created OmegaConf stanza for
that model.
"""
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
'''
return self.mgr.heuristic_import(items_to_import, prediction_type_helper)
def merge_models(
self,
model_names: List[str] = Field(
default=None, min_items=2, max_items=3, description="List of model names to merge"
),
base_model: Union[BaseModelType, str] = Field(
default=None, description="Base model shared by all models to be merged"
),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = Field(
default=None, description="Optional directory location for merged model"
),
self,
model_names: List[str] = Field(default=None, min_items=2, max_items=3, description="List of model names to merge"),
base_model: Union[BaseModelType,str] = Field(default=None, description="Base model shared by all models to be merged"),
merged_model_name: str = Field(default=None, description="Name of destination model after merging"),
alpha: Optional[float] = 0.5,
interp: Optional[MergeInterpolationMethod] = None,
force: Optional[bool] = False,
merge_dest_directory: Optional[Path] = Field(default=None, description="Optional directory location for merged model"),
) -> AddModelResult:
"""
Merge two to three diffusrs pipeline models and save as a new model.
@@ -586,56 +578,55 @@ class ModelManagerService(ModelManagerServiceBase):
:param base_model: Base model to use for all models
:param merged_model_name: Name of destination merged model
:param alpha: Alpha strength to apply to 2d and 3d model
:param interp: Interpolation method. None (default)
:param interp: Interpolation method. None (default)
:param merge_dest_directory: Save the merged model to the designated directory (with 'merged_model_name' appended)
"""
merger = ModelMerger(self.mgr)
try:
result = merger.merge_diffusion_models_and_save(
model_names=model_names,
base_model=base_model,
merged_model_name=merged_model_name,
alpha=alpha,
interp=interp,
force=force,
model_names = model_names,
base_model = base_model,
merged_model_name = merged_model_name,
alpha = alpha,
interp = interp,
force = force,
merge_dest_directory=merge_dest_directory,
)
except AssertionError as e:
raise ValueError(e)
return result
def search_for_models(self, directory: Path) -> List[Path]:
def search_for_models(self, directory: Path)->List[Path]:
"""
Return list of all models found in the designated directory.
"""
search = FindModels([directory], self.logger)
search = FindModels(directory,self.logger)
return search.list_models()
def sync_to_config(self):
"""
Re-read models.yaml, rescan the models directory, and reimport models
Re-read models.yaml, rescan the models directory, and reimport models
in the autoimport directories. Call after making changes outside the
model manager API.
"""
return self.mgr.sync_to_config()
def list_checkpoint_configs(self) -> List[Path]:
def list_checkpoint_configs(self)->List[Path]:
"""
List the checkpoint config paths from ROOT/configs/stable-diffusion.
"""
config = self.mgr.app_config
conf_path = config.legacy_conf_path
root_path = config.root_path
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob("**/*.yaml")]
return [(conf_path / x).relative_to(root_path) for x in conf_path.glob('**/*.yaml')]
def rename_model(
self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
):
def rename_model(self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
new_name: str = None,
new_base: BaseModelType = None,
):
"""
Rename the indicated model. Can provide a new name and/or a new base.
:param model_name: Current name of the model
@@ -644,10 +635,10 @@ class ModelManagerService(ModelManagerServiceBase):
:param new_name: New name for the model
:param new_base: New base for the model
"""
self.mgr.rename_model(
base_model=base_model,
model_type=model_type,
model_name=model_name,
new_name=new_name,
new_base=new_base,
)
self.mgr.rename_model(base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)

View File

@@ -11,20 +11,30 @@ class BoardRecord(BaseModel):
"""The unique ID of the board."""
board_name: str = Field(description="The name of the board.")
"""The name of the board."""
created_at: Union[datetime, str] = Field(description="The created timestamp of the board.")
created_at: Union[datetime, str] = Field(
description="The created timestamp of the board."
)
"""The created timestamp of the image."""
updated_at: Union[datetime, str] = Field(description="The updated timestamp of the board.")
updated_at: Union[datetime, str] = Field(
description="The updated timestamp of the board."
)
"""The updated timestamp of the image."""
deleted_at: Union[datetime, str, None] = Field(description="The deleted timestamp of the board.")
deleted_at: Union[datetime, str, None] = Field(
description="The deleted timestamp of the board."
)
"""The updated timestamp of the image."""
cover_image_name: Optional[str] = Field(description="The name of the cover image of the board.")
cover_image_name: Optional[str] = Field(
description="The name of the cover image of the board."
)
"""The name of the cover image of the board."""
class BoardDTO(BoardRecord):
"""Deserialized board record with cover image URL and image count."""
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
cover_image_name: Optional[str] = Field(
description="The name of the board's cover image."
)
"""The URL of the thumbnail of the most recent image in the board."""
image_count: int = Field(description="The number of images in the board.")
"""The number of images in the board."""

View File

@@ -20,11 +20,17 @@ class ImageRecord(BaseModel):
"""The actual width of the image in px. This may be different from the width in metadata."""
height: int = Field(description="The height of the image in px.")
"""The actual height of the image in px. This may be different from the height in metadata."""
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the image.")
created_at: Union[datetime.datetime, str] = Field(
description="The created timestamp of the image."
)
"""The created timestamp of the image."""
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the image.")
updated_at: Union[datetime.datetime, str] = Field(
description="The updated timestamp of the image."
)
"""The updated timestamp of the image."""
deleted_at: Union[datetime.datetime, str, None] = Field(description="The deleted timestamp of the image.")
deleted_at: Union[datetime.datetime, str, None] = Field(
description="The deleted timestamp of the image."
)
"""The deleted timestamp of the image."""
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
"""Whether this is an intermediate image."""
@@ -49,14 +55,18 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
- `is_intermediate`: change the image's `is_intermediate` flag
"""
image_category: Optional[ImageCategory] = Field(description="The image's new category.")
image_category: Optional[ImageCategory] = Field(
description="The image's new category."
)
"""The image's new category."""
session_id: Optional[StrictStr] = Field(
default=None,
description="The image's new session ID.",
)
"""The image's new session ID."""
is_intermediate: Optional[StrictBool] = Field(default=None, description="The image's new `is_intermediate` flag.")
is_intermediate: Optional[StrictBool] = Field(
default=None, description="The image's new `is_intermediate` flag."
)
"""The image's new `is_intermediate` flag."""
@@ -74,7 +84,9 @@ class ImageUrlsDTO(BaseModel):
class ImageDTO(ImageRecord, ImageUrlsDTO):
"""Deserialized image record, enriched for the frontend."""
board_id: Optional[str] = Field(description="The id of the board the image belongs to, if one exists.")
board_id: Optional[str] = Field(
description="The id of the board the image belongs to, if one exists."
)
"""The id of the board the image belongs to, if one exists."""
pass
@@ -98,8 +110,12 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
# TODO: do we really need to handle default values here? ideally the data is the correct shape...
image_name = image_dict.get("image_name", "unknown")
image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value))
image_category = ImageCategory(image_dict.get("image_category", ImageCategory.GENERAL.value))
image_origin = ResourceOrigin(
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
)
image_category = ImageCategory(
image_dict.get("image_category", ImageCategory.GENERAL.value)
)
width = image_dict.get("width", 0)
height = image_dict.get("height", 0)
session_id = image_dict.get("session_id", None)

View File

@@ -8,8 +8,6 @@ from .invoker import InvocationProcessorABC, Invoker
from ..models.exceptions import CanceledException
import invokeai.backend.util.logging as logger
class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker_thread: Thread
__stop_event: Event
@@ -26,7 +24,9 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
target=self.__process,
kwargs=dict(stop_event=self.__stop_event),
)
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
self.__invoker_thread.daemon = (
True # TODO: make async and do not use threads
)
self.__invoker_thread.start()
def stop(self, *args, **kwargs) -> None:
@@ -39,37 +39,21 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
except Exception as e:
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
logger.debug("Exception while getting from queue: %s" % e)
if not queue_item: # Probably stopping
# do not hammer the queue
time.sleep(0.5)
continue
try:
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
graph_execution_state = (
self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id
)
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
self.__invoker.services.events.emit_session_retrieval_error(
graph_execution_state_id=queue_item.graph_execution_state_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
try:
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
self.__invoker.services.events.emit_invocation_retrieval_error(
graph_execution_state_id=queue_item.graph_execution_state_id,
node_id=queue_item.invocation_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
)
continue
)
invocation = graph_execution_state.execution_graph.get_node(
queue_item.invocation_id
)
# get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
@@ -78,7 +62,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__invoker.services.events.emit_invocation_started(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
source_node_id=source_node_id
)
# Invoke
@@ -91,14 +75,18 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
)
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
if self.__invoker.services.queue.is_canceled(
graph_execution_state.id
):
continue
# Save outputs and history
graph_execution_state.complete(invocation.id, outputs)
# Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
self.__invoker.services.graph_execution_manager.set(
graph_execution_state
)
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
@@ -122,22 +110,24 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_execution_state.set_node_error(invocation.id, error)
# Save the state changes
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
self.__invoker.services.graph_execution_manager.set(
graph_execution_state
)
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# Check queue to see if this is canceled, and skip if so
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
if self.__invoker.services.queue.is_canceled(
graph_execution_state.id
):
continue
# Queue any further commands if invoking all
@@ -146,16 +136,17 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try:
self.__invoker.invoke(graph_execution_state, invoke_all=True)
except Exception as e:
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
logger.error("Error while invoking: %s" % e)
self.__invoker.services.events.emit_invocation_error(
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
error=traceback.format_exc()
)
elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
self.__invoker.services.events.emit_graph_execution_complete(
graph_execution_state.id
)
except KeyboardInterrupt:
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor

View File

@@ -66,7 +66,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def get(self, id: str) -> Optional[T]:
try:
self._lock.acquire()
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone()
finally:
self._lock.release()
@@ -79,7 +81,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def get_raw(self, id: str) -> Optional[str]:
try:
self._lock.acquire()
self._cursor.execute(f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),))
self._cursor.execute(
f"""SELECT item FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
result = self._cursor.fetchone()
finally:
self._lock.release()
@@ -92,7 +96,9 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
def delete(self, id: str):
try:
self._lock.acquire()
self._cursor.execute(f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),))
self._cursor.execute(
f"""DELETE FROM {self._table_name} WHERE id = ?;""", (str(id),)
)
self._conn.commit()
finally:
self._lock.release()
@@ -116,9 +122,13 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
return PaginatedResults[T](
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)
def search(self, query: str, page: int = 0, per_page: int = 10) -> PaginatedResults[T]:
def search(
self, query: str, page: int = 0, per_page: int = 10
) -> PaginatedResults[T]:
try:
self._lock.acquire()
self._cursor.execute(
@@ -139,4 +149,6 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]):
pageCount = int(count / per_page) + 1
return PaginatedResults[T](items=items, page=page, pages=pageCount, per_page=per_page, total=count)
return PaginatedResults[T](
items=items, page=page, pages=pageCount, per_page=per_page, total=count
)

View File

@@ -17,8 +17,16 @@ from controlnet_aux.util import HWC3, resize_image
# If you use this, please Cite "High Quality Edge Thinning using Pure Python", Lvmin Zhang, In Mikubill/sd-webui-controlnet.
lvmin_kernels_raw = [
np.array([[-1, -1, -1], [0, 1, 0], [1, 1, 1]], dtype=np.int32),
np.array([[0, -1, -1], [1, 1, -1], [0, 1, 0]], dtype=np.int32),
np.array([
[-1, -1, -1],
[0, 1, 0],
[1, 1, 1]
], dtype=np.int32),
np.array([
[0, -1, -1],
[1, 1, -1],
[0, 1, 0]
], dtype=np.int32)
]
lvmin_kernels = []
@@ -28,8 +36,16 @@ lvmin_kernels += [np.rot90(x, k=2, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_kernels += [np.rot90(x, k=3, axes=(0, 1)) for x in lvmin_kernels_raw]
lvmin_prunings_raw = [
np.array([[-1, -1, -1], [-1, 1, -1], [0, 0, -1]], dtype=np.int32),
np.array([[-1, -1, -1], [-1, 1, -1], [-1, 0, 0]], dtype=np.int32),
np.array([
[-1, -1, -1],
[-1, 1, -1],
[0, 0, -1]
], dtype=np.int32),
np.array([
[-1, -1, -1],
[-1, 1, -1],
[-1, 0, 0]
], dtype=np.int32)
]
lvmin_prunings = []
@@ -83,10 +99,10 @@ def nake_nms(x):
################################################################################
# FIXME: not using yet, if used in the future will most likely require modification of preprocessors
def pixel_perfect_resolution(
image: np.ndarray,
target_H: int,
target_W: int,
resize_mode: str,
image: np.ndarray,
target_H: int,
target_W: int,
resize_mode: str,
) -> int:
"""
Calculate the estimated resolution for resizing an image while preserving aspect ratio.
@@ -119,7 +135,7 @@ def pixel_perfect_resolution(
if resize_mode == "fill_resize":
estimation = min(k0, k1) * float(min(raw_H, raw_W))
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
else: # "crop_resize" or "just_resize" (or possibly "just_resize_simple"?)
estimation = max(k0, k1) * float(min(raw_H, raw_W))
# print(f"Pixel Perfect Computation:")
@@ -138,7 +154,13 @@ def pixel_perfect_resolution(
# modified for InvokeAI
###########################################################################
# def detectmap_proc(detected_map, module, resize_mode, h, w):
def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device: torch.device = torch.device("cpu")):
def np_img_resize(
np_img: np.ndarray,
resize_mode: str,
h: int,
w: int,
device: torch.device = torch.device('cpu')
):
# if 'inpaint' in module:
# np_img = np_img.astype(np.float32)
# else:
@@ -162,14 +184,15 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
# below is very boring but do not change these. If you change these Apple or Mac may fail.
y = torch.from_numpy(y)
y = y.float() / 255.0
y = rearrange(y, "h w c -> 1 c h w")
y = rearrange(y, 'h w c -> 1 c h w')
y = y.clone()
# y = y.to(devices.get_device_for("controlnet"))
y = y.to(device)
y = y.clone()
return y
def high_quality_resize(x: np.ndarray, size):
def high_quality_resize(x: np.ndarray,
size):
# Written by lvmin
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
inpaint_mask = None
@@ -221,7 +244,7 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
return y
# if resize_mode == external_code.ResizeMode.RESIZE:
if resize_mode == "just_resize": # RESIZE
if resize_mode == "just_resize": # RESIZE
np_img = high_quality_resize(np_img, (w, h))
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
@@ -247,21 +270,20 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
new_h, new_w, _ = np_img.shape
pad_h = max(0, (h - new_h) // 2)
pad_w = max(0, (w - new_w) // 2)
high_quality_background[pad_h : pad_h + new_h, pad_w : pad_w + new_w] = np_img
high_quality_background[pad_h:pad_h + new_h, pad_w:pad_w + new_w] = np_img
np_img = high_quality_background
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
else: # resize_mode == "crop_resize" (INNER_FIT)
else: # resize_mode == "crop_resize" (INNER_FIT)
k = max(k0, k1)
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape
pad_h = max(0, (new_h - h) // 2)
pad_w = max(0, (new_w - w) // 2)
np_img = np_img[pad_h : pad_h + h, pad_w : pad_w + w]
np_img = np_img[pad_h:pad_h + h, pad_w:pad_w + w]
np_img = safe_numpy(np_img)
return get_pytorch_control(np_img), np_img
def prepare_control_image(
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
# but now should be able to assume that image is a single PIL.Image, which simplifies things
@@ -279,17 +301,15 @@ def prepare_control_image(
resize_mode="just_resize_simple",
):
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
if (
resize_mode == "just_resize_simple"
or resize_mode == "crop_resize_simple"
or resize_mode == "fill_resize_simple"
):
if (resize_mode == "just_resize_simple" or
resize_mode == "crop_resize_simple" or
resize_mode == "fill_resize_simple"):
image = image.convert("RGB")
if resize_mode == "just_resize_simple":
if (resize_mode == "just_resize_simple"):
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
elif resize_mode == "crop_resize_simple": # not yet implemented
elif (resize_mode == "crop_resize_simple"): # not yet implemented
pass
elif resize_mode == "fill_resize_simple": # not yet implemented
elif (resize_mode == "fill_resize_simple"): # not yet implemented
pass
nimage = np.array(image)
nimage = nimage[None, :]
@@ -300,7 +320,7 @@ def prepare_control_image(
timage = torch.from_numpy(nimage)
# use fancy lvmin controlnet resizing
elif resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize":
elif (resize_mode == "just_resize" or resize_mode == "crop_resize" or resize_mode == "fill_resize"):
nimage = np.array(image)
timage, nimage = np_img_resize(
np_img=nimage,
@@ -316,7 +336,7 @@ def prepare_control_image(
exit(1)
timage = timage.to(device=device, dtype=dtype)
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
cfg_injection = (control_mode == "more_control" or control_mode == "unbalanced")
if do_classifier_free_guidance and not cfg_injection:
timage = torch.cat([timage] * 2)
return timage

View File

@@ -14,9 +14,8 @@ def get_datetime_from_iso_timestamp(iso_timestamp: str) -> datetime.datetime:
return datetime.datetime.fromisoformat(iso_timestamp)
SEED_MAX = np.iinfo(np.uint32).max
SEED_MAX = np.iinfo(np.int32).max
def get_random_seed():
rng = np.random.default_rng(seed=0)
return int(rng.integers(0, SEED_MAX))
return np.random.randint(0, SEED_MAX)

View File

@@ -9,16 +9,19 @@ from ...backend.stable_diffusion import PipelineIntermediateState
from invokeai.app.services.config import InvokeAIAppConfig
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix = None):
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
if smooth_matrix is not None:
latent_image = latent_image.unsqueeze(0).permute(3, 0, 1, 2)
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1, 1, 3, 3)), padding=1)
latent_image = torch.nn.functional.conv2d(latent_image, smooth_matrix.reshape((1,1,3,3)), padding=1)
latent_image = latent_image.permute(1, 2, 3, 0).squeeze(0)
latents_ubyte = (
((latent_image + 1) / 2).clamp(0, 1).mul(0xFF).byte() # change scale from -1..1 to 0..1 # to 0..255
((latent_image + 1) / 2)
.clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
.byte()
).cpu()
return Image.fromarray(latents_ubyte.numpy())
@@ -89,7 +92,6 @@ def stable_diffusion_step_callback(
total_steps=node["steps"],
)
def stable_diffusion_xl_step_callback(
context: InvocationContext,
node: dict,
@@ -104,9 +106,9 @@ def stable_diffusion_xl_step_callback(
sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[ 0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[ 0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
],
dtype=sample.dtype,
@@ -115,9 +117,9 @@ def stable_diffusion_xl_step_callback(
sdxl_smooth_matrix = torch.tensor(
[
# [ 0.0478, 0.1285, 0.0478],
# [ 0.1285, 0.2948, 0.1285],
# [ 0.0478, 0.1285, 0.0478],
#[ 0.0478, 0.1285, 0.0478],
#[ 0.1285, 0.2948, 0.1285],
#[ 0.0478, 0.1285, 0.0478],
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
@@ -141,4 +143,4 @@ def stable_diffusion_xl_step_callback(
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=step,
total_steps=total_steps,
)
)

Some files were not shown because too many files have changed in this diff Show More