mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 14:58:03 -05:00
Compare commits
28 Commits
next-rebas
...
refactor/m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ffe672bc1 | ||
|
|
ed2d9ae0d9 | ||
|
|
09e7d35b55 | ||
|
|
9758082dc5 | ||
|
|
5f4ce0b118 | ||
|
|
8ac4b9b32c | ||
|
|
ec77599e79 | ||
|
|
2c1b8c0bc2 | ||
|
|
d4525e1282 | ||
|
|
b0d67ea2cc | ||
|
|
bd802d1e7a | ||
|
|
433eb73d8e | ||
|
|
b71f53ba86 | ||
|
|
68064c133a | ||
|
|
411ec1ed64 | ||
|
|
40a81c358d | ||
|
|
1d724bca4a | ||
|
|
a6508d1391 | ||
|
|
1eeca48529 | ||
|
|
79d028ecbd | ||
|
|
531d2c8fd7 | ||
|
|
37675ee4f5 | ||
|
|
26f721d0ec | ||
|
|
420f6050a6 | ||
|
|
9804cb0e67 | ||
|
|
4c5aedbcba | ||
|
|
a380d1f3b2 | ||
|
|
6b8a6e12bc |
4
.github/workflows/lint-frontend.yml
vendored
4
.github/workflows/lint-frontend.yml
vendored
@@ -36,10 +36,8 @@ jobs:
|
||||
- name: Typescript
|
||||
run: 'pnpm run lint:tsc'
|
||||
- name: Madge
|
||||
run: 'pnpm run lint:dpdm'
|
||||
run: 'pnpm run lint:madge'
|
||||
- name: ESLint
|
||||
run: 'pnpm run lint:eslint'
|
||||
- name: Prettier
|
||||
run: 'pnpm run lint:prettier'
|
||||
- name: Knip
|
||||
run: 'pnpm run lint:knip'
|
||||
|
||||
39
Makefile
39
Makefile
@@ -6,44 +6,33 @@ default: help
|
||||
help:
|
||||
@echo Developer commands:
|
||||
@echo
|
||||
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
|
||||
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "test" Run the unit tests.
|
||||
@echo "frontend-install" Install the pnpm modules needed for the front end
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
@echo "ruff Run ruff, fixing any safely-fixable errors and formatting"
|
||||
@echo "ruff-unsafe Run ruff, fixing all fixable errors and formatting"
|
||||
@echo "mypy Run mypy using the config in pyproject.toml to identify type mismatches and other coding errors"
|
||||
@echo "mypy-all Run mypy ignoring the config in pyproject.tom but still ignoring missing imports"
|
||||
@echo "frontend-build Build the frontend in order to run on localhost:9090"
|
||||
@echo "frontend-dev Run the frontend in developer mode on localhost:5173"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
|
||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||
ruff:
|
||||
ruff check . --fix
|
||||
ruff format .
|
||||
ruff check . --fix
|
||||
ruff format .
|
||||
|
||||
# Runs ruff, fixing all errors it can fix and formatting
|
||||
ruff-unsafe:
|
||||
ruff check . --fix --unsafe-fixes
|
||||
ruff format .
|
||||
ruff check . --fix --unsafe-fixes
|
||||
ruff format .
|
||||
|
||||
# Runs mypy, using the config in pyproject.toml
|
||||
mypy:
|
||||
mypy scripts/invokeai-web.py
|
||||
mypy scripts/invokeai-web.py
|
||||
|
||||
# Runs mypy, ignoring the config in pyproject.toml but still ignoring missing (untyped) imports
|
||||
# (many files are ignored by the config, so this is useful for checking all files)
|
||||
mypy-all:
|
||||
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
|
||||
|
||||
# Run the unit tests
|
||||
test:
|
||||
pytest ./tests
|
||||
|
||||
# Install the pnpm modules needed for the front end
|
||||
frontend-install:
|
||||
rm -rf invokeai/frontend/web/node_modules
|
||||
cd invokeai/frontend/web && pnpm install
|
||||
mypy scripts/invokeai-web.py --config-file= --ignore-missing-imports
|
||||
|
||||
# Build the frontend
|
||||
frontend-build:
|
||||
|
||||
@@ -18,8 +18,8 @@ ENV INVOKEAI_SRC=/opt/invokeai
|
||||
ENV VIRTUAL_ENV=/opt/venv/invokeai
|
||||
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
ARG TORCH_VERSION=2.1.2
|
||||
ARG TORCHVISION_VERSION=0.16.2
|
||||
ARG TORCH_VERSION=2.1.0
|
||||
ARG TORCHVISION_VERSION=0.16
|
||||
ARG GPU_DRIVER=cuda
|
||||
ARG TARGETPLATFORM="linux/amd64"
|
||||
# unused but available
|
||||
@@ -35,7 +35,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ] || [ "$GPU_DRIVER" = "cpu" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cpu"; \
|
||||
elif [ "$GPU_DRIVER" = "rocm" ]; then \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/rocm5.6"; \
|
||||
extra_index_url_arg="--index-url https://download.pytorch.org/whl/rocm5.6"; \
|
||||
else \
|
||||
extra_index_url_arg="--extra-index-url https://download.pytorch.org/whl/cu121"; \
|
||||
fi &&\
|
||||
@@ -54,7 +54,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if [ "$GPU_DRIVER" = "cuda" ] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then \
|
||||
pip install -e ".[xformers]"; \
|
||||
else \
|
||||
pip install $extra_index_url_arg -e "."; \
|
||||
pip install -e "."; \
|
||||
fi
|
||||
|
||||
# #### Build the Web UI ------------------------------------
|
||||
|
||||
@@ -28,7 +28,7 @@ This is done via Docker Desktop preferences
|
||||
|
||||
### Configure Invoke environment
|
||||
|
||||
1. Make a copy of `.env.sample` and name it `.env` (`cp .env.sample .env` (Mac/Linux) or `copy example.env .env` (Windows)). Make changes as necessary. Set `INVOKEAI_ROOT` to an absolute path to:
|
||||
1. Make a copy of `env.sample` and name it `.env` (`cp env.sample .env` (Mac/Linux) or `copy example.env .env` (Windows)). Make changes as necessary. Set `INVOKEAI_ROOT` to an absolute path to:
|
||||
a. the desired location of the InvokeAI runtime directory, or
|
||||
b. an existing, v3.0.0 compatible runtime directory.
|
||||
1. Execute `run.sh`
|
||||
|
||||
@@ -21,7 +21,7 @@ run() {
|
||||
printf "%s\n" "$build_args"
|
||||
fi
|
||||
|
||||
docker compose build $build_args $service_name
|
||||
docker compose build $build_args
|
||||
unset build_args
|
||||
|
||||
printf "%s\n" "starting service $service_name"
|
||||
|
||||
@@ -9,15 +9,11 @@ complex functionality.
|
||||
|
||||
## Invocations Directory
|
||||
|
||||
InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These
|
||||
can be used as examples to create your own nodes.
|
||||
InvokeAI Nodes can be found in the `invokeai/app/invocations` directory. These can be used as examples to create your own nodes.
|
||||
|
||||
New nodes should be added to a subfolder in `nodes` direction found at the root
|
||||
level of the InvokeAI installation location. Nodes added to this folder will be
|
||||
able to be used upon application startup.
|
||||
|
||||
Example `nodes` subfolder structure:
|
||||
New nodes should be added to a subfolder in `nodes` direction found at the root level of the InvokeAI installation location. Nodes added to this folder will be able to be used upon application startup.
|
||||
|
||||
Example `nodes` subfolder structure:
|
||||
```py
|
||||
├── __init__.py # Invoke-managed custom node loader
|
||||
│
|
||||
@@ -34,14 +30,14 @@ Example `nodes` subfolder structure:
|
||||
└── fancy_node.py
|
||||
```
|
||||
|
||||
Each node folder must have an `__init__.py` file that imports its nodes. Only
|
||||
nodes imported in the `__init__.py` file are loaded. See the README in the nodes
|
||||
folder for more examples:
|
||||
Each node folder must have an `__init__.py` file that imports its nodes. Only nodes imported in the `__init__.py` file are loaded.
|
||||
See the README in the nodes folder for more examples:
|
||||
|
||||
```py
|
||||
from .cool_node import CoolInvocation
|
||||
```
|
||||
|
||||
|
||||
## Creating A New Invocation
|
||||
|
||||
In order to understand the process of creating a new Invocation, let us actually
|
||||
@@ -135,6 +131,7 @@ from invokeai.app.invocations.primitives import ImageField
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
'''Resizes an image'''
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The input image")
|
||||
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||
@@ -170,6 +167,7 @@ from invokeai.app.invocations.primitives import ImageField
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
'''Resizes an image'''
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The input image")
|
||||
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||
@@ -199,6 +197,7 @@ from invokeai.app.invocations.image import ImageOutput
|
||||
class ResizeInvocation(BaseInvocation):
|
||||
'''Resizes an image'''
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The input image")
|
||||
width: int = InputField(default=512, ge=64, le=2048, description="Width of the new image")
|
||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||
@@ -230,17 +229,30 @@ class ResizeInvocation(BaseInvocation):
|
||||
height: int = InputField(default=512, ge=64, le=2048, description="Height of the new image")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# Load the input image as a PIL image
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
# Load the image using InvokeAI's predefined Image Service. Returns the PIL image.
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
# Resize the image
|
||||
# Resizing the image
|
||||
resized_image = image.resize((self.width, self.height))
|
||||
|
||||
# Save the image
|
||||
image_dto = context.images.save(image=resized_image)
|
||||
# Save the image using InvokeAI's predefined Image Service. Returns the prepared PIL image.
|
||||
output_image = context.services.images.create(
|
||||
image=resized_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
# Return an ImageOutput
|
||||
return ImageOutput.build(image_dto)
|
||||
# Returning the Image
|
||||
return ImageOutput(
|
||||
image=ImageField(
|
||||
image_name=output_image.image_name,
|
||||
),
|
||||
width=output_image.width,
|
||||
height=output_image.height,
|
||||
)
|
||||
```
|
||||
|
||||
**Note:** Do not be overwhelmed by the `ImageOutput` process. InvokeAI has a
|
||||
@@ -331,25 +343,27 @@ class ImageColorStringOutput(BaseInvocationOutput):
|
||||
|
||||
That's all there is to it.
|
||||
|
||||
<!-- TODO: DANGER - we probably do not want people to create their own field types, because this requires a lot of work on the frontend to accomodate.
|
||||
|
||||
### Custom Input Fields
|
||||
|
||||
Now that you know how to create your own Invocations, let us dive into slightly
|
||||
more advanced topics.
|
||||
|
||||
While creating your own Invocations, you might run into a scenario where the
|
||||
existing fields in InvokeAI do not meet your requirements. In such cases, you
|
||||
can create your own fields.
|
||||
existing input types in InvokeAI do not meet your requirements. In such cases,
|
||||
you can create your own input types.
|
||||
|
||||
Let us create one as an example. Let us say we want to create a color input
|
||||
field that represents a color code. But before we start on that here are some
|
||||
general good practices to keep in mind.
|
||||
|
||||
### Best Practices
|
||||
**Good Practices**
|
||||
|
||||
- There is no naming convention for input fields but we highly recommend that
|
||||
you name it something appropriate like `ColorField`.
|
||||
- It is not mandatory but it is heavily recommended to add a relevant
|
||||
`docstring` to describe your field.
|
||||
`docstring` to describe your input field.
|
||||
- Keep your field in the same file as the Invocation that it is made for or in
|
||||
another file where it is relevant.
|
||||
|
||||
@@ -364,13 +378,10 @@ class ColorField(BaseModel):
|
||||
pass
|
||||
```
|
||||
|
||||
Perfect. Now let us create the properties for our field. This is similar to how
|
||||
you created input fields for your Invocation. All the same rules apply. Let us
|
||||
create four fields representing the _red(r)_, _blue(b)_, _green(g)_ and
|
||||
_alpha(a)_ channel of the color.
|
||||
|
||||
> Technically, the properties are _also_ called fields - but in this case, it
|
||||
> refers to a `pydantic` field.
|
||||
Perfect. Now let us create our custom inputs for our field. This is exactly
|
||||
similar how you created input fields for your Invocation. All the same rules
|
||||
apply. Let us create four fields representing the _red(r)_, _blue(b)_,
|
||||
_green(g)_ and _alpha(a)_ channel of the color.
|
||||
|
||||
```python
|
||||
class ColorField(BaseModel):
|
||||
@@ -385,11 +396,25 @@ That's it. We now have a new input field type that we can use in our Invocations
|
||||
like this.
|
||||
|
||||
```python
|
||||
color: ColorField = InputField(default=ColorField(r=0, g=0, b=0, a=0), description='Background color of an image')
|
||||
color: ColorField = Field(default=ColorField(r=0, g=0, b=0, a=0), description='Background color of an image')
|
||||
```
|
||||
|
||||
### Using the custom field
|
||||
### Custom Components For Frontend
|
||||
|
||||
When you start the UI, your custom field will be automatically recognized.
|
||||
Every backend input type should have a corresponding frontend component so the
|
||||
UI knows what to render when you use a particular field type.
|
||||
|
||||
Custom fields only support connection inputs in the Workflow Editor.
|
||||
If you are using existing field types, we already have components for those. So
|
||||
you don't have to worry about creating anything new. But this might not always
|
||||
be the case. Sometimes you might want to create new field types and have the
|
||||
frontend UI deal with it in a different way.
|
||||
|
||||
This is where we venture into the world of React and Javascript and create our
|
||||
own new components for our Invocations. Do not fear the world of JS. It's
|
||||
actually pretty straightforward.
|
||||
|
||||
Let us create a new component for our custom color field we created above. When
|
||||
we use a color field, let us say we want the UI to display a color picker for
|
||||
the user to pick from rather than entering values. That is what we will build
|
||||
now.
|
||||
-->
|
||||
|
||||
@@ -94,8 +94,6 @@ A model that helps generate creative QR codes that still scan. Can also be used
|
||||
**Openpose**:
|
||||
The OpenPose control model allows for the identification of the general pose of a character by pre-processing an existing image with a clear human structure. With advanced options, Openpose can also detect the face or hands in the image.
|
||||
|
||||
*Note:* The DWPose Processor has replaced the OpenPose processor in Invoke. Workflows and generations that relied on the OpenPose Processor will need to be updated to use the DWPose Processor instead.
|
||||
|
||||
**Mediapipe Face**:
|
||||
|
||||
The MediaPipe Face identification processor is able to clearly identify facial features in order to capture vivid expressions of human faces.
|
||||
|
||||
@@ -230,13 +230,13 @@ manager, please follow these steps:
|
||||
=== "local Webserver"
|
||||
|
||||
```bash
|
||||
invokeai-web
|
||||
invokeai --web
|
||||
```
|
||||
|
||||
=== "Public Webserver"
|
||||
|
||||
```bash
|
||||
invokeai-web --host 0.0.0.0
|
||||
invokeai --web --host 0.0.0.0
|
||||
```
|
||||
|
||||
=== "CLI"
|
||||
@@ -402,4 +402,4 @@ environment variable INVOKEAI_ROOT to point to the installation directory.
|
||||
Note that if you run into problems with the Conda installation, the InvokeAI
|
||||
staff will **not** be able to help you out. Caveat Emptor!
|
||||
|
||||
[dev-chat]: https://discord.com/channels/1020123559063990373/1049495067846524939
|
||||
[dev-chat]: https://discord.com/channels/1020123559063990373/1049495067846524939
|
||||
@@ -69,7 +69,7 @@ a token and copy it, since you will need in for the next step.
|
||||
|
||||
### Setup
|
||||
|
||||
Set up your environmnent variables. In the `docker` directory, make a copy of `.env.sample` and name it `.env`. Make changes as necessary.
|
||||
Set up your environmnent variables. In the `docker` directory, make a copy of `env.sample` and name it `.env`. Make changes as necessary.
|
||||
|
||||
Any environment variables supported by InvokeAI can be set here - please see the [CONFIGURATION](../features/CONFIGURATION.md) for further detail.
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ To use a community workflow, download the the `.json` node graph file and load i
|
||||
+ [Image to Character Art Image Nodes](#image-to-character-art-image-nodes)
|
||||
+ [Image Picker](#image-picker)
|
||||
+ [Image Resize Plus](#image-resize-plus)
|
||||
+ [Latent Upscale](#latent-upscale)
|
||||
+ [Load Video Frame](#load-video-frame)
|
||||
+ [Make 3D](#make-3d)
|
||||
+ [Mask Operations](#mask-operations)
|
||||
@@ -291,13 +290,6 @@ View:
|
||||
</br><img src="https://raw.githubusercontent.com/VeyDlin/image-resize-plus-node/master/.readme/node.png" width="500" />
|
||||
|
||||
|
||||
--------------------------------
|
||||
### Latent Upscale
|
||||
|
||||
**Description:** This node uses a small (~2.4mb) model to upscale the latents used in a Stable Diffusion 1.5 or Stable Diffusion XL image generation, rather than the typical interpolation method, avoiding the traditional downsides of the latent upscale technique.
|
||||
|
||||
**Node Link:** [https://github.com/gogurtenjoyer/latent-upscale](https://github.com/gogurtenjoyer/latent-upscale)
|
||||
|
||||
--------------------------------
|
||||
### Load Video Frame
|
||||
|
||||
@@ -354,21 +346,12 @@ See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/mai
|
||||
|
||||
**Description:** A set of nodes for Metadata. Collect Metadata from within an `iterate` node & extract metadata from an image.
|
||||
|
||||
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node
|
||||
- `Metadata From Image` - Provides Metadata from an image
|
||||
- `Metadata To String` - Extracts a String value of a label from metadata
|
||||
- `Metadata To Integer` - Extracts an Integer value of a label from metadata
|
||||
- `Metadata To Float` - Extracts a Float value of a label from metadata
|
||||
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata
|
||||
- `Metadata To Bool` - Extracts Bool types from metadata
|
||||
- `Metadata To Model` - Extracts model types from metadata
|
||||
- `Metadata To SDXL Model` - Extracts SDXL model types from metadata
|
||||
- `Metadata To LoRAs` - Extracts Loras from metadata.
|
||||
- `Metadata To SDXL LoRAs` - Extracts SDXL Loras from metadata
|
||||
- `Metadata To ControlNets` - Extracts ControNets from metadata
|
||||
- `Metadata To IP-Adapters` - Extracts IP-Adapters from metadata
|
||||
- `Metadata To T2I-Adapters` - Extracts T2I-Adapters from metadata
|
||||
- `Denoise Latents + Metadata` - This is an inherited version of the existing `Denoise Latents` node but with a metadata input and output.
|
||||
- `Metadata Item Linked` - Allows collecting of metadata while within an iterate node with no need for a collect node or conversion to metadata node.
|
||||
- `Metadata From Image` - Provides Metadata from an image.
|
||||
- `Metadata To String` - Extracts a String value of a label from metadata.
|
||||
- `Metadata To Integer` - Extracts an Integer value of a label from metadata.
|
||||
- `Metadata To Float` - Extracts a Float value of a label from metadata.
|
||||
- `Metadata To Scheduler` - Extracts a Scheduler value of a label from metadata.
|
||||
|
||||
**Node Link:** https://github.com/skunkworxdark/metadata-linked-nodes
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ their descriptions.
|
||||
| ONNX Text to Latents | Generates latents from conditionings. |
|
||||
| ONNX Model Loader | Loads a main model, outputting its submodels. |
|
||||
| OpenCV Inpaint | Simple inpaint using opencv. |
|
||||
| DW Openpose Processor | Applies Openpose processing to image |
|
||||
| Openpose Processor | Applies Openpose processing to image |
|
||||
| PIDI Processor | Applies PIDI processing to image |
|
||||
| Prompts from File | Loads prompts from a text file |
|
||||
| Random Integer | Outputs a single random integer. |
|
||||
|
||||
@@ -91,7 +91,8 @@ def choose_version(available_releases: tuple | None = None) -> str:
|
||||
complete_while_typing=True,
|
||||
completer=FuzzyWordCompleter(choices),
|
||||
)
|
||||
console.print(f" Version {choices[0] if response == '' else response} will be installed.")
|
||||
|
||||
console.print(f" Version {choices[0] if response == "" else response} will be installed.")
|
||||
|
||||
console.line()
|
||||
|
||||
|
||||
@@ -2,12 +2,8 @@
|
||||
|
||||
from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_disk import ObjectSerializerDisk
|
||||
from invokeai.app.services.object_serializer.object_serializer_forward_cache import ObjectSerializerForwardCache
|
||||
from invokeai.app.services.item_storage.item_storage_memory import ItemStorageMemory
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
@@ -15,22 +11,26 @@ from ..services.board_image_records.board_image_records_sqlite import SqliteBoar
|
||||
from ..services.board_images.board_images_default import BoardImagesService
|
||||
from ..services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
from ..services.boards.boards_default import BoardService
|
||||
from ..services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from ..services.config import InvokeAIAppConfig
|
||||
from ..services.download import DownloadQueueService
|
||||
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||
from ..services.images.images_default import ImageService
|
||||
from ..services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from ..services.invocation_processor.invocation_processor_default import DefaultInvocationProcessor
|
||||
from ..services.invocation_queue.invocation_queue_memory import MemoryInvocationQueue
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from ..services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from ..services.invoker import Invoker
|
||||
from ..services.latents_storage.latents_storage_disk import DiskLatentsStorage
|
||||
from ..services.latents_storage.latents_storage_forward_cache import ForwardCacheLatentsStorage
|
||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||
from ..services.model_metadata import ModelMetadataStoreSQL
|
||||
from ..services.model_records import ModelRecordServiceSQL
|
||||
from ..services.names.names_default import SimpleNameService
|
||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from ..services.shared.graph import GraphExecutionState
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from .events import FastAPIEventService
|
||||
@@ -67,9 +67,6 @@ class ApiDependencies:
|
||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||
|
||||
output_folder = config.output_path
|
||||
if output_folder is None:
|
||||
raise ValueError("Output folder is not set")
|
||||
|
||||
image_files = DiskImageFileStorage(f"{output_folder}/images")
|
||||
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
@@ -82,16 +79,11 @@ class ApiDependencies:
|
||||
board_records = SqliteBoardRecordStorage(db=db)
|
||||
boards = BoardService()
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
bulk_download = BulkDownloadService()
|
||||
graph_execution_manager = ItemStorageMemory[GraphExecutionState]()
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
tensors = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[torch.Tensor](output_folder / "tensors", ephemeral=True)
|
||||
)
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents"))
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
model_metadata_service = ModelMetadataStoreSQL(db=db)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
@@ -102,6 +94,8 @@ class ApiDependencies:
|
||||
)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
processor = DefaultInvocationProcessor()
|
||||
queue = MemoryInvocationQueue()
|
||||
session_processor = DefaultSessionProcessor()
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
@@ -112,24 +106,25 @@ class ApiDependencies:
|
||||
board_images=board_images,
|
||||
board_records=board_records,
|
||||
boards=boards,
|
||||
bulk_download=bulk_download,
|
||||
configuration=configuration,
|
||||
events=events,
|
||||
graph_execution_manager=graph_execution_manager,
|
||||
image_files=image_files,
|
||||
image_records=image_records,
|
||||
images=images,
|
||||
invocation_cache=invocation_cache,
|
||||
latents=latents,
|
||||
logger=logger,
|
||||
model_manager=model_manager,
|
||||
download_queue=download_queue_service,
|
||||
names=names,
|
||||
performance_statistics=performance_statistics,
|
||||
processor=processor,
|
||||
queue=queue,
|
||||
session_processor=session_processor,
|
||||
session_queue=session_queue,
|
||||
urls=urls,
|
||||
workflow_records=workflow_records,
|
||||
tensors=tensors,
|
||||
conditioning=conditioning,
|
||||
)
|
||||
|
||||
ApiDependencies.invoker = Invoker(services)
|
||||
|
||||
@@ -2,13 +2,13 @@ import io
|
||||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
@@ -375,67 +375,16 @@ async def unstar_images_in_list(
|
||||
|
||||
class ImagesDownloaded(BaseModel):
|
||||
response: Optional[str] = Field(
|
||||
default=None, description="The message to display to the user when images begin downloading"
|
||||
)
|
||||
bulk_download_item_name: Optional[str] = Field(
|
||||
default=None, description="The name of the bulk download item for which events will be emitted"
|
||||
description="If defined, the message to display to the user when images begin downloading"
|
||||
)
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
|
||||
)
|
||||
@images_router.post("/download", operation_id="download_images_from_list", response_model=ImagesDownloaded)
|
||||
async def download_images_from_list(
|
||||
background_tasks: BackgroundTasks,
|
||||
image_names: Optional[list[str]] = Body(
|
||||
default=None, description="The list of names of images to download", embed=True
|
||||
),
|
||||
image_names: list[str] = Body(description="The list of names of images to download", embed=True),
|
||||
board_id: Optional[str] = Body(
|
||||
default=None, description="The board from which image should be downloaded", embed=True
|
||||
default=None, description="The board from which image should be downloaded from", embed=True
|
||||
),
|
||||
) -> ImagesDownloaded:
|
||||
if (image_names is None or len(image_names) == 0) and board_id is None:
|
||||
raise HTTPException(status_code=400, detail="No images or board id specified.")
|
||||
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
|
||||
|
||||
background_tasks.add_task(
|
||||
ApiDependencies.invoker.services.bulk_download.handler,
|
||||
image_names,
|
||||
board_id,
|
||||
bulk_download_item_id,
|
||||
)
|
||||
return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip")
|
||||
|
||||
|
||||
@images_router.api_route(
|
||||
"/download/{bulk_download_item_name}",
|
||||
methods=["GET"],
|
||||
operation_id="get_bulk_download_item",
|
||||
response_class=Response,
|
||||
responses={
|
||||
200: {
|
||||
"description": "Return the complete bulk download item",
|
||||
"content": {"application/zip": {}},
|
||||
},
|
||||
404: {"description": "Image not found"},
|
||||
},
|
||||
)
|
||||
async def get_bulk_download_item(
|
||||
background_tasks: BackgroundTasks,
|
||||
bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"),
|
||||
) -> FileResponse:
|
||||
"""Gets a bulk download zip file"""
|
||||
try:
|
||||
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
|
||||
|
||||
response = FileResponse(
|
||||
path,
|
||||
media_type="application/zip",
|
||||
filename=bulk_download_item_name,
|
||||
content_disposition_type="inline",
|
||||
)
|
||||
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
|
||||
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name)
|
||||
return response
|
||||
except Exception:
|
||||
raise HTTPException(status_code=404)
|
||||
# return ImagesDownloaded(response="Your images are downloading")
|
||||
raise HTTPException(status_code=501, detail="Endpoint is not yet implemented")
|
||||
|
||||
@@ -9,11 +9,11 @@ from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_install import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
@@ -32,7 +32,6 @@ from invokeai.backend.model_manager.config import (
|
||||
)
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
@@ -165,27 +164,6 @@ async def list_model_records(
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/get_by_attrs",
|
||||
operation_id="get_model_records_by_attrs",
|
||||
response_model=AnyModelConfig,
|
||||
)
|
||||
async def get_model_records_by_attrs(
|
||||
name: str = Query(description="The name of the model"),
|
||||
type: ModelType = Query(description="The type of the model"),
|
||||
base: BaseModelType = Query(description="The base model of the model"),
|
||||
) -> AnyModelConfig:
|
||||
"""Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old
|
||||
model manager, which identified models by a combination of name, base and type."""
|
||||
configs = ApiDependencies.invoker.services.model_manager.store.search_by_attr(
|
||||
base_model=base, model_type=type, model_name=name
|
||||
)
|
||||
if not configs:
|
||||
raise HTTPException(status_code=404, detail="No model found with these attributes")
|
||||
|
||||
return configs[0]
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}",
|
||||
operation_id="get_model_record",
|
||||
@@ -223,7 +201,7 @@ async def list_model_summary(
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/i/{key}/metadata",
|
||||
"/meta/i/{key}",
|
||||
operation_id="get_model_metadata",
|
||||
responses={
|
||||
200: {
|
||||
@@ -231,6 +209,7 @@ async def list_model_summary(
|
||||
"content": {"application/json": {"example": example_model_metadata}},
|
||||
},
|
||||
400: {"description": "Bad request"},
|
||||
404: {"description": "No metadata available"},
|
||||
},
|
||||
)
|
||||
async def get_model_metadata(
|
||||
@@ -239,7 +218,8 @@ async def get_model_metadata(
|
||||
"""Get a model metadata object."""
|
||||
record_store = ApiDependencies.invoker.services.model_manager.store
|
||||
result: Optional[AnyModelRepoMetadata] = record_store.get_metadata(key)
|
||||
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="No metadata for a model with this key")
|
||||
return result
|
||||
|
||||
|
||||
@@ -254,75 +234,6 @@ async def list_tags() -> Set[str]:
|
||||
return result
|
||||
|
||||
|
||||
class FoundModel(BaseModel):
|
||||
path: str = Field(description="Path to the model")
|
||||
is_installed: bool = Field(description="Whether or not the model is already installed")
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/scan_folder",
|
||||
operation_id="scan_for_models",
|
||||
responses={
|
||||
200: {"description": "Directory scanned successfully"},
|
||||
400: {"description": "Invalid directory path"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=List[FoundModel],
|
||||
)
|
||||
async def scan_for_models(
|
||||
scan_path: str = Query(description="Directory path to search for models", default=None),
|
||||
) -> List[FoundModel]:
|
||||
path = pathlib.Path(scan_path)
|
||||
if not scan_path or not path.is_dir():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"The search path '{scan_path}' does not exist or is not directory",
|
||||
)
|
||||
|
||||
search = ModelSearch()
|
||||
try:
|
||||
found_model_paths = search.search(path)
|
||||
models_path = ApiDependencies.invoker.services.configuration.models_path
|
||||
|
||||
# If the search path includes the main models directory, we need to exclude core models from the list.
|
||||
# TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed
|
||||
# without needing to crawl the filesystem.
|
||||
core_models_path = pathlib.Path(models_path, "core").resolve()
|
||||
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
|
||||
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
|
||||
resolved_installed_model_paths: list[str] = []
|
||||
installed_model_sources: list[str] = []
|
||||
|
||||
# This call lists all installed models.
|
||||
for model in installed_models:
|
||||
path = pathlib.Path(model.path)
|
||||
# If the model has a source, we need to add it to the list of installed sources.
|
||||
if model.source:
|
||||
installed_model_sources.append(model.source)
|
||||
# If the path is not absolute, that means it is in the app models directory, and we need to join it with
|
||||
# the models path before resolving.
|
||||
if not path.is_absolute():
|
||||
resolved_installed_model_paths.append(str(pathlib.Path(models_path, path).resolve()))
|
||||
continue
|
||||
resolved_installed_model_paths.append(str(path.resolve()))
|
||||
|
||||
scan_results: list[FoundModel] = []
|
||||
|
||||
# Check if the model is installed by comparing the resolved paths, appending to the scan result.
|
||||
for p in non_core_model_paths:
|
||||
path = str(p)
|
||||
is_installed = path in resolved_installed_model_paths or path in installed_model_sources
|
||||
found_model = FoundModel(path=path, is_installed=is_installed)
|
||||
scan_results.append(found_model)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"An error occurred while searching the directory: {e}",
|
||||
)
|
||||
return scan_results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/tags/search",
|
||||
operation_id="search_by_metadata_tags",
|
||||
@@ -439,8 +350,8 @@ async def add_model_record(
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/install",
|
||||
operation_id="install_model",
|
||||
"/heuristic_import",
|
||||
operation_id="heuristic_import_model",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
@@ -449,13 +360,12 @@ async def add_model_record(
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def install_model(
|
||||
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
|
||||
# TODO(MM2): Can we type this?
|
||||
async def heuristic_import(
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
example={"name": "string", "description": "string"},
|
||||
example={"name": "modelT", "description": "antique cars"},
|
||||
),
|
||||
access_token: Optional[str] = None,
|
||||
) -> ModelInstallJob:
|
||||
@@ -492,7 +402,106 @@ async def install_model(
|
||||
result: ModelInstallJob = installer.heuristic_import(
|
||||
source=source,
|
||||
config=config,
|
||||
access_token=access_token,
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=424, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.post(
|
||||
"/install",
|
||||
operation_id="import_model",
|
||||
responses={
|
||||
201: {"description": "The model imported successfully"},
|
||||
415: {"description": "Unrecognized file/folder format"},
|
||||
424: {"description": "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
)
|
||||
async def import_model(
|
||||
source: ModelSource,
|
||||
config: Optional[Dict[str, Any]] = Body(
|
||||
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
|
||||
default=None,
|
||||
),
|
||||
) -> ModelInstallJob:
|
||||
"""Install a model using its local path, repo_id, or remote URL.
|
||||
|
||||
Models will be downloaded, probed, configured and installed in a
|
||||
series of background threads. The return object has `status` attribute
|
||||
that can be used to monitor progress.
|
||||
|
||||
The source object is a discriminated Union of LocalModelSource,
|
||||
HFModelSource and URLModelSource. Set the "type" field to the
|
||||
appropriate value:
|
||||
|
||||
* To install a local path using LocalModelSource, pass a source of form:
|
||||
```
|
||||
{
|
||||
"type": "local",
|
||||
"path": "/path/to/model",
|
||||
"inplace": false
|
||||
}
|
||||
```
|
||||
The "inplace" flag, if true, will register the model in place in its
|
||||
current filesystem location. Otherwise, the model will be copied
|
||||
into the InvokeAI models directory.
|
||||
|
||||
* To install a HuggingFace repo_id using HFModelSource, pass a source of form:
|
||||
```
|
||||
{
|
||||
"type": "hf",
|
||||
"repo_id": "stabilityai/stable-diffusion-2.0",
|
||||
"variant": "fp16",
|
||||
"subfolder": "vae",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}
|
||||
```
|
||||
The `variant`, `subfolder` and `access_token` fields are optional.
|
||||
|
||||
* To install a remote model using an arbitrary URL, pass:
|
||||
```
|
||||
{
|
||||
"type": "url",
|
||||
"url": "http://www.civitai.com/models/123456",
|
||||
"access_token": "f5820a918aaf01"
|
||||
}
|
||||
```
|
||||
The `access_token` field is optonal
|
||||
|
||||
The model's configuration record will be probed and filled in
|
||||
automatically. To override the default guesses, pass "metadata"
|
||||
with a Dict containing the attributes you wish to override.
|
||||
|
||||
Installation occurs in the background. Either use list_model_install_jobs()
|
||||
to poll for completion, or listen on the event bus for the following events:
|
||||
|
||||
* "model_install_running"
|
||||
* "model_install_completed"
|
||||
* "model_install_error"
|
||||
|
||||
On successful completion, the event's payload will contain the field "key"
|
||||
containing the installed ID of the model. On an error, the event's payload
|
||||
will contain the fields "error_type" and "error" describing the nature of the
|
||||
error and its traceback, respectively.
|
||||
|
||||
"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
result: ModelInstallJob = installer.import_model(
|
||||
source=source,
|
||||
config=config,
|
||||
)
|
||||
logger.info(f"Started installation of {source}")
|
||||
except UnknownModelException as e:
|
||||
@@ -628,7 +637,6 @@ async def convert_model(
|
||||
Note that during the conversion process the key and model hash will change.
|
||||
The return value is the model configuration for the converted model.
|
||||
"""
|
||||
model_manager = ApiDependencies.invoker.services.model_manager
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
loader = ApiDependencies.invoker.services.model_manager.load
|
||||
store = ApiDependencies.invoker.services.model_manager.store
|
||||
@@ -645,7 +653,7 @@ async def convert_model(
|
||||
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
|
||||
|
||||
# loading the model will convert it into a cached diffusers file
|
||||
model_manager.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
||||
loader.load_model_by_config(model_config, submodel_type=SubModelType.Scheduler)
|
||||
|
||||
# Get the path of the converted model from the loader
|
||||
cache_path = loader.convert_cache.cache_path(key)
|
||||
|
||||
276
invokeai/app/api/routers/sessions.py
Normal file
276
invokeai/app/api/routers/sessions.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
|
||||
from fastapi import HTTPException, Path
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from ...services.shared.graph import GraphExecutionState
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
session_router = APIRouter(prefix="/v1/sessions", tags=["sessions"])
|
||||
|
||||
|
||||
# @session_router.post(
|
||||
# "/",
|
||||
# operation_id="create_session",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid json"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def create_session(
|
||||
# queue_id: str = Query(default="", description="The id of the queue to associate the session with"),
|
||||
# graph: Optional[Graph] = Body(default=None, description="The graph to initialize the session with"),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Creates a new session, optionally initializing it with an invocation graph"""
|
||||
# session = ApiDependencies.invoker.create_execution_state(queue_id=queue_id, graph=graph)
|
||||
# return session
|
||||
|
||||
|
||||
# @session_router.get(
|
||||
# "/",
|
||||
# operation_id="list_sessions",
|
||||
# responses={200: {"model": PaginatedResults[GraphExecutionState]}},
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def list_sessions(
|
||||
# page: int = Query(default=0, description="The page of results to get"),
|
||||
# per_page: int = Query(default=10, description="The number of results per page"),
|
||||
# query: str = Query(default="", description="The query string to search for"),
|
||||
# ) -> PaginatedResults[GraphExecutionState]:
|
||||
# """Gets a list of sessions, optionally searching"""
|
||||
# if query == "":
|
||||
# result = ApiDependencies.invoker.services.graph_execution_manager.list(page, per_page)
|
||||
# else:
|
||||
# result = ApiDependencies.invoker.services.graph_execution_manager.search(query, page, per_page)
|
||||
# return result
|
||||
|
||||
|
||||
@session_router.get(
|
||||
"/{session_id}",
|
||||
operation_id="get_session",
|
||||
responses={
|
||||
200: {"model": GraphExecutionState},
|
||||
404: {"description": "Session not found"},
|
||||
},
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str = Path(description="The id of the session to get"),
|
||||
) -> GraphExecutionState:
|
||||
"""Gets a session"""
|
||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404)
|
||||
else:
|
||||
return session
|
||||
|
||||
|
||||
# @session_router.post(
|
||||
# "/{session_id}/nodes",
|
||||
# operation_id="add_node",
|
||||
# responses={
|
||||
# 200: {"model": str},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def add_node(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
# description="The node to add"
|
||||
# ),
|
||||
# ) -> str:
|
||||
# """Adds a node to the graph"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
# try:
|
||||
# session.add_node(node)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session.id
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# @session_router.put(
|
||||
# "/{session_id}/nodes/{node_path}",
|
||||
# operation_id="update_node",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def update_node(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# node_path: str = Path(description="The path to the node in the graph"),
|
||||
# node: Annotated[Union[BaseInvocation.get_invocations()], Field(discriminator="type")] = Body( # type: ignore
|
||||
# description="The new node"
|
||||
# ),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Updates a node in the graph and removes all linked edges"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
# try:
|
||||
# session.update_node(node_path, node)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# @session_router.delete(
|
||||
# "/{session_id}/nodes/{node_path}",
|
||||
# operation_id="delete_node",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def delete_node(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# node_path: str = Path(description="The path to the node to delete"),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Deletes a node in the graph and removes all linked edges"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
# try:
|
||||
# session.delete_node(node_path)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# @session_router.post(
|
||||
# "/{session_id}/edges",
|
||||
# operation_id="add_edge",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def add_edge(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# edge: Edge = Body(description="The edge to add"),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Adds an edge to the graph"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
# try:
|
||||
# session.add_edge(edge)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# # TODO: the edge being in the path here is really ugly, find a better solution
|
||||
# @session_router.delete(
|
||||
# "/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}",
|
||||
# operation_id="delete_edge",
|
||||
# responses={
|
||||
# 200: {"model": GraphExecutionState},
|
||||
# 400: {"description": "Invalid node or link"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def delete_edge(
|
||||
# session_id: str = Path(description="The id of the session"),
|
||||
# from_node_id: str = Path(description="The id of the node the edge is coming from"),
|
||||
# from_field: str = Path(description="The field of the node the edge is coming from"),
|
||||
# to_node_id: str = Path(description="The id of the node the edge is going to"),
|
||||
# to_field: str = Path(description="The field of the node the edge is going to"),
|
||||
# ) -> GraphExecutionState:
|
||||
# """Deletes an edge from the graph"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
# try:
|
||||
# edge = Edge(
|
||||
# source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||
# destination=EdgeConnection(node_id=to_node_id, field=to_field),
|
||||
# )
|
||||
# session.delete_edge(edge)
|
||||
# ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||
# session
|
||||
# ) # TODO: can this be done automatically, or add node through an API?
|
||||
# return session
|
||||
# except NodeAlreadyExecutedError:
|
||||
# raise HTTPException(status_code=400)
|
||||
# except IndexError:
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
|
||||
# @session_router.put(
|
||||
# "/{session_id}/invoke",
|
||||
# operation_id="invoke_session",
|
||||
# responses={
|
||||
# 200: {"model": None},
|
||||
# 202: {"description": "The invocation is queued"},
|
||||
# 400: {"description": "The session has no invocations ready to invoke"},
|
||||
# 404: {"description": "Session not found"},
|
||||
# },
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def invoke_session(
|
||||
# queue_id: str = Query(description="The id of the queue to associate the session with"),
|
||||
# session_id: str = Path(description="The id of the session to invoke"),
|
||||
# all: bool = Query(default=False, description="Whether or not to invoke all remaining invocations"),
|
||||
# ) -> Response:
|
||||
# """Invokes a session"""
|
||||
# session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||
# if session is None:
|
||||
# raise HTTPException(status_code=404)
|
||||
|
||||
# if session.is_complete():
|
||||
# raise HTTPException(status_code=400)
|
||||
|
||||
# ApiDependencies.invoker.invoke(queue_id, session, invoke_all=all)
|
||||
# return Response(status_code=202)
|
||||
|
||||
|
||||
# @session_router.delete(
|
||||
# "/{session_id}/invoke",
|
||||
# operation_id="cancel_session_invoke",
|
||||
# responses={202: {"description": "The invocation is canceled"}},
|
||||
# deprecated=True,
|
||||
# )
|
||||
# async def cancel_session_invoke(
|
||||
# session_id: str = Path(description="The id of the session to cancel"),
|
||||
# ) -> Response:
|
||||
# """Invokes a session"""
|
||||
# ApiDependencies.invoker.cancel(session_id)
|
||||
# return Response(status_code=202)
|
||||
@@ -12,26 +12,16 @@ class SocketIO:
|
||||
__sio: AsyncServer
|
||||
__app: ASGIApp
|
||||
|
||||
__sub_queue: str = "subscribe_queue"
|
||||
__unsub_queue: str = "unsubscribe_queue"
|
||||
|
||||
__sub_bulk_download: str = "subscribe_bulk_download"
|
||||
__unsub_bulk_download: str = "unsubscribe_bulk_download"
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
|
||||
app.mount("/ws", self.__app)
|
||||
|
||||
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
|
||||
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
|
||||
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
||||
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
|
||||
|
||||
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
|
||||
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
|
||||
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
|
||||
|
||||
async def _handle_queue_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
@@ -49,18 +39,3 @@ class SocketIO:
|
||||
|
||||
async def _handle_model_event(self, event: Event) -> None:
|
||||
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
|
||||
|
||||
async def _handle_bulk_download_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["bulk_download_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
|
||||
if "bulk_download_id" in data:
|
||||
await self.__sio.enter_room(sid, data["bulk_download_id"])
|
||||
|
||||
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
|
||||
if "bulk_download_id" in data:
|
||||
await self.__sio.leave_room(sid, data["bulk_download_id"])
|
||||
|
||||
@@ -2,12 +2,10 @@
|
||||
# which are imported/used before parse_args() is called will get the default config values instead of the
|
||||
# values from the command line or config file.
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
from .services.config import InvokeAIAppConfig
|
||||
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
@@ -51,12 +49,15 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
images,
|
||||
model_manager,
|
||||
session_queue,
|
||||
sessions,
|
||||
utilities,
|
||||
workflows,
|
||||
)
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
InputFieldJSONSchemaExtra,
|
||||
OutputFieldJSONSchemaExtra,
|
||||
UIConfigBase,
|
||||
)
|
||||
|
||||
@@ -72,25 +73,9 @@ logger = InvokeAILogger.get_logger(config=app_config)
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Add startup event to load dependencies
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
yield
|
||||
# Shut down threads
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
|
||||
# Create the app
|
||||
# TODO: create this all in a method so configuration/etc. can be passed in?
|
||||
app = FastAPI(
|
||||
title="Invoke - Community Edition",
|
||||
docs_url=None,
|
||||
redoc_url=None,
|
||||
separate_input_output_schemas=False,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app = FastAPI(title="Invoke - Community Edition", docs_url=None, redoc_url=None, separate_input_output_schemas=False)
|
||||
|
||||
# Add event handler
|
||||
event_handler_id: int = id(app)
|
||||
@@ -113,7 +98,21 @@ app.add_middleware(
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
|
||||
# Add startup event to load dependencies
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
||||
|
||||
|
||||
# Shut down threads
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event() -> None:
|
||||
ApiDependencies.shutdown()
|
||||
|
||||
|
||||
# Include all routers
|
||||
app.include_router(sessions.session_router, prefix="/api")
|
||||
|
||||
app.include_router(utilities.utilities_router, prefix="/api")
|
||||
app.include_router(model_manager.model_manager_router, prefix="/api")
|
||||
app.include_router(download_queue.download_queue_router, prefix="/api")
|
||||
@@ -153,8 +152,6 @@ def custom_openapi() -> dict[str, Any]:
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
||||
|
||||
# Add Node Editor UI helper schemas
|
||||
ui_config_schemas = models_json_schema(
|
||||
@@ -177,6 +174,7 @@ def custom_openapi() -> dict[str, Any]:
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
invoker_schema["output"] = outputs_ref
|
||||
invoker_schema["class"] = "invocation"
|
||||
openapi_schema["components"]["schemas"][f"{output_type_title}"]["class"] = "output"
|
||||
|
||||
# This code no longer seems to be necessary?
|
||||
# Leave it here just in case
|
||||
|
||||
@@ -8,33 +8,17 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from types import UnionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterable, Literal, Optional, Type, TypeVar, Union, cast
|
||||
|
||||
import semver
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter, create_model
|
||||
from pydantic.fields import FieldInfo, _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
from typing_extensions import TypeAliasType
|
||||
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldKind,
|
||||
Input,
|
||||
)
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -68,6 +52,393 @@ class Classification(str, Enum, metaclass=MetaEnum):
|
||||
Prototype = "prototype"
|
||||
|
||||
|
||||
class Input(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of input a field accepts.
|
||||
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
||||
are instantiated.
|
||||
- `Input.Connection`: The field must have its value provided by a connection.
|
||||
- `Input.Any`: The field may have its value provided either directly or by a connection.
|
||||
"""
|
||||
|
||||
Connection = "connection"
|
||||
Direct = "direct"
|
||||
Any = "any"
|
||||
|
||||
|
||||
class FieldKind(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The kind of field.
|
||||
- `Input`: An input field on a node.
|
||||
- `Output`: An output field on a node.
|
||||
- `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is
|
||||
one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name
|
||||
"metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic,
|
||||
allowing "metadata" for that field.
|
||||
- `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs,
|
||||
but which are used to store information about the node. For example, the `id` and `type` fields are node
|
||||
attributes.
|
||||
|
||||
The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app
|
||||
startup, and when generating the OpenAPI schema for the workflow editor.
|
||||
"""
|
||||
|
||||
Input = "input"
|
||||
Output = "output"
|
||||
Internal = "internal"
|
||||
NodeAttribute = "node_attribute"
|
||||
|
||||
|
||||
class UIType(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
Type hints for the UI for situations in which the field type is not enough to infer the correct UI type.
|
||||
|
||||
- Model Fields
|
||||
The most common node-author-facing use will be for model fields. Internally, there is no difference
|
||||
between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the
|
||||
base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that
|
||||
the field is an SDXL main model field.
|
||||
|
||||
- Any Field
|
||||
We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to
|
||||
indicate that the field accepts any type. Use with caution. This cannot be used on outputs.
|
||||
|
||||
- Scheduler Field
|
||||
Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field.
|
||||
|
||||
- Internal Fields
|
||||
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
|
||||
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
|
||||
should not be used by node authors.
|
||||
|
||||
- DEPRECATED Fields
|
||||
These types are deprecated and should not be used by node authors. A warning will be logged if one is
|
||||
used, and the type will be ignored. They are included here for backwards compatibility.
|
||||
"""
|
||||
|
||||
# region Model Field Types
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VaeModel = "VAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
Scheduler = "SchedulerField"
|
||||
Any = "AnyField"
|
||||
# endregion
|
||||
|
||||
# region Internal Field Types
|
||||
_Collection = "CollectionField"
|
||||
_CollectionItem = "CollectionItemField"
|
||||
# endregion
|
||||
|
||||
# region DEPRECATED
|
||||
Boolean = "DEPRECATED_Boolean"
|
||||
Color = "DEPRECATED_Color"
|
||||
Conditioning = "DEPRECATED_Conditioning"
|
||||
Control = "DEPRECATED_Control"
|
||||
Float = "DEPRECATED_Float"
|
||||
Image = "DEPRECATED_Image"
|
||||
Integer = "DEPRECATED_Integer"
|
||||
Latents = "DEPRECATED_Latents"
|
||||
String = "DEPRECATED_String"
|
||||
BooleanCollection = "DEPRECATED_BooleanCollection"
|
||||
ColorCollection = "DEPRECATED_ColorCollection"
|
||||
ConditioningCollection = "DEPRECATED_ConditioningCollection"
|
||||
ControlCollection = "DEPRECATED_ControlCollection"
|
||||
FloatCollection = "DEPRECATED_FloatCollection"
|
||||
ImageCollection = "DEPRECATED_ImageCollection"
|
||||
IntegerCollection = "DEPRECATED_IntegerCollection"
|
||||
LatentsCollection = "DEPRECATED_LatentsCollection"
|
||||
StringCollection = "DEPRECATED_StringCollection"
|
||||
BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic"
|
||||
ColorPolymorphic = "DEPRECATED_ColorPolymorphic"
|
||||
ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic"
|
||||
ControlPolymorphic = "DEPRECATED_ControlPolymorphic"
|
||||
FloatPolymorphic = "DEPRECATED_FloatPolymorphic"
|
||||
ImagePolymorphic = "DEPRECATED_ImagePolymorphic"
|
||||
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
|
||||
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
|
||||
StringPolymorphic = "DEPRECATED_StringPolymorphic"
|
||||
MainModel = "DEPRECATED_MainModel"
|
||||
UNet = "DEPRECATED_UNet"
|
||||
Vae = "DEPRECATED_Vae"
|
||||
CLIP = "DEPRECATED_CLIP"
|
||||
Collection = "DEPRECATED_Collection"
|
||||
CollectionItem = "DEPRECATED_CollectionItem"
|
||||
Enum = "DEPRECATED_Enum"
|
||||
WorkflowField = "DEPRECATED_WorkflowField"
|
||||
IsIntermediate = "DEPRECATED_IsIntermediate"
|
||||
BoardField = "DEPRECATED_BoardField"
|
||||
MetadataItem = "DEPRECATED_MetadataItem"
|
||||
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
|
||||
MetadataDict = "DEPRECATED_MetadataDict"
|
||||
# endregion
|
||||
|
||||
|
||||
class UIComponent(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of UI component to use for a field, used to override the default components, which are
|
||||
inferred from the field type.
|
||||
"""
|
||||
|
||||
None_ = "none"
|
||||
Textarea = "textarea"
|
||||
Slider = "slider"
|
||||
|
||||
|
||||
class InputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution,
|
||||
and by the workflow editor during schema parsing and UI rendering.
|
||||
"""
|
||||
|
||||
input: Input
|
||||
orig_required: bool
|
||||
field_kind: FieldKind
|
||||
default: Optional[Any] = None
|
||||
orig_default: Optional[Any] = None
|
||||
ui_hidden: bool = False
|
||||
ui_type: Optional[UIType] = None
|
||||
ui_component: Optional[UIComponent] = None
|
||||
ui_order: Optional[int] = None
|
||||
ui_choice_labels: Optional[dict[str, str]] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
class OutputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor
|
||||
during schema parsing and UI rendering.
|
||||
"""
|
||||
|
||||
field_kind: FieldKind
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
def InputField(
|
||||
# copied from pydantic's Field
|
||||
# TODO: Can we support default_factory?
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
input: Input = Input.Any,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||
`Input.Direct` means a value must be provided on instantiation. \
|
||||
`Input.Connection` means the value must be provided by a connection. \
|
||||
`Input.Any` means either will do.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
|
||||
|
||||
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
|
||||
"""
|
||||
|
||||
json_schema_extra_ = InputFieldJSONSchemaExtra(
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
field_kind=FieldKind.Input,
|
||||
orig_required=True,
|
||||
)
|
||||
|
||||
"""
|
||||
There is a conflict between the typing of invocation definitions and the typing of an invocation's
|
||||
`invoke()` function.
|
||||
|
||||
On instantiation of a node, the invocation definition is used to create the python class. At this time,
|
||||
any number of fields may be optional, because they may be provided by connections.
|
||||
|
||||
On calling of `invoke()`, however, those fields may be required.
|
||||
|
||||
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||
|
||||
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||
the field may not be present. This is fine, because that image field will be provided by a
|
||||
connection from an ancestor node, which outputs an image.
|
||||
|
||||
This means we want to type the `image` field as optional for the node class definition, but required
|
||||
for the `invoke()` function.
|
||||
|
||||
If we use `typing.Optional` in the node class definition, the field will be typed as optional in the
|
||||
`invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or
|
||||
any static type analysis tools will complain.
|
||||
|
||||
To get around this, in node class definitions, we type all fields correctly for the `invoke()` function,
|
||||
but secretly make them optional in `InputField()`. We also store the original required bool and/or default
|
||||
value. When we call `invoke()`, we use this stored information to do an additional check on the class.
|
||||
"""
|
||||
|
||||
if default_factory is not _Unset and default_factory is not None:
|
||||
default = default_factory()
|
||||
logger.warn('"default_factory" is not supported, calling it now to set "default"')
|
||||
|
||||
# These are the args we may wish pass to the pydantic `Field()` function
|
||||
field_args = {
|
||||
"default": default,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"pattern": pattern,
|
||||
"strict": strict,
|
||||
"gt": gt,
|
||||
"ge": ge,
|
||||
"lt": lt,
|
||||
"le": le,
|
||||
"multiple_of": multiple_of,
|
||||
"allow_inf_nan": allow_inf_nan,
|
||||
"max_digits": max_digits,
|
||||
"decimal_places": decimal_places,
|
||||
"min_length": min_length,
|
||||
"max_length": max_length,
|
||||
}
|
||||
|
||||
# We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected
|
||||
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
|
||||
|
||||
# Because we are manually making fields optional, we need to store the original required bool for reference later
|
||||
json_schema_extra_.orig_required = default is PydanticUndefined
|
||||
|
||||
# Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||
if input is Input.Any or input is Input.Connection:
|
||||
default_ = None if default is PydanticUndefined else default
|
||||
provided_args.update({"default": default_})
|
||||
if default is not PydanticUndefined:
|
||||
# Before invoking, we'll check for the original default value and set it on the field if the field has no value
|
||||
json_schema_extra_.default = default
|
||||
json_schema_extra_.orig_default = default
|
||||
elif default is not PydanticUndefined:
|
||||
default_ = default
|
||||
provided_args.update({"default": default_})
|
||||
json_schema_extra_.orig_default = default_
|
||||
|
||||
return Field(
|
||||
**provided_args,
|
||||
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def OutputField(
|
||||
# copied from pydantic's Field
|
||||
default: Any = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
default=default,
|
||||
title=title,
|
||||
description=description,
|
||||
pattern=pattern,
|
||||
strict=strict,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
json_schema_extra=OutputFieldJSONSchemaExtra(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
field_kind=FieldKind.Output,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class UIConfigBase(BaseModel):
|
||||
"""
|
||||
Provides additional node configuration to the UI.
|
||||
@@ -89,6 +460,33 @@ class UIConfigBase(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Initialized and provided to on execution of invocations."""
|
||||
|
||||
services: InvocationServices
|
||||
graph_execution_state_id: str
|
||||
queue_id: str
|
||||
queue_item_id: int
|
||||
queue_batch_id: str
|
||||
workflow: Optional[WorkflowWithoutID]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
services: InvocationServices,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
workflow: Optional[WorkflowWithoutID],
|
||||
):
|
||||
self.services = services
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.queue_id = queue_id
|
||||
self.queue_item_id = queue_item_id
|
||||
self.queue_batch_id = queue_batch_id
|
||||
self.workflow = workflow
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
"""
|
||||
Base class for all invocation outputs.
|
||||
@@ -97,7 +495,6 @@ class BaseInvocationOutput(BaseModel):
|
||||
"""
|
||||
|
||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
@@ -110,14 +507,10 @@ class BaseInvocationOutput(BaseModel):
|
||||
return cls._output_classes
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationOutputsUnion = TypeAliasType(
|
||||
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
|
||||
return cls._typeadapter
|
||||
def get_outputs_union(cls) -> UnionType:
|
||||
"""Gets a union of all invocation outputs."""
|
||||
outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type]
|
||||
return outputs_union # type: ignore [return-value]
|
||||
|
||||
@classmethod
|
||||
def get_output_types(cls) -> Iterable[str]:
|
||||
@@ -166,7 +559,6 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
@@ -179,14 +571,10 @@ class BaseInvocation(ABC, BaseModel):
|
||||
cls._invocation_classes.add(invocation)
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter:
|
||||
InvocationsUnion = TypeAliasType(
|
||||
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
||||
return cls._typeadapter
|
||||
def get_invocations_union(cls) -> UnionType:
|
||||
"""Gets a union of all invocation types."""
|
||||
invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type]
|
||||
return invocations_union # type: ignore [return-value]
|
||||
|
||||
@classmethod
|
||||
def get_invocations(cls) -> Iterable[BaseInvocation]:
|
||||
@@ -244,7 +632,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
"""Invoke with provided context and return outputs."""
|
||||
pass
|
||||
|
||||
def invoke_internal(self, context: InvocationContext, services: "InvocationServices") -> BaseInvocationOutput:
|
||||
def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput:
|
||||
"""
|
||||
Internal invoke method, calls `invoke()` after some prep.
|
||||
Handles optional fields that are required to call `invoke()` and invocation cache.
|
||||
@@ -269,23 +657,23 @@ class BaseInvocation(ABC, BaseModel):
|
||||
raise MissingInputException(self.model_fields["type"].default, field_name)
|
||||
|
||||
# skip node cache codepath if it's disabled
|
||||
if services.configuration.node_cache_size == 0:
|
||||
if context.services.configuration.node_cache_size == 0:
|
||||
return self.invoke(context)
|
||||
|
||||
output: BaseInvocationOutput
|
||||
if self.use_cache:
|
||||
key = services.invocation_cache.create_key(self)
|
||||
cached_value = services.invocation_cache.get(key)
|
||||
key = context.services.invocation_cache.create_key(self)
|
||||
cached_value = context.services.invocation_cache.get(key)
|
||||
if cached_value is None:
|
||||
services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
|
||||
output = self.invoke(context)
|
||||
services.invocation_cache.save(key, output)
|
||||
context.services.invocation_cache.save(key, output)
|
||||
return output
|
||||
else:
|
||||
services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
||||
context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
|
||||
return cached_value
|
||||
else:
|
||||
services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
|
||||
return self.invoke(context)
|
||||
|
||||
id: str = Field(
|
||||
@@ -326,7 +714,9 @@ RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = {
|
||||
"workflow",
|
||||
}
|
||||
|
||||
RESERVED_INPUT_FIELD_NAMES = {"metadata", "board"}
|
||||
RESERVED_INPUT_FIELD_NAMES = {
|
||||
"metadata",
|
||||
}
|
||||
|
||||
RESERVED_OUTPUT_FIELD_NAMES = {"type"}
|
||||
|
||||
@@ -536,3 +926,37 @@ def invocation_output(
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
Metadata is stored without a strict schema.
|
||||
"""
|
||||
|
||||
root: dict[str, Any] = Field(description="The metadata")
|
||||
|
||||
|
||||
MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||
|
||||
|
||||
class WithMetadata(BaseModel):
|
||||
metadata: Optional[MetadataField] = Field(
|
||||
default=None,
|
||||
description=FieldDescriptions.metadata,
|
||||
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||
field_kind=FieldKind.Internal,
|
||||
input=Input.Connection,
|
||||
orig_required=False,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class WithWorkflow:
|
||||
workflow = None
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
logger.warn(
|
||||
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
|
||||
)
|
||||
super().__init_subclass__()
|
||||
|
||||
@@ -5,11 +5,9 @@ import numpy as np
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import IntegerCollectionOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation(
|
||||
|
||||
@@ -1,28 +1,45 @@
|
||||
from typing import Iterator, List, Optional, Tuple, Union, cast
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIComponent
|
||||
from invokeai.app.invocations.primitives import ConditioningOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.ti_utils import generate_ti_list
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.ti_utils import extract_ti_triggers_from_prompt
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import ModelType
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningFieldData,
|
||||
ExtraConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.util.devices import torch_dtype
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .model import ClipField
|
||||
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[BasicConditioningInfo]
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
|
||||
# class ConditioningAlgo(str, Enum):
|
||||
@@ -36,7 +53,7 @@ from .model import ClipField
|
||||
title="Prompt",
|
||||
tags=["prompt", "compel"],
|
||||
category="conditioning",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CompelInvocation(BaseInvocation):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@@ -54,27 +71,45 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(**self.clip.tokenizer.model_dump())
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(**self.clip.text_encoder.model_dump())
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||
tokenizer_info = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_info = context.services.model_manager.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
||||
ti_list = []
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
context=context,
|
||||
).model
|
||||
assert isinstance(loaded_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, loaded_model))
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with (
|
||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
@@ -82,9 +117,8 @@ class CompelInvocation(BaseInvocation):
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, self.clip.skipped_layers),
|
||||
):
|
||||
assert isinstance(text_encoder, CLIPTextModel)
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
@@ -95,7 +129,7 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
if context.services.configuration.log_tokenization:
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
@@ -116,9 +150,14 @@ class CompelInvocation(BaseInvocation):
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data) # TODO: fix type mismatch here
|
||||
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SDXLPromptInvocationBase:
|
||||
@@ -133,12 +172,14 @@ class SDXLPromptInvocationBase:
|
||||
lora_prefix: str,
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[ExtraConditioningInfo]]:
|
||||
tokenizer_info = context.models.load(**clip_field.tokenizer.model_dump())
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(**clip_field.text_encoder.model_dump())
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
tokenizer_info = context.services.model_manager.load_model_by_key(
|
||||
**clip_field.tokenizer.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.load_model_by_key(
|
||||
**clip_field.text_encoder.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
@@ -163,19 +204,39 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in clip_field.loras:
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_info = context.services.model_manager.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}), context=context
|
||||
)
|
||||
lora_model = lora_info.model
|
||||
assert isinstance(lora_model, LoRAModelRaw)
|
||||
yield (lora_model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
# loras = [(context.models.get(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
# loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||
|
||||
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
|
||||
ti_list = []
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_model = context.services.model_manager.load_model_by_attr(
|
||||
model_name=name,
|
||||
base_model=text_encoder_info.config.base,
|
||||
model_type=ModelType.TextualInversion,
|
||||
context=context,
|
||||
).model
|
||||
assert isinstance(ti_model, TextualInversionModelRaw)
|
||||
ti_list.append((name, ti_model))
|
||||
except UnknownModelException:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
logger.warning(f'trigger: "{trigger}" not found')
|
||||
except ValueError:
|
||||
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||
|
||||
with (
|
||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||
ModelPatcher.apply_ti(tokenizer_info.model, text_encoder_info.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
@@ -183,10 +244,8 @@ class SDXLPromptInvocationBase:
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.model, clip_field.skipped_layers),
|
||||
):
|
||||
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
text_encoder = cast(CLIPTextModel, text_encoder)
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
@@ -199,7 +258,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt)
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
if context.services.configuration.log_tokenization:
|
||||
# TODO: better logging for and syntax
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
@@ -232,7 +291,7 @@ class SDXLPromptInvocationBase:
|
||||
title="SDXL Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@@ -315,9 +374,14 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -325,7 +389,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
title="SDXL Refiner Prompt",
|
||||
tags=["sdxl", "compel", "prompt"],
|
||||
category="conditioning",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
"""Parse prompt using compel package to conditioning."""
|
||||
@@ -364,9 +428,14 @@ class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
context.services.latents.save(conditioning_name, conditioning_data)
|
||||
|
||||
return ConditioningOutput.build(conditioning_name)
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("clip_skip_output")
|
||||
@@ -387,7 +456,7 @@ class ClipSkipInvocation(BaseInvocation):
|
||||
"""Skip layers in clip text_encoder model."""
|
||||
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP")
|
||||
skipped_layers: int = InputField(default=0, ge=0, description=FieldDescriptions.skipped_layers)
|
||||
skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput:
|
||||
self.clip.skipped_layers += self.skipped_layers
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
|
||||
LATENT_SCALE_FACTOR = 8
|
||||
"""
|
||||
HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
|
||||
be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
|
||||
factor is hard-coded to a literal '8' rather than using this constant.
|
||||
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
|
||||
"""
|
||||
|
||||
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
|
||||
"""A literal type representing the valid scheduler names."""
|
||||
|
||||
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
|
||||
"""A literal type for PIL image modes supported by Invoke"""
|
||||
@@ -17,6 +17,7 @@ from controlnet_aux import (
|
||||
MidasDetector,
|
||||
MLSDdetector,
|
||||
NormalBaeDetector,
|
||||
OpenposeDetector,
|
||||
PidiNetDetector,
|
||||
SamDetector,
|
||||
ZoeDetector,
|
||||
@@ -25,22 +26,23 @@ from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
|
||||
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
|
||||
CONTROLNET_RESIZE_VALUES = Literal[
|
||||
@@ -134,7 +136,7 @@ class ControlNetInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# This invocation exists for other invocations to subclass it - do not register with @invocation!
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class ImageProcessorInvocation(BaseInvocation, WithMetadata):
|
||||
"""Base class for invocations that preprocess images for ControlNet"""
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
@@ -143,18 +145,23 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
|
||||
def load_image(self, context: InvocationContext) -> Image.Image:
|
||||
# allows override for any special formatting specific to the preprocessor
|
||||
return context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raw_image = self.load_image(context)
|
||||
raw_image = context.services.images.get_pil_image(self.image.image_name)
|
||||
# image type should be PIL.PngImagePlugin.PngImageFile ?
|
||||
processed_image = self.run_processor(raw_image)
|
||||
|
||||
# currently can't see processed image in node UI without a showImage node,
|
||||
# so for now setting image_type to RESULT instead of INTERMEDIATE so will get saved in gallery
|
||||
image_dto = context.images.save(image=processed_image)
|
||||
image_dto = context.services.images.create(
|
||||
image=processed_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.CONTROL,
|
||||
session_id=context.graph_execution_state_id,
|
||||
node_id=self.id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
"""Builds an ImageOutput and its ImageField"""
|
||||
processed_image_field = ImageField(image_name=image_dto.image_name)
|
||||
@@ -173,7 +180,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
title="Canny Processor",
|
||||
tags=["controlnet", "canny"],
|
||||
category="controlnet",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Canny edge detection for ControlNet"""
|
||||
@@ -185,10 +192,6 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)"
|
||||
)
|
||||
|
||||
def load_image(self, context: InvocationContext) -> Image.Image:
|
||||
# Keep alpha channel for Canny processing to detect edges of transparent areas
|
||||
return context.images.get_pil(self.image.image_name, "RGBA")
|
||||
|
||||
def run_processor(self, image):
|
||||
canny_processor = CannyDetector()
|
||||
processed_image = canny_processor(image, self.low_threshold, self.high_threshold)
|
||||
@@ -200,7 +203,7 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="HED (softedge) Processor",
|
||||
tags=["controlnet", "hed", "softedge"],
|
||||
category="controlnet",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies HED edge detection to image"""
|
||||
@@ -229,7 +232,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Processor",
|
||||
tags=["controlnet", "lineart"],
|
||||
category="controlnet",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art processing to image"""
|
||||
@@ -251,7 +254,7 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Lineart Anime Processor",
|
||||
tags=["controlnet", "lineart", "anime"],
|
||||
category="controlnet",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies line art anime processing to image"""
|
||||
@@ -269,12 +272,37 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"openpose_image_processor",
|
||||
title="Openpose Processor",
|
||||
tags=["controlnet", "openpose", "pose"],
|
||||
category="controlnet",
|
||||
version="1.2.0",
|
||||
)
|
||||
class OpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Openpose processing to image"""
|
||||
|
||||
hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode")
|
||||
detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image):
|
||||
openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = openpose_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
hand_and_face=self.hand_and_face,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"midas_depth_image_processor",
|
||||
title="Midas Depth Processor",
|
||||
tags=["controlnet", "midas"],
|
||||
category="controlnet",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Midas depth processing to image"""
|
||||
@@ -301,7 +329,7 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Normal BAE Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies NormalBae processing to image"""
|
||||
@@ -318,7 +346,7 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.1"
|
||||
"mlsd_image_processor", title="MLSD Processor", tags=["controlnet", "mlsd"], category="controlnet", version="1.2.0"
|
||||
)
|
||||
class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies MLSD processing to image"""
|
||||
@@ -341,7 +369,7 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
|
||||
|
||||
@invocation(
|
||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.1"
|
||||
"pidi_image_processor", title="PIDI Processor", tags=["controlnet", "pidi"], category="controlnet", version="1.2.0"
|
||||
)
|
||||
class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies PIDI processing to image"""
|
||||
@@ -368,7 +396,7 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Content Shuffle Processor",
|
||||
tags=["controlnet", "contentshuffle"],
|
||||
category="controlnet",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies content shuffle processing to image"""
|
||||
@@ -398,7 +426,7 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Zoe (Depth) Processor",
|
||||
tags=["controlnet", "zoe", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
@@ -414,7 +442,7 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Mediapipe Face Processor",
|
||||
tags=["controlnet", "mediapipe", "face"],
|
||||
category="controlnet",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies mediapipe face processing to image"""
|
||||
@@ -423,6 +451,10 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection")
|
||||
|
||||
def run_processor(self, image):
|
||||
# MediaPipeFaceDetector throws an error if image has alpha channel
|
||||
# so convert to RGB if needed
|
||||
if image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
mediapipe_face_processor = MediapipeFaceDetector()
|
||||
processed_image = mediapipe_face_processor(image, max_faces=self.max_faces, min_confidence=self.min_confidence)
|
||||
return processed_image
|
||||
@@ -433,7 +465,7 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Leres (Depth) Processor",
|
||||
tags=["controlnet", "leres", "depth"],
|
||||
category="controlnet",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies leres processing to image"""
|
||||
@@ -462,7 +494,7 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Tile Resample Processor",
|
||||
tags=["controlnet", "tile"],
|
||||
category="controlnet",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Tile resampler processor"""
|
||||
@@ -502,7 +534,7 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation):
|
||||
title="Segment Anything Processor",
|
||||
tags=["controlnet", "segmentanything"],
|
||||
category="controlnet",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies segment anything processing to image"""
|
||||
@@ -544,7 +576,7 @@ class SamDetectorReproducibleColors(SamDetector):
|
||||
title="Color Map Processor",
|
||||
tags=["controlnet"],
|
||||
category="controlnet",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a color map from the provided image"""
|
||||
@@ -552,6 +584,7 @@ class ColorMapImageProcessorInvocation(ImageProcessorInvocation):
|
||||
color_map_tile_size: int = InputField(default=64, ge=0, description=FieldDescriptions.tile_size)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
image = image.convert("RGB")
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
height, width = np_image.shape[:2]
|
||||
|
||||
@@ -587,36 +620,12 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
|
||||
offload: bool = InputField(default=False)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
def run_processor(self, image):
|
||||
depth_anything_detector = DepthAnythingDetector()
|
||||
depth_anything_detector.load_model(model_size=self.model_size)
|
||||
|
||||
if image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
"dw_openpose_image_processor",
|
||||
title="DW Openpose Image Processor",
|
||||
tags=["controlnet", "dwpose", "openpose"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
)
|
||||
class DWOpenposeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates an openpose pose from an image using DWPose"""
|
||||
|
||||
draw_body: bool = InputField(default=True)
|
||||
draw_face: bool = InputField(default=False)
|
||||
draw_hands: bool = InputField(default=False)
|
||||
image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
dw_openpose = DWOpenposeDetector()
|
||||
processed_image = dw_openpose(
|
||||
image,
|
||||
draw_face=self.draw_face,
|
||||
draw_hands=self.draw_hands,
|
||||
draw_body=self.draw_body,
|
||||
resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
@@ -5,24 +5,22 @@ import cv2 as cv
|
||||
import numpy
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
||||
|
||||
|
||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@invocation("cv_inpaint", title="OpenCV Inpaint", tags=["opencv", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class CvInpaintInvocation(BaseInvocation, WithMetadata):
|
||||
"""Simple inpaint using opencv."""
|
||||
|
||||
image: ImageField = InputField(description="The image to inpaint")
|
||||
mask: ImageField = InputField(description="The mask to use when inpainting")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
mask = context.images.get_pil(self.mask.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
mask = context.services.images.get_pil_image(self.mask.image_name)
|
||||
|
||||
# Convert to cv image/mask
|
||||
# TODO: consider making these utility functions
|
||||
@@ -36,6 +34,18 @@ class CvInpaintInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# TODO: consider making a utility function
|
||||
image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB))
|
||||
|
||||
image_dto = context.images.save(image=image_inpainted)
|
||||
image_dto = context.services.images.create(
|
||||
image=image_inpainted,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -13,13 +13,15 @@ from pydantic import field_validator
|
||||
import invokeai.assets.fonts as font_assets
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithMetadata,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import ImageField, InputField, OutputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
|
||||
|
||||
@invocation_output("face_mask_output")
|
||||
@@ -304,37 +306,37 @@ def extract_face(
|
||||
|
||||
# Adjust the crop boundaries to stay within the original image's dimensions
|
||||
if x_min < 0:
|
||||
context.logger.warning("FaceTools --> -X-axis padding reached image edge.")
|
||||
context.services.logger.warning("FaceTools --> -X-axis padding reached image edge.")
|
||||
x_max -= x_min
|
||||
x_min = 0
|
||||
elif x_max > mask.width:
|
||||
context.logger.warning("FaceTools --> +X-axis padding reached image edge.")
|
||||
context.services.logger.warning("FaceTools --> +X-axis padding reached image edge.")
|
||||
x_min -= x_max - mask.width
|
||||
x_max = mask.width
|
||||
|
||||
if y_min < 0:
|
||||
context.logger.warning("FaceTools --> +Y-axis padding reached image edge.")
|
||||
context.services.logger.warning("FaceTools --> +Y-axis padding reached image edge.")
|
||||
y_max -= y_min
|
||||
y_min = 0
|
||||
elif y_max > mask.height:
|
||||
context.logger.warning("FaceTools --> -Y-axis padding reached image edge.")
|
||||
context.services.logger.warning("FaceTools --> -Y-axis padding reached image edge.")
|
||||
y_min -= y_max - mask.height
|
||||
y_max = mask.height
|
||||
|
||||
# Ensure the crop is square and adjust the boundaries if needed
|
||||
if x_max - x_min != crop_size:
|
||||
context.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.")
|
||||
context.services.logger.warning("FaceTools --> Limiting x-axis padding to constrain bounding box to a square.")
|
||||
diff = crop_size - (x_max - x_min)
|
||||
x_min -= diff // 2
|
||||
x_max += diff - diff // 2
|
||||
|
||||
if y_max - y_min != crop_size:
|
||||
context.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.")
|
||||
context.services.logger.warning("FaceTools --> Limiting y-axis padding to constrain bounding box to a square.")
|
||||
diff = crop_size - (y_max - y_min)
|
||||
y_min -= diff // 2
|
||||
y_max += diff - diff // 2
|
||||
|
||||
context.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}")
|
||||
context.services.logger.info(f"FaceTools --> Calculated bounding box (8 multiple): {crop_size}")
|
||||
|
||||
# Crop the output image to the specified size with the center of the face mesh as the center.
|
||||
mask = mask.crop((x_min, y_min, x_max, y_max))
|
||||
@@ -366,7 +368,7 @@ def get_faces_list(
|
||||
|
||||
# Generate the face box mask and get the center of the face.
|
||||
if not should_chunk:
|
||||
context.logger.info("FaceTools --> Attempting full image face detection.")
|
||||
context.services.logger.info("FaceTools --> Attempting full image face detection.")
|
||||
result = generate_face_box_mask(
|
||||
context=context,
|
||||
minimum_confidence=minimum_confidence,
|
||||
@@ -378,7 +380,7 @@ def get_faces_list(
|
||||
draw_mesh=draw_mesh,
|
||||
)
|
||||
if should_chunk or len(result) == 0:
|
||||
context.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
|
||||
context.services.logger.info("FaceTools --> Chunking image (chunk toggled on, or no face found in full image).")
|
||||
width, height = image.size
|
||||
image_chunks = []
|
||||
x_offsets = []
|
||||
@@ -397,7 +399,7 @@ def get_faces_list(
|
||||
x_offsets.append(x)
|
||||
y_offsets.append(0)
|
||||
fx += increment
|
||||
context.logger.info(f"FaceTools --> Chunk starting at x = {x}")
|
||||
context.services.logger.info(f"FaceTools --> Chunk starting at x = {x}")
|
||||
elif height > width:
|
||||
# Portrait - slice the image vertically
|
||||
fy = 0.0
|
||||
@@ -409,10 +411,10 @@ def get_faces_list(
|
||||
x_offsets.append(0)
|
||||
y_offsets.append(y)
|
||||
fy += increment
|
||||
context.logger.info(f"FaceTools --> Chunk starting at y = {y}")
|
||||
context.services.logger.info(f"FaceTools --> Chunk starting at y = {y}")
|
||||
|
||||
for idx in range(len(image_chunks)):
|
||||
context.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}")
|
||||
context.services.logger.info(f"FaceTools --> Evaluating faces in chunk {idx}")
|
||||
result = result + generate_face_box_mask(
|
||||
context=context,
|
||||
minimum_confidence=minimum_confidence,
|
||||
@@ -426,7 +428,7 @@ def get_faces_list(
|
||||
|
||||
if len(result) == 0:
|
||||
# Give up
|
||||
context.logger.warning(
|
||||
context.services.logger.warning(
|
||||
"FaceTools --> No face detected in chunked input image. Passing through original image."
|
||||
)
|
||||
|
||||
@@ -435,7 +437,7 @@ def get_faces_list(
|
||||
return all_faces
|
||||
|
||||
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.1")
|
||||
@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.0")
|
||||
class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
"""Bound, extract, and mask a face from an image using MediaPipe detection"""
|
||||
|
||||
@@ -468,11 +470,11 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
)
|
||||
|
||||
if len(all_faces) == 0:
|
||||
context.logger.warning("FaceOff --> No faces detected. Passing through original image.")
|
||||
context.services.logger.warning("FaceOff --> No faces detected. Passing through original image.")
|
||||
return None
|
||||
|
||||
if self.face_id > len(all_faces) - 1:
|
||||
context.logger.warning(
|
||||
context.services.logger.warning(
|
||||
f"FaceOff --> Face ID {self.face_id} is outside of the number of faces detected ({len(all_faces)}). Passing through original image."
|
||||
)
|
||||
return None
|
||||
@@ -484,7 +486,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
return face_data
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FaceOffOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
result = self.faceoff(context=context, image=image)
|
||||
|
||||
if result is None:
|
||||
@@ -498,9 +500,24 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
x = result["x_min"]
|
||||
y = result["y_min"]
|
||||
|
||||
image_dto = context.images.save(image=result_image)
|
||||
image_dto = context.services.images.create(
|
||||
image=result_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
mask_dto = context.images.save(image=result_mask, image_category=ImageCategory.MASK)
|
||||
mask_dto = context.services.images.create(
|
||||
image=result_mask,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.MASK,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
output = FaceOffOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
@@ -514,7 +531,7 @@ class FaceOffInvocation(BaseInvocation, WithMetadata):
|
||||
return output
|
||||
|
||||
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.1")
|
||||
@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.0")
|
||||
class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
"""Face mask creation using mediapipe face detection"""
|
||||
|
||||
@@ -563,7 +580,7 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
if len(intersected_face_ids) == 0:
|
||||
id_range_str = ",".join([str(id) for id in id_range])
|
||||
context.logger.warning(
|
||||
context.services.logger.warning(
|
||||
f"Face IDs must be in range of detected faces - requested {self.face_ids}, detected {id_range_str}. Passing through original image."
|
||||
)
|
||||
return FaceMaskResult(
|
||||
@@ -599,12 +616,27 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FaceMaskOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
result = self.facemask(context=context, image=image)
|
||||
|
||||
image_dto = context.images.save(image=result["image"])
|
||||
image_dto = context.services.images.create(
|
||||
image=result["image"],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
mask_dto = context.images.save(image=result["mask"], image_category=ImageCategory.MASK)
|
||||
mask_dto = context.services.images.create(
|
||||
image=result["mask"],
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.MASK,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
)
|
||||
|
||||
output = FaceMaskOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
@@ -617,9 +649,9 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
|
||||
@invocation(
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.1"
|
||||
"face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.0"
|
||||
)
|
||||
class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class FaceIdentifierInvocation(BaseInvocation, WithMetadata):
|
||||
"""Outputs an image with detected face IDs printed on each face. For use with other FaceTools."""
|
||||
|
||||
image: ImageField = InputField(description="Image to face detect")
|
||||
@@ -673,9 +705,21 @@ class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
return image
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
result_image = self.faceidentifier(context=context, image=image)
|
||||
|
||||
image_dto = context.images.save(image=result_image)
|
||||
image_dto = context.services.images.create(
|
||||
image=result_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -1,566 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, TypeAdapter
|
||||
from pydantic.fields import _Unset
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class UIType(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
Type hints for the UI for situations in which the field type is not enough to infer the correct UI type.
|
||||
|
||||
- Model Fields
|
||||
The most common node-author-facing use will be for model fields. Internally, there is no difference
|
||||
between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the
|
||||
base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that
|
||||
the field is an SDXL main model field.
|
||||
|
||||
- Any Field
|
||||
We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to
|
||||
indicate that the field accepts any type. Use with caution. This cannot be used on outputs.
|
||||
|
||||
- Scheduler Field
|
||||
Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field.
|
||||
|
||||
- Internal Fields
|
||||
Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate
|
||||
handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These
|
||||
should not be used by node authors.
|
||||
|
||||
- DEPRECATED Fields
|
||||
These types are deprecated and should not be used by node authors. A warning will be logged if one is
|
||||
used, and the type will be ignored. They are included here for backwards compatibility.
|
||||
"""
|
||||
|
||||
# region Model Field Types
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VaeModel = "VAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
# endregion
|
||||
|
||||
# region Misc Field Types
|
||||
Scheduler = "SchedulerField"
|
||||
Any = "AnyField"
|
||||
# endregion
|
||||
|
||||
# region Internal Field Types
|
||||
_Collection = "CollectionField"
|
||||
_CollectionItem = "CollectionItemField"
|
||||
# endregion
|
||||
|
||||
# region DEPRECATED
|
||||
Boolean = "DEPRECATED_Boolean"
|
||||
Color = "DEPRECATED_Color"
|
||||
Conditioning = "DEPRECATED_Conditioning"
|
||||
Control = "DEPRECATED_Control"
|
||||
Float = "DEPRECATED_Float"
|
||||
Image = "DEPRECATED_Image"
|
||||
Integer = "DEPRECATED_Integer"
|
||||
Latents = "DEPRECATED_Latents"
|
||||
String = "DEPRECATED_String"
|
||||
BooleanCollection = "DEPRECATED_BooleanCollection"
|
||||
ColorCollection = "DEPRECATED_ColorCollection"
|
||||
ConditioningCollection = "DEPRECATED_ConditioningCollection"
|
||||
ControlCollection = "DEPRECATED_ControlCollection"
|
||||
FloatCollection = "DEPRECATED_FloatCollection"
|
||||
ImageCollection = "DEPRECATED_ImageCollection"
|
||||
IntegerCollection = "DEPRECATED_IntegerCollection"
|
||||
LatentsCollection = "DEPRECATED_LatentsCollection"
|
||||
StringCollection = "DEPRECATED_StringCollection"
|
||||
BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic"
|
||||
ColorPolymorphic = "DEPRECATED_ColorPolymorphic"
|
||||
ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic"
|
||||
ControlPolymorphic = "DEPRECATED_ControlPolymorphic"
|
||||
FloatPolymorphic = "DEPRECATED_FloatPolymorphic"
|
||||
ImagePolymorphic = "DEPRECATED_ImagePolymorphic"
|
||||
IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic"
|
||||
LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic"
|
||||
StringPolymorphic = "DEPRECATED_StringPolymorphic"
|
||||
MainModel = "DEPRECATED_MainModel"
|
||||
UNet = "DEPRECATED_UNet"
|
||||
Vae = "DEPRECATED_Vae"
|
||||
CLIP = "DEPRECATED_CLIP"
|
||||
Collection = "DEPRECATED_Collection"
|
||||
CollectionItem = "DEPRECATED_CollectionItem"
|
||||
Enum = "DEPRECATED_Enum"
|
||||
WorkflowField = "DEPRECATED_WorkflowField"
|
||||
IsIntermediate = "DEPRECATED_IsIntermediate"
|
||||
BoardField = "DEPRECATED_BoardField"
|
||||
MetadataItem = "DEPRECATED_MetadataItem"
|
||||
MetadataItemCollection = "DEPRECATED_MetadataItemCollection"
|
||||
MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic"
|
||||
MetadataDict = "DEPRECATED_MetadataDict"
|
||||
|
||||
|
||||
class UIComponent(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of UI component to use for a field, used to override the default components, which are
|
||||
inferred from the field type.
|
||||
"""
|
||||
|
||||
None_ = "none"
|
||||
Textarea = "textarea"
|
||||
Slider = "slider"
|
||||
|
||||
|
||||
class FieldDescriptions:
|
||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||
cfg_scale = "Classifier-Free Guidance scale"
|
||||
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
|
||||
scheduler = "Scheduler to use during inference"
|
||||
positive_cond = "Positive conditioning tensor"
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
raw_prompt = "Raw prompt text (no parsing)"
|
||||
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
|
||||
skipped_layers = "Number of layers to skip in text encoder"
|
||||
seed = "Seed for random number generation"
|
||||
steps = "Number of steps to run"
|
||||
width = "Width of output (px)"
|
||||
height = "Height of output (px)"
|
||||
control = "ControlNet(s) to apply"
|
||||
ip_adapter = "IP-Adapter to apply"
|
||||
t2i_adapter = "T2I-Adapter(s) to apply"
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
metadata = "Optional metadata to be saved with the image"
|
||||
metadata_collection = "Collection of Metadata"
|
||||
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
|
||||
metadata_item_label = "Label for this metadata item"
|
||||
metadata_item_value = "The value for this metadata item (may be any type)"
|
||||
workflow = "Optional workflow to be saved with the image"
|
||||
interp_mode = "Interpolation mode"
|
||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||
fp32 = "Whether or not to use full float32 precision"
|
||||
precision = "Precision to use"
|
||||
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
||||
detect_res = "Pixel resolution for detection"
|
||||
image_res = "Pixel resolution for output image"
|
||||
safe_mode = "Whether or not to use safe mode"
|
||||
scribble_mode = "Whether or not to use scribble mode"
|
||||
scale_factor = "The factor by which to scale"
|
||||
blend_alpha = (
|
||||
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
|
||||
)
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
mask = "The mask to use for the operation"
|
||||
board = "The board to save the image to"
|
||||
image = "The image to process"
|
||||
tile_size = "Tile size"
|
||||
inclusive_low = "The inclusive low value"
|
||||
exclusive_high = "The exclusive high value"
|
||||
decimal_places = "The number of decimal places to round to"
|
||||
freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
||||
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
||||
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
|
||||
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image primitive field"""
|
||||
|
||||
image_name: str = Field(description="The name of the image")
|
||||
|
||||
|
||||
class BoardField(BaseModel):
|
||||
"""A board primitive field"""
|
||||
|
||||
board_id: str = Field(description="The id of the board")
|
||||
|
||||
|
||||
class DenoiseMaskField(BaseModel):
|
||||
"""An inpaint mask field"""
|
||||
|
||||
mask_name: str = Field(description="The name of the mask image")
|
||||
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
|
||||
gradient: bool = Field(default=False, description="Used for gradient inpainting")
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents tensor primitive field"""
|
||||
|
||||
latents_name: str = Field(description="The name of the latents")
|
||||
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
"""A color primitive field"""
|
||||
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
# endregion
|
||||
|
||||
|
||||
class MetadataField(RootModel):
|
||||
"""
|
||||
Pydantic model for metadata with custom root of type dict[str, Any].
|
||||
Metadata is stored without a strict schema.
|
||||
"""
|
||||
|
||||
root: dict[str, Any] = Field(description="The metadata")
|
||||
|
||||
|
||||
MetadataFieldValidator = TypeAdapter(MetadataField)
|
||||
|
||||
|
||||
class Input(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The type of input a field accepts.
|
||||
- `Input.Direct`: The field must have its value provided directly, when the invocation and field \
|
||||
are instantiated.
|
||||
- `Input.Connection`: The field must have its value provided by a connection.
|
||||
- `Input.Any`: The field may have its value provided either directly or by a connection.
|
||||
"""
|
||||
|
||||
Connection = "connection"
|
||||
Direct = "direct"
|
||||
Any = "any"
|
||||
|
||||
|
||||
class FieldKind(str, Enum, metaclass=MetaEnum):
|
||||
"""
|
||||
The kind of field.
|
||||
- `Input`: An input field on a node.
|
||||
- `Output`: An output field on a node.
|
||||
- `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is
|
||||
one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name
|
||||
"metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic,
|
||||
allowing "metadata" for that field.
|
||||
- `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs,
|
||||
but which are used to store information about the node. For example, the `id` and `type` fields are node
|
||||
attributes.
|
||||
|
||||
The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app
|
||||
startup, and when generating the OpenAPI schema for the workflow editor.
|
||||
"""
|
||||
|
||||
Input = "input"
|
||||
Output = "output"
|
||||
Internal = "internal"
|
||||
NodeAttribute = "node_attribute"
|
||||
|
||||
|
||||
class InputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution,
|
||||
and by the workflow editor during schema parsing and UI rendering.
|
||||
"""
|
||||
|
||||
input: Input
|
||||
orig_required: bool
|
||||
field_kind: FieldKind
|
||||
default: Optional[Any] = None
|
||||
orig_default: Optional[Any] = None
|
||||
ui_hidden: bool = False
|
||||
ui_type: Optional[UIType] = None
|
||||
ui_component: Optional[UIComponent] = None
|
||||
ui_order: Optional[int] = None
|
||||
ui_choice_labels: Optional[dict[str, str]] = None
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
class WithMetadata(BaseModel):
|
||||
"""
|
||||
Inherit from this class if your node needs a metadata input field.
|
||||
"""
|
||||
|
||||
metadata: Optional[MetadataField] = Field(
|
||||
default=None,
|
||||
description=FieldDescriptions.metadata,
|
||||
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||
field_kind=FieldKind.Internal,
|
||||
input=Input.Connection,
|
||||
orig_required=False,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class WithWorkflow:
|
||||
workflow = None
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
logger.warn(
|
||||
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
|
||||
)
|
||||
super().__init_subclass__()
|
||||
|
||||
|
||||
class WithBoard(BaseModel):
|
||||
"""
|
||||
Inherit from this class if your node needs a board input field.
|
||||
"""
|
||||
|
||||
board: Optional[BoardField] = Field(
|
||||
default=None,
|
||||
description=FieldDescriptions.board,
|
||||
json_schema_extra=InputFieldJSONSchemaExtra(
|
||||
field_kind=FieldKind.Internal,
|
||||
input=Input.Direct,
|
||||
orig_required=False,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
class OutputFieldJSONSchemaExtra(BaseModel):
|
||||
"""
|
||||
Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor
|
||||
during schema parsing and UI rendering.
|
||||
"""
|
||||
|
||||
field_kind: FieldKind
|
||||
ui_hidden: bool
|
||||
ui_type: Optional[UIType]
|
||||
ui_order: Optional[int]
|
||||
|
||||
model_config = ConfigDict(
|
||||
validate_assignment=True,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
)
|
||||
|
||||
|
||||
def InputField(
|
||||
# copied from pydantic's Field
|
||||
# TODO: Can we support default_factory?
|
||||
default: Any = _Unset,
|
||||
default_factory: Callable[[], Any] | None = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
input: Input = Input.Any,
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_component: Optional[UIComponent] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
ui_choice_labels: Optional[dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an input field for an invocation.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param Input input: [Input.Any] The kind of input this field requires. \
|
||||
`Input.Direct` means a value must be provided on instantiation. \
|
||||
`Input.Connection` means the value must be provided by a connection. \
|
||||
`Input.Any` means either will do.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
:param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \
|
||||
The UI will always render a suitable component, but sometimes you want something different than the default. \
|
||||
For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \
|
||||
For this case, you could provide `UIComponent.Textarea`.
|
||||
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI.
|
||||
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI.
|
||||
|
||||
:param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field.
|
||||
"""
|
||||
|
||||
json_schema_extra_ = InputFieldJSONSchemaExtra(
|
||||
input=input,
|
||||
ui_type=ui_type,
|
||||
ui_component=ui_component,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
ui_choice_labels=ui_choice_labels,
|
||||
field_kind=FieldKind.Input,
|
||||
orig_required=True,
|
||||
)
|
||||
|
||||
"""
|
||||
There is a conflict between the typing of invocation definitions and the typing of an invocation's
|
||||
`invoke()` function.
|
||||
|
||||
On instantiation of a node, the invocation definition is used to create the python class. At this time,
|
||||
any number of fields may be optional, because they may be provided by connections.
|
||||
|
||||
On calling of `invoke()`, however, those fields may be required.
|
||||
|
||||
For example, consider an ResizeImageInvocation with an `image: ImageField` field.
|
||||
|
||||
`image` is required during the call to `invoke()`, but when the python class is instantiated,
|
||||
the field may not be present. This is fine, because that image field will be provided by a
|
||||
connection from an ancestor node, which outputs an image.
|
||||
|
||||
This means we want to type the `image` field as optional for the node class definition, but required
|
||||
for the `invoke()` function.
|
||||
|
||||
If we use `typing.Optional` in the node class definition, the field will be typed as optional in the
|
||||
`invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or
|
||||
any static type analysis tools will complain.
|
||||
|
||||
To get around this, in node class definitions, we type all fields correctly for the `invoke()` function,
|
||||
but secretly make them optional in `InputField()`. We also store the original required bool and/or default
|
||||
value. When we call `invoke()`, we use this stored information to do an additional check on the class.
|
||||
"""
|
||||
|
||||
if default_factory is not _Unset and default_factory is not None:
|
||||
default = default_factory()
|
||||
logger.warn('"default_factory" is not supported, calling it now to set "default"')
|
||||
|
||||
# These are the args we may wish pass to the pydantic `Field()` function
|
||||
field_args = {
|
||||
"default": default,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"pattern": pattern,
|
||||
"strict": strict,
|
||||
"gt": gt,
|
||||
"ge": ge,
|
||||
"lt": lt,
|
||||
"le": le,
|
||||
"multiple_of": multiple_of,
|
||||
"allow_inf_nan": allow_inf_nan,
|
||||
"max_digits": max_digits,
|
||||
"decimal_places": decimal_places,
|
||||
"min_length": min_length,
|
||||
"max_length": max_length,
|
||||
}
|
||||
|
||||
# We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected
|
||||
provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined}
|
||||
|
||||
# Because we are manually making fields optional, we need to store the original required bool for reference later
|
||||
json_schema_extra_.orig_required = default is PydanticUndefined
|
||||
|
||||
# Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one
|
||||
if input is Input.Any or input is Input.Connection:
|
||||
default_ = None if default is PydanticUndefined else default
|
||||
provided_args.update({"default": default_})
|
||||
if default is not PydanticUndefined:
|
||||
# Before invoking, we'll check for the original default value and set it on the field if the field has no value
|
||||
json_schema_extra_.default = default
|
||||
json_schema_extra_.orig_default = default
|
||||
elif default is not PydanticUndefined:
|
||||
default_ = default
|
||||
provided_args.update({"default": default_})
|
||||
json_schema_extra_.orig_default = default_
|
||||
|
||||
return Field(
|
||||
**provided_args,
|
||||
json_schema_extra=json_schema_extra_.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def OutputField(
|
||||
# copied from pydantic's Field
|
||||
default: Any = _Unset,
|
||||
title: str | None = _Unset,
|
||||
description: str | None = _Unset,
|
||||
pattern: str | None = _Unset,
|
||||
strict: bool | None = _Unset,
|
||||
gt: float | None = _Unset,
|
||||
ge: float | None = _Unset,
|
||||
lt: float | None = _Unset,
|
||||
le: float | None = _Unset,
|
||||
multiple_of: float | None = _Unset,
|
||||
allow_inf_nan: bool | None = _Unset,
|
||||
max_digits: int | None = _Unset,
|
||||
decimal_places: int | None = _Unset,
|
||||
min_length: int | None = _Unset,
|
||||
max_length: int | None = _Unset,
|
||||
# custom
|
||||
ui_type: Optional[UIType] = None,
|
||||
ui_hidden: bool = False,
|
||||
ui_order: Optional[int] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Creates an output field for an invocation output.
|
||||
|
||||
This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \
|
||||
that adds a few extra parameters to support graph execution and the node editor UI.
|
||||
|
||||
:param UIType ui_type: [None] Optionally provides an extra type hint for the UI. \
|
||||
In some situations, the field's type is not enough to infer the correct UI type. \
|
||||
For example, model selection fields should render a dropdown UI component to select a model. \
|
||||
Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \
|
||||
`MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \
|
||||
`UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field.
|
||||
|
||||
:param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \
|
||||
|
||||
:param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \
|
||||
"""
|
||||
return Field(
|
||||
default=default,
|
||||
title=title,
|
||||
description=description,
|
||||
pattern=pattern,
|
||||
strict=strict,
|
||||
gt=gt,
|
||||
ge=ge,
|
||||
lt=lt,
|
||||
le=le,
|
||||
multiple_of=multiple_of,
|
||||
allow_inf_nan=allow_inf_nan,
|
||||
max_digits=max_digits,
|
||||
decimal_places=decimal_places,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
json_schema_extra=OutputFieldJSONSchemaExtra(
|
||||
ui_type=ui_type,
|
||||
ui_hidden=ui_hidden,
|
||||
ui_order=ui_order,
|
||||
field_kind=FieldKind.Output,
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,16 +6,14 @@ from typing import Literal, Optional, get_args
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from invokeai.app.invocations.fields import ColorField, ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
|
||||
from invokeai.backend.image_util.lama import LaMA
|
||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
||||
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES
|
||||
|
||||
|
||||
@@ -120,8 +118,8 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int]
|
||||
return si
|
||||
|
||||
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class InfillColorInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image with a solid color"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@@ -131,20 +129,33 @@ class InfillColorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
|
||||
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
|
||||
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
|
||||
class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class InfillTileInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image with tiles of the image"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@@ -157,20 +168,33 @@ class InfillTileInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
|
||||
infilled.paste(image, (0, 0), image.split()[-1])
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1"
|
||||
"infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0"
|
||||
)
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image using the PatchMatch algorithm"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
@@ -178,7 +202,7 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name).convert("RGBA")
|
||||
image = context.services.images.get_pil_image(self.image.image_name).convert("RGBA")
|
||||
|
||||
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
|
||||
|
||||
@@ -203,38 +227,77 @@ class InfillPatchMatchInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
infilled.paste(image, (0, 0), mask=image.split()[-1])
|
||||
# image.paste(infilled, (0, 0), mask=image.split()[-1])
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class LaMaInfillInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image using the LaMa model"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
infilled = infill_lama(image.copy())
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
|
||||
class CV2InfillInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.0")
|
||||
class CV2InfillInvocation(BaseInvocation, WithMetadata):
|
||||
"""Infills transparent areas of an image using OpenCV Inpainting"""
|
||||
|
||||
image: ImageField = InputField(description="The image to infill")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
infilled = infill_cv2(image.copy())
|
||||
|
||||
image_dto = context.images.save(image=infilled)
|
||||
image_dto = context.services.images.create(
|
||||
image=infilled,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -7,14 +7,17 @@ from typing_extensions import Self
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType
|
||||
|
||||
|
||||
# LS: Consider moving these two classes into model.py
|
||||
@@ -56,7 +59,7 @@ class IPAdapterOutput(BaseInvocationOutput):
|
||||
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.2")
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.1.1")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
@@ -89,10 +92,10 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||
ip_adapter_info = context.services.model_manager.store.get_model(self.ip_adapter_model.key)
|
||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_models = context.models.search_by_attrs(
|
||||
image_encoder_models = context.services.model_manager.store.search_by_attr(
|
||||
model_name=image_encoder_model_name, base_model=BaseModelType.Any, model_type=ModelType.CLIPVision
|
||||
)
|
||||
assert len(image_encoder_models) == 1
|
||||
|
||||
@@ -23,33 +23,25 @@ from diffusers.models.attention_processor import (
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers import DPMSolverSDEScheduler
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from PIL import Image, ImageFilter
|
||||
from PIL import Image
|
||||
from pydantic import field_validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.primitives import (
|
||||
DenoiseMaskField,
|
||||
DenoiseMaskOutput,
|
||||
ImageField,
|
||||
ImageOutput,
|
||||
LatentsField,
|
||||
LatentsOutput,
|
||||
build_latents_output,
|
||||
)
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
@@ -71,9 +63,16 @@ from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
WithMetadata,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
@@ -82,10 +81,20 @@ if choose_torch_device() == torch.device("mps"):
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(SCHEDULER_MAP.keys())
|
||||
] # FIXME: "Invalid type alias". This defeats static type checking.
|
||||
|
||||
# HACK: Many nodes are currently hard-coded to use a fixed latent scale factor of 8. This is fragile, and will need to
|
||||
# be addressed if future models use a different latent scale factor. Also, note that there may be places where the scale
|
||||
# factor is hard-coded to a literal '8' rather than using this constant.
|
||||
# The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
|
||||
LATENT_SCALE_FACTOR = 8
|
||||
|
||||
|
||||
@invocation_output("scheduler_output")
|
||||
class SchedulerOutput(BaseInvocationOutput):
|
||||
scheduler: SCHEDULER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||
scheduler: SAMPLER_NAME_VALUES = OutputField(description=FieldDescriptions.scheduler, ui_type=UIType.Scheduler)
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -98,7 +107,7 @@ class SchedulerOutput(BaseInvocationOutput):
|
||||
class SchedulerInvocation(BaseInvocation):
|
||||
"""Selects a scheduler."""
|
||||
|
||||
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
@@ -113,7 +122,7 @@ class SchedulerInvocation(BaseInvocation):
|
||||
title="Create Denoise Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
@@ -128,7 +137,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
ui_order=4,
|
||||
)
|
||||
|
||||
def prep_mask_tensor(self, mask_image: Image.Image) -> torch.Tensor:
|
||||
def prep_mask_tensor(self, mask_image: Image) -> torch.Tensor:
|
||||
if mask_image.mode != "L":
|
||||
mask_image = mask_image.convert("L")
|
||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
@@ -141,7 +150,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
if self.image is not None:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
@@ -149,82 +158,33 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
image_tensor = None
|
||||
|
||||
mask = self.prep_mask_tensor(
|
||||
context.images.get_pil(self.mask.image_name),
|
||||
context.services.images.get_pil_image(self.mask.image_name),
|
||||
)
|
||||
|
||||
if image_tensor is not None:
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
vae_info = context.services.model_manager.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
# TODO:
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||
|
||||
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||
masked_latents_name = f"{context.graph_execution_state_id}__{self.id}_masked_latents"
|
||||
context.services.latents.save(masked_latents_name, masked_latents)
|
||||
else:
|
||||
masked_latents_name = None
|
||||
|
||||
mask_name = context.tensors.save(tensor=mask)
|
||||
mask_name = f"{context.graph_execution_state_id}__{self.id}_mask"
|
||||
context.services.latents.save(mask_name, mask)
|
||||
|
||||
return DenoiseMaskOutput.build(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=masked_latents_name,
|
||||
gradient=False,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"create_gradient_mask",
|
||||
title="Create Gradient Mask",
|
||||
tags=["mask", "denoise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class CreateGradientMaskInvocation(BaseInvocation):
|
||||
"""Creates mask for denoising model run."""
|
||||
|
||||
mask: ImageField = InputField(default=None, description="Image which will be masked", ui_order=1)
|
||||
edge_radius: int = InputField(
|
||||
default=16, ge=0, description="How far to blur/expand the edges of the mask", ui_order=2
|
||||
)
|
||||
coherence_mode: Literal["Gaussian Blur", "Box Blur", "Staged"] = InputField(default="Gaussian Blur", ui_order=3)
|
||||
minimum_denoise: float = InputField(
|
||||
default=0.0, ge=0, le=1, description="Minimum denoise level for the coherence region", ui_order=4
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
|
||||
mask_image = context.images.get_pil(self.mask.image_name, mode="L")
|
||||
if self.coherence_mode == "Box Blur":
|
||||
blur_mask = mask_image.filter(ImageFilter.BoxBlur(self.edge_radius))
|
||||
else: # Gaussian Blur OR Staged
|
||||
# Gaussian Blur uses standard deviation. 1/2 radius is a good approximation
|
||||
blur_mask = mask_image.filter(ImageFilter.GaussianBlur(self.edge_radius / 2))
|
||||
|
||||
mask_tensor: torch.Tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
|
||||
blur_tensor: torch.Tensor = image_resized_to_grid_as_tensor(blur_mask, normalize=False)
|
||||
|
||||
# redistribute blur so that the edges are 0 and blur out to 1
|
||||
blur_tensor = (blur_tensor - 0.5) * 2
|
||||
|
||||
threshold = 1 - self.minimum_denoise
|
||||
|
||||
if self.coherence_mode == "Staged":
|
||||
# wherever the blur_tensor is masked to any degree, convert it to threshold
|
||||
blur_tensor = torch.where((blur_tensor < 1), threshold, blur_tensor)
|
||||
else:
|
||||
# wherever the blur_tensor is above threshold but less than 1, drop it to threshold
|
||||
blur_tensor = torch.where((blur_tensor > threshold) & (blur_tensor < 1), threshold, blur_tensor)
|
||||
|
||||
# multiply original mask to force actually masked regions to 0
|
||||
blur_tensor = mask_tensor * blur_tensor
|
||||
|
||||
mask_name = context.tensors.save(tensor=blur_tensor.unsqueeze(1))
|
||||
|
||||
return DenoiseMaskOutput.build(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=None,
|
||||
gradient=True,
|
||||
return DenoiseMaskOutput(
|
||||
denoise_mask=DenoiseMaskField(
|
||||
mask_name=mask_name,
|
||||
masked_latents_name=masked_latents_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -235,7 +195,10 @@ def get_scheduler(
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.models.load(**scheduler_info.model_dump())
|
||||
orig_scheduler_info = context.services.model_manager.load_model_by_key(
|
||||
**scheduler_info.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
scheduler_config = orig_scheduler.config
|
||||
|
||||
@@ -265,7 +228,7 @@ def get_scheduler(
|
||||
title="Denoise Latents",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
version="1.5.2",
|
||||
version="1.5.1",
|
||||
)
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
@@ -293,7 +256,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
@@ -351,6 +314,22 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# TODO: pass this an emitter method or something? or a session for dispatching?
|
||||
def dispatch_progress(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
source_node_id: str,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
base_model: BaseModelType,
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
base_model=base_model,
|
||||
)
|
||||
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@@ -358,11 +337,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
unet: UNet2DConditionModel,
|
||||
seed: int,
|
||||
) -> ConditioningData:
|
||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
extra_conditioning_info = c.extra_conditioning
|
||||
|
||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||
|
||||
conditioning_data = ConditioningData(
|
||||
@@ -449,11 +428,16 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# and if weight is None, populate with default 1.0?
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(context.models.load(key=control_info.control_model.key))
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.load_model_by_key(
|
||||
key=control_info.control_model.key,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.images.get_pil(control_image_field.image_name)
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
@@ -511,17 +495,25 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
conditioning_data.ip_adapter_conditioning = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.models.load(key=single_ip_adapter.ip_adapter_model.key)
|
||||
context.services.model_manager.load_model_by_key(
|
||||
key=single_ip_adapter.ip_adapter_model.key,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.models.load(key=single_ip_adapter.image_encoder_model.key)
|
||||
image_encoder_model_info = context.services.model_manager.load_model_by_key(
|
||||
key=single_ip_adapter.image_encoder_model.key,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
single_ipa_image_fields = [single_ipa_image_fields]
|
||||
|
||||
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
|
||||
single_ipa_images = [
|
||||
context.services.images.get_pil_image(image.image_name) for image in single_ipa_image_fields
|
||||
]
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
@@ -565,20 +557,22 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
t2i_adapter_data = []
|
||||
for t2i_adapter_field in t2i_adapter:
|
||||
t2i_adapter_model_config = context.models.get_config(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||
t2i_adapter_loaded_model = context.models.load(key=t2i_adapter_field.t2i_adapter_model.key)
|
||||
image = context.images.get_pil(t2i_adapter_field.image.image_name)
|
||||
t2i_adapter_model_info = context.services.model_manager.load_model_by_key(
|
||||
key=t2i_adapter_field.t2i_adapter_model.key,
|
||||
context=context,
|
||||
)
|
||||
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
|
||||
|
||||
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
|
||||
if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1:
|
||||
if t2i_adapter_model_info.base == BaseModelType.StableDiffusion1:
|
||||
max_unet_downscale = 8
|
||||
elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL:
|
||||
elif t2i_adapter_model_info.base == BaseModelType.StableDiffusionXL:
|
||||
max_unet_downscale = 4
|
||||
else:
|
||||
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.")
|
||||
raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_info.base}'.")
|
||||
|
||||
t2i_adapter_model: T2IAdapter
|
||||
with t2i_adapter_loaded_model as t2i_adapter_model:
|
||||
with t2i_adapter_model_info as t2i_adapter_model:
|
||||
total_downscale_factor = t2i_adapter_model.total_downscale_factor
|
||||
|
||||
# Resize the T2I-Adapter input image.
|
||||
@@ -662,18 +656,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
def prep_inpaint_mask(
|
||||
self, context: InvocationContext, latents: torch.Tensor
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], bool]:
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
if self.denoise_mask is None:
|
||||
return None, None, False
|
||||
return None, None
|
||||
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
mask = context.services.latents.get(self.denoise_mask.mask_name)
|
||||
mask = tv_resize(mask, latents.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
if self.denoise_mask.masked_latents_name is not None:
|
||||
masked_latents = context.tensors.load(self.denoise_mask.masked_latents_name)
|
||||
masked_latents = context.services.latents.get(self.denoise_mask.masked_latents_name)
|
||||
else:
|
||||
masked_latents = None
|
||||
|
||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||
return 1 - mask, masked_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
@@ -681,11 +675,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed = None
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.tensors.load(self.noise.latents_name)
|
||||
noise = context.services.latents.get(self.noise.latents_name)
|
||||
seed = self.noise.seed
|
||||
|
||||
if self.latents is not None:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
if seed is None:
|
||||
seed = self.latents.seed
|
||||
|
||||
@@ -700,7 +694,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if seed is None:
|
||||
seed = 0
|
||||
|
||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
mask, masked_latents = self.prep_inpaint_mask(context, latents)
|
||||
|
||||
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
|
||||
# below. Investigate whether this is appropriate.
|
||||
@@ -711,20 +705,30 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
# Get the source node id (we are invoking the prepared node)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
|
||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
unet_config = context.services.model_manager.store.get_model(self.unet.unet.key)
|
||||
|
||||
def step_callback(state: PipelineIntermediateState) -> None:
|
||||
context.util.sd_step_callback(state, unet_config.base)
|
||||
self.dispatch_progress(context, source_node_id, state, unet_config.base)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(**lora.model_dump(exclude={"weight"}))
|
||||
lora_info = context.services.model_manager.load_model_by_key(
|
||||
**lora.model_dump(exclude={"weight"}),
|
||||
context=context,
|
||||
)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
return
|
||||
|
||||
unet_info = context.models.load(**self.unet.unet.model_dump())
|
||||
unet_info = context.services.model_manager.load_model_by_key(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
@@ -788,7 +792,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed=seed,
|
||||
mask=mask,
|
||||
masked_latents=masked_latents,
|
||||
gradient_mask=gradient_mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=controlnet_data,
|
||||
@@ -803,8 +806,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=result_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=seed)
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, result_latents)
|
||||
return build_latents_output(latents_name=name, latents=result_latents, seed=seed)
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -812,9 +816,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
title="Latents to Image",
|
||||
tags=["latents", "image", "vae", "l2i"],
|
||||
category="latents",
|
||||
version="1.2.1",
|
||||
version="1.2.0",
|
||||
)
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class LatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
@@ -830,9 +834,12 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
vae_info = context.services.model_manager.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
@@ -862,7 +869,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
if self.tiled or context.config.get().tiled_decode:
|
||||
if self.tiled or context.services.configuration.tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
@@ -886,9 +893,22 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=image)
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"]
|
||||
@@ -899,7 +919,7 @@ LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic",
|
||||
title="Resize Latents",
|
||||
tags=["latents", "resize"],
|
||||
category="latents",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ResizeLatentsInvocation(BaseInvocation):
|
||||
"""Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8."""
|
||||
@@ -922,7 +942,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
@@ -940,8 +960,10 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=resized_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -949,7 +971,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
title="Scale Latents",
|
||||
tags=["latents", "resize"],
|
||||
category="latents",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ScaleLatentsInvocation(BaseInvocation):
|
||||
"""Scales latents by a given factor."""
|
||||
@@ -963,7 +985,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
# TODO:
|
||||
device = choose_torch_device()
|
||||
@@ -982,8 +1004,10 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=resized_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, resized_latents)
|
||||
return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed)
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -991,7 +1015,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
title="Image to Latents",
|
||||
tags=["latents", "image", "vae", "i2l"],
|
||||
category="latents",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ImageToLatentsInvocation(BaseInvocation):
|
||||
"""Encodes an image into latents."""
|
||||
@@ -1053,9 +1077,12 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
vae_info = context.models.load(**self.vae.vae.model_dump())
|
||||
vae_info = context.services.model_manager.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
@@ -1063,9 +1090,10 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
|
||||
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
@singledispatchmethod
|
||||
@staticmethod
|
||||
@@ -1090,7 +1118,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
title="Blend Latents",
|
||||
tags=["latents", "blend"],
|
||||
category="latents",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||
@@ -1106,8 +1134,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
alpha: float = InputField(default=0.5, description=FieldDescriptions.blend_alpha)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents_a = context.tensors.load(self.latents_a.latents_name)
|
||||
latents_b = context.tensors.load(self.latents_b.latents_name)
|
||||
latents_a = context.services.latents.get(self.latents_a.latents_name)
|
||||
latents_b = context.services.latents.get(self.latents_b.latents_name)
|
||||
|
||||
if latents_a.shape != latents_b.shape:
|
||||
raise Exception("Latents to blend must be the same size.")
|
||||
@@ -1170,8 +1198,10 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=blended_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
context.services.latents.save(name, blended_latents)
|
||||
return build_latents_output(latents_name=name, latents=blended_latents)
|
||||
|
||||
|
||||
# The Crop Latents node was copied from @skunkworxdark's implementation here:
|
||||
@@ -1181,7 +1211,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
title="Crop Latents",
|
||||
tags=["latents", "crop"],
|
||||
category="latents",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
# TODO(ryand): Named `CropLatentsCoreInvocation` to prevent a conflict with custom node `CropLatentsInvocation`.
|
||||
# Currently, if the class names conflict then 'GET /openapi.json' fails.
|
||||
@@ -1216,7 +1246,7 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
x1 = self.x // LATENT_SCALE_FACTOR
|
||||
y1 = self.y // LATENT_SCALE_FACTOR
|
||||
@@ -1225,9 +1255,10 @@ class CropLatentsCoreInvocation(BaseInvocation):
|
||||
|
||||
cropped_latents = latents[..., y1:y2, x1:x2]
|
||||
|
||||
name = context.tensors.save(tensor=cropped_latents)
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, cropped_latents)
|
||||
|
||||
return LatentsOutput.build(latents_name=name, latents=cropped_latents)
|
||||
return build_latents_output(latents_name=name, latents=cropped_latents)
|
||||
|
||||
|
||||
@invocation_output("ideal_size_output")
|
||||
@@ -1259,7 +1290,10 @@ class IdealSizeInvocation(BaseInvocation):
|
||||
return tuple((x - x % multiple_of) for x in args)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IdealSizeOutput:
|
||||
unet_config = context.models.get_config(**self.unet.unet.model_dump())
|
||||
unet_config = context.services.model_manager.load_model_by_key(
|
||||
**self.unet.unet.model_dump(),
|
||||
context=context,
|
||||
)
|
||||
aspect = self.width / self.height
|
||||
dimension: float = 512
|
||||
if unet_config.base == BaseModelType.StableDiffusion2:
|
||||
|
||||
@@ -5,11 +5,10 @@ from typing import Literal
|
||||
import numpy as np
|
||||
from pydantic import ValidationInfo, field_validator
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField
|
||||
from invokeai.app.invocations.primitives import FloatOutput, IntegerOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation("add", title="Add Integers", tags=["math", "add"], category="math", version="1.0.0")
|
||||
|
||||
@@ -5,22 +5,20 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
MetadataField,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
InputField,
|
||||
MetadataField,
|
||||
OutputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
|
||||
from ...version import __version__
|
||||
|
||||
@@ -33,7 +31,7 @@ class MetadataItemField(BaseModel):
|
||||
class LoRAMetadataField(BaseModel):
|
||||
"""LoRA Metadata Field"""
|
||||
|
||||
model: LoRAModelField = Field(description=FieldDescriptions.lora_model)
|
||||
lora: LoRAModelField = Field(description=FieldDescriptions.lora_model)
|
||||
weight: float = Field(description=FieldDescriptions.lora_weight)
|
||||
|
||||
|
||||
@@ -114,7 +112,7 @@ GENERATION_MODES = Literal[
|
||||
]
|
||||
|
||||
|
||||
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.1.1")
|
||||
@invocation("core_metadata", title="Core Metadata", tags=["metadata"], category="metadata", version="1.0.1")
|
||||
class CoreMetadataInvocation(BaseInvocation):
|
||||
"""Collects core generation metadata into a MetadataField"""
|
||||
|
||||
|
||||
@@ -3,14 +3,17 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
|
||||
from ...backend.model_manager import SubModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -91,7 +94,7 @@ class LoRAModelField(BaseModel):
|
||||
title="Main Model",
|
||||
tags=["model"],
|
||||
category="model",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class MainModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
@@ -103,7 +106,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
||||
key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.models.exists(key):
|
||||
if not context.services.model_manager.store.exists(key):
|
||||
raise Exception(f"Unknown model {key}")
|
||||
|
||||
return ModelLoaderOutput(
|
||||
@@ -147,7 +150,7 @@ class LoraLoaderOutput(BaseInvocationOutput):
|
||||
clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
|
||||
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.1")
|
||||
@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.0")
|
||||
class LoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
|
||||
@@ -172,7 +175,7 @@ class LoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
if not context.services.model_manager.store.exists(lora_key):
|
||||
raise Exception(f"Unkown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
@@ -220,7 +223,7 @@ class SDXLLoraLoaderOutput(BaseInvocationOutput):
|
||||
title="SDXL LoRA",
|
||||
tags=["lora", "model"],
|
||||
category="model",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
"""Apply selected lora to unet and text_encoder."""
|
||||
@@ -252,7 +255,7 @@ class SDXLLoraLoaderInvocation(BaseInvocation):
|
||||
|
||||
lora_key = self.lora.key
|
||||
|
||||
if not context.models.exists(lora_key):
|
||||
if not context.services.model_manager.store.exists(lora_key):
|
||||
raise Exception(f"Unknown lora: {lora_key}!")
|
||||
|
||||
if self.unet is not None and any(lora.key == lora_key for lora in self.unet.loras):
|
||||
@@ -305,7 +308,7 @@ class VAEModelField(BaseModel):
|
||||
key: str = Field(description="Model's key")
|
||||
|
||||
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.1")
|
||||
@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.0")
|
||||
class VaeLoaderInvocation(BaseInvocation):
|
||||
"""Loads a VAE model, outputting a VaeLoaderOutput"""
|
||||
|
||||
@@ -318,7 +321,7 @@ class VaeLoaderInvocation(BaseInvocation):
|
||||
def invoke(self, context: InvocationContext) -> VAEOutput:
|
||||
key = self.vae_model.key
|
||||
|
||||
if not context.models.exists(key):
|
||||
if not context.services.model_manager.store.exists(key):
|
||||
raise Exception(f"Unkown vae: {key}!")
|
||||
|
||||
return VAEOutput(vae=VaeField(vae=ModelInfo(key=key)))
|
||||
|
||||
@@ -4,15 +4,17 @@
|
||||
import torch
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.invocations.latent import LatentsField
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
|
||||
from ...backend.util.devices import choose_torch_device, torch_dtype
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -67,13 +69,13 @@ class NoiseOutput(BaseInvocationOutput):
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
height: int = OutputField(description=FieldDescriptions.height)
|
||||
|
||||
@classmethod
|
||||
def build(cls, latents_name: str, latents: torch.Tensor, seed: int) -> "NoiseOutput":
|
||||
return cls(
|
||||
noise=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * LATENT_SCALE_FACTOR,
|
||||
height=latents.size()[2] * LATENT_SCALE_FACTOR,
|
||||
)
|
||||
|
||||
def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int):
|
||||
return NoiseOutput(
|
||||
noise=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -94,13 +96,13 @@ class NoiseInvocation(BaseInvocation):
|
||||
)
|
||||
width: int = InputField(
|
||||
default=512,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
height: int = InputField(
|
||||
default=512,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
multiple_of=8,
|
||||
gt=0,
|
||||
description=FieldDescriptions.height,
|
||||
)
|
||||
@@ -122,5 +124,6 @@ class NoiseInvocation(BaseInvocation):
|
||||
seed=self.seed,
|
||||
use_cpu=self.use_cpu,
|
||||
)
|
||||
name = context.tensors.save(tensor=noise)
|
||||
return NoiseOutput.build(latents_name=name, latents=noise, seed=self.seed)
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, noise)
|
||||
return build_noise_output(latents_name=name, latents=noise, seed=self.seed)
|
||||
|
||||
457
invokeai/app/invocations/onnx.py
Normal file
457
invokeai/app/invocations/onnx.py
Normal file
@@ -0,0 +1,457 @@
|
||||
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||
|
||||
import inspect
|
||||
|
||||
# from contextlib import ExitStack
|
||||
from typing import List, Literal, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager import ModelType, SubModelType
|
||||
from invokeai.backend.model_patcher import ONNXModelPatcher
|
||||
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util import choose_torch_device
|
||||
from ..util.ti_utils import extract_ti_triggers_from_prompt
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
UIType,
|
||||
WithMetadata,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
||||
from .model import ClipField, ModelInfo, UNetField, VaeField
|
||||
|
||||
ORT_TO_NP_TYPE = {
|
||||
"tensor(bool)": np.bool_,
|
||||
"tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8,
|
||||
"tensor(int16)": np.int16,
|
||||
"tensor(uint16)": np.uint16,
|
||||
"tensor(int32)": np.int32,
|
||||
"tensor(uint32)": np.uint32,
|
||||
"tensor(int64)": np.int64,
|
||||
"tensor(uint64)": np.uint64,
|
||||
"tensor(float16)": np.float16,
|
||||
"tensor(float)": np.float32,
|
||||
"tensor(double)": np.float64,
|
||||
}
|
||||
|
||||
PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())]
|
||||
|
||||
|
||||
@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0")
|
||||
class ONNXPromptInvocation(BaseInvocation):
|
||||
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
||||
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.tokenizer.model_dump(),
|
||||
)
|
||||
text_encoder_info = context.services.model_manager.load_model_by_key(
|
||||
**self.clip.text_encoder.model_dump(),
|
||||
)
|
||||
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack:
|
||||
loras = [
|
||||
(
|
||||
context.services.model_manager.load_model_by_key(**lora.model_dump(exclude={"weight"})).model,
|
||||
lora.weight,
|
||||
)
|
||||
for lora in self.clip.loras
|
||||
]
|
||||
|
||||
ti_list = []
|
||||
for trigger in extract_ti_triggers_from_prompt(self.prompt):
|
||||
name = trigger[1:-1]
|
||||
try:
|
||||
ti_list.append(
|
||||
(
|
||||
name,
|
||||
context.services.model_manager.load_model_by_attr(
|
||||
model_name=name,
|
||||
base_model=text_encoder_info.config.base,
|
||||
model_type=ModelType.TextualInversion,
|
||||
).model,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# print(e)
|
||||
# import traceback
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
if loras or ti_list:
|
||||
text_encoder.release_session()
|
||||
with (
|
||||
ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),
|
||||
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager),
|
||||
):
|
||||
text_encoder.create_session()
|
||||
|
||||
# copy from
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
|
||||
text_inputs = tokenizer(
|
||||
self.prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="np",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
"""
|
||||
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
||||
|
||||
if not np.array_equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
"""
|
||||
|
||||
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||
|
||||
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||
|
||||
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Text to image
|
||||
@invocation(
|
||||
"t2l_onnx",
|
||||
title="ONNX Text to Latents",
|
||||
tags=["latents", "inference", "txt2img", "onnx"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||
"""Generates latents from conditionings."""
|
||||
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond,
|
||||
input=Input.Connection,
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond,
|
||||
input=Input.Connection,
|
||||
)
|
||||
noise: LatentsField = InputField(
|
||||
description=FieldDescriptions.noise,
|
||||
input=Input.Connection,
|
||||
)
|
||||
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: Union[float, List[float]] = InputField(
|
||||
default=7.5,
|
||||
ge=1,
|
||||
description=FieldDescriptions.cfg_scale,
|
||||
)
|
||||
scheduler: SAMPLER_NAME_VALUES = InputField(
|
||||
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler
|
||||
)
|
||||
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
)
|
||||
control: Union[ControlField, list[ControlField]] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.control,
|
||||
)
|
||||
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
"""validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
# based on
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
||||
if isinstance(c, torch.Tensor):
|
||||
c = c.cpu().numpy()
|
||||
if isinstance(uc, torch.Tensor):
|
||||
uc = uc.cpu().numpy()
|
||||
device = torch.device(choose_torch_device())
|
||||
prompt_embeds = np.concatenate([uc, c])
|
||||
|
||||
latents = context.services.latents.get(self.noise.latents_name)
|
||||
if isinstance(latents, torch.Tensor):
|
||||
latents = latents.cpu().numpy()
|
||||
|
||||
# TODO: better execution device handling
|
||||
latents = latents.astype(ORT_TO_NP_TYPE[self.precision])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
do_classifier_free_guidance = True
|
||||
# latents_dtype = prompt_embeds.dtype
|
||||
# latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||
# if latents.shape != latents_shape:
|
||||
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=0, # TODO: refactor this node
|
||||
)
|
||||
|
||||
def torch2numpy(latent: torch.Tensor):
|
||||
return latent.cpu().numpy()
|
||||
|
||||
def numpy2torch(latent, device):
|
||||
return torch.from_numpy(latent).to(device)
|
||||
|
||||
def dispatch_progress(
|
||||
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
||||
) -> None:
|
||||
stable_diffusion_step_callback(
|
||||
context=context,
|
||||
intermediate_state=intermediate_state,
|
||||
node=self.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
scheduler.set_timesteps(self.steps)
|
||||
latents = latents * np.float64(scheduler.init_noise_sigma)
|
||||
|
||||
extra_step_kwargs = {}
|
||||
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
extra_step_kwargs.update(
|
||||
eta=0.0,
|
||||
)
|
||||
|
||||
unet_info = context.services.model_manager.load_model_by_key(**self.unet.unet.model_dump())
|
||||
|
||||
with unet_info as unet: # , ExitStack() as stack:
|
||||
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||
loras = [
|
||||
(
|
||||
context.services.model_manager.load_model_by_key(**lora.model_dump(exclude={"weight"})).model,
|
||||
lora.weight,
|
||||
)
|
||||
for lora in self.unet.loras
|
||||
]
|
||||
|
||||
if loras:
|
||||
unet.release_session()
|
||||
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
||||
# TODO:
|
||||
_, _, h, w = latents.shape
|
||||
unet.create_session(h, w)
|
||||
|
||||
timestep_dtype = next(
|
||||
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
||||
)
|
||||
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||
for i in tqdm(range(len(scheduler.timesteps))):
|
||||
t = scheduler.timesteps[i]
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
|
||||
latent_model_input = latent_model_input.cpu().numpy()
|
||||
|
||||
# predict the noise residual
|
||||
timestep = np.array([t], dtype=timestep_dtype)
|
||||
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
||||
noise_pred = noise_pred[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = scheduler.step(
|
||||
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
|
||||
)
|
||||
latents = torch2numpy(scheduler_output.prev_sample)
|
||||
|
||||
state = PipelineIntermediateState(
|
||||
run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample
|
||||
)
|
||||
dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state)
|
||||
|
||||
# call the callback, if provided
|
||||
# if callback is not None and i % callback_steps == 0:
|
||||
# callback(i, t, latents)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, latents)
|
||||
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
|
||||
|
||||
|
||||
# Latent to image
|
||||
@invocation(
|
||||
"l2i_onnx",
|
||||
title="ONNX Latents to Image",
|
||||
tags=["latents", "image", "vae", "onnx"],
|
||||
category="image",
|
||||
version="1.2.0",
|
||||
)
|
||||
class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Generates an image from latents."""
|
||||
|
||||
latents: LatentsField = InputField(
|
||||
description=FieldDescriptions.denoised_latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae: VaeField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.submodel}")
|
||||
|
||||
vae_info = context.services.model_manager.load_model_by_key(
|
||||
**self.vae.vae.model_dump(),
|
||||
)
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with vae_info as vae:
|
||||
vae.create_session()
|
||||
|
||||
# copied from
|
||||
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
|
||||
latents = 1 / 0.18215 * latents
|
||||
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||
image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])])
|
||||
|
||||
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||
image = image.transpose((0, 2, 3, 1))
|
||||
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("model_loader_output_onnx")
|
||||
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Model loader output"""
|
||||
|
||||
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
||||
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
||||
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
|
||||
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
|
||||
|
||||
|
||||
class OnnxModelField(BaseModel):
|
||||
"""Onnx model field"""
|
||||
|
||||
key: str = Field(description="Model ID")
|
||||
|
||||
|
||||
@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0")
|
||||
class OnnxModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads a main model, outputting its submodels."""
|
||||
|
||||
model: OnnxModelField = InputField(
|
||||
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||
model_key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.store.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
return ONNXModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
skipped_layers=0,
|
||||
),
|
||||
vae_decoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.VaeDecoder,
|
||||
),
|
||||
),
|
||||
vae_encoder=VaeField(
|
||||
vae=ModelInfo(
|
||||
key=model_key,
|
||||
submodel_type=SubModelType.VaeEncoder,
|
||||
),
|
||||
),
|
||||
)
|
||||
@@ -40,10 +40,8 @@ from easing_functions import (
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
|
||||
from invokeai.app.invocations.primitives import FloatCollectionOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -111,7 +109,7 @@ EASING_FUNCTION_KEYS = Literal[tuple(EASING_FUNCTIONS_MAP.keys())]
|
||||
title="Step Param Easing",
|
||||
tags=["step", "easing"],
|
||||
category="step",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class StepParamEasingInvocation(BaseInvocation):
|
||||
"""Experimental per-step parameter easing for denoising steps"""
|
||||
@@ -150,19 +148,19 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
postlist = list(num_poststeps * [self.post_end_value])
|
||||
|
||||
if log_diagnostics:
|
||||
context.logger.debug("start_step: " + str(start_step))
|
||||
context.logger.debug("end_step: " + str(end_step))
|
||||
context.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||
context.logger.debug("num_presteps: " + str(num_presteps))
|
||||
context.logger.debug("num_poststeps: " + str(num_poststeps))
|
||||
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||
context.logger.debug("prelist: " + str(prelist))
|
||||
context.logger.debug("postlist: " + str(postlist))
|
||||
context.services.logger.debug("start_step: " + str(start_step))
|
||||
context.services.logger.debug("end_step: " + str(end_step))
|
||||
context.services.logger.debug("num_easing_steps: " + str(num_easing_steps))
|
||||
context.services.logger.debug("num_presteps: " + str(num_presteps))
|
||||
context.services.logger.debug("num_poststeps: " + str(num_poststeps))
|
||||
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
||||
context.services.logger.debug("prelist: " + str(prelist))
|
||||
context.services.logger.debug("postlist: " + str(postlist))
|
||||
|
||||
easing_class = EASING_FUNCTIONS_MAP[self.easing]
|
||||
if log_diagnostics:
|
||||
context.logger.debug("easing class: " + str(easing_class))
|
||||
context.services.logger.debug("easing class: " + str(easing_class))
|
||||
easing_list = []
|
||||
if self.mirror: # "expected" mirroring
|
||||
# if number of steps is even, squeeze duration down to (number_of_steps)/2
|
||||
@@ -173,7 +171,7 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
|
||||
base_easing_duration = int(np.ceil(num_easing_steps / 2.0))
|
||||
if log_diagnostics:
|
||||
context.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
context.services.logger.debug("base easing duration: " + str(base_easing_duration))
|
||||
even_num_steps = num_easing_steps % 2 == 0 # even number of steps
|
||||
easing_function = easing_class(
|
||||
start=self.start_value,
|
||||
@@ -185,14 +183,14 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
easing_val = easing_function.ease(step_index)
|
||||
base_easing_vals.append(easing_val)
|
||||
if log_diagnostics:
|
||||
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
|
||||
if even_num_steps:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals))
|
||||
else:
|
||||
mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
|
||||
if log_diagnostics:
|
||||
context.logger.debug("base easing vals: " + str(base_easing_vals))
|
||||
context.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||
context.services.logger.debug("base easing vals: " + str(base_easing_vals))
|
||||
context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
|
||||
easing_list = base_easing_vals + mirror_easing_vals
|
||||
|
||||
# FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
|
||||
@@ -227,12 +225,12 @@ class StepParamEasingInvocation(BaseInvocation):
|
||||
step_val = easing_function.ease(step_index)
|
||||
easing_list.append(step_val)
|
||||
if log_diagnostics:
|
||||
context.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||
context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))
|
||||
|
||||
if log_diagnostics:
|
||||
context.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.logger.debug("easing_list size: " + str(len(easing_list)))
|
||||
context.logger.debug("postlist size: " + str(len(postlist)))
|
||||
context.services.logger.debug("prelist size: " + str(len(prelist)))
|
||||
context.services.logger.debug("easing_list size: " + str(len(easing_list)))
|
||||
context.services.logger.debug("postlist size: " + str(len(postlist)))
|
||||
|
||||
param_list = prelist + easing_list + postlist
|
||||
|
||||
|
||||
@@ -1,28 +1,20 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
ColorField,
|
||||
ConditioningField,
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
)
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -229,6 +221,18 @@ class StringCollectionInvocation(BaseInvocation):
|
||||
# region Image
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
"""An image primitive field"""
|
||||
|
||||
image_name: str = Field(description="The name of the image")
|
||||
|
||||
|
||||
class BoardField(BaseModel):
|
||||
"""A board primitive field"""
|
||||
|
||||
board_id: str = Field(description="The id of the board")
|
||||
|
||||
|
||||
@invocation_output("image_output")
|
||||
class ImageOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single image"""
|
||||
@@ -237,14 +241,6 @@ class ImageOutput(BaseInvocationOutput):
|
||||
width: int = OutputField(description="The width of the image in pixels")
|
||||
height: int = OutputField(description="The height of the image in pixels")
|
||||
|
||||
@classmethod
|
||||
def build(cls, image_dto: ImageDTO) -> "ImageOutput":
|
||||
return cls(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("image_collection_output")
|
||||
class ImageCollectionOutput(BaseInvocationOutput):
|
||||
@@ -255,14 +251,16 @@ class ImageCollectionOutput(BaseInvocationOutput):
|
||||
)
|
||||
|
||||
|
||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.1")
|
||||
class ImageInvocation(BaseInvocation):
|
||||
@invocation("image", title="Image Primitive", tags=["primitives", "image"], category="primitives", version="1.0.0")
|
||||
class ImageInvocation(
|
||||
BaseInvocation,
|
||||
):
|
||||
"""An image primitive value"""
|
||||
|
||||
image: ImageField = InputField(description="The image to load")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=self.image.image_name),
|
||||
@@ -292,44 +290,42 @@ class ImageCollectionInvocation(BaseInvocation):
|
||||
# region DenoiseMask
|
||||
|
||||
|
||||
class DenoiseMaskField(BaseModel):
|
||||
"""An inpaint mask field"""
|
||||
|
||||
mask_name: str = Field(description="The name of the mask image")
|
||||
masked_latents_name: Optional[str] = Field(default=None, description="The name of the masked image latents")
|
||||
|
||||
|
||||
@invocation_output("denoise_mask_output")
|
||||
class DenoiseMaskOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single image"""
|
||||
|
||||
denoise_mask: DenoiseMaskField = OutputField(description="Mask for denoise model run")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, mask_name: str, masked_latents_name: Optional[str] = None, gradient: bool = False
|
||||
) -> "DenoiseMaskOutput":
|
||||
return cls(
|
||||
denoise_mask=DenoiseMaskField(
|
||||
mask_name=mask_name, masked_latents_name=masked_latents_name, gradient=gradient
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Latents
|
||||
|
||||
|
||||
class LatentsField(BaseModel):
|
||||
"""A latents tensor primitive field"""
|
||||
|
||||
latents_name: str = Field(description="The name of the latents")
|
||||
seed: Optional[int] = Field(default=None, description="Seed used to generate this latents")
|
||||
|
||||
|
||||
@invocation_output("latents_output")
|
||||
class LatentsOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single latents tensor"""
|
||||
|
||||
latents: LatentsField = OutputField(description=FieldDescriptions.latents)
|
||||
latents: LatentsField = OutputField(
|
||||
description=FieldDescriptions.latents,
|
||||
)
|
||||
width: int = OutputField(description=FieldDescriptions.width)
|
||||
height: int = OutputField(description=FieldDescriptions.height)
|
||||
|
||||
@classmethod
|
||||
def build(cls, latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> "LatentsOutput":
|
||||
return cls(
|
||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * LATENT_SCALE_FACTOR,
|
||||
height=latents.size()[2] * LATENT_SCALE_FACTOR,
|
||||
)
|
||||
|
||||
|
||||
@invocation_output("latents_collection_output")
|
||||
class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
@@ -341,7 +337,7 @@ class LatentsCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
|
||||
@invocation(
|
||||
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.1"
|
||||
"latents", title="Latents Primitive", tags=["primitives", "latents"], category="primitives", version="1.0.0"
|
||||
)
|
||||
class LatentsInvocation(BaseInvocation):
|
||||
"""A latents tensor primitive value"""
|
||||
@@ -349,9 +345,9 @@ class LatentsInvocation(BaseInvocation):
|
||||
latents: LatentsField = InputField(description="The latents tensor", input=Input.Connection)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
latents = context.services.latents.get(self.latents.latents_name)
|
||||
|
||||
return LatentsOutput.build(self.latents.latents_name, latents)
|
||||
return build_latents_output(self.latents.latents_name, latents)
|
||||
|
||||
|
||||
@invocation(
|
||||
@@ -372,11 +368,31 @@ class LatentsCollectionInvocation(BaseInvocation):
|
||||
return LatentsCollectionOutput(collection=self.collection)
|
||||
|
||||
|
||||
def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int] = None) -> LatentsOutput:
|
||||
return LatentsOutput(
|
||||
latents=LatentsField(latents_name=latents_name, seed=seed),
|
||||
width=latents.size()[3] * 8,
|
||||
height=latents.size()[2] * 8,
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Color
|
||||
|
||||
|
||||
class ColorField(BaseModel):
|
||||
"""A color primitive field"""
|
||||
|
||||
r: int = Field(ge=0, le=255, description="The red component")
|
||||
g: int = Field(ge=0, le=255, description="The green component")
|
||||
b: int = Field(ge=0, le=255, description="The blue component")
|
||||
a: int = Field(ge=0, le=255, description="The alpha component")
|
||||
|
||||
def tuple(self) -> Tuple[int, int, int, int]:
|
||||
return (self.r, self.g, self.b, self.a)
|
||||
|
||||
|
||||
@invocation_output("color_output")
|
||||
class ColorOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single color"""
|
||||
@@ -408,16 +424,18 @@ class ColorInvocation(BaseInvocation):
|
||||
# region Conditioning
|
||||
|
||||
|
||||
class ConditioningField(BaseModel):
|
||||
"""A conditioning tensor primitive value"""
|
||||
|
||||
conditioning_name: str = Field(description="The name of conditioning tensor")
|
||||
|
||||
|
||||
@invocation_output("conditioning_output")
|
||||
class ConditioningOutput(BaseInvocationOutput):
|
||||
"""Base class for nodes that output a single conditioning tensor"""
|
||||
|
||||
conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond)
|
||||
|
||||
@classmethod
|
||||
def build(cls, conditioning_name: str) -> "ConditioningOutput":
|
||||
return cls(conditioning=ConditioningField(conditioning_name=conditioning_name))
|
||||
|
||||
|
||||
@invocation_output("conditioning_collection_output")
|
||||
class ConditioningCollectionOutput(BaseInvocationOutput):
|
||||
|
||||
@@ -6,10 +6,8 @@ from dynamicprompts.generators import CombinatorialPromptGenerator, RandomPrompt
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.primitives import StringCollectionOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, UIComponent
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
|
||||
|
||||
|
||||
@invocation(
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
from invokeai.backend.model_manager import SubModelType
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -30,7 +34,7 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||
vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.1")
|
||||
@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.0")
|
||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl base model, outputting its submodels."""
|
||||
|
||||
@@ -43,7 +47,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
model_key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.models.exists(model_key):
|
||||
if not context.services.model_manager.store.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
return SDXLModelLoaderOutput(
|
||||
@@ -96,7 +100,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
||||
title="SDXL Refiner Model",
|
||||
tags=["model", "sdxl", "refiner"],
|
||||
category="model",
|
||||
version="1.0.1",
|
||||
version="1.0.0",
|
||||
)
|
||||
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||
@@ -112,7 +116,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||
model_key = self.model.key
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.models.exists(model_key):
|
||||
if not context.services.model_manager.store.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
return SDXLRefinerModelLoaderOutput(
|
||||
|
||||
@@ -2,15 +2,16 @@
|
||||
|
||||
import re
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .fields import InputField, OutputField, UIComponent
|
||||
from .primitives import StringOutput
|
||||
|
||||
|
||||
|
||||
@@ -5,13 +5,17 @@ from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
|
||||
|
||||
class T2IAdapterModelField(BaseModel):
|
||||
|
||||
@@ -8,12 +8,16 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
WithMetadata,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import ImageField, Input, InputField, OutputField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.backend.tiles.tiles import (
|
||||
calc_tiles_even_split,
|
||||
calc_tiles_min_overlap,
|
||||
@@ -232,7 +236,7 @@ BLEND_MODES = Literal["Linear", "Seam"]
|
||||
version="1.1.0",
|
||||
classification=Classification.Beta,
|
||||
)
|
||||
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
class MergeTilesToImageInvocation(BaseInvocation, WithMetadata):
|
||||
"""Merge multiple tile images into a single image."""
|
||||
|
||||
# Inputs
|
||||
@@ -264,7 +268,7 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# existed in memory at an earlier point in the graph.
|
||||
tile_np_images: list[np.ndarray] = []
|
||||
for image in images:
|
||||
pil_image = context.images.get_pil(image.image_name)
|
||||
pil_image = context.services.images.get_pil_image(image.image_name)
|
||||
pil_image = pil_image.convert("RGB")
|
||||
tile_np_images.append(np.array(pil_image))
|
||||
|
||||
@@ -287,5 +291,18 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# Convert into a PIL image and save
|
||||
pil_image = Image.fromarray(np_image)
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
return ImageOutput.build(image_dto)
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -8,15 +8,13 @@ import torch
|
||||
from PIL import Image
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.backend.image_util.basicsr.rrdbnet_arch import RRDBNet
|
||||
from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
from .baseinvocation import BaseInvocation, invocation
|
||||
from .fields import InputField, WithBoard, WithMetadata
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation
|
||||
|
||||
# TODO: Populate this from disk?
|
||||
# TODO: Use model manager to load?
|
||||
@@ -31,8 +29,8 @@ if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.1")
|
||||
class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0")
|
||||
class ESRGANInvocation(BaseInvocation, WithMetadata):
|
||||
"""Upscales an image using RealESRGAN."""
|
||||
|
||||
image: ImageField = InputField(description="The input image")
|
||||
@@ -44,8 +42,8 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
models_path = context.config.get().models_path
|
||||
image = context.services.images.get_pil_image(self.image.image_name)
|
||||
models_path = context.services.configuration.models_path
|
||||
|
||||
rrdbnet_model = None
|
||||
netscale = None
|
||||
@@ -89,7 +87,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
netscale = 2
|
||||
else:
|
||||
msg = f"Invalid RealESRGAN model: {self.model_name}"
|
||||
context.logger.error(msg)
|
||||
context.services.logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
|
||||
@@ -112,6 +110,19 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
image_dto = context.images.save(image=pil_image)
|
||||
image_dto = context.services.images.create(
|
||||
image=pil_image,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
node_id=self.id,
|
||||
session_id=context.graph_execution_state_id,
|
||||
is_intermediate=self.is_intermediate,
|
||||
metadata=self.metadata,
|
||||
workflow=context.workflow,
|
||||
)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_name=image_dto.image_name),
|
||||
width=image_dto.width,
|
||||
height=image_dto.height,
|
||||
)
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class BulkDownloadBase(ABC):
|
||||
"""Responsible for creating a zip file containing the images specified by the given image names or board id."""
|
||||
|
||||
@abstractmethod
|
||||
def handler(
|
||||
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
|
||||
) -> None:
|
||||
"""
|
||||
Create a zip file containing the images specified by the given image names or board id.
|
||||
|
||||
:param image_names: A list of image names to include in the zip file.
|
||||
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
|
||||
:param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_path(self, bulk_download_item_name: str) -> str:
|
||||
"""
|
||||
Get the path to the bulk download file.
|
||||
|
||||
:param bulk_download_item_name: The name of the bulk download item.
|
||||
:return: The path to the bulk download file.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def generate_item_id(self, board_id: Optional[str]) -> str:
|
||||
"""
|
||||
Generate an item ID for a bulk download item.
|
||||
|
||||
:param board_id: The ID of the board whose name is to be included in the item id.
|
||||
:return: The generated item ID.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, bulk_download_item_name: str) -> None:
|
||||
"""
|
||||
Delete the bulk download file.
|
||||
|
||||
:param bulk_download_item_name: The name of the bulk download item.
|
||||
"""
|
||||
@@ -1,25 +0,0 @@
|
||||
DEFAULT_BULK_DOWNLOAD_ID = "default"
|
||||
|
||||
|
||||
class BulkDownloadException(Exception):
|
||||
"""Exception raised when a bulk download fails."""
|
||||
|
||||
def __init__(self, message="Bulk download failed"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class BulkDownloadTargetException(BulkDownloadException):
|
||||
"""Exception raised when a bulk download target is not found."""
|
||||
|
||||
def __init__(self, message="The bulk download target was not found"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class BulkDownloadParametersException(BulkDownloadException):
|
||||
"""Exception raised when a bulk download parameter is invalid."""
|
||||
|
||||
def __init__(self, message="No image names or board ID provided"):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
@@ -1,157 +0,0 @@
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Union
|
||||
from zipfile import ZipFile
|
||||
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecordNotFoundException
|
||||
from invokeai.app.services.bulk_download.bulk_download_common import (
|
||||
DEFAULT_BULK_DOWNLOAD_ID,
|
||||
BulkDownloadException,
|
||||
BulkDownloadParametersException,
|
||||
BulkDownloadTargetException,
|
||||
)
|
||||
from invokeai.app.services.image_records.image_records_common import ImageRecordNotFoundException
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
from .bulk_download_base import BulkDownloadBase
|
||||
|
||||
|
||||
class BulkDownloadService(BulkDownloadBase):
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def __init__(self):
|
||||
self._temp_directory = TemporaryDirectory()
|
||||
self._bulk_downloads_folder = Path(self._temp_directory.name) / "bulk_downloads"
|
||||
self._bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def handler(
|
||||
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
|
||||
) -> None:
|
||||
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
|
||||
bulk_download_item_id = bulk_download_item_id or uuid_string()
|
||||
bulk_download_item_name = bulk_download_item_id + ".zip"
|
||||
|
||||
self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
|
||||
|
||||
try:
|
||||
image_dtos: list[ImageDTO] = []
|
||||
|
||||
if board_id:
|
||||
image_dtos = self._board_handler(board_id)
|
||||
elif image_names:
|
||||
image_dtos = self._image_handler(image_names)
|
||||
else:
|
||||
raise BulkDownloadParametersException()
|
||||
|
||||
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
|
||||
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
|
||||
except (
|
||||
ImageRecordNotFoundException,
|
||||
BoardRecordNotFoundException,
|
||||
BulkDownloadException,
|
||||
BulkDownloadParametersException,
|
||||
) as e:
|
||||
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
|
||||
except Exception as e:
|
||||
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
|
||||
self._invoker.services.logger.error("Problem bulk downloading images.")
|
||||
raise e
|
||||
|
||||
def _image_handler(self, image_names: list[str]) -> list[ImageDTO]:
|
||||
return [self._invoker.services.images.get_dto(image_name) for image_name in image_names]
|
||||
|
||||
def _board_handler(self, board_id: str) -> list[ImageDTO]:
|
||||
image_names = self._invoker.services.board_image_records.get_all_board_image_names_for_board(board_id)
|
||||
return self._image_handler(image_names)
|
||||
|
||||
def generate_item_id(self, board_id: Optional[str]) -> str:
|
||||
return uuid_string() if board_id is None else self._get_clean_board_name(board_id) + "_" + uuid_string()
|
||||
|
||||
def _get_clean_board_name(self, board_id: str) -> str:
|
||||
if board_id == "none":
|
||||
return "Uncategorized"
|
||||
|
||||
return self._clean_string_to_path_safe(self._invoker.services.board_records.get(board_id).board_name)
|
||||
|
||||
def _create_zip_file(self, image_dtos: list[ImageDTO], bulk_download_item_id: str) -> str:
|
||||
"""
|
||||
Create a zip file containing the images specified by the given image names or board id.
|
||||
If download with the same bulk_download_id already exists, it will be overwritten.
|
||||
|
||||
:return: The name of the zip file.
|
||||
"""
|
||||
zip_file_name = bulk_download_item_id + ".zip"
|
||||
zip_file_path = self._bulk_downloads_folder / (zip_file_name)
|
||||
|
||||
with ZipFile(zip_file_path, "w") as zip_file:
|
||||
for image_dto in image_dtos:
|
||||
image_zip_path = Path(image_dto.image_category.value) / image_dto.image_name
|
||||
image_disk_path = self._invoker.services.images.get_path(image_dto.image_name)
|
||||
zip_file.write(image_disk_path, arcname=image_zip_path)
|
||||
|
||||
return str(zip_file_name)
|
||||
|
||||
# from https://stackoverflow.com/questions/7406102/create-sane-safe-filename-from-any-unsafe-string
|
||||
def _clean_string_to_path_safe(self, s: str) -> str:
|
||||
"""Clean a string to be path safe."""
|
||||
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip()
|
||||
|
||||
def _signal_job_started(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Signal that a bulk download job has started."""
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
self._invoker.services.events.emit_bulk_download_started(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
)
|
||||
|
||||
def _signal_job_completed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Signal that a bulk download job has completed."""
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert bulk_download_item_name is not None
|
||||
self._invoker.services.events.emit_bulk_download_completed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
)
|
||||
|
||||
def _signal_job_failed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception
|
||||
) -> None:
|
||||
"""Signal that a bulk download job has failed."""
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert exception is not None
|
||||
self._invoker.services.events.emit_bulk_download_failed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
error=str(exception),
|
||||
)
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self._temp_directory.cleanup()
|
||||
|
||||
def delete(self, bulk_download_item_name: str) -> None:
|
||||
path = self.get_path(bulk_download_item_name)
|
||||
Path(path).unlink()
|
||||
|
||||
def get_path(self, bulk_download_item_name: str) -> str:
|
||||
path = str(self._bulk_downloads_folder / bulk_download_item_name)
|
||||
if not self._is_valid_path(path):
|
||||
raise BulkDownloadTargetException()
|
||||
return path
|
||||
|
||||
def _is_valid_path(self, path: Union[str, Path]) -> bool:
|
||||
"""Validates the path given for a bulk download."""
|
||||
path = path if isinstance(path, Path) else Path(path)
|
||||
return path.exists()
|
||||
@@ -156,7 +156,6 @@ class InvokeAISettings(BaseSettings):
|
||||
"lora_dir",
|
||||
"embedding_dir",
|
||||
"controlnet_dir",
|
||||
"conf_path",
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -30,6 +30,7 @@ InvokeAI:
|
||||
lora_dir: null
|
||||
embedding_dir: null
|
||||
controlnet_dir: null
|
||||
conf_path: configs/models.yaml
|
||||
models_dir: models
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
db_dir: databases
|
||||
@@ -122,6 +123,7 @@ a Path object:
|
||||
|
||||
root_path - path to InvokeAI root
|
||||
output_path - path to default outputs directory
|
||||
model_conf_path - path to models.yaml
|
||||
conf - alias for the above
|
||||
embedding_path - path to the embeddings directory
|
||||
lora_path - path to the LoRA directory
|
||||
@@ -161,6 +163,7 @@ two configs are kept in separate sections of the config file:
|
||||
InvokeAI:
|
||||
Paths:
|
||||
root: /home/lstein/invokeai-main
|
||||
conf_path: configs/models.yaml
|
||||
legacy_conf_dir: configs/stable-diffusion
|
||||
outdir: outputs
|
||||
...
|
||||
@@ -234,6 +237,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
# PATHS
|
||||
root : Optional[Path] = Field(default=None, description='InvokeAI runtime root directory', json_schema_extra=Categories.Paths)
|
||||
autoimport_dir : Path = Field(default=Path('autoimport'), description='Path to a directory of models files to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
models_dir : Path = Field(default=Path('models'), description='Path to the models directory', json_schema_extra=Categories.Paths)
|
||||
convert_cache_dir : Path = Field(default=Path('models/.cache'), description='Path to the converted models cache directory', json_schema_extra=Categories.Paths)
|
||||
legacy_conf_dir : Path = Field(default=Path('configs/stable-diffusion'), description='Path to directory of legacy checkpoint config files', json_schema_extra=Categories.Paths)
|
||||
@@ -297,7 +301,6 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
lora_dir : Optional[Path] = Field(default=None, description='Path to a directory of LoRA/LyCORIS models to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
embedding_dir : Optional[Path] = Field(default=None, description='Path to a directory of Textual Inversion embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
controlnet_dir : Optional[Path] = Field(default=None, description='Path to a directory of ControlNet embeddings to be imported on startup.', json_schema_extra=Categories.Paths)
|
||||
conf_path : Path = Field(default=Path('configs/models.yaml'), description='Path to models definition file', json_schema_extra=Categories.Paths)
|
||||
|
||||
# this is not referred to in the source code and can be removed entirely
|
||||
#free_gpu_mem : Optional[bool] = Field(default=None, description="If true, purge model from GPU after each generation.", json_schema_extra=Categories.MemoryPerformance)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
@@ -16,7 +16,6 @@ from invokeai.backend.model_manager import AnyModelConfig
|
||||
|
||||
class EventServiceBase:
|
||||
queue_event: str = "queue_event"
|
||||
bulk_download_event: str = "bulk_download_event"
|
||||
download_event: str = "download_event"
|
||||
model_event: str = "model_event"
|
||||
|
||||
@@ -25,14 +24,6 @@ class EventServiceBase:
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Bulk download events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.bulk_download_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Queue events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
@@ -63,7 +54,7 @@ class EventServiceBase:
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
progress_image: Optional[ProgressImage],
|
||||
step: int,
|
||||
@@ -78,7 +69,7 @@ class EventServiceBase:
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node_id": node_id,
|
||||
"node_id": node.get("id"),
|
||||
"source_node_id": source_node_id,
|
||||
"progress_image": progress_image.model_dump() if progress_image is not None else None,
|
||||
"step": step,
|
||||
@@ -213,6 +204,52 @@ class EventServiceBase:
|
||||
},
|
||||
)
|
||||
|
||||
def emit_session_retrieval_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when session retrieval fails"""
|
||||
self.__emit_queue_event(
|
||||
event_name="session_retrieval_error",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_invocation_retrieval_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when invocation retrieval fails"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_retrieval_error",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node_id": node_id,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_session_canceled(
|
||||
self,
|
||||
queue_id: str,
|
||||
@@ -357,7 +394,6 @@ class EventServiceBase:
|
||||
bytes: int,
|
||||
total_bytes: int,
|
||||
parts: List[Dict[str, Union[str, int]]],
|
||||
id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Emit at intervals while the install job is in progress (remote models only).
|
||||
@@ -377,7 +413,6 @@ class EventServiceBase:
|
||||
"bytes": bytes,
|
||||
"total_bytes": total_bytes,
|
||||
"parts": parts,
|
||||
"id": id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -392,7 +427,7 @@ class EventServiceBase:
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
|
||||
def emit_model_install_completed(self, source: str, key: str, total_bytes: Optional[int] = None) -> None:
|
||||
"""
|
||||
Emit when an install job is completed successfully.
|
||||
|
||||
@@ -402,7 +437,11 @@ class EventServiceBase:
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_completed",
|
||||
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
|
||||
payload={
|
||||
"source": source,
|
||||
"total_bytes": total_bytes,
|
||||
"key": key,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_install_cancelled(self, source: str) -> None:
|
||||
@@ -416,7 +455,12 @@ class EventServiceBase:
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
|
||||
def emit_model_install_error(
|
||||
self,
|
||||
source: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""
|
||||
Emit when an install job encounters an exception.
|
||||
|
||||
@@ -426,45 +470,9 @@ class EventServiceBase:
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_error",
|
||||
payload={"source": source, "error_type": error_type, "error": error, "id": id},
|
||||
)
|
||||
|
||||
def emit_bulk_download_started(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk download starts"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_started",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_bulk_download_completed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk download completes"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_completed",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_bulk_download_failed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk download fails"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_failed",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
"source": source,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from PIL import Image, PngImagePlugin
|
||||
from PIL.Image import Image as PILImageType
|
||||
from send2trash import send2trash
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
|
||||
|
||||
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.invocations.metadata import MetadataField
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
|
||||
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
|
||||
|
||||
@@ -3,7 +3,7 @@ import threading
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Callable, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageRecord,
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.invocations.baseinvocation import MetadataField
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
@@ -154,7 +154,7 @@ class ImageService(ImageServiceABC):
|
||||
self.__invoker.services.logger.error("Image record not found")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image metadata")
|
||||
self.__invoker.services.logger.error("Problem getting image DTO")
|
||||
raise e
|
||||
|
||||
def get_workflow(self, image_name: str) -> Optional[WorkflowWithoutID]:
|
||||
|
||||
@@ -37,8 +37,7 @@ class MemoryInvocationCache(InvocationCacheBase):
|
||||
if self._max_cache_size == 0:
|
||||
return
|
||||
self._invoker.services.images.on_deleted(self._delete_by_match)
|
||||
self._invoker.services.tensors.on_deleted(self._delete_by_match)
|
||||
self._invoker.services.conditioning.on_deleted(self._delete_by_match)
|
||||
self._invoker.services.latents.on_deleted(self._delete_by_match)
|
||||
|
||||
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
|
||||
with self._lock:
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
from abc import ABC
|
||||
|
||||
|
||||
class InvocationProcessorABC(ABC): # noqa: B024
|
||||
pass
|
||||
@@ -0,0 +1,15 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,237 @@
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Event, Thread
|
||||
from typing import Optional
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import (
|
||||
GESStatsNotFoundError,
|
||||
)
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .invocation_processor_base import InvocationProcessorABC
|
||||
from .invocation_processor_common import CanceledException
|
||||
|
||||
|
||||
class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker_thread: Thread
|
||||
__stop_event: Event
|
||||
__invoker: Invoker
|
||||
__threadLimit: BoundedSemaphore
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
# if we do want multithreading at some point, we could make this configurable
|
||||
self.__threadLimit = BoundedSemaphore(1)
|
||||
self.__invoker = invoker
|
||||
self.__stop_event = Event()
|
||||
self.__invoker_thread = Thread(
|
||||
name="invoker_processor",
|
||||
target=self.__process,
|
||||
kwargs={"stop_event": self.__stop_event},
|
||||
)
|
||||
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
||||
self.__invoker_thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
def __process(self, stop_event: Event):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
queue_item: Optional[InvocationQueueItem] = None
|
||||
|
||||
profiler = (
|
||||
Profiler(
|
||||
logger=self.__invoker.services.logger,
|
||||
output_dir=self.__invoker.services.configuration.profiles_path,
|
||||
prefix=self.__invoker.services.configuration.profile_prefix,
|
||||
)
|
||||
if self.__invoker.services.configuration.profile_graphs
|
||||
else None
|
||||
)
|
||||
|
||||
def stats_cleanup(graph_execution_state_id: str) -> None:
|
||||
if profiler:
|
||||
profile_path = profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self.__invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=graph_execution_state_id, output_path=stats_path
|
||||
)
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self.__invoker.services.performance_statistics.log_stats(graph_execution_state_id)
|
||||
self.__invoker.services.performance_statistics.reset_stats(graph_execution_state_id)
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
queue_item = self.__invoker.services.queue.get()
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
|
||||
|
||||
if not queue_item: # Probably stopping
|
||||
# do not hammer the queue
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
if profiler and profiler.profile_id != queue_item.graph_execution_state_id:
|
||||
profiler.start(profile_id=queue_item.graph_execution_state_id)
|
||||
|
||||
try:
|
||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
||||
self.__invoker.services.events.emit_session_retrieval_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
invocation = graph_execution_state.execution_graph.get_node(queue_item.invocation_id)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_retrieval_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||
node_id=queue_item.invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
continue
|
||||
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_node_id = graph_execution_state.prepared_source_mapping[invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
)
|
||||
|
||||
# Invoke
|
||||
try:
|
||||
graph_id = graph_execution_state.id
|
||||
with self.__invoker.services.performance_statistics.collect_stats(invocation, graph_id):
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
||||
# which handles a few things:
|
||||
# - nodes that require a value, but get it only from a connection
|
||||
# - referencing the invocation cache instead of executing the node
|
||||
outputs = invocation.invoke_internal(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
workflow=queue_item.workflow,
|
||||
)
|
||||
)
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||
continue
|
||||
|
||||
# Save outputs and history
|
||||
graph_execution_state.complete(invocation.id, outputs)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
stats_cleanup(graph_execution_state.id)
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
logger.error(error)
|
||||
|
||||
# Save error
|
||||
graph_execution_state.set_node_error(invocation.id, error)
|
||||
|
||||
# Save the state changes
|
||||
self.__invoker.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
|
||||
# Check queue to see if this is canceled, and skip if so
|
||||
if self.__invoker.services.queue.is_canceled(graph_execution_state.id):
|
||||
continue
|
||||
|
||||
# Queue any further commands if invoking all
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
try:
|
||||
self.__invoker.invoke(
|
||||
session_queue_batch_id=queue_item.session_queue_batch_id,
|
||||
session_queue_item_id=queue_item.session_queue_item_id,
|
||||
session_queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state=graph_execution_state,
|
||||
workflow=queue_item.workflow,
|
||||
invoke_all=True,
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=source_node_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.session_queue_batch_id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
)
|
||||
stats_cleanup(graph_execution_state.id)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
finally:
|
||||
self.__threadLimit.release()
|
||||
@@ -0,0 +1,26 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from .invocation_queue_common import InvocationQueueItem
|
||||
|
||||
|
||||
class InvocationQueueABC(ABC):
|
||||
"""Abstract base class for all invocation queues"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self) -> InvocationQueueItem:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||
pass
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
||||
session_queue_item_id: int = Field(
|
||||
description="The ID of session queue item from which this invocation queue item came"
|
||||
)
|
||||
session_queue_batch_id: str = Field(
|
||||
description="The ID of the session batch from which this invocation queue item came"
|
||||
)
|
||||
workflow: Optional[WorkflowWithoutID] = Field(description="The workflow associated with this queue item")
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import time
|
||||
from queue import Queue
|
||||
from typing import Optional
|
||||
|
||||
from .invocation_queue_base import InvocationQueueABC
|
||||
from .invocation_queue_common import InvocationQueueItem
|
||||
|
||||
|
||||
class MemoryInvocationQueue(InvocationQueueABC):
|
||||
__queue: Queue
|
||||
__cancellations: dict[str, float]
|
||||
|
||||
def __init__(self):
|
||||
self.__queue = Queue()
|
||||
self.__cancellations = {}
|
||||
|
||||
def get(self) -> InvocationQueueItem:
|
||||
item = self.__queue.get()
|
||||
|
||||
while (
|
||||
isinstance(item, InvocationQueueItem)
|
||||
and item.graph_execution_state_id in self.__cancellations
|
||||
and self.__cancellations[item.graph_execution_state_id] > item.timestamp
|
||||
):
|
||||
item = self.__queue.get()
|
||||
|
||||
# Clear old items
|
||||
for graph_execution_state_id in list(self.__cancellations.keys()):
|
||||
if self.__cancellations[graph_execution_state_id] < item.timestamp:
|
||||
del self.__cancellations[graph_execution_state_id]
|
||||
|
||||
return item
|
||||
|
||||
def put(self, item: Optional[InvocationQueueItem]) -> None:
|
||||
self.__queue.put(item)
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
if graph_execution_state_id not in self.__cancellations:
|
||||
self.__cancellations[graph_execution_state_id] = time.time()
|
||||
|
||||
def is_canceled(self, graph_execution_state_id: str) -> bool:
|
||||
return graph_execution_state_id in self.__cancellations
|
||||
@@ -3,20 +3,13 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
from .board_image_records.board_image_records_base import BoardImageRecordStorageBase
|
||||
from .board_images.board_images_base import BoardImagesServiceABC
|
||||
from .board_records.board_records_base import BoardRecordStorageBase
|
||||
from .boards.boards_base import BoardServiceABC
|
||||
from .bulk_download.bulk_download_base import BulkDownloadBase
|
||||
from .config import InvokeAIAppConfig
|
||||
from .download import DownloadQueueServiceBase
|
||||
from .events.events_base import EventServiceBase
|
||||
@@ -24,11 +17,16 @@ if TYPE_CHECKING:
|
||||
from .image_records.image_records_base import ImageRecordStorageBase
|
||||
from .images.images_base import ImageServiceABC
|
||||
from .invocation_cache.invocation_cache_base import InvocationCacheBase
|
||||
from .invocation_processor.invocation_processor_base import InvocationProcessorABC
|
||||
from .invocation_queue.invocation_queue_base import InvocationQueueABC
|
||||
from .invocation_stats.invocation_stats_base import InvocationStatsServiceBase
|
||||
from .item_storage.item_storage_base import ItemStorageABC
|
||||
from .latents_storage.latents_storage_base import LatentsStorageBase
|
||||
from .model_manager.model_manager_base import ModelManagerServiceBase
|
||||
from .names.names_base import NameServiceBase
|
||||
from .session_processor.session_processor_base import SessionProcessorBase
|
||||
from .session_queue.session_queue_base import SessionQueueBase
|
||||
from .shared.graph import GraphExecutionState
|
||||
from .urls.urls_base import UrlServiceBase
|
||||
from .workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
|
||||
@@ -36,50 +34,77 @@ if TYPE_CHECKING:
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
# TODO: Just forward-declared everything due to circular dependencies. Fix structure.
|
||||
board_images: "BoardImagesServiceABC"
|
||||
board_image_record_storage: "BoardImageRecordStorageBase"
|
||||
boards: "BoardServiceABC"
|
||||
board_records: "BoardRecordStorageBase"
|
||||
configuration: "InvokeAIAppConfig"
|
||||
events: "EventServiceBase"
|
||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]"
|
||||
images: "ImageServiceABC"
|
||||
image_records: "ImageRecordStorageBase"
|
||||
image_files: "ImageFileStorageBase"
|
||||
latents: "LatentsStorageBase"
|
||||
logger: "Logger"
|
||||
model_manager: "ModelManagerServiceBase"
|
||||
download_queue: "DownloadQueueServiceBase"
|
||||
processor: "InvocationProcessorABC"
|
||||
performance_statistics: "InvocationStatsServiceBase"
|
||||
queue: "InvocationQueueABC"
|
||||
session_queue: "SessionQueueBase"
|
||||
session_processor: "SessionProcessorBase"
|
||||
invocation_cache: "InvocationCacheBase"
|
||||
names: "NameServiceBase"
|
||||
urls: "UrlServiceBase"
|
||||
workflow_records: "WorkflowRecordsStorageBase"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
board_images: "BoardImagesServiceABC",
|
||||
board_image_records: "BoardImageRecordStorageBase",
|
||||
boards: "BoardServiceABC",
|
||||
board_records: "BoardRecordStorageBase",
|
||||
bulk_download: "BulkDownloadBase",
|
||||
configuration: "InvokeAIAppConfig",
|
||||
events: "EventServiceBase",
|
||||
graph_execution_manager: "ItemStorageABC[GraphExecutionState]",
|
||||
images: "ImageServiceABC",
|
||||
image_files: "ImageFileStorageBase",
|
||||
image_records: "ImageRecordStorageBase",
|
||||
latents: "LatentsStorageBase",
|
||||
logger: "Logger",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
processor: "InvocationProcessorABC",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
queue: "InvocationQueueABC",
|
||||
session_queue: "SessionQueueBase",
|
||||
session_processor: "SessionProcessorBase",
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
names: "NameServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||
conditioning: "ObjectSerializerBase[ConditioningFieldData]",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.board_image_records = board_image_records
|
||||
self.boards = boards
|
||||
self.board_records = board_records
|
||||
self.bulk_download = bulk_download
|
||||
self.configuration = configuration
|
||||
self.events = events
|
||||
self.graph_execution_manager = graph_execution_manager
|
||||
self.images = images
|
||||
self.image_files = image_files
|
||||
self.image_records = image_records
|
||||
self.latents = latents
|
||||
self.logger = logger
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.processor = processor
|
||||
self.performance_statistics = performance_statistics
|
||||
self.queue = queue
|
||||
self.session_queue = session_queue
|
||||
self.session_processor = session_processor
|
||||
self.invocation_cache = invocation_cache
|
||||
self.names = names
|
||||
self.urls = urls
|
||||
self.workflow_records = workflow_records
|
||||
self.tensors = tensors
|
||||
self.conditioning = conditioning
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
Usage:
|
||||
|
||||
statistics = InvocationStatsService()
|
||||
statistics = InvocationStatsService(graph_execution_manager)
|
||||
with statistics.collect_stats(invocation, graph_execution_state.id):
|
||||
... execute graphs...
|
||||
statistics.log_stats()
|
||||
@@ -30,7 +30,7 @@ writes to the system log is stored in InvocationServices.performance_statistics.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import ContextManager
|
||||
from typing import Iterator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import InvocationStatsSummary
|
||||
@@ -50,7 +50,7 @@ class InvocationStatsServiceBase(ABC):
|
||||
self,
|
||||
invocation: BaseInvocation,
|
||||
graph_execution_state_id: str,
|
||||
) -> ContextManager[None]:
|
||||
) -> Iterator[None]:
|
||||
"""
|
||||
Return a context object that will capture the statistics on the execution
|
||||
of invocaation. Use with: to place around the part of the code that executes the invocation.
|
||||
@@ -60,8 +60,12 @@ class InvocationStatsServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset_stats(self):
|
||||
"""Reset all stored statistics."""
|
||||
def reset_stats(self, graph_execution_state_id: str) -> None:
|
||||
"""
|
||||
Reset all statistics for the indicated graph.
|
||||
:param graph_execution_state_id: The id of the session whose stats to reset.
|
||||
:raises GESStatsNotFoundError: if the graph isn't tracked in the stats.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from typing import Iterator
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.item_storage.item_storage_common import ItemNotFoundError
|
||||
from invokeai.backend.model_manager.load.model_cache import CacheStats
|
||||
|
||||
from .invocation_stats_base import InvocationStatsServiceBase
|
||||
@@ -41,7 +42,7 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self._invoker = invoker
|
||||
|
||||
@contextmanager
|
||||
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Generator[None, None, None]:
|
||||
def collect_stats(self, invocation: BaseInvocation, graph_execution_state_id: str) -> Iterator[None]:
|
||||
# This is to handle case of the model manager not being initialized, which happens
|
||||
# during some tests.
|
||||
services = self._invoker.services
|
||||
@@ -50,6 +51,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
self._stats[graph_execution_state_id] = GraphExecutionStats()
|
||||
self._cache_stats[graph_execution_state_id] = CacheStats()
|
||||
|
||||
# Prune stale stats. There should be none since we're starting a new graph, but just in case.
|
||||
self._prune_stale_stats()
|
||||
|
||||
# Record state before the invocation.
|
||||
start_time = time.time()
|
||||
start_ram = psutil.Process().memory_info().rss
|
||||
@@ -74,9 +78,42 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
||||
|
||||
def reset_stats(self):
|
||||
self._stats = {}
|
||||
self._cache_stats = {}
|
||||
def _prune_stale_stats(self) -> None:
|
||||
"""Check all graphs being tracked and prune any that have completed/errored.
|
||||
|
||||
This shouldn't be necessary, but we don't have totally robust upstream handling of graph completions/errors, so
|
||||
for now we call this function periodically to prevent them from accumulating.
|
||||
"""
|
||||
to_prune: list[str] = []
|
||||
for graph_execution_state_id in self._stats:
|
||||
try:
|
||||
graph_execution_state = self._invoker.services.graph_execution_manager.get(graph_execution_state_id)
|
||||
except ItemNotFoundError:
|
||||
# TODO(ryand): What would cause this? Should this exception just be allowed to propagate?
|
||||
logger.warning(f"Failed to get graph state for {graph_execution_state_id}.")
|
||||
continue
|
||||
|
||||
if not graph_execution_state.is_complete():
|
||||
# The graph is still running, don't prune it.
|
||||
continue
|
||||
|
||||
to_prune.append(graph_execution_state_id)
|
||||
|
||||
for graph_execution_state_id in to_prune:
|
||||
del self._stats[graph_execution_state_id]
|
||||
del self._cache_stats[graph_execution_state_id]
|
||||
|
||||
if len(to_prune) > 0:
|
||||
logger.info(f"Pruned stale graph stats for {to_prune}.")
|
||||
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
try:
|
||||
del self._stats[graph_execution_state_id]
|
||||
del self._cache_stats[graph_execution_state_id]
|
||||
except KeyError as e:
|
||||
raise GESStatsNotFoundError(
|
||||
f"Attempted to clear statistics for unknown graph {graph_execution_state_id}: {e}."
|
||||
) from e
|
||||
|
||||
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
||||
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
|
||||
|
||||
from .invocation_queue.invocation_queue_common import InvocationQueueItem
|
||||
from .invocation_services import InvocationServices
|
||||
from .shared.graph import Graph, GraphExecutionState
|
||||
|
||||
|
||||
class Invoker:
|
||||
@@ -13,6 +18,51 @@ class Invoker:
|
||||
self.services = services
|
||||
self._start()
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
session_queue_id: str,
|
||||
session_queue_item_id: int,
|
||||
session_queue_batch_id: str,
|
||||
graph_execution_state: GraphExecutionState,
|
||||
workflow: Optional[WorkflowWithoutID] = None,
|
||||
invoke_all: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||
|
||||
# Get the next invocation
|
||||
invocation = graph_execution_state.next()
|
||||
if not invocation:
|
||||
return None
|
||||
|
||||
# Save the execution state
|
||||
self.services.graph_execution_manager.set(graph_execution_state)
|
||||
|
||||
# Queue the invocation
|
||||
self.services.queue.put(
|
||||
InvocationQueueItem(
|
||||
session_queue_id=session_queue_id,
|
||||
session_queue_item_id=session_queue_item_id,
|
||||
session_queue_batch_id=session_queue_batch_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
workflow=workflow,
|
||||
invoke_all=invoke_all,
|
||||
)
|
||||
)
|
||||
|
||||
return invocation.id
|
||||
|
||||
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
|
||||
"""Creates a new execution state for the given graph"""
|
||||
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
|
||||
self.services.graph_execution_manager.set(new_state)
|
||||
return new_state
|
||||
|
||||
def cancel(self, graph_execution_state_id: str) -> None:
|
||||
"""Cancels the given execution state"""
|
||||
self.services.queue.cancel(graph_execution_state_id)
|
||||
|
||||
def __start_service(self, service) -> None:
|
||||
# Call start() method on any services that have it
|
||||
start_op = getattr(service, "start", None)
|
||||
@@ -35,3 +85,5 @@ class Invoker:
|
||||
# First stop all services
|
||||
for service in vars(self.services):
|
||||
self.__stop_service(getattr(self.services, service))
|
||||
|
||||
self.services.queue.put(None)
|
||||
|
||||
@@ -30,7 +30,7 @@ class ItemStorageABC(ABC, Generic[T]):
|
||||
@abstractmethod
|
||||
def set(self, item: T) -> None:
|
||||
"""
|
||||
Sets the item.
|
||||
Sets the item. The id will be extracted based on id_field.
|
||||
:param item: the item to set
|
||||
"""
|
||||
pass
|
||||
|
||||
0
invokeai/app/services/latents_storage/__init__.py
Normal file
0
invokeai/app/services/latents_storage/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.compel import ConditioningFieldData
|
||||
|
||||
|
||||
class LatentsStorageBase(ABC):
|
||||
"""Responsible for storing and retrieving latents."""
|
||||
|
||||
_on_changed_callbacks: list[Callable[[torch.Tensor], None]]
|
||||
_on_deleted_callbacks: list[Callable[[str], None]]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_changed_callbacks = []
|
||||
self._on_deleted_callbacks = []
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
# (LS) Added a Union with ConditioningFieldData to fix type mismatch errors in compel.py
|
||||
# Not 100% sure this isn't an existing bug.
|
||||
@abstractmethod
|
||||
def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name: str) -> None:
|
||||
pass
|
||||
|
||||
def on_changed(self, on_changed: Callable[[torch.Tensor], None]) -> None:
|
||||
"""Register a callback for when an item is changed"""
|
||||
self._on_changed_callbacks.append(on_changed)
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an item is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_changed(self, item: torch.Tensor) -> None:
|
||||
for callback in self._on_changed_callbacks:
|
||||
callback(item)
|
||||
|
||||
def _on_deleted(self, item_id: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(item_id)
|
||||
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.compel import ConditioningFieldData
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
|
||||
|
||||
class DiskLatentsStorage(LatentsStorageBase):
|
||||
"""Stores latents in a folder on disk without caching"""
|
||||
|
||||
__output_folder: Path
|
||||
|
||||
def __init__(self, output_folder: Union[str, Path]):
|
||||
self.__output_folder = output_folder if isinstance(output_folder, Path) else Path(output_folder)
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._delete_all_latents()
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
latent_path = self.get_path(name)
|
||||
return torch.load(latent_path)
|
||||
|
||||
def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None:
|
||||
self.__output_folder.mkdir(parents=True, exist_ok=True)
|
||||
latent_path = self.get_path(name)
|
||||
torch.save(data, latent_path)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
latent_path = self.get_path(name)
|
||||
latent_path.unlink()
|
||||
|
||||
def get_path(self, name: str) -> Path:
|
||||
return self.__output_folder / name
|
||||
|
||||
def _delete_all_latents(self) -> None:
|
||||
"""
|
||||
Deletes all latents from disk.
|
||||
Must be called after we have access to `self._invoker` (e.g. in `start()`).
|
||||
"""
|
||||
deleted_latents_count = 0
|
||||
freed_space = 0
|
||||
for latents_file in Path(self.__output_folder).glob("*"):
|
||||
if latents_file.is_file():
|
||||
freed_space += latents_file.stat().st_size
|
||||
deleted_latents_count += 1
|
||||
latents_file.unlink()
|
||||
if deleted_latents_count > 0:
|
||||
freed_space_in_mb = round(freed_space / 1024 / 1024, 2)
|
||||
self._invoker.services.logger.info(
|
||||
f"Deleted {deleted_latents_count} latents files (freed {freed_space_in_mb}MB)"
|
||||
)
|
||||
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from queue import Queue
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.invocations.compel import ConditioningFieldData
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
from .latents_storage_base import LatentsStorageBase
|
||||
|
||||
|
||||
class ForwardCacheLatentsStorage(LatentsStorageBase):
|
||||
"""Caches the latest N latents in memory, writing-thorugh to and reading from underlying storage"""
|
||||
|
||||
__cache: Dict[str, torch.Tensor]
|
||||
__cache_ids: Queue
|
||||
__max_cache_size: int
|
||||
__underlying_storage: LatentsStorageBase
|
||||
|
||||
def __init__(self, underlying_storage: LatentsStorageBase, max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self.__underlying_storage = underlying_storage
|
||||
self.__cache = {}
|
||||
self.__cache_ids = Queue()
|
||||
self.__max_cache_size = max_cache_size
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
start_op = getattr(self.__underlying_storage, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(invoker)
|
||||
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
stop_op = getattr(self.__underlying_storage, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(invoker)
|
||||
|
||||
def get(self, name: str) -> torch.Tensor:
|
||||
cache_item = self.__get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
latent = self.__underlying_storage.get(name)
|
||||
self.__set_cache(name, latent)
|
||||
return latent
|
||||
|
||||
# TODO: (LS) ConditioningFieldData added as Union because of type-checking errors
|
||||
# in compel.py. Unclear whether this is a long-standing bug, but seems to run.
|
||||
def save(self, name: str, data: Union[torch.Tensor, ConditioningFieldData]) -> None:
|
||||
self.__underlying_storage.save(name, data)
|
||||
self.__set_cache(name, data)
|
||||
self._on_changed(data)
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self.__underlying_storage.delete(name)
|
||||
if name in self.__cache:
|
||||
del self.__cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def __get_cache(self, name: str) -> Optional[torch.Tensor]:
|
||||
return None if name not in self.__cache else self.__cache[name]
|
||||
|
||||
def __set_cache(self, name: str, data: torch.Tensor):
|
||||
if name not in self.__cache:
|
||||
self.__cache[name] = data
|
||||
self.__cache_ids.put(name)
|
||||
if self.__cache_ids.qsize() > self.__max_cache_size:
|
||||
self.__cache.pop(self.__cache_ids.get())
|
||||
@@ -156,7 +156,6 @@ class ModelInstallJob(BaseModel):
|
||||
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
)
|
||||
@@ -178,12 +177,6 @@ class ModelInstallJob(BaseModel):
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||
)
|
||||
error_traceback: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the exception traceback"
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
@@ -191,10 +184,7 @@ class ModelInstallJob(BaseModel):
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self._exception = e
|
||||
self.error = str(e)
|
||||
self.error_traceback = self._format_error(e)
|
||||
self.status = InstallStatus.ERROR
|
||||
self.error_reason = self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
@@ -205,9 +195,10 @@ class ModelInstallJob(BaseModel):
|
||||
"""Class name of the exception that led to status==ERROR."""
|
||||
return self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def _format_error(self, exception: Exception) -> str:
|
||||
@property
|
||||
def error(self) -> Optional[str]:
|
||||
"""Error traceback."""
|
||||
return "".join(traceback.format_exception(exception))
|
||||
return "".join(traceback.format_exception(self._exception)) if self._exception else None
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
|
||||
@@ -154,12 +154,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
info: AnyModelConfig = self._probe_model(Path(model_path), config)
|
||||
old_hash = info.current_hash
|
||||
|
||||
if preferred_name := config.get("name"):
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
|
||||
dest_path = (
|
||||
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
|
||||
self.app_config.models_path / info.base.value / info.type.value / (config.get("name") or model_path.name)
|
||||
)
|
||||
try:
|
||||
new_path = self._copy_model(model_path, dest_path)
|
||||
@@ -542,10 +538,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def _register(
|
||||
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
|
||||
) -> str:
|
||||
key = self._create_key()
|
||||
if config and not config.get("key", None):
|
||||
config["key"] = key
|
||||
info = info or ModelProbe.probe(model_path, config)
|
||||
key = self._create_key()
|
||||
|
||||
model_path = model_path.absolute()
|
||||
if model_path.is_relative_to(self.app_config.models_path):
|
||||
@@ -558,8 +552,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# make config relative to our root
|
||||
legacy_conf = (self.app_config.root_dir / self.app_config.legacy_conf_dir / info.config).resolve()
|
||||
info.config = legacy_conf.relative_to(self.app_config.root_dir).as_posix()
|
||||
self.record_store.add_model(info.key, info)
|
||||
return info.key
|
||||
self.record_store.add_model(key, info)
|
||||
return key
|
||||
|
||||
def _next_id(self) -> int:
|
||||
with self._lock:
|
||||
@@ -743,7 +737,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._signal_job_downloading(install_job)
|
||||
|
||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"{download_job.source}: model download complete")
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
@@ -776,7 +769,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if not install_job:
|
||||
return
|
||||
self._downloads_changed_event.set()
|
||||
self._logger.warning(f"{download_job.source}: model download cancelled")
|
||||
self._logger.warning(f"Download {download_job.source} cancelled.")
|
||||
# if install job has already registered an error, then do not replace its status with cancelled
|
||||
if not install_job.errored:
|
||||
install_job.cancel()
|
||||
@@ -823,7 +816,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
id=job.id,
|
||||
)
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||
@@ -836,7 +828,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
assert job.local_path is not None
|
||||
assert job.config_out is not None
|
||||
key = job.config_out.key
|
||||
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
|
||||
self._event_bus.emit_model_install_completed(str(job.source), key)
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||
self._logger.info(f"{job.source}: model installation encountered an exception: {job.error_type}\n{job.error}")
|
||||
@@ -845,7 +837,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
error = job.error
|
||||
assert error_type is not None
|
||||
assert error is not None
|
||||
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
|
||||
self._event_bus.emit_model_install_error(str(job.source), error_type, error)
|
||||
|
||||
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
|
||||
self._logger.info(f"{job.source}: model installation was cancelled")
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
@@ -19,14 +19,14 @@ class ModelLoadServiceBase(ABC):
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context_data: Invocation context data used for event reporting
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
|
||||
@property
|
||||
|
||||
@@ -3,15 +3,11 @@
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
ModelLoaderRegistry,
|
||||
ModelLoaderRegistryBase,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import LoadedModel, ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -38,9 +34,6 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._convert_cache = convert_cache
|
||||
self._registry = registry
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
@property
|
||||
def ram_cache(self) -> ModelCacheBase[AnyModel]:
|
||||
"""Return the RAM cache used by this loader."""
|
||||
@@ -55,7 +48,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
@@ -64,9 +57,9 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
if context_data:
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
@@ -78,9 +71,9 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
convert_cache=self._convert_cache,
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
if context_data:
|
||||
if context:
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
context=context,
|
||||
model_config=model_config,
|
||||
loaded=True,
|
||||
)
|
||||
@@ -88,26 +81,26 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context_data: InvocationContextData,
|
||||
context: InvocationContext,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
) -> None:
|
||||
if not self._invoker:
|
||||
return
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException()
|
||||
|
||||
if not loaded:
|
||||
self._invoker.services.events.emit_model_load_started(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
else:
|
||||
self._invoker.services.events.emit_model_load_completed(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
from ..download import DownloadQueueServiceBase
|
||||
@@ -17,6 +12,7 @@ from ..events.events_base import EventServiceBase
|
||||
from ..model_install import ModelInstallServiceBase
|
||||
from ..model_load import ModelLoadServiceBase
|
||||
from ..model_records import ModelRecordServiceBase
|
||||
from ..shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class ModelManagerServiceBase(ABC):
|
||||
@@ -32,10 +28,9 @@ class ModelManagerServiceBase(ABC):
|
||||
def build_model_manager(
|
||||
cls,
|
||||
app_config: InvokeAIAppConfig,
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
db: SqliteDatabase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: torch.device,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
@@ -70,32 +65,3 @@ class ModelManagerServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def stop(self, invoker: Invoker) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_config(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
pass
|
||||
|
||||
@@ -3,14 +3,12 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import InvocationContext
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, LoadedModel, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
@@ -68,18 +66,18 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
return self.load.load_model(model_config, submodel_type, context_data)
|
||||
return self.load.load_model(model_config, submodel_type, context)
|
||||
|
||||
def load_model_by_key(
|
||||
self,
|
||||
key: str,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
config = self.store.get_model(key)
|
||||
return self.load.load_model(config, submodel_type, context_data)
|
||||
return self.load.load_model(config, submodel_type, context)
|
||||
|
||||
def load_model_by_attr(
|
||||
self,
|
||||
@@ -87,7 +85,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
submodel: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
context: Optional[InvocationContext] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's attributes, search the database for it, and if found, load and return the LoadedModel object.
|
||||
@@ -112,7 +110,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
elif len(configs) > 1:
|
||||
raise ValueError(f"{base_model}/{model_type}/{model_name}: More than one model matches.")
|
||||
else:
|
||||
return self.load.load_model(configs[0], submodel, context_data)
|
||||
return self.load.load_model(configs[0], submodel, context)
|
||||
|
||||
@classmethod
|
||||
def build_model_manager(
|
||||
@@ -121,7 +119,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: torch.device = choose_torch_device(),
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
@@ -132,10 +129,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
logger.setLevel(app_config.log_level.upper())
|
||||
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram_cache_size,
|
||||
max_vram_cache_size=app_config.vram_cache_size,
|
||||
logger=logger,
|
||||
execution_device=execution_device,
|
||||
max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=app_config.models_convert_cache_path, max_size=app_config.convert_cache_size
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ObjectSerializerBase(ABC, Generic[T]):
|
||||
"""Saves and loads arbitrary python objects."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._on_deleted_callbacks: list[Callable[[str], None]] = []
|
||||
|
||||
@abstractmethod
|
||||
def load(self, name: str) -> T:
|
||||
"""
|
||||
Loads the object.
|
||||
:param name: The name of the object to load.
|
||||
:raises ObjectNotFoundError: if the object is not found
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, obj: T) -> str:
|
||||
"""
|
||||
Saves the object, returning its name.
|
||||
:param obj: The object to save.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, name: str) -> None:
|
||||
"""
|
||||
Deletes the object, if it exists.
|
||||
:param name: The name of the object to delete.
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_deleted(self, on_deleted: Callable[[str], None]) -> None:
|
||||
"""Register a callback for when an object is deleted"""
|
||||
self._on_deleted_callbacks.append(on_deleted)
|
||||
|
||||
def _on_deleted(self, name: str) -> None:
|
||||
for callback in self._on_deleted_callbacks:
|
||||
callback(name)
|
||||
@@ -1,5 +0,0 @@
|
||||
class ObjectNotFoundError(KeyError):
|
||||
"""Raised when an object is not found while loading"""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(f"Object with name {name} not found")
|
||||
@@ -1,85 +0,0 @@
|
||||
import tempfile
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeleteAllResult:
|
||||
deleted_count: int
|
||||
freed_space_bytes: float
|
||||
|
||||
|
||||
class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
"""Disk-backed storage for arbitrary python objects. Serialization is handled by `torch.save` and `torch.load`.
|
||||
|
||||
:param output_dir: The folder where the serialized objects will be stored
|
||||
:param ephemeral: If True, objects will be stored in a temporary directory inside the given output_dir and cleaned up on exit
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir: Path, ephemeral: bool = False):
|
||||
super().__init__()
|
||||
self._ephemeral = ephemeral
|
||||
self._base_output_dir = output_dir
|
||||
self._base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Must specify `ignore_cleanup_errors` to avoid fatal errors during cleanup on Windows
|
||||
self._tempdir = (
|
||||
tempfile.TemporaryDirectory(dir=self._base_output_dir, ignore_cleanup_errors=True) if ephemeral else None
|
||||
)
|
||||
self._output_dir = Path(self._tempdir.name) if self._tempdir else self._base_output_dir
|
||||
self.__obj_class_name: Optional[str] = None
|
||||
|
||||
def load(self, name: str) -> T:
|
||||
file_path = self._get_path(name)
|
||||
try:
|
||||
return torch.load(file_path) # pyright: ignore [reportUnknownMemberType]
|
||||
except FileNotFoundError as e:
|
||||
raise ObjectNotFoundError(name) from e
|
||||
|
||||
def save(self, obj: T) -> str:
|
||||
name = self._new_name()
|
||||
file_path = self._get_path(name)
|
||||
torch.save(obj, file_path) # pyright: ignore [reportUnknownMemberType]
|
||||
return name
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
file_path = self._get_path(name)
|
||||
file_path.unlink()
|
||||
|
||||
@property
|
||||
def _obj_class_name(self) -> str:
|
||||
if not self.__obj_class_name:
|
||||
# `__orig_class__` is not available in the constructor for some technical, undoubtedly very pythonic reason
|
||||
self.__obj_class_name = typing.get_args(self.__orig_class__)[0].__name__ # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue]
|
||||
return self.__obj_class_name
|
||||
|
||||
def _get_path(self, name: str) -> Path:
|
||||
return self._output_dir / name
|
||||
|
||||
def _new_name(self) -> str:
|
||||
return f"{self._obj_class_name}_{uuid_string()}"
|
||||
|
||||
def _tempdir_cleanup(self) -> None:
|
||||
"""Calls `cleanup` on the temporary directory, if it exists."""
|
||||
if self._tempdir:
|
||||
self._tempdir.cleanup()
|
||||
|
||||
def __del__(self) -> None:
|
||||
# In case the service is not properly stopped, clean up the temporary directory when the class instance is GC'd.
|
||||
self._tempdir_cleanup()
|
||||
|
||||
def stop(self, invoker: "Invoker") -> None:
|
||||
self._tempdir_cleanup()
|
||||
@@ -1,65 +0,0 @@
|
||||
from queue import Queue
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
|
||||
|
||||
class ObjectSerializerForwardCache(ObjectSerializerBase[T]):
|
||||
"""
|
||||
Provides a LRU cache for an instance of `ObjectSerializerBase`.
|
||||
Saving an object to the cache always writes through to the underlying storage.
|
||||
"""
|
||||
|
||||
def __init__(self, underlying_storage: ObjectSerializerBase[T], max_cache_size: int = 20):
|
||||
super().__init__()
|
||||
self._underlying_storage = underlying_storage
|
||||
self._cache: dict[str, T] = {}
|
||||
self._cache_ids = Queue[str]()
|
||||
self._max_cache_size = max_cache_size
|
||||
|
||||
def start(self, invoker: "Invoker") -> None:
|
||||
self._invoker = invoker
|
||||
start_op = getattr(self._underlying_storage, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(invoker)
|
||||
|
||||
def stop(self, invoker: "Invoker") -> None:
|
||||
self._invoker = invoker
|
||||
stop_op = getattr(self._underlying_storage, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(invoker)
|
||||
|
||||
def load(self, name: str) -> T:
|
||||
cache_item = self._get_cache(name)
|
||||
if cache_item is not None:
|
||||
return cache_item
|
||||
|
||||
obj = self._underlying_storage.load(name)
|
||||
self._set_cache(name, obj)
|
||||
return obj
|
||||
|
||||
def save(self, obj: T) -> str:
|
||||
name = self._underlying_storage.save(obj)
|
||||
self._set_cache(name, obj)
|
||||
return name
|
||||
|
||||
def delete(self, name: str) -> None:
|
||||
self._underlying_storage.delete(name)
|
||||
if name in self._cache:
|
||||
del self._cache[name]
|
||||
self._on_deleted(name)
|
||||
|
||||
def _get_cache(self, name: str) -> Optional[T]:
|
||||
return None if name not in self._cache else self._cache[name]
|
||||
|
||||
def _set_cache(self, name: str, data: T):
|
||||
if name not in self._cache:
|
||||
self._cache[name] = data
|
||||
self._cache_ids.put(name)
|
||||
if self._cache_ids.qsize() > self._max_cache_size:
|
||||
self._cache.pop(self._cache_ids.get())
|
||||
@@ -4,17 +4,3 @@ from pydantic import BaseModel, Field
|
||||
class SessionProcessorStatus(BaseModel):
|
||||
is_started: bool = Field(description="Whether the session processor is started")
|
||||
is_processing: bool = Field(description="Whether a session is being processed")
|
||||
|
||||
|
||||
class CanceledException(Exception):
|
||||
"""Execution canceled by user."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ProgressImage(BaseModel):
|
||||
"""The progress image sent intermittently during processing"""
|
||||
|
||||
width: int = Field(description="The effective width of the image in pixels")
|
||||
height: int = Field(description="The effective height of the image in pixels")
|
||||
dataURL: str = Field(description="The image data as a b64 data URL")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
@@ -7,270 +6,136 @@ from typing import Optional
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import SessionProcessorBase
|
||||
from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
POLLING_INTERVAL = 1
|
||||
THREAD_LIMIT = 1
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker: Invoker = invoker
|
||||
self.__queue_item: Optional[SessionQueueItem] = None
|
||||
|
||||
self._resume_event = ThreadEvent()
|
||||
self._stop_event = ThreadEvent()
|
||||
self._poll_now_event = ThreadEvent()
|
||||
self._cancel_event = ThreadEvent()
|
||||
self.__resume_event = ThreadEvent()
|
||||
self.__stop_event = ThreadEvent()
|
||||
self.__poll_now_event = ThreadEvent()
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||
|
||||
self._thread_limit = thread_limit
|
||||
self._thread_semaphore = BoundedSemaphore(thread_limit)
|
||||
self._polling_interval = polling_interval
|
||||
|
||||
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
|
||||
# the profiler will create a new profile for each session.
|
||||
self._profiler = (
|
||||
Profiler(
|
||||
logger=self._invoker.services.logger,
|
||||
output_dir=self._invoker.services.configuration.profiles_path,
|
||||
prefix=self._invoker.services.configuration.profile_prefix,
|
||||
)
|
||||
if self._invoker.services.configuration.profile_graphs
|
||||
else None
|
||||
)
|
||||
|
||||
self._thread = Thread(
|
||||
self.__threadLimit = BoundedSemaphore(THREAD_LIMIT)
|
||||
self.__thread = Thread(
|
||||
name="session_processor",
|
||||
target=self._process,
|
||||
target=self.__process,
|
||||
kwargs={
|
||||
"stop_event": self._stop_event,
|
||||
"poll_now_event": self._poll_now_event,
|
||||
"resume_event": self._resume_event,
|
||||
"cancel_event": self._cancel_event,
|
||||
"stop_event": self.__stop_event,
|
||||
"poll_now_event": self.__poll_now_event,
|
||||
"resume_event": self.__resume_event,
|
||||
},
|
||||
)
|
||||
self._thread.start()
|
||||
self.__thread.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self._stop_event.set()
|
||||
self.__stop_event.set()
|
||||
|
||||
def _poll_now(self) -> None:
|
||||
self._poll_now_event.set()
|
||||
self.__poll_now_event.set()
|
||||
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
if event_name == "session_canceled" or event_name == "queue_cleared":
|
||||
# These both mean we should cancel the current session.
|
||||
self._cancel_event.set()
|
||||
# This was a match statement, but match is not supported on python 3.9
|
||||
if event_name in [
|
||||
"graph_execution_state_complete",
|
||||
"invocation_error",
|
||||
"session_retrieval_error",
|
||||
"invocation_retrieval_error",
|
||||
]:
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
elif (
|
||||
event_name == "session_canceled"
|
||||
and self.__queue_item is not None
|
||||
and self.__queue_item.session_id == event[1]["data"]["graph_execution_state_id"]
|
||||
):
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
elif event_name == "batch_enqueued":
|
||||
self._poll_now()
|
||||
elif event_name == "queue_cleared":
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
if not self._resume_event.is_set():
|
||||
self._resume_event.set()
|
||||
if not self.__resume_event.is_set():
|
||||
self.__resume_event.set()
|
||||
return self.get_status()
|
||||
|
||||
def pause(self) -> SessionProcessorStatus:
|
||||
if self._resume_event.is_set():
|
||||
self._resume_event.clear()
|
||||
if self.__resume_event.is_set():
|
||||
self.__resume_event.clear()
|
||||
return self.get_status()
|
||||
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
return SessionProcessorStatus(
|
||||
is_started=self._resume_event.is_set(),
|
||||
is_processing=self._queue_item is not None,
|
||||
is_started=self.__resume_event.is_set(),
|
||||
is_processing=self.__queue_item is not None,
|
||||
)
|
||||
|
||||
def _process(
|
||||
def __process(
|
||||
self,
|
||||
stop_event: ThreadEvent,
|
||||
poll_now_event: ThreadEvent,
|
||||
resume_event: ThreadEvent,
|
||||
cancel_event: ThreadEvent,
|
||||
):
|
||||
# Outermost processor try block; any unhandled exception is a fatal processor error
|
||||
try:
|
||||
self._thread_semaphore.acquire()
|
||||
stop_event.clear()
|
||||
resume_event.set()
|
||||
cancel_event.clear()
|
||||
|
||||
self.__threadLimit.acquire()
|
||||
queue_item: Optional[SessionQueueItem] = None
|
||||
while not stop_event.is_set():
|
||||
poll_now_event.clear()
|
||||
# Middle processor try block; any unhandled exception is a non-fatal processor error
|
||||
try:
|
||||
# Get the next session to process
|
||||
self._queue_item = self._invoker.services.session_queue.dequeue()
|
||||
if self._queue_item is not None and resume_event.is_set():
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
cancel_event.clear()
|
||||
# do not dequeue if there is already a session running
|
||||
if self.__queue_item is None and resume_event.is_set():
|
||||
queue_item = self.__invoker.services.session_queue.dequeue()
|
||||
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
||||
|
||||
# Prepare invocations and take the first
|
||||
self._invocation = self._queue_item.session.next()
|
||||
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while self._invocation is not None and not cancel_event.is_set():
|
||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
||||
|
||||
# Send starting event
|
||||
self._invoker.services.events.emit_invocation_started(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session_id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
if queue_item is not None:
|
||||
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
|
||||
self.__queue_item = queue_item
|
||||
self.__invoker.services.graph_execution_manager.set(queue_item.session)
|
||||
self.__invoker.invoke(
|
||||
session_queue_batch_id=queue_item.batch_id,
|
||||
session_queue_id=queue_item.queue_id,
|
||||
session_queue_item_id=queue_item.item_id,
|
||||
graph_execution_state=queue_item.session,
|
||||
workflow=queue_item.workflow,
|
||||
invoke_all=True,
|
||||
)
|
||||
queue_item = None
|
||||
|
||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
||||
try:
|
||||
with self._invoker.services.performance_statistics.collect_stats(
|
||||
self._invocation, self._queue_item.session.id
|
||||
):
|
||||
# Build invocation context (the node-facing API)
|
||||
data = InvocationContextData(
|
||||
invocation=self._invocation,
|
||||
source_invocation_id=source_invocation_id,
|
||||
queue_item=self._queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._invoker.services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
outputs = self._invocation.invoke_internal(
|
||||
context=context, services=self._invoker.services
|
||||
)
|
||||
|
||||
# Save outputs and history
|
||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
||||
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_invocation_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
result=outputs.model_dump(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
error = traceback.format_exc()
|
||||
|
||||
# Save error
|
||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
||||
self._invoker.services.logger.error(
|
||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
||||
)
|
||||
|
||||
# Send error event
|
||||
self._invoker.services.events.emit_invocation_error(
|
||||
queue_batch_id=self._queue_item.session_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
node=self._invocation.model_dump(),
|
||||
source_node_id=source_invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=error,
|
||||
)
|
||||
pass
|
||||
|
||||
# The session is complete if the all invocations are complete or there was an error
|
||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
||||
# Send complete event
|
||||
self._invoker.services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=self._queue_item.batch_id,
|
||||
queue_item_id=self._queue_item.item_id,
|
||||
queue_id=self._queue_item.queue_id,
|
||||
graph_execution_state_id=self._queue_item.session.id,
|
||||
)
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self._profiler:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self._invoker.services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
||||
self._invoker.services.performance_statistics.reset_stats()
|
||||
|
||||
# Set the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
else:
|
||||
# Prepare the next invocation
|
||||
self._invocation = self._queue_item.session.next()
|
||||
|
||||
# The session is complete, immediately poll for next session
|
||||
self._queue_item = None
|
||||
poll_now_event.set()
|
||||
else:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
if queue_item is None:
|
||||
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(POLLING_INTERVAL)
|
||||
continue
|
||||
except Exception:
|
||||
# Non-fatal error in processor
|
||||
self._invoker.services.logger.error(
|
||||
f"Non-fatal error in session processor:\n{traceback.format_exc()}"
|
||||
)
|
||||
# Cancel the queue item
|
||||
if self._queue_item is not None:
|
||||
self._invoker.services.session_queue.cancel_queue_item(
|
||||
self._queue_item.item_id, error=traceback.format_exc()
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error(f"Error in session processor: {e}")
|
||||
if queue_item is not None:
|
||||
self.__invoker.services.session_queue.cancel_queue_item(
|
||||
queue_item.item_id, error=traceback.format_exc()
|
||||
)
|
||||
# Reset the invocation to None to prepare for the next session
|
||||
self._invocation = None
|
||||
# Immediately poll for next queue item
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
poll_now_event.wait(POLLING_INTERVAL)
|
||||
continue
|
||||
except Exception:
|
||||
# Fatal error in processor, log and pass - we're done here
|
||||
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error(f"Fatal Error in session processor: {e}")
|
||||
pass
|
||||
finally:
|
||||
stop_event.clear()
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
self.__queue_item = None
|
||||
self.__threadLimit.release()
|
||||
|
||||
@@ -60,7 +60,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
# This was a match statement, but match is not supported on python 3.9
|
||||
if event_name == "graph_execution_state_complete":
|
||||
await self._handle_complete_event(event)
|
||||
elif event_name == "invocation_error":
|
||||
elif event_name in ["invocation_error", "session_retrieval_error", "invocation_retrieval_error"]:
|
||||
await self._handle_error_event(event)
|
||||
elif event_name == "session_canceled":
|
||||
await self._handle_cancel_event(event)
|
||||
@@ -429,6 +429,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
status = "failed" if error is not None else "canceled"
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
|
||||
self.__invoker.services.queue.cancel(queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
@@ -470,6 +471,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
@@ -521,6 +523,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
|
||||
92
invokeai/app/services/shared/default_graphs.py
Normal file
92
invokeai/app/services/shared/default_graphs.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from invokeai.app.services.item_storage.item_storage_base import ItemStorageABC
|
||||
|
||||
from ...invocations.compel import CompelInvocation
|
||||
from ...invocations.image import ImageNSFWBlurInvocation
|
||||
from ...invocations.latent import DenoiseLatentsInvocation, LatentsToImageInvocation
|
||||
from ...invocations.noise import NoiseInvocation
|
||||
from ...invocations.primitives import IntegerInvocation
|
||||
from .graph import Edge, EdgeConnection, ExposedNodeInput, ExposedNodeOutput, Graph, LibraryGraph
|
||||
|
||||
default_text_to_image_graph_id = "539b2af5-2b4d-4d8c-8071-e54a3255fc74"
|
||||
|
||||
|
||||
def create_text_to_image() -> LibraryGraph:
|
||||
graph = Graph(
|
||||
nodes={
|
||||
"width": IntegerInvocation(id="width", value=512),
|
||||
"height": IntegerInvocation(id="height", value=512),
|
||||
"seed": IntegerInvocation(id="seed", value=-1),
|
||||
"3": NoiseInvocation(id="3"),
|
||||
"4": CompelInvocation(id="4"),
|
||||
"5": CompelInvocation(id="5"),
|
||||
"6": DenoiseLatentsInvocation(id="6"),
|
||||
"7": LatentsToImageInvocation(id="7"),
|
||||
"8": ImageNSFWBlurInvocation(id="8"),
|
||||
},
|
||||
edges=[
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="width", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="width"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="height", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="height"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="seed", field="value"),
|
||||
destination=EdgeConnection(node_id="3", field="seed"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="3", field="noise"),
|
||||
destination=EdgeConnection(node_id="6", field="noise"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="6", field="latents"),
|
||||
destination=EdgeConnection(node_id="7", field="latents"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="4", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="positive_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="5", field="conditioning"),
|
||||
destination=EdgeConnection(node_id="6", field="negative_conditioning"),
|
||||
),
|
||||
Edge(
|
||||
source=EdgeConnection(node_id="7", field="image"),
|
||||
destination=EdgeConnection(node_id="8", field="image"),
|
||||
),
|
||||
],
|
||||
)
|
||||
return LibraryGraph(
|
||||
id=default_text_to_image_graph_id,
|
||||
name="t2i",
|
||||
description="Converts text to an image",
|
||||
graph=graph,
|
||||
exposed_inputs=[
|
||||
ExposedNodeInput(node_path="4", field="prompt", alias="positive_prompt"),
|
||||
ExposedNodeInput(node_path="5", field="prompt", alias="negative_prompt"),
|
||||
ExposedNodeInput(node_path="width", field="value", alias="width"),
|
||||
ExposedNodeInput(node_path="height", field="value", alias="height"),
|
||||
ExposedNodeInput(node_path="seed", field="value", alias="seed"),
|
||||
],
|
||||
exposed_outputs=[ExposedNodeOutput(node_path="8", field="image", alias="image")],
|
||||
)
|
||||
|
||||
|
||||
def create_system_graphs(graph_library: ItemStorageABC[LibraryGraph]) -> list[LibraryGraph]:
|
||||
"""Creates the default system graphs, or adds new versions if the old ones don't match"""
|
||||
|
||||
# TODO: Uncomment this when we are ready to fix this up to prevent breaking changes
|
||||
graphs: list[LibraryGraph] = []
|
||||
|
||||
text_to_image = graph_library.get(default_text_to_image_graph_id)
|
||||
|
||||
# TODO: Check if the graph is the same as the default one, and if not, update it
|
||||
# if text_to_image is None:
|
||||
text_to_image = create_text_to_image()
|
||||
graph_library.set(text_to_image)
|
||||
|
||||
graphs.append(text_to_image)
|
||||
|
||||
return graphs
|
||||
@@ -5,25 +5,22 @@ import itertools
|
||||
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
GetJsonSchemaHandler,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
||||
from pydantic.fields import Field
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from pydantic_core import CoreSchema
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from invokeai.app.invocations import * # noqa: F401 F403
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
# in 3.10 this would be "from types import NoneType"
|
||||
@@ -182,6 +179,10 @@ class NodeIdMismatchError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidSubGraphError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class CyclicalGraphError(ValueError):
|
||||
pass
|
||||
|
||||
@@ -190,6 +191,25 @@ class UnknownGraphValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: Create and use an Empty output?
|
||||
@invocation_output("graph_output")
|
||||
class GraphInvocationOutput(BaseInvocationOutput):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
@invocation("graph", version="1.0.0")
|
||||
class GraphInvocation(BaseInvocation):
|
||||
"""Execute a graph"""
|
||||
|
||||
# TODO: figure out how to create a default here
|
||||
graph: "Graph" = InputField(description="The graph to run", default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
|
||||
"""Invoke with provided services and return outputs."""
|
||||
return GraphInvocationOutput()
|
||||
|
||||
|
||||
@invocation_output("iterate_output")
|
||||
class IterateInvocationOutput(BaseInvocationOutput):
|
||||
"""Used to connect iteration outputs. Will be expanded to a specific output."""
|
||||
@@ -243,73 +263,21 @@ class CollectInvocation(BaseInvocation):
|
||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||
|
||||
|
||||
InvocationsUnion: Any = BaseInvocation.get_invocations_union()
|
||||
InvocationOutputsUnion: Any = BaseInvocationOutput.get_outputs_union()
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The nodes in this graph", default_factory=dict
|
||||
)
|
||||
edges: list[Edge] = Field(
|
||||
description="The connections between nodes and their fields in this graph",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
@field_validator("nodes", mode="plain")
|
||||
@classmethod
|
||||
def validate_nodes(cls, v: dict[str, Any]):
|
||||
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
|
||||
|
||||
# Invocations register themselves as their python modules are executed. The union of all invocations is
|
||||
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
|
||||
#
|
||||
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
|
||||
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
|
||||
# invocations will cause a graph to fail if they are used.
|
||||
#
|
||||
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
|
||||
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
|
||||
#
|
||||
# This same pattern is used in `GraphExecutionState`.
|
||||
|
||||
nodes: dict[str, BaseInvocation] = {}
|
||||
typeadapter = BaseInvocation.get_typeadapter()
|
||||
for node_id, node in v.items():
|
||||
nodes[node_id] = typeadapter.validate_python(node)
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
|
||||
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
|
||||
# the generated schema as options for the `nodes` field.
|
||||
#
|
||||
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
|
||||
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
|
||||
# expected.
|
||||
#
|
||||
# You might be tempted to do something like this:
|
||||
#
|
||||
# ```py
|
||||
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
|
||||
# delattr(cloned_model, "validate_nodes")
|
||||
# cloned_model.model_rebuild(force=True)
|
||||
# json_schema = handler(cloned_model.__pydantic_core_schema__)
|
||||
# ```
|
||||
#
|
||||
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
|
||||
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
|
||||
#
|
||||
# This same pattern is used in `GraphExecutionState`.
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: Optional[str] = Field(default=None, description="The id of this graph")
|
||||
nodes: dict[
|
||||
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
|
||||
] = Field(description="The nodes in this graph")
|
||||
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
|
||||
|
||||
json_schema = handler(Graph.__pydantic_core_schema__)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
return json_schema
|
||||
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
"""Adds a node to a graph
|
||||
|
||||
@@ -321,21 +289,41 @@ class Graph(BaseModel):
|
||||
|
||||
self.nodes[node.id] = node
|
||||
|
||||
def delete_node(self, node_id: str) -> None:
|
||||
def _get_graph_and_node(self, node_path: str) -> tuple["Graph", str]:
|
||||
"""Returns the graph and node id for a node path."""
|
||||
# Materialized graphs may have nodes at the top level
|
||||
if node_path in self.nodes:
|
||||
return (self, node_path)
|
||||
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
if node_id not in self.nodes:
|
||||
raise NodeNotFoundError(f"Node {node_path} not found in graph")
|
||||
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if not isinstance(node, GraphInvocation):
|
||||
# There's more node path left but this isn't a graph - failure
|
||||
raise NodeNotFoundError("Node path terminated early at a non-graph node")
|
||||
|
||||
return node.graph._get_graph_and_node(node_path[node_path.index(".") + 1 :])
|
||||
|
||||
def delete_node(self, node_path: str) -> None:
|
||||
"""Deletes a node from a graph"""
|
||||
|
||||
try:
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
|
||||
# Delete edges for this node
|
||||
input_edges = self._get_input_edges(node_id)
|
||||
output_edges = self._get_output_edges(node_id)
|
||||
input_edges = self._get_input_edges_and_graphs(node_path)
|
||||
output_edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
for edge in input_edges:
|
||||
self.delete_edge(edge)
|
||||
for edge_graph, _, edge in input_edges:
|
||||
edge_graph.delete_edge(edge)
|
||||
|
||||
for edge in output_edges:
|
||||
self.delete_edge(edge)
|
||||
for edge_graph, _, edge in output_edges:
|
||||
edge_graph.delete_edge(edge)
|
||||
|
||||
del self.nodes[node_id]
|
||||
del graph.nodes[node_id]
|
||||
|
||||
except NodeNotFoundError:
|
||||
pass # Ignore, not doesn't exist (should this throw?)
|
||||
@@ -385,6 +373,13 @@ class Graph(BaseModel):
|
||||
if k != v.id:
|
||||
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
|
||||
|
||||
# Validate all subgraphs
|
||||
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
|
||||
try:
|
||||
gn.graph.validate_self()
|
||||
except Exception as e:
|
||||
raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e
|
||||
|
||||
# Validate that all edges match nodes and fields in the graph
|
||||
for edge in self.edges:
|
||||
source_node = self.nodes.get(edge.source.node_id, None)
|
||||
@@ -446,6 +441,7 @@ class Graph(BaseModel):
|
||||
except (
|
||||
DuplicateNodeIdError,
|
||||
NodeIdMismatchError,
|
||||
InvalidSubGraphError,
|
||||
NodeNotFoundError,
|
||||
NodeFieldNotFoundError,
|
||||
CyclicalGraphError,
|
||||
@@ -466,7 +462,7 @@ class Graph(BaseModel):
|
||||
def _validate_edge(self, edge: Edge):
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
try:
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
@@ -533,90 +529,171 @@ class Graph(BaseModel):
|
||||
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
def has_node(self, node_id: str) -> bool:
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
"""Determines whether or not a node exists in the graph."""
|
||||
try:
|
||||
_ = self.get_node(node_id)
|
||||
return True
|
||||
n = self.get_node(node_path)
|
||||
if n is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
|
||||
def get_node(self, node_id: str) -> BaseInvocation:
|
||||
"""Gets a node from the graph."""
|
||||
try:
|
||||
return self.nodes[node_id]
|
||||
except KeyError as e:
|
||||
raise NodeNotFoundError(f"Node {node_id} not found in graph") from e
|
||||
def get_node(self, node_path: str) -> InvocationsUnion:
|
||||
"""Gets a node from the graph using a node path."""
|
||||
# Materialized graphs may have nodes at the top level
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
return graph.nodes[node_id]
|
||||
|
||||
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
|
||||
def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str:
|
||||
return node_id if prefix is None or prefix == "" else f"{prefix}.{node_id}"
|
||||
|
||||
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
|
||||
"""Updates a node in the graph."""
|
||||
node = self.nodes[node_id]
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
node = graph.nodes[node_id]
|
||||
|
||||
# Ensure the node type matches the new node
|
||||
if type(node) is not type(new_node):
|
||||
raise TypeError(f"Node {node_id} is type {type(node)} but new node is type {type(new_node)}")
|
||||
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
|
||||
|
||||
# Ensure the new id is either the same or is not in the graph
|
||||
if new_node.id != node.id and self.has_node(new_node.id):
|
||||
raise NodeAlreadyInGraphError(f"Node with id {new_node.id} already exists in graph")
|
||||
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
|
||||
new_path = self._get_node_path(new_node.id, prefix=prefix)
|
||||
if new_node.id != node.id and self.has_node(new_path):
|
||||
raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph")
|
||||
|
||||
# Set the new node in the graph
|
||||
self.nodes[new_node.id] = new_node
|
||||
graph.nodes[new_node.id] = new_node
|
||||
if new_node.id != node.id:
|
||||
input_edges = self._get_input_edges(node_id)
|
||||
output_edges = self._get_output_edges(node_id)
|
||||
input_edges = self._get_input_edges_and_graphs(node_path)
|
||||
output_edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
# Delete node and all edges
|
||||
self.delete_node(node_id)
|
||||
graph.delete_node(node_path)
|
||||
|
||||
# Create new edges for each input and output
|
||||
for edge in input_edges:
|
||||
self.add_edge(
|
||||
for graph, _, edge in input_edges:
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = (
|
||||
new_node.id
|
||||
if "." not in edge.destination.node_id
|
||||
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
|
||||
)
|
||||
graph.add_edge(
|
||||
Edge(
|
||||
source=edge.source,
|
||||
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
|
||||
destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field),
|
||||
)
|
||||
)
|
||||
|
||||
for edge in output_edges:
|
||||
self.add_edge(
|
||||
for graph, _, edge in output_edges:
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = (
|
||||
new_node.id
|
||||
if "." not in edge.source.node_id
|
||||
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
|
||||
)
|
||||
graph.add_edge(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=new_node.id, field=edge.source.field),
|
||||
source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field),
|
||||
destination=edge.destination,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_input_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
|
||||
"""Gets all input edges for a node. If field is provided, only edges to that field are returned."""
|
||||
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]:
|
||||
"""Gets all input edges for a node"""
|
||||
edges = self._get_input_edges_and_graphs(node_path)
|
||||
|
||||
edges = [e for e in self.edges if e.destination.node_id == node_id]
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
|
||||
|
||||
if field is None:
|
||||
return edges
|
||||
# Create full node paths for each edge
|
||||
return [
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||
field=e.source.field,
|
||||
),
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
),
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
|
||||
filtered_edges = [e for e in edges if e.destination.field == field]
|
||||
def _get_input_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = []
|
||||
|
||||
return filtered_edges
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
||||
|
||||
def _get_output_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
|
||||
"""Gets all output edges for a node. If field is provided, only edges from that field are returned."""
|
||||
edges = [e for e in self.edges if e.source.node_id == node_id]
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if field is None:
|
||||
return edges
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
||||
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
filtered_edges = [e for e in edges if e.source.field == field]
|
||||
return edges
|
||||
|
||||
return filtered_edges
|
||||
def _get_output_edges(self, node_path: str, field: str) -> list[Edge]:
|
||||
"""Gets all output edges for a node"""
|
||||
edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if e[2].source.field == field)
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||
field=e.source.field,
|
||||
),
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
),
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
|
||||
def _get_output_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = []
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
||||
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
||||
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
return edges
|
||||
|
||||
def _is_iterator_connection_valid(
|
||||
self,
|
||||
node_id: str,
|
||||
node_path: str,
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_id, "item")]
|
||||
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@@ -644,12 +721,12 @@ class Graph(BaseModel):
|
||||
|
||||
def _is_collector_connection_valid(
|
||||
self,
|
||||
node_id: str,
|
||||
node_path: str,
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = [e.source for e in self._get_input_edges(node_id, "item")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]
|
||||
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@@ -705,17 +782,27 @@ class Graph(BaseModel):
|
||||
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
||||
return g
|
||||
|
||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph:
|
||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
||||
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
|
||||
g = nx_graph or nx.DiGraph()
|
||||
|
||||
# Add all nodes from this graph except graph/iteration nodes
|
||||
g.add_nodes_from([n.id for n in self.nodes.values() if not isinstance(n, IterateInvocation)])
|
||||
g.add_nodes_from(
|
||||
[
|
||||
self._get_node_path(n.id, prefix)
|
||||
for n in self.nodes.values()
|
||||
if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)
|
||||
]
|
||||
)
|
||||
|
||||
# Expand graph nodes
|
||||
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
|
||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
|
||||
g.add_edges_from([(e[0], e[1]) for e in unique_edges])
|
||||
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
||||
return g
|
||||
|
||||
|
||||
@@ -740,7 +827,9 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
# The results of executed nodes
|
||||
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||
results: dict[str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]] = Field(
|
||||
description="The results of node executions", default_factory=dict
|
||||
)
|
||||
|
||||
# Errors raised when executing nodes
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||
@@ -757,51 +846,27 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
@field_validator("results", mode="plain")
|
||||
@classmethod
|
||||
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
|
||||
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
|
||||
|
||||
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
|
||||
results: dict[str, BaseInvocationOutput] = {}
|
||||
typeadapter = BaseInvocationOutput.get_typeadapter()
|
||||
for result_id, result in v.items():
|
||||
results[result_id] = typeadapter.validate_python(result)
|
||||
return results
|
||||
|
||||
@field_validator("graph")
|
||||
def graph_is_valid(cls, v: Graph):
|
||||
"""Validates that the graph is valid"""
|
||||
v.validate_self()
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(description="The id of the execution state")
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
|
||||
executed: set[str] = Field(description="The set of node ids that have been executed")
|
||||
executed_history: list[str] = Field(
|
||||
description="The list of node ids that have been executed, in order of execution"
|
||||
)
|
||||
results: dict[
|
||||
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
|
||||
] = Field(description="The results of node executions")
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
|
||||
prepared_source_mapping: dict[str, str] = Field(
|
||||
description="The map of prepared nodes to original graph nodes"
|
||||
)
|
||||
source_prepared_mapping: dict[str, set[str]] = Field(
|
||||
description="The map of original graph nodes to prepared nodes"
|
||||
)
|
||||
|
||||
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
return json_schema
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra={
|
||||
"required": [
|
||||
"id",
|
||||
"graph",
|
||||
"execution_graph",
|
||||
"executed",
|
||||
"executed_history",
|
||||
"results",
|
||||
"errors",
|
||||
"prepared_source_mapping",
|
||||
"source_prepared_mapping",
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
def next(self) -> Optional[BaseInvocation]:
|
||||
"""Gets the next node ready to execute."""
|
||||
@@ -826,7 +891,7 @@ class GraphExecutionState(BaseModel):
|
||||
# If next is still none, there's no next node, return None
|
||||
return next_node
|
||||
|
||||
def complete(self, node_id: str, output: BaseInvocationOutput) -> None:
|
||||
def complete(self, node_id: str, output: InvocationOutputsUnion):
|
||||
"""Marks a node as complete"""
|
||||
|
||||
if node_id not in self.execution_graph.nodes:
|
||||
@@ -857,17 +922,17 @@ class GraphExecutionState(BaseModel):
|
||||
"""Returns true if the graph has any errors"""
|
||||
return len(self.errors) > 0
|
||||
|
||||
def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
||||
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
||||
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
||||
|
||||
node = self.graph.get_node(node_id)
|
||||
node = self.graph.get_node(node_path)
|
||||
|
||||
self_iteration_count = -1
|
||||
|
||||
# If this is an iterator node, we must create a copy for each iteration
|
||||
if isinstance(node, IterateInvocation):
|
||||
# Get input collection edge (should error if there are no inputs)
|
||||
input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection")))
|
||||
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection")))
|
||||
input_collection_prepared_node_id = next(
|
||||
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
|
||||
)
|
||||
@@ -881,7 +946,7 @@ class GraphExecutionState(BaseModel):
|
||||
return new_nodes
|
||||
|
||||
# Get all input edges
|
||||
input_edges = self.graph._get_input_edges(node_id)
|
||||
input_edges = self.graph._get_input_edges(node_path)
|
||||
|
||||
# Create new edges for this iteration
|
||||
# For collect nodes, this may contain multiple inputs to the same field
|
||||
@@ -908,10 +973,10 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Add to execution graph
|
||||
self.execution_graph.add_node(new_node)
|
||||
self.prepared_source_mapping[new_node.id] = node_id
|
||||
if node_id not in self.source_prepared_mapping:
|
||||
self.source_prepared_mapping[node_id] = set()
|
||||
self.source_prepared_mapping[node_id].add(new_node.id)
|
||||
self.prepared_source_mapping[new_node.id] = node_path
|
||||
if node_path not in self.source_prepared_mapping:
|
||||
self.source_prepared_mapping[node_path] = set()
|
||||
self.source_prepared_mapping[node_path].add(new_node.id)
|
||||
|
||||
# Add new edges to execution graph
|
||||
for edge in new_edges:
|
||||
@@ -1015,13 +1080,13 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def _get_iteration_node(
|
||||
self,
|
||||
source_node_id: str,
|
||||
source_node_path: str,
|
||||
graph: nx.DiGraph,
|
||||
execution_graph: nx.DiGraph,
|
||||
prepared_iterator_nodes: list[str],
|
||||
) -> Optional[str]:
|
||||
"""Gets the prepared version of the specified source node that matches every iteration specified"""
|
||||
prepared_nodes = self.source_prepared_mapping[source_node_id]
|
||||
prepared_nodes = self.source_prepared_mapping[source_node_path]
|
||||
if len(prepared_nodes) == 1:
|
||||
return next(iter(prepared_nodes))
|
||||
|
||||
@@ -1032,7 +1097,7 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
|
||||
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
|
||||
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)]
|
||||
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
|
||||
|
||||
return next(
|
||||
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
|
||||
@@ -1101,19 +1166,19 @@ class GraphExecutionState(BaseModel):
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
self.graph.add_node(node)
|
||||
|
||||
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
|
||||
if not self._is_node_updatable(node_id):
|
||||
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
|
||||
if not self._is_node_updatable(node_path):
|
||||
raise NodeAlreadyExecutedError(
|
||||
f"Node {node_id} has already been prepared or executed and cannot be updated"
|
||||
f"Node {node_path} has already been prepared or executed and cannot be updated"
|
||||
)
|
||||
self.graph.update_node(node_id, new_node)
|
||||
self.graph.update_node(node_path, new_node)
|
||||
|
||||
def delete_node(self, node_id: str) -> None:
|
||||
if not self._is_node_updatable(node_id):
|
||||
def delete_node(self, node_path: str) -> None:
|
||||
if not self._is_node_updatable(node_path):
|
||||
raise NodeAlreadyExecutedError(
|
||||
f"Node {node_id} has already been prepared or executed and cannot be deleted"
|
||||
f"Node {node_path} has already been prepared or executed and cannot be deleted"
|
||||
)
|
||||
self.graph.delete_node(node_id)
|
||||
self.graph.delete_node(node_path)
|
||||
|
||||
def add_edge(self, edge: Edge) -> None:
|
||||
if not self._is_node_updatable(edge.destination.node_id):
|
||||
@@ -1128,3 +1193,63 @@ class GraphExecutionState(BaseModel):
|
||||
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
|
||||
)
|
||||
self.graph.delete_edge(edge)
|
||||
|
||||
|
||||
class ExposedNodeInput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the input")
|
||||
field: str = Field(description="The field name of the input")
|
||||
alias: str = Field(description="The alias of the input")
|
||||
|
||||
|
||||
class ExposedNodeOutput(BaseModel):
|
||||
node_path: str = Field(description="The node path to the node with the output")
|
||||
field: str = Field(description="The field name of the output")
|
||||
alias: str = Field(description="The alias of the output")
|
||||
|
||||
|
||||
class LibraryGraph(BaseModel):
|
||||
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid_string)
|
||||
graph: Graph = Field(description="The graph")
|
||||
name: str = Field(description="The name of the graph")
|
||||
description: str = Field(description="The description of the graph")
|
||||
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
|
||||
exposed_outputs: list[ExposedNodeOutput] = Field(
|
||||
description="The outputs exposed by this graph", default_factory=list
|
||||
)
|
||||
|
||||
@field_validator("exposed_inputs", "exposed_outputs")
|
||||
def validate_exposed_aliases(cls, v: list[Union[ExposedNodeInput, ExposedNodeOutput]]):
|
||||
if len(v) != len({i.alias for i in v}):
|
||||
raise ValueError("Duplicate exposed alias")
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_exposed_nodes(cls, values):
|
||||
graph = values.graph
|
||||
|
||||
# Validate exposed inputs
|
||||
for exposed_input in values.exposed_inputs:
|
||||
if not graph.has_node(exposed_input.node_path):
|
||||
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
|
||||
node = graph.get_node(exposed_input.node_path)
|
||||
if get_input_field(node, exposed_input.field) is None:
|
||||
raise ValueError(
|
||||
f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}"
|
||||
)
|
||||
|
||||
# Validate exposed outputs
|
||||
for exposed_output in values.exposed_outputs:
|
||||
if not graph.has_node(exposed_output.node_path):
|
||||
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
|
||||
node = graph.get_node(exposed_output.node_path)
|
||||
if get_output_field(node, exposed_output.field) is None:
|
||||
raise ValueError(
|
||||
f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
GraphInvocation.model_rebuild(force=True)
|
||||
Graph.model_rebuild(force=True)
|
||||
GraphExecutionState.model_rebuild(force=True)
|
||||
|
||||
@@ -1,470 +0,0 @@
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
"""
|
||||
The InvocationContext provides access to various services and data about the current invocation.
|
||||
|
||||
We do not provide the invocation services directly, as their methods are both dangerous and
|
||||
inconvenient to use.
|
||||
|
||||
For example:
|
||||
- The `images` service allows nodes to delete or unsafely modify existing images.
|
||||
- The `configuration` service allows nodes to change the app's config at runtime.
|
||||
- The `events` service allows nodes to emit arbitrary events.
|
||||
|
||||
Wrapping these services provides a simpler and safer interface for nodes to use.
|
||||
|
||||
When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere
|
||||
with each other.
|
||||
|
||||
Many of the wrappers have the same signature as the methods they wrap. This allows us to write
|
||||
user-facing docstrings and not need to go and update the internal services to match.
|
||||
|
||||
Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvocationContextData:
|
||||
queue_item: "SessionQueueItem"
|
||||
"""The queue item that is being executed."""
|
||||
invocation: "BaseInvocation"
|
||||
"""The invocation that is being executed."""
|
||||
source_invocation_id: str
|
||||
"""The ID of the invocation from which the currently executing invocation was prepared."""
|
||||
|
||||
|
||||
class InvocationContextInterface:
|
||||
def __init__(self, services: InvocationServices, data: InvocationContextData) -> None:
|
||||
self._services = services
|
||||
self._data = data
|
||||
|
||||
|
||||
class BoardsInterface(InvocationContextInterface):
|
||||
def create(self, board_name: str) -> BoardDTO:
|
||||
"""
|
||||
Creates a board.
|
||||
|
||||
:param board_name: The name of the board to create.
|
||||
"""
|
||||
return self._services.boards.create(board_name)
|
||||
|
||||
def get_dto(self, board_id: str) -> BoardDTO:
|
||||
"""
|
||||
Gets a board DTO.
|
||||
|
||||
:param board_id: The ID of the board to get.
|
||||
"""
|
||||
return self._services.boards.get_dto(board_id)
|
||||
|
||||
def get_all(self) -> list[BoardDTO]:
|
||||
"""
|
||||
Gets all boards.
|
||||
"""
|
||||
return self._services.boards.get_all()
|
||||
|
||||
def add_image_to_board(self, board_id: str, image_name: str) -> None:
|
||||
"""
|
||||
Adds an image to a board.
|
||||
|
||||
:param board_id: The ID of the board to add the image to.
|
||||
:param image_name: The name of the image to add to the board.
|
||||
"""
|
||||
return self._services.board_images.add_image_to_board(board_id, image_name)
|
||||
|
||||
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
|
||||
"""
|
||||
Gets all image names for a board.
|
||||
|
||||
:param board_id: The ID of the board to get the image names for.
|
||||
"""
|
||||
return self._services.board_images.get_all_board_image_names_for_board(board_id)
|
||||
|
||||
|
||||
class LoggerInterface(InvocationContextInterface):
|
||||
def debug(self, message: str) -> None:
|
||||
"""
|
||||
Logs a debug message.
|
||||
|
||||
:param message: The message to log.
|
||||
"""
|
||||
self._services.logger.debug(message)
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""
|
||||
Logs an info message.
|
||||
|
||||
:param message: The message to log.
|
||||
"""
|
||||
self._services.logger.info(message)
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""
|
||||
Logs a warning message.
|
||||
|
||||
:param message: The message to log.
|
||||
"""
|
||||
self._services.logger.warning(message)
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""
|
||||
Logs an error message.
|
||||
|
||||
:param message: The message to log.
|
||||
"""
|
||||
self._services.logger.error(message)
|
||||
|
||||
|
||||
class ImagesInterface(InvocationContextInterface):
|
||||
def save(
|
||||
self,
|
||||
image: Image,
|
||||
board_id: Optional[str] = None,
|
||||
image_category: ImageCategory = ImageCategory.GENERAL,
|
||||
metadata: Optional[MetadataField] = None,
|
||||
) -> ImageDTO:
|
||||
"""
|
||||
Saves an image, returning its DTO.
|
||||
|
||||
If the current queue item has a workflow or metadata, it is automatically saved with the image.
|
||||
|
||||
:param image: The image to save, as a PIL image.
|
||||
:param board_id: The board ID to add the image to, if it should be added. It the invocation \
|
||||
inherits from `WithBoard`, that board will be used automatically. **Use this only if \
|
||||
you want to override or provide a board manually!**
|
||||
:param image_category: The category of the image. Only the GENERAL category is added \
|
||||
to the gallery.
|
||||
:param metadata: The metadata to save with the image, if it should have any. If the \
|
||||
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
|
||||
**Use this only if you want to override or provide metadata manually!**
|
||||
"""
|
||||
|
||||
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
||||
metadata_ = None
|
||||
if metadata:
|
||||
metadata_ = metadata
|
||||
elif isinstance(self._data.invocation, WithMetadata):
|
||||
metadata_ = self._data.invocation.metadata
|
||||
|
||||
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
|
||||
board_id_ = None
|
||||
if board_id:
|
||||
board_id_ = board_id
|
||||
elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board:
|
||||
board_id_ = self._data.invocation.board.board_id
|
||||
|
||||
return self._services.images.create(
|
||||
image=image,
|
||||
is_intermediate=self._data.invocation.is_intermediate,
|
||||
image_category=image_category,
|
||||
board_id=board_id_,
|
||||
metadata=metadata_,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
workflow=self._data.queue_item.workflow,
|
||||
session_id=self._data.queue_item.session_id,
|
||||
node_id=self._data.invocation.id,
|
||||
)
|
||||
|
||||
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
|
||||
"""
|
||||
Gets an image as a PIL Image object.
|
||||
|
||||
:param image_name: The name of the image to get.
|
||||
:param mode: The color mode to convert the image to. If None, the original mode is used.
|
||||
"""
|
||||
image = self._services.images.get_pil_image(image_name)
|
||||
if mode and mode != image.mode:
|
||||
try:
|
||||
image = image.convert(mode)
|
||||
except ValueError:
|
||||
self._services.logger.warning(
|
||||
f"Could not convert image from {image.mode} to {mode}. Using original mode instead."
|
||||
)
|
||||
return image
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
"""
|
||||
Gets an image's metadata, if it has any.
|
||||
|
||||
:param image_name: The name of the image to get the metadata for.
|
||||
"""
|
||||
return self._services.images.get_metadata(image_name)
|
||||
|
||||
def get_dto(self, image_name: str) -> ImageDTO:
|
||||
"""
|
||||
Gets an image as an ImageDTO object.
|
||||
|
||||
:param image_name: The name of the image to get.
|
||||
"""
|
||||
return self._services.images.get_dto(image_name)
|
||||
|
||||
|
||||
class TensorsInterface(InvocationContextInterface):
|
||||
def save(self, tensor: Tensor) -> str:
|
||||
"""
|
||||
Saves a tensor, returning its name.
|
||||
|
||||
:param tensor: The tensor to save.
|
||||
"""
|
||||
|
||||
name = self._services.tensors.save(obj=tensor)
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> Tensor:
|
||||
"""
|
||||
Loads a tensor by name.
|
||||
|
||||
:param name: The name of the tensor to load.
|
||||
"""
|
||||
return self._services.tensors.load(name)
|
||||
|
||||
|
||||
class ConditioningInterface(InvocationContextInterface):
|
||||
def save(self, conditioning_data: ConditioningFieldData) -> str:
|
||||
"""
|
||||
Saves a conditioning data object, returning its name.
|
||||
|
||||
:param conditioning_data: The conditioning data to save.
|
||||
"""
|
||||
|
||||
name = self._services.conditioning.save(obj=conditioning_data)
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> ConditioningFieldData:
|
||||
"""
|
||||
Loads conditioning data by name.
|
||||
|
||||
:param name: The name of the conditioning data to load.
|
||||
"""
|
||||
|
||||
return self._services.conditioning.load(name)
|
||||
|
||||
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
def exists(self, key: str) -> bool:
|
||||
"""
|
||||
Checks if a model exists.
|
||||
|
||||
:param key: The key of the model.
|
||||
"""
|
||||
return self._services.model_manager.store.exists(key)
|
||||
|
||||
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Loads a model.
|
||||
|
||||
:param key: The key of the model.
|
||||
:param submodel_type: The submodel of the model to get.
|
||||
:returns: An object representing the loaded model.
|
||||
"""
|
||||
|
||||
# The model manager emits events as it loads the model. It needs the context data to build
|
||||
# the event payloads.
|
||||
|
||||
return self._services.model_manager.load_model_by_key(
|
||||
key=key, submodel_type=submodel_type, context_data=self._data
|
||||
)
|
||||
|
||||
def load_by_attrs(
|
||||
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Loads a model by its attributes.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
"""
|
||||
return self._services.model_manager.load_model_by_attr(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
submodel=submodel,
|
||||
context_data=self._data,
|
||||
)
|
||||
|
||||
def get_config(self, key: str) -> AnyModelConfig:
|
||||
"""
|
||||
Gets a model's info, an dict-like object.
|
||||
|
||||
:param key: The key of the model.
|
||||
"""
|
||||
return self._services.model_manager.store.get_model(key=key)
|
||||
|
||||
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
||||
"""
|
||||
Gets a model's metadata, if it has any.
|
||||
|
||||
:param key: The key of the model.
|
||||
"""
|
||||
return self._services.model_manager.store.get_metadata(key=key)
|
||||
|
||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||
"""
|
||||
Searches for models by path.
|
||||
|
||||
:param path: The path to search for.
|
||||
"""
|
||||
return self._services.model_manager.store.search_by_path(path)
|
||||
|
||||
def search_by_attrs(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_format: Optional[ModelFormat] = None,
|
||||
) -> list[AnyModelConfig]:
|
||||
"""
|
||||
Searches for models by attributes.
|
||||
|
||||
:param model_name: Name of to be fetched.
|
||||
:param base_model: Base model
|
||||
:param model_type: Type of the model
|
||||
:param submodel: For main (pipeline models), the submodel to fetch
|
||||
"""
|
||||
|
||||
return self._services.model_manager.store.search_by_attr(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_format=model_format,
|
||||
)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
def get(self) -> InvokeAIAppConfig:
|
||||
"""Gets the app's config."""
|
||||
|
||||
return self._services.configuration.get_config()
|
||||
|
||||
|
||||
class UtilInterface(InvocationContextInterface):
|
||||
def __init__(
|
||||
self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event
|
||||
) -> None:
|
||||
super().__init__(services, data)
|
||||
self._cancel_event = cancel_event
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
"""Checks if the current invocation has been canceled."""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
||||
"""
|
||||
The step callback emits a progress event with the current step, the total number of
|
||||
steps, a preview image, and some other internal metadata.
|
||||
|
||||
This should be called after each denoising step.
|
||||
|
||||
:param intermediate_state: The intermediate state of the diffusion pipeline.
|
||||
:param base_model: The base model for the current denoising step.
|
||||
"""
|
||||
|
||||
stable_diffusion_step_callback(
|
||||
context_data=self._data,
|
||||
intermediate_state=intermediate_state,
|
||||
base_model=base_model,
|
||||
events=self._services.events,
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""
|
||||
The `InvocationContext` provides access to various services and data for the current invocation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
images: ImagesInterface,
|
||||
tensors: TensorsInterface,
|
||||
conditioning: ConditioningInterface,
|
||||
models: ModelsInterface,
|
||||
logger: LoggerInterface,
|
||||
config: ConfigInterface,
|
||||
util: UtilInterface,
|
||||
boards: BoardsInterface,
|
||||
data: InvocationContextData,
|
||||
services: InvocationServices,
|
||||
) -> None:
|
||||
self.images = images
|
||||
"""Methods to save, get and update images and their metadata."""
|
||||
self.tensors = tensors
|
||||
"""Methods to save and get tensors, including image, noise, masks, and masked images."""
|
||||
self.conditioning = conditioning
|
||||
"""Methods to save and get conditioning data."""
|
||||
self.models = models
|
||||
"""Methods to check if a model exists, get a model, and get a model's info."""
|
||||
self.logger = logger
|
||||
"""The app logger."""
|
||||
self.config = config
|
||||
"""The app config."""
|
||||
self.util = util
|
||||
"""Utility methods, including a method to check if an invocation was canceled and step callbacks."""
|
||||
self.boards = boards
|
||||
"""Methods to interact with boards."""
|
||||
self._data = data
|
||||
"""An internal API providing access to data about the current queue item and invocation. You probably shouldn't use this. It may change without warning."""
|
||||
self._services = services
|
||||
"""An internal API providing access to all application services. You probably shouldn't use this. It may change without warning."""
|
||||
|
||||
|
||||
def build_invocation_context(
|
||||
services: InvocationServices,
|
||||
data: InvocationContextData,
|
||||
cancel_event: threading.Event,
|
||||
) -> InvocationContext:
|
||||
"""
|
||||
Builds the invocation context for a specific invocation execution.
|
||||
|
||||
:param services: The invocation services to wrap.
|
||||
:param data: The invocation context data.
|
||||
"""
|
||||
|
||||
logger = LoggerInterface(services=services, data=data)
|
||||
images = ImagesInterface(services=services, data=data)
|
||||
tensors = TensorsInterface(services=services, data=data)
|
||||
models = ModelsInterface(services=services, data=data)
|
||||
config = ConfigInterface(services=services, data=data)
|
||||
util = UtilInterface(services=services, data=data, cancel_event=cancel_event)
|
||||
conditioning = ConditioningInterface(services=services, data=data)
|
||||
boards = BoardsInterface(services=services, data=data)
|
||||
|
||||
ctx = InvocationContext(
|
||||
images=images,
|
||||
logger=logger,
|
||||
config=config,
|
||||
tensors=tensors,
|
||||
models=models,
|
||||
data=data,
|
||||
util=util,
|
||||
conditioning=conditioning,
|
||||
services=services,
|
||||
boards=boards,
|
||||
)
|
||||
|
||||
return ctx
|
||||
67
invokeai/app/shared/fields.py
Normal file
67
invokeai/app/shared/fields.py
Normal file
@@ -0,0 +1,67 @@
|
||||
class FieldDescriptions:
|
||||
denoising_start = "When to start denoising, expressed a percentage of total steps"
|
||||
denoising_end = "When to stop denoising, expressed a percentage of total steps"
|
||||
cfg_scale = "Classifier-Free Guidance scale"
|
||||
cfg_rescale_multiplier = "Rescale multiplier for CFG guidance, used for models trained with zero-terminal SNR"
|
||||
scheduler = "Scheduler to use during inference"
|
||||
positive_cond = "Positive conditioning tensor"
|
||||
negative_cond = "Negative conditioning tensor"
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
vae_model = "VAE model to load"
|
||||
lora_model = "LoRA model to load"
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
raw_prompt = "Raw prompt text (no parsing)"
|
||||
sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor"
|
||||
skipped_layers = "Number of layers to skip in text encoder"
|
||||
seed = "Seed for random number generation"
|
||||
steps = "Number of steps to run"
|
||||
width = "Width of output (px)"
|
||||
height = "Height of output (px)"
|
||||
control = "ControlNet(s) to apply"
|
||||
ip_adapter = "IP-Adapter to apply"
|
||||
t2i_adapter = "T2I-Adapter(s) to apply"
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
metadata = "Optional metadata to be saved with the image"
|
||||
metadata_collection = "Collection of Metadata"
|
||||
metadata_item_polymorphic = "A single metadata item or collection of metadata items"
|
||||
metadata_item_label = "Label for this metadata item"
|
||||
metadata_item_value = "The value for this metadata item (may be any type)"
|
||||
workflow = "Optional workflow to be saved with the image"
|
||||
interp_mode = "Interpolation mode"
|
||||
torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)"
|
||||
fp32 = "Whether or not to use full float32 precision"
|
||||
precision = "Precision to use"
|
||||
tiled = "Processing using overlapping tiles (reduce memory consumption)"
|
||||
detect_res = "Pixel resolution for detection"
|
||||
image_res = "Pixel resolution for output image"
|
||||
safe_mode = "Whether or not to use safe mode"
|
||||
scribble_mode = "Whether or not to use scribble mode"
|
||||
scale_factor = "The factor by which to scale"
|
||||
blend_alpha = (
|
||||
"Blending factor. 0.0 = use input A only, 1.0 = use input B only, 0.5 = 50% mix of input A and input B."
|
||||
)
|
||||
num_1 = "The first number"
|
||||
num_2 = "The second number"
|
||||
mask = "The mask to use for the operation"
|
||||
board = "The board to save the image to"
|
||||
image = "The image to process"
|
||||
tile_size = "Tile size"
|
||||
inclusive_low = "The inclusive low value"
|
||||
exclusive_high = "The exclusive high value"
|
||||
decimal_places = "The number of decimal places to round to"
|
||||
freeu_s1 = 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
||||
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
||||
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
|
||||
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.fields import FieldDescriptions
|
||||
from invokeai.app.shared.fields import FieldDescriptions
|
||||
|
||||
|
||||
class FreeUConfig(BaseModel):
|
||||
|
||||
@@ -1,17 +1,12 @@
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException, ProgressImage
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.app.services.invocation_processor.invocation_processor_common import CanceledException, ProgressImage
|
||||
|
||||
from ...backend.model_manager import BaseModelType
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.util.util import image_to_dataURL
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from ..invocations.baseinvocation import InvocationContext
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||
@@ -30,13 +25,13 @@ def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=
|
||||
|
||||
|
||||
def stable_diffusion_step_callback(
|
||||
context_data: "InvocationContextData",
|
||||
context: InvocationContext,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
base_model: BaseModelType,
|
||||
events: "EventServiceBase",
|
||||
is_canceled: Callable[[], bool],
|
||||
) -> None:
|
||||
if is_canceled():
|
||||
):
|
||||
if context.services.queue.is_canceled(context.graph_execution_state_id):
|
||||
raise CanceledException
|
||||
|
||||
# Some schedulers report not only the noisy latents at the current timestep,
|
||||
@@ -113,13 +108,13 @@ def stable_diffusion_step_callback(
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
events.emit_generator_progress(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
node_id=context_data.invocation.id,
|
||||
source_node_id=context_data.source_invocation_id,
|
||||
context.services.events.emit_generator_progress(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
queue_batch_id=context.queue_batch_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
order=intermediate_state.order,
|
||||
|
||||
@@ -1,47 +1,8 @@
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.model_records import UnknownModelException
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelType
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
|
||||
def extract_ti_triggers_from_prompt(prompt: str) -> List[str]:
|
||||
ti_triggers: List[str] = []
|
||||
def extract_ti_triggers_from_prompt(prompt: str) -> list[str]:
|
||||
ti_triggers = []
|
||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", prompt):
|
||||
ti_triggers.append(str(trigger))
|
||||
ti_triggers.append(trigger)
|
||||
return ti_triggers
|
||||
|
||||
|
||||
def generate_ti_list(
|
||||
prompt: str, base: BaseModelType, context: InvocationContext
|
||||
) -> List[Tuple[str, TextualInversionModelRaw]]:
|
||||
ti_list: List[Tuple[str, TextualInversionModelRaw]] = []
|
||||
for trigger in extract_ti_triggers_from_prompt(prompt):
|
||||
name_or_key = trigger[1:-1]
|
||||
try:
|
||||
loaded_model = context.models.load(key=name_or_key)
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
assert loaded_model.config.base == base
|
||||
ti_list.append((name_or_key, model))
|
||||
except UnknownModelException:
|
||||
try:
|
||||
loaded_model = context.models.load_by_attrs(
|
||||
model_name=name_or_key, base_model=base, model_type=ModelType.TextualInversion
|
||||
)
|
||||
model = loaded_model.model
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
assert loaded_model.config.base == base
|
||||
ti_list.append((name_or_key, model))
|
||||
except UnknownModelException:
|
||||
pass
|
||||
except ValueError:
|
||||
logger.warning(f'trigger: "{trigger}" more than one similarly-named textual inversion models')
|
||||
except AssertionError:
|
||||
logger.warning(f'trigger: "{trigger}" not a valid textual inversion model for this graph')
|
||||
except Exception:
|
||||
logger.warning(f'Failed to load TI model for trigger: "{trigger}"')
|
||||
return ti_list
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from controlnet_aux.util import resize_image
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.dw_openpose.utils import draw_bodypose, draw_facepose, draw_handpose
|
||||
from invokeai.backend.image_util.dw_openpose.wholebody import Wholebody
|
||||
|
||||
|
||||
def draw_pose(pose, H, W, draw_face=True, draw_body=True, draw_hands=True, resolution=512):
|
||||
bodies = pose["bodies"]
|
||||
faces = pose["faces"]
|
||||
hands = pose["hands"]
|
||||
candidate = bodies["candidate"]
|
||||
subset = bodies["subset"]
|
||||
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
||||
|
||||
if draw_body:
|
||||
canvas = draw_bodypose(canvas, candidate, subset)
|
||||
|
||||
if draw_hands:
|
||||
canvas = draw_handpose(canvas, hands)
|
||||
|
||||
if draw_face:
|
||||
canvas = draw_facepose(canvas, faces)
|
||||
|
||||
dwpose_image = resize_image(
|
||||
canvas,
|
||||
resolution,
|
||||
)
|
||||
dwpose_image = Image.fromarray(dwpose_image)
|
||||
|
||||
return dwpose_image
|
||||
|
||||
|
||||
class DWOpenposeDetector:
|
||||
"""
|
||||
Code from the original implementation of the DW Openpose Detector.
|
||||
Credits: https://github.com/IDEA-Research/DWPose
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.pose_estimation = Wholebody()
|
||||
|
||||
def __call__(
|
||||
self, image: Image.Image, draw_face=False, draw_body=True, draw_hands=False, resolution=512
|
||||
) -> Image.Image:
|
||||
np_image = np.array(image)
|
||||
H, W, C = np_image.shape
|
||||
|
||||
with torch.no_grad():
|
||||
candidate, subset = self.pose_estimation(np_image)
|
||||
nums, keys, locs = candidate.shape
|
||||
candidate[..., 0] /= float(W)
|
||||
candidate[..., 1] /= float(H)
|
||||
body = candidate[:, :18].copy()
|
||||
body = body.reshape(nums * 18, locs)
|
||||
score = subset[:, :18]
|
||||
for i in range(len(score)):
|
||||
for j in range(len(score[i])):
|
||||
if score[i][j] > 0.3:
|
||||
score[i][j] = int(18 * i + j)
|
||||
else:
|
||||
score[i][j] = -1
|
||||
|
||||
un_visible = subset < 0.3
|
||||
candidate[un_visible] = -1
|
||||
|
||||
# foot = candidate[:, 18:24]
|
||||
|
||||
faces = candidate[:, 24:92]
|
||||
|
||||
hands = candidate[:, 92:113]
|
||||
hands = np.vstack([hands, candidate[:, 113:]])
|
||||
|
||||
bodies = {"candidate": body, "subset": score}
|
||||
pose = {"bodies": bodies, "hands": hands, "faces": faces}
|
||||
|
||||
return draw_pose(
|
||||
pose, H, W, draw_face=draw_face, draw_hands=draw_hands, draw_body=draw_body, resolution=resolution
|
||||
)
|
||||
@@ -1,128 +0,0 @@
|
||||
# Code from the original DWPose Implementation: https://github.com/IDEA-Research/DWPose
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def nms(boxes, scores, nms_thr):
|
||||
"""Single class NMS implemented in Numpy."""
|
||||
x1 = boxes[:, 0]
|
||||
y1 = boxes[:, 1]
|
||||
x2 = boxes[:, 2]
|
||||
y2 = boxes[:, 3]
|
||||
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= nms_thr)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
def multiclass_nms(boxes, scores, nms_thr, score_thr):
|
||||
"""Multiclass NMS implemented in Numpy. Class-aware version."""
|
||||
final_dets = []
|
||||
num_classes = scores.shape[1]
|
||||
for cls_ind in range(num_classes):
|
||||
cls_scores = scores[:, cls_ind]
|
||||
valid_score_mask = cls_scores > score_thr
|
||||
if valid_score_mask.sum() == 0:
|
||||
continue
|
||||
else:
|
||||
valid_scores = cls_scores[valid_score_mask]
|
||||
valid_boxes = boxes[valid_score_mask]
|
||||
keep = nms(valid_boxes, valid_scores, nms_thr)
|
||||
if len(keep) > 0:
|
||||
cls_inds = np.ones((len(keep), 1)) * cls_ind
|
||||
dets = np.concatenate([valid_boxes[keep], valid_scores[keep, None], cls_inds], 1)
|
||||
final_dets.append(dets)
|
||||
if len(final_dets) == 0:
|
||||
return None
|
||||
return np.concatenate(final_dets, 0)
|
||||
|
||||
|
||||
def demo_postprocess(outputs, img_size, p6=False):
|
||||
grids = []
|
||||
expanded_strides = []
|
||||
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
|
||||
|
||||
hsizes = [img_size[0] // stride for stride in strides]
|
||||
wsizes = [img_size[1] // stride for stride in strides]
|
||||
|
||||
for hsize, wsize, stride in zip(hsizes, wsizes, strides, strict=False):
|
||||
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
|
||||
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
|
||||
grids.append(grid)
|
||||
shape = grid.shape[:2]
|
||||
expanded_strides.append(np.full((*shape, 1), stride))
|
||||
|
||||
grids = np.concatenate(grids, 1)
|
||||
expanded_strides = np.concatenate(expanded_strides, 1)
|
||||
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
|
||||
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def preprocess(img, input_size, swap=(2, 0, 1)):
|
||||
if len(img.shape) == 3:
|
||||
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
||||
else:
|
||||
padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
||||
|
||||
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
||||
resized_img = cv2.resize(
|
||||
img,
|
||||
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
).astype(np.uint8)
|
||||
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
||||
|
||||
padded_img = padded_img.transpose(swap)
|
||||
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
||||
return padded_img, r
|
||||
|
||||
|
||||
def inference_detector(session, oriImg):
|
||||
input_shape = (640, 640)
|
||||
img, ratio = preprocess(oriImg, input_shape)
|
||||
|
||||
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
|
||||
output = session.run(None, ort_inputs)
|
||||
predictions = demo_postprocess(output[0], input_shape)[0]
|
||||
|
||||
boxes = predictions[:, :4]
|
||||
scores = predictions[:, 4:5] * predictions[:, 5:]
|
||||
|
||||
boxes_xyxy = np.ones_like(boxes)
|
||||
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2.0
|
||||
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2.0
|
||||
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2.0
|
||||
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2.0
|
||||
boxes_xyxy /= ratio
|
||||
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
|
||||
if dets is not None:
|
||||
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
|
||||
isscore = final_scores > 0.3
|
||||
iscat = final_cls_inds == 0
|
||||
isbbox = [i and j for (i, j) in zip(isscore, iscat, strict=False)]
|
||||
final_boxes = final_boxes[isbbox]
|
||||
else:
|
||||
final_boxes = np.array([])
|
||||
|
||||
return final_boxes
|
||||
@@ -1,361 +0,0 @@
|
||||
# Code from the original DWPose Implementation: https://github.com/IDEA-Research/DWPose
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
def preprocess(
|
||||
img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Do preprocessing for RTMPose model inference.
|
||||
|
||||
Args:
|
||||
img (np.ndarray): Input image in shape.
|
||||
input_size (tuple): Input image size in shape (w, h).
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- resized_img (np.ndarray): Preprocessed image.
|
||||
- center (np.ndarray): Center of image.
|
||||
- scale (np.ndarray): Scale of image.
|
||||
"""
|
||||
# get shape of image
|
||||
img_shape = img.shape[:2]
|
||||
out_img, out_center, out_scale = [], [], []
|
||||
if len(out_bbox) == 0:
|
||||
out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
|
||||
for i in range(len(out_bbox)):
|
||||
x0 = out_bbox[i][0]
|
||||
y0 = out_bbox[i][1]
|
||||
x1 = out_bbox[i][2]
|
||||
y1 = out_bbox[i][3]
|
||||
bbox = np.array([x0, y0, x1, y1])
|
||||
|
||||
# get center and scale
|
||||
center, scale = bbox_xyxy2cs(bbox, padding=1.25)
|
||||
|
||||
# do affine transformation
|
||||
resized_img, scale = top_down_affine(input_size, scale, center, img)
|
||||
|
||||
# normalize image
|
||||
mean = np.array([123.675, 116.28, 103.53])
|
||||
std = np.array([58.395, 57.12, 57.375])
|
||||
resized_img = (resized_img - mean) / std
|
||||
|
||||
out_img.append(resized_img)
|
||||
out_center.append(center)
|
||||
out_scale.append(scale)
|
||||
|
||||
return out_img, out_center, out_scale
|
||||
|
||||
|
||||
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
|
||||
"""Inference RTMPose model.
|
||||
|
||||
Args:
|
||||
sess (ort.InferenceSession): ONNXRuntime session.
|
||||
img (np.ndarray): Input image in shape.
|
||||
|
||||
Returns:
|
||||
outputs (np.ndarray): Output of RTMPose model.
|
||||
"""
|
||||
all_out = []
|
||||
# build input
|
||||
for i in range(len(img)):
|
||||
input = [img[i].transpose(2, 0, 1)]
|
||||
|
||||
# build output
|
||||
sess_input = {sess.get_inputs()[0].name: input}
|
||||
sess_output = []
|
||||
for out in sess.get_outputs():
|
||||
sess_output.append(out.name)
|
||||
|
||||
# run model
|
||||
outputs = sess.run(sess_output, sess_input)
|
||||
all_out.append(outputs)
|
||||
|
||||
return all_out
|
||||
|
||||
|
||||
def postprocess(
|
||||
outputs: List[np.ndarray],
|
||||
model_input_size: Tuple[int, int],
|
||||
center: Tuple[int, int],
|
||||
scale: Tuple[int, int],
|
||||
simcc_split_ratio: float = 2.0,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Postprocess for RTMPose model output.
|
||||
|
||||
Args:
|
||||
outputs (np.ndarray): Output of RTMPose model.
|
||||
model_input_size (tuple): RTMPose model Input image size.
|
||||
center (tuple): Center of bbox in shape (x, y).
|
||||
scale (tuple): Scale of bbox in shape (w, h).
|
||||
simcc_split_ratio (float): Split ratio of simcc.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- keypoints (np.ndarray): Rescaled keypoints.
|
||||
- scores (np.ndarray): Model predict scores.
|
||||
"""
|
||||
all_key = []
|
||||
all_score = []
|
||||
for i in range(len(outputs)):
|
||||
# use simcc to decode
|
||||
simcc_x, simcc_y = outputs[i]
|
||||
keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
|
||||
|
||||
# rescale keypoints
|
||||
keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
|
||||
all_key.append(keypoints[0])
|
||||
all_score.append(scores[0])
|
||||
|
||||
return np.array(all_key), np.array(all_score)
|
||||
|
||||
|
||||
def bbox_xyxy2cs(bbox: np.ndarray, padding: float = 1.0) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Transform the bbox format from (x,y,w,h) into (center, scale)
|
||||
|
||||
Args:
|
||||
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
|
||||
as (left, top, right, bottom)
|
||||
padding (float): BBox padding factor that will be multilied to scale.
|
||||
Default: 1.0
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing center and scale.
|
||||
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
|
||||
(n, 2)
|
||||
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
|
||||
(n, 2)
|
||||
"""
|
||||
# convert single bbox from (4, ) to (1, 4)
|
||||
dim = bbox.ndim
|
||||
if dim == 1:
|
||||
bbox = bbox[None, :]
|
||||
|
||||
# get bbox center and scale
|
||||
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
|
||||
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
|
||||
scale = np.hstack([x2 - x1, y2 - y1]) * padding
|
||||
|
||||
if dim == 1:
|
||||
center = center[0]
|
||||
scale = scale[0]
|
||||
|
||||
return center, scale
|
||||
|
||||
|
||||
def _fix_aspect_ratio(bbox_scale: np.ndarray, aspect_ratio: float) -> np.ndarray:
|
||||
"""Extend the scale to match the given aspect ratio.
|
||||
|
||||
Args:
|
||||
scale (np.ndarray): The image scale (w, h) in shape (2, )
|
||||
aspect_ratio (float): The ratio of ``w/h``
|
||||
|
||||
Returns:
|
||||
np.ndarray: The reshaped image scale in (2, )
|
||||
"""
|
||||
w, h = np.hsplit(bbox_scale, [1])
|
||||
bbox_scale = np.where(w > h * aspect_ratio, np.hstack([w, w / aspect_ratio]), np.hstack([h * aspect_ratio, h]))
|
||||
return bbox_scale
|
||||
|
||||
|
||||
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
|
||||
"""Rotate a point by an angle.
|
||||
|
||||
Args:
|
||||
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
|
||||
angle_rad (float): rotation angle in radian
|
||||
|
||||
Returns:
|
||||
np.ndarray: Rotated point in shape (2, )
|
||||
"""
|
||||
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
||||
rot_mat = np.array([[cs, -sn], [sn, cs]])
|
||||
return rot_mat @ pt
|
||||
|
||||
|
||||
def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
||||
"""To calculate the affine matrix, three pairs of points are required. This
|
||||
function is used to get the 3rd point, given 2D points a & b.
|
||||
|
||||
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
||||
anticlockwise, using b as the rotation center.
|
||||
|
||||
Args:
|
||||
a (np.ndarray): The 1st point (x,y) in shape (2, )
|
||||
b (np.ndarray): The 2nd point (x,y) in shape (2, )
|
||||
|
||||
Returns:
|
||||
np.ndarray: The 3rd point.
|
||||
"""
|
||||
direction = a - b
|
||||
c = b + np.r_[-direction[1], direction[0]]
|
||||
return c
|
||||
|
||||
|
||||
def get_warp_matrix(
|
||||
center: np.ndarray,
|
||||
scale: np.ndarray,
|
||||
rot: float,
|
||||
output_size: Tuple[int, int],
|
||||
shift: Tuple[float, float] = (0.0, 0.0),
|
||||
inv: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""Calculate the affine transformation matrix that can warp the bbox area
|
||||
in the input image to the output size.
|
||||
|
||||
Args:
|
||||
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
||||
scale (np.ndarray[2, ]): Scale of the bounding box
|
||||
wrt [width, height].
|
||||
rot (float): Rotation angle (degree).
|
||||
output_size (np.ndarray[2, ] | list(2,)): Size of the
|
||||
destination heatmaps.
|
||||
shift (0-100%): Shift translation ratio wrt the width/height.
|
||||
Default (0., 0.).
|
||||
inv (bool): Option to inverse the affine transform direction.
|
||||
(inv=False: src->dst or inv=True: dst->src)
|
||||
|
||||
Returns:
|
||||
np.ndarray: A 2x3 transformation matrix
|
||||
"""
|
||||
shift = np.array(shift)
|
||||
src_w = scale[0]
|
||||
dst_w = output_size[0]
|
||||
dst_h = output_size[1]
|
||||
|
||||
# compute transformation matrix
|
||||
rot_rad = np.deg2rad(rot)
|
||||
src_dir = _rotate_point(np.array([0.0, src_w * -0.5]), rot_rad)
|
||||
dst_dir = np.array([0.0, dst_w * -0.5])
|
||||
|
||||
# get four corners of the src rectangle in the original image
|
||||
src = np.zeros((3, 2), dtype=np.float32)
|
||||
src[0, :] = center + scale * shift
|
||||
src[1, :] = center + src_dir + scale * shift
|
||||
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
||||
|
||||
# get four corners of the dst rectangle in the input image
|
||||
dst = np.zeros((3, 2), dtype=np.float32)
|
||||
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
||||
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
||||
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
||||
|
||||
if inv:
|
||||
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
||||
else:
|
||||
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
||||
|
||||
return warp_mat
|
||||
|
||||
|
||||
def top_down_affine(
|
||||
input_size: dict, bbox_scale: dict, bbox_center: dict, img: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Get the bbox image as the model input by affine transform.
|
||||
|
||||
Args:
|
||||
input_size (dict): The input size of the model.
|
||||
bbox_scale (dict): The bbox scale of the img.
|
||||
bbox_center (dict): The bbox center of the img.
|
||||
img (np.ndarray): The original image.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing center and scale.
|
||||
- np.ndarray[float32]: img after affine transform.
|
||||
- np.ndarray[float32]: bbox scale after affine transform.
|
||||
"""
|
||||
w, h = input_size
|
||||
warp_size = (int(w), int(h))
|
||||
|
||||
# reshape bbox to fixed aspect ratio
|
||||
bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
|
||||
|
||||
# get the affine matrix
|
||||
center = bbox_center
|
||||
scale = bbox_scale
|
||||
rot = 0
|
||||
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
|
||||
|
||||
# do affine transform
|
||||
img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
|
||||
|
||||
return img, bbox_scale
|
||||
|
||||
|
||||
def get_simcc_maximum(simcc_x: np.ndarray, simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Get maximum response location and value from simcc representations.
|
||||
|
||||
Note:
|
||||
instance number: N
|
||||
num_keypoints: K
|
||||
heatmap height: H
|
||||
heatmap width: W
|
||||
|
||||
Args:
|
||||
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
|
||||
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
||||
(K, 2) or (N, K, 2)
|
||||
- vals (np.ndarray): values of maximum heatmap responses in shape
|
||||
(K,) or (N, K)
|
||||
"""
|
||||
N, K, Wx = simcc_x.shape
|
||||
simcc_x = simcc_x.reshape(N * K, -1)
|
||||
simcc_y = simcc_y.reshape(N * K, -1)
|
||||
|
||||
# get maximum value locations
|
||||
x_locs = np.argmax(simcc_x, axis=1)
|
||||
y_locs = np.argmax(simcc_y, axis=1)
|
||||
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
||||
max_val_x = np.amax(simcc_x, axis=1)
|
||||
max_val_y = np.amax(simcc_y, axis=1)
|
||||
|
||||
# get maximum value across x and y axis
|
||||
mask = max_val_x > max_val_y
|
||||
max_val_x[mask] = max_val_y[mask]
|
||||
vals = max_val_x
|
||||
locs[vals <= 0.0] = -1
|
||||
|
||||
# reshape
|
||||
locs = locs.reshape(N, K, 2)
|
||||
vals = vals.reshape(N, K)
|
||||
|
||||
return locs, vals
|
||||
|
||||
|
||||
def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Modulate simcc distribution with Gaussian.
|
||||
|
||||
Args:
|
||||
simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
|
||||
simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
|
||||
simcc_split_ratio (int): The split ratio of simcc.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing center and scale.
|
||||
- np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
|
||||
- np.ndarray[float32]: scores in shape (K,) or (n, K)
|
||||
"""
|
||||
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
|
||||
keypoints /= simcc_split_ratio
|
||||
|
||||
return keypoints, scores
|
||||
|
||||
|
||||
def inference_pose(session, out_bbox, oriImg):
|
||||
h, w = session.get_inputs()[0].shape[2:]
|
||||
model_input_size = (w, h)
|
||||
resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
|
||||
outputs = inference(session, resized_img)
|
||||
keypoints, scores = postprocess(outputs, model_input_size, center, scale)
|
||||
|
||||
return keypoints, scores
|
||||
@@ -1,155 +0,0 @@
|
||||
# Code from the original DWPose Implementation: https://github.com/IDEA-Research/DWPose
|
||||
|
||||
import math
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
|
||||
eps = 0.01
|
||||
|
||||
|
||||
def draw_bodypose(canvas, candidate, subset):
|
||||
H, W, C = canvas.shape
|
||||
candidate = np.array(candidate)
|
||||
subset = np.array(subset)
|
||||
|
||||
stickwidth = 4
|
||||
|
||||
limbSeq = [
|
||||
[2, 3],
|
||||
[2, 6],
|
||||
[3, 4],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[2, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[2, 12],
|
||||
[12, 13],
|
||||
[13, 14],
|
||||
[2, 1],
|
||||
[1, 15],
|
||||
[15, 17],
|
||||
[1, 16],
|
||||
[16, 18],
|
||||
[3, 17],
|
||||
[6, 18],
|
||||
]
|
||||
|
||||
colors = [
|
||||
[255, 0, 0],
|
||||
[255, 85, 0],
|
||||
[255, 170, 0],
|
||||
[255, 255, 0],
|
||||
[170, 255, 0],
|
||||
[85, 255, 0],
|
||||
[0, 255, 0],
|
||||
[0, 255, 85],
|
||||
[0, 255, 170],
|
||||
[0, 255, 255],
|
||||
[0, 170, 255],
|
||||
[0, 85, 255],
|
||||
[0, 0, 255],
|
||||
[85, 0, 255],
|
||||
[170, 0, 255],
|
||||
[255, 0, 255],
|
||||
[255, 0, 170],
|
||||
[255, 0, 85],
|
||||
]
|
||||
|
||||
for i in range(17):
|
||||
for n in range(len(subset)):
|
||||
index = subset[n][np.array(limbSeq[i]) - 1]
|
||||
if -1 in index:
|
||||
continue
|
||||
Y = candidate[index.astype(int), 0] * float(W)
|
||||
X = candidate[index.astype(int), 1] * float(H)
|
||||
mX = np.mean(X)
|
||||
mY = np.mean(Y)
|
||||
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
||||
cv2.fillConvexPoly(canvas, polygon, colors[i])
|
||||
|
||||
canvas = (canvas * 0.6).astype(np.uint8)
|
||||
|
||||
for i in range(18):
|
||||
for n in range(len(subset)):
|
||||
index = int(subset[n][i])
|
||||
if index == -1:
|
||||
continue
|
||||
x, y = candidate[index][0:2]
|
||||
x = int(x * W)
|
||||
y = int(y * H)
|
||||
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
||||
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_handpose(canvas, all_hand_peaks):
|
||||
H, W, C = canvas.shape
|
||||
|
||||
edges = [
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
[2, 3],
|
||||
[3, 4],
|
||||
[0, 5],
|
||||
[5, 6],
|
||||
[6, 7],
|
||||
[7, 8],
|
||||
[0, 9],
|
||||
[9, 10],
|
||||
[10, 11],
|
||||
[11, 12],
|
||||
[0, 13],
|
||||
[13, 14],
|
||||
[14, 15],
|
||||
[15, 16],
|
||||
[0, 17],
|
||||
[17, 18],
|
||||
[18, 19],
|
||||
[19, 20],
|
||||
]
|
||||
|
||||
for peaks in all_hand_peaks:
|
||||
peaks = np.array(peaks)
|
||||
|
||||
for ie, e in enumerate(edges):
|
||||
x1, y1 = peaks[e[0]]
|
||||
x2, y2 = peaks[e[1]]
|
||||
x1 = int(x1 * W)
|
||||
y1 = int(y1 * H)
|
||||
x2 = int(x2 * W)
|
||||
y2 = int(y2 * H)
|
||||
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
||||
cv2.line(
|
||||
canvas,
|
||||
(x1, y1),
|
||||
(x2, y2),
|
||||
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
|
||||
thickness=2,
|
||||
)
|
||||
|
||||
for _, keyponit in enumerate(peaks):
|
||||
x, y = keyponit
|
||||
x = int(x * W)
|
||||
y = int(y * H)
|
||||
if x > eps and y > eps:
|
||||
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
||||
return canvas
|
||||
|
||||
|
||||
def draw_facepose(canvas, all_lmks):
|
||||
H, W, C = canvas.shape
|
||||
for lmks in all_lmks:
|
||||
lmks = np.array(lmks)
|
||||
for lmk in lmks:
|
||||
x, y = lmk
|
||||
x = int(x * W)
|
||||
y = int(y * H)
|
||||
if x > eps and y > eps:
|
||||
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
|
||||
return canvas
|
||||
@@ -1,67 +0,0 @@
|
||||
# Code from the original DWPose Implementation: https://github.com/IDEA-Research/DWPose
|
||||
# Modified pathing to suit Invoke
|
||||
|
||||
import pathlib
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.util import download_with_progress_bar
|
||||
|
||||
from .onnxdet import inference_detector
|
||||
from .onnxpose import inference_pose
|
||||
|
||||
DWPOSE_MODELS = {
|
||||
"yolox_l.onnx": {
|
||||
"local": "any/annotators/dwpose/yolox_l.onnx",
|
||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/yolox_l.onnx?download=true",
|
||||
},
|
||||
"dw-ll_ucoco_384.onnx": {
|
||||
"local": "any/annotators/dwpose/dw-ll_ucoco_384.onnx",
|
||||
"url": "https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.onnx?download=true",
|
||||
},
|
||||
}
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
|
||||
class Wholebody:
|
||||
def __init__(self):
|
||||
device = choose_torch_device()
|
||||
|
||||
providers = ["CUDAExecutionProvider"] if device == "cuda" else ["CPUExecutionProvider"]
|
||||
|
||||
DET_MODEL_PATH = pathlib.Path(config.models_path / DWPOSE_MODELS["yolox_l.onnx"]["local"])
|
||||
if not DET_MODEL_PATH.exists():
|
||||
download_with_progress_bar(DWPOSE_MODELS["yolox_l.onnx"]["url"], DET_MODEL_PATH)
|
||||
|
||||
POSE_MODEL_PATH = pathlib.Path(config.models_path / DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["local"])
|
||||
if not POSE_MODEL_PATH.exists():
|
||||
download_with_progress_bar(DWPOSE_MODELS["dw-ll_ucoco_384.onnx"]["url"], POSE_MODEL_PATH)
|
||||
|
||||
onnx_det = DET_MODEL_PATH
|
||||
onnx_pose = POSE_MODEL_PATH
|
||||
|
||||
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
||||
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
||||
|
||||
def __call__(self, oriImg):
|
||||
det_result = inference_detector(self.session_det, oriImg)
|
||||
keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
|
||||
|
||||
keypoints_info = np.concatenate((keypoints, scores[..., None]), axis=-1)
|
||||
# compute neck joint
|
||||
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
||||
# neck score when visualizing pred
|
||||
neck[:, 2:4] = np.logical_and(keypoints_info[:, 5, 2:4] > 0.3, keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
||||
new_keypoints_info = np.insert(keypoints_info, 17, neck, axis=1)
|
||||
mmpose_idx = [17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]
|
||||
openpose_idx = [1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]
|
||||
new_keypoints_info[:, openpose_idx] = new_keypoints_info[:, mmpose_idx]
|
||||
keypoints_info = new_keypoints_info
|
||||
|
||||
keypoints, scores = keypoints_info[..., :2], keypoints_info[..., 2]
|
||||
|
||||
return keypoints, scores
|
||||
@@ -8,6 +8,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
def check_invokeai_root(config: InvokeAIAppConfig):
|
||||
try:
|
||||
assert config.model_conf_path.exists(), f"{config.model_conf_path} not found"
|
||||
assert config.db_path.parent.exists(), f"{config.db_path.parent} not found"
|
||||
assert config.models_path.exists(), f"{config.models_path} not found"
|
||||
if not config.ignore_missing_core_models:
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
"""Utility (backend) functions used by model_install.py"""
|
||||
import re
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import omegaconf
|
||||
from huggingface_hub import HfFolder
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -15,8 +18,12 @@ from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.model_install import (
|
||||
HFModelSource,
|
||||
LocalModelSource,
|
||||
ModelInstallService,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
@@ -24,6 +31,7 @@ from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.backend.model_manager import (
|
||||
BaseModelType,
|
||||
InvalidModelConfigException,
|
||||
ModelRepoVariant,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import UnknownMetadataException
|
||||
@@ -218,13 +226,37 @@ class InstallHelper(object):
|
||||
additional_models.append(reverse_source[requirement])
|
||||
model_list.extend(additional_models)
|
||||
|
||||
def _make_install_source(self, model_info: UnifiedModelInfo) -> ModelSource:
|
||||
assert model_info.source
|
||||
model_path_id_or_url = model_info.source.strip("\"' ")
|
||||
model_path = Path(model_path_id_or_url)
|
||||
|
||||
if model_path.exists(): # local file on disk
|
||||
return LocalModelSource(path=model_path.absolute(), inplace=True)
|
||||
|
||||
# parsing huggingface repo ids
|
||||
# we're going to do a little trick that allows for extended repo_ids of form "foo/bar:fp16"
|
||||
variants = "|".join([x.lower() for x in ModelRepoVariant.__members__])
|
||||
if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url):
|
||||
repo_id = match.group(1)
|
||||
repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None
|
||||
subfolder = Path(model_info.subfolder) if model_info.subfolder else None
|
||||
return HFModelSource(
|
||||
repo_id=repo_id,
|
||||
access_token=HfFolder.get_token(),
|
||||
subfolder=subfolder,
|
||||
variant=repo_variant,
|
||||
)
|
||||
if re.match(r"^(http|https):", model_path_id_or_url):
|
||||
return URLModelSource(url=AnyHttpUrl(model_path_id_or_url))
|
||||
raise ValueError(f"Unsupported model source: {model_path_id_or_url}")
|
||||
|
||||
def add_or_delete(self, selections: InstallSelections) -> None:
|
||||
"""Add or delete selected models."""
|
||||
installer = self._installer
|
||||
self._add_required_models(selections.install_models)
|
||||
for model in selections.install_models:
|
||||
assert model.source
|
||||
model_path_id_or_url = model.source.strip("\"' ")
|
||||
source = self._make_install_source(model)
|
||||
config = (
|
||||
{
|
||||
"description": model.description,
|
||||
@@ -235,12 +267,12 @@ class InstallHelper(object):
|
||||
)
|
||||
|
||||
try:
|
||||
installer.heuristic_import(
|
||||
source=model_path_id_or_url,
|
||||
installer.import_model(
|
||||
source=source,
|
||||
config=config,
|
||||
)
|
||||
except (UnknownMetadataException, InvalidModelConfigException, HTTPError, OSError) as e:
|
||||
self._logger.warning(f"{model.source}: {e}")
|
||||
self._logger.warning(f"{source}: {e}")
|
||||
|
||||
for model_to_remove in selections.remove_models:
|
||||
parts = model_to_remove.split("/")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user