mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 14:08:20 -05:00
Compare commits
175 Commits
3.4.0post1
...
feat/workf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b1762d8da | ||
|
|
b056a9c181 | ||
|
|
f56c47b550 | ||
|
|
f8e35aec7b | ||
|
|
0463541d99 | ||
|
|
10cf10c16c | ||
|
|
e45704833e | ||
|
|
0fdcc0af65 | ||
|
|
4fc2ed7195 | ||
|
|
d0464a5793 | ||
|
|
7d4a78e470 | ||
|
|
37c87affd0 | ||
|
|
3863bd9da3 | ||
|
|
4b2e3aa54d | ||
|
|
d699efa5bc | ||
|
|
b9a1374b8f | ||
|
|
411ea75861 | ||
|
|
375c9a1c20 | ||
|
|
907340b1e1 | ||
|
|
0f32d260b7 | ||
|
|
92bc04dc87 | ||
|
|
929b1f4a41 | ||
|
|
6d7b4b8e8a | ||
|
|
4a14ee0e01 | ||
|
|
f268ea4e39 | ||
|
|
78face3481 | ||
|
|
5a0e8261bf | ||
|
|
0447fa2dcb | ||
|
|
fb9b471150 | ||
|
|
3f0e0af177 | ||
|
|
0228aba06f | ||
|
|
1fd6666682 | ||
|
|
4fd163698c | ||
|
|
cff6600ded | ||
|
|
04ddcf53f3 | ||
|
|
0539a64569 | ||
|
|
224438a108 | ||
|
|
81d2d5abae | ||
|
|
734e871e8f | ||
|
|
b0350e9bc8 | ||
|
|
5a3f1f2b22 | ||
|
|
f95ce1870c | ||
|
|
0719a46372 | ||
|
|
0a25efd054 | ||
|
|
a8ef4e5be8 | ||
|
|
e6fe2540b8 | ||
|
|
aadcde3edd | ||
|
|
984e609c61 | ||
|
|
57e70aaf50 | ||
|
|
bfdef120d1 | ||
|
|
32da359ba5 | ||
|
|
b19ed36b43 | ||
|
|
e5a212b5c8 | ||
|
|
9b863fb9bc | ||
|
|
7cab51745b | ||
|
|
18c6ff427e | ||
|
|
843f2d71d6 | ||
|
|
67540c9ee0 | ||
|
|
7f816c9243 | ||
|
|
76b888de17 | ||
|
|
65a16be299 | ||
|
|
1c8ff0ae66 | ||
|
|
29eade4880 | ||
|
|
86fd1d5b22 | ||
|
|
909b78a1cb | ||
|
|
2f81f9fb22 | ||
|
|
a6d4e4ed57 | ||
|
|
3e01c396e1 | ||
|
|
0beb08686c | ||
|
|
46905175a9 | ||
|
|
11085783ef | ||
|
|
693c6cf5e4 | ||
|
|
77933a0a85 | ||
|
|
2a087bf161 | ||
|
|
3d57c14bb3 | ||
|
|
18f3190857 | ||
|
|
fcc056fe6a | ||
|
|
c1bfc1f47b | ||
|
|
b0fe57ec80 | ||
|
|
09cb40786f | ||
|
|
18ecfc0521 | ||
|
|
14bf87e5e7 | ||
|
|
715ce8538b | ||
|
|
1987bc9cc5 | ||
|
|
0b079df4ae | ||
|
|
a514c9e28b | ||
|
|
8cf2806489 | ||
|
|
eb446471cc | ||
|
|
7392d07331 | ||
|
|
59d932e9c1 | ||
|
|
578c8ce5dd | ||
|
|
3d4874dc34 | ||
|
|
5aaf2e8873 | ||
|
|
f3fd0f6d73 | ||
|
|
4468581d2e | ||
|
|
da642b7aad | ||
|
|
b379e3d187 | ||
|
|
6867c79185 | ||
|
|
a1705dc6b3 | ||
|
|
4af4486dd9 | ||
|
|
282a7f32d3 | ||
|
|
4c6a88a642 | ||
|
|
e41d0b9a76 | ||
|
|
a02090b06b | ||
|
|
0d9a546d74 | ||
|
|
8d99113bef | ||
|
|
4309f3bd58 | ||
|
|
42370939a8 | ||
|
|
654591cbf3 | ||
|
|
ad9c954a58 | ||
|
|
a703e1b3d3 | ||
|
|
e85f2254f0 | ||
|
|
8f2cf30191 | ||
|
|
296741306c | ||
|
|
5386a286fd | ||
|
|
803fb393bb | ||
|
|
ab944bd13a | ||
|
|
514c49d946 | ||
|
|
858bcdd3ff | ||
|
|
ed79980dd4 | ||
|
|
86a74e929a | ||
|
|
0d52430481 | ||
|
|
4eca802cdd | ||
|
|
ff0a25bd9c | ||
|
|
ace0eb366b | ||
|
|
d971c5fa64 | ||
|
|
ae82df0fda | ||
|
|
e28262ebd9 | ||
|
|
250ee4b11c | ||
|
|
b7293d638b | ||
|
|
eee863e380 | ||
|
|
e509d719ee | ||
|
|
1d8f44d356 | ||
|
|
7653d21cf5 | ||
|
|
46a2d83b84 | ||
|
|
79efc6789e | ||
|
|
2192210910 | ||
|
|
84629df49c | ||
|
|
ef6b27ab35 | ||
|
|
17420f76b3 | ||
|
|
45213aa631 | ||
|
|
4381dabbd9 | ||
|
|
b4a03fcf42 | ||
|
|
714be33850 | ||
|
|
5f23fc493d | ||
|
|
4fe93e521e | ||
|
|
6e6d903f99 | ||
|
|
667a2a3d84 | ||
|
|
f57b277d5a | ||
|
|
e62991c54d | ||
|
|
785d584603 | ||
|
|
da4aab9233 | ||
|
|
591b601fd3 | ||
|
|
317b5ebae1 | ||
|
|
98a4930a52 | ||
|
|
1a596a5684 | ||
|
|
84a0a0fa14 | ||
|
|
da443973cb | ||
|
|
d073d10f9f | ||
|
|
2b7e7496f7 | ||
|
|
50ab677ea4 | ||
|
|
cb81558302 | ||
|
|
9259483081 | ||
|
|
4ece322f82 | ||
|
|
13e8fa733e | ||
|
|
3e473ae008 | ||
|
|
487fda0226 | ||
|
|
74d3b22533 | ||
|
|
b5e018972f | ||
|
|
2af844385f | ||
|
|
540047e26e | ||
|
|
4d8b8a2db8 | ||
|
|
d581a3289b | ||
|
|
d756c9b10a | ||
|
|
63d3212bec |
21
Makefile
Normal file
21
Makefile
Normal file
@@ -0,0 +1,21 @@
|
||||
# simple Makefile with scripts that are otherwise hard to remember
|
||||
# to use, run from the repo root `make <command>`
|
||||
|
||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||
ruff:
|
||||
ruff check . --fix
|
||||
ruff format .
|
||||
|
||||
# Runs ruff, fixing all errors it can fix and formatting
|
||||
ruff-unsafe:
|
||||
ruff check . --fix --unsafe-fixes
|
||||
ruff format .
|
||||
|
||||
# Runs mypy, using the config in pyproject.toml
|
||||
mypy:
|
||||
mypy scripts/invokeai-web.py
|
||||
|
||||
# Runs mypy, ignoring the config in pyproject.toml but still ignoring missing (untyped) imports
|
||||
# (many files are ignored by the config, so this is useful for checking all files)
|
||||
mypy-all:
|
||||
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
|
||||
@@ -395,7 +395,7 @@ Notes](https://github.com/invoke-ai/InvokeAI/releases) and the
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
Please check out our **[Q&A](https://invoke-ai.github.io/InvokeAI/help/TROUBLESHOOT/#faq)** to get solutions for common installation
|
||||
Please check out our **[Troubleshooting Guide](https://invoke-ai.github.io/InvokeAI/installation/010_INSTALL_AUTOMATED/#troubleshooting)** to get solutions for common installation
|
||||
problems and other issues. For more help, please join our [Discord][discord link]
|
||||
|
||||
## Contributing
|
||||
|
||||
@@ -65,7 +65,7 @@ The first set of things we need to do when creating a new Invocation are -
|
||||
So let us do that.
|
||||
|
||||
```python
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -99,8 +99,8 @@ create your own custom field types later in this guide. For now, let's go ahead
|
||||
and use it.
|
||||
|
||||
```python
|
||||
from .baseinvocation import BaseInvocation, InputField, invocation
|
||||
from .primitives import ImageField
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -124,8 +124,8 @@ image: ImageField = InputField(description="The input image")
|
||||
Great. Now let us create our other inputs for `width` and `height`
|
||||
|
||||
```python
|
||||
from .baseinvocation import BaseInvocation, InputField, invocation
|
||||
from .primitives import ImageField
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -160,8 +160,8 @@ that are provided by it by InvokeAI.
|
||||
Let us create this function first.
|
||||
|
||||
```python
|
||||
from .baseinvocation import BaseInvocation, InputField, invocation
|
||||
from .primitives import ImageField
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -189,9 +189,9 @@ all the necessary info related to image outputs. So let us use that.
|
||||
We will cover how to create your own output types later in this guide.
|
||||
|
||||
```python
|
||||
from .baseinvocation import BaseInvocation, InputField, invocation
|
||||
from .primitives import ImageField
|
||||
from .image import ImageOutput
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.image import ImageOutput
|
||||
|
||||
@invocation('resize')
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
@@ -216,9 +216,9 @@ Perfect. Now that we have our Invocation setup, let us do what we want to do.
|
||||
So let's do that.
|
||||
|
||||
```python
|
||||
from .baseinvocation import BaseInvocation, InputField, invocation
|
||||
from .primitives import ImageField
|
||||
from .image import ImageOutput
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, InputField, invocation, InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.image import ImageOutput, ResourceOrigin, ImageCategory
|
||||
|
||||
@invocation("resize")
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
|
||||
@@ -120,7 +120,7 @@ Generate an image with a given prompt, record the seed of the image, and then
|
||||
use the `prompt2prompt` syntax to substitute words in the original prompt for
|
||||
words in a new prompt. This works for `img2img` as well.
|
||||
|
||||
For example, consider the prompt `a cat.swap(dog) playing with a ball in the forest`. Normally, because of the word words interact with each other when doing a stable diffusion image generation, these two prompts would generate different compositions:
|
||||
For example, consider the prompt `a cat.swap(dog) playing with a ball in the forest`. Normally, because the words interact with each other when doing a stable diffusion image generation, these two prompts would generate different compositions:
|
||||
- `a cat playing with a ball in the forest`
|
||||
- `a dog playing with a ball in the forest`
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ To use a node, add the node to the `nodes` folder found in your InvokeAI install
|
||||
|
||||
The suggested method is to use `git clone` to clone the repository the node is found in. This allows for easy updates of the node in the future.
|
||||
|
||||
If you'd prefer, you can also just download the `.py` file from the linked repository and add it to the `nodes` folder.
|
||||
If you'd prefer, you can also just download the whole node folder from the linked repository and add it to the `nodes` folder.
|
||||
|
||||
To use a community workflow, download the the `.json` node graph file and load it into Invoke AI via the **Load Workflow** button in the Workflow Editor.
|
||||
|
||||
@@ -26,8 +26,10 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
+ [Image Picker](#image-picker)
|
||||
+ [Load Video Frame](#load-video-frame)
|
||||
+ [Make 3D](#make-3d)
|
||||
+ [Match Histogram](#match-histogram)
|
||||
+ [Oobabooga](#oobabooga)
|
||||
+ [Prompt Tools](#prompt-tools)
|
||||
+ [Remote Image](#remote-image)
|
||||
+ [Retroize](#retroize)
|
||||
+ [Size Stepper Nodes](#size-stepper-nodes)
|
||||
+ [Text font to Image](#text-font-to-image)
|
||||
@@ -207,6 +209,23 @@ This includes 15 Nodes:
|
||||
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-1.png" width="300" />
|
||||
<img src="https://gitlab.com/srcrr/shift3d/-/raw/main/example-2.png" width="300" />
|
||||
|
||||
--------------------------------
|
||||
### Match Histogram
|
||||
|
||||
**Description:** An InvokeAI node to match a histogram from one image to another. This is a bit like the `color correct` node in the main InvokeAI but this works in the YCbCr colourspace and can handle images of different sizes. Also does not require a mask input.
|
||||
- Option to only transfer luminance channel.
|
||||
- Option to save output as grayscale
|
||||
|
||||
A good use case for this node is to normalize the colors of an image that has been through the tiled scaling workflow of my XYGrid Nodes.
|
||||
|
||||
See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/match_histogram
|
||||
|
||||
**Output Examples**
|
||||
|
||||
<img src="https://github.com/skunkworxdark/match_histogram/assets/21961335/ed12f329-a0ef-444a-9bae-129ed60d6097" width="300" />
|
||||
|
||||
--------------------------------
|
||||
### Oobabooga
|
||||
|
||||
@@ -236,22 +255,41 @@ This node works best with SDXL models, especially as the style can be described
|
||||
--------------------------------
|
||||
### Prompt Tools
|
||||
|
||||
**Description:** A set of InvokeAI nodes that add general prompt manipulation tools. These were written to accompany the PromptsFromFile node and other prompt generation nodes.
|
||||
**Description:** A set of InvokeAI nodes that add general prompt (string) manipulation tools. Designed to accompany the `Prompts From File` node and other prompt generation nodes.
|
||||
|
||||
1. `Prompt To File` - saves a prompt or collection of prompts to a file. one per line. There is an append/overwrite option.
|
||||
2. `PTFields Collect` - Converts image generation fields into a Json format string that can be passed to Prompt to file.
|
||||
3. `PTFields Expand` - Takes Json string and converts it to individual generation parameters. This can be fed from the Prompts to file node.
|
||||
4. `Prompt Strength` - Formats prompt with strength like the weighted format of compel
|
||||
5. `Prompt Strength Combine` - Combines weighted prompts for .and()/.blend()
|
||||
6. `CSV To Index String` - Gets a string from a CSV by index. Includes a Random index option
|
||||
|
||||
The following Nodes are now included in v3.2 of Invoke and are nolonger in this set of tools.<br>
|
||||
- `Prompt Join` -> `String Join`
|
||||
- `Prompt Join Three` -> `String Join Three`
|
||||
- `Prompt Replace` -> `String Replace`
|
||||
- `Prompt Split Neg` -> `String Split Neg`
|
||||
|
||||
1. PromptJoin - Joins to prompts into one.
|
||||
2. PromptReplace - performs a search and replace on a prompt. With the option of using regex.
|
||||
3. PromptSplitNeg - splits a prompt into positive and negative using the old V2 method of [] for negative.
|
||||
4. PromptToFile - saves a prompt or collection of prompts to a file. one per line. There is an append/overwrite option.
|
||||
5. PTFieldsCollect - Converts image generation fields into a Json format string that can be passed to Prompt to file.
|
||||
6. PTFieldsExpand - Takes Json string and converts it to individual generation parameters This can be fed from the Prompts to file node.
|
||||
7. PromptJoinThree - Joins 3 prompt together.
|
||||
8. PromptStrength - This take a string and float and outputs another string in the format of (string)strength like the weighted format of compel.
|
||||
9. PromptStrengthCombine - This takes a collection of prompt strength strings and outputs a string in the .and() or .blend() format that can be fed into a proper prompt node.
|
||||
|
||||
See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/Prompt-tools-nodes
|
||||
|
||||
**Workflow Examples**
|
||||
|
||||
<img src="https://github.com/skunkworxdark/prompt-tools/blob/main/images/CSVToIndexStringNode.png" width="300" />
|
||||
|
||||
--------------------------------
|
||||
### Remote Image
|
||||
|
||||
**Description:** This is a pack of nodes to interoperate with other services, be they public websites or bespoke local servers. The pack consists of these nodes:
|
||||
|
||||
- *Load Remote Image* - Lets you load remote images such as a realtime webcam image, an image of the day, or dynamically created images.
|
||||
- *Post Image to Remote Server* - Lets you upload an image to a remote server using an HTTP POST request, eg for storage, display or further processing.
|
||||
|
||||
**Node Link:** https://github.com/fieldOfView/InvokeAI-remote_image
|
||||
|
||||
|
||||
--------------------------------
|
||||
### Retroize
|
||||
|
||||
@@ -327,15 +365,27 @@ Highlights/Midtones/Shadows (with LUT blur enabled):
|
||||
--------------------------------
|
||||
### XY Image to Grid and Images to Grids nodes
|
||||
|
||||
**Description:** Image to grid nodes and supporting tools.
|
||||
**Description:** These nodes add the following to InvokeAI:
|
||||
- Generate grids of images from multiple input images
|
||||
- Create XY grid images with labels from parameters
|
||||
- Split images into overlapping tiles for processing (for super-resolution workflows)
|
||||
- Recombine image tiles into a single output image blending the seams
|
||||
|
||||
1. "Images To Grids" node - Takes a collection of images and creates a grid(s) of images. If there are more images than the size of a single grid then multiple grids will be created until it runs out of images.
|
||||
2. "XYImage To Grid" node - Converts a collection of XYImages into a labeled Grid of images. The XYImages collection has to be built using the supporting nodes. See example node setups for more details.
|
||||
The nodes include:
|
||||
1. `Images To Grids` - Combine multiple images into a grid of images
|
||||
2. `XYImage To Grid` - Take X & Y params and creates a labeled image grid.
|
||||
3. `XYImage Tiles` - Super-resolution (embiggen) style tiled resizing
|
||||
4. `Image Tot XYImages` - Takes an image and cuts it up into a number of columns and rows.
|
||||
5. Multiple supporting nodes - Helper nodes for data wrangling and building `XYImage` collections
|
||||
|
||||
See full docs here: https://github.com/skunkworxdark/XYGrid_nodes/edit/main/README.md
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/XYGrid_nodes
|
||||
|
||||
**Output Examples**
|
||||
|
||||
<img src="https://github.com/skunkworxdark/XYGrid_nodes/blob/main/images/collage.png" width="300" />
|
||||
|
||||
--------------------------------
|
||||
### Example Node Template
|
||||
|
||||
|
||||
@@ -1,104 +1,106 @@
|
||||
# List of Default Nodes
|
||||
|
||||
The table below contains a list of the default nodes shipped with InvokeAI and their descriptions.
|
||||
The table below contains a list of the default nodes shipped with InvokeAI and
|
||||
their descriptions.
|
||||
|
||||
| Node <img width=160 align="right"> | Function |
|
||||
|: ---------------------------------- | :--------------------------------------------------------------------------------------|
|
||||
|Add Integers | Adds two numbers|
|
||||
|Boolean Primitive Collection | A collection of boolean primitive values|
|
||||
|Boolean Primitive | A boolean primitive value|
|
||||
|Canny Processor | Canny edge detection for ControlNet|
|
||||
|CLIP Skip | Skip layers in clip text_encoder model.|
|
||||
|Collect | Collects values into a collection|
|
||||
|Color Correct | Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image.|
|
||||
|Color Primitive | A color primitive value|
|
||||
|Compel Prompt | Parse prompt using compel package to conditioning.|
|
||||
|Conditioning Primitive Collection | A collection of conditioning tensor primitive values|
|
||||
|Conditioning Primitive | A conditioning tensor primitive value|
|
||||
|Content Shuffle Processor | Applies content shuffle processing to image|
|
||||
|ControlNet | Collects ControlNet info to pass to other nodes|
|
||||
|Denoise Latents | Denoises noisy latents to decodable images|
|
||||
|Divide Integers | Divides two numbers|
|
||||
|Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator|
|
||||
|[FaceMask](./detailedNodes/faceTools.md#facemask) | Generates masks for faces in an image to use with Inpainting|
|
||||
|[FaceIdentifier](./detailedNodes/faceTools.md#faceidentifier) | Identifies and labels faces in an image|
|
||||
|[FaceOff](./detailedNodes/faceTools.md#faceoff) | Creates a new image that is a scaled bounding box with a mask on the face for Inpainting|
|
||||
|Float Math | Perform basic math operations on two floats|
|
||||
|Float Primitive Collection | A collection of float primitive values|
|
||||
|Float Primitive | A float primitive value|
|
||||
|Float Range | Creates a range|
|
||||
|HED (softedge) Processor | Applies HED edge detection to image|
|
||||
|Blur Image | Blurs an image|
|
||||
|Extract Image Channel | Gets a channel from an image.|
|
||||
|Image Primitive Collection | A collection of image primitive values|
|
||||
|Integer Math | Perform basic math operations on two integers|
|
||||
|Convert Image Mode | Converts an image to a different mode.|
|
||||
|Crop Image | Crops an image to a specified box. The box can be outside of the image.|
|
||||
|Image Hue Adjustment | Adjusts the Hue of an image.|
|
||||
|Inverse Lerp Image | Inverse linear interpolation of all pixels of an image|
|
||||
|Image Primitive | An image primitive value|
|
||||
|Lerp Image | Linear interpolation of all pixels of an image|
|
||||
|Offset Image Channel | Add to or subtract from an image color channel by a uniform value.|
|
||||
|Multiply Image Channel | Multiply or Invert an image color channel by a scalar value.|
|
||||
|Multiply Images | Multiplies two images together using `PIL.ImageChops.multiply()`.|
|
||||
|Blur NSFW Image | Add blur to NSFW-flagged images|
|
||||
|Paste Image | Pastes an image into another image.|
|
||||
|ImageProcessor | Base class for invocations that preprocess images for ControlNet|
|
||||
|Resize Image | Resizes an image to specific dimensions|
|
||||
|Round Float | Rounds a float to a specified number of decimal places|
|
||||
|Float to Integer | Converts a float to an integer. Optionally rounds to an even multiple of a input number.|
|
||||
|Scale Image | Scales an image by a factor|
|
||||
|Image to Latents | Encodes an image into latents.|
|
||||
|Add Invisible Watermark | Add an invisible watermark to an image|
|
||||
|Solid Color Infill | Infills transparent areas of an image with a solid color|
|
||||
|PatchMatch Infill | Infills transparent areas of an image using the PatchMatch algorithm|
|
||||
|Tile Infill | Infills transparent areas of an image with tiles of the image|
|
||||
|Integer Primitive Collection | A collection of integer primitive values|
|
||||
|Integer Primitive | An integer primitive value|
|
||||
|Iterate | Iterates over a list of items|
|
||||
|Latents Primitive Collection | A collection of latents tensor primitive values|
|
||||
|Latents Primitive | A latents tensor primitive value|
|
||||
|Latents to Image | Generates an image from latents.|
|
||||
|Leres (Depth) Processor | Applies leres processing to image|
|
||||
|Lineart Anime Processor | Applies line art anime processing to image|
|
||||
|Lineart Processor | Applies line art processing to image|
|
||||
|LoRA Loader | Apply selected lora to unet and text_encoder.|
|
||||
|Main Model Loader | Loads a main model, outputting its submodels.|
|
||||
|Combine Mask | Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.|
|
||||
|Mask Edge | Applies an edge mask to an image|
|
||||
|Mask from Alpha | Extracts the alpha channel of an image as a mask.|
|
||||
|Mediapipe Face Processor | Applies mediapipe face processing to image|
|
||||
|Midas (Depth) Processor | Applies Midas depth processing to image|
|
||||
|MLSD Processor | Applies MLSD processing to image|
|
||||
|Multiply Integers | Multiplies two numbers|
|
||||
|Noise | Generates latent noise.|
|
||||
|Normal BAE Processor | Applies NormalBae processing to image|
|
||||
|ONNX Latents to Image | Generates an image from latents.|
|
||||
|ONNX Prompt (Raw) | A node to process inputs and produce outputs. May use dependency injection in __init__ to receive providers.|
|
||||
|ONNX Text to Latents | Generates latents from conditionings.|
|
||||
|ONNX Model Loader | Loads a main model, outputting its submodels.|
|
||||
|OpenCV Inpaint | Simple inpaint using opencv.|
|
||||
|Openpose Processor | Applies Openpose processing to image|
|
||||
|PIDI Processor | Applies PIDI processing to image|
|
||||
|Prompts from File | Loads prompts from a text file|
|
||||
|Random Integer | Outputs a single random integer.|
|
||||
|Random Range | Creates a collection of random numbers|
|
||||
|Integer Range | Creates a range of numbers from start to stop with step|
|
||||
|Integer Range of Size | Creates a range from start to start + size with step|
|
||||
|Resize Latents | Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.|
|
||||
|SDXL Compel Prompt | Parse prompt using compel package to conditioning.|
|
||||
|SDXL LoRA Loader | Apply selected lora to unet and text_encoder.|
|
||||
|SDXL Main Model Loader | Loads an sdxl base model, outputting its submodels.|
|
||||
|SDXL Refiner Compel Prompt | Parse prompt using compel package to conditioning.|
|
||||
|SDXL Refiner Model Loader | Loads an sdxl refiner model, outputting its submodels.|
|
||||
|Scale Latents | Scales latents by a given factor.|
|
||||
|Segment Anything Processor | Applies segment anything processing to image|
|
||||
|Show Image | Displays a provided image, and passes it forward in the pipeline.|
|
||||
|Step Param Easing | Experimental per-step parameter easing for denoising steps|
|
||||
|String Primitive Collection | A collection of string primitive values|
|
||||
|String Primitive | A string primitive value|
|
||||
|Subtract Integers | Subtracts two numbers|
|
||||
|Tile Resample Processor | Tile resampler processor|
|
||||
|Upscale (RealESRGAN) | Upscales an image using RealESRGAN.|
|
||||
|VAE Loader | Loads a VAE model, outputting a VaeLoaderOutput|
|
||||
|Zoe (Depth) Processor | Applies Zoe depth processing to image|
|
||||
| Node <img width=160 align="right"> | Function |
|
||||
| :------------------------------------------------------------ | :--------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Add Integers | Adds two numbers |
|
||||
| Boolean Primitive Collection | A collection of boolean primitive values |
|
||||
| Boolean Primitive | A boolean primitive value |
|
||||
| Canny Processor | Canny edge detection for ControlNet |
|
||||
| CenterPadCrop | Pad or crop an image's sides from the center by specified pixels. Positive values are outside of the image. |
|
||||
| CLIP Skip | Skip layers in clip text_encoder model. |
|
||||
| Collect | Collects values into a collection |
|
||||
| Color Correct | Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image. |
|
||||
| Color Primitive | A color primitive value |
|
||||
| Compel Prompt | Parse prompt using compel package to conditioning. |
|
||||
| Conditioning Primitive Collection | A collection of conditioning tensor primitive values |
|
||||
| Conditioning Primitive | A conditioning tensor primitive value |
|
||||
| Content Shuffle Processor | Applies content shuffle processing to image |
|
||||
| ControlNet | Collects ControlNet info to pass to other nodes |
|
||||
| Denoise Latents | Denoises noisy latents to decodable images |
|
||||
| Divide Integers | Divides two numbers |
|
||||
| Dynamic Prompt | Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator |
|
||||
| [FaceMask](./detailedNodes/faceTools.md#facemask) | Generates masks for faces in an image to use with Inpainting |
|
||||
| [FaceIdentifier](./detailedNodes/faceTools.md#faceidentifier) | Identifies and labels faces in an image |
|
||||
| [FaceOff](./detailedNodes/faceTools.md#faceoff) | Creates a new image that is a scaled bounding box with a mask on the face for Inpainting |
|
||||
| Float Math | Perform basic math operations on two floats |
|
||||
| Float Primitive Collection | A collection of float primitive values |
|
||||
| Float Primitive | A float primitive value |
|
||||
| Float Range | Creates a range |
|
||||
| HED (softedge) Processor | Applies HED edge detection to image |
|
||||
| Blur Image | Blurs an image |
|
||||
| Extract Image Channel | Gets a channel from an image. |
|
||||
| Image Primitive Collection | A collection of image primitive values |
|
||||
| Integer Math | Perform basic math operations on two integers |
|
||||
| Convert Image Mode | Converts an image to a different mode. |
|
||||
| Crop Image | Crops an image to a specified box. The box can be outside of the image. |
|
||||
| Image Hue Adjustment | Adjusts the Hue of an image. |
|
||||
| Inverse Lerp Image | Inverse linear interpolation of all pixels of an image |
|
||||
| Image Primitive | An image primitive value |
|
||||
| Lerp Image | Linear interpolation of all pixels of an image |
|
||||
| Offset Image Channel | Add to or subtract from an image color channel by a uniform value. |
|
||||
| Multiply Image Channel | Multiply or Invert an image color channel by a scalar value. |
|
||||
| Multiply Images | Multiplies two images together using `PIL.ImageChops.multiply()`. |
|
||||
| Blur NSFW Image | Add blur to NSFW-flagged images |
|
||||
| Paste Image | Pastes an image into another image. |
|
||||
| ImageProcessor | Base class for invocations that preprocess images for ControlNet |
|
||||
| Resize Image | Resizes an image to specific dimensions |
|
||||
| Round Float | Rounds a float to a specified number of decimal places |
|
||||
| Float to Integer | Converts a float to an integer. Optionally rounds to an even multiple of a input number. |
|
||||
| Scale Image | Scales an image by a factor |
|
||||
| Image to Latents | Encodes an image into latents. |
|
||||
| Add Invisible Watermark | Add an invisible watermark to an image |
|
||||
| Solid Color Infill | Infills transparent areas of an image with a solid color |
|
||||
| PatchMatch Infill | Infills transparent areas of an image using the PatchMatch algorithm |
|
||||
| Tile Infill | Infills transparent areas of an image with tiles of the image |
|
||||
| Integer Primitive Collection | A collection of integer primitive values |
|
||||
| Integer Primitive | An integer primitive value |
|
||||
| Iterate | Iterates over a list of items |
|
||||
| Latents Primitive Collection | A collection of latents tensor primitive values |
|
||||
| Latents Primitive | A latents tensor primitive value |
|
||||
| Latents to Image | Generates an image from latents. |
|
||||
| Leres (Depth) Processor | Applies leres processing to image |
|
||||
| Lineart Anime Processor | Applies line art anime processing to image |
|
||||
| Lineart Processor | Applies line art processing to image |
|
||||
| LoRA Loader | Apply selected lora to unet and text_encoder. |
|
||||
| Main Model Loader | Loads a main model, outputting its submodels. |
|
||||
| Combine Mask | Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`. |
|
||||
| Mask Edge | Applies an edge mask to an image |
|
||||
| Mask from Alpha | Extracts the alpha channel of an image as a mask. |
|
||||
| Mediapipe Face Processor | Applies mediapipe face processing to image |
|
||||
| Midas (Depth) Processor | Applies Midas depth processing to image |
|
||||
| MLSD Processor | Applies MLSD processing to image |
|
||||
| Multiply Integers | Multiplies two numbers |
|
||||
| Noise | Generates latent noise. |
|
||||
| Normal BAE Processor | Applies NormalBae processing to image |
|
||||
| ONNX Latents to Image | Generates an image from latents. |
|
||||
| ONNX Prompt (Raw) | A node to process inputs and produce outputs. May use dependency injection in **init** to receive providers. |
|
||||
| ONNX Text to Latents | Generates latents from conditionings. |
|
||||
| ONNX Model Loader | Loads a main model, outputting its submodels. |
|
||||
| OpenCV Inpaint | Simple inpaint using opencv. |
|
||||
| Openpose Processor | Applies Openpose processing to image |
|
||||
| PIDI Processor | Applies PIDI processing to image |
|
||||
| Prompts from File | Loads prompts from a text file |
|
||||
| Random Integer | Outputs a single random integer. |
|
||||
| Random Range | Creates a collection of random numbers |
|
||||
| Integer Range | Creates a range of numbers from start to stop with step |
|
||||
| Integer Range of Size | Creates a range from start to start + size with step |
|
||||
| Resize Latents | Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8. |
|
||||
| SDXL Compel Prompt | Parse prompt using compel package to conditioning. |
|
||||
| SDXL LoRA Loader | Apply selected lora to unet and text_encoder. |
|
||||
| SDXL Main Model Loader | Loads an sdxl base model, outputting its submodels. |
|
||||
| SDXL Refiner Compel Prompt | Parse prompt using compel package to conditioning. |
|
||||
| SDXL Refiner Model Loader | Loads an sdxl refiner model, outputting its submodels. |
|
||||
| Scale Latents | Scales latents by a given factor. |
|
||||
| Segment Anything Processor | Applies segment anything processing to image |
|
||||
| Show Image | Displays a provided image, and passes it forward in the pipeline. |
|
||||
| Step Param Easing | Experimental per-step parameter easing for denoising steps |
|
||||
| String Primitive Collection | A collection of string primitive values |
|
||||
| String Primitive | A string primitive value |
|
||||
| Subtract Integers | Subtracts two numbers |
|
||||
| Tile Resample Processor | Tile resampler processor |
|
||||
| Upscale (RealESRGAN) | Upscales an image using RealESRGAN. |
|
||||
| VAE Loader | Loads a VAE model, outputting a VaeLoaderOutput |
|
||||
| Zoe (Depth) Processor | Applies Zoe depth processing to image |
|
||||
|
||||
@@ -7,12 +7,12 @@ To use them, right click on your desired workflow, follow the link to GitHub and
|
||||
If you're interested in finding more workflows, checkout the [#share-your-workflows](https://discord.com/channels/1020123559063990373/1130291608097661000) channel in the InvokeAI Discord.
|
||||
|
||||
* [SD1.5 / SD2 Text to Image](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/Text_to_Image.json)
|
||||
* [SDXL Text to Image](https://github.com/invoke-ai/InvokeAI/blob/docs/main/docs/workflows/SDXL_Text_to_Image.json)
|
||||
* [SDXL Text to Image with Refiner](https://github.com/invoke-ai/InvokeAI/blob/docs/main/docs/workflows/SDXL_w_Refiner_Text_to_Image.json)
|
||||
* [Multi ControlNet (Canny & Depth)](https://github.com/invoke-ai/InvokeAI/blob/docs/main/docs/workflows/Multi_ControlNet_Canny_and_Depth.json)
|
||||
* [SDXL Text to Image](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/SDXL_Text_to_Image.json)
|
||||
* [SDXL Text to Image with Refiner](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/SDXL_w_Refiner_Text_to_Image.json)
|
||||
* [Multi ControlNet (Canny & Depth)](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/Multi_ControlNet_Canny_and_Depth.json)
|
||||
* [Tiled Upscaling with ControlNet](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/ESRGAN_img2img_upscale_w_Canny_ControlNet.json)
|
||||
* [Prompt From File](https://github.com/invoke-ai/InvokeAI/blob/docs/main/docs/workflows/Prompt_from_File.json)
|
||||
* [Face Detailer with IP-Adapter & ControlNet](https://github.com/invoke-ai/InvokeAI/blob/docs/main/docs/workflows/Face_Detailer_with_IP-Adapter_and_Canny.json.json)
|
||||
* [Prompt From File](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/Prompt_from_File.json)
|
||||
* [Face Detailer with IP-Adapter & ControlNet](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/Face_Detailer_with_IP-Adapter_and_Canny.json)
|
||||
* [FaceMask](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/FaceMask.json)
|
||||
* [FaceOff with 2x Face Scaling](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/FaceOff_FaceScale2x.json)
|
||||
* [QR Code Monster](https://github.com/invoke-ai/InvokeAI/blob/docs/main/docs/workflows/QR_Code_Monster.json)
|
||||
* [QR Code Monster](https://github.com/invoke-ai/InvokeAI/blob/main/docs/workflows/QR_Code_Monster.json)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.workflow_image_records.workflow_image_records_sqlite import SqliteWorkflowImageRecordsStorage
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@@ -30,7 +29,7 @@ from ..services.session_processor.session_processor_default import DefaultSessio
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from ..services.shared.default_graphs import create_system_graphs
|
||||
from ..services.shared.graph import GraphExecutionState, LibraryGraph
|
||||
from ..services.shared.sqlite import SqliteDatabase
|
||||
from ..services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from .events import FastAPIEventService
|
||||
@@ -94,7 +93,6 @@ class ApiDependencies:
|
||||
session_processor = DefaultSessionProcessor()
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
workflow_image_records = SqliteWorkflowImageRecordsStorage(db=db)
|
||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||
|
||||
services = InvocationServices(
|
||||
@@ -121,14 +119,12 @@ class ApiDependencies:
|
||||
session_processor=session_processor,
|
||||
session_queue=session_queue,
|
||||
urls=urls,
|
||||
workflow_image_records=workflow_image_records,
|
||||
workflow_records=workflow_records,
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
db.clean()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import typing
|
||||
from enum import Enum
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
from platform import python_version
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastapi import Body
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -40,6 +44,24 @@ class AppVersion(BaseModel):
|
||||
version: str = Field(description="App version")
|
||||
|
||||
|
||||
class AppDependencyVersions(BaseModel):
|
||||
"""App depencency Versions Response"""
|
||||
|
||||
accelerate: str = Field(description="accelerate version")
|
||||
compel: str = Field(description="compel version")
|
||||
cuda: Optional[str] = Field(description="CUDA version")
|
||||
diffusers: str = Field(description="diffusers version")
|
||||
numpy: str = Field(description="Numpy version")
|
||||
opencv: str = Field(description="OpenCV version")
|
||||
onnx: str = Field(description="ONNX version")
|
||||
pillow: str = Field(description="Pillow (PIL) version")
|
||||
python: str = Field(description="Python version")
|
||||
torch: str = Field(description="PyTorch version")
|
||||
torchvision: str = Field(description="PyTorch Vision version")
|
||||
transformers: str = Field(description="transformers version")
|
||||
xformers: Optional[str] = Field(description="xformers version")
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""App Config Response"""
|
||||
|
||||
@@ -54,6 +76,29 @@ async def get_version() -> AppVersion:
|
||||
return AppVersion(version=__version__)
|
||||
|
||||
|
||||
@app_router.get("/app_deps", operation_id="get_app_deps", status_code=200, response_model=AppDependencyVersions)
|
||||
async def get_app_deps() -> AppDependencyVersions:
|
||||
try:
|
||||
xformers = version("xformers")
|
||||
except PackageNotFoundError:
|
||||
xformers = None
|
||||
return AppDependencyVersions(
|
||||
accelerate=version("accelerate"),
|
||||
compel=version("compel"),
|
||||
cuda=torch.version.cuda,
|
||||
diffusers=version("diffusers"),
|
||||
numpy=version("numpy"),
|
||||
opencv=version("opencv-python"),
|
||||
onnx=version("onnx"),
|
||||
pillow=version("pillow"),
|
||||
python=python_version(),
|
||||
torch=torch.version.__version__,
|
||||
torchvision=version("torchvision"),
|
||||
transformers=version("transformers"),
|
||||
xformers=xformers,
|
||||
)
|
||||
|
||||
|
||||
@app_router.get("/config", operation_id="get_config", status_code=200, response_model=AppConfig)
|
||||
async def get_config() -> AppConfig:
|
||||
infill_methods = ["tile", "lama", "cv2"]
|
||||
|
||||
@@ -8,10 +8,11 @@ from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator, WorkflowFieldValidator
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID, WorkflowWithoutIDValidator
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@@ -73,7 +74,7 @@ async def upload_image(
|
||||
workflow_raw = pil_image.info.get("invokeai_workflow", None)
|
||||
if workflow_raw is not None:
|
||||
try:
|
||||
workflow = WorkflowFieldValidator.validate_json(workflow_raw)
|
||||
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw)
|
||||
except ValidationError:
|
||||
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
|
||||
pass
|
||||
@@ -184,6 +185,18 @@ async def get_image_metadata(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=Optional[WorkflowWithoutID]
|
||||
)
|
||||
async def get_image_workflow(
|
||||
image_name: str = Path(description="The name of image whose workflow to get"),
|
||||
) -> Optional[WorkflowWithoutID]:
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_workflow(image_name)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.api_route(
|
||||
"/i/{image_name}/full",
|
||||
methods=["GET", "HEAD"],
|
||||
|
||||
@@ -141,7 +141,7 @@ async def del_model_record(
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model using the configuration information appropriate for its type.
|
||||
|
||||
@@ -1,7 +1,19 @@
|
||||
from fastapi import APIRouter, Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, Query
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.invocations.baseinvocation import WorkflowField
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
Workflow,
|
||||
WorkflowCategory,
|
||||
WorkflowNotFoundError,
|
||||
WorkflowRecordDTO,
|
||||
WorkflowRecordListItemDTO,
|
||||
WorkflowRecordOrderBy,
|
||||
WorkflowWithoutID,
|
||||
)
|
||||
|
||||
workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
|
||||
@@ -10,11 +22,76 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
|
||||
"/i/{workflow_id}",
|
||||
operation_id="get_workflow",
|
||||
responses={
|
||||
200: {"model": WorkflowField},
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
},
|
||||
)
|
||||
async def get_workflow(
|
||||
workflow_id: str = Path(description="The workflow to get"),
|
||||
) -> WorkflowField:
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Gets a workflow"""
|
||||
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
try:
|
||||
return ApiDependencies.invoker.services.workflow_records.get(workflow_id)
|
||||
except WorkflowNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
|
||||
@workflows_router.patch(
|
||||
"/i/{workflow_id}",
|
||||
operation_id="update_workflow",
|
||||
responses={
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
},
|
||||
)
|
||||
async def update_workflow(
|
||||
workflow: Workflow = Body(description="The updated workflow", embed=True),
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Updates a workflow"""
|
||||
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow)
|
||||
|
||||
|
||||
@workflows_router.delete(
|
||||
"/i/{workflow_id}",
|
||||
operation_id="delete_workflow",
|
||||
)
|
||||
async def delete_workflow(
|
||||
workflow_id: str = Path(description="The workflow to delete"),
|
||||
) -> None:
|
||||
"""Deletes a workflow"""
|
||||
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
|
||||
|
||||
|
||||
@workflows_router.post(
|
||||
"/",
|
||||
operation_id="create_workflow",
|
||||
responses={
|
||||
200: {"model": WorkflowRecordDTO},
|
||||
},
|
||||
)
|
||||
async def create_workflow(
|
||||
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
|
||||
) -> WorkflowRecordDTO:
|
||||
"""Creates a workflow"""
|
||||
return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow)
|
||||
|
||||
|
||||
@workflows_router.get(
|
||||
"/",
|
||||
operation_id="list_workflows",
|
||||
responses={
|
||||
200: {"model": PaginatedResults[WorkflowRecordListItemDTO]},
|
||||
},
|
||||
)
|
||||
async def list_workflows(
|
||||
page: int = Query(default=0, description="The page to get"),
|
||||
per_page: int = Query(default=10, description="The number of workflows per page"),
|
||||
order_by: WorkflowRecordOrderBy = Query(
|
||||
default=WorkflowRecordOrderBy.Name, description="The attribute to order by"
|
||||
),
|
||||
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
|
||||
category: WorkflowCategory = Query(default=WorkflowCategory.User, description="The category of workflow to get"),
|
||||
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
"""Gets a page of workflows"""
|
||||
return ApiDependencies.invoker.services.workflow_records.get_many(
|
||||
page=page, per_page=per_page, order_by=order_by, direction=direction, query=query, category=category
|
||||
)
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
|
||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||
# values from the command line or config file.
|
||||
import sys
|
||||
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
app_config.parse_args()
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
sys.exit(0)
|
||||
|
||||
if True: # hack to make flake8 happy with imports coming after setting up the config
|
||||
import asyncio
|
||||
@@ -16,6 +19,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
import socket
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
@@ -23,7 +27,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
@@ -34,7 +38,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
# noinspection PyUnresolvedReferences
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
from .api.dependencies import ApiDependencies
|
||||
@@ -51,7 +54,12 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
workflows,
|
||||
)
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import BaseInvocation, UIConfigBase, _InputField, _OutputField
|
||||
from .invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
InputFieldJSONSchemaExtra,
|
||||
OutputFieldJSONSchemaExtra,
|
||||
UIConfigBase,
|
||||
)
|
||||
|
||||
if is_mps_available():
|
||||
import invokeai.backend.util.mps_fixes # noqa: F401 (monkeypatching on import)
|
||||
@@ -147,7 +155,11 @@ def custom_openapi() -> dict[str, Any]:
|
||||
|
||||
# Add Node Editor UI helper schemas
|
||||
ui_config_schemas = models_json_schema(
|
||||
[(UIConfigBase, "serialization"), (_InputField, "serialization"), (_OutputField, "serialization")],
|
||||
[
|
||||
(UIConfigBase, "serialization"),
|
||||
(InputFieldJSONSchemaExtra, "serialization"),
|
||||
(OutputFieldJSONSchemaExtra, "serialization"),
|
||||
],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items():
|
||||
@@ -155,7 +167,7 @@ def custom_openapi() -> dict[str, Any]:
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__
|
||||
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
|
||||
output_type = signature(obj=invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
||||
@@ -273,7 +285,4 @@ def invoke_api() -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if app_config.version:
|
||||
print(f"InvokeAI version {__version__}")
|
||||
else:
|
||||
invoke_api()
|
||||
invoke_api()
|
||||
|
||||
@@ -5,7 +5,7 @@ from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
|
||||
custom_nodes_path = Path(InvokeAIAppConfig.get_config().custom_nodes_path.absolute())
|
||||
custom_nodes_path = Path(InvokeAIAppConfig.get_config().custom_nodes_path.resolve())
|
||||
custom_nodes_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
custom_nodes_init_path = str(custom_nodes_path / "__init__.py")
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI team
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -8,7 +8,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
||||
@@ -16,12 +16,19 @@ from pydantic.fields import FieldInfo, _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
CUSTOM_NODE_PACK_SUFFIX = "__invokeai-custom-node"
|
||||
|
||||
|
||||
class InvalidVersionError(ValueError):
|
||||
pass
|
||||
@@ -31,7 +38,7 @@ class InvalidFieldError(TypeError):
|
||||
pass
|
||||
|
||||
|
||||
class Input(str, Enum):
|
||||
class Input(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of input a field accepts.
|
||||
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
||||
@@ -45,86 +52,124 @@ class Input(str, Enum):
|
||||
Any = "any"
|
||||
|
||||
|
||||
class UIType(str, Enum):
|
||||
class FieldKind(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
Type hints for the UI.
|
||||
If a field should be provided a data type that does not exactly match the python type of the field, \
|
||||
use this to provide the type that should be used instead. See the node development docs for detail \
|
||||
on adding a new field type, which involves client-side changes.
|
||||
The kind of field.
|
||||
- `Input`: An input field on a node.
|
||||
- `Output`: An output field on a node.
|
||||
- `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is
|
||||
one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name
|
||||
"metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic,
|
||||
allowing "metadata" for that field.
|
||||
- `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs,
|
||||
but which are used to store information about the node. For example, the `id` and `type` fields are node
|
||||
attributes.
|
||||
|
||||
The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app
|
||||
startup, and when generating the OpenAPI schema for the workflow editor.
|
||||
"""
|
||||
|
||||
# region Primitives
|
||||
Boolean = "boolean"
|
||||
Color = "ColorField"
|
||||
Conditioning = "ConditioningField"
|
||||
Control = "ControlField"
|
||||
Float = "float"
|
||||
Image = "ImageField"
|
||||
Integer = "integer"
|
||||
Latents = "LatentsField"
|
||||
String = "string"
|
||||
# endregion
|
||||
Input = "input"
|
||||
Output = "output"
|
||||
Internal = "internal"
|
||||
NodeAttribute = "node_attribute"
|
||||
|
||||
# region Collection Primitives
|
||||
BooleanCollection = "BooleanCollection"
|
||||
ColorCollection = "ColorCollection"
|
||||
ConditioningCollection = "ConditioningCollection"
|
||||
ControlCollection = "ControlCollection"
|
||||
FloatCollection = "FloatCollection"
|
||||
ImageCollection = "ImageCollection"
|
||||
IntegerCollection = "IntegerCollection"
|
||||
LatentsCollection = "LatentsCollection"
|
||||
StringCollection = "StringCollection"
|
||||
# endregion
|
||||
|
||||
# region Polymorphic Primitives
|
||||
BooleanPolymorphic = "BooleanPolymorphic"
|
||||
ColorPolymorphic = "ColorPolymorphic"
|
||||
ConditioningPolymorphic = "ConditioningPolymorphic"
|
||||
ControlPolymorphic = "ControlPolymorphic"
|
||||
FloatPolymorphic = "FloatPolymorphic"
|
||||
ImagePolymorphic = "ImagePolymorphic"
|
||||
IntegerPolymorphic = "IntegerPolymorphic"
|
||||
LatentsPolymorphic = "LatentsPolymorphic"
|
||||
StringPolymorphic = "StringPolymorphic"
|
||||
# endregion
|
||||
class UIType(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
Type hints for the UI for situations in which the field type is not enough to infer the correct UI type.
|
||||
|
||||
# region Models
|
||||
MainModel = "MainModelField"
|
||||
- Model Fields
|
||||
The most common node-author-facing use will be for model fields. Internally, there is no difference
|
||||
between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the
|
||||
base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that
|
||||
the field is an SDXL main model field.
|
||||
|
||||
- Any Field
|
||||
We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to
|
||||
indicate that the field accepts any type. Use with caution. This cannot be used on outputs.
|
||||
|
||||
- Scheduler Field
|
||||
Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field.
|
||||
|
||||
- Internal Fields
|
||||
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
|
||||
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
|
||||
should not be used by node authors.
|
||||
|
||||
- DEPRECATED Fields
|
||||
These types are deprecated and should not be used by node authors. A warning will be logged if one is
|
||||
used, and the type will be ignored. They are included here for backwards compatibility.
|
||||
"""
|
||||
|
||||
# region Model Field Types
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VaeModel = "VaeModelField"
|
||||
VaeModel = "VAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
UNet = "UNetField"
|
||||
Vae = "VaeField"
|
||||
CLIP = "ClipField"
|
||||
# endregion
|
||||
|
||||
# region Iterate/Collect
|
||||
Collection = "Collection"
|
||||
CollectionItem = "CollectionItem"
|
||||
# region Misc Field Types
|
||||
Scheduler = "SchedulerField"
|
||||
Any = "AnyField"
|
||||
# endregion
|
||||
|
||||
# region Misc
|
||||
Enum = "enum"
|
||||
Scheduler = "Scheduler"
|
||||
WorkflowField = "WorkflowField"
|
||||
IsIntermediate = "IsIntermediate"
|
||||
BoardField = "BoardField"
|
||||
Any = "Any"
|
||||
MetadataItem = "MetadataItem"
|
||||
MetadataItemCollection = "MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "MetadataItemPolymorphic"
|
||||
MetadataDict = "MetadataDict"
|
||||
# region Internal Field Types
|
||||
_Collection = "CollectionField"
|
||||
_CollectionItem = "CollectionItemField"
|
||||
# endregion
|
||||
|
||||
# region DEPRECATED
|
||||
Boolean = "DEPRECATED_Boolean"
|
||||
Color = "DEPRECATED_Color"
|
||||
Conditioning = "DEPRECATED_Conditioning"
|
||||
Control = "DEPRECATED_Control"
|
||||
Float = "DEPRECATED_Float"
|
||||
Image = "DEPRECATED_Image"
|
||||
Integer = "DEPRECATED_Integer"
|
||||
Latents = "DEPRECATED_Latents"
|
||||
String = "DEPRECATED_String"
|
||||
BooleanCollection = "DEPRECATED_BooleanCollection"
|
||||
ColorCollection = "DEPRECATED_ColorCollection"
|
||||
ConditioningCollection = "DEPRECATED_ConditioningCollection"
|
||||
ControlCollection = "DEPRECATED_ControlCollection"
|
||||
FloatCollection = "DEPRECATED_FloatCollection"
|
||||
ImageCollection = "DEPRECATED_ImageCollection"
|
||||
IntegerCollection = "DEPRECATED_IntegerCollection"
|
||||
LatentsCollection = "DEPRECATED_LatentsCollection"
|
||||
StringCollection = "DEPRECATED_StringCollection"
|
||||
BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic"
|
||||
ColorPolymorphic = "DEPRECATED_ColorPolymorphic"
|
||||
ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic"
|
||||
ControlPolymorphic = "DEPRECATED_ControlPolymorphic"
|
||||
FloatPolymorphic = "DEPRECATED_FloatPolymorphic"
|
||||
ImagePolymorphic = "DEPRECATED_ImagePolymorphic"
|
||||
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
|
||||
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
|
||||
StringPolymorphic = "DEPRECATED_StringPolymorphic"
|
||||
MainModel = "DEPRECATED_MainModel"
|
||||
UNet = "DEPRECATED_UNet"
|
||||
Vae = "DEPRECATED_Vae"
|
||||
CLIP = "DEPRECATED_CLIP"
|
||||
Collection = "DEPRECATED_Collection"
|
||||
CollectionItem = "DEPRECATED_CollectionItem"
|
||||
Enum = "DEPRECATED_Enum"
|
||||
WorkflowField = "DEPRECATED_WorkflowField"
|
||||
IsIntermediate = "DEPRECATED_IsIntermediate"
|
||||
BoardField = "DEPRECATED_BoardField"
|
||||
MetadataItem = "DEPRECATED_MetadataItem"
|
||||
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
|
||||
MetadataDict = "DEPRECATED_MetadataDict"
|
||||
# endregion
|
||||
|
||||
|
||||
class UIComponent(str, Enum):
|
||||
class UIComponent(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of UI component to use for a field, used to override the default components, which are \
|
||||
The type of UI component to use for a field, used to override the default components, which are
|
||||
inferred from the field type.
|
||||
"""
|
||||
|
||||
@@ -133,21 +178,22 @@ class UIComponent(str, Enum):
|
||||
Slider = "slider"
|
||||
|
||||
|
||||
class _InputField(BaseModel):
|
||||
class InputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
*DO NOT USE*
|
||||
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||
schema generation, and Typescript type generation from that schema. It serves no functional
|
||||
purpose in the backend.
|
||||
Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution,
|
||||
and by the workflow editor during schema parsing and UI rendering.
|
||||
"""
|
||||
|
||||
input: Input
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_component: Optional[UIComponent]
|
||||
ui_order: Optional[int]
|
||||
ui_choice_labels: Optional[dict[str, str]]
|
||||
item_default: Optional[Any]
|
||||
orig_required: bool
|
||||
field_kind: FieldKind
|
||||
default: Optional[Any] = None
|
||||
orig_default: Optional[Any] = None
|
||||
ui_hidden: bool = False
|
||||
ui_type: Optional[UIType] = None
|
||||
ui_component: Optional[UIComponent] = None
|
||||
ui_order: Optional[int] = None
|
||||
ui_choice_labels: Optional[dict[str, str]] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
@@ -155,14 +201,13 @@ class _InputField(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class _OutputField(BaseModel):
|
||||
class OutputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
*DO NOT USE*
|
||||
This helper class is used to tell the client about our custom field attributes via OpenAPI
|
||||
schema generation, and Typescript type generation from that schema. It serves no functional
|
||||
purpose in the backend.
|
||||
Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor
|
||||
during schema parsing and UI rendering.
|
||||
"""
|
||||
|
||||
field_kind: FieldKind
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
@@ -173,13 +218,9 @@ class _OutputField(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
def get_type(klass: BaseModel) -> str:
|
||||
"""Helper function to get an invocation or invocation output's type. This is the default value of the `type` field."""
|
||||
return klass.model_fields["type"].default
|
||||
|
||||
|
||||
def InputField(
|
||||
# copied from pydantic's Field
|
||||
# TODO: Can we support default_factory?
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
@@ -203,12 +244,11 @@ def InputField(
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
item_default: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||
@@ -228,28 +268,58 @@ def InputField(
|
||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
|
||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
|
||||
|
||||
: param bool item_default: [None] Specifies the default item value, if this is a collection input. \
|
||||
Ignored for non-collection fields.
|
||||
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
|
||||
"""
|
||||
|
||||
json_schema_extra_: dict[str, Any] = {
|
||||
"input": input,
|
||||
"ui_type": ui_type,
|
||||
"ui_component": ui_component,
|
||||
"ui_hidden": ui_hidden,
|
||||
"ui_order": ui_order,
|
||||
"item_default": item_default,
|
||||
"ui_choice_labels": ui_choice_labels,
|
||||
"_field_kind": "input",
|
||||
}
|
||||
json_schema_extra_ = InputFieldJSONSchemaExtra(
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
field_kind=FieldKind.Input,
|
||||
orig_required=True,
|
||||
)
|
||||
|
||||
"""
|
||||
There is a conflict between the typing of invocation definitions and the typing of an invocation's
|
||||
`invoke()` function.
|
||||
|
||||
On instantiation of a node, the invocation definition is used to create the python class. At this time,
|
||||
any number of fields may be optional, because they may be provided by connections.
|
||||
|
||||
On calling of `invoke()`, however, those fields may be required.
|
||||
|
||||
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||
|
||||
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||
the field may not be present. This is fine, because that image field will be provided by a
|
||||
connection from an ancestor node, which outputs an image.
|
||||
|
||||
This means we want to type the `image` field as optional for the node class definition, but required
|
||||
for the `invoke()` function.
|
||||
|
||||
If we use `typing.Optional` in the node class definition, the field will be typed as optional in the
|
||||
`invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or
|
||||
any static type analysis tools will complain.
|
||||
|
||||
To get around this, in node class definitions, we type all fields correctly for the `invoke()` function,
|
||||
but secretly make them optional in `InputField()`. We also store the original required bool and/or default
|
||||
value. When we call `invoke()`, we use this stored information to do an additional check on the class.
|
||||
"""
|
||||
|
||||
if default_factory is not _Unset and default_factory is not None:
|
||||
default = default_factory()
|
||||
logger.warn('"default_factory" is not supported, calling it now to set "default"')
|
||||
|
||||
# These are the args we may wish pass to the pydantic `Field()` function
|
||||
field_args = {
|
||||
"default": default,
|
||||
"default_factory": default_factory,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"pattern": pattern,
|
||||
@@ -266,70 +336,34 @@ def InputField(
|
||||
"max_length": max_length,
|
||||
}
|
||||
|
||||
"""
|
||||
Invocation definitions have their fields typed correctly for their `invoke()` functions.
|
||||
This typing is often more specific than the actual invocation definition requires, because
|
||||
fields may have values provided only by connections.
|
||||
|
||||
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||
|
||||
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||
the field may not be present. This is fine, because that image field will be provided by a
|
||||
an ancestor node that outputs the image.
|
||||
|
||||
So we'd like to type that `image` field as `Optional[ImageField]`. If we do that, however, then
|
||||
we need to handle a lot of extra logic in the `invoke()` function to check if the field has a
|
||||
value or not. This is very tedious.
|
||||
|
||||
Ideally, the invocation definition would be able to specify that the field is required during
|
||||
invocation, but optional during instantiation. So the field would be typed as `image: ImageField`,
|
||||
but when calling the `invoke()` function, we raise an error if the field is not present.
|
||||
|
||||
To do this, we need to do a bit of fanagling to make the pydantic field optional, and then do
|
||||
extra validation when calling `invoke()`.
|
||||
|
||||
There is some additional logic here to cleaning create the pydantic field via the wrapper.
|
||||
"""
|
||||
|
||||
# Filter out field args not provided
|
||||
# We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected
|
||||
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
|
||||
|
||||
if (default is not PydanticUndefined) and (default_factory is not PydanticUndefined):
|
||||
raise ValueError("Cannot specify both default and default_factory")
|
||||
# Because we are manually making fields optional, we need to store the original required bool for reference later
|
||||
json_schema_extra_.orig_required = default is PydanticUndefined
|
||||
|
||||
# because we are manually making fields optional, we need to store the original required bool for reference later
|
||||
if default is PydanticUndefined and default_factory is PydanticUndefined:
|
||||
json_schema_extra_.update({"orig_required": True})
|
||||
else:
|
||||
json_schema_extra_.update({"orig_required": False})
|
||||
|
||||
# make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||
if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined:
|
||||
# Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||
if input is Input.Any or input is Input.Connection:
|
||||
default_ = None if default is PydanticUndefined else default
|
||||
provided_args.update({"default": default_})
|
||||
if default is not PydanticUndefined:
|
||||
# before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value
|
||||
json_schema_extra_.update({"default": default})
|
||||
json_schema_extra_.update({"orig_default": default})
|
||||
elif default is not PydanticUndefined and default_factory is PydanticUndefined:
|
||||
# Before invoking, we'll check for the original default value and set it on the field if the field has no value
|
||||
json_schema_extra_.default = default
|
||||
json_schema_extra_.orig_default = default
|
||||
elif default is not PydanticUndefined:
|
||||
default_ = default
|
||||
provided_args.update({"default": default_})
|
||||
json_schema_extra_.update({"orig_default": default_})
|
||||
elif default_factory is not PydanticUndefined:
|
||||
provided_args.update({"default_factory": default_factory})
|
||||
# TODO: cannot serialize default_factory...
|
||||
# json_schema_extra_.update(dict(orig_default_factory=default_factory))
|
||||
json_schema_extra_.orig_default = default_
|
||||
|
||||
return Field(
|
||||
**provided_args,
|
||||
json_schema_extra=json_schema_extra_,
|
||||
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def OutputField(
|
||||
# copied from pydantic's Field
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
@@ -362,13 +396,12 @@ def OutputField(
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
: param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
|
||||
: param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
default=default,
|
||||
default_factory=default_factory,
|
||||
title=title,
|
||||
description=description,
|
||||
pattern=pattern,
|
||||
@@ -383,12 +416,12 @@ def OutputField(
|
||||
decimal_places=decimal_places,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
json_schema_extra={
|
||||
"ui_type": ui_type,
|
||||
"ui_hidden": ui_hidden,
|
||||
"ui_order": ui_order,
|
||||
"_field_kind": "output",
|
||||
},
|
||||
json_schema_extra=OutputFieldJSONSchemaExtra(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
field_kind=FieldKind.Output,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
@@ -401,10 +434,10 @@ class UIConfigBase(BaseModel):
|
||||
tags: Optional[list[str]] = Field(default_factory=None, description="The node's tags")
|
||||
title: Optional[str] = Field(default=None, description="The node's display name")
|
||||
category: Optional[str] = Field(default=None, description="The node's category")
|
||||
version: Optional[str] = Field(
|
||||
default=None,
|
||||
version: str = Field(
|
||||
description='The node\'s version. Should be a valid semver string e.g. "1.0.0" or "3.8.13".',
|
||||
)
|
||||
node_pack: Optional[str] = Field(default=None, description="Whether or not this is a custom node")
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
@@ -420,6 +453,7 @@ class InvocationContext:
|
||||
queue_id: str
|
||||
queue_item_id: int
|
||||
queue_batch_id: str
|
||||
workflow: Optional[WorkflowWithoutID]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -428,12 +462,14 @@ class InvocationContext:
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
workflow: Optional[WorkflowWithoutID],
|
||||
):
|
||||
self.services = services
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.queue_id = queue_id
|
||||
self.queue_item_id = queue_item_id
|
||||
self.queue_batch_id = queue_batch_id
|
||||
self.workflow = workflow
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
@@ -447,29 +483,39 @@ class BaseInvocationOutput(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
"""Registers an invocation output."""
|
||||
cls._output_classes.add(output)
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||
"""Gets all invocation outputs."""
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_outputs_union(cls) -> UnionType:
|
||||
"""Gets a union of all invocation outputs."""
|
||||
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
|
||||
return outputs_union # type: ignore [return-value]
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
return (get_type(i) for i in BaseInvocationOutput.get_outputs())
|
||||
"""Gets all invocation output types."""
|
||||
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Gets the invocation output's type, as provided by the `@invocation_output` decorator."""
|
||||
return cls.model_fields["type"].default
|
||||
|
||||
model_config = ConfigDict(
|
||||
protected_namespaces=(),
|
||||
validate_assignment=True,
|
||||
@@ -499,21 +545,29 @@ class BaseInvocation(ABC, BaseModel):
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
"""Gets the invocation's type, as provided by the `@invocation` decorator."""
|
||||
return cls.model_fields["type"].default
|
||||
|
||||
@classmethod
|
||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||
"""Registers an invocation."""
|
||||
cls._invocation_classes.add(invocation)
|
||||
|
||||
@classmethod
|
||||
def get_invocations_union(cls) -> UnionType:
|
||||
"""Gets a union of all invocation types."""
|
||||
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
|
||||
return invocations_union # type: ignore [return-value]
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
"""Gets all invocations, respecting the allowlist and denylist."""
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
allowed_invocations: set[BaseInvocation] = set()
|
||||
for sc in cls._invocation_classes:
|
||||
invocation_type = get_type(sc)
|
||||
invocation_type = sc.get_type()
|
||||
is_in_allowlist = (
|
||||
invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True
|
||||
)
|
||||
@@ -526,28 +580,32 @@ class BaseInvocation(ABC, BaseModel):
|
||||
|
||||
@classmethod
|
||||
def get_invocations_map(cls) -> dict[str, BaseInvocation]:
|
||||
# Get the type strings out of the literals and into a dictionary
|
||||
return {get_type(i): i for i in BaseInvocation.get_invocations()}
|
||||
"""Gets a map of all invocation types to their invocation classes."""
|
||||
return {i.get_type(): i for i in BaseInvocation.get_invocations()}
|
||||
|
||||
@classmethod
|
||||
def get_invocation_types(cls) -> Iterable[str]:
|
||||
return (get_type(i) for i in BaseInvocation.get_invocations())
|
||||
"""Gets all invocation types."""
|
||||
return (i.get_type() for i in BaseInvocation.get_invocations())
|
||||
|
||||
@classmethod
|
||||
def get_output_type(cls) -> BaseInvocationOutput:
|
||||
def get_output_annotation(cls) -> BaseInvocationOutput:
|
||||
"""Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method)."""
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
# Add the various UI-facing attributes to the schema. These are used to build the invocation templates.
|
||||
uiconfig = getattr(model_class, "UIConfig", None)
|
||||
if uiconfig and hasattr(uiconfig, "title"):
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig and hasattr(uiconfig, "tags"):
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig and hasattr(uiconfig, "category"):
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig and hasattr(uiconfig, "version"):
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
||||
if uiconfig is not None:
|
||||
if uiconfig.title is not None:
|
||||
schema["title"] = uiconfig.title
|
||||
if uiconfig.tags is not None:
|
||||
schema["tags"] = uiconfig.tags
|
||||
if uiconfig.category is not None:
|
||||
schema["category"] = uiconfig.category
|
||||
if uiconfig.node_pack is not None:
|
||||
schema["node_pack"] = uiconfig.node_pack
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
@@ -559,6 +617,10 @@ class BaseInvocation(ABC, BaseModel):
|
||||
pass
|
||||
|
||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
"""
|
||||
Internal invoke method, calls `invoke()` after some prep.
|
||||
Handles optional fields that are required to call `invoke()` and invocation cache.
|
||||
"""
|
||||
for field_name, field in self.model_fields.items():
|
||||
if not field.json_schema_extra or callable(field.json_schema_extra):
|
||||
# something has gone terribly awry, we should always have this and it should be a dict
|
||||
@@ -598,21 +660,20 @@ class BaseInvocation(ABC, BaseModel):
|
||||
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||
return self.invoke(context)
|
||||
|
||||
def get_type(self) -> str:
|
||||
return self.model_fields["type"].default
|
||||
|
||||
id: str = Field(
|
||||
default_factory=uuid_string,
|
||||
description="The id of this instance of an invocation. Must be unique among all instances of invocations.",
|
||||
json_schema_extra={"_field_kind": "internal"},
|
||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
is_intermediate: bool = Field(
|
||||
default=False,
|
||||
description="Whether or not this is an intermediate invocation.",
|
||||
json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"},
|
||||
json_schema_extra={"ui_type": "IsIntermediate", "field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
use_cache: bool = Field(
|
||||
default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"}
|
||||
default=True,
|
||||
description="Whether or not to use the cache",
|
||||
json_schema_extra={"field_kind": FieldKind.NodeAttribute},
|
||||
)
|
||||
|
||||
UIConfig: ClassVar[Type[UIConfigBase]]
|
||||
@@ -629,12 +690,15 @@ class BaseInvocation(ABC, BaseModel):
|
||||
TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation)
|
||||
|
||||
|
||||
RESERVED_INPUT_FIELD_NAMES = {
|
||||
RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
|
||||
"id",
|
||||
"is_intermediate",
|
||||
"use_cache",
|
||||
"type",
|
||||
"workflow",
|
||||
}
|
||||
|
||||
RESERVED_INPUT_FIELD_NAMES = {
|
||||
"metadata",
|
||||
}
|
||||
|
||||
@@ -652,40 +716,59 @@ RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())}
|
||||
def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None:
|
||||
"""
|
||||
Validates the fields of an invocation or invocation output:
|
||||
- must not override any pydantic reserved fields
|
||||
- must be created via `InputField`, `OutputField`, or be an internal field defined in this file
|
||||
- Must not override any pydantic reserved fields
|
||||
- Must have a type annotation
|
||||
- Must have a json_schema_extra dict
|
||||
- Must have field_kind in json_schema_extra
|
||||
- Field name must not be reserved, according to its field_kind
|
||||
"""
|
||||
for name, field in model_fields.items():
|
||||
if name in RESERVED_PYDANTIC_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)')
|
||||
|
||||
field_kind = (
|
||||
# _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file
|
||||
field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None
|
||||
)
|
||||
if not field.annotation:
|
||||
raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)')
|
||||
|
||||
if not isinstance(field.json_schema_extra, dict):
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field definition for "{name}" on "{model_type}" (missing json_schema_extra dict)'
|
||||
)
|
||||
|
||||
field_kind = field.json_schema_extra.get("field_kind", None)
|
||||
|
||||
# must have a field_kind
|
||||
if field_kind is None or field_kind not in {"input", "output", "internal"}:
|
||||
if not isinstance(field_kind, FieldKind):
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)'
|
||||
)
|
||||
|
||||
if field_kind == "input" and name in RESERVED_INPUT_FIELD_NAMES:
|
||||
if field_kind is FieldKind.Input and (
|
||||
name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES
|
||||
):
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)')
|
||||
|
||||
if field_kind == "output" and name in RESERVED_OUTPUT_FIELD_NAMES:
|
||||
if field_kind is FieldKind.Output and name in RESERVED_OUTPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)')
|
||||
|
||||
# internal fields *must* be in the reserved list
|
||||
if (
|
||||
field_kind == "internal"
|
||||
and name not in RESERVED_INPUT_FIELD_NAMES
|
||||
and name not in RESERVED_OUTPUT_FIELD_NAMES
|
||||
):
|
||||
if (field_kind is FieldKind.Internal) and name not in RESERVED_INPUT_FIELD_NAMES:
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)'
|
||||
)
|
||||
|
||||
# node attribute fields *must* be in the reserved list
|
||||
if (
|
||||
field_kind is FieldKind.NodeAttribute
|
||||
and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES
|
||||
and name not in RESERVED_OUTPUT_FIELD_NAMES
|
||||
):
|
||||
raise InvalidFieldError(
|
||||
f'Invalid field name "{name}" on "{model_type}" (node attribute field without reserved name)'
|
||||
)
|
||||
|
||||
ui_type = field.json_schema_extra.get("ui_type", None)
|
||||
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
|
||||
logger.warn(f"\"UIType.{ui_type.split('_')[-1]}\" is deprecated, ignoring")
|
||||
field.json_schema_extra.pop("ui_type")
|
||||
return None
|
||||
|
||||
|
||||
@@ -720,21 +803,30 @@ def invocation(
|
||||
validate_fields(cls.model_fields, invocation_type)
|
||||
|
||||
# Add OpenAPI schema extras
|
||||
uiconf_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name:
|
||||
cls.UIConfig = type(uiconf_name, (UIConfigBase,), {})
|
||||
if title is not None:
|
||||
cls.UIConfig.title = title
|
||||
if tags is not None:
|
||||
cls.UIConfig.tags = tags
|
||||
if category is not None:
|
||||
cls.UIConfig.category = category
|
||||
uiconfig_name = cls.__qualname__ + ".UIConfig"
|
||||
if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconfig_name:
|
||||
cls.UIConfig = type(uiconfig_name, (UIConfigBase,), {})
|
||||
cls.UIConfig.title = title
|
||||
cls.UIConfig.tags = tags
|
||||
cls.UIConfig.category = category
|
||||
|
||||
# Grab the node pack's name from the module name, if it's a custom node
|
||||
is_custom_node = cls.__module__.rsplit(".", 1)[0] == "invokeai.app.invocations"
|
||||
if is_custom_node:
|
||||
cls.UIConfig.node_pack = cls.__module__.split(".")[0]
|
||||
else:
|
||||
cls.UIConfig.node_pack = None
|
||||
|
||||
if version is not None:
|
||||
try:
|
||||
semver.Version.parse(version)
|
||||
except ValueError as e:
|
||||
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
|
||||
cls.UIConfig.version = version
|
||||
else:
|
||||
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
|
||||
cls.UIConfig.version = "1.0.0"
|
||||
|
||||
if use_cache is not None:
|
||||
cls.model_fields["use_cache"].default = use_cache
|
||||
|
||||
@@ -749,7 +841,7 @@ def invocation(
|
||||
|
||||
invocation_type_annotation = Literal[invocation_type] # type: ignore
|
||||
invocation_type_field = Field(
|
||||
title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"}
|
||||
title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
@@ -795,7 +887,9 @@ def invocation_output(
|
||||
# Add the output type to the model.
|
||||
|
||||
output_type_annotation = Literal[output_type] # type: ignore
|
||||
output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"})
|
||||
output_type_field = Field(
|
||||
title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute}
|
||||
)
|
||||
|
||||
docstring = cls.__doc__
|
||||
cls = create_model(
|
||||
@@ -813,24 +907,6 @@ def invocation_output(
|
||||
return wrapper
|
||||
|
||||
|
||||
class WorkflowField(RootModel):
|
||||
"""
|
||||
Pydantic model for workflows with custom root of type dict[str, Any].
|
||||
Workflows are stored without a strict schema.
|
||||
"""
|
||||
|
||||
root: dict[str, Any] = Field(description="The workflow")
|
||||
|
||||
|
||||
WorkflowFieldValidator = TypeAdapter(WorkflowField)
|
||||
|
||||
|
||||
class WithWorkflow(BaseModel):
|
||||
workflow: Optional[WorkflowField] = Field(
|
||||
default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"}
|
||||
)
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
@@ -845,5 +921,21 @@ MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||
|
||||
class WithMetadata(BaseModel):
|
||||
metadata: Optional[MetadataField] = Field(
|
||||
default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"}
|
||||
default=None,
|
||||
description=FieldDescriptions.metadata,
|
||||
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||
field_kind=FieldKind.Internal,
|
||||
input=Input.Connection,
|
||||
orig_required=False,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class WithWorkflow:
|
||||
workflow = None
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
logger.warn(
|
||||
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
|
||||
)
|
||||
super().__init_subclass__()
|
||||
|
||||
@@ -5,7 +5,7 @@ import numpy as np
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
@@ -55,7 +55,7 @@ class RangeOfSizeInvocation(BaseInvocation):
|
||||
title="Random Range",
|
||||
tags=["range", "integer", "random", "collection"],
|
||||
category="collections",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
use_cache=False,
|
||||
)
|
||||
class RandomRangeInvocation(BaseInvocation):
|
||||
@@ -65,10 +65,10 @@ class RandomRangeInvocation(BaseInvocation):
|
||||
high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value")
|
||||
size: int = InputField(default=1, description="The number of values to generate")
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed for the RNG (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
|
||||
@@ -39,7 +39,6 @@ from .baseinvocation import (
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -129,7 +128,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
@@ -153,7 +152,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
@@ -173,7 +172,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
title="Canny Processor",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
@@ -196,7 +195,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="HED (softedge) Processor",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
@@ -225,7 +224,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Processor",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
@@ -247,7 +246,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Anime Processor",
|
||||
tags=["controlnet", "lineart", "anime"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
@@ -270,7 +269,7 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Openpose Processor",
|
||||
tags=["controlnet", "openpose", "pose"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Openpose processing to image"""
|
||||
@@ -295,7 +294,7 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Midas Depth Processor",
|
||||
tags=["controlnet", "midas"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
@@ -322,7 +321,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Normal BAE Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
@@ -339,7 +338,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.1.0"
|
||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0"
|
||||
)
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
@@ -362,7 +361,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.1.0"
|
||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0"
|
||||
)
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
@@ -389,7 +388,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Content Shuffle Processor",
|
||||
tags=["controlnet", "contentshuffle"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
@@ -419,7 +418,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Zoe (Depth) Processor",
|
||||
tags=["controlnet", "zoe", "depth"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
@@ -435,7 +434,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Mediapipe Face Processor",
|
||||
tags=["controlnet", "mediapipe", "face"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
@@ -458,7 +457,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Leres (Depth) Processor",
|
||||
tags=["controlnet", "leres", "depth"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
@@ -487,7 +486,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Tile Resample Processor",
|
||||
tags=["controlnet", "tile"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
@@ -527,7 +526,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Segment Anything Processor",
|
||||
tags=["controlnet", "segmentanything"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
@@ -569,7 +568,7 @@ class SamDetectorReproducibleColors(SamDetector):
|
||||
title="Color Map Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a color map from the provided image"""
|
||||
|
||||
@@ -32,13 +32,15 @@ for d in Path(__file__).parent.iterdir():
|
||||
if module_name in globals():
|
||||
continue
|
||||
|
||||
# we have a legit module to import
|
||||
# load the module, appending adding a suffix to identify it as a custom node pack
|
||||
spec = spec_from_file_location(module_name, init.absolute())
|
||||
|
||||
if spec is None or spec.loader is None:
|
||||
logger.warn(f"Could not load {init}")
|
||||
continue
|
||||
|
||||
logger.info(f"Loading node pack {module_name}")
|
||||
|
||||
module = module_from_spec(spec)
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
@@ -47,5 +49,5 @@ for d in Path(__file__).parent.iterdir():
|
||||
|
||||
del init, module_name
|
||||
|
||||
|
||||
logger.info(f"Loaded {loaded_count} modules from {Path(__file__).parent}")
|
||||
if loaded_count > 0:
|
||||
logger.info(f"Loaded {loaded_count} node packs from {Path(__file__).parent}")
|
||||
|
||||
@@ -8,11 +8,11 @@ from PIL import Image, ImageOps
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
||||
|
||||
|
||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.1.0")
|
||||
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class CvInpaintInvocation(BaseInvocation, WithMetadata):
|
||||
"""Simple inpaint using opencv."""
|
||||
|
||||
image: ImageField = InputField(description="The image to inpaint")
|
||||
@@ -41,7 +41,7 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
||||
@@ -17,7 +17,6 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -438,8 +437,8 @@ def get_faces_list(
|
||||
return all_faces
|
||||
|
||||
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.1.0")
|
||||
class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0")
|
||||
class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||
|
||||
image: ImageField = InputField(description="Image for face detection")
|
||||
@@ -508,7 +507,7 @@ class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
mask_dto = context.services.images.create(
|
||||
@@ -532,8 +531,8 @@ class FaceOffInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
return output
|
||||
|
||||
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.1.0")
|
||||
class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0")
|
||||
class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"""Face mask creation using mediapipe face detection"""
|
||||
|
||||
image: ImageField = InputField(description="Image to face detect")
|
||||
@@ -627,7 +626,7 @@ class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
mask_dto = context.services.images.create(
|
||||
@@ -650,9 +649,9 @@ class FaceMaskInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation(
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.1.0"
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0"
|
||||
)
|
||||
class FaceIdentifierInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
|
||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||
|
||||
image: ImageField = InputField(description="Image to face detect")
|
||||
@@ -716,7 +715,7 @@ class FaceIdentifierInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
||||
@@ -13,7 +13,7 @@ from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||
|
||||
from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||
from .baseinvocation import BaseInvocation, Input, InputField, InvocationContext, WithMetadata, invocation
|
||||
|
||||
|
||||
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.0")
|
||||
@@ -36,8 +36,14 @@ class ShowImageInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
|
||||
@invocation("blank_image", title="Blank Image", tags=["image"], category="image", version="1.1.0")
|
||||
class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
@invocation(
|
||||
"blank_image",
|
||||
title="Blank Image",
|
||||
tags=["image"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class BlankImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Creates a blank image and forwards it to the pipeline"""
|
||||
|
||||
width: int = InputField(default=512, description="The width of the image")
|
||||
@@ -56,7 +62,7 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -66,8 +72,14 @@ class BlankImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_crop", title="Crop Image", tags=["image", "crop"], category="image", version="1.1.0")
|
||||
class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_crop",
|
||||
title="Crop Image",
|
||||
tags=["image", "crop"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageCropInvocation(BaseInvocation, WithMetadata):
|
||||
"""Crops an image to a specified box. The box can be outside of the image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to crop")
|
||||
@@ -90,7 +102,7 @@ class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -100,8 +112,69 @@ class ImageCropInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_paste", title="Paste Image", tags=["image", "paste"], category="image", version="1.1.0")
|
||||
class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_paste",
|
||||
title="Paste Image",
|
||||
tags=["image", "paste"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class CenterPadCropInvocation(BaseInvocation):
|
||||
"""Pad or crop an image's sides from the center by specified pixels. Positive values are outside of the image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to crop")
|
||||
left: int = InputField(
|
||||
default=0,
|
||||
description="Number of pixels to pad/crop from the left (negative values crop inwards, positive values pad outwards)",
|
||||
)
|
||||
right: int = InputField(
|
||||
default=0,
|
||||
description="Number of pixels to pad/crop from the right (negative values crop inwards, positive values pad outwards)",
|
||||
)
|
||||
top: int = InputField(
|
||||
default=0,
|
||||
description="Number of pixels to pad/crop from the top (negative values crop inwards, positive values pad outwards)",
|
||||
)
|
||||
bottom: int = InputField(
|
||||
default=0,
|
||||
description="Number of pixels to pad/crop from the bottom (negative values crop inwards, positive values pad outwards)",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# Calculate and create new image dimensions
|
||||
new_width = image.width + self.right + self.left
|
||||
new_height = image.height + self.top + self.bottom
|
||||
image_crop = Image.new(mode="RGBA", size=(new_width, new_height), color=(0, 0, 0, 0))
|
||||
|
||||
# Paste new image onto input
|
||||
image_crop.paste(image, (self.left, self.top))
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image_crop,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
invocation_type="img_pad_crop",
|
||||
title="Center Pad or Crop Image",
|
||||
category="image",
|
||||
tags=["image", "pad", "crop"],
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImagePasteInvocation(BaseInvocation, WithMetadata):
|
||||
"""Pastes an image into another image."""
|
||||
|
||||
base_image: ImageField = InputField(description="The base image")
|
||||
@@ -144,7 +217,7 @@ class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -154,8 +227,14 @@ class ImagePasteInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("tomask", title="Mask from Alpha", tags=["image", "mask"], category="image", version="1.1.0")
|
||||
class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"tomask",
|
||||
title="Mask from Alpha",
|
||||
tags=["image", "mask"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class MaskFromAlphaInvocation(BaseInvocation, WithMetadata):
|
||||
"""Extracts the alpha channel of an image as a mask."""
|
||||
|
||||
image: ImageField = InputField(description="The image to create the mask from")
|
||||
@@ -176,7 +255,7 @@ class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -186,8 +265,14 @@ class MaskFromAlphaInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_mul", title="Multiply Images", tags=["image", "multiply"], category="image", version="1.1.0")
|
||||
class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_mul",
|
||||
title="Multiply Images",
|
||||
tags=["image", "multiply"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageMultiplyInvocation(BaseInvocation, WithMetadata):
|
||||
"""Multiplies two images together using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
image1: ImageField = InputField(description="The first image to multiply")
|
||||
@@ -207,7 +292,7 @@ class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -220,8 +305,14 @@ class ImageMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
IMAGE_CHANNELS = Literal["A", "R", "G", "B"]
|
||||
|
||||
|
||||
@invocation("img_chan", title="Extract Image Channel", tags=["image", "channel"], category="image", version="1.1.0")
|
||||
class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_chan",
|
||||
title="Extract Image Channel",
|
||||
tags=["image", "channel"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageChannelInvocation(BaseInvocation, WithMetadata):
|
||||
"""Gets a channel from an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to get the channel from")
|
||||
@@ -240,7 +331,7 @@ class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -253,8 +344,14 @@ class ImageChannelInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
|
||||
|
||||
@invocation("img_conv", title="Convert Image Mode", tags=["image", "convert"], category="image", version="1.1.0")
|
||||
class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_conv",
|
||||
title="Convert Image Mode",
|
||||
tags=["image", "convert"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageConvertInvocation(BaseInvocation, WithMetadata):
|
||||
"""Converts an image to a different mode."""
|
||||
|
||||
image: ImageField = InputField(description="The image to convert")
|
||||
@@ -273,7 +370,7 @@ class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -283,8 +380,14 @@ class ImageConvertInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_blur", title="Blur Image", tags=["image", "blur"], category="image", version="1.1.0")
|
||||
class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_blur",
|
||||
title="Blur Image",
|
||||
tags=["image", "blur"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageBlurInvocation(BaseInvocation, WithMetadata):
|
||||
"""Blurs an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to blur")
|
||||
@@ -308,7 +411,7 @@ class ImageBlurInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -338,8 +441,14 @@ PIL_RESAMPLING_MAP = {
|
||||
}
|
||||
|
||||
|
||||
@invocation("img_resize", title="Resize Image", tags=["image", "resize"], category="image", version="1.1.0")
|
||||
class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
@invocation(
|
||||
"img_resize",
|
||||
title="Resize Image",
|
||||
tags=["image", "resize"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageResizeInvocation(BaseInvocation, WithMetadata):
|
||||
"""Resizes an image to specific dimensions"""
|
||||
|
||||
image: ImageField = InputField(description="The image to resize")
|
||||
@@ -365,7 +474,7 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -375,8 +484,14 @@ class ImageResizeInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_scale", title="Scale Image", tags=["image", "scale"], category="image", version="1.1.0")
|
||||
class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
@invocation(
|
||||
"img_scale",
|
||||
title="Scale Image",
|
||||
tags=["image", "scale"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageScaleInvocation(BaseInvocation, WithMetadata):
|
||||
"""Scales an image by a factor"""
|
||||
|
||||
image: ImageField = InputField(description="The image to scale")
|
||||
@@ -407,7 +522,7 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -417,8 +532,14 @@ class ImageScaleInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_lerp", title="Lerp Image", tags=["image", "lerp"], category="image", version="1.1.0")
|
||||
class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_lerp",
|
||||
title="Lerp Image",
|
||||
tags=["image", "lerp"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageLerpInvocation(BaseInvocation, WithMetadata):
|
||||
"""Linear interpolation of all pixels of an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
@@ -441,7 +562,7 @@ class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -451,8 +572,14 @@ class ImageLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_ilerp", title="Inverse Lerp Image", tags=["image", "ilerp"], category="image", version="1.1.0")
|
||||
class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_ilerp",
|
||||
title="Inverse Lerp Image",
|
||||
tags=["image", "ilerp"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageInverseLerpInvocation(BaseInvocation, WithMetadata):
|
||||
"""Inverse linear interpolation of all pixels of an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to lerp")
|
||||
@@ -475,7 +602,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -485,8 +612,14 @@ class ImageInverseLerpInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_nsfw", title="Blur NSFW Image", tags=["image", "nsfw"], category="image", version="1.1.0")
|
||||
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
@invocation(
|
||||
"img_nsfw",
|
||||
title="Blur NSFW Image",
|
||||
tags=["image", "nsfw"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata):
|
||||
"""Add blur to NSFW-flagged images"""
|
||||
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
@@ -511,7 +644,7 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -532,9 +665,9 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
title="Add Invisible Watermark",
|
||||
tags=["image", "watermark"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class ImageWatermarkInvocation(BaseInvocation, WithMetadata):
|
||||
"""Add an invisible watermark to an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to check")
|
||||
@@ -551,7 +684,7 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -561,8 +694,14 @@ class ImageWatermarkInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
)
|
||||
|
||||
|
||||
@invocation("mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], category="image", version="1.1.0")
|
||||
class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"mask_edge",
|
||||
title="Mask Edge",
|
||||
tags=["image", "mask", "inpaint"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class MaskEdgeInvocation(BaseInvocation, WithMetadata):
|
||||
"""Applies an edge mask to an image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to apply the mask to")
|
||||
@@ -597,7 +736,7 @@ class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -612,9 +751,9 @@ class MaskEdgeInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
title="Combine Masks",
|
||||
tags=["image", "mask", "multiply"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class MaskCombineInvocation(BaseInvocation, WithMetadata):
|
||||
"""Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`."""
|
||||
|
||||
mask1: ImageField = InputField(description="The first mask to combine")
|
||||
@@ -634,7 +773,7 @@ class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -644,8 +783,14 @@ class MaskCombineInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("color_correct", title="Color Correct", tags=["image", "color"], category="image", version="1.1.0")
|
||||
class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"color_correct",
|
||||
title="Color Correct",
|
||||
tags=["image", "color"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ColorCorrectInvocation(BaseInvocation, WithMetadata):
|
||||
"""
|
||||
Shifts the colors of a target image to match the reference image, optionally
|
||||
using a mask to only color-correct certain regions of the target image.
|
||||
@@ -745,7 +890,7 @@ class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -755,8 +900,14 @@ class ColorCorrectInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("img_hue_adjust", title="Adjust Image Hue", tags=["image", "hue"], category="image", version="1.1.0")
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation(
|
||||
"img_hue_adjust",
|
||||
title="Adjust Image Hue",
|
||||
tags=["image", "hue"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageHueAdjustmentInvocation(BaseInvocation, WithMetadata):
|
||||
"""Adjusts the Hue of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
@@ -785,7 +936,7 @@ class ImageHueAdjustmentInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -858,9 +1009,9 @@ CHANNEL_FORMATS = {
|
||||
"value",
|
||||
],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageChannelOffsetInvocation(BaseInvocation, WithMetadata):
|
||||
"""Add or subtract a value from a specific color channel of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
@@ -895,7 +1046,7 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -929,9 +1080,9 @@ class ImageChannelOffsetInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
"value",
|
||||
],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata):
|
||||
"""Scale a specific color channel of an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to adjust")
|
||||
@@ -970,7 +1121,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata)
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
session_id=context.graph_execution_state_id,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
|
||||
@@ -988,10 +1139,10 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithWorkflow, WithMetadata)
|
||||
title="Save Image",
|
||||
tags=["primitives", "image"],
|
||||
category="primitives",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
use_cache=False,
|
||||
)
|
||||
class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class SaveImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
|
||||
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
@@ -1009,7 +1160,7 @@ class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -1027,7 +1178,7 @@ class SaveImageInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
version="1.0.1",
|
||||
use_cache=False,
|
||||
)
|
||||
class LinearUIOutputInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class LinearUIOutputInvocation(BaseInvocation, WithMetadata):
|
||||
"""Handles Linear UI Image Outputting tasks."""
|
||||
|
||||
image: ImageField = InputField(description=FieldDescriptions.image)
|
||||
|
||||
@@ -8,12 +8,12 @@ from PIL import Image, ImageOps
|
||||
|
||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||
|
||||
|
||||
@@ -118,8 +118,8 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
return si
|
||||
|
||||
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
|
||||
class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class InfillColorInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@@ -144,7 +144,7 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -154,17 +154,17 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
|
||||
class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class InfillTileInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description="The seed to use for tile generation (omit for random)",
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
@@ -181,7 +181,7 @@ class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -192,9 +192,9 @@ class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
|
||||
@invocation(
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0"
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0"
|
||||
)
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@@ -235,7 +235,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -245,8 +245,8 @@ class InfillPatchMatchInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
|
||||
class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@@ -264,7 +264,7 @@ class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -274,8 +274,8 @@ class LaMaInfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0")
|
||||
class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class CV2InfillInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@@ -293,7 +293,7 @@ class CV2InfillInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
||||
@@ -11,7 +11,6 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -67,7 +66,7 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
|
||||
weight: Union[float, List[float]] = InputField(
|
||||
default=1, ge=-1, description="The weight given to the IP-Adapter", ui_type=UIType.Float, title="Weight"
|
||||
default=1, ge=-1, description="The weight given to the IP-Adapter", title="Weight"
|
||||
)
|
||||
|
||||
begin_step_percent: float = InputField(
|
||||
|
||||
@@ -64,7 +64,6 @@ from .baseinvocation import (
|
||||
OutputField,
|
||||
UIType,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -79,6 +78,12 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
||||
|
||||
# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
|
||||
# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
|
||||
# factor is hard-coded to a literal '8' rather than using this constant.
|
||||
# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
|
||||
LATENT_SCALE_FACTOR = 8
|
||||
|
||||
|
||||
@invocation_output("scheduler_output")
|
||||
class SchedulerOutput(BaseInvocationOutput):
|
||||
@@ -215,7 +220,7 @@ def get_scheduler(
|
||||
title="Denoise Latents",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
version="1.4.0",
|
||||
version="1.5.0",
|
||||
)
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
@@ -273,8 +278,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
ui_order=7,
|
||||
)
|
||||
cfg_rescale_multiplier: float = InputField(
|
||||
default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
ui_order=4,
|
||||
)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None,
|
||||
@@ -329,6 +340,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unconditioned_embeddings=uc,
|
||||
text_embeddings=c,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
extra=extra_conditioning_info,
|
||||
postprocessing_settings=PostprocessingSettings(
|
||||
threshold=0.0, # threshold,
|
||||
@@ -387,9 +399,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> List[ControlNetData]:
|
||||
# assuming fixed dimensional scaling of 8:1 for image:latents
|
||||
control_height_resize = latents_shape[2] * 8
|
||||
control_width_resize = latents_shape[3] * 8
|
||||
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
|
||||
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
|
||||
if control_input is None:
|
||||
control_list = None
|
||||
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||
@@ -706,7 +718,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
|
||||
ModelPatcher.apply_freeu(unet_info.context.model, self.unet.freeu_config),
|
||||
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
||||
unet_info as unet,
|
||||
@@ -790,9 +801,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
title="Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i"],
|
||||
category="latents",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
@@ -874,7 +885,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
@@ -903,12 +914,12 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
width: int = InputField(
|
||||
ge=64,
|
||||
multiple_of=8,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
height: int = InputField(
|
||||
ge=64,
|
||||
multiple_of=8,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode)
|
||||
@@ -922,7 +933,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
|
||||
resized_latents = torch.nn.functional.interpolate(
|
||||
latents.to(device),
|
||||
size=(self.height // 8, self.width // 8),
|
||||
size=(self.height // LATENT_SCALE_FACTOR, self.width // LATENT_SCALE_FACTOR),
|
||||
mode=self.mode,
|
||||
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
|
||||
)
|
||||
@@ -1160,3 +1171,60 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, blended_latents)
|
||||
return build_latents_output(latents_name=name, latents=blended_latents)
|
||||
|
||||
|
||||
# The Crop Latents node was copied from @skunkworxdark's implementation here:
|
||||
# https://github.com/skunkworxdark/XYGrid_nodes/blob/74647fa9c1fa57d317a94bd43ca689af7f0aae5e/images_to_grids.py#L1117C1-L1167C80
|
||||
@invocation(
|
||||
"crop_latents",
|
||||
title="Crop Latents",
|
||||
tags=["latents", "crop"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
|
||||
# Currently, if the class names conflict then 'GET /openapi.json' fails.
|
||||
class CropLatentsCoreInvocation(BaseInvocation):
|
||||
"""Crops a latent-space tensor to a box specified in image-space. The box dimensions and coordinates must be
|
||||
divisible by the latent scale factor of 8.
|
||||
"""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
x: int = InputField(
|
||||
ge=0,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
description="The left x coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
||||
)
|
||||
y: int = InputField(
|
||||
ge=0,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
description="The top y coordinate (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
||||
)
|
||||
width: int = InputField(
|
||||
ge=1,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
description="The width (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
||||
)
|
||||
height: int = InputField(
|
||||
ge=1,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
description="The height (in px) of the crop rectangle in image space. This value will be converted to a dimension in latent space.",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
x1 = self.x // LATENT_SCALE_FACTOR
|
||||
y1 = self.y // LATENT_SCALE_FACTOR
|
||||
x2 = x1 + (self.width // LATENT_SCALE_FACTOR)
|
||||
y2 = y1 + (self.height // LATENT_SCALE_FACTOR)
|
||||
|
||||
cropped_latents = latents[..., y1:y2, x1:x2]
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, cropped_latents)
|
||||
|
||||
return build_latents_output(latents_name=name, latents=cropped_latents)
|
||||
|
||||
@@ -127,6 +127,9 @@ class CoreMetadataInvocation(BaseInvocation):
|
||||
seed: Optional[int] = InputField(default=None, description="The seed used for noise generation")
|
||||
rand_device: Optional[str] = InputField(default=None, description="The device used for random number generation")
|
||||
cfg_scale: Optional[float] = InputField(default=None, description="The classifier-free guidance scale parameter")
|
||||
cfg_rescale_multiplier: Optional[float] = InputField(
|
||||
default=None, description=FieldDescriptions.cfg_rescale_multiplier
|
||||
)
|
||||
steps: Optional[int] = InputField(default=None, description="The number of steps used for inference")
|
||||
scheduler: Optional[str] = InputField(default=None, description="The scheduler used for inference")
|
||||
seamless_x: Optional[bool] = InputField(default=None, description="Whether seamless tiling was used on the X axis")
|
||||
|
||||
@@ -14,7 +14,6 @@ from .baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -395,7 +394,6 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
vae_model: VAEModelField = InputField(
|
||||
description=FieldDescriptions.vae_model,
|
||||
input=Input.Direct,
|
||||
ui_type=UIType.VaeModel,
|
||||
title="VAE",
|
||||
)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.misc import SEED_MAX, get_random_seed
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from .baseinvocation import (
|
||||
@@ -83,16 +83,16 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||
title="Noise",
|
||||
tags=["latents", "noise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class NoiseInvocation(BaseInvocation):
|
||||
"""Generates latent noise."""
|
||||
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description=FieldDescriptions.seed,
|
||||
default_factory=get_random_seed,
|
||||
)
|
||||
width: int = InputField(
|
||||
default=512,
|
||||
|
||||
@@ -31,7 +31,6 @@ from .baseinvocation import (
|
||||
UIComponent,
|
||||
UIType,
|
||||
WithMetadata,
|
||||
WithWorkflow,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -326,9 +325,9 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
title="ONNX Latents to Image",
|
||||
tags=["latents", "image", "vae", "onnx"],
|
||||
category="image",
|
||||
version="1.1.0",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
@@ -378,7 +377,7 @@ class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
||||
@@ -62,12 +62,12 @@ class BooleanInvocation(BaseInvocation):
|
||||
title="Boolean Collection Primitive",
|
||||
tags=["primitives", "boolean", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class BooleanCollectionInvocation(BaseInvocation):
|
||||
"""A collection of boolean primitive values"""
|
||||
|
||||
collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values")
|
||||
collection: list[bool] = InputField(default=[], description="The collection of boolean values")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> BooleanCollectionOutput:
|
||||
return BooleanCollectionOutput(collection=self.collection)
|
||||
@@ -111,12 +111,12 @@ class IntegerInvocation(BaseInvocation):
|
||||
title="Integer Collection Primitive",
|
||||
tags=["primitives", "integer", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class IntegerCollectionInvocation(BaseInvocation):
|
||||
"""A collection of integer primitive values"""
|
||||
|
||||
collection: list[int] = InputField(default_factory=list, description="The collection of integer values")
|
||||
collection: list[int] = InputField(default=[], description="The collection of integer values")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerCollectionOutput:
|
||||
return IntegerCollectionOutput(collection=self.collection)
|
||||
@@ -158,12 +158,12 @@ class FloatInvocation(BaseInvocation):
|
||||
title="Float Collection Primitive",
|
||||
tags=["primitives", "float", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class FloatCollectionInvocation(BaseInvocation):
|
||||
"""A collection of float primitive values"""
|
||||
|
||||
collection: list[float] = InputField(default_factory=list, description="The collection of float values")
|
||||
collection: list[float] = InputField(default=[], description="The collection of float values")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
|
||||
return FloatCollectionOutput(collection=self.collection)
|
||||
@@ -205,12 +205,12 @@ class StringInvocation(BaseInvocation):
|
||||
title="String Collection Primitive",
|
||||
tags=["primitives", "string", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class StringCollectionInvocation(BaseInvocation):
|
||||
"""A collection of string primitive values"""
|
||||
|
||||
collection: list[str] = InputField(default_factory=list, description="The collection of string values")
|
||||
collection: list[str] = InputField(default=[], description="The collection of string values")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringCollectionOutput:
|
||||
return StringCollectionOutput(collection=self.collection)
|
||||
@@ -467,13 +467,13 @@ class ConditioningInvocation(BaseInvocation):
|
||||
title="Conditioning Collection Primitive",
|
||||
tags=["primitives", "conditioning", "collection"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class ConditioningCollectionInvocation(BaseInvocation):
|
||||
"""A collection of conditioning tensor primitive values"""
|
||||
|
||||
collection: list[ConditioningField] = InputField(
|
||||
default_factory=list,
|
||||
default=[],
|
||||
description="The collection of conditioning tensors",
|
||||
)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class DynamicPromptInvocation(BaseInvocation):
|
||||
title="Prompts from File",
|
||||
tags=["prompt", "file"],
|
||||
category="prompt",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class PromptsFromFileInvocation(BaseInvocation):
|
||||
"""Loads prompts from a text file"""
|
||||
@@ -82,7 +82,7 @@ class PromptsFromFileInvocation(BaseInvocation):
|
||||
end_line = start_line + max_prompts
|
||||
if max_prompts <= 0:
|
||||
end_line = np.iinfo(np.int32).max
|
||||
with open(file_path) as f:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
for i, line in enumerate(f):
|
||||
if i >= start_line and i < end_line:
|
||||
prompts.append((pre_prompt or "") + line.strip() + (post_prompt or ""))
|
||||
|
||||
@@ -9,7 +9,6 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -59,7 +58,7 @@ class T2IAdapterInvocation(BaseInvocation):
|
||||
ui_order=-1,
|
||||
)
|
||||
weight: Union[float, list[float]] = InputField(
|
||||
default=1, ge=0, description="The weight given to the T2I-Adapter", ui_type=UIType.Float, title="Weight"
|
||||
default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)"
|
||||
|
||||
180
invokeai/app/invocations/tiles.py
Normal file
180
invokeai/app/invocations/tiles.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithMetadata,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
|
||||
from invokeai.backend.tiles.utils import Tile
|
||||
|
||||
|
||||
class TileWithImage(BaseModel):
|
||||
tile: Tile
|
||||
image: ImageField
|
||||
|
||||
|
||||
@invocation_output("calculate_image_tiles_output")
|
||||
class CalculateImageTilesOutput(BaseInvocationOutput):
|
||||
tiles: list[Tile] = OutputField(description="The tiles coordinates that cover a particular image shape.")
|
||||
|
||||
|
||||
@invocation("calculate_image_tiles", title="Calculate Image Tiles", tags=["tiles"], category="tiles", version="1.0.0")
|
||||
class CalculateImageTilesInvocation(BaseInvocation):
|
||||
"""Calculate the coordinates and overlaps of tiles that cover a target image shape."""
|
||||
|
||||
image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.")
|
||||
image_height: int = InputField(
|
||||
ge=1, default=1024, description="The image height, in pixels, to calculate tiles for."
|
||||
)
|
||||
tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.")
|
||||
tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.")
|
||||
overlap: int = InputField(
|
||||
ge=0,
|
||||
default=128,
|
||||
description="The target overlap, in pixels, between adjacent tiles. Adjacent tiles will overlap by at least this amount",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput:
|
||||
tiles = calc_tiles_with_overlap(
|
||||
image_height=self.image_height,
|
||||
image_width=self.image_width,
|
||||
tile_height=self.tile_height,
|
||||
tile_width=self.tile_width,
|
||||
overlap=self.overlap,
|
||||
)
|
||||
return CalculateImageTilesOutput(tiles=tiles)
|
||||
|
||||
|
||||
@invocation_output("tile_to_properties_output")
|
||||
class TileToPropertiesOutput(BaseInvocationOutput):
|
||||
coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.")
|
||||
coords_right: int = OutputField(description="Right coordinate of the tile relative to its parent image.")
|
||||
coords_top: int = OutputField(description="Top coordinate of the tile relative to its parent image.")
|
||||
coords_bottom: int = OutputField(description="Bottom coordinate of the tile relative to its parent image.")
|
||||
|
||||
# HACK: The width and height fields are 'meta' fields that can easily be calculated from the other fields on this
|
||||
# object. Including redundant fields that can cheaply/easily be re-calculated goes against conventional API design
|
||||
# principles. These fields are included, because 1) they are often useful in tiled workflows, and 2) they are
|
||||
# difficult to calculate in a workflow (even though it's just a couple of subtraction nodes the graph gets
|
||||
# surprisingly complicated).
|
||||
width: int = OutputField(description="The width of the tile. Equal to coords_right - coords_left.")
|
||||
height: int = OutputField(description="The height of the tile. Equal to coords_bottom - coords_top.")
|
||||
|
||||
overlap_top: int = OutputField(description="Overlap between this tile and its top neighbor.")
|
||||
overlap_bottom: int = OutputField(description="Overlap between this tile and its bottom neighbor.")
|
||||
overlap_left: int = OutputField(description="Overlap between this tile and its left neighbor.")
|
||||
overlap_right: int = OutputField(description="Overlap between this tile and its right neighbor.")
|
||||
|
||||
|
||||
@invocation("tile_to_properties", title="Tile to Properties", tags=["tiles"], category="tiles", version="1.0.0")
|
||||
class TileToPropertiesInvocation(BaseInvocation):
|
||||
"""Split a Tile into its individual properties."""
|
||||
|
||||
tile: Tile = InputField(description="The tile to split into properties.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> TileToPropertiesOutput:
|
||||
return TileToPropertiesOutput(
|
||||
coords_left=self.tile.coords.left,
|
||||
coords_right=self.tile.coords.right,
|
||||
coords_top=self.tile.coords.top,
|
||||
coords_bottom=self.tile.coords.bottom,
|
||||
width=self.tile.coords.right - self.tile.coords.left,
|
||||
height=self.tile.coords.bottom - self.tile.coords.top,
|
||||
overlap_top=self.tile.overlap.top,
|
||||
overlap_bottom=self.tile.overlap.bottom,
|
||||
overlap_left=self.tile.overlap.left,
|
||||
overlap_right=self.tile.overlap.right,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("pair_tile_image_output")
|
||||
class PairTileImageOutput(BaseInvocationOutput):
|
||||
tile_with_image: TileWithImage = OutputField(description="A tile description with its corresponding image.")
|
||||
|
||||
|
||||
@invocation("pair_tile_image", title="Pair Tile with Image", tags=["tiles"], category="tiles", version="1.0.0")
|
||||
class PairTileImageInvocation(BaseInvocation):
|
||||
"""Pair an image with its tile properties."""
|
||||
|
||||
# TODO(ryand): The only reason that PairTileImage is needed is because the iterate/collect nodes don't preserve
|
||||
# order. Can this be fixed?
|
||||
|
||||
image: ImageField = InputField(description="The tile image.")
|
||||
tile: Tile = InputField(description="The tile properties.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> PairTileImageOutput:
|
||||
return PairTileImageOutput(
|
||||
tile_with_image=TileWithImage(
|
||||
tile=self.tile,
|
||||
image=self.image,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.1.0")
|
||||
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Merge multiple tile images into a single image."""
|
||||
|
||||
# Inputs
|
||||
tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.")
|
||||
blend_amount: int = InputField(
|
||||
ge=0,
|
||||
description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.",
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
images = [twi.image for twi in self.tiles_with_images]
|
||||
tiles = [twi.tile for twi in self.tiles_with_images]
|
||||
|
||||
# Infer the output image dimensions from the max/min tile limits.
|
||||
height = 0
|
||||
width = 0
|
||||
for tile in tiles:
|
||||
height = max(height, tile.coords.bottom)
|
||||
width = max(width, tile.coords.right)
|
||||
|
||||
# Get all tile images for processing.
|
||||
# TODO(ryand): It pains me that we spend time PNG decoding each tile from disk when they almost certainly
|
||||
# existed in memory at an earlier point in the graph.
|
||||
tile_np_images: list[np.ndarray] = []
|
||||
for image in images:
|
||||
pil_image = context.services.images.get_pil_image(image.image_name)
|
||||
pil_image = pil_image.convert("RGB")
|
||||
tile_np_images.append(np.array(pil_image))
|
||||
|
||||
# Prepare the output image buffer.
|
||||
# Check the first tile to determine how many image channels are expected in the output.
|
||||
channels = tile_np_images[0].shape[-1]
|
||||
dtype = tile_np_images[0].dtype
|
||||
np_image = np.zeros(shape=(height, width, channels), dtype=dtype)
|
||||
|
||||
merge_tiles_with_linear_blending(
|
||||
dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount
|
||||
)
|
||||
pil_image = Image.fromarray(np_image)
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
@@ -2,19 +2,19 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2 as cv
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from PIL import Image
|
||||
from pydantic import ConfigDict
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
||||
|
||||
# TODO: Populate this from disk?
|
||||
# TODO: Use model manager to load?
|
||||
@@ -29,8 +29,8 @@ if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.2.0")
|
||||
class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0")
|
||||
class ESRGANInvocation(BaseInvocation, WithMetadata):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
image: ImageField = InputField(description="The input image")
|
||||
@@ -92,9 +92,9 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
upscaler = RealESRGAN(
|
||||
scale=netscale,
|
||||
model_path=str(models_path / esrgan_model_path),
|
||||
model_path=models_path / esrgan_model_path,
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
@@ -102,15 +102,9 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||
# TODO: This strips the alpha... is that okay?
|
||||
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||
|
||||
# We can pass an `outscale` value here, but it just resizes the image by that factor after
|
||||
# upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
|
||||
# upscaling, you'll need to add a resize node after this one.
|
||||
upscaled_image, img_mode = upsampler.enhance(cv_image)
|
||||
|
||||
# back to PIL
|
||||
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
|
||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
upscaled_image = upscaler.upscale(cv2_image)
|
||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
@@ -124,7 +118,7 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=self.workflow,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Optional, cast
|
||||
|
||||
from invokeai.app.services.image_records.image_records_common import ImageRecord, deserialize_image_record
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .board_image_records_base import BoardImageRecordStorageBase
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import threading
|
||||
from typing import Union, cast
|
||||
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from .board_records_base import BoardRecordStorageBase
|
||||
|
||||
@@ -15,7 +15,7 @@ import os
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
@@ -24,10 +24,7 @@ from invokeai.app.services.config.config_common import PagingArgumentParser, int
|
||||
|
||||
|
||||
class InvokeAISettings(BaseSettings):
|
||||
"""
|
||||
Runtime configuration settings in which default values are
|
||||
read from an omegaconf .yaml file.
|
||||
"""
|
||||
"""Runtime configuration settings in which default values are read from an omegaconf .yaml file."""
|
||||
|
||||
initconf: ClassVar[Optional[DictConfig]] = None
|
||||
argparse_groups: ClassVar[Dict] = {}
|
||||
@@ -35,6 +32,7 @@ class InvokeAISettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file_encoding="utf-8", arbitrary_types_allowed=True, case_sensitive=True)
|
||||
|
||||
def parse_args(self, argv: Optional[list] = sys.argv[1:]):
|
||||
"""Call to parse command-line arguments."""
|
||||
parser = self.get_parser()
|
||||
opt, unknown_opts = parser.parse_known_args(argv)
|
||||
if len(unknown_opts) > 0:
|
||||
@@ -49,20 +47,19 @@ class InvokeAISettings(BaseSettings):
|
||||
setattr(self, name, value)
|
||||
|
||||
def to_yaml(self) -> str:
|
||||
"""
|
||||
Return a YAML string representing our settings. This can be used
|
||||
as the contents of `invokeai.yaml` to restore settings later.
|
||||
"""
|
||||
"""Return a YAML string representing our settings. This can be used as the contents of `invokeai.yaml` to restore settings later."""
|
||||
cls = self.__class__
|
||||
type = get_args(get_type_hints(cls)["type"])[0]
|
||||
field_dict = {type: {}}
|
||||
field_dict: Dict[str, Dict[str, Any]] = {type: {}}
|
||||
for name, field in self.model_fields.items():
|
||||
if name in cls._excluded_from_yaml():
|
||||
continue
|
||||
assert isinstance(field.json_schema_extra, dict)
|
||||
category = (
|
||||
field.json_schema_extra.get("category", "Uncategorized") if field.json_schema_extra else "Uncategorized"
|
||||
)
|
||||
value = getattr(self, name)
|
||||
assert isinstance(category, str)
|
||||
if category not in field_dict[type]:
|
||||
field_dict[type][category] = {}
|
||||
# keep paths as strings to make it easier to read
|
||||
@@ -72,6 +69,7 @@ class InvokeAISettings(BaseSettings):
|
||||
|
||||
@classmethod
|
||||
def add_parser_arguments(cls, parser):
|
||||
"""Dynamically create arguments for a settings parser."""
|
||||
if "type" in get_type_hints(cls):
|
||||
settings_stanza = get_args(get_type_hints(cls)["type"])[0]
|
||||
else:
|
||||
@@ -116,6 +114,7 @@ class InvokeAISettings(BaseSettings):
|
||||
|
||||
@classmethod
|
||||
def cmd_name(cls, command_field: str = "type") -> str:
|
||||
"""Return the category of a setting."""
|
||||
hints = get_type_hints(cls)
|
||||
if command_field in hints:
|
||||
return get_args(hints[command_field])[0]
|
||||
@@ -124,6 +123,7 @@ class InvokeAISettings(BaseSettings):
|
||||
|
||||
@classmethod
|
||||
def get_parser(cls) -> ArgumentParser:
|
||||
"""Get the command-line parser for a setting."""
|
||||
parser = PagingArgumentParser(
|
||||
prog=cls.cmd_name(),
|
||||
description=cls.__doc__,
|
||||
@@ -152,10 +152,14 @@ class InvokeAISettings(BaseSettings):
|
||||
"free_gpu_mem",
|
||||
"xformers_enabled",
|
||||
"tiled_decode",
|
||||
"lora_dir",
|
||||
"embedding_dir",
|
||||
"controlnet_dir",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def add_field_argument(cls, command_parser, name: str, field, default_override=None):
|
||||
"""Add the argparse arguments for a setting parser."""
|
||||
field_type = get_type_hints(cls).get(name)
|
||||
default = (
|
||||
default_override
|
||||
|
||||
@@ -177,6 +177,7 @@ from typing import ClassVar, Dict, List, Literal, Optional, Union, get_type_hint
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field, TypeAdapter
|
||||
from pydantic.config import JsonDict
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from .config_base import InvokeAISettings
|
||||
@@ -188,28 +189,24 @@ DEFAULT_MAX_VRAM = 0.5
|
||||
|
||||
|
||||
class Categories(object):
|
||||
WebServer = {"category": "Web Server"}
|
||||
Features = {"category": "Features"}
|
||||
Paths = {"category": "Paths"}
|
||||
Logging = {"category": "Logging"}
|
||||
Development = {"category": "Development"}
|
||||
Other = {"category": "Other"}
|
||||
ModelCache = {"category": "Model Cache"}
|
||||
Device = {"category": "Device"}
|
||||
Generation = {"category": "Generation"}
|
||||
Queue = {"category": "Queue"}
|
||||
Nodes = {"category": "Nodes"}
|
||||
MemoryPerformance = {"category": "Memory/Performance"}
|
||||
"""Category headers for configuration variable groups."""
|
||||
|
||||
WebServer: JsonDict = {"category": "Web Server"}
|
||||
Features: JsonDict = {"category": "Features"}
|
||||
Paths: JsonDict = {"category": "Paths"}
|
||||
Logging: JsonDict = {"category": "Logging"}
|
||||
Development: JsonDict = {"category": "Development"}
|
||||
Other: JsonDict = {"category": "Other"}
|
||||
ModelCache: JsonDict = {"category": "Model Cache"}
|
||||
Device: JsonDict = {"category": "Device"}
|
||||
Generation: JsonDict = {"category": "Generation"}
|
||||
Queue: JsonDict = {"category": "Queue"}
|
||||
Nodes: JsonDict = {"category": "Nodes"}
|
||||
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""
|
||||
Generate images using Stable Diffusion. Use "invokeai" to launch
|
||||
the command-line client (recommended for experts only), or
|
||||
"invokeai-web" to launch the web server. Global options
|
||||
can be changed by editing the file "INVOKEAI_ROOT/invokeai.yaml" or by
|
||||
setting environment variables INVOKEAI_<setting>.
|
||||
"""
|
||||
"""Configuration object for InvokeAI App."""
|
||||
|
||||
singleton_config: ClassVar[Optional[InvokeAIAppConfig]] = None
|
||||
singleton_init: ClassVar[Optional[Dict]] = None
|
||||
@@ -234,15 +231,12 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
# PATHS
|
||||
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
|
||||
autoimport_dir : Optional[Path] = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Optional[Path] = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
models_dir : Optional[Path] = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||
legacy_conf_dir : Optional[Path] = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||
db_dir : Optional[Path] = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||
outdir : Optional[Path] = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||
db_dir : Path = Field(default=Path('databases'), description='Path to InvokeAI databases directory', json_schema_extra=Categories.Paths)
|
||||
outdir : Path = Field(default=Path('outputs'), description='Default folder for output images', json_schema_extra=Categories.Paths)
|
||||
use_memory_db : bool = Field(default=False, description='Use in-memory database for storing image metadata', json_schema_extra=Categories.Paths)
|
||||
custom_nodes_dir : Path = Field(default=Path('nodes'), description='Path to directory for custom nodes', json_schema_extra=Categories.Paths)
|
||||
from_file : Optional[Path] = Field(default=None, description='Take command input from the indicated file (command-line client only)', json_schema_extra=Categories.Paths)
|
||||
@@ -285,11 +279,15 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||
free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_vram_cache_size : Optional[float] = Field(default=None, ge=0, description="Amount of VRAM reserved for model storage", json_schema_extra=Categories.MemoryPerformance)
|
||||
xformers_enabled : bool = Field(default=True, description="Enable/disable memory-efficient attention", json_schema_extra=Categories.MemoryPerformance)
|
||||
tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", json_schema_extra=Categories.MemoryPerformance)
|
||||
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
# this is not referred to in the source code and can be removed entirely
|
||||
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||
|
||||
# See InvokeAIAppConfig subclass below for CACHE and DEVICE categories
|
||||
# fmt: on
|
||||
@@ -303,8 +301,8 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
clobber=False,
|
||||
):
|
||||
"""
|
||||
Update settings with contents of init file, environment, and
|
||||
command-line settings.
|
||||
Update settings with contents of init file, environment, and command-line settings.
|
||||
|
||||
:param conf: alternate Omegaconf dictionary object
|
||||
:param argv: aternate sys.argv list
|
||||
:param clobber: ovewrite any initialization parameters passed during initialization
|
||||
@@ -337,9 +335,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, **kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
This returns a singleton InvokeAIAppConfig configuration object.
|
||||
"""
|
||||
"""Return a singleton InvokeAIAppConfig configuration object."""
|
||||
if (
|
||||
cls.singleton_config is None
|
||||
or type(cls.singleton_config) is not cls
|
||||
@@ -351,9 +347,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
@property
|
||||
def root_path(self) -> Path:
|
||||
"""
|
||||
Path to the runtime root directory
|
||||
"""
|
||||
"""Path to the runtime root directory."""
|
||||
if self.root:
|
||||
root = Path(self.root).expanduser().absolute()
|
||||
else:
|
||||
@@ -363,9 +357,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
@property
|
||||
def root_dir(self) -> Path:
|
||||
"""
|
||||
Alias for above.
|
||||
"""
|
||||
"""Alias for above."""
|
||||
return self.root_path
|
||||
|
||||
def _resolve(self, partial_path: Path) -> Path:
|
||||
@@ -373,108 +365,95 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
|
||||
@property
|
||||
def init_file_path(self) -> Path:
|
||||
"""
|
||||
Path to invokeai.yaml
|
||||
"""
|
||||
return self._resolve(INIT_FILE)
|
||||
"""Path to invokeai.yaml."""
|
||||
resolved_path = self._resolve(INIT_FILE)
|
||||
assert resolved_path is not None
|
||||
return resolved_path
|
||||
|
||||
@property
|
||||
def output_path(self) -> Path:
|
||||
"""
|
||||
Path to defaults outputs directory.
|
||||
"""
|
||||
def output_path(self) -> Optional[Path]:
|
||||
"""Path to defaults outputs directory."""
|
||||
return self._resolve(self.outdir)
|
||||
|
||||
@property
|
||||
def db_path(self) -> Path:
|
||||
"""
|
||||
Path to the invokeai.db file.
|
||||
"""
|
||||
return self._resolve(self.db_dir) / DB_FILE
|
||||
"""Path to the invokeai.db file."""
|
||||
db_dir = self._resolve(self.db_dir)
|
||||
assert db_dir is not None
|
||||
return db_dir / DB_FILE
|
||||
|
||||
@property
|
||||
def model_conf_path(self) -> Path:
|
||||
"""
|
||||
Path to models configuration file.
|
||||
"""
|
||||
def model_conf_path(self) -> Optional[Path]:
|
||||
"""Path to models configuration file."""
|
||||
return self._resolve(self.conf_path)
|
||||
|
||||
@property
|
||||
def legacy_conf_path(self) -> Path:
|
||||
"""
|
||||
Path to directory of legacy configuration files (e.g. v1-inference.yaml)
|
||||
"""
|
||||
def legacy_conf_path(self) -> Optional[Path]:
|
||||
"""Path to directory of legacy configuration files (e.g. v1-inference.yaml)."""
|
||||
return self._resolve(self.legacy_conf_dir)
|
||||
|
||||
@property
|
||||
def models_path(self) -> Path:
|
||||
"""
|
||||
Path to the models directory
|
||||
"""
|
||||
def models_path(self) -> Optional[Path]:
|
||||
"""Path to the models directory."""
|
||||
return self._resolve(self.models_dir)
|
||||
|
||||
@property
|
||||
def custom_nodes_path(self) -> Path:
|
||||
"""
|
||||
Path to the custom nodes directory
|
||||
"""
|
||||
return self._resolve(self.custom_nodes_dir)
|
||||
"""Path to the custom nodes directory."""
|
||||
custom_nodes_path = self._resolve(self.custom_nodes_dir)
|
||||
assert custom_nodes_path is not None
|
||||
return custom_nodes_path
|
||||
|
||||
# the following methods support legacy calls leftover from the Globals era
|
||||
@property
|
||||
def full_precision(self) -> bool:
|
||||
"""Return true if precision set to float32"""
|
||||
"""Return true if precision set to float32."""
|
||||
return self.precision == "float32"
|
||||
|
||||
@property
|
||||
def try_patchmatch(self) -> bool:
|
||||
"""Return true if patchmatch true"""
|
||||
"""Return true if patchmatch true."""
|
||||
return self.patchmatch
|
||||
|
||||
@property
|
||||
def nsfw_checker(self) -> bool:
|
||||
"""NSFW node is always active and disabled from Web UIe"""
|
||||
"""Return value for NSFW checker. The NSFW node is always active and disabled from Web UI."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def invisible_watermark(self) -> bool:
|
||||
"""invisible watermark node is always active and disabled from Web UIe"""
|
||||
"""Return value of invisible watermark. It is always active and disabled from Web UI."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def ram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||
"""Return the ram cache size using the legacy or modern setting."""
|
||||
return self.max_cache_size or self.ram
|
||||
|
||||
@property
|
||||
def vram_cache_size(self) -> Union[Literal["auto"], float]:
|
||||
"""Return the vram cache size using the legacy or modern setting."""
|
||||
return self.max_vram_cache_size or self.vram
|
||||
|
||||
@property
|
||||
def use_cpu(self) -> bool:
|
||||
"""Return true if the device is set to CPU or the always_use_cpu flag is set."""
|
||||
return self.always_use_cpu or self.device == "cpu"
|
||||
|
||||
@property
|
||||
def disable_xformers(self) -> bool:
|
||||
"""
|
||||
Return true if enable_xformers is false (reversed logic)
|
||||
and attention type is not set to xformers.
|
||||
"""
|
||||
"""Return true if enable_xformers is false (reversed logic) and attention type is not set to xformers."""
|
||||
disabled_in_config = not self.xformers_enabled
|
||||
return disabled_in_config and self.attention_type != "xformers"
|
||||
|
||||
@staticmethod
|
||||
def find_root() -> Path:
|
||||
"""
|
||||
Choose the runtime root directory when not specified on command line or
|
||||
init file.
|
||||
"""
|
||||
"""Choose the runtime root directory when not specified on command line or init file."""
|
||||
return _find_root()
|
||||
|
||||
|
||||
def get_invokeai_config(**kwargs) -> InvokeAIAppConfig:
|
||||
"""
|
||||
Legacy function which returns InvokeAIAppConfig.get_config()
|
||||
"""
|
||||
"""Legacy function which returns InvokeAIAppConfig.get_config()."""
|
||||
return InvokeAIAppConfig.get_config(**kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
class ImageFileStorageBase(ABC):
|
||||
@@ -33,7 +34,7 @@ class ImageFileStorageBase(ABC):
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
|
||||
@@ -43,3 +44,8 @@ class ImageFileStorageBase(ABC):
|
||||
def delete(self, image_name: str) -> None:
|
||||
"""Deletes an image and its thumbnail (if one exists)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
|
||||
"""Gets the workflow of an image."""
|
||||
pass
|
||||
|
||||
@@ -7,8 +7,9 @@ from PIL import Image, PngImagePlugin
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
from .image_files_base import ImageFileStorageBase
|
||||
@@ -56,7 +57,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image: PILImageType,
|
||||
image_name: str,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
thumbnail_size: int = 256,
|
||||
) -> None:
|
||||
try:
|
||||
@@ -64,12 +65,19 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
image_path = self.get_path(image_name)
|
||||
|
||||
pnginfo = PngImagePlugin.PngInfo()
|
||||
info_dict = {}
|
||||
|
||||
if metadata is not None:
|
||||
pnginfo.add_text("invokeai_metadata", metadata.model_dump_json())
|
||||
metadata_json = metadata.model_dump_json()
|
||||
info_dict["invokeai_metadata"] = metadata_json
|
||||
pnginfo.add_text("invokeai_metadata", metadata_json)
|
||||
if workflow is not None:
|
||||
pnginfo.add_text("invokeai_workflow", workflow.model_dump_json())
|
||||
workflow_json = workflow.model_dump_json()
|
||||
info_dict["invokeai_workflow"] = workflow_json
|
||||
pnginfo.add_text("invokeai_workflow", workflow_json)
|
||||
|
||||
# When saving the image, the image object's info field is not populated. We need to set it
|
||||
image.info = info_dict
|
||||
image.save(
|
||||
image_path,
|
||||
"PNG",
|
||||
@@ -121,6 +129,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
|
||||
path = path if isinstance(path, Path) else Path(path)
|
||||
return path.exists()
|
||||
|
||||
def get_workflow(self, image_name: str) -> WorkflowWithoutID | None:
|
||||
image = self.get(image_name)
|
||||
workflow = image.info.get("invokeai_workflow", None)
|
||||
if workflow is not None:
|
||||
return WorkflowWithoutID.model_validate_json(workflow)
|
||||
return None
|
||||
|
||||
def __validate_storage_folders(self) -> None:
|
||||
"""Checks if the required output folders exist and create them if they don't"""
|
||||
folders: list[Path] = [self.__output_folder, self.__thumbnails_folder]
|
||||
|
||||
@@ -75,6 +75,7 @@ class ImageRecordStorageBase(ABC):
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
has_workflow: bool,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
|
||||
@@ -100,6 +100,7 @@ IMAGE_DTO_COLS = ", ".join(
|
||||
"height",
|
||||
"session_id",
|
||||
"node_id",
|
||||
"has_workflow",
|
||||
"is_intermediate",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
@@ -145,6 +146,7 @@ class ImageRecord(BaseModelExcludeNull):
|
||||
"""The node ID that generated this image, if it is a generated image."""
|
||||
starred: bool = Field(description="Whether this image is starred.")
|
||||
"""Whether this image is starred."""
|
||||
has_workflow: bool = Field(description="Whether this image has a workflow.")
|
||||
|
||||
|
||||
class ImageRecordChanges(BaseModelExcludeNull, extra="allow"):
|
||||
@@ -188,6 +190,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
|
||||
is_intermediate = image_dict.get("is_intermediate", False)
|
||||
starred = image_dict.get("starred", False)
|
||||
has_workflow = image_dict.get("has_workflow", False)
|
||||
|
||||
return ImageRecord(
|
||||
image_name=image_name,
|
||||
@@ -202,4 +205,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
deleted_at=deleted_at,
|
||||
is_intermediate=is_intermediate,
|
||||
starred=starred,
|
||||
has_workflow=has_workflow,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Optional, Union, cast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .image_records_base import ImageRecordStorageBase
|
||||
from .image_records_common import (
|
||||
@@ -117,6 +117,16 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute("PRAGMA table_info(images)")
|
||||
columns = [column[1] for column in self._cursor.fetchall()]
|
||||
if "has_workflow" not in columns:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
ALTER TABLE images
|
||||
ADD COLUMN has_workflow BOOLEAN DEFAULT FALSE;
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@@ -408,6 +418,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
has_workflow: bool,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
@@ -429,9 +440,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
session_id,
|
||||
metadata,
|
||||
is_intermediate,
|
||||
starred
|
||||
starred,
|
||||
has_workflow
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
|
||||
""",
|
||||
(
|
||||
image_name,
|
||||
@@ -444,6 +456,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
metadata_json,
|
||||
is_intermediate,
|
||||
starred,
|
||||
has_workflow,
|
||||
),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Callable, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageRecord,
|
||||
@@ -12,6 +12,7 @@ from invokeai.app.services.image_records.image_records_common import (
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
class ImageServiceABC(ABC):
|
||||
@@ -51,7 +52,7 @@ class ImageServiceABC(ABC):
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
) -> ImageDTO:
|
||||
"""Creates an image, storing the file and its metadata."""
|
||||
pass
|
||||
@@ -85,6 +86,11 @@ class ImageServiceABC(ABC):
|
||||
"""Gets an image's metadata."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
|
||||
"""Gets an image's workflow."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
"""Gets an image's path."""
|
||||
|
||||
@@ -24,11 +24,6 @@ class ImageDTO(ImageRecord, ImageUrlsDTO):
|
||||
default=None, description="The id of the board the image belongs to, if one exists."
|
||||
)
|
||||
"""The id of the board the image belongs to, if one exists."""
|
||||
workflow_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The workflow that generated this image.",
|
||||
)
|
||||
"""The workflow that generated this image."""
|
||||
|
||||
|
||||
def image_record_to_dto(
|
||||
@@ -36,7 +31,6 @@ def image_record_to_dto(
|
||||
image_url: str,
|
||||
thumbnail_url: str,
|
||||
board_id: Optional[str],
|
||||
workflow_id: Optional[str],
|
||||
) -> ImageDTO:
|
||||
"""Converts an image record to an image DTO."""
|
||||
return ImageDTO(
|
||||
@@ -44,5 +38,4 @@ def image_record_to_dto(
|
||||
image_url=image_url,
|
||||
thumbnail_url=thumbnail_url,
|
||||
board_id=board_id,
|
||||
workflow_id=workflow_id,
|
||||
)
|
||||
|
||||
@@ -2,9 +2,10 @@ from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, WorkflowField
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
from ..image_files.image_files_common import (
|
||||
ImageFileDeleteException,
|
||||
@@ -42,7 +43,7 @@ class ImageService(ImageServiceABC):
|
||||
board_id: Optional[str] = None,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
workflow: Optional[WorkflowField] = None,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
) -> ImageDTO:
|
||||
if image_origin not in ResourceOrigin:
|
||||
raise InvalidOriginException
|
||||
@@ -55,12 +56,6 @@ class ImageService(ImageServiceABC):
|
||||
(width, height) = image.size
|
||||
|
||||
try:
|
||||
if workflow is not None:
|
||||
created_workflow = self.__invoker.services.workflow_records.create(workflow)
|
||||
workflow_id = created_workflow.model_dump()["id"]
|
||||
else:
|
||||
workflow_id = None
|
||||
|
||||
# TODO: Consider using a transaction here to ensure consistency between storage and database
|
||||
self.__invoker.services.image_records.save(
|
||||
# Non-nullable fields
|
||||
@@ -69,6 +64,7 @@ class ImageService(ImageServiceABC):
|
||||
image_category=image_category,
|
||||
width=width,
|
||||
height=height,
|
||||
has_workflow=workflow is not None,
|
||||
# Meta fields
|
||||
is_intermediate=is_intermediate,
|
||||
# Nullable fields
|
||||
@@ -78,8 +74,6 @@ class ImageService(ImageServiceABC):
|
||||
)
|
||||
if board_id is not None:
|
||||
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
|
||||
if workflow_id is not None:
|
||||
self.__invoker.services.workflow_image_records.create(workflow_id=workflow_id, image_name=image_name)
|
||||
self.__invoker.services.image_files.save(
|
||||
image_name=image_name, image=image, metadata=metadata, workflow=workflow
|
||||
)
|
||||
@@ -143,7 +137,6 @@ class ImageService(ImageServiceABC):
|
||||
image_url=self.__invoker.services.urls.get_image_url(image_name),
|
||||
thumbnail_url=self.__invoker.services.urls.get_image_url(image_name, True),
|
||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(image_name),
|
||||
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name),
|
||||
)
|
||||
|
||||
return image_dto
|
||||
@@ -164,18 +157,15 @@ class ImageService(ImageServiceABC):
|
||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowField]:
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
|
||||
try:
|
||||
workflow_id = self.__invoker.services.workflow_image_records.get_workflow_for_image(image_name)
|
||||
if workflow_id is None:
|
||||
return None
|
||||
return self.__invoker.services.workflow_records.get(workflow_id)
|
||||
except ImageRecordNotFoundException:
|
||||
self.__invoker.services.logger.error("Image record not found")
|
||||
return self.__invoker.services.image_files.get_workflow(image_name)
|
||||
except ImageFileNotFoundException:
|
||||
self.__invoker.services.logger.error("Image file not found")
|
||||
raise
|
||||
except Exception:
|
||||
self.__invoker.services.logger.error("Problem getting image workflow")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_path(self, image_name: str, thumbnail: bool = False) -> str:
|
||||
try:
|
||||
@@ -223,7 +213,6 @@ class ImageService(ImageServiceABC):
|
||||
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
||||
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||
workflow_id=self.__invoker.services.workflow_image_records.get_workflow_for_image(r.image_name),
|
||||
)
|
||||
for r in results.items
|
||||
]
|
||||
|
||||
@@ -108,6 +108,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
workflow=queue_item.workflow,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -178,6 +179,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
session_queue_item_id=queue_item.session_queue_item_id,
|
||||
session_queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state=graph_execution_state,
|
||||
workflow=queue_item.workflow,
|
||||
invoke_all=True,
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
@@ -15,5 +18,6 @@ class InvocationQueueItem(BaseModel):
|
||||
session_queue_batch_id: str = Field(
|
||||
description="The ID of the session batch from which this invocation queue item came"
|
||||
)
|
||||
workflow: Optional[WorkflowWithoutID] = Field(description="The workflow associated with this queue item")
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
|
||||
@@ -28,7 +28,6 @@ if TYPE_CHECKING:
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
from .shared.graph import GraphExecutionState, LibraryGraph
|
||||
from .urls.urls_base import UrlServiceBase
|
||||
from .workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
|
||||
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
|
||||
|
||||
@@ -59,7 +58,6 @@ class InvocationServices:
|
||||
invocation_cache: "InvocationCacheBase"
|
||||
names: "NameServiceBase"
|
||||
urls: "UrlServiceBase"
|
||||
workflow_image_records: "WorkflowImageRecordsStorageBase"
|
||||
workflow_records: "WorkflowRecordsStorageBase"
|
||||
|
||||
def __init__(
|
||||
@@ -87,7 +85,6 @@ class InvocationServices:
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
names: "NameServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
workflow_image_records: "WorkflowImageRecordsStorageBase",
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
@@ -113,5 +110,4 @@ class InvocationServices:
|
||||
self.invocation_cache = invocation_cache
|
||||
self.names = names
|
||||
self.urls = urls
|
||||
self.workflow_image_records = workflow_image_records
|
||||
self.workflow_records = workflow_records
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
from .invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||
from .invocation_services import InvocationServices
|
||||
from .shared.graph import Graph, GraphExecutionState
|
||||
@@ -22,6 +24,7 @@ class Invoker:
|
||||
session_queue_item_id: int,
|
||||
session_queue_batch_id: str,
|
||||
graph_execution_state: GraphExecutionState,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
invoke_all: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
@@ -43,6 +46,7 @@ class Invoker:
|
||||
session_queue_batch_id=session_queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
workflow=workflow,
|
||||
invoke_all=invoke_all,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Generic, Optional, TypeVar, get_args
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
from .item_storage_base import ItemStorageABC
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
|
||||
|
||||
@@ -17,6 +19,10 @@ class DiskLatentsStorage(LatentsStorageBase):
|
||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._delete_all_latents()
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
@@ -32,3 +38,21 @@ class DiskLatentsStorage(LatentsStorageBase):
|
||||
|
||||
def get_path(self, name: str) -> Path:
|
||||
return self.__output_folder / name
|
||||
|
||||
def _delete_all_latents(self) -> None:
|
||||
"""
|
||||
Deletes all latents from disk.
|
||||
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
||||
"""
|
||||
deleted_latents_count = 0
|
||||
freed_space = 0
|
||||
for latents_file in Path(self.__output_folder).glob("*"):
|
||||
if latents_file.is_file():
|
||||
freed_space += latents_file.stat().st_size
|
||||
deleted_latents_count += 1
|
||||
latents_file.unlink()
|
||||
if deleted_latents_count > 0:
|
||||
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
|
||||
self._invoker.services.logger.info(
|
||||
f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)"
|
||||
)
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
|
||||
|
||||
@@ -23,6 +25,18 @@ class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = max_cache_size
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
start_op = getattr(self.__underlying_storage, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(invoker)
|
||||
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
stop_op = getattr(self.__underlying_storage, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(invoker)
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
cache_item = self.__get_cache(name)
|
||||
if cache_item is not None:
|
||||
|
||||
@@ -48,12 +48,11 @@ from typing import List, Optional, Union
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelConfigFactory,
|
||||
ModelType,
|
||||
)
|
||||
|
||||
from ..shared.sqlite import SqliteDatabase
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from .model_records_base import (
|
||||
CONFIG_FILE_VERSION,
|
||||
DuplicateModelException,
|
||||
@@ -158,7 +157,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
("version", CONFIG_FILE_VERSION),
|
||||
)
|
||||
|
||||
def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> AnyModelConfig:
|
||||
def add_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model to the database.
|
||||
|
||||
@@ -255,7 +254,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
self._db.conn.rollback()
|
||||
raise e
|
||||
|
||||
def update_model(self, key: str, config: ModelConfigBase) -> AnyModelConfig:
|
||||
def update_model(self, key: str, config: Union[dict, AnyModelConfig]) -> AnyModelConfig:
|
||||
"""
|
||||
Update the model, returning the updated version.
|
||||
|
||||
@@ -368,7 +367,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated path."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
@@ -382,7 +381,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
|
||||
return results
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
"""Return models with the indicated original_hash."""
|
||||
results = []
|
||||
with self._db.lock:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import traceback
|
||||
from threading import BoundedSemaphore
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
@@ -115,6 +114,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
session_queue_id=queue_item.queue_id,
|
||||
session_queue_item_id=queue_item.item_id,
|
||||
graph_execution_state=queue_item.session,
|
||||
workflow=queue_item.workflow,
|
||||
invoke_all=True,
|
||||
)
|
||||
queue_item = None
|
||||
|
||||
@@ -8,6 +8,10 @@ from pydantic_core import to_jsonable_python
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
WorkflowWithoutID,
|
||||
WorkflowWithoutIDValidator,
|
||||
)
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
# region Errors
|
||||
@@ -66,6 +70,9 @@ class Batch(BaseModel):
|
||||
batch_id: str = Field(default_factory=uuid_string, description="The ID of the batch")
|
||||
data: Optional[BatchDataCollection] = Field(default=None, description="The batch data collection.")
|
||||
graph: Graph = Field(description="The graph to initialize the session with")
|
||||
workflow: Optional[WorkflowWithoutID] = Field(
|
||||
default=None, description="The workflow to initialize the session with"
|
||||
)
|
||||
runs: int = Field(
|
||||
default=1, ge=1, description="Int stating how many times to iterate through all possible batch indices"
|
||||
)
|
||||
@@ -164,6 +171,14 @@ def get_session(queue_item_dict: dict) -> GraphExecutionState:
|
||||
return session
|
||||
|
||||
|
||||
def get_workflow(queue_item_dict: dict) -> Optional[WorkflowWithoutID]:
|
||||
workflow_raw = queue_item_dict.get("workflow", None)
|
||||
if workflow_raw is not None:
|
||||
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw, strict=False)
|
||||
return workflow
|
||||
return None
|
||||
|
||||
|
||||
class SessionQueueItemWithoutGraph(BaseModel):
|
||||
"""Session queue item without the full graph. Used for serialization."""
|
||||
|
||||
@@ -213,12 +228,16 @@ class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||
|
||||
class SessionQueueItem(SessionQueueItemWithoutGraph):
|
||||
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
|
||||
workflow: Optional[WorkflowWithoutID] = Field(
|
||||
default=None, description="The workflow associated with this queue item"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def queue_item_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItem":
|
||||
# must parse these manually
|
||||
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
|
||||
queue_item_dict["session"] = get_session(queue_item_dict)
|
||||
queue_item_dict["workflow"] = get_workflow(queue_item_dict)
|
||||
return SessionQueueItem(**queue_item_dict)
|
||||
|
||||
model_config = ConfigDict(
|
||||
@@ -334,7 +353,7 @@ def populate_graph(graph: Graph, node_field_values: Iterable[NodeFieldValue]) ->
|
||||
|
||||
def create_session_nfv_tuples(
|
||||
batch: Batch, maximum: int
|
||||
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue]], None, None]:
|
||||
) -> Generator[tuple[GraphExecutionState, list[NodeFieldValue], Optional[WorkflowWithoutID]], None, None]:
|
||||
"""
|
||||
Create all graph permutations from the given batch data and graph. Yields tuples
|
||||
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
|
||||
@@ -365,7 +384,7 @@ def create_session_nfv_tuples(
|
||||
return
|
||||
flat_node_field_values = list(chain.from_iterable(d))
|
||||
graph = populate_graph(batch.graph, flat_node_field_values)
|
||||
yield (GraphExecutionState(graph=graph), flat_node_field_values)
|
||||
yield (GraphExecutionState(graph=graph), flat_node_field_values, batch.workflow)
|
||||
count += 1
|
||||
|
||||
|
||||
@@ -391,12 +410,14 @@ def calc_session_count(batch: Batch) -> int:
|
||||
class SessionQueueValueToInsert(NamedTuple):
|
||||
"""A tuple of values to insert into the session_queue table"""
|
||||
|
||||
# Careful with the ordering of this - it must match the insert statement
|
||||
queue_id: str # queue_id
|
||||
session: str # session json
|
||||
session_id: str # session_id
|
||||
batch_id: str # batch_id
|
||||
field_values: Optional[str] # field_values json
|
||||
priority: int # priority
|
||||
workflow: Optional[str] # workflow json
|
||||
|
||||
|
||||
ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
@@ -404,7 +425,7 @@ ValuesToInsert: TypeAlias = list[SessionQueueValueToInsert]
|
||||
|
||||
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> ValuesToInsert:
|
||||
values_to_insert: ValuesToInsert = []
|
||||
for session, field_values in create_session_nfv_tuples(batch, max_new_queue_items):
|
||||
for session, field_values, workflow in create_session_nfv_tuples(batch, max_new_queue_items):
|
||||
# sessions must have unique id
|
||||
session.id = uuid_string()
|
||||
values_to_insert.append(
|
||||
@@ -416,6 +437,7 @@ def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new
|
||||
# must use pydantic_encoder bc field_values is a list of models
|
||||
json.dumps(field_values, default=to_jsonable_python) if field_values else None, # field_values (json)
|
||||
priority, # priority
|
||||
json.dumps(workflow, default=to_jsonable_python) if workflow else None, # workflow (json)
|
||||
)
|
||||
)
|
||||
return values_to_insert
|
||||
|
||||
@@ -28,7 +28,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class SqliteSessionQueue(SessionQueueBase):
|
||||
@@ -42,7 +42,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self._set_in_progress_to_canceled()
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
@@ -198,6 +199,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute("PRAGMA table_info(session_queue)")
|
||||
columns = [column[1] for column in self.__cursor.fetchall()]
|
||||
if "workflow" not in columns:
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
ALTER TABLE session_queue ADD COLUMN workflow TEXT;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
@@ -280,8 +290,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
self.__cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
|
||||
@@ -49,7 +49,7 @@ class Edge(BaseModel):
|
||||
|
||||
def get_output_field(node: BaseInvocation, field: str) -> Any:
|
||||
node_type = type(node)
|
||||
node_outputs = get_type_hints(node_type.get_output_type())
|
||||
node_outputs = get_type_hints(node_type.get_output_annotation())
|
||||
node_output_field = node_outputs.get(field) or None
|
||||
return node_output_field
|
||||
|
||||
@@ -188,7 +188,7 @@ class GraphInvocationOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
@invocation("graph")
|
||||
@invocation("graph", version="1.0.0")
|
||||
class GraphInvocation(BaseInvocation):
|
||||
"""Execute a graph"""
|
||||
|
||||
@@ -205,29 +205,31 @@ class IterateInvocationOutput(BaseInvocationOutput):
|
||||
"""Used to connect iteration outputs. Will be expanded to a specific output."""
|
||||
|
||||
item: Any = OutputField(
|
||||
description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem
|
||||
description="The item being iterated over", title="Collection Item", ui_type=UIType._CollectionItem
|
||||
)
|
||||
index: int = OutputField(description="The index of the item", title="Index")
|
||||
total: int = OutputField(description="The total number of items", title="Total")
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
@invocation("iterate", version="1.0.0")
|
||||
@invocation("iterate", version="1.1.0")
|
||||
class IterateInvocation(BaseInvocation):
|
||||
"""Iterates over a list of items"""
|
||||
|
||||
collection: list[Any] = InputField(
|
||||
description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection
|
||||
description="The list of items to iterate over", default=[], ui_type=UIType._Collection
|
||||
)
|
||||
index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
|
||||
"""Produces the outputs as values"""
|
||||
return IterateInvocationOutput(item=self.collection[self.index])
|
||||
return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection))
|
||||
|
||||
|
||||
@invocation_output("collect_output")
|
||||
class CollectInvocationOutput(BaseInvocationOutput):
|
||||
collection: list[Any] = OutputField(
|
||||
description="The collection of input items", title="Collection", ui_type=UIType.Collection
|
||||
description="The collection of input items", title="Collection", ui_type=UIType._Collection
|
||||
)
|
||||
|
||||
|
||||
@@ -238,12 +240,12 @@ class CollectInvocation(BaseInvocation):
|
||||
item: Optional[Any] = InputField(
|
||||
default=None,
|
||||
description="The item to collect (all inputs must be of the same type)",
|
||||
ui_type=UIType.CollectionItem,
|
||||
ui_type=UIType._CollectionItem,
|
||||
title="Collection Item",
|
||||
input=Input.Connection,
|
||||
)
|
||||
collection: list[Any] = InputField(
|
||||
description="The collection, will be provided on execution", default_factory=list, ui_hidden=True
|
||||
description="The collection, will be provided on execution", default=[], ui_hidden=True
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
|
||||
@@ -379,7 +381,7 @@ class Graph(BaseModel):
|
||||
raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph")
|
||||
|
||||
# output fields are not on the node object directly, they are on the output type
|
||||
if edge.source.field not in source_node.get_output_type().model_fields:
|
||||
if edge.source.field not in source_node.get_output_annotation().model_fields:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
||||
)
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class SqliteDatabase:
|
||||
conn: sqlite3.Connection
|
||||
lock: threading.RLock
|
||||
_logger: Logger
|
||||
_config: InvokeAIAppConfig
|
||||
|
||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
||||
self._logger = logger
|
||||
self._config = config
|
||||
|
||||
if self._config.use_memory_db:
|
||||
location = sqlite_memory
|
||||
logger.info("Using in-memory database")
|
||||
else:
|
||||
db_path = self._config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
location = str(db_path)
|
||||
self._logger.info(f"Using database at {location}")
|
||||
|
||||
self.conn = sqlite3.connect(location, check_same_thread=False)
|
||||
self.lock = threading.RLock()
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
if self._config.log_sql:
|
||||
self.conn.set_trace_callback(self._logger.debug)
|
||||
|
||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
def clean(self) -> None:
|
||||
try:
|
||||
self.lock.acquire()
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
self._logger.info("Cleaned database")
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error cleaning database: {e}")
|
||||
raise e
|
||||
finally:
|
||||
self.lock.release()
|
||||
10
invokeai/app/services/shared/sqlite/sqlite_common.py
Normal file
10
invokeai/app/services/shared/sqlite/sqlite_common.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from enum import Enum
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
|
||||
sqlite_memory = ":memory:"
|
||||
|
||||
|
||||
class SQLiteDirection(str, Enum, metaclass=MetaEnum):
|
||||
Ascending = "ASC"
|
||||
Descending = "DESC"
|
||||
47
invokeai/app/services/shared/sqlite/sqlite_database.py
Normal file
47
invokeai/app/services/shared/sqlite/sqlite_database.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
|
||||
|
||||
class SqliteDatabase:
|
||||
def __init__(self, config: InvokeAIAppConfig, logger: Logger):
|
||||
self._logger = logger
|
||||
self._config = config
|
||||
|
||||
if self._config.use_memory_db:
|
||||
self.db_path = sqlite_memory
|
||||
logger.info("Using in-memory database")
|
||||
else:
|
||||
db_path = self._config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.db_path = str(db_path)
|
||||
self._logger.info(f"Using database at {self.db_path}")
|
||||
|
||||
self.conn = sqlite3.connect(self.db_path, check_same_thread=False)
|
||||
self.lock = threading.RLock()
|
||||
self.conn.row_factory = sqlite3.Row
|
||||
|
||||
if self._config.log_sql:
|
||||
self.conn.set_trace_callback(self._logger.debug)
|
||||
|
||||
self.conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
def clean(self) -> None:
|
||||
with self.lock:
|
||||
try:
|
||||
if self.db_path == sqlite_memory:
|
||||
return
|
||||
initial_db_size = Path(self.db_path).stat().st_size
|
||||
self.conn.execute("VACUUM;")
|
||||
self.conn.commit()
|
||||
final_db_size = Path(self.db_path).stat().st_size
|
||||
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
||||
if freed_space_in_mb > 0:
|
||||
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
||||
except Exception as e:
|
||||
self._logger.error(f"Error cleaning database: {e}")
|
||||
raise
|
||||
@@ -1,23 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class WorkflowImageRecordsStorageBase(ABC):
|
||||
"""Abstract base class for the one-to-many workflow-image relationship record storage."""
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Creates a workflow-image record."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_workflow_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
"""Gets an image's workflow id, if it has one."""
|
||||
pass
|
||||
@@ -1,122 +0,0 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, cast
|
||||
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.workflow_image_records.workflow_image_records_base import WorkflowImageRecordsStorageBase
|
||||
|
||||
|
||||
class SqliteWorkflowImageRecordsStorage(WorkflowImageRecordsStorageBase):
|
||||
"""SQLite implementation of WorkflowImageRecordsStorageBase."""
|
||||
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
self._conn = db.conn
|
||||
self._cursor = self._conn.cursor()
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
# Create the `workflow_images` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflow_images (
|
||||
workflow_id TEXT NOT NULL,
|
||||
image_name TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Soft delete, currently unused
|
||||
deleted_at DATETIME,
|
||||
-- enforce one-to-many relationship between workflows and images using PK
|
||||
-- (we can extend this to many-to-many later)
|
||||
PRIMARY KEY (image_name),
|
||||
FOREIGN KEY (workflow_id) REFERENCES workflows (workflow_id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (image_name) REFERENCES images (image_name) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for workflow id
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id ON workflow_images (workflow_id);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add index for workflow id, sorted by created_at
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_images_workflow_id_created_at ON workflow_images (workflow_id, created_at);
|
||||
"""
|
||||
)
|
||||
|
||||
# Add trigger for `updated_at`.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflow_images_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflow_images FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflow_images SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id AND image_name = old.image_name;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
def create(
|
||||
self,
|
||||
workflow_id: str,
|
||||
image_name: str,
|
||||
) -> None:
|
||||
"""Creates a workflow-image record."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO workflow_images (workflow_id, image_name)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(workflow_id, image_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_workflow_for_image(
|
||||
self,
|
||||
image_name: str,
|
||||
) -> Optional[str]:
|
||||
"""Gets an image's workflow id, if it has one."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id
|
||||
FROM workflow_images
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(image_name,),
|
||||
)
|
||||
result = self._cursor.fetchone()
|
||||
if result is None:
|
||||
return None
|
||||
return cast(str, result[0])
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
self._lock.release()
|
||||
@@ -0,0 +1,17 @@
|
||||
# Default Workflows
|
||||
|
||||
Workflows placed in this directory will be synced to the `workflow_library` as
|
||||
_default workflows_ on app startup.
|
||||
|
||||
- Default workflows are not editable by users. If they are loaded and saved,
|
||||
they will save as a copy of the default workflow.
|
||||
- Default workflows must have the `meta.category` property set to `"default"`.
|
||||
An exception will be raised during sync if this is not set correctly.
|
||||
- Default workflows appear on the "Default Workflows" tab of the Workflow
|
||||
Library.
|
||||
|
||||
After adding or updating default workflows, you **must** start the app up and
|
||||
load them to ensure:
|
||||
|
||||
- The workflow loads without warning or errors
|
||||
- The workflow runs successfully
|
||||
@@ -0,0 +1,798 @@
|
||||
{
|
||||
"name": "Text to Image - SD1.5",
|
||||
"author": "InvokeAI",
|
||||
"description": "Sample text to image workflow for Stable Diffusion 1.5/2",
|
||||
"version": "1.1.0",
|
||||
"contact": "invoke@invoke.ai",
|
||||
"tags": "text2image, SD1.5, SD2, default",
|
||||
"notes": "",
|
||||
"exposedFields": [
|
||||
{
|
||||
"nodeId": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"fieldName": "model"
|
||||
},
|
||||
{
|
||||
"nodeId": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
|
||||
"fieldName": "prompt"
|
||||
},
|
||||
{
|
||||
"nodeId": "55705012-79b9-4aac-9f26-c0b10309785b",
|
||||
"fieldName": "width"
|
||||
},
|
||||
{
|
||||
"nodeId": "55705012-79b9-4aac-9f26-c0b10309785b",
|
||||
"fieldName": "height"
|
||||
}
|
||||
],
|
||||
"meta": {
|
||||
"category": "default",
|
||||
"version": "2.0.0"
|
||||
},
|
||||
"nodes": [
|
||||
{
|
||||
"id": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
|
||||
"type": "compel",
|
||||
"label": "Negative Compel Prompt",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"prompt": {
|
||||
"id": "7739aff6-26cb-4016-8897-5a1fb2305e4e",
|
||||
"name": "prompt",
|
||||
"fieldKind": "input",
|
||||
"label": "Negative Prompt",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "StringField"
|
||||
},
|
||||
"value": ""
|
||||
},
|
||||
"clip": {
|
||||
"id": "48d23dce-a6ae-472a-9f8c-22a714ea5ce0",
|
||||
"name": "clip",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"conditioning": {
|
||||
"id": "37cf3a9d-f6b7-4b64-8ff6-2558c5ecc447",
|
||||
"name": "conditioning",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 259,
|
||||
"position": {
|
||||
"x": 1000,
|
||||
"y": 350
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "55705012-79b9-4aac-9f26-c0b10309785b",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "55705012-79b9-4aac-9f26-c0b10309785b",
|
||||
"type": "noise",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.1",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"seed": {
|
||||
"id": "6431737c-918a-425d-a3b4-5d57e2f35d4d",
|
||||
"name": "seed",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"width": {
|
||||
"id": "38fc5b66-fe6e-47c8-bba9-daf58e454ed7",
|
||||
"name": "width",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 512
|
||||
},
|
||||
"height": {
|
||||
"id": "16298330-e2bf-4872-a514-d6923df53cbb",
|
||||
"name": "height",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 512
|
||||
},
|
||||
"use_cpu": {
|
||||
"id": "c7c436d3-7a7a-4e76-91e4-c6deb271623c",
|
||||
"name": "use_cpu",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "BooleanField"
|
||||
},
|
||||
"value": true
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"noise": {
|
||||
"id": "50f650dc-0184-4e23-a927-0497a96fe954",
|
||||
"name": "noise",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"width": {
|
||||
"id": "bb8a452b-133d-42d1-ae4a-3843d7e4109a",
|
||||
"name": "width",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
},
|
||||
"height": {
|
||||
"id": "35cfaa12-3b8b-4b7a-a884-327ff3abddd9",
|
||||
"name": "height",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 388,
|
||||
"position": {
|
||||
"x": 600,
|
||||
"y": 325
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"type": "main_model_loader",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"model": {
|
||||
"id": "993eabd2-40fd-44fe-bce7-5d0c7075ddab",
|
||||
"name": "model",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "MainModelField"
|
||||
},
|
||||
"value": {
|
||||
"model_name": "stable-diffusion-v1-5",
|
||||
"base_model": "sd-1",
|
||||
"model_type": "main"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"unet": {
|
||||
"id": "5c18c9db-328d-46d0-8cb9-143391c410be",
|
||||
"name": "unet",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "UNetField"
|
||||
}
|
||||
},
|
||||
"clip": {
|
||||
"id": "6effcac0-ec2f-4bf5-a49e-a2c29cf921f4",
|
||||
"name": "clip",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"id": "57683ba3-f5f5-4f58-b9a2-4b83dacad4a1",
|
||||
"name": "vae",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "VaeField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 226,
|
||||
"position": {
|
||||
"x": 600,
|
||||
"y": 25
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
|
||||
"type": "compel",
|
||||
"label": "Positive Compel Prompt",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.0.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"prompt": {
|
||||
"id": "7739aff6-26cb-4016-8897-5a1fb2305e4e",
|
||||
"name": "prompt",
|
||||
"fieldKind": "input",
|
||||
"label": "Positive Prompt",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "StringField"
|
||||
},
|
||||
"value": "Super cute tiger cub, national geographic award-winning photograph"
|
||||
},
|
||||
"clip": {
|
||||
"id": "48d23dce-a6ae-472a-9f8c-22a714ea5ce0",
|
||||
"name": "clip",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ClipField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"conditioning": {
|
||||
"id": "37cf3a9d-f6b7-4b64-8ff6-2558c5ecc447",
|
||||
"name": "conditioning",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 259,
|
||||
"position": {
|
||||
"x": 1000,
|
||||
"y": 25
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
|
||||
"type": "rand_int",
|
||||
"label": "Random Seed",
|
||||
"isOpen": false,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": false,
|
||||
"version": "1.0.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"low": {
|
||||
"id": "3ec65a37-60ba-4b6c-a0b2-553dd7a84b84",
|
||||
"name": "low",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"high": {
|
||||
"id": "085f853a-1a5f-494d-8bec-e4ba29a3f2d1",
|
||||
"name": "high",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 2147483647
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"value": {
|
||||
"id": "812ade4d-7699-4261-b9fc-a6c9d2ab55ee",
|
||||
"name": "value",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 32,
|
||||
"position": {
|
||||
"x": 600,
|
||||
"y": 275
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"type": "denoise_latents",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": true,
|
||||
"useCache": true,
|
||||
"version": "1.5.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"positive_conditioning": {
|
||||
"id": "90b7f4f8-ada7-4028-8100-d2e54f192052",
|
||||
"name": "positive_conditioning",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
},
|
||||
"negative_conditioning": {
|
||||
"id": "9393779e-796c-4f64-b740-902a1177bf53",
|
||||
"name": "negative_conditioning",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ConditioningField"
|
||||
}
|
||||
},
|
||||
"noise": {
|
||||
"id": "8e17f1e5-4f98-40b1-b7f4-86aeeb4554c1",
|
||||
"name": "noise",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"steps": {
|
||||
"id": "9b63302d-6bd2-42c9-ac13-9b1afb51af88",
|
||||
"name": "steps",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
},
|
||||
"value": 50
|
||||
},
|
||||
"cfg_scale": {
|
||||
"id": "87dd04d3-870e-49e1-98bf-af003a810109",
|
||||
"name": "cfg_scale",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 7.5
|
||||
},
|
||||
"denoising_start": {
|
||||
"id": "f369d80f-4931-4740-9bcd-9f0620719fab",
|
||||
"name": "denoising_start",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"denoising_end": {
|
||||
"id": "747d10e5-6f02-445c-994c-0604d814de8c",
|
||||
"name": "denoising_end",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 1
|
||||
},
|
||||
"scheduler": {
|
||||
"id": "1de84a4e-3a24-4ec8-862b-16ce49633b9b",
|
||||
"name": "scheduler",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "SchedulerField"
|
||||
},
|
||||
"value": "unipc"
|
||||
},
|
||||
"unet": {
|
||||
"id": "ffa6fef4-3ce2-4bdb-9296-9a834849489b",
|
||||
"name": "unet",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "UNetField"
|
||||
}
|
||||
},
|
||||
"control": {
|
||||
"id": "077b64cb-34be-4fcc-83f2-e399807a02bd",
|
||||
"name": "control",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "ControlField"
|
||||
}
|
||||
},
|
||||
"ip_adapter": {
|
||||
"id": "1d6948f7-3a65-4a65-a20c-768b287251aa",
|
||||
"name": "ip_adapter",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "IPAdapterField"
|
||||
}
|
||||
},
|
||||
"t2i_adapter": {
|
||||
"id": "75e67b09-952f-4083-aaf4-6b804d690412",
|
||||
"name": "t2i_adapter",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": true,
|
||||
"name": "T2IAdapterField"
|
||||
}
|
||||
},
|
||||
"cfg_rescale_multiplier": {
|
||||
"id": "9101f0a6-5fe0-4826-b7b3-47e5d506826c",
|
||||
"name": "cfg_rescale_multiplier",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "FloatField"
|
||||
},
|
||||
"value": 0
|
||||
},
|
||||
"latents": {
|
||||
"id": "334d4ba3-5a99-4195-82c5-86fb3f4f7d43",
|
||||
"name": "latents",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"denoise_mask": {
|
||||
"id": "0d3dbdbf-b014-4e95-8b18-ff2ff9cb0bfa",
|
||||
"name": "denoise_mask",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "DenoiseMaskField"
|
||||
}
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"latents": {
|
||||
"id": "70fa5bbc-0c38-41bb-861a-74d6d78d2f38",
|
||||
"name": "latents",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"width": {
|
||||
"id": "98ee0e6c-82aa-4e8f-8be5-dc5f00ee47f0",
|
||||
"name": "width",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
},
|
||||
"height": {
|
||||
"id": "e8cb184a-5e1a-47c8-9695-4b8979564f5d",
|
||||
"name": "height",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 703,
|
||||
"position": {
|
||||
"x": 1400,
|
||||
"y": 25
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
|
||||
"type": "invocation",
|
||||
"data": {
|
||||
"id": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
|
||||
"type": "l2i",
|
||||
"label": "",
|
||||
"isOpen": true,
|
||||
"notes": "",
|
||||
"isIntermediate": false,
|
||||
"useCache": true,
|
||||
"version": "1.2.0",
|
||||
"nodePack": "invokeai",
|
||||
"inputs": {
|
||||
"metadata": {
|
||||
"id": "ab375f12-0042-4410-9182-29e30db82c85",
|
||||
"name": "metadata",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "MetadataField"
|
||||
}
|
||||
},
|
||||
"latents": {
|
||||
"id": "3a7e7efd-bff5-47d7-9d48-615127afee78",
|
||||
"name": "latents",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "LatentsField"
|
||||
}
|
||||
},
|
||||
"vae": {
|
||||
"id": "a1f5f7a1-0795-4d58-b036-7820c0b0ef2b",
|
||||
"name": "vae",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "VaeField"
|
||||
}
|
||||
},
|
||||
"tiled": {
|
||||
"id": "da52059a-0cee-4668-942f-519aa794d739",
|
||||
"name": "tiled",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "BooleanField"
|
||||
},
|
||||
"value": false
|
||||
},
|
||||
"fp32": {
|
||||
"id": "c4841df3-b24e-4140-be3b-ccd454c2522c",
|
||||
"name": "fp32",
|
||||
"fieldKind": "input",
|
||||
"label": "",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "BooleanField"
|
||||
},
|
||||
"value": true
|
||||
}
|
||||
},
|
||||
"outputs": {
|
||||
"image": {
|
||||
"id": "72d667d0-cf85-459d-abf2-28bd8b823fe7",
|
||||
"name": "image",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "ImageField"
|
||||
}
|
||||
},
|
||||
"width": {
|
||||
"id": "c8c907d8-1066-49d1-b9a6-83bdcd53addc",
|
||||
"name": "width",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
},
|
||||
"height": {
|
||||
"id": "230f359c-b4ea-436c-b372-332d7dcdca85",
|
||||
"name": "height",
|
||||
"fieldKind": "output",
|
||||
"type": {
|
||||
"isCollection": false,
|
||||
"isCollectionOrScalar": false,
|
||||
"name": "IntegerField"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"width": 320,
|
||||
"height": 266,
|
||||
"position": {
|
||||
"x": 1800,
|
||||
"y": 25
|
||||
}
|
||||
}
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"id": "reactflow__edge-ea94bc37-d995-4a83-aa99-4af42479f2f2value-55705012-79b9-4aac-9f26-c0b10309785bseed",
|
||||
"source": "ea94bc37-d995-4a83-aa99-4af42479f2f2",
|
||||
"target": "55705012-79b9-4aac-9f26-c0b10309785b",
|
||||
"type": "default",
|
||||
"sourceHandle": "value",
|
||||
"targetHandle": "seed"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-7d8bf987-284f-413a-b2fd-d825445a5d6cclip",
|
||||
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"target": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
|
||||
"type": "default",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8clip-93dc02a4-d05b-48ed-b99c-c9b616af3402clip",
|
||||
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"target": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
|
||||
"type": "default",
|
||||
"sourceHandle": "clip",
|
||||
"targetHandle": "clip"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-55705012-79b9-4aac-9f26-c0b10309785bnoise-eea2702a-19fb-45b5-9d75-56b4211ec03cnoise",
|
||||
"source": "55705012-79b9-4aac-9f26-c0b10309785b",
|
||||
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"type": "default",
|
||||
"sourceHandle": "noise",
|
||||
"targetHandle": "noise"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-7d8bf987-284f-413a-b2fd-d825445a5d6cconditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cpositive_conditioning",
|
||||
"source": "7d8bf987-284f-413a-b2fd-d825445a5d6c",
|
||||
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"type": "default",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "positive_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-93dc02a4-d05b-48ed-b99c-c9b616af3402conditioning-eea2702a-19fb-45b5-9d75-56b4211ec03cnegative_conditioning",
|
||||
"source": "93dc02a4-d05b-48ed-b99c-c9b616af3402",
|
||||
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"type": "default",
|
||||
"sourceHandle": "conditioning",
|
||||
"targetHandle": "negative_conditioning"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8unet-eea2702a-19fb-45b5-9d75-56b4211ec03cunet",
|
||||
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"target": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"type": "default",
|
||||
"sourceHandle": "unet",
|
||||
"targetHandle": "unet"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-eea2702a-19fb-45b5-9d75-56b4211ec03clatents-58c957f5-0d01-41fc-a803-b2bbf0413d4flatents",
|
||||
"source": "eea2702a-19fb-45b5-9d75-56b4211ec03c",
|
||||
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
|
||||
"type": "default",
|
||||
"sourceHandle": "latents",
|
||||
"targetHandle": "latents"
|
||||
},
|
||||
{
|
||||
"id": "reactflow__edge-c8d55139-f380-4695-b7f2-8b3d1e1e3db8vae-58c957f5-0d01-41fc-a803-b2bbf0413d4fvae",
|
||||
"source": "c8d55139-f380-4695-b7f2-8b3d1e1e3db8",
|
||||
"target": "58c957f5-0d01-41fc-a803-b2bbf0413d4f",
|
||||
"type": "default",
|
||||
"sourceHandle": "vae",
|
||||
"targetHandle": "vae"
|
||||
}
|
||||
]
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,17 +1,50 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import WorkflowField
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
Workflow,
|
||||
WorkflowCategory,
|
||||
WorkflowRecordDTO,
|
||||
WorkflowRecordListItemDTO,
|
||||
WorkflowRecordOrderBy,
|
||||
WorkflowWithoutID,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRecordsStorageBase(ABC):
|
||||
"""Base class for workflow storage services."""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, workflow_id: str) -> WorkflowField:
|
||||
def get(self, workflow_id: str) -> WorkflowRecordDTO:
|
||||
"""Get workflow by id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create(self, workflow: WorkflowField) -> WorkflowField:
|
||||
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
|
||||
"""Creates a workflow."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
|
||||
"""Updates a workflow."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, workflow_id: str) -> None:
|
||||
"""Deletes a workflow."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
page: int,
|
||||
per_page: int,
|
||||
order_by: WorkflowRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
category: WorkflowCategory,
|
||||
query: Optional[str],
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
"""Gets many workflows."""
|
||||
pass
|
||||
|
||||
@@ -1,2 +1,104 @@
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Union
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, Field, JsonValue, TypeAdapter, field_validator
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
__workflow_meta_version__ = semver.Version.parse("1.0.0")
|
||||
|
||||
|
||||
class ExposedField(BaseModel):
|
||||
nodeId: str
|
||||
fieldName: str
|
||||
|
||||
|
||||
class WorkflowNotFoundError(Exception):
|
||||
"""Raised when a workflow is not found"""
|
||||
|
||||
|
||||
class WorkflowRecordOrderBy(str, Enum, metaclass=MetaEnum):
|
||||
"""The order by options for workflow records"""
|
||||
|
||||
CreatedAt = "created_at"
|
||||
UpdatedAt = "updated_at"
|
||||
OpenedAt = "opened_at"
|
||||
Name = "name"
|
||||
|
||||
|
||||
class WorkflowCategory(str, Enum, metaclass=MetaEnum):
|
||||
User = "user"
|
||||
Default = "default"
|
||||
Project = "project"
|
||||
|
||||
|
||||
class WorkflowMeta(BaseModel):
|
||||
version: str = Field(description="The version of the workflow schema.")
|
||||
category: WorkflowCategory = Field(description="The category of the workflow (user or default).")
|
||||
|
||||
@field_validator("version")
|
||||
def validate_version(cls, version: str):
|
||||
try:
|
||||
semver.Version.parse(version)
|
||||
return version
|
||||
except Exception:
|
||||
raise ValueError(f"Invalid workflow meta version: {version}")
|
||||
|
||||
def to_semver(self) -> semver.Version:
|
||||
return semver.Version.parse(self.version)
|
||||
|
||||
|
||||
class WorkflowWithoutID(BaseModel):
|
||||
name: str = Field(description="The name of the workflow.")
|
||||
author: str = Field(description="The author of the workflow.")
|
||||
description: str = Field(description="The description of the workflow.")
|
||||
version: str = Field(description="The version of the workflow.")
|
||||
contact: str = Field(description="The contact of the workflow.")
|
||||
tags: str = Field(description="The tags of the workflow.")
|
||||
notes: str = Field(description="The notes of the workflow.")
|
||||
exposedFields: list[ExposedField] = Field(description="The exposed fields of the workflow.")
|
||||
meta: WorkflowMeta = Field(description="The meta of the workflow.")
|
||||
# TODO: nodes and edges are very loosely typed
|
||||
nodes: list[dict[str, JsonValue]] = Field(description="The nodes of the workflow.")
|
||||
edges: list[dict[str, JsonValue]] = Field(description="The edges of the workflow.")
|
||||
|
||||
|
||||
WorkflowWithoutIDValidator = TypeAdapter(WorkflowWithoutID)
|
||||
|
||||
|
||||
class Workflow(WorkflowWithoutID):
|
||||
id: str = Field(default_factory=uuid_string, description="The id of the workflow.")
|
||||
|
||||
|
||||
WorkflowValidator = TypeAdapter(Workflow)
|
||||
|
||||
|
||||
class WorkflowRecordDTOBase(BaseModel):
|
||||
workflow_id: str = Field(description="The id of the workflow.")
|
||||
name: str = Field(description="The name of the workflow.")
|
||||
created_at: Union[datetime.datetime, str] = Field(description="The created timestamp of the workflow.")
|
||||
updated_at: Union[datetime.datetime, str] = Field(description="The updated timestamp of the workflow.")
|
||||
opened_at: Union[datetime.datetime, str] = Field(description="The opened timestamp of the workflow.")
|
||||
|
||||
|
||||
class WorkflowRecordDTO(WorkflowRecordDTOBase):
|
||||
workflow: Workflow = Field(description="The workflow.")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRecordDTO":
|
||||
data["workflow"] = WorkflowValidator.validate_json(data.get("workflow", ""))
|
||||
return WorkflowRecordDTOValidator.validate_python(data)
|
||||
|
||||
|
||||
WorkflowRecordDTOValidator = TypeAdapter(WorkflowRecordDTO)
|
||||
|
||||
|
||||
class WorkflowRecordListItemDTO(WorkflowRecordDTOBase):
|
||||
description: str = Field(description="The description of the workflow.")
|
||||
category: WorkflowCategory = Field(description="The description of the workflow.")
|
||||
|
||||
|
||||
WorkflowRecordListItemDTOValidator = TypeAdapter(WorkflowRecordListItemDTO)
|
||||
|
||||
@@ -1,20 +1,25 @@
|
||||
import sqlite3
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import WorkflowField, WorkflowFieldValidator
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowNotFoundError
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
Workflow,
|
||||
WorkflowCategory,
|
||||
WorkflowNotFoundError,
|
||||
WorkflowRecordDTO,
|
||||
WorkflowRecordListItemDTO,
|
||||
WorkflowRecordListItemDTOValidator,
|
||||
WorkflowRecordOrderBy,
|
||||
WorkflowValidator,
|
||||
WorkflowWithoutID,
|
||||
)
|
||||
|
||||
|
||||
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
_invoker: Invoker
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.RLock
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._lock = db.lock
|
||||
@@ -24,14 +29,25 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._sync_default_workflows()
|
||||
|
||||
def get(self, workflow_id: str) -> WorkflowField:
|
||||
def get(self, workflow_id: str) -> WorkflowRecordDTO:
|
||||
"""Gets a workflow by ID. Updates the opened_at column."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow
|
||||
FROM workflows
|
||||
UPDATE workflow_library
|
||||
SET opened_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
|
||||
FROM workflow_library
|
||||
WHERE workflow_id = ?;
|
||||
""",
|
||||
(workflow_id,),
|
||||
@@ -39,25 +55,28 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
row = self._cursor.fetchone()
|
||||
if row is None:
|
||||
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||
return WorkflowFieldValidator.validate_json(row[0])
|
||||
return WorkflowRecordDTO.from_dict(dict(row))
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def create(self, workflow: WorkflowField) -> WorkflowField:
|
||||
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
|
||||
try:
|
||||
# workflows do not have ids until they are saved
|
||||
workflow_id = uuid_string()
|
||||
workflow.root["id"] = workflow_id
|
||||
# Only user workflows may be created by this method
|
||||
assert workflow.meta.category is WorkflowCategory.User
|
||||
workflow_with_id = WorkflowValidator.validate_python(workflow.model_dump())
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT INTO workflows(workflow)
|
||||
VALUES (?);
|
||||
INSERT OR IGNORE INTO workflow_library (
|
||||
workflow_id,
|
||||
workflow
|
||||
)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(workflow.model_dump_json(),),
|
||||
(workflow_with_id.id, workflow_with_id.model_dump_json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
@@ -65,35 +84,231 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(workflow_id)
|
||||
return self.get(workflow_with_id.id)
|
||||
|
||||
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
UPDATE workflow_library
|
||||
SET workflow = ?
|
||||
WHERE workflow_id = ? AND category = 'user';
|
||||
""",
|
||||
(workflow.model_dump_json(), workflow.id),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get(workflow.id)
|
||||
|
||||
def delete(self, workflow_id: str) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE from workflow_library
|
||||
WHERE workflow_id = ? AND category = 'user';
|
||||
""",
|
||||
(workflow_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return None
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
page: int,
|
||||
per_page: int,
|
||||
order_by: WorkflowRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
category: WorkflowCategory,
|
||||
query: Optional[str] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
assert category in WorkflowCategory
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library WHERE category = ?"
|
||||
main_query = """
|
||||
SELECT
|
||||
workflow_id,
|
||||
category,
|
||||
name,
|
||||
description,
|
||||
created_at,
|
||||
updated_at,
|
||||
opened_at
|
||||
FROM workflow_library
|
||||
WHERE category = ?
|
||||
"""
|
||||
main_params: list[int | str] = [category.value]
|
||||
count_params: list[int | str] = [category.value]
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
wildcard_query = "%" + stripped_query + "%"
|
||||
main_query += " AND name LIKE ? OR description LIKE ? "
|
||||
count_query += " AND name LIKE ? OR description LIKE ?;"
|
||||
main_params.extend([wildcard_query, wildcard_query])
|
||||
count_params.extend([wildcard_query, wildcard_query])
|
||||
|
||||
main_query += f" ORDER BY {order_by.value} {direction.value} LIMIT ? OFFSET ?;"
|
||||
main_params.extend([per_page, page * per_page])
|
||||
self._cursor.execute(main_query, main_params)
|
||||
rows = self._cursor.fetchall()
|
||||
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
|
||||
|
||||
self._cursor.execute(count_query, count_params)
|
||||
total = self._cursor.fetchone()[0]
|
||||
pages = int(total / per_page) + 1
|
||||
|
||||
return PaginatedResults(
|
||||
items=workflows,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
pages=pages,
|
||||
total=total,
|
||||
)
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _sync_default_workflows(self) -> None:
|
||||
"""Syncs default workflows to the database. Internal use only."""
|
||||
|
||||
"""
|
||||
An enhancement might be to only update workflows that have changed. This would require stable
|
||||
default workflow IDs, and properly incrementing the workflow version.
|
||||
|
||||
It's much simpler to just replace them all with whichever workflows are in the directory.
|
||||
|
||||
The downside is that the `updated_at` and `opened_at` timestamps for default workflows are
|
||||
meaningless, as they are overwritten every time the server starts.
|
||||
"""
|
||||
|
||||
try:
|
||||
self._lock.acquire()
|
||||
workflows: list[Workflow] = []
|
||||
workflows_dir = Path(__file__).parent / Path("default_workflows")
|
||||
workflow_paths = workflows_dir.glob("*.json")
|
||||
for path in workflow_paths:
|
||||
bytes_ = path.read_bytes()
|
||||
workflow = WorkflowValidator.validate_json(bytes_)
|
||||
workflows.append(workflow)
|
||||
# Only default workflows may be managed by this method
|
||||
assert all(w.meta.category is WorkflowCategory.Default for w in workflows)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM workflow_library
|
||||
WHERE category = 'default';
|
||||
"""
|
||||
)
|
||||
for w in workflows:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
INSERT OR REPLACE INTO workflow_library (
|
||||
workflow_id,
|
||||
workflow
|
||||
)
|
||||
VALUES (?, ?);
|
||||
""",
|
||||
(w.id, w.model_dump_json()),
|
||||
)
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS workflows (
|
||||
CREATE TABLE IF NOT EXISTS workflow_library (
|
||||
workflow_id TEXT NOT NULL PRIMARY KEY,
|
||||
workflow TEXT NOT NULL,
|
||||
workflow_id TEXT GENERATED ALWAYS AS (json_extract(workflow, '$.id')) VIRTUAL NOT NULL UNIQUE, -- gets implicit index
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')) -- updated via trigger
|
||||
-- updated via trigger
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- updated manually when retrieving workflow
|
||||
opened_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
-- Generated columns, needed for indexing and searching
|
||||
category TEXT GENERATED ALWAYS as (json_extract(workflow, '$.meta.category')) VIRTUAL NOT NULL,
|
||||
name TEXT GENERATED ALWAYS as (json_extract(workflow, '$.name')) VIRTUAL NOT NULL,
|
||||
description TEXT GENERATED ALWAYS as (json_extract(workflow, '$.description')) VIRTUAL NOT NULL
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflows_updated_at
|
||||
CREATE TRIGGER IF NOT EXISTS tg_workflow_library_updated_at
|
||||
AFTER UPDATE
|
||||
ON workflows FOR EACH ROW
|
||||
ON workflow_library FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE workflows
|
||||
UPDATE workflow_library
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE workflow_id = old.workflow_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_created_at ON workflow_library(created_at);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_updated_at ON workflow_library(updated_at);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_opened_at ON workflow_library(opened_at);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_category ON workflow_library(category);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_name ON workflow_library(name);
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_workflow_library_description ON workflow_library(description);
|
||||
"""
|
||||
)
|
||||
|
||||
# We do not need the original `workflows` table or `workflow_images` junction table.
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DROP TABLE IF EXISTS workflow_images;
|
||||
"""
|
||||
)
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
DROP TABLE IF EXISTS workflows;
|
||||
"""
|
||||
)
|
||||
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
|
||||
@@ -2,6 +2,7 @@ class FieldDescriptions:
|
||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||
cfg_scale = "Classifier-Free Guidance scale"
|
||||
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
|
||||
scheduler = "Scheduler to use during inference"
|
||||
positive_cond = "Positive conditioning tensor"
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
|
||||
29
invokeai/backend/image_util/realesrgan/LICENSE
Normal file
29
invokeai/backend/image_util/realesrgan/LICENSE
Normal file
@@ -0,0 +1,29 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2021, Xintao Wang
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
0
invokeai/backend/image_util/realesrgan/__init__.py
Normal file
0
invokeai/backend/image_util/realesrgan/__init__.py
Normal file
274
invokeai/backend/image_util/realesrgan/realesrgan.py
Normal file
274
invokeai/backend/image_util/realesrgan/realesrgan.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import math
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from cv2.typing import MatLike
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
"""
|
||||
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
|
||||
License is BSD3, copied to `LICENSE` in this directory.
|
||||
|
||||
The adaptation here has a few changes:
|
||||
- Remove print statements, use `tqdm` to show progress
|
||||
- Remove unused "outscale" logic, which simply scales the final image to a given factor
|
||||
- Remove `dni_weight` logic, which was only used when multiple models were used
|
||||
- Remove logic to fetch models from network
|
||||
- Add types, rename a few things
|
||||
"""
|
||||
|
||||
|
||||
class ImageMode(str, Enum):
|
||||
L = "L"
|
||||
RGB = "RGB"
|
||||
RGBA = "RGBA"
|
||||
|
||||
|
||||
class RealESRGAN:
|
||||
"""A helper class for upsampling images with RealESRGAN.
|
||||
|
||||
Args:
|
||||
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
||||
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
||||
model (nn.Module): The defined network. Default: None.
|
||||
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
||||
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
||||
0 denotes for do not use tile. Default: 0.
|
||||
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
||||
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
||||
half (float): Whether to use half precision during inference. Default: False.
|
||||
"""
|
||||
|
||||
output: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scale: int,
|
||||
model_path: Path,
|
||||
model: RRDBNet,
|
||||
tile: int = 0,
|
||||
tile_pad: int = 10,
|
||||
pre_pad: int = 10,
|
||||
half: bool = False,
|
||||
) -> None:
|
||||
self.scale = scale
|
||||
self.tile_size = tile
|
||||
self.tile_pad = tile_pad
|
||||
self.pre_pad = pre_pad
|
||||
self.mod_scale: Optional[int] = None
|
||||
self.half = half
|
||||
self.device = choose_torch_device()
|
||||
|
||||
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
||||
|
||||
# prefer to use params_ema
|
||||
if "params_ema" in loadnet:
|
||||
keyname = "params_ema"
|
||||
else:
|
||||
keyname = "params"
|
||||
|
||||
model.load_state_dict(loadnet[keyname], strict=True)
|
||||
model.eval()
|
||||
self.model = model.to(self.device)
|
||||
|
||||
if self.half:
|
||||
self.model = self.model.half()
|
||||
|
||||
def pre_process(self, img: MatLike) -> None:
|
||||
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible"""
|
||||
img_tensor: torch.Tensor = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||
self.img = img_tensor.unsqueeze(0).to(self.device)
|
||||
if self.half:
|
||||
self.img = self.img.half()
|
||||
|
||||
# pre_pad
|
||||
if self.pre_pad != 0:
|
||||
self.img = torch.nn.functional.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
|
||||
# mod pad for divisible borders
|
||||
if self.scale == 2:
|
||||
self.mod_scale = 2
|
||||
elif self.scale == 1:
|
||||
self.mod_scale = 4
|
||||
if self.mod_scale is not None:
|
||||
self.mod_pad_h, self.mod_pad_w = 0, 0
|
||||
_, _, h, w = self.img.size()
|
||||
if h % self.mod_scale != 0:
|
||||
self.mod_pad_h = self.mod_scale - h % self.mod_scale
|
||||
if w % self.mod_scale != 0:
|
||||
self.mod_pad_w = self.mod_scale - w % self.mod_scale
|
||||
self.img = torch.nn.functional.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect")
|
||||
|
||||
def process(self) -> None:
|
||||
# model inference
|
||||
self.output = self.model(self.img)
|
||||
|
||||
def tile_process(self) -> None:
|
||||
"""It will first crop input images to tiles, and then process each tile.
|
||||
Finally, all the processed tiles are merged into one images.
|
||||
|
||||
Modified from: https://github.com/ata4/esrgan-launcher
|
||||
"""
|
||||
batch, channel, height, width = self.img.shape
|
||||
output_height = height * self.scale
|
||||
output_width = width * self.scale
|
||||
output_shape = (batch, channel, output_height, output_width)
|
||||
|
||||
# start with black image
|
||||
self.output = self.img.new_zeros(output_shape)
|
||||
tiles_x = math.ceil(width / self.tile_size)
|
||||
tiles_y = math.ceil(height / self.tile_size)
|
||||
|
||||
# loop over all tiles
|
||||
total_steps = tiles_y * tiles_x
|
||||
for i in tqdm(range(total_steps), desc="Upscaling"):
|
||||
y = i // tiles_x
|
||||
x = i % tiles_x
|
||||
# extract tile from input image
|
||||
ofs_x = x * self.tile_size
|
||||
ofs_y = y * self.tile_size
|
||||
# input tile area on total image
|
||||
input_start_x = ofs_x
|
||||
input_end_x = min(ofs_x + self.tile_size, width)
|
||||
input_start_y = ofs_y
|
||||
input_end_y = min(ofs_y + self.tile_size, height)
|
||||
|
||||
# input tile area on total image with padding
|
||||
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
||||
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
||||
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
||||
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
||||
|
||||
# input tile dimensions
|
||||
input_tile_width = input_end_x - input_start_x
|
||||
input_tile_height = input_end_y - input_start_y
|
||||
input_tile = self.img[
|
||||
:,
|
||||
:,
|
||||
input_start_y_pad:input_end_y_pad,
|
||||
input_start_x_pad:input_end_x_pad,
|
||||
]
|
||||
|
||||
# upscale tile
|
||||
with torch.no_grad():
|
||||
output_tile = self.model(input_tile)
|
||||
|
||||
# output tile area on total image
|
||||
output_start_x = input_start_x * self.scale
|
||||
output_end_x = input_end_x * self.scale
|
||||
output_start_y = input_start_y * self.scale
|
||||
output_end_y = input_end_y * self.scale
|
||||
|
||||
# output tile area without padding
|
||||
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
||||
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
||||
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
||||
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
||||
|
||||
# put tile into output image
|
||||
self.output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[
|
||||
:,
|
||||
:,
|
||||
output_start_y_tile:output_end_y_tile,
|
||||
output_start_x_tile:output_end_x_tile,
|
||||
]
|
||||
|
||||
def post_process(self) -> torch.Tensor:
|
||||
# remove extra pad
|
||||
if self.mod_scale is not None:
|
||||
_, _, h, w = self.output.size()
|
||||
self.output = self.output[
|
||||
:,
|
||||
:,
|
||||
0 : h - self.mod_pad_h * self.scale,
|
||||
0 : w - self.mod_pad_w * self.scale,
|
||||
]
|
||||
# remove prepad
|
||||
if self.pre_pad != 0:
|
||||
_, _, h, w = self.output.size()
|
||||
self.output = self.output[
|
||||
:,
|
||||
:,
|
||||
0 : h - self.pre_pad * self.scale,
|
||||
0 : w - self.pre_pad * self.scale,
|
||||
]
|
||||
return self.output
|
||||
|
||||
@torch.no_grad()
|
||||
def upscale(self, img: MatLike, esrgan_alpha_upscale: bool = True) -> npt.NDArray[Any]:
|
||||
np_img = img.astype(np.float32)
|
||||
alpha: Optional[np.ndarray] = None
|
||||
if np.max(np_img) > 256:
|
||||
# 16-bit image
|
||||
max_range = 65535
|
||||
else:
|
||||
max_range = 255
|
||||
np_img = np_img / max_range
|
||||
if len(np_img.shape) == 2:
|
||||
# grayscale image
|
||||
img_mode = ImageMode.L
|
||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_GRAY2RGB)
|
||||
elif np_img.shape[2] == 4:
|
||||
# RGBA image with alpha channel
|
||||
img_mode = ImageMode.RGBA
|
||||
alpha = np_img[:, :, 3]
|
||||
np_img = np_img[:, :, 0:3]
|
||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
|
||||
if esrgan_alpha_upscale:
|
||||
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
||||
else:
|
||||
img_mode = ImageMode.RGB
|
||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# ------------------- process image (without the alpha channel) ------------------- #
|
||||
self.pre_process(np_img)
|
||||
if self.tile_size > 0:
|
||||
self.tile_process()
|
||||
else:
|
||||
self.process()
|
||||
output_tensor = self.post_process()
|
||||
output_img: npt.NDArray[Any] = output_tensor.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
||||
if img_mode is ImageMode.L:
|
||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# ------------------- process the alpha channel if necessary ------------------- #
|
||||
if img_mode is ImageMode.RGBA:
|
||||
if esrgan_alpha_upscale:
|
||||
assert alpha is not None
|
||||
self.pre_process(alpha)
|
||||
if self.tile_size > 0:
|
||||
self.tile_process()
|
||||
else:
|
||||
self.process()
|
||||
output_alpha_tensor = self.post_process()
|
||||
output_alpha: npt.NDArray[Any] = output_alpha_tensor.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||
else: # use the cv2 resize for alpha channel
|
||||
assert alpha is not None
|
||||
h, w = alpha.shape[0:2]
|
||||
output_alpha = cv2.resize(
|
||||
alpha,
|
||||
(w * self.scale, h * self.scale),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
|
||||
# merge the alpha channel
|
||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
||||
output_img[:, :, 3] = output_alpha
|
||||
|
||||
# ------------------------------ return ------------------------------ #
|
||||
if max_range == 65535: # 16-bit image
|
||||
output = (output_img * 65535.0).round().astype(np.uint16)
|
||||
else:
|
||||
output = (output_img * 255.0).round().astype(np.uint8)
|
||||
|
||||
return output
|
||||
@@ -54,6 +54,44 @@ class ImageProjModel(torch.nn.Module):
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class MLPProjModel(torch.nn.Module):
|
||||
"""SD model with image prompt"""
|
||||
|
||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
|
||||
super().__init__()
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
|
||||
torch.nn.LayerNorm(cross_attention_dim),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
|
||||
"""Initialize an MLPProjModel from a state_dict.
|
||||
|
||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||
|
||||
Args:
|
||||
state_dict (dict[torch.Tensor]): The state_dict of model weights.
|
||||
|
||||
Returns:
|
||||
MLPProjModel
|
||||
"""
|
||||
cross_attention_dim = state_dict["proj.3.weight"].shape[0]
|
||||
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
|
||||
|
||||
model = cls(cross_attention_dim, clip_embeddings_dim)
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, image_embeds):
|
||||
clip_extra_context_tokens = self.proj(image_embeds)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class IPAdapter:
|
||||
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||
|
||||
@@ -130,6 +168,13 @@ class IPAdapterPlus(IPAdapter):
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
||||
|
||||
class IPAdapterFull(IPAdapterPlus):
|
||||
"""IP-Adapter Plus with full features."""
|
||||
|
||||
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
|
||||
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
|
||||
|
||||
|
||||
class IPAdapterPlusXL(IPAdapterPlus):
|
||||
"""IP-Adapter Plus for SDXL."""
|
||||
|
||||
@@ -149,11 +194,9 @@ def build_ip_adapter(
|
||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
||||
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
||||
|
||||
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it
|
||||
# contains.
|
||||
is_plus = "proj.weight" not in state_dict["image_proj"]
|
||||
|
||||
if is_plus:
|
||||
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
|
||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
|
||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||
if cross_attention_dim == 768:
|
||||
# SD1 IP-Adapter Plus
|
||||
@@ -163,5 +206,7 @@ def build_ip_adapter(
|
||||
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
|
||||
else:
|
||||
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
|
||||
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
|
||||
return IPAdapterFull(state_dict, device=device, dtype=dtype)
|
||||
else:
|
||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# ruff: noqa: I001, F401
|
||||
"""
|
||||
Initialization file for invokeai.backend.model_management
|
||||
"""
|
||||
# This import must be first
|
||||
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType # noqa: F401 isort: split
|
||||
from .model_manager import AddModelResult, ModelInfo, ModelManager, SchedulerPredictionType
|
||||
from .lora import ModelPatcher, ONNXModelPatcher
|
||||
from .model_cache import ModelCache
|
||||
|
||||
from .lora import ModelPatcher, ONNXModelPatcher # noqa: F401
|
||||
from .model_cache import ModelCache # noqa: F401
|
||||
from .models import ( # noqa: F401
|
||||
from .models import (
|
||||
BaseModelType,
|
||||
DuplicateModelException,
|
||||
ModelNotFoundException,
|
||||
@@ -16,4 +17,4 @@ from .models import ( # noqa: F401
|
||||
)
|
||||
|
||||
# This import must be last
|
||||
from .model_merge import ModelMerger, MergeInterpolationMethod # noqa: F401 isort: split
|
||||
from .model_merge import MergeInterpolationMethod, ModelMerger
|
||||
|
||||
@@ -192,20 +192,33 @@ class ModelPatcher:
|
||||
trigger += f"-!pad-{i}"
|
||||
return f"<{trigger}>"
|
||||
|
||||
def _get_ti_embedding(model_embeddings, ti):
|
||||
# for SDXL models, select the embedding that matches the text encoder's dimensions
|
||||
if ti.embedding_2 is not None:
|
||||
return (
|
||||
ti.embedding_2
|
||||
if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0]
|
||||
else ti.embedding
|
||||
)
|
||||
else:
|
||||
return ti.embedding
|
||||
|
||||
# modify tokenizer
|
||||
new_tokens_added = 0
|
||||
for ti_name, ti in ti_list:
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
|
||||
|
||||
for i in range(ti_embedding.shape[0]):
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||
|
||||
# modify text_encoder
|
||||
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
|
||||
model_embeddings = text_encoder.get_input_embeddings()
|
||||
|
||||
for ti_name, ti in ti_list:
|
||||
for ti_name, _ in ti_list:
|
||||
ti_tokens = []
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
embedding = ti.embedding[i]
|
||||
for i in range(ti_embedding.shape[0]):
|
||||
embedding = ti_embedding[i]
|
||||
trigger = _get_trigger(ti_name, i)
|
||||
|
||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||
@@ -273,6 +286,7 @@ class ModelPatcher:
|
||||
|
||||
class TextualInversionModel:
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
@@ -296,8 +310,8 @@ class TextualInversionModel:
|
||||
if "string_to_param" in state_dict:
|
||||
if len(state_dict["string_to_param"]) > 1:
|
||||
print(
|
||||
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first'
|
||||
" token will be used."
|
||||
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
|
||||
" token will be used.",
|
||||
)
|
||||
|
||||
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
||||
@@ -306,6 +320,11 @@ class TextualInversionModel:
|
||||
elif "emb_params" in state_dict:
|
||||
result.embedding = state_dict["emb_params"]
|
||||
|
||||
# v5(sdxl safetensors file)
|
||||
elif "clip_g" in state_dict and "clip_l" in state_dict:
|
||||
result.embedding = state_dict["clip_g"]
|
||||
result.embedding_2 = state_dict["clip_l"]
|
||||
|
||||
# v4(diffusers bin files)
|
||||
else:
|
||||
result.embedding = next(iter(state_dict.values()))
|
||||
@@ -342,6 +361,13 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
if token_id in self.pad_tokens:
|
||||
new_token_ids.extend(self.pad_tokens[token_id])
|
||||
|
||||
# Do not exceed the max model input size
|
||||
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
|
||||
# which first removes and then adds back the start and end tokens.
|
||||
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
|
||||
if len(new_token_ids) > max_length:
|
||||
new_token_ids = new_token_ids[0:max_length]
|
||||
|
||||
return new_token_ids
|
||||
|
||||
|
||||
@@ -490,24 +516,31 @@ class ONNXModelPatcher:
|
||||
trigger += f"-!pad-{i}"
|
||||
return f"<{trigger}>"
|
||||
|
||||
# modify text_encoder
|
||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||
|
||||
# modify tokenizer
|
||||
new_tokens_added = 0
|
||||
for ti_name, ti in ti_list:
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||
if ti.embedding_2 is not None:
|
||||
ti_embedding = (
|
||||
ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
|
||||
)
|
||||
else:
|
||||
ti_embedding = ti.embedding
|
||||
|
||||
# modify text_encoder
|
||||
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||
for i in range(ti_embedding.shape[0]):
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
|
||||
|
||||
embeddings = np.concatenate(
|
||||
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
|
||||
axis=0,
|
||||
)
|
||||
|
||||
for ti_name, ti in ti_list:
|
||||
for ti_name, _ in ti_list:
|
||||
ti_tokens = []
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
embedding = ti.embedding[i].detach().numpy()
|
||||
for i in range(ti_embedding.shape[0]):
|
||||
embedding = ti_embedding[i].detach().numpy()
|
||||
trigger = _get_trigger(ti_name, i)
|
||||
|
||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||
|
||||
@@ -53,6 +53,7 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLPipeline": ModelType.Main,
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"AutoencoderTiny": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
@@ -224,7 +225,7 @@ class ModelProbe(object):
|
||||
with SilenceWarnings():
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
|
||||
cls._scan_model(model_path, model_path)
|
||||
return torch.load(model_path)
|
||||
return torch.load(model_path, map_location="cpu")
|
||||
else:
|
||||
return safetensors.torch.load_file(model_path)
|
||||
|
||||
@@ -372,12 +373,16 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
|
||||
elif "emb_params" in checkpoint:
|
||||
token_dim = checkpoint["emb_params"].shape[-1]
|
||||
elif "clip_g" in checkpoint:
|
||||
token_dim = checkpoint["clip_g"].shape[-1]
|
||||
else:
|
||||
token_dim = list(checkpoint.values())[0].shape[0]
|
||||
if token_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif token_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif token_dim == 1280:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
ModelRecordServiceSQL,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
|
||||
@@ -607,11 +607,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if isinstance(guidance_scale, list):
|
||||
guidance_scale = guidance_scale[step_index]
|
||||
|
||||
noise_pred = self.invokeai_diffuser._combine(
|
||||
uc_noise_pred,
|
||||
c_noise_pred,
|
||||
guidance_scale,
|
||||
)
|
||||
noise_pred = self.invokeai_diffuser._combine(uc_noise_pred, c_noise_pred, guidance_scale)
|
||||
guidance_rescale_multiplier = conditioning_data.guidance_rescale_multiplier
|
||||
if guidance_rescale_multiplier > 0:
|
||||
noise_pred = self._rescale_cfg(
|
||||
noise_pred,
|
||||
c_noise_pred,
|
||||
guidance_rescale_multiplier,
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args)
|
||||
@@ -634,6 +637,16 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
return step_output
|
||||
|
||||
@staticmethod
|
||||
def _rescale_cfg(total_noise_pred, pos_noise_pred, multiplier=0.7):
|
||||
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
|
||||
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
|
||||
|
||||
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
|
||||
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
|
||||
return x_final
|
||||
|
||||
def _unet_forward(
|
||||
self,
|
||||
latents,
|
||||
|
||||
@@ -67,13 +67,17 @@ class IPAdapterConditioningInfo:
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: BasicConditioningInfo
|
||||
text_embeddings: BasicConditioningInfo
|
||||
guidance_scale: Union[float, List[float]]
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
"""
|
||||
guidance_scale: Union[float, List[float]]
|
||||
""" for models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7 .
|
||||
ref [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf)
|
||||
"""
|
||||
guidance_rescale_multiplier: float = 0
|
||||
extra: Optional[ExtraConditioningInfo] = None
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
"""
|
||||
|
||||
0
invokeai/backend/tiles/__init__.py
Normal file
0
invokeai/backend/tiles/__init__.py
Normal file
201
invokeai/backend/tiles/tiles.py
Normal file
201
invokeai/backend/tiles/tiles.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import math
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from invokeai.backend.tiles.utils import TBLR, Tile, paste
|
||||
|
||||
|
||||
def calc_tiles_with_overlap(
|
||||
image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int = 0
|
||||
) -> list[Tile]:
|
||||
"""Calculate the tile coordinates for a given image shape under a simple tiling scheme with overlaps.
|
||||
|
||||
Args:
|
||||
image_height (int): The image height in px.
|
||||
image_width (int): The image width in px.
|
||||
tile_height (int): The tile height in px. All tiles will have this height.
|
||||
tile_width (int): The tile width in px. All tiles will have this width.
|
||||
overlap (int, optional): The target overlap between adjacent tiles. If the tiles do not evenly cover the image
|
||||
shape, then the last row/column of tiles will overlap more than this. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom.
|
||||
"""
|
||||
assert image_height >= tile_height
|
||||
assert image_width >= tile_width
|
||||
assert overlap < tile_height
|
||||
assert overlap < tile_width
|
||||
|
||||
non_overlap_per_tile_height = tile_height - overlap
|
||||
non_overlap_per_tile_width = tile_width - overlap
|
||||
|
||||
num_tiles_y = math.ceil((image_height - overlap) / non_overlap_per_tile_height)
|
||||
num_tiles_x = math.ceil((image_width - overlap) / non_overlap_per_tile_width)
|
||||
|
||||
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
|
||||
tiles: list[Tile] = []
|
||||
|
||||
# Calculate tile coordinates. (Ignore overlap values for now.)
|
||||
for tile_idx_y in range(num_tiles_y):
|
||||
for tile_idx_x in range(num_tiles_x):
|
||||
tile = Tile(
|
||||
coords=TBLR(
|
||||
top=tile_idx_y * non_overlap_per_tile_height,
|
||||
bottom=tile_idx_y * non_overlap_per_tile_height + tile_height,
|
||||
left=tile_idx_x * non_overlap_per_tile_width,
|
||||
right=tile_idx_x * non_overlap_per_tile_width + tile_width,
|
||||
),
|
||||
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||
)
|
||||
|
||||
if tile.coords.bottom > image_height:
|
||||
# If this tile would go off the bottom of the image, shift it so that it is aligned with the bottom
|
||||
# of the image.
|
||||
tile.coords.bottom = image_height
|
||||
tile.coords.top = image_height - tile_height
|
||||
|
||||
if tile.coords.right > image_width:
|
||||
# If this tile would go off the right edge of the image, shift it so that it is aligned with the
|
||||
# right edge of the image.
|
||||
tile.coords.right = image_width
|
||||
tile.coords.left = image_width - tile_width
|
||||
|
||||
tiles.append(tile)
|
||||
|
||||
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
|
||||
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
|
||||
return None
|
||||
return tiles[idx_y * num_tiles_x + idx_x]
|
||||
|
||||
# Iterate over tiles again and calculate overlaps.
|
||||
for tile_idx_y in range(num_tiles_y):
|
||||
for tile_idx_x in range(num_tiles_x):
|
||||
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
|
||||
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
|
||||
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
|
||||
|
||||
assert cur_tile is not None
|
||||
|
||||
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
|
||||
if top_neighbor_tile is not None:
|
||||
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
|
||||
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
|
||||
|
||||
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
|
||||
if left_neighbor_tile is not None:
|
||||
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
|
||||
left_neighbor_tile.overlap.right = cur_tile.overlap.left
|
||||
|
||||
return tiles
|
||||
|
||||
|
||||
def merge_tiles_with_linear_blending(
|
||||
dst_image: np.ndarray, tiles: list[Tile], tile_images: list[np.ndarray], blend_amount: int
|
||||
):
|
||||
"""Merge a set of image tiles into `dst_image` with linear blending between the tiles.
|
||||
|
||||
We expect every tile edge to either:
|
||||
1) have an overlap of 0, because it is aligned with the image edge, or
|
||||
2) have an overlap >= blend_amount.
|
||||
If neither of these conditions are satisfied, we raise an exception.
|
||||
|
||||
The linear blending is centered at the halfway point of the overlap between adjacent tiles.
|
||||
|
||||
Args:
|
||||
dst_image (np.ndarray): The destination image. Shape: (H, W, C).
|
||||
tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`.
|
||||
tile_images (list[np.ndarray]): The tile images to merge into `dst_image`.
|
||||
blend_amount (int): The amount of blending (in px) between adjacent overlapping tiles.
|
||||
"""
|
||||
# Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to
|
||||
# iterate over tiles left-to-right, top-to-bottom.
|
||||
tiles_and_images = list(zip(tiles, tile_images, strict=True))
|
||||
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.left)
|
||||
tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.top)
|
||||
|
||||
# Organize tiles into rows.
|
||||
tile_and_image_rows: list[list[tuple[Tile, np.ndarray]]] = []
|
||||
cur_tile_and_image_row: list[tuple[Tile, np.ndarray]] = []
|
||||
first_tile_in_cur_row, _ = tiles_and_images[0]
|
||||
for tile_and_image in tiles_and_images:
|
||||
tile, _ = tile_and_image
|
||||
if not (
|
||||
tile.coords.top == first_tile_in_cur_row.coords.top
|
||||
and tile.coords.bottom == first_tile_in_cur_row.coords.bottom
|
||||
):
|
||||
# Store the previous row, and start a new one.
|
||||
tile_and_image_rows.append(cur_tile_and_image_row)
|
||||
cur_tile_and_image_row = []
|
||||
first_tile_in_cur_row, _ = tile_and_image
|
||||
|
||||
cur_tile_and_image_row.append(tile_and_image)
|
||||
tile_and_image_rows.append(cur_tile_and_image_row)
|
||||
|
||||
# Prepare 1D linear gradients for blending.
|
||||
gradient_left_x = np.linspace(start=0.0, stop=1.0, num=blend_amount)
|
||||
gradient_top_y = np.linspace(start=0.0, stop=1.0, num=blend_amount)
|
||||
# Convert shape: (blend_amount, ) -> (blend_amount, 1). The extra dimension enables the gradient to be applied
|
||||
# to a 2D image via broadcasting. Note that no additional dimension is needed on gradient_left_x for
|
||||
# broadcasting to work correctly.
|
||||
gradient_top_y = np.expand_dims(gradient_top_y, axis=1)
|
||||
|
||||
for tile_and_image_row in tile_and_image_rows:
|
||||
first_tile_in_row, _ = tile_and_image_row[0]
|
||||
row_height = first_tile_in_row.coords.bottom - first_tile_in_row.coords.top
|
||||
row_image = np.zeros((row_height, dst_image.shape[1], dst_image.shape[2]), dtype=dst_image.dtype)
|
||||
|
||||
# Blend the tiles in the row horizontally.
|
||||
for tile, tile_image in tile_and_image_row:
|
||||
# We expect the tiles to be ordered left-to-right. For each tile, we construct a mask that applies linear
|
||||
# blending to the left of the current tile. The inverse linear blending is automatically applied to the
|
||||
# right of the tiles that have already been pasted by the paste(...) operation.
|
||||
tile_height, tile_width, _ = tile_image.shape
|
||||
mask = np.ones(shape=(tile_height, tile_width), dtype=np.float64)
|
||||
|
||||
# Left blending:
|
||||
if tile.overlap.left > 0:
|
||||
assert tile.overlap.left >= blend_amount
|
||||
# Center the blending gradient in the middle of the overlap.
|
||||
blend_start_left = tile.overlap.left // 2 - blend_amount // 2
|
||||
# The region left of the blending region is masked completely.
|
||||
mask[:, :blend_start_left] = 0.0
|
||||
# Apply the blend gradient to the mask.
|
||||
mask[:, blend_start_left : blend_start_left + blend_amount] = gradient_left_x
|
||||
# For visual debugging:
|
||||
# tile_image[:, blend_start_left : blend_start_left + blend_amount] = 0
|
||||
|
||||
paste(
|
||||
dst_image=row_image,
|
||||
src_image=tile_image,
|
||||
box=TBLR(
|
||||
top=0, bottom=tile.coords.bottom - tile.coords.top, left=tile.coords.left, right=tile.coords.right
|
||||
),
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# Blend the row into the dst_image vertically.
|
||||
# We construct a mask that applies linear blending to the top of the current row. The inverse linear blending is
|
||||
# automatically applied to the bottom of the tiles that have already been pasted by the paste(...) operation.
|
||||
mask = np.ones(shape=(row_image.shape[0], row_image.shape[1]), dtype=np.float64)
|
||||
# Top blending:
|
||||
# (See comments under 'Left blending' for an explanation of the logic.)
|
||||
# We assume that the entire row has the same vertical overlaps as the first_tile_in_row.
|
||||
if first_tile_in_row.overlap.top > 0:
|
||||
assert first_tile_in_row.overlap.top >= blend_amount
|
||||
blend_start_top = first_tile_in_row.overlap.top // 2 - blend_amount // 2
|
||||
mask[:blend_start_top, :] = 0.0
|
||||
mask[blend_start_top : blend_start_top + blend_amount, :] = gradient_top_y
|
||||
# For visual debugging:
|
||||
# row_image[blend_start_top : blend_start_top + blend_amount, :] = 0
|
||||
paste(
|
||||
dst_image=dst_image,
|
||||
src_image=row_image,
|
||||
box=TBLR(
|
||||
top=first_tile_in_row.coords.top,
|
||||
bottom=first_tile_in_row.coords.bottom,
|
||||
left=0,
|
||||
right=row_image.shape[1],
|
||||
),
|
||||
mask=mask,
|
||||
)
|
||||
47
invokeai/backend/tiles/utils.py
Normal file
47
invokeai/backend/tiles/utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TBLR(BaseModel):
|
||||
top: int
|
||||
bottom: int
|
||||
left: int
|
||||
right: int
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.top == other.top
|
||||
and self.bottom == other.bottom
|
||||
and self.left == other.left
|
||||
and self.right == other.right
|
||||
)
|
||||
|
||||
|
||||
class Tile(BaseModel):
|
||||
coords: TBLR = Field(description="The coordinates of this tile relative to its parent image.")
|
||||
overlap: TBLR = Field(description="The amount of overlap with adjacent tiles on each side of this tile.")
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.coords == other.coords and self.overlap == other.overlap
|
||||
|
||||
|
||||
def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optional[np.ndarray] = None):
|
||||
"""Paste a source image into a destination image.
|
||||
|
||||
Args:
|
||||
dst_image (torch.Tensor): The destination image to paste into. Shape: (H, W, C).
|
||||
src_image (torch.Tensor): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'.
|
||||
box (TBLR): Box defining the region in the 'dst_image' where 'src_image' will be pasted.
|
||||
mask (Optional[torch.Tensor]): A mask that defines the blending between 'src_image' and 'dst_image'.
|
||||
Range: [0.0, 1.0], Shape: (H, W). The output is calculate per-pixel according to
|
||||
`src * mask + dst * (1 - mask)`.
|
||||
"""
|
||||
|
||||
if mask is None:
|
||||
dst_image[box.top : box.bottom, box.left : box.right] = src_image
|
||||
else:
|
||||
mask = np.expand_dims(mask, -1)
|
||||
dst_image_box = dst_image[box.top : box.bottom, box.left : box.right]
|
||||
dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask)
|
||||
@@ -1,8 +1,7 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
|
||||
|
||||
"""invokeai.backend.util.logging
|
||||
|
||||
Logging class for InvokeAI that produces console messages
|
||||
"""
|
||||
Logging class for InvokeAI that produces console messages.
|
||||
|
||||
Usage:
|
||||
|
||||
@@ -178,8 +177,8 @@ InvokeAI:
|
||||
import logging.handlers
|
||||
import socket
|
||||
import urllib.parse
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
@@ -192,36 +191,36 @@ except ImportError:
|
||||
|
||||
|
||||
# module level functions
|
||||
def debug(msg, *args, **kwargs):
|
||||
def debug(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
|
||||
InvokeAILogger.get_logger().debug(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def info(msg, *args, **kwargs):
|
||||
def info(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
|
||||
InvokeAILogger.get_logger().info(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def warning(msg, *args, **kwargs):
|
||||
def warning(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
|
||||
InvokeAILogger.get_logger().warning(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def error(msg, *args, **kwargs):
|
||||
def error(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
|
||||
InvokeAILogger.get_logger().error(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def critical(msg, *args, **kwargs):
|
||||
def critical(msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
|
||||
InvokeAILogger.get_logger().critical(msg, *args, **kwargs)
|
||||
|
||||
|
||||
def log(level, msg, *args, **kwargs):
|
||||
def log(level: int, msg: str, *args: str, **kwargs: Any) -> None: # noqa D103
|
||||
InvokeAILogger.get_logger().log(level, msg, *args, **kwargs)
|
||||
|
||||
|
||||
def disable(level=logging.CRITICAL):
|
||||
InvokeAILogger.get_logger().disable(level)
|
||||
def disable(level: int = logging.CRITICAL) -> None: # noqa D103
|
||||
logging.disable(level)
|
||||
|
||||
|
||||
def basicConfig(**kwargs):
|
||||
InvokeAILogger.get_logger().basicConfig(**kwargs)
|
||||
def basicConfig(**kwargs: Any) -> None: # noqa D103
|
||||
logging.basicConfig(**kwargs)
|
||||
|
||||
|
||||
_FACILITY_MAP = (
|
||||
@@ -256,33 +255,25 @@ _SOCK_MAP = {
|
||||
|
||||
|
||||
class InvokeAIFormatter(logging.Formatter):
|
||||
"""
|
||||
Base class for logging formatter
|
||||
"""Base class for logging formatter."""
|
||||
|
||||
"""
|
||||
|
||||
def format(self, record):
|
||||
def format(self, record: logging.LogRecord) -> str: # noqa D102
|
||||
formatter = logging.Formatter(self.log_fmt(record.levelno))
|
||||
return formatter.format(record)
|
||||
|
||||
@abstractmethod
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
pass
|
||||
def log_fmt(self, levelno: int) -> str: # noqa D102
|
||||
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
|
||||
|
||||
|
||||
class InvokeAISyslogFormatter(InvokeAIFormatter):
|
||||
"""
|
||||
Formatting for syslog
|
||||
"""
|
||||
"""Formatting for syslog."""
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
def log_fmt(self, levelno: int) -> str: # noqa D102
|
||||
return "%(name)s [%(process)d] <%(levelname)s> %(message)s"
|
||||
|
||||
|
||||
class InvokeAILegacyLogFormatter(InvokeAIFormatter):
|
||||
"""
|
||||
Formatting for the InvokeAI Logger (legacy version)
|
||||
"""
|
||||
class InvokeAILegacyLogFormatter(InvokeAIFormatter): # noqa D102
|
||||
"""Formatting for the InvokeAI Logger (legacy version)."""
|
||||
|
||||
FORMATS = {
|
||||
logging.DEBUG: " | %(message)s",
|
||||
@@ -292,23 +283,21 @@ class InvokeAILegacyLogFormatter(InvokeAIFormatter):
|
||||
logging.CRITICAL: "### %(message)s",
|
||||
}
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
return self.FORMATS.get(levelno)
|
||||
def log_fmt(self, levelno: int) -> str: # noqa D102
|
||||
format = self.FORMATS.get(levelno)
|
||||
assert format is not None
|
||||
return format
|
||||
|
||||
|
||||
class InvokeAIPlainLogFormatter(InvokeAIFormatter):
|
||||
"""
|
||||
Custom Formatting for the InvokeAI Logger (plain version)
|
||||
"""
|
||||
"""Custom Formatting for the InvokeAI Logger (plain version)."""
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
def log_fmt(self, levelno: int) -> str: # noqa D102
|
||||
return "[%(asctime)s]::[%(name)s]::%(levelname)s --> %(message)s"
|
||||
|
||||
|
||||
class InvokeAIColorLogFormatter(InvokeAIFormatter):
|
||||
"""
|
||||
Custom Formatting for the InvokeAI Logger
|
||||
"""
|
||||
"""Custom Formatting for the InvokeAI Logger."""
|
||||
|
||||
# Color Codes
|
||||
grey = "\x1b[38;20m"
|
||||
@@ -331,8 +320,10 @@ class InvokeAIColorLogFormatter(InvokeAIFormatter):
|
||||
logging.CRITICAL: bold_red + log_format + reset,
|
||||
}
|
||||
|
||||
def log_fmt(self, levelno: int) -> str:
|
||||
return self.FORMATS.get(levelno)
|
||||
def log_fmt(self, levelno: int) -> str: # noqa D102
|
||||
format = self.FORMATS.get(levelno)
|
||||
assert format is not None
|
||||
return format
|
||||
|
||||
|
||||
LOG_FORMATTERS = {
|
||||
@@ -343,13 +334,13 @@ LOG_FORMATTERS = {
|
||||
}
|
||||
|
||||
|
||||
class InvokeAILogger(object):
|
||||
loggers = {}
|
||||
class InvokeAILogger(object): # noqa D102
|
||||
loggers: Dict[str, logging.Logger] = {}
|
||||
|
||||
@classmethod
|
||||
def get_logger(
|
||||
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||
) -> logging.Logger:
|
||||
) -> logging.Logger: # noqa D102
|
||||
if name in cls.loggers:
|
||||
logger = cls.loggers[name]
|
||||
logger.handlers.clear()
|
||||
@@ -362,7 +353,7 @@ class InvokeAILogger(object):
|
||||
return cls.loggers[name]
|
||||
|
||||
@classmethod
|
||||
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]:
|
||||
def get_loggers(cls, config: InvokeAIAppConfig) -> list[logging.Handler]: # noqa D102
|
||||
handler_strs = config.log_handlers
|
||||
handlers = []
|
||||
for handler in handler_strs:
|
||||
@@ -374,7 +365,7 @@ class InvokeAILogger(object):
|
||||
# http gets no custom formatter
|
||||
formatter = LOG_FORMATTERS[config.log_format]
|
||||
if handler_name == "console":
|
||||
ch = logging.StreamHandler()
|
||||
ch: logging.Handler = logging.StreamHandler()
|
||||
ch.setFormatter(formatter())
|
||||
handlers.append(ch)
|
||||
|
||||
@@ -393,18 +384,18 @@ class InvokeAILogger(object):
|
||||
return handlers
|
||||
|
||||
@staticmethod
|
||||
def _parse_syslog_args(args: str = None) -> logging.Handler:
|
||||
def _parse_syslog_args(args: Optional[str] = None) -> logging.Handler:
|
||||
if not SYSLOG_AVAILABLE:
|
||||
raise ValueError("syslog is not available on this system")
|
||||
if not args:
|
||||
args = "/dev/log" if Path("/dev/log").exists() else "address:localhost:514"
|
||||
syslog_args = {}
|
||||
syslog_args: Dict[str, Any] = {}
|
||||
try:
|
||||
for a in args.split(","):
|
||||
arg_name, *arg_value = a.split(":", 2)
|
||||
if arg_name == "address":
|
||||
host, *port = arg_value
|
||||
port = 514 if len(port) == 0 else int(port[0])
|
||||
host, *port_list = arg_value
|
||||
port = 514 if not port_list else int(port_list[0])
|
||||
syslog_args["address"] = (host, port)
|
||||
elif arg_name == "facility":
|
||||
syslog_args["facility"] = _FACILITY_MAP[arg_value[0]]
|
||||
@@ -417,13 +408,13 @@ class InvokeAILogger(object):
|
||||
return logging.handlers.SysLogHandler(**syslog_args)
|
||||
|
||||
@staticmethod
|
||||
def _parse_file_args(args: str = None) -> logging.Handler:
|
||||
def _parse_file_args(args: Optional[str] = None) -> logging.Handler: # noqa D102
|
||||
if not args:
|
||||
raise ValueError("please provide filename for file logging using format 'file=/path/to/logfile.txt'")
|
||||
return logging.FileHandler(args)
|
||||
|
||||
@staticmethod
|
||||
def _parse_http_args(args: str = None) -> logging.Handler:
|
||||
def _parse_http_args(args: Optional[str] = None) -> logging.Handler: # noqa D102
|
||||
if not args:
|
||||
raise ValueError("please provide destination for http logging using format 'http=url'")
|
||||
arg_list = args.split(",")
|
||||
@@ -434,12 +425,12 @@ class InvokeAILogger(object):
|
||||
path = url.path
|
||||
port = url.port or 80
|
||||
|
||||
syslog_args = {}
|
||||
syslog_args: Dict[str, Any] = {}
|
||||
for a in arg_list:
|
||||
arg_name, *arg_value = a.split(":", 2)
|
||||
if arg_name == "method":
|
||||
arg_value = arg_value[0] if len(arg_value) > 0 else "GET"
|
||||
syslog_args[arg_name] = arg_value
|
||||
method = arg_value[0] if len(arg_value) > 0 else "GET"
|
||||
syslog_args[arg_name] = method
|
||||
else: # TODO: Provide support for SSL context and credentials
|
||||
pass
|
||||
return logging.handlers.HTTPHandler(f"{host}:{port}", path, **syslog_args)
|
||||
|
||||
@@ -20,10 +20,18 @@ module.exports = {
|
||||
ecmaVersion: 2018,
|
||||
sourceType: 'module',
|
||||
},
|
||||
plugins: ['react', '@typescript-eslint', 'eslint-plugin-react-hooks'],
|
||||
plugins: [
|
||||
'react',
|
||||
'@typescript-eslint',
|
||||
'eslint-plugin-react-hooks',
|
||||
'i18next',
|
||||
'path',
|
||||
],
|
||||
root: true,
|
||||
rules: {
|
||||
'path/no-relative-imports': ['error', { maxDepth: 0 }],
|
||||
curly: 'error',
|
||||
'i18next/no-literal-string': 2,
|
||||
'react/jsx-no-bind': ['error', { allowBind: true }],
|
||||
'react/jsx-curly-brace-presence': [
|
||||
'error',
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
dist/
|
||||
public/locales/*.json
|
||||
!public/locales/en.json
|
||||
.husky/
|
||||
node_modules/
|
||||
patches/
|
||||
@@ -9,6 +10,5 @@ index.html
|
||||
.yalc/
|
||||
*.scss
|
||||
src/services/api/schema.d.ts
|
||||
docs/
|
||||
static/
|
||||
src/theme/css/overlayscrollbars.css
|
||||
|
||||
171
invokeai/frontend/web/dist/assets/App-6440ab3b.js
vendored
Normal file
171
invokeai/frontend/web/dist/assets/App-6440ab3b.js
vendored
Normal file
File diff suppressed because one or more lines are too long
171
invokeai/frontend/web/dist/assets/App-cd19f6f7.js
vendored
171
invokeai/frontend/web/dist/assets/App-cd19f6f7.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/MantineProvider-a6a1d85c.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/MantineProvider-a6a1d85c.js
vendored
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,4 +1,4 @@
|
||||
import{w as s,ie as T,v as l,_ as I,ig as R,aa as V,ih as z,ii as j,ij as D,ik as F,il as G,im as W,io as K,az as H,ip as U,iq as Y}from"./index-c553e366.js";import{M as Z}from"./MantineProvider-094ba0de.js";var P=String.raw,E=P`
|
||||
import{I as s,ie as T,v as l,$ as A,ig as R,aa as V,ih as z,ii as j,ij as D,ik as F,il as G,im as W,io as K,az as H,ip as U,iq as Y}from"./index-f820e2e3.js";import{M as Z}from"./MantineProvider-a6a1d85c.js";var P=String.raw,E=P`
|
||||
:root,
|
||||
:host {
|
||||
--chakra-vh: 100vh;
|
||||
@@ -277,4 +277,4 @@ import{w as s,ie as T,v as l,_ as I,ig as R,aa as V,ih as z,ii as j,ij as D,ik a
|
||||
}
|
||||
|
||||
${E}
|
||||
`}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(m=>{const f=m==="system"?w():m;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);I(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const m=a.get();if(m){c(m);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const A=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:A,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function h(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>h(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((h(e)||h(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=h(e)?e(...t):e,a=h(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function N(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}N.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(N,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function me({children:e}){const{i18n:o}=H(),n=o.dir(),r=l.useMemo(()=>ie({...U,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(Z,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:Y,children:e})})}const ve=l.memo(me);export{ve as default};
|
||||
`}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(m=>{const f=m==="system"?w():m;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);A(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const m=a.get();if(m){c(m);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const N=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:N,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function h(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>h(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((h(e)||h(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=h(e)?e(...t):e,a=h(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function I(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}I.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(I,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function me({children:e}){const{i18n:o}=H(),n=o.dir(),r=l.useMemo(()=>ie({...U,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(Z,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:Y,children:e})})}const ve=l.memo(me);export{ve as default};
|
||||
9
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-f5f9aabf.css
vendored
Normal file
9
invokeai/frontend/web/dist/assets/ThemeLocaleProvider-f5f9aabf.css
vendored
Normal file
File diff suppressed because one or more lines are too long
156
invokeai/frontend/web/dist/assets/index-c553e366.js
vendored
156
invokeai/frontend/web/dist/assets/index-c553e366.js
vendored
File diff suppressed because one or more lines are too long
157
invokeai/frontend/web/dist/assets/index-f820e2e3.js
vendored
Normal file
157
invokeai/frontend/web/dist/assets/index-f820e2e3.js
vendored
Normal file
File diff suppressed because one or more lines are too long
BIN
invokeai/frontend/web/dist/assets/inter-cyrillic-ext-wght-normal-1c3007b8.woff2
vendored
Normal file
BIN
invokeai/frontend/web/dist/assets/inter-cyrillic-ext-wght-normal-1c3007b8.woff2
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
invokeai/frontend/web/dist/assets/inter-cyrillic-wght-normal-eba94878.woff2
vendored
Normal file
BIN
invokeai/frontend/web/dist/assets/inter-cyrillic-wght-normal-eba94878.woff2
vendored
Normal file
Binary file not shown.
BIN
invokeai/frontend/web/dist/assets/inter-greek-ext-wght-normal-81f77e51.woff2
vendored
Normal file
BIN
invokeai/frontend/web/dist/assets/inter-greek-ext-wght-normal-81f77e51.woff2
vendored
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user