Compare commits

..

1 Commits

Author SHA1 Message Date
psychedelicious
ff68ae7710 Update LOCAL_DEVELOPMENT.md
Add debugger config for UI
2023-10-04 20:21:26 +11:00
210 changed files with 3154 additions and 6169 deletions

View File

@@ -47,9 +47,34 @@ pip install ".[dev,test]"
These are optional groups of packages which are defined within the `pyproject.toml`
and will be required for testing the changes you make to the code.
### Tests
### Running Tests
We use [pytest](https://docs.pytest.org/en/7.2.x/) for our test suite. Tests can
be found under the `./tests` folder and can be run with a single `pytest`
command. Optionally, to review test coverage you can append `--cov`.
```zsh
pytest --cov
```
Test outcomes and coverage will be reported in the terminal. In addition a more
detailed report is created in both XML and HTML format in the `./coverage`
folder. The HTML one in particular can help identify missing statements
requiring tests to ensure coverage. This can be run by opening
`./coverage/html/index.html`.
For example.
```zsh
pytest --cov; open ./coverage/html/index.html
```
??? info "HTML coverage report output"
![html-overview](../assets/contributing/html-overview.png)
![html-detail](../assets/contributing/html-detail.png)
See the [tests documentation](./TESTS.md) for information about running and writing tests.
### Reloading Changes
Experimenting with changes to the Python source code is a drag if you have to re-start the server —
@@ -200,6 +225,14 @@ Now we can create the InvokeAI debugging configs:
"program": "scripts/invokeai-cli.py",
"justMyCode": true
},
{
"type": "chrome",
"request": "launch",
"name": "InvokeAI UI",
// You have to run the UI with `yarn dev` for this to work
"url": "http://localhost:5173",
"webRoot": "${workspaceFolder}/invokeai/frontend/web"
},
{
// Run tests
"name": "InvokeAI Test",
@@ -235,7 +268,8 @@ Now we can create the InvokeAI debugging configs:
You'll see these configs in the debugging configs drop down. Running them will
start InvokeAI with attached debugger, in the correct environment, and work just
like the normal app.
like the normal app, though the UI debugger requires you to run the UI in dev
mode. See the [frontend guide](contribution_guides/contributingToFrontend.md) for setting that up.
Enjoy debugging InvokeAI with ease (not that we have any bugs of course).

View File

@@ -1,89 +0,0 @@
# InvokeAI Backend Tests
We use `pytest` to run the backend python tests. (See [pyproject.toml](/pyproject.toml) for the default `pytest` options.)
## Fast vs. Slow
All tests are categorized as either 'fast' (no test annotation) or 'slow' (annotated with the `@pytest.mark.slow` decorator).
'Fast' tests are run to validate every PR, and are fast enough that they can be run routinely during development.
'Slow' tests are currently only run manually on an ad-hoc basis. In the future, they may be automated to run nightly. Most developers are only expected to run the 'slow' tests that directly relate to the feature(s) that they are working on.
As a rule of thumb, tests should be marked as 'slow' if there is a chance that they take >1s (e.g. on a CPU-only machine with slow internet connection). Common examples of slow tests are tests that depend on downloading a model, or running model inference.
## Running Tests
Below are some common test commands:
```bash
# Run the fast tests. (This implicitly uses the configured default option: `-m "not slow"`.)
pytest tests/
# Equivalent command to run the fast tests.
pytest tests/ -m "not slow"
# Run the slow tests.
pytest tests/ -m "slow"
# Run the slow tests from a specific file.
pytest tests/path/to/slow_test.py -m "slow"
# Run all tests (fast and slow).
pytest tests -m ""
```
## Test Organization
All backend tests are in the [`tests/`](/tests/) directory. This directory mirrors the organization of the `invokeai/` directory. For example, tests for `invokeai/model_management/model_manager.py` would be found in `tests/model_management/test_model_manager.py`.
TODO: The above statement is aspirational. A re-organization of legacy tests is required to make it true.
## Tests that depend on models
There are a few things to keep in mind when adding tests that depend on models.
1. If a required model is not already present, it should automatically be downloaded as part of the test setup.
2. If a model is already downloaded, it should not be re-downloaded unnecessarily.
3. Take reasonable care to keep the total number of models required for the tests low. Whenever possible, re-use models that are already required for other tests. If you are adding a new model, consider including a comment to explain why it is required/unique.
There are several utilities to help with model setup for tests. Here is a sample test that depends on a model:
```python
import pytest
import torch
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.util.test_utils import install_and_load_model
@pytest.mark.slow
def test_model(model_installer, torch_device):
model_info = install_and_load_model(
model_installer=model_installer,
model_path_id_or_url="HF/dummy_model_id",
model_name="dummy_model",
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.Dummy,
)
dummy_input = build_dummy_input(torch_device)
with torch.no_grad(), model_info as model:
model.to(torch_device, dtype=torch.float32)
output = model(dummy_input)
# Validate output...
```
## Test Coverage
To review test coverage, append `--cov` to your pytest command:
```bash
pytest tests/ --cov
```
Test outcomes and coverage will be reported in the terminal. In addition, a more detailed report is created in both XML and HTML format in the `./coverage` folder. The HTML output is particularly helpful in identifying untested statements where coverage should be improved. The HTML report can be viewed by opening `./coverage/html/index.html`.
??? info "HTML coverage report output"
![html-overview](../assets/contributing/html-overview.png)
![html-detail](../assets/contributing/html-detail.png)

View File

@@ -12,7 +12,7 @@ To get started, take a look at our [new contributors checklist](newContributorCh
Once you're setup, for more information, you can review the documentation specific to your area of interest:
* #### [InvokeAI Architecure](../ARCHITECTURE.md)
* #### [Frontend Documentation](./contributingToFrontend.md)
* #### [Frontend Documentation](development_guides/contributingToFrontend.md)
* #### [Node Documentation](../INVOCATIONS.md)
* #### [Local Development](../LOCAL_DEVELOPMENT.md)

View File

@@ -256,10 +256,6 @@ manager, please follow these steps:
*highly recommended** if your virtual environment is located outside of
your runtime directory.
!!! tip
On linux, it is recommended to run invokeai with the following env var: `MALLOC_MMAP_THRESHOLD_=1048576`. For example: `MALLOC_MMAP_THRESHOLD_=1048576 invokeai --web`. This helps to prevent memory fragmentation that can lead to memory accumulation over time. This env var is set automatically when running via `invoke.sh`.
10. Render away!
Browse the [features](../features/index.md) section to learn about all the

View File

@@ -10,20 +10,6 @@ To use a community workflow, download the the `.json` node graph file and load i
--------------------------------
--------------------------------
### Make 3D
**Description:** Create compelling 3D stereo images from 2D originals.
**Node Link:** [https://gitlab.com/srcrr/shift3d/-/raw/main/make3d.py](https://gitlab.com/srcrr/shift3d)
**Example Node Graph:** https://gitlab.com/srcrr/shift3d/-/raw/main/example-workflow.json?ref_type=heads&inline=false
**Output Examples**
![Painting of a cozy delapidated house](https://gitlab.com/srcrr/shift3d/-/raw/main/example-1.png){: style="height:512px;width:512px"}
![Photo of cute puppies](https://gitlab.com/srcrr/shift3d/-/raw/main/example-2.png){: style="height:512px;width:512px"}
--------------------------------
### Ideal Size

View File

@@ -46,9 +46,6 @@ if [ "$(uname -s)" == "Darwin" ]; then
export PYTORCH_ENABLE_MPS_FALLBACK=1
fi
# Avoid glibc memory fragmentation. See invokeai/backend/model_management/README.md for details.
export MALLOC_MMAP_THRESHOLD_=1048576
# Primary function for the case statement to determine user input
do_choice() {
case $1 in

View File

@@ -68,7 +68,6 @@ class FieldDescriptions:
height = "Height of output (px)"
control = "ControlNet(s) to apply"
ip_adapter = "IP-Adapter to apply"
t2i_adapter = "T2I-Adapter(s) to apply"
denoised_latents = "Denoised latents tensor"
latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)"

View File

@@ -10,7 +10,7 @@ import torch
import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
from diffusers.models import UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@@ -33,7 +33,6 @@ from invokeai.app.invocations.primitives import (
LatentsOutput,
build_latents_output,
)
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
@@ -48,7 +47,6 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
IPAdapterData,
StableDiffusionGeneratorPipeline,
T2IAdapterData,
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
@@ -198,7 +196,7 @@ def get_scheduler(
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.3.0",
version="1.1.0",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
@@ -225,15 +223,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
input=Input.Connection,
ui_order=5,
)
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField(
ip_adapter: Optional[IPAdapterField] = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
)
t2i_adapter: Union[T2IAdapterField, list[T2IAdapterField]] = InputField(
description=FieldDescriptions.t2i_adapter, title="T2I-Adapter", default=None, input=Input.Connection, ui_order=7
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=8
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=7
)
@validator("cfg_scale")
@@ -409,150 +404,52 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_ip_adapter_data(
self,
context: InvocationContext,
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
ip_adapter: Optional[IPAdapterField],
conditioning_data: ConditioningData,
unet: UNet2DConditionModel,
exit_stack: ExitStack,
) -> Optional[list[IPAdapterData]]:
) -> Optional[IPAdapterData]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
to the `conditioning_data` (in-place).
"""
if ip_adapter is None:
return None
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
if not isinstance(ip_adapter, list):
ip_adapter = [ip_adapter]
image_encoder_model_info = context.services.model_manager.get_model(
model_name=ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=ip_adapter.image_encoder_model.base_model,
context=context,
)
if len(ip_adapter) == 0:
return None
ip_adapter_data_list = []
conditioning_data.ip_adapter_conditioning = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=single_ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=single_ip_adapter.ip_adapter_model.base_model,
context=context,
)
)
image_encoder_model_info = context.services.model_manager.get_model(
model_name=single_ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=single_ip_adapter.image_encoder_model.base_model,
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=ip_adapter.ip_adapter_model.base_model,
context=context,
)
)
input_image = context.services.images.get_pil_image(single_ip_adapter.image.image_name)
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
input_image, image_encoder_model
)
conditioning_data.ip_adapter_conditioning.append(
IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds)
)
ip_adapter_data_list.append(
IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=single_ip_adapter.weight,
begin_step_percent=single_ip_adapter.begin_step_percent,
end_step_percent=single_ip_adapter.end_step_percent,
)
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
input_image, image_encoder_model
)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds
)
return ip_adapter_data_list
def run_t2i_adapters(
self,
context: InvocationContext,
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
latents_shape: list[int],
do_classifier_free_guidance: bool,
) -> Optional[list[T2IAdapterData]]:
if t2i_adapter is None:
return None
# Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
if isinstance(t2i_adapter, T2IAdapterField):
t2i_adapter = [t2i_adapter]
if len(t2i_adapter) == 0:
return None
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_info = context.services.model_manager.get_model(
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
model_type=ModelType.T2IAdapter,
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
context=context,
)
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
max_unet_downscale = 8
elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL:
max_unet_downscale = 4
else:
raise ValueError(
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
)
t2i_adapter_model: T2IAdapter
with t2i_adapter_model_info as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor
if isinstance(t2i_adapter_model.adapter, FullAdapterXL):
# HACK(ryand): Work around a bug in FullAdapterXL. This is being addressed upstream in diffusers by
# this PR: https://github.com/huggingface/diffusers/pull/5134.
total_downscale_factor = total_downscale_factor // 2
# Resize the T2I-Adapter input image.
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
# result will match the latent image's dimensions after max_unet_downscale is applied.
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
# T2I-Adapter model.
#
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
# of the same requirements (e.g. preserving binary masks during resize).
t2i_image = prepare_control_image(
image=image,
do_classifier_free_guidance=False,
width=t2i_input_width,
height=t2i_input_height,
num_channels=t2i_adapter_model.config.in_channels,
device=t2i_adapter_model.device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,
)
adapter_state = t2i_adapter_model(t2i_image)
if do_classifier_free_guidance:
for idx, value in enumerate(adapter_state):
adapter_state[idx] = torch.cat([value] * 2, dim=0)
t2i_adapter_data.append(
T2IAdapterData(
adapter_state=adapter_state,
weight=t2i_adapter_field.weight,
begin_step_percent=t2i_adapter_field.begin_step_percent,
end_step_percent=t2i_adapter_field.end_step_percent,
)
)
return t2i_adapter_data
return IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=ip_adapter.weight,
begin_step_percent=ip_adapter.begin_step_percent,
end_step_percent=ip_adapter.end_step_percent,
)
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
@@ -625,12 +522,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
mask, masked_latents = self.prep_inpaint_mask(context, latents)
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate.
t2i_adapter_data = self.run_t2i_adapters(
context, self.t2i_adapter, latents.shape, do_classifier_free_guidance=True
)
# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
@@ -689,6 +580,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context,
ip_adapter=self.ip_adapter,
conditioning_data=conditioning_data,
unet=unet,
exit_stack=exit_stack,
)
@@ -710,9 +602,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
masked_latents=masked_latents,
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
control_data=controlnet_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
control_data=controlnet_data, # list[ControlNetData],
ip_adapter_data=ip_adapter_data, # IPAdapterData,
callback=step_callback,
)

View File

@@ -15,7 +15,6 @@ from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.ip_adapter import IPAdapterModelField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from ...version import __version__
@@ -64,7 +63,6 @@ class CoreMetadata(BaseModelExcludeNull):
model: MainModelField = Field(description="The main model used for inference")
controlnets: list[ControlField] = Field(description="The ControlNets used for inference")
ipAdapters: list[IPAdapterMetadataField] = Field(description="The IP Adapters used for inference")
t2iAdapters: list[T2IAdapterField] = Field(description="The IP Adapters used for inference")
loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference")
vae: Optional[VAEModelField] = Field(
default=None,
@@ -141,7 +139,6 @@ class MetadataAccumulatorInvocation(BaseInvocation):
model: MainModelField = InputField(description="The main model used for inference")
controlnets: list[ControlField] = InputField(description="The ControlNets used for inference")
ipAdapters: list[IPAdapterMetadataField] = InputField(description="The IP Adapters used for inference")
t2iAdapters: list[T2IAdapterField] = Field(description="The IP Adapters used for inference")
loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference")
strength: Optional[float] = InputField(
default=None,

View File

@@ -1,83 +0,0 @@
from typing import Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.primitives import ImageField
from invokeai.backend.model_management.models.base import BaseModelType
class T2IAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the T2I-Adapter model")
base_model: BaseModelType = Field(description="Base model")
class T2IAdapterField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@invocation_output("t2i_adapter_output")
class T2IAdapterOutput(BaseInvocationOutput):
t2i_adapter: T2IAdapterField = OutputField(description=FieldDescriptions.t2i_adapter, title="T2I Adapter")
@invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.0"
)
class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes."""
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
t2i_adapter_model: T2IAdapterModelField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,
ui_order=-1,
)
weight: Union[float, list[float]] = InputField(
default=1, ge=0, description="The weight given to the T2I-Adapter", ui_type=UIType.Float, title="Weight"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(
default="just_resize",
description="The resize mode applied to the T2I-Adapter input image so that it matches the target output size.",
)
def invoke(self, context: InvocationContext) -> T2IAdapterOutput:
return T2IAdapterOutput(
t2i_adapter=T2IAdapterField(
image=self.image,
t2i_adapter_model=self.t2i_adapter_model,
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
resize_mode=self.resize_mode,
)
)

View File

@@ -4,14 +4,12 @@ from typing import Literal
import cv2 as cv
import numpy as np
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from realesrgan import RealESRGANer
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.backend.util.devices import choose_torch_device
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@@ -24,19 +22,13 @@ ESRGAN_MODELS = Literal[
"RealESRGAN_x2plus.pth",
]
if choose_torch_device() == torch.device("mps"):
from torch import mps
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.1.0")
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.0")
class ESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN."""
image: ImageField = InputField(description="The input image")
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
tile_size: int = InputField(
default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
@@ -94,11 +86,9 @@ class ESRGANInvocation(BaseInvocation):
model_path=str(models_path / esrgan_model_path),
model=rrdbnet_model,
half=False,
tile=self.tile_size,
)
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay?
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
# We can pass an `outscale` value here, but it just resizes the image by that factor after
@@ -109,10 +99,6 @@ class ESRGANInvocation(BaseInvocation):
# back to PIL
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,

View File

@@ -255,7 +255,6 @@ class InvokeAIAppConfig(InvokeAISettings):
attention_slice_size: Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode : bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
png_compress_level : int = Field(default=6, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = fastest, largest filesize, 9 = slowest, smallest filesize", category="Generation", )
# QUEUE
max_queue_size : int = Field(default=10000, gt=0, description="Maximum number of items in the session queue", category="Queue", )

View File

@@ -4,12 +4,7 @@ from typing import Any, Optional
from invokeai.app.models.image import ProgressImage
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem
from invokeai.app.util.misc import get_timestamp
@@ -267,31 +262,21 @@ class EventServiceBase:
),
)
def emit_queue_item_status_changed(
self,
session_queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
) -> None:
def emit_queue_item_status_changed(self, session_queue_item: SessionQueueItem) -> None:
"""Emitted when a queue item's status changes"""
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload=dict(
queue_id=queue_status.queue_id,
queue_item=dict(
queue_id=session_queue_item.queue_id,
item_id=session_queue_item.item_id,
status=session_queue_item.status,
batch_id=session_queue_item.batch_id,
session_id=session_queue_item.session_id,
error=session_queue_item.error,
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
),
batch_status=batch_status.dict(),
queue_status=queue_status.dict(),
queue_id=session_queue_item.queue_id,
queue_item_id=session_queue_item.item_id,
status=session_queue_item.status,
batch_id=session_queue_item.batch_id,
session_id=session_queue_item.session_id,
error=session_queue_item.error,
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
started_at=str(session_queue_item.started_at) if session_queue_item.started_at else None,
completed_at=str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
),
)

View File

@@ -2,7 +2,7 @@
import copy
import itertools
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
from typing import Annotated, Any, Optional, Union, cast, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import BaseModel, root_validator, validator
@@ -170,18 +170,6 @@ class NodeIdMismatchError(ValueError):
pass
class InvalidSubGraphError(ValueError):
pass
class CyclicalGraphError(ValueError):
pass
class UnknownGraphValidationError(ValueError):
pass
# TODO: Create and use an Empty output?
@invocation_output("graph_output")
class GraphInvocationOutput(BaseInvocationOutput):
@@ -266,6 +254,59 @@ class Graph(BaseModel):
default_factory=list,
)
@root_validator
def validate_nodes_and_edges(cls, values):
"""Validates that all edges match nodes in the graph"""
nodes = cast(Optional[dict[str, BaseInvocation]], values.get("nodes"))
edges = cast(Optional[list[Edge]], values.get("edges"))
if nodes is not None:
# Validate that all node ids are unique
node_ids = [n.id for n in nodes.values()]
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
if duplicate_node_ids:
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
# Validate that all node ids match the keys in the nodes dict
for k, v in nodes.items():
if k != v.id:
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
if edges is not None and nodes is not None:
# Validate that all edges match nodes in the graph
node_ids = set([e.source.node_id for e in edges] + [e.destination.node_id for e in edges])
missing_node_ids = [node_id for node_id in node_ids if node_id not in nodes]
if missing_node_ids:
raise NodeNotFoundError(
f"All edges must reference nodes in the graph, missing nodes: {missing_node_ids}"
)
# Validate that all edge fields match node fields in the graph
for edge in edges:
source_node = nodes.get(edge.source.node_id, None)
if source_node is None:
raise NodeFieldNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
destination_node = nodes.get(edge.destination.node_id, None)
if destination_node is None:
raise NodeFieldNotFoundError(
f"Edge destination node {edge.destination.node_id} does not exist in the graph"
)
# output fields are not on the node object directly, they are on the output type
if edge.source.field not in source_node.get_output_type().__fields__:
raise NodeFieldNotFoundError(
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
)
# input fields are on the node
if edge.destination.field not in destination_node.__fields__:
raise NodeFieldNotFoundError(
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
)
return values
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
@@ -336,108 +377,53 @@ class Graph(BaseModel):
except KeyError:
pass
def validate_self(self) -> None:
"""
Validates the graph.
Raises an exception if the graph is invalid:
- `DuplicateNodeIdError`
- `NodeIdMismatchError`
- `InvalidSubGraphError`
- `NodeNotFoundError`
- `NodeFieldNotFoundError`
- `CyclicalGraphError`
- `InvalidEdgeError`
"""
# Validate that all node ids are unique
node_ids = [n.id for n in self.nodes.values()]
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
if duplicate_node_ids:
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
# Validate that all node ids match the keys in the nodes dict
for k, v in self.nodes.items():
if k != v.id:
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
def is_valid(self) -> bool:
"""Validates the graph."""
# Validate all subgraphs
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
try:
gn.graph.validate_self()
except Exception as e:
raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e
if not gn.graph.is_valid():
return False
# Validate that all edges match nodes and fields in the graph
for edge in self.edges:
source_node = self.nodes.get(edge.source.node_id, None)
if source_node is None:
raise NodeNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
destination_node = self.nodes.get(edge.destination.node_id, None)
if destination_node is None:
raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph")
# output fields are not on the node object directly, they are on the output type
if edge.source.field not in source_node.get_output_type().__fields__:
raise NodeFieldNotFoundError(
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
)
# input fields are on the node
if edge.destination.field not in destination_node.__fields__:
raise NodeFieldNotFoundError(
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
)
# Validate all edges reference nodes in the graph
node_ids = set([e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges])
if not all((self.has_node(node_id) for node_id in node_ids)):
return False
# Validate there are no cycles
g = self.nx_graph_flat()
if not nx.is_directed_acyclic_graph(g):
raise CyclicalGraphError("Graph contains cycles")
return False
# Validate all edge connections are valid
for e in self.edges:
if not are_connections_compatible(
self.get_node(e.source.node_id),
e.source.field,
self.get_node(e.destination.node_id),
e.destination.field,
):
raise InvalidEdgeError(
f"Invalid edge from {e.source.node_id}.{e.source.field} to {e.destination.node_id}.{e.destination.field}"
if not all(
(
are_connections_compatible(
self.get_node(e.source.node_id),
e.source.field,
self.get_node(e.destination.node_id),
e.destination.field,
)
# Validate all iterators & collectors
# TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available
for n in self.nodes.values():
if isinstance(n, IterateInvocation) and not self._is_iterator_connection_valid(n.id):
raise InvalidEdgeError(f"Invalid iterator node {n.id}")
if isinstance(n, CollectInvocation) and not self._is_collector_connection_valid(n.id):
raise InvalidEdgeError(f"Invalid collector node {n.id}")
return None
def is_valid(self) -> bool:
"""
Checks if the graph is valid.
Raises `UnknownGraphValidationError` if there is a problem validating the graph (not a validation error).
"""
try:
self.validate_self()
return True
except (
DuplicateNodeIdError,
NodeIdMismatchError,
InvalidSubGraphError,
NodeNotFoundError,
NodeFieldNotFoundError,
CyclicalGraphError,
InvalidEdgeError,
for e in self.edges
)
):
return False
except Exception as e:
raise UnknownGraphValidationError(f"Problem validating graph {e}") from e
# Validate all iterators
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
if not all(
(self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation))
):
return False
# Validate all collectors
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
if not all(
(self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation))
):
return False
return True
def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph"""
@@ -818,12 +804,6 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
@validator("graph")
def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid"""
v.validate_self()
return v
class Config:
schema_extra = {
"required": [

View File

@@ -9,7 +9,6 @@ from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType
from send2trash import send2trash
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@@ -80,7 +79,6 @@ class DiskImageFileStorage(ImageFileStorageBase):
__cache_ids: Queue # TODO: this is an incredibly naive cache
__cache: Dict[Path, PILImageType]
__max_cache_size: int
__compress_level: int
def __init__(self, output_folder: Union[str, Path]):
self.__cache = dict()
@@ -89,7 +87,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
self.__output_folder: Path = output_folder if isinstance(output_folder, Path) else Path(output_folder)
self.__thumbnails_folder = self.__output_folder / "thumbnails"
self.__compress_level = InvokeAIAppConfig.get_config().png_compress_level
# Validate required output folders at launch
self.__validate_storage_folders()
@@ -136,7 +134,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
if original_workflow is not None:
pnginfo.add_text("invokeai_workflow", original_workflow)
image.save(image_path, "PNG", pnginfo=pnginfo, compress_level=self.__compress_level)
image.save(image_path, "PNG", pnginfo=pnginfo)
thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(thumbnail_name, thumbnail=True)

View File

@@ -1,4 +1,3 @@
import traceback
from threading import BoundedSemaphore
from threading import Event as ThreadEvent
from threading import Thread
@@ -124,10 +123,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
continue
except Exception as e:
self.__invoker.services.logger.error(f"Error in session processor: {e}")
if queue_item is not None:
self.__invoker.services.session_queue.cancel_queue_item(
queue_item.item_id, error=traceback.format_exc()
)
poll_now_event.wait(POLLING_INTERVAL)
continue
except Exception as e:

View File

@@ -80,7 +80,7 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
"""Cancels a session queue item"""
pass

View File

@@ -123,11 +123,6 @@ class Batch(BaseModel):
raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}")
return values
@validator("graph")
def validate_graph(cls, v: Graph):
v.validate_self()
return v
class Config:
schema_extra = {
"required": [

View File

@@ -427,13 +427,7 @@ class SqliteSessionQueue(SessionQueueBase):
finally:
self.__lock.release()
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
@@ -561,11 +555,10 @@ class SqliteSessionQueue(SessionQueueBase):
self.__lock.release()
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["canceled", "failed", "completed"]:
status = "failed" if error is not None else "canceled"
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error)
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
self.__invoker.services.queue.cancel(queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id,
@@ -615,13 +608,7 @@ class SqliteSessionQueue(SessionQueueBase):
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
except Exception:
self.__conn.rollback()
raise
@@ -667,13 +654,7 @@ class SqliteSessionQueue(SessionQueueBase):
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=current_queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
except Exception:
self.__conn.rollback()
raise

View File

@@ -265,41 +265,22 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
def prepare_control_image(
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
# but now should be able to assume that image is a single PIL.Image, which simplifies things
image: Image,
width: int,
height: int,
num_channels: int = 3,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
# latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width)
width=512, # should be 8 * latent.shape[3]
height=512, # should be 8 * latent height[2]
# batch_size=1, # currently no batching
# num_images_per_prompt=1, # currently only single image
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced",
resize_mode="just_resize_simple",
):
"""Pre-process images for ControlNets or T2I-Adapters.
Args:
image (Image): The PIL image to pre-process.
width (int): The target width in pixels.
height (int): The target height in pixels.
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
device (str, optional): The target device for the output image. Defaults to "cuda".
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
Defaults to True.
control_mode (str, optional): Defaults to "balanced".
resize_mode (str, optional): Defaults to "just_resize_simple".
Raises:
NotImplementedError: If resize_mode == "crop_resize_simple".
NotImplementedError: If resize_mode == "fill_resize_simple".
ValueError: If `resize_mode` is not recognized.
ValueError: If `num_channels` is out of range.
Returns:
torch.Tensor: The pre-processed input tensor.
"""
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
if (
resize_mode == "just_resize_simple"
or resize_mode == "crop_resize_simple"
@@ -308,10 +289,10 @@ def prepare_control_image(
image = image.convert("RGB")
if resize_mode == "just_resize_simple":
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
elif resize_mode == "crop_resize_simple":
raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
elif resize_mode == "fill_resize_simple":
raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
elif resize_mode == "crop_resize_simple": # not yet implemented
pass
elif resize_mode == "fill_resize_simple": # not yet implemented
pass
nimage = np.array(image)
nimage = nimage[None, :]
nimage = np.concatenate([nimage], axis=0)
@@ -332,11 +313,9 @@ def prepare_control_image(
device=device,
)
else:
raise ValueError(f"Unsupported resize_mode: '{resize_mode}'.")
if timage.shape[1] < num_channels or num_channels <= 0:
raise ValueError(f"Cannot achieve the target of num_channels={num_channels}.")
timage = timage[:, :num_channels, :, :]
pass
print("ERROR: invalid resize_mode ==> ", resize_mode)
exit(1)
timage = timage.to(device=device, dtype=dtype)
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"

View File

@@ -335,7 +335,7 @@ class ModelInstall(object):
# list all the files in the repo
files = [x.rfilename for x in hinfo.siblings]
if subfolder:
files = [x for x in files if x.startswith(f"{subfolder}/")]
files = [x for x in files if x.startswith("v2/")]
prefix = f"{subfolder}/" if subfolder else ""
location = None

View File

@@ -8,8 +8,6 @@ import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionProcessorWeights
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
# loading.
@@ -47,16 +45,18 @@ class IPAttnProcessor2_0(torch.nn.Module):
the weight scale of image prompt.
"""
def __init__(self, weights: list[IPAttentionProcessorWeights], scales: list[float]):
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
assert len(weights) == len(scales)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self._weights = weights
self._scales = scales
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
@@ -67,6 +67,16 @@ class IPAttnProcessor2_0(torch.nn.Module):
temb=None,
ip_adapter_image_prompt_embeds=None,
):
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
# The batch dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
ip_hidden_states = ip_adapter_image_prompt_embeds
residual = hidden_states
if attn.spatial_norm is not None:
@@ -118,36 +128,23 @@ class IPAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
assert len(ip_adapter_image_prompt_embeds) == len(self._weights)
if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ipa_embed.shape[2] == encoder_hidden_states.shape[2]
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_hidden_states = ipa_embed
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# The output of sdpa has shape: (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + scale * ip_hidden_states
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)

View File

@@ -1,15 +1,17 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
from contextlib import contextmanager
from typing import Optional, Union
import torch
from diffusers.models import UNet2DConditionModel
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.model_management.models.base import calc_model_size_by_data
from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from .resampler import Resampler
@@ -59,7 +61,7 @@ class IPAdapter:
def __init__(
self,
state_dict: dict[str, torch.Tensor],
state_dict: dict[torch.Tensor],
device: torch.device,
dtype: torch.dtype = torch.float16,
num_tokens: int = 4,
@@ -71,11 +73,12 @@ class IPAdapter:
self._clip_image_processor = CLIPImageProcessor()
self._image_proj_model = self._init_image_proj_model(state_dict["image_proj"])
self._state_dict = state_dict
self.attn_weights = IPAttentionWeights.from_state_dict(state_dict["ip_adapter"]).to(
self.device, dtype=self.dtype
)
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
# The _attn_processors will be initialized later when we have access to the UNet.
self._attn_processors = None
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device
@@ -83,14 +86,99 @@ class IPAdapter:
self.dtype = dtype
self._image_proj_model.to(device=self.device, dtype=self.dtype)
self.attn_weights.to(device=self.device, dtype=self.dtype)
if self._attn_processors is not None:
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
def calc_size(self):
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
if self._state_dict is not None:
image_proj_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["image_proj"].values()]
)
ip_adapter_size = sum(
[tensor.nelement() * tensor.element_size() for tensor in self._state_dict["ip_adapter"].values()]
)
return image_proj_size + ip_adapter_size
else:
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(
torch.nn.ModuleList(self._attn_processors.values())
)
def _init_image_proj_model(self, state_dict):
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can later be injected into a unet, and load the IP-Adapter
attention weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
TODO(ryand): As a future improvement, this could all be inferred from the state_dict when the IPAdapter is
intialized.
"""
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self.device, dtype=self.dtype)
ip_layers = torch.nn.ModuleList(attn_procs.values())
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
self._attn_processors = attn_procs
self._state_dict = None
# @genomancer: pushed scaling back out into its own method (like original Tencent implementation)
# which makes implementing begin_step_percent and end_step_percent easier
# but based on self._attn_processors (ala @Ryan) instead of original Tencent unet.attn_processors,
# which should make it easier to implement multiple IPAdapters
def set_scale(self, scale):
if self._attn_processors is not None:
for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor2_0):
attn_processor.scale = scale
@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel, scale: float):
"""A context manager that patches `unet` with this IP-Adapter's attention processors while it is active.
Yields:
None
"""
if self._attn_processors is None:
# We only have to call _prepare_attention_processors(...) once, and then the result is cached and can be
# used on any UNet model (with the same dimensions).
self._prepare_attention_processors(unet)
# Set scale
self.set_scale(scale)
# for attn_processor in self._attn_processors.values():
# if isinstance(attn_processor, IPAttnProcessor2_0):
# attn_processor.scale = scale
orig_attn_processors = unet.attn_processors
# Make a (moderately-) shallow copy of the self._attn_processors dict, because unet.set_attn_processor(...)
# actually pops elements from the passed dict.
ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()}
try:
unet.set_attn_processor(ip_adapter_attn_processors)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)
@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
@@ -130,20 +218,6 @@ class IPAdapterPlus(IPAdapter):
return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterPlusXL(IPAdapterPlus):
"""IP-Adapter Plus for SDXL."""
def _init_image_proj_model(self, state_dict):
return Resampler.from_state_dict(
state_dict=state_dict,
depth=4,
dim_head=64,
heads=20,
num_queries=self._num_tokens,
ff_mult=4,
).to(self.device, dtype=self.dtype)
def build_ip_adapter(
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus]:
@@ -154,14 +228,6 @@ def build_ip_adapter(
is_plus = "proj.weight" not in state_dict["image_proj"]
if is_plus:
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
# SD1 IP-Adapter Plus
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
elif cross_attention_dim == 2048:
# SDXL IP-Adapter Plus
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
else:
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
else:
return IPAdapter(state_dict, device=device, dtype=dtype)

View File

@@ -1,46 +0,0 @@
import torch
class IPAttentionProcessorWeights(torch.nn.Module):
"""The IP-Adapter weights for a single attention processor.
This class is a torch.nn.Module sub-class to facilitate loading from a state_dict. It does not have a forward(...)
method.
"""
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.to_k_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
self.to_v_ip = torch.nn.Linear(in_dim, out_dim, bias=False)
class IPAttentionWeights(torch.nn.Module):
"""A collection of all the `IPAttentionProcessorWeights` objects for an IP-Adapter model.
This class is a torch.nn.Module sub-class so that it inherits the `.to(...)` functionality. It does not have a
forward(...) method.
"""
def __init__(self, weights: torch.nn.ModuleDict):
super().__init__()
self._weights = weights
def get_attention_processor_weights(self, idx: int) -> IPAttentionProcessorWeights:
"""Get the `IPAttentionProcessorWeights` for the idx'th attention processor."""
# Cast to int first, because we expect the key to represent an int. Then cast back to str, because
# `torch.nn.ModuleDict` only supports str keys.
return self._weights[str(int(idx))]
@classmethod
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
attn_proc_weights: dict[str, IPAttentionProcessorWeights] = {}
for tensor_name, tensor in state_dict.items():
if "to_k_ip.weight" in tensor_name:
index = str(int(tensor_name.split(".")[0]))
attn_proc_weights[index] = IPAttentionProcessorWeights(tensor.shape[1], tensor.shape[0])
attn_proc_weights_module = torch.nn.ModuleDict(attn_proc_weights)
attn_proc_weights_module.load_state_dict(state_dict)
return cls(attn_proc_weights_module)

View File

@@ -1,53 +0,0 @@
from contextlib import contextmanager
from diffusers.models import UNet2DConditionModel
from invokeai.backend.ip_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
class UNetPatcher:
"""A class that contains multiple IP-Adapters and can apply them to a UNet."""
def __init__(self, ip_adapters: list[IPAdapter]):
self._ip_adapters = ip_adapters
self._scales = [1.0] * len(self._ip_adapters)
def set_scale(self, idx: int, value: float):
self._scales[idx] = value
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention
weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
"""
# Construct a dict of attention processors based on the UNet's architecture.
attn_procs = {}
for idx, name in enumerate(unet.attn_processors.keys()):
if name.endswith("attn1.processor"):
attn_procs[name] = AttnProcessor2_0()
else:
# Collect the weights from each IP Adapter for the idx'th attention processor.
attn_procs[name] = IPAttnProcessor2_0(
[ip_adapter.attn_weights.get_attention_processor_weights(idx) for ip_adapter in self._ip_adapters],
self._scales,
)
return attn_procs
@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel):
"""A context manager that patches `unet` with IP-Adapter attention processors."""
attn_procs = self._prepare_attention_processors(unet)
orig_attn_processors = unet.attn_processors
try:
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from the
# passed dict. So, if you wanted to keep the dict for future use, you'd have to make a moderately-shallow copy
# of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
unet.set_attn_processor(attn_procs)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)

View File

@@ -1,27 +0,0 @@
# Model Cache
## `glibc` Memory Allocator Fragmentation
Python (and PyTorch) relies on the memory allocator from the C Standard Library (`libc`). On linux, with the GNU C Standard Library implementation (`glibc`), our memory access patterns have been observed to cause severe memory fragmentation. This fragmentation results in large amounts of memory that has been freed but can't be released back to the OS. Loading models from disk and moving them between CPU/CUDA seem to be the operations that contribute most to the fragmentation. This memory fragmentation issue can result in OOM crashes during frequent model switching, even if `max_cache_size` is set to a reasonable value (e.g. a OOM crash with `max_cache_size=16` on a system with 32GB of RAM).
This problem may also exist on other OSes, and other `libc` implementations. But, at the time of writing, it has only been investigated on linux with `glibc`.
To better understand how the `glibc` memory allocator works, see these references:
- Basics: https://www.gnu.org/software/libc/manual/html_node/The-GNU-Allocator.html
- Details: https://sourceware.org/glibc/wiki/MallocInternals
Note the differences between memory allocated as chunks in an arena vs. memory allocated with `mmap`. Under `glibc`'s default configuration, most model tensors get allocated as chunks in an arena making them vulnerable to the problem of fragmentation.
We can work around this memory fragmentation issue by setting the following env var:
```bash
# Force blocks >1MB to be allocated with `mmap` so that they are released to the system immediately when they are freed.
MALLOC_MMAP_THRESHOLD_=1048576
```
See the following references for more information about the `malloc` tunable parameters:
- https://www.gnu.org/software/libc/manual/html_node/Malloc-Tunable-Parameters.html
- https://www.gnu.org/software/libc/manual/html_node/Memory-Allocation-Tunables.html
- https://man7.org/linux/man-pages/man3/mallopt.3.html
The model cache emits debug logs that provide visibility into the state of the `libc` memory allocator. See the `LibcUtil` class for more info on how these `libc` malloc stats are collected.

View File

@@ -1,75 +0,0 @@
import ctypes
class Struct_mallinfo2(ctypes.Structure):
"""A ctypes Structure that matches the libc mallinfo2 struct.
Docs:
- https://man7.org/linux/man-pages/man3/mallinfo.3.html
- https://www.gnu.org/software/libc/manual/html_node/Statistics-of-Malloc.html
struct mallinfo2 {
size_t arena; /* Non-mmapped space allocated (bytes) */
size_t ordblks; /* Number of free chunks */
size_t smblks; /* Number of free fastbin blocks */
size_t hblks; /* Number of mmapped regions */
size_t hblkhd; /* Space allocated in mmapped regions (bytes) */
size_t usmblks; /* See below */
size_t fsmblks; /* Space in freed fastbin blocks (bytes) */
size_t uordblks; /* Total allocated space (bytes) */
size_t fordblks; /* Total free space (bytes) */
size_t keepcost; /* Top-most, releasable space (bytes) */
};
"""
_fields_ = [
("arena", ctypes.c_size_t),
("ordblks", ctypes.c_size_t),
("smblks", ctypes.c_size_t),
("hblks", ctypes.c_size_t),
("hblkhd", ctypes.c_size_t),
("usmblks", ctypes.c_size_t),
("fsmblks", ctypes.c_size_t),
("uordblks", ctypes.c_size_t),
("fordblks", ctypes.c_size_t),
("keepcost", ctypes.c_size_t),
]
def __str__(self):
s = ""
s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n"
s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n"
s += f"{'smblks': <10}= {(self.smblks): >15} # Number of free fastbin blocks \n"
s += f"{'hblks': <10}= {(self.hblks): >15} # Number of mmapped regions \n"
s += f"{'hblkhd': <10}= {(self.hblkhd/2**30):15.5f} # Space allocated in mmapped regions (GB)\n"
s += f"{'usmblks': <10}= {(self.usmblks): >15} # Unused\n"
s += f"{'fsmblks': <10}= {(self.fsmblks/2**30):15.5f} # Space in freed fastbin blocks (GB)\n"
s += (
f"{'uordblks': <10}= {(self.uordblks/2**30):15.5f} # Space used by in-use allocations (non-mmapped)"
" (GB)\n"
)
s += f"{'fordblks': <10}= {(self.fordblks/2**30):15.5f} # Space in free blocks (non-mmapped) (GB)\n"
s += f"{'keepcost': <10}= {(self.keepcost/2**30):15.5f} # Top-most, releasable space (GB)\n"
return s
class LibcUtil:
"""A utility class for interacting with the C Standard Library (`libc`) via ctypes.
Note that this class will raise on __init__() if 'libc.so.6' can't be found. Take care to handle environments where
this shared library is not available.
TODO: Improve cross-OS compatibility of this class.
"""
def __init__(self):
self._libc = ctypes.cdll.LoadLibrary("libc.so.6")
def mallinfo2(self) -> Struct_mallinfo2:
"""Calls `libc` `mallinfo2`.
Docs: https://man7.org/linux/man-pages/man3/mallinfo.3.html
"""
mallinfo2 = self._libc.mallinfo2
mallinfo2.restype = Struct_mallinfo2
return mallinfo2()

View File

@@ -1,94 +0,0 @@
import gc
from typing import Optional
import psutil
import torch
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
GB = 2**30 # 1 GB
class MemorySnapshot:
"""A snapshot of RAM and VRAM usage. All values are in bytes."""
def __init__(self, process_ram: int, vram: Optional[int], malloc_info: Optional[Struct_mallinfo2]):
"""Initialize a MemorySnapshot.
Most of the time, `MemorySnapshot` will be constructed with `MemorySnapshot.capture()`.
Args:
process_ram (int): CPU RAM used by the current process.
vram (Optional[int]): VRAM used by torch.
malloc_info (Optional[Struct_mallinfo2]): Malloc info obtained from LibcUtil.
"""
self.process_ram = process_ram
self.vram = vram
self.malloc_info = malloc_info
@classmethod
def capture(cls, run_garbage_collector: bool = True):
"""Capture and return a MemorySnapshot.
Note: This function has significant overhead, particularly if `run_garbage_collector == True`.
Args:
run_garbage_collector (bool, optional): If true, gc.collect() will be run before checking the process RAM
usage. Defaults to True.
Returns:
MemorySnapshot
"""
if run_garbage_collector:
gc.collect()
# According to the psutil docs (https://psutil.readthedocs.io/en/latest/#psutil.Process.memory_info), rss is
# supported on all platforms.
process_ram = psutil.Process().memory_info().rss
if torch.cuda.is_available():
vram = torch.cuda.memory_allocated()
else:
# TODO: We could add support for mps.current_allocated_memory() as well. Leaving out for now until we have
# time to test it properly.
vram = None
try:
malloc_info = LibcUtil().mallinfo2()
except OSError:
# This is expected in environments that do not have the 'libc.so.6' shared library.
malloc_info = None
return cls(process_ram, vram, malloc_info)
def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str:
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
def get_msg_line(prefix: str, val1: int, val2: int):
diff = val2 - val1
return f"{prefix: <30} ({(diff/GB):+5.3f}): {(val1/GB):5.3f}GB -> {(val2/GB):5.3f}GB\n"
msg = ""
msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram)
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
msg += get_msg_line("libc mmap allocated", snapshot_1.malloc_info.hblkhd, snapshot_2.malloc_info.hblkhd)
msg += get_msg_line("libc arena used", snapshot_1.malloc_info.uordblks, snapshot_2.malloc_info.uordblks)
msg += get_msg_line("libc arena free", snapshot_1.malloc_info.fordblks, snapshot_2.malloc_info.fordblks)
libc_total_allocated_1 = snapshot_1.malloc_info.arena + snapshot_1.malloc_info.hblkhd
libc_total_allocated_2 = snapshot_2.malloc_info.arena + snapshot_2.malloc_info.hblkhd
msg += get_msg_line("libc total allocated", libc_total_allocated_1, libc_total_allocated_2)
libc_total_used_1 = snapshot_1.malloc_info.uordblks + snapshot_1.malloc_info.hblkhd
libc_total_used_2 = snapshot_2.malloc_info.uordblks + snapshot_2.malloc_info.hblkhd
msg += get_msg_line("libc total used", libc_total_used_1, libc_total_used_2)
if snapshot_1.vram is not None and snapshot_2.vram is not None:
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
return msg

View File

@@ -18,10 +18,8 @@ context. Use like this:
import gc
import hashlib
import math
import os
import sys
import time
from contextlib import suppress
from dataclasses import dataclass, field
from pathlib import Path
@@ -30,8 +28,6 @@ from typing import Any, Dict, Optional, Type, Union, types
import torch
import invokeai.backend.util.logging as logger
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_management.model_load_optimizations import skip_torch_weight_init
from ..util.devices import choose_torch_device
from .models import BaseModelType, ModelBase, ModelType, SubModelType
@@ -48,8 +44,6 @@ DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
# actual size of a gig
GIG = 1073741824
# Size of a MB in bytes.
MB = 2**20
@dataclass
@@ -211,44 +205,22 @@ class ModelCache(object):
cache_entry = self._cached_models.get(key, None)
if cache_entry is None:
self.logger.info(
f"Loading model {model_path}, type"
f" {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
f"Loading model {model_path}, type {base_model.value}:{model_type.value}{':'+submodel.value if submodel else ''}"
)
if self.stats:
self.stats.misses += 1
self_reported_model_size_before_load = model_info.get_size(submodel)
# Remove old models from the cache to make room for the new model.
self._make_cache_room(self_reported_model_size_before_load)
# this will remove older cached models until
# there is sufficient room to load the requested model
self._make_cache_room(model_info.get_size(submodel))
# Load the model from disk and capture a memory snapshot before/after.
start_load_time = time.time()
snapshot_before = MemorySnapshot.capture()
with skip_torch_weight_init():
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
snapshot_after = MemorySnapshot.capture()
end_load_time = time.time()
# clean memory to make MemoryUsage() more accurate
gc.collect()
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
if mem_used := model_info.get_size(submodel):
self.logger.debug(f"CPU RAM used for load: {(mem_used/GIG):.2f} GB")
self_reported_model_size_after_load = model_info.get_size(submodel)
self.logger.debug(
f"Moved model '{key}' from disk to cpu in {(end_load_time-start_load_time):.2f}s.\n"
f"Self-reported size before/after load: {(self_reported_model_size_before_load/GIG):.3f}GB /"
f" {(self_reported_model_size_after_load/GIG):.3f}GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
# We only log a warning for over-reported (not under-reported) model sizes before load. There is a known
# issue where models report their fp32 size before load, and are then loaded as fp16. Once this issue is
# addressed, it would make sense to log a warning for both over-reported and under-reported model sizes.
if (self_reported_model_size_after_load - self_reported_model_size_before_load) > 10 * MB:
self.logger.warning(
f"Model '{key}' mis-reported its size before load. Self-reported size before/after load:"
f" {(self_reported_model_size_before_load/GIG):.2f}GB /"
f" {(self_reported_model_size_after_load/GIG):.2f}GB."
)
cache_entry = _CacheRecord(self, model, self_reported_model_size_after_load)
cache_entry = _CacheRecord(self, model, mem_used)
self._cached_models[key] = cache_entry
else:
if self.stats:
@@ -268,45 +240,6 @@ class ModelCache(object):
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
def _move_model_to_device(self, key: str, target_device: torch.device):
cache_entry = self._cached_models[key]
source_device = cache_entry.model.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
# multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
start_model_to_time = time.time()
snapshot_before = MemorySnapshot.capture()
cache_entry.model.to(target_device)
snapshot_after = MemorySnapshot.capture()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n"
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if snapshot_before.vram is not None and snapshot_after.vram is not None:
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
if not math.isclose(
vram_change,
cache_entry.size,
rel_tol=0.1,
abs_tol=10 * MB,
):
self.logger.warning(
f"Moving model '{key}' from {source_device} to"
f" {target_device} caused an unexpected change in VRAM usage. The model's"
" estimated size may be incorrect. Estimated model size:"
f" {(cache_entry.size/GIG):.3f} GB.\n"
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
class ModelLocker(object):
def __init__(self, cache, key, model, gpu_load, size_needed):
"""
@@ -336,7 +269,11 @@ class ModelCache(object):
if self.cache.lazy_offloading:
self.cache._offload_unlocked_models(self.size_needed)
self.cache._move_model_to_device(self.key, self.cache.execution_device)
if self.model.device != self.cache.execution_device:
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
with VRAMUsage() as mem:
self.model.to(self.cache.execution_device) # move into GPU
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
self.cache._print_cuda_stats()
@@ -349,7 +286,7 @@ class ModelCache(object):
# in the event that the caller wants the model in RAM, we
# move it into CPU if it is in GPU and not locked
elif self.cache_entry.loaded and not self.cache_entry.locked:
self.cache._move_model_to_device(self.key, self.cache.storage_device)
self.model.to(self.cache.storage_device)
return self.model
@@ -402,8 +339,7 @@ class ModelCache(object):
locked_models += 1
self.logger.debug(
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ ="
f" {cached_models}/{loaded_models}/{locked_models}"
f"Current VRAM/RAM usage: {vram}/{ram}; cached_models/loaded_models/locked_models/ = {cached_models}/{loaded_models}/{locked_models}"
)
def _cache_size(self) -> int:
@@ -418,8 +354,7 @@ class ModelCache(object):
if current_size + bytes_needed > maximum_size:
self.logger.debug(
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional"
f" {(bytes_needed/GIG):.2f} GB"
f"Max cache size exceeded: {(current_size/GIG):.2f}/{self.max_cache_size:.2f} GB, need an additional {(bytes_needed/GIG):.2f} GB"
)
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
@@ -452,8 +387,7 @@ class ModelCache(object):
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
f" refs: {refs}"
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}, refs: {refs}"
)
# 2 refs:
@@ -489,9 +423,11 @@ class ModelCache(object):
if vram_in_use <= reserved:
break
if not cache_entry.locked and cache_entry.loaded:
self._move_model_to_device(model_key, self.storage_device)
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"Offloading {model_key} from {self.execution_device} into {self.storage_device}")
with VRAMUsage() as mem:
cache_entry.model.to(self.storage_device)
self.logger.debug(f"GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB")
vram_in_use += mem.vram_used # note vram_used is negative
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
gc.collect()
@@ -518,3 +454,16 @@ class ModelCache(object):
with open(hashpath, "w") as f:
f.write(hash)
return hash
class VRAMUsage(object):
def __init__(self):
self.vram = None
self.vram_used = 0
def __enter__(self):
self.vram = torch.cuda.memory_allocated()
return self
def __exit__(self, *args):
self.vram_used = torch.cuda.memory_allocated() - self.vram

View File

@@ -1,30 +0,0 @@
from contextlib import contextmanager
import torch
def _no_op(*args, **kwargs):
pass
@contextmanager
def skip_torch_weight_init():
"""A context manager that monkey-patches several of the common torch layers (torch.nn.Linear, torch.nn.Conv1d, etc.)
to skip weight initialization.
By default, `torch.nn.Linear` and `torch.nn.ConvNd` layers initialize their weights (according to a particular
distribution) when __init__ is called. This weight initialization step can take a significant amount of time, and is
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step.
"""
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
saved_functions = [m.reset_parameters for m in torch_modules]
try:
for torch_module in torch_modules:
torch_module.reset_parameters = _no_op
yield None
finally:
for torch_module, saved_function in zip(torch_modules, saved_functions):
torch_module.reset_parameters = saved_function

View File

@@ -57,7 +57,6 @@ class ModelProbe(object):
"AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
}
@classmethod
@@ -409,11 +408,6 @@ class CLIPVisionCheckpointProbe(CheckpointProbeBase):
raise NotImplementedError()
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
@@ -601,26 +595,6 @@ class CLIPVisionFolderProbe(FolderProbeBase):
return BaseModelType.Any
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
adapter_type = config.get("adapter_type", None)
if adapter_type == "full_adapter_xl":
return BaseModelType.StableDiffusionXL
elif adapter_type == "full_adapter" or "light_adapter":
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
return BaseModelType.StableDiffusion1
else:
raise InvalidModelException(
f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})."
)
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
@@ -629,7 +603,6 @@ ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInvers
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
@@ -638,6 +611,5 @@ ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInver
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@@ -25,7 +25,6 @@ from .lora import LoRAModel
from .sdxl import StableDiffusionXLModel
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
from .t2i_adapter import T2IAdapterModel
from .textual_inversion import TextualInversionModel
from .vae import VaeModel
@@ -39,7 +38,6 @@ MODEL_CLASSES = {
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.StableDiffusion2: {
ModelType.ONNX: ONNXStableDiffusion2Model,
@@ -50,7 +48,6 @@ MODEL_CLASSES = {
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.StableDiffusionXL: {
ModelType.Main: StableDiffusionXLModel,
@@ -62,7 +59,6 @@ MODEL_CLASSES = {
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelType.Main: StableDiffusionXLModel,
@@ -74,7 +70,6 @@ MODEL_CLASSES = {
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
BaseModelType.Any: {
ModelType.CLIPVision: CLIPVisionModel,
@@ -86,7 +81,6 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.T2IAdapter: T2IAdapterModel,
},
# BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model,

View File

@@ -53,7 +53,6 @@ class ModelType(str, Enum):
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
class SubModelType(str, Enum):

View File

@@ -1,102 +0,0 @@
import os
from enum import Enum
from typing import Literal, Optional
import torch
from diffusers import T2IAdapter
from invokeai.backend.model_management.models.base import (
BaseModelType,
EmptyConfigLoader,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelNotFoundException,
ModelType,
SubModelType,
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
)
class T2IAdapterModelFormat(str, Enum):
Diffusers = "diffusers"
class T2IAdapterModel(ModelBase):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[T2IAdapterModelFormat.Diffusers]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.T2IAdapter
super().__init__(model_path, base_model, model_type)
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
model_class_name = config.get("_class_name", None)
if model_class_name not in {"T2IAdapter"}:
raise InvalidModelException(f"Invalid T2I-Adapter model. Unknown _class_name: '{model_class_name}'.")
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> T2IAdapter:
if child_type is not None:
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
model = None
for variant in ["fp16", None]:
try:
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
variant=variant,
)
break
except Exception:
pass
if not model:
raise ModelNotFoundException()
# Calculate a more accurate size after loading the model into memory.
self.model_size = calc_model_size_by_data(model)
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException(f"Model not found at '{path}'.")
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):
return T2IAdapterModelFormat.Diffusers
raise InvalidModelException(f"Unsupported T2I-Adapter format: '{path}'.")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == T2IAdapterModelFormat.Diffusers:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")

View File

@@ -24,7 +24,6 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from ..util import auto_detect_slice_size, normalize_device
@@ -174,16 +173,6 @@ class IPAdapterData:
end_step_percent: float = Field(default=1.0)
@dataclass
class T2IAdapterData:
"""A structure containing the information required to apply conditioning from a single T2I-Adapter model."""
adapter_state: dict[torch.Tensor] = Field()
weight: Union[float, list[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r"""
@@ -337,8 +326,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
seed: Optional[int] = None,
@@ -391,7 +379,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=callback,
)
finally:
@@ -411,8 +398,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
*,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
):
self._adjust_memory_efficient_attention(latents)
@@ -425,7 +411,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0:
return latents, attention_map_saver
ip_adapter_unet_patcher = None
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
@@ -436,8 +421,11 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
weight = ip_adapter_data.weight[0] if isinstance(ip_adapter_data.weight, List) else ip_adapter_data.weight
attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention(
unet=self.invokeai_diffuser.model,
scale=weight,
)
self.use_ip_adapter = True
else:
attn_ctx = nullcontext()
@@ -466,8 +454,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
ip_adapter_unet_patcher=ip_adapter_unet_patcher,
)
latents = step_output.prev_sample
@@ -513,9 +499,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count: int,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[list[IPAdapterData]] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
ip_adapter_unet_patcher: Optional[UNetPatcher] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
@@ -528,30 +512,26 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# handle IP-Adapter
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
for i, single_ip_adapter_data in enumerate(ip_adapter_data):
first_adapter_step = math.floor(single_ip_adapter_data.begin_step_percent * total_step_count)
last_adapter_step = math.ceil(single_ip_adapter_data.end_step_percent * total_step_count)
weight = (
single_ip_adapter_data.weight[step_index]
if isinstance(single_ip_adapter_data.weight, List)
else single_ip_adapter_data.weight
)
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
ip_adapter_unet_patcher.set_scale(i, weight)
else:
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
ip_adapter_unet_patcher.set_scale(i, 0.0)
first_adapter_step = math.floor(ip_adapter_data.begin_step_percent * total_step_count)
last_adapter_step = math.ceil(ip_adapter_data.end_step_percent * total_step_count)
weight = (
ip_adapter_data.weight[step_index]
if isinstance(ip_adapter_data.weight, List)
else ip_adapter_data.weight
)
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# only apply IP-Adapter if current step is within the IP-Adapter's begin/end step range
# ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight)
ip_adapter_data.ip_adapter_model.set_scale(weight)
else:
# otherwise, set IP-Adapter scale to 0, so it has no effect
ip_adapter_data.ip_adapter_model.set_scale(0.0)
# Handle ControlNet(s) and T2I-Adapter(s)
down_block_additional_residuals = None
mid_block_additional_residual = None
if control_data is not None and t2i_adapter_data is not None:
# TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
elif control_data is not None:
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
# handle ControlNet(s)
# default is no controlnet, so set controlnet processing output to None
controlnet_down_block_samples, controlnet_mid_block_sample = None, None
if control_data is not None:
controlnet_down_block_samples, controlnet_mid_block_sample = self.invokeai_diffuser.do_controlnet_step(
control_data=control_data,
sample=latent_model_input,
timestep=timestep,
@@ -559,32 +539,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=total_step_count,
conditioning_data=conditioning_data,
)
elif t2i_adapter_data is not None:
accum_adapter_state = None
for single_t2i_adapter_data in t2i_adapter_data:
# Determine the T2I-Adapter weights for the current denoising step.
first_t2i_adapter_step = math.floor(single_t2i_adapter_data.begin_step_percent * total_step_count)
last_t2i_adapter_step = math.ceil(single_t2i_adapter_data.end_step_percent * total_step_count)
t2i_adapter_weight = (
single_t2i_adapter_data.weight[step_index]
if isinstance(single_t2i_adapter_data.weight, list)
else single_t2i_adapter_data.weight
)
if step_index < first_t2i_adapter_step or step_index > last_t2i_adapter_step:
# If the current step is outside of the T2I-Adapter's begin/end step range, then set its weight to 0
# so it has no effect.
t2i_adapter_weight = 0.0
# Apply the t2i_adapter_weight, and accumulate.
if accum_adapter_state is None:
# Handle the first T2I-Adapter.
accum_adapter_state = [val * t2i_adapter_weight for val in single_t2i_adapter_data.adapter_state]
else:
# Add to the previous adapter states.
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
accum_adapter_state[idx] += value * t2i_adapter_weight
down_block_additional_residuals = accum_adapter_state
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input,
@@ -593,8 +547,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=total_step_count,
conditioning_data=conditioning_data,
# extra:
down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=mid_block_additional_residual,
down_block_additional_residuals=controlnet_down_block_samples, # from controlnet(s)
mid_block_additional_residual=controlnet_mid_block_sample, # from controlnet(s)
)
guidance_scale = conditioning_data.guidance_scale

View File

@@ -81,7 +81,7 @@ class ConditioningData:
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
ip_adapter_conditioning: Optional[IPAdapterConditioningInfo] = None
@property
def dtype(self):

View File

@@ -346,10 +346,12 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.cat([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
"ip_adapter_image_prompt_embeds": torch.cat(
[
conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds,
conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds,
]
)
}
added_cond_kwargs = None
@@ -416,10 +418,7 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
ipa_conditioning.uncond_image_prompt_embeds
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds
}
added_cond_kwargs = None
@@ -445,10 +444,7 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
ipa_conditioning.cond_image_prompt_embeds
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds
}
added_cond_kwargs = None

View File

@@ -1,67 +0,0 @@
import contextlib
from pathlib import Path
from typing import Optional, Union
import pytest
import torch
from invokeai.app.services.config.invokeai_config import InvokeAIAppConfig
from invokeai.backend.install.model_install_backend import ModelInstall
from invokeai.backend.model_management.model_manager import ModelInfo
from invokeai.backend.model_management.models.base import BaseModelType, ModelNotFoundException, ModelType, SubModelType
@pytest.fixture(scope="session")
def torch_device():
return "cuda" if torch.cuda.is_available() else "cpu"
@pytest.fixture(scope="module")
def model_installer():
"""A global ModelInstall pytest fixture to be used by many tests."""
# HACK(ryand): InvokeAIAppConfig.get_config() returns a singleton config object. This can lead to weird interactions
# between tests that need to alter the config. For example, some tests change the 'root' directory in the config,
# which can cause `install_and_load_model(...)` to re-download the model unnecessarily. As a temporary workaround,
# we pass a kwarg to get_config, which causes the config to be re-loaded. To fix this properly, we should stop using
# a singleton.
return ModelInstall(InvokeAIAppConfig.get_config(log_level="info"))
def install_and_load_model(
model_installer: ModelInstall,
model_path_id_or_url: Union[str, Path],
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
) -> ModelInfo:
"""Install a model if it is not already installed, then get the ModelInfo for that model.
This is intended as a utility function for tests.
Args:
model_installer (ModelInstall): The model installer.
model_path_id_or_url (Union[str, Path]): The path, HF ID, URL, etc. where the model can be installed from if it
is not already installed.
model_name (str): The model name, forwarded to ModelManager.get_model(...).
base_model (BaseModelType): The base model, forwarded to ModelManager.get_model(...).
model_type (ModelType): The model type, forwarded to ModelManager.get_model(...).
submodel_type (Optional[SubModelType]): The submodel type, forwarded to ModelManager.get_model(...).
Returns:
ModelInfo
"""
# If the requested model is already installed, return its ModelInfo.
with contextlib.suppress(ModelNotFoundException):
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
# Install the requested model.
model_installer.heuristic_import(model_path_id_or_url)
try:
return model_installer.mgr.get_model(model_name, base_model, model_type, submodel_type)
except ModelNotFoundException as e:
raise Exception(
"Failed to get model info after installing it. There could be a mismatch between the requested model and"
f" the installation id ('{model_path_id_or_url}'). Error: {e}"
)

View File

@@ -96,22 +96,6 @@ sd-1/controlnet/tile:
repo_id: lllyasviel/control_v11f1e_sd15_tile
sd-1/controlnet/ip2p:
repo_id: lllyasviel/control_v11e_sd15_ip2p
sd-1/t2i_adapter/canny-sd15:
repo_id: TencentARC/t2iadapter_canny_sd15v2
sd-1/t2i_adapter/sketch-sd15:
repo_id: TencentARC/t2iadapter_sketch_sd15v2
sd-1/t2i_adapter/depth-sd15:
repo_id: TencentARC/t2iadapter_depth_sd15v2
sd-1/t2i_adapter/zoedepth-sd15:
repo_id: TencentARC/t2iadapter_zoedepth_sd15v1
sdxl/t2i_adapter/canny-sdxl:
repo_id: TencentARC/t2i-adapter-canny-sdxl-1.0
sdxl/t2i_adapter/zoedepth-sdxl:
repo_id: TencentARC/t2i-adapter-depth-zoe-sdxl-1.0
sdxl/t2i_adapter/lineart-sdxl:
repo_id: TencentARC/t2i-adapter-lineart-sdxl-1.0
sdxl/t2i_adapter/sketch-sdxl:
repo_id: TencentARC/t2i-adapter-sketch-sdxl-1.0
sd-1/embedding/EasyNegative:
path: https://huggingface.co/embed/EasyNegative/resolve/main/EasyNegative.safetensors
recommended: True

View File

@@ -98,16 +98,15 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.tabs = self.add_widget_intelligent(
SingleSelectColumns,
values=[
"STARTERS",
"MAINS",
"STARTER MODELS",
"MAIN MODELS",
"CONTROLNETS",
"T2I-ADAPTERS",
"IP-ADAPTERS",
"LORAS",
"TI EMBEDDINGS",
"LORA/LYCORIS",
"TEXTUAL INVERSION",
],
value=[self.current_tab],
columns=7,
columns=6,
max_height=2,
relx=8,
scroll_exit=True,
@@ -132,12 +131,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.t2i_models = self.add_model_widgets(
model_type=ModelType.T2IAdapter,
window_width=window_width,
)
bottom_of_table = max(bottom_of_table, self.nextrely)
self.nextrely = top_of_table
self.ipadapter_models = self.add_model_widgets(
model_type=ModelType.IPAdapter,
@@ -358,7 +351,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.starter_pipelines,
self.pipeline_models,
self.controlnet_models,
self.t2i_models,
self.ipadapter_models,
self.lora_models,
self.ti_models,
@@ -549,7 +541,6 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
self.starter_pipelines,
self.pipeline_models,
self.controlnet_models,
self.t2i_models,
self.ipadapter_models,
self.lora_models,
self.ti_models,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,4 +1,4 @@
import{w as s,h_ as T,v as l,a2 as I,h$ as R,ae as V,i0 as z,i1 as j,i2 as D,i3 as F,i4 as G,i5 as W,i6 as K,aG as H,i7 as U,i8 as Y}from"./index-90346e33.js";import{M as Z}from"./MantineProvider-486d4834.js";var P=String.raw,E=P`
import{w as s,hY as T,v as l,a2 as I,hZ as R,ae as V,h_ as z,h$ as j,i0 as D,i1 as F,i2 as G,i3 as W,i4 as K,aG as Y,i5 as Z,i6 as H}from"./index-94062f76.js";import{M as U}from"./MantineProvider-a057bfc9.js";var P=String.raw,E=P`
:root,
:host {
--chakra-vh: 100vh;
@@ -277,4 +277,4 @@ import{w as s,h_ as T,v as l,a2 as I,h$ as R,ae as V,i0 as z,i1 as j,i2 as D,i3
}
${E}
`}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(h=>{const f=h==="system"?w():h;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);I(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const h=a.get();if(h){c(h);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const A=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:A,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function m(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>m(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((m(e)||m(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=m(e)?e(...t):e,a=m(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function N(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}N.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(N,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function he({children:e}){const{i18n:o}=H(),n=o.dir(),r=l.useMemo(()=>ie({...U,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(Z,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:Y,children:e})})}const ve=l.memo(he);export{ve as default};
`}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(h=>{const f=h==="system"?w():h;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);I(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const h=a.get();if(h){c(h);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const A=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:A,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function m(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>m(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((m(e)||m(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=m(e)?e(...t):e,a=m(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function N(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}N.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(N,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function he({children:e}){const{i18n:o}=Y(),n=o.dir(),r=l.useMemo(()=>ie({...Z,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(U,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:H,children:e})})}const ve=l.memo(he);export{ve as default};

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -3,9 +3,6 @@
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate">
<meta http-equiv="Pragma" content="no-cache">
<meta http-equiv="Expires" content="0">
<title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
<style>
@@ -15,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-90346e33.js"></script>
<script type="module" crossorigin src="./assets/index-94062f76.js"></script>
</head>
<body dir="ltr">

View File

@@ -49,10 +49,8 @@
"cancel": "Cancel",
"close": "Close",
"communityLabel": "Community",
"controlNet": "ControlNet",
"controlAdapter": "Control Adapter",
"controlNet": "Controlnet",
"ipAdapter": "IP Adapter",
"t2iAdapter": "T2I Adapter",
"darkMode": "Dark Mode",
"discordLabel": "Discord",
"dontAskMeAgain": "Don't ask me again",
@@ -132,16 +130,6 @@
"upload": "Upload"
},
"controlnet": {
"controlAdapter": "Control Adapter",
"controlnet": "$t(controlnet.controlAdapter) #{{number}} ($t(common.controlNet))",
"ip_adapter": "$t(controlnet.controlAdapter) #{{number}} ($t(common.ipAdapter))",
"t2i_adapter": "$t(controlnet.controlAdapter) #{{number}} ($t(common.t2iAdapter))",
"addControlNet": "Add $t(common.controlNet)",
"addIPAdapter": "Add $t(common.ipAdapter)",
"addT2IAdapter": "Add $t(common.t2iAdapter)",
"controlNetEnabledT2IDisabled": "$t(common.controlNet) enabled, $t(common.t2iAdapter)s disabled",
"t2iEnabledControlNetDisabled": "$t(common.t2iAdapter) enabled, $t(common.controlNet)s disabled",
"controlNetT2IMutexDesc": "$t(common.controlNet) and $t(common.t2iAdapter) at same time is currently unsupported.",
"amult": "a_mult",
"autoConfigure": "Auto configure processor",
"balanced": "Balanced",
@@ -709,7 +697,7 @@
"noLoRAsAvailable": "No LoRAs available",
"noMatchingLoRAs": "No matching LoRAs",
"noMatchingModels": "No matching Models",
"noModelsAvailable": "No models available",
"noModelsAvailable": "No Modelss available",
"selectLoRA": "Select a LoRA",
"selectModel": "Select a Model"
},
@@ -797,14 +785,6 @@
"integerPolymorphic": "Integer Polymorphic",
"integerPolymorphicDescription": "A collection of integers.",
"invalidOutputSchema": "Invalid output schema",
"ipAdapter": "IP-Adapter",
"ipAdapterCollection": "IP-Adapters Collection",
"ipAdapterCollectionDescription": "A collection of IP-Adapters.",
"ipAdapterDescription": "An Image Prompt Adapter (IP-Adapter).",
"ipAdapterModel": "IP-Adapter Model",
"ipAdapterModelDescription": "IP-Adapter Model Field",
"ipAdapterPolymorphic": "IP-Adapter Polymorphic",
"ipAdapterPolymorphicDescription": "A collection of IP-Adapters.",
"latentsCollection": "Latents Collection",
"latentsCollectionDescription": "Latents may be passed between nodes.",
"latentsField": "Latents",
@@ -964,10 +944,9 @@
"missingFieldTemplate": "Missing field template",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} missing input",
"missingNodeTemplate": "Missing node template",
"noControlImageForControlAdapter": "Control Adapter {{number}} has no control image",
"noControlImageForControlNet": "ControlNet {{index}} has no control image",
"noInitialImageSelected": "No initial image selected",
"noModelForControlAdapter": "Control Adapter {{number}} has no model selected.",
"incompatibleBaseModelForControlAdapter": "Control Adapter {{number}} model is invalid with main model.",
"noModelForControlNet": "ControlNet {{index}} has no model selected.",
"noModelSelected": "No model selected",
"noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph",
@@ -1106,8 +1085,7 @@
},
"toast": {
"addedToBoard": "Added to board",
"baseModelChangedCleared_one": "Base model changed, cleared or disabled {{number}} incompatible submodel",
"baseModelChangedCleared_many": "$t(toast.baseModelChangedCleared_one)s",
"baseModelChangedCleared": "Base model changed, cleared",
"canceled": "Processing Canceled",
"canvasCopiedClipboard": "Canvas Copied to Clipboard",
"canvasDownloaded": "Canvas Downloaded",
@@ -1127,6 +1105,7 @@
"imageSavingFailed": "Image Saving Failed",
"imageUploaded": "Image Uploaded",
"imageUploadFailed": "Image Upload Failed",
"incompatibleSubmodel": "incompatible submodel",
"initialImageNotSet": "Initial Image Not Set",
"initialImageNotSetDesc": "Could not load initial image",
"initialImageSet": "Initial Image Set",

View File

@@ -3,9 +3,6 @@
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate">
<meta http-equiv="Pragma" content="no-cache">
<meta http-equiv="Expires" content="0">
<title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="favicon.ico" />
<style>

View File

@@ -157,7 +157,6 @@
"prettier": "^3.0.2",
"rollup-plugin-visualizer": "^5.9.2",
"ts-toolbelt": "^9.6.0",
"typescript": "^5.2.2",
"vite": "^4.4.9",
"vite-plugin-css-injected-by-js": "^3.3.0",
"vite-plugin-dts": "^3.5.2",

View File

@@ -49,10 +49,8 @@
"cancel": "Cancel",
"close": "Close",
"communityLabel": "Community",
"controlNet": "ControlNet",
"controlAdapter": "Control Adapter",
"controlNet": "Controlnet",
"ipAdapter": "IP Adapter",
"t2iAdapter": "T2I Adapter",
"darkMode": "Dark Mode",
"discordLabel": "Discord",
"dontAskMeAgain": "Don't ask me again",
@@ -132,16 +130,6 @@
"upload": "Upload"
},
"controlnet": {
"controlAdapter": "Control Adapter",
"controlnet": "$t(controlnet.controlAdapter) #{{number}} ($t(common.controlNet))",
"ip_adapter": "$t(controlnet.controlAdapter) #{{number}} ($t(common.ipAdapter))",
"t2i_adapter": "$t(controlnet.controlAdapter) #{{number}} ($t(common.t2iAdapter))",
"addControlNet": "Add $t(common.controlNet)",
"addIPAdapter": "Add $t(common.ipAdapter)",
"addT2IAdapter": "Add $t(common.t2iAdapter)",
"controlNetEnabledT2IDisabled": "$t(common.controlNet) enabled, $t(common.t2iAdapter)s disabled",
"t2iEnabledControlNetDisabled": "$t(common.t2iAdapter) enabled, $t(common.controlNet)s disabled",
"controlNetT2IMutexDesc": "$t(common.controlNet) and $t(common.t2iAdapter) at same time is currently unsupported.",
"amult": "a_mult",
"autoConfigure": "Auto configure processor",
"balanced": "Balanced",
@@ -797,14 +785,6 @@
"integerPolymorphic": "Integer Polymorphic",
"integerPolymorphicDescription": "A collection of integers.",
"invalidOutputSchema": "Invalid output schema",
"ipAdapter": "IP-Adapter",
"ipAdapterCollection": "IP-Adapters Collection",
"ipAdapterCollectionDescription": "A collection of IP-Adapters.",
"ipAdapterDescription": "An Image Prompt Adapter (IP-Adapter).",
"ipAdapterModel": "IP-Adapter Model",
"ipAdapterModelDescription": "IP-Adapter Model Field",
"ipAdapterPolymorphic": "IP-Adapter Polymorphic",
"ipAdapterPolymorphicDescription": "A collection of IP-Adapters.",
"latentsCollection": "Latents Collection",
"latentsCollectionDescription": "Latents may be passed between nodes.",
"latentsField": "Latents",
@@ -964,10 +944,9 @@
"missingFieldTemplate": "Missing field template",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}} missing input",
"missingNodeTemplate": "Missing node template",
"noControlImageForControlAdapter": "Control Adapter {{number}} has no control image",
"noControlImageForControlNet": "ControlNet {{index}} has no control image",
"noInitialImageSelected": "No initial image selected",
"noModelForControlAdapter": "Control Adapter {{number}} has no model selected.",
"incompatibleBaseModelForControlAdapter": "Control Adapter {{number}} model is invalid with main model.",
"noModelForControlNet": "ControlNet {{index}} has no model selected.",
"noModelSelected": "No model selected",
"noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph",
@@ -1106,8 +1085,7 @@
},
"toast": {
"addedToBoard": "Added to board",
"baseModelChangedCleared_one": "Base model changed, cleared or disabled {{number}} incompatible submodel",
"baseModelChangedCleared_many": "$t(toast.baseModelChangedCleared_one)s",
"baseModelChangedCleared": "Base model changed, cleared",
"canceled": "Processing Canceled",
"canvasCopiedClipboard": "Canvas Copied to Clipboard",
"canvasDownloaded": "Canvas Downloaded",
@@ -1127,6 +1105,7 @@
"imageSavingFailed": "Image Saving Failed",
"imageUploaded": "Image Uploaded",
"imageUploadFailed": "Image Upload Failed",
"incompatibleSubmodel": "incompatible submodel",
"initialImageNotSet": "Initial Image Not Set",
"initialImageNotSetDesc": "Could not load initial image",
"initialImageSet": "Initial Image Set",

View File

@@ -1,5 +1,5 @@
import { canvasPersistDenylist } from 'features/canvas/store/canvasPersistDenylist';
import { controlAdaptersPersistDenylist } from 'features/controlAdapters/store/controlAdaptersPersistDenylist';
import { controlNetDenylist } from 'features/controlNet/store/controlNetDenylist';
import { dynamicPromptsPersistDenylist } from 'features/dynamicPrompts/store/dynamicPromptsPersistDenylist';
import { galleryPersistDenylist } from 'features/gallery/store/galleryPersistDenylist';
import { nodesPersistDenylist } from 'features/nodes/store/nodesPersistDenylist';
@@ -20,7 +20,7 @@ const serializationDenylist: {
postprocessing: postprocessingPersistDenylist,
system: systemPersistDenylist,
ui: uiPersistDenylist,
controlNet: controlAdaptersPersistDenylist,
controlNet: controlNetDenylist,
dynamicPrompts: dynamicPromptsPersistDenylist,
};

View File

@@ -1,5 +1,5 @@
import { initialCanvasState } from 'features/canvas/store/canvasSlice';
import { initialControlAdapterState } from 'features/controlAdapters/store/controlAdaptersSlice';
import { initialControlNetState } from 'features/controlNet/store/controlNetSlice';
import { initialDynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { initialGalleryState } from 'features/gallery/store/gallerySlice';
import { initialNodesState } from 'features/nodes/store/nodesSlice';
@@ -25,7 +25,7 @@ const initialStates: {
config: initialConfigState,
ui: initialUIState,
hotkeys: initialHotkeysState,
controlAdapters: initialControlAdapterState,
controlNet: initialControlNetState,
dynamicPrompts: initialDynamicPromptsState,
sdxl: initialSDXLState,
};

View File

@@ -72,7 +72,6 @@ import { addTabChangedListener } from './listeners/tabChanged';
import { addUpscaleRequestedListener } from './listeners/upscaleRequested';
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
import { addBatchEnqueuedListener } from './listeners/batchEnqueued';
import { addControlAdapterAddedOrEnabledListener } from './listeners/controlAdapterAddedOrEnabled';
export const listenerMiddleware = createListenerMiddleware();
@@ -200,7 +199,3 @@ addTabChangedListener();
// Dynamic prompts
addDynamicPromptsListener();
// Display toast when controlnet or t2i adapter enabled
// TODO: Remove when they can both be enabled at same time
addControlAdapterAddedOrEnabledListener();

View File

@@ -1,5 +1,8 @@
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlAdaptersReset } from 'features/controlAdapters/store/controlAdaptersSlice';
import {
controlNetReset,
ipAdapterStateReset,
} from 'features/controlNet/store/controlNetSlice';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
@@ -17,7 +20,8 @@ export const addDeleteBoardAndImagesFulfilledListener = () => {
let wasInitialImageReset = false;
let wasCanvasReset = false;
let wasNodeEditorReset = false;
let wereControlAdaptersReset = false;
let wasControlNetReset = false;
let wasIPAdapterReset = false;
const state = getState();
deleted_images.forEach((image_name) => {
@@ -38,9 +42,14 @@ export const addDeleteBoardAndImagesFulfilledListener = () => {
wasNodeEditorReset = true;
}
if (imageUsage.isControlImage && !wereControlAdaptersReset) {
dispatch(controlAdaptersReset());
wereControlAdaptersReset = true;
if (imageUsage.isControlNetImage && !wasControlNetReset) {
dispatch(controlNetReset());
wasControlNetReset = true;
}
if (imageUsage.isIPAdapterImage && !wasIPAdapterReset) {
dispatch(ipAdapterStateReset());
wasIPAdapterReset = true;
}
});
},

View File

@@ -1,21 +1,20 @@
import { logger } from 'app/logging/logger';
import { canvasImageToControlNet } from 'features/canvas/store/actions';
import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '..';
import { canvasImageToControlAdapter } from 'features/canvas/store/actions';
export const addCanvasImageToControlNetListener = () => {
startAppListening({
actionCreator: canvasImageToControlAdapter,
actionCreator: canvasImageToControlNet,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { id } = action.payload;
let blob: Blob;
let blob;
try {
blob = await getBaseLayerBlob(state, true);
} catch (err) {
@@ -51,8 +50,8 @@ export const addCanvasImageToControlNetListener = () => {
const { image_name } = imageDTO;
dispatch(
controlAdapterImageChanged({
id,
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlImage: image_name,
})
);

View File

@@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import { canvasMaskToControlAdapter } from 'features/canvas/store/actions';
import { canvasMaskToControlNet } from 'features/canvas/store/actions';
import { getCanvasData } from 'features/canvas/util/getCanvasData';
import { controlAdapterImageChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
@@ -9,11 +9,11 @@ import { startAppListening } from '..';
export const addCanvasMaskToControlNetListener = () => {
startAppListening({
actionCreator: canvasMaskToControlAdapter,
actionCreator: canvasMaskToControlNet,
effect: async (action, { dispatch, getState }) => {
const log = logger('canvas');
const state = getState();
const { id } = action.payload;
const canvasBlobsAndImageData = await getCanvasData(
state.canvas.layerState,
state.canvas.boundingBoxCoordinates,
@@ -61,8 +61,8 @@ export const addCanvasMaskToControlNetListener = () => {
const { image_name } = imageDTO;
dispatch(
controlAdapterImageChanged({
id,
controlNetImageChanged({
controlNetId: action.payload.controlNet.controlNetId,
controlImage: image_name,
})
);

View File

@@ -1,87 +0,0 @@
import { isAnyOf } from '@reduxjs/toolkit';
import {
controlAdapterAdded,
controlAdapterAddedFromImage,
controlAdapterIsEnabledChanged,
controlAdapterRecalled,
selectControlAdapterAll,
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { ControlAdapterType } from 'features/controlAdapters/store/types';
import { addToast } from 'features/system/store/systemSlice';
import i18n from 'i18n';
import { startAppListening } from '..';
const isAnyControlAdapterAddedOrEnabled = isAnyOf(
controlAdapterAdded,
controlAdapterAddedFromImage,
controlAdapterRecalled,
controlAdapterIsEnabledChanged
);
/**
* Until we can have both controlnet and t2i adapter enabled at once, they are mutually exclusive
* This displays a toast when one is enabled and the other is already enabled, or one is added
* with the other enabled
*/
export const addControlAdapterAddedOrEnabledListener = () => {
startAppListening({
matcher: isAnyControlAdapterAddedOrEnabled,
effect: async (action, { dispatch, getOriginalState }) => {
const controlAdapters = getOriginalState().controlAdapters;
const hasEnabledControlNets = selectControlAdapterAll(
controlAdapters
).some((ca) => ca.isEnabled && ca.type === 'controlnet');
const hasEnabledT2IAdapters = selectControlAdapterAll(
controlAdapters
).some((ca) => ca.isEnabled && ca.type === 't2i_adapter');
let caType: ControlAdapterType | null = null;
if (controlAdapterAdded.match(action)) {
caType = action.payload.type;
}
if (controlAdapterAddedFromImage.match(action)) {
caType = action.payload.type;
}
if (controlAdapterRecalled.match(action)) {
caType = action.payload.type;
}
if (controlAdapterIsEnabledChanged.match(action)) {
const _caType = selectControlAdapterById(
controlAdapters,
action.payload.id
)?.type;
if (!_caType) {
return;
}
caType = _caType;
}
if (
(caType === 'controlnet' && hasEnabledT2IAdapters) ||
(caType === 't2i_adapter' && hasEnabledControlNets)
) {
const title =
caType === 'controlnet'
? i18n.t('controlnet.controlNetEnabledT2IDisabled')
: i18n.t('controlnet.t2iEnabledControlNetDisabled');
const description = i18n.t('controlnet.controlNetT2IMutexDesc');
dispatch(
addToast({
title,
description,
status: 'warning',
})
);
}
},
});
};

View File

@@ -1,24 +1,15 @@
import { AnyListenerPredicate } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { RootState } from 'app/store/store';
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import {
controlAdapterAutoConfigToggled,
controlAdapterImageChanged,
controlAdapterModelChanged,
controlAdapterProcessorParamsChanged,
controlAdapterProcessortTypeChanged,
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
controlNetAutoConfigToggled,
controlNetImageChanged,
controlNetModelChanged,
controlNetProcessorParamsChanged,
controlNetProcessorTypeChanged,
} from 'features/controlNet/store/controlNetSlice';
import { startAppListening } from '..';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
type AnyControlAdapterParamChangeAction =
| ReturnType<typeof controlAdapterProcessorParamsChanged>
| ReturnType<typeof controlAdapterModelChanged>
| ReturnType<typeof controlAdapterImageChanged>
| ReturnType<typeof controlAdapterProcessortTypeChanged>
| ReturnType<typeof controlAdapterAutoConfigToggled>;
const predicate: AnyListenerPredicate<RootState> = (
action,
@@ -26,37 +17,35 @@ const predicate: AnyListenerPredicate<RootState> = (
prevState
) => {
const isActionMatched =
controlAdapterProcessorParamsChanged.match(action) ||
controlAdapterModelChanged.match(action) ||
controlAdapterImageChanged.match(action) ||
controlAdapterProcessortTypeChanged.match(action) ||
controlAdapterAutoConfigToggled.match(action);
controlNetProcessorParamsChanged.match(action) ||
controlNetModelChanged.match(action) ||
controlNetImageChanged.match(action) ||
controlNetProcessorTypeChanged.match(action) ||
controlNetAutoConfigToggled.match(action);
if (!isActionMatched) {
return false;
}
const { id } = action.payload;
const prevCA = selectControlAdapterById(prevState.controlAdapters, id);
const ca = selectControlAdapterById(state.controlAdapters, id);
if (
!prevCA ||
!isControlNetOrT2IAdapter(prevCA) ||
!ca ||
!isControlNetOrT2IAdapter(ca)
) {
return false;
}
if (controlAdapterAutoConfigToggled.match(action)) {
if (controlNetAutoConfigToggled.match(action)) {
// do not process if the user just disabled auto-config
if (prevCA.shouldAutoConfig === true) {
if (
prevState.controlNet.controlNets[action.payload.controlNetId]
?.shouldAutoConfig === true
) {
return false;
}
}
const { controlImage, processorType, shouldAutoConfig } = ca;
if (controlAdapterModelChanged.match(action) && !shouldAutoConfig) {
const cn = state.controlNet.controlNets[action.payload.controlNetId];
if (!cn) {
// something is wrong, the controlNet should exist
return false;
}
const { controlImage, processorType, shouldAutoConfig } = cn;
if (controlNetModelChanged.match(action) && !shouldAutoConfig) {
// do not process if the action is a model change but the processor settings are dirty
return false;
}
@@ -78,7 +67,7 @@ export const addControlNetAutoProcessListener = () => {
predicate,
effect: async (action, { dispatch, cancelActiveListeners, delay }) => {
const log = logger('session');
const { id } = (action as AnyControlAdapterParamChangeAction).payload;
const { controlNetId } = action.payload;
// Cancel any in-progress instances of this listener
cancelActiveListeners();
@@ -86,7 +75,7 @@ export const addControlNetAutoProcessListener = () => {
// Delay before starting actual work
await delay(300);
dispatch(controlAdapterImageProcessed({ id }));
dispatch(controlNetImageProcessed({ controlNetId }));
},
});
};

View File

@@ -1,11 +1,11 @@
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { controlNetImageProcessed } from 'features/controlNet/store/actions';
import {
pendingControlImagesCleared,
controlAdapterImageChanged,
selectControlAdapterById,
controlAdapterProcessedImageChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
clearPendingControlImages,
controlNetImageChanged,
controlNetProcessedImageChanged,
} from 'features/controlNet/store/controlNetSlice';
import { SAVE_IMAGE } from 'features/nodes/util/graphBuilders/constants';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
@@ -15,34 +15,28 @@ import { isImageOutput } from 'services/api/guards';
import { Graph, ImageDTO } from 'services/api/types';
import { socketInvocationComplete } from 'services/events/actions';
import { startAppListening } from '..';
import { controlAdapterImageProcessed } from 'features/controlAdapters/store/actions';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
export const addControlNetImageProcessedListener = () => {
startAppListening({
actionCreator: controlAdapterImageProcessed,
actionCreator: controlNetImageProcessed,
effect: async (action, { dispatch, getState, take }) => {
const log = logger('session');
const { id } = action.payload;
const ca = selectControlAdapterById(getState().controlAdapters, id);
const { controlNetId } = action.payload;
const controlNet = getState().controlNet.controlNets[controlNetId];
if (!ca?.controlImage || !isControlNetOrT2IAdapter(ca)) {
if (!controlNet?.controlImage) {
log.error('Unable to process ControlNet image');
return;
}
if (ca.processorType === 'none' || ca.processorNode.type === 'none') {
return;
}
// ControlNet one-off procressing graph is just the processor node, no edges.
// Also we need to grab the image.
const graph: Graph = {
nodes: {
[ca.processorNode.id]: {
...ca.processorNode,
[controlNet.processorNode.id]: {
...controlNet.processorNode,
is_intermediate: true,
image: { image_name: ca.controlImage },
image: { image_name: controlNet.controlImage },
},
[SAVE_IMAGE]: {
id: SAVE_IMAGE,
@@ -54,7 +48,7 @@ export const addControlNetImageProcessedListener = () => {
edges: [
{
source: {
node_id: ca.processorNode.id,
node_id: controlNet.processorNode.id,
field: 'image',
},
destination: {
@@ -109,8 +103,8 @@ export const addControlNetImageProcessedListener = () => {
// Update the processed image in the store
dispatch(
controlAdapterProcessedImageChanged({
id,
controlNetProcessedImageChanged({
controlNetId,
processedControlImage: processedControlImage.image_name,
})
);
@@ -132,8 +126,10 @@ export const addControlNetImageProcessedListener = () => {
duration: 15000,
})
);
dispatch(pendingControlImagesCleared());
dispatch(controlAdapterImageChanged({ id, controlImage: null }));
dispatch(clearPendingControlImages());
dispatch(
controlNetImageChanged({ controlNetId, controlImage: null })
);
return;
}
}

View File

@@ -1,10 +1,10 @@
import { logger } from 'app/logging/logger';
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import {
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
controlNetImageChanged,
controlNetProcessedImageChanged,
ipAdapterImageChanged,
} from 'features/controlNet/store/controlNetSlice';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
@@ -17,7 +17,6 @@ import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesAdapter } from 'services/api/util';
import { startAppListening } from '..';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
export const addRequestedSingleImageDeletionListener = () => {
startAppListening({
@@ -91,28 +90,35 @@ export const addRequestedSingleImageDeletionListener = () => {
dispatch(clearInitialImage());
}
// reset control adapters that use the deleted images
forEach(selectControlAdapterAll(getState().controlAdapters), (ca) => {
// reset controlNets that use the deleted images
forEach(getState().controlNet.controlNets, (controlNet) => {
if (
ca.controlImage === imageDTO.image_name ||
(isControlNetOrT2IAdapter(ca) &&
ca.processedControlImage === imageDTO.image_name)
controlNet.controlImage === imageDTO.image_name ||
controlNet.processedControlImage === imageDTO.image_name
) {
dispatch(
controlAdapterImageChanged({
id: ca.id,
controlNetImageChanged({
controlNetId: controlNet.controlNetId,
controlImage: null,
})
);
dispatch(
controlAdapterProcessedImageChanged({
id: ca.id,
controlNetProcessedImageChanged({
controlNetId: controlNet.controlNetId,
processedControlImage: null,
})
);
}
});
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));
}
// reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
@@ -209,28 +215,35 @@ export const addRequestedMultipleImageDeletionListener = () => {
dispatch(clearInitialImage());
}
// reset control adapters that use the deleted images
forEach(selectControlAdapterAll(getState().controlAdapters), (ca) => {
// reset controlNets that use the deleted images
forEach(getState().controlNet.controlNets, (controlNet) => {
if (
ca.controlImage === imageDTO.image_name ||
(isControlNetOrT2IAdapter(ca) &&
ca.processedControlImage === imageDTO.image_name)
controlNet.controlImage === imageDTO.image_name ||
controlNet.processedControlImage === imageDTO.image_name
) {
dispatch(
controlAdapterImageChanged({
id: ca.id,
controlNetImageChanged({
controlNetId: controlNet.controlNetId,
controlImage: null,
})
);
dispatch(
controlAdapterProcessedImageChanged({
id: ca.id,
controlNetProcessedImageChanged({
controlNetId: controlNet.controlNetId,
processedControlImage: null,
})
);
}
});
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));
}
// reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {

View File

@@ -3,9 +3,11 @@ import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlAdapterImageChanged,
controlAdapterIsEnabledChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
controlNetImageChanged,
controlNetIsEnabledChanged,
ipAdapterImageChanged,
isIPAdapterEnabledChanged,
} from 'features/controlNet/store/controlNetSlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
@@ -88,26 +90,39 @@ export const addImageDroppedListener = () => {
* Image dropped on ControlNet
*/
if (
overData.actionType === 'SET_CONTROL_ADAPTER_IMAGE' &&
overData.actionType === 'SET_CONTROLNET_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { id } = overData.context;
const { controlNetId } = overData.context;
dispatch(
controlAdapterImageChanged({
id,
controlNetImageChanged({
controlImage: activeData.payload.imageDTO.image_name,
controlNetId,
})
);
dispatch(
controlAdapterIsEnabledChanged({
id,
controlNetIsEnabledChanged({
controlNetId,
isEnabled: true,
})
);
return;
}
/**
* Image dropped on IP Adapter image
*/
if (
overData.actionType === 'SET_IP_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(ipAdapterImageChanged(activeData.payload.imageDTO.image_name));
dispatch(isIPAdapterEnabledChanged(true));
return;
}
/**
* Image dropped on Canvas
*/

View File

@@ -18,7 +18,8 @@ export const addImageToDeleteSelectedListener = () => {
const isImageInUse =
imagesUsage.some((i) => i.isCanvasImage) ||
imagesUsage.some((i) => i.isInitialImage) ||
imagesUsage.some((i) => i.isControlImage) ||
imagesUsage.some((i) => i.isControlNetImage) ||
imagesUsage.some((i) => i.isIPAdapterImage) ||
imagesUsage.some((i) => i.isNodesImage);
if (shouldConfirmOnDelete || isImageInUse) {

View File

@@ -2,9 +2,11 @@ import { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import {
controlAdapterImageChanged,
controlAdapterIsEnabledChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
controlNetImageChanged,
controlNetIsEnabledChanged,
ipAdapterImageChanged,
isIPAdapterEnabledChanged,
} from 'features/controlNet/store/controlNetSlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
@@ -85,17 +87,17 @@ export const addImageUploadedFulfilledListener = () => {
return;
}
if (postUploadAction?.type === 'SET_CONTROL_ADAPTER_IMAGE') {
const { id } = postUploadAction;
if (postUploadAction?.type === 'SET_CONTROLNET_IMAGE') {
const { controlNetId } = postUploadAction;
dispatch(
controlAdapterIsEnabledChanged({
id,
controlNetIsEnabledChanged({
controlNetId,
isEnabled: true,
})
);
dispatch(
controlAdapterImageChanged({
id,
controlNetImageChanged({
controlNetId,
controlImage: imageDTO.image_name,
})
);
@@ -108,6 +110,18 @@ export const addImageUploadedFulfilledListener = () => {
return;
}
if (postUploadAction?.type === 'SET_IP_ADAPTER_IMAGE') {
dispatch(ipAdapterImageChanged(imageDTO.image_name));
dispatch(isIPAdapterEnabledChanged(true));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setIPAdapterImage'),
})
);
return;
}
if (postUploadAction?.type === 'SET_INITIAL_IMAGE') {
dispatch(initialImageChanged(imageDTO));
dispatch(

View File

@@ -1,9 +1,9 @@
import { logger } from 'app/logging/logger';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import {
controlAdapterIsEnabledChanged,
selectControlAdapterAll,
} from 'features/controlAdapters/store/controlAdaptersSlice';
controlNetRemoved,
ipAdapterStateReset,
} from 'features/controlNet/store/controlNetSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions';
import {
@@ -15,9 +15,9 @@ import {
import { zMainOrOnnxModel } from 'features/parameters/types/parameterSchemas';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
import { startAppListening } from '..';
import { t } from 'i18next';
export const addModelSelectedListener = () => {
startAppListening({
@@ -60,27 +60,33 @@ export const addModelSelectedListener = () => {
}
// handle incompatible controlnets
selectControlAdapterAll(state.controlAdapters).forEach((ca) => {
if (ca.model?.base_model !== base_model) {
dispatch(
controlAdapterIsEnabledChanged({ id: ca.id, isEnabled: false })
);
const { controlNets } = state.controlNet;
forEach(controlNets, (controlNet, controlNetId) => {
if (controlNet.model?.base_model !== base_model) {
dispatch(controlNetRemoved({ controlNetId }));
modelsCleared += 1;
}
});
// handle incompatible IP-Adapter
const { ipAdapterInfo } = state.controlNet;
if (
ipAdapterInfo.model &&
ipAdapterInfo.model.base_model !== base_model
) {
dispatch(ipAdapterStateReset());
modelsCleared += 1;
}
if (modelsCleared > 0) {
dispatch(
addToast(
makeToast({
title: t(
modelsCleared === 1
? 'toast.baseModelChangedCleared_one'
: 'toast.baseModelChangedCleared_many',
{
number: modelsCleared,
}
),
title: `${t(
'toast.baseModelChangedCleared'
)} ${modelsCleared} ${t('toast.incompatibleSubmodel')}${
modelsCleared === 1 ? '' : 's'
}`,
status: 'warning',
})
)

View File

@@ -1,10 +1,8 @@
import { logger } from 'app/logging/logger';
import {
controlAdapterModelCleared,
selectAllControlNets,
selectAllIPAdapters,
selectAllT2IAdapters,
} from 'features/controlAdapters/store/controlAdaptersSlice';
controlNetRemoved,
ipAdapterModelChanged,
} from 'features/controlNet/store/controlNetSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import {
modelChanged,
@@ -21,12 +19,14 @@ import {
} from 'features/sdxl/store/sdxlSlice';
import { forEach, some } from 'lodash-es';
import {
ipAdapterModelsAdapter,
mainModelsAdapter,
modelsApi,
vaeModelsAdapter,
} from 'services/api/endpoints/models';
import { TypeGuardFor } from 'services/api/types';
import { startAppListening } from '..';
import { zIPAdapterModel } from 'features/nodes/types/types';
export const addModelsLoadedListener = () => {
startAppListening({
@@ -221,45 +221,21 @@ export const addModelsLoadedListener = () => {
`ControlNet models loaded (${action.payload.ids.length})`
);
selectAllControlNets(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
const controlNets = getState().controlNet.controlNets;
forEach(controlNets, (controlNet, controlNetId) => {
const isControlNetAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === ca?.model?.model_name &&
m?.base_model === ca?.model?.base_model
m?.model_name === controlNet?.model?.model_name &&
m?.base_model === controlNet?.model?.base_model
);
if (isModelAvailable) {
if (isControlNetAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
},
});
startAppListening({
matcher: modelsApi.endpoints.getT2IAdapterModels.matchFulfilled,
effect: async (action, { getState, dispatch }) => {
// ControlNet models loaded - need to remove missing ControlNets from state
const log = logger('models');
log.info(
{ models: action.payload.entities },
`ControlNet models loaded (${action.payload.ids.length})`
);
selectAllT2IAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === ca?.model?.model_name &&
m?.base_model === ca?.model?.base_model
);
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
dispatch(controlNetRemoved({ controlNetId }));
});
},
});
@@ -273,20 +249,38 @@ export const addModelsLoadedListener = () => {
`IP Adapter models loaded (${action.payload.ids.length})`
);
selectAllIPAdapters(getState().controlAdapters).forEach((ca) => {
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === ca?.model?.model_name &&
m?.base_model === ca?.model?.base_model
const { model } = getState().controlNet.ipAdapterInfo;
const isModelAvailable = some(
action.payload.entities,
(m) =>
m?.model_name === model?.model_name &&
m?.base_model === model?.base_model
);
if (isModelAvailable) {
return;
}
const firstModel = ipAdapterModelsAdapter
.getSelectors()
.selectAll(action.payload)[0];
if (!firstModel) {
dispatch(ipAdapterModelChanged(null));
}
const result = zIPAdapterModel.safeParse(firstModel);
if (!result.success) {
log.error(
{ error: result.error.format() },
'Failed to parse IP Adapter model'
);
return;
}
if (isModelAvailable) {
return;
}
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
dispatch(ipAdapterModelChanged(result.data));
},
});
startAppListening({

View File

@@ -8,7 +8,6 @@ import {
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { CANVAS_OUTPUT } from 'features/nodes/util/graphBuilders/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { isImageOutput } from 'services/api/guards';
import { imagesAdapter } from 'services/api/util';
@@ -71,21 +70,11 @@ export const addInvocationCompleteEventListener = () => {
)
);
// update the total images for the board
dispatch(
boardsApi.util.updateQueryData(
'getBoardImagesTotal',
imageDTO.board_id ?? 'none',
(draft) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
draft.total += 1;
}
)
);
dispatch(
imagesApi.util.invalidateTags([
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
{ type: 'BoardImagesTotal', id: imageDTO.board_id },
{ type: 'BoardAssetsTotal', id: imageDTO.board_id },
{ type: 'Board', id: imageDTO.board_id },
])
);

View File

@@ -11,70 +11,44 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
actionCreator: socketQueueItemStatusChanged,
effect: async (action, { dispatch }) => {
const log = logger('socketio');
// we've got new status for the queue item, batch and queue
const { queue_item, batch_status, queue_status } = action.payload.data;
const {
queue_item_id: item_id,
queue_batch_id,
status,
} = action.payload.data;
log.debug(
action.payload,
`Queue item ${queue_item.item_id} status updated: ${queue_item.status}`
`Queue item ${item_id} status updated: ${status}`
);
dispatch(appSocketQueueItemStatusChanged(action.payload));
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
queueItemsAdapter.updateOne(draft, {
id: queue_item.item_id,
changes: queue_item,
id: item_id,
changes: action.payload.data,
});
})
);
// Update the queue status (we do not get the processor status here)
dispatch(
queueApi.util.updateQueryData('getQueueStatus', undefined, (draft) => {
if (!draft) {
return;
}
Object.assign(draft.queue, queue_status);
})
);
// Update the batch status
dispatch(
queueApi.util.updateQueryData(
'getBatchStatus',
{ batch_id: batch_status.batch_id },
() => batch_status
)
);
// Update the queue item status (this is the full queue item, including the session)
dispatch(
queueApi.util.updateQueryData(
'getQueueItem',
queue_item.item_id,
(draft) => {
if (!draft) {
return;
}
Object.assign(draft, queue_item);
}
)
);
// Invalidate caches for things we cannot update
// TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again
dispatch(
queueApi.util.invalidateTags([
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'InvocationCacheStatus',
{ type: 'SessionQueueItem', id: item_id },
{ type: 'SessionQueueItemDTO', id: item_id },
{ type: 'BatchStatus', id: queue_batch_id },
])
);
// Pass the event along
dispatch(appSocketQueueItemStatusChanged(action.payload));
const req = dispatch(
queueApi.endpoints.getQueueStatus.initiate(undefined, {
forceRefetch: true,
})
);
await req.unwrap();
req.unsubscribe();
},
});
};

View File

@@ -7,7 +7,7 @@ import {
} from '@reduxjs/toolkit';
import canvasReducer from 'features/canvas/store/canvasSlice';
import changeBoardModalReducer from 'features/changeBoardModal/store/slice';
import controlAdaptersReducer from 'features/controlAdapters/store/controlAdaptersSlice';
import controlNetReducer from 'features/controlNet/store/controlNetSlice';
import deleteImageModalReducer from 'features/deleteImageModal/store/slice';
import dynamicPromptsReducer from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import galleryReducer from 'features/gallery/store/gallerySlice';
@@ -44,7 +44,7 @@ const allReducers = {
config: configReducer,
ui: uiReducer,
hotkeys: hotkeysReducer,
controlAdapters: controlAdaptersReducer,
controlNet: controlNetReducer,
dynamicPrompts: dynamicPromptsReducer,
deleteImageModal: deleteImageModalReducer,
changeBoardModal: changeBoardModalReducer,
@@ -68,7 +68,7 @@ const rememberedKeys: (keyof typeof allReducers)[] = [
'postprocessing',
'system',
'ui',
'controlAdapters',
'controlNet',
'dynamicPrompts',
'lora',
'modelmanager',

View File

@@ -1,4 +1,4 @@
import { CONTROLNET_PROCESSORS } from 'features/controlAdapters/store/constants';
import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants';
import { InvokeTabName } from 'features/ui/store/tabMap';
import { O } from 'ts-toolbelt';

View File

@@ -15,7 +15,7 @@ type UseImageUploadButtonArgs = {
* @example
* const { getUploadButtonProps, getUploadInputProps, openUploader } = useImageUploadButton({
* postUploadAction: {
* type: 'SET_CONTROL_ADAPTER_IMAGE',
* type: 'SET_CONTROLNET_IMAGE',
* controlNetId: '12345',
* },
* isDisabled: getIsUploadDisabled(),

View File

@@ -2,18 +2,16 @@ import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { isInvocationNode } from 'features/nodes/types/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import i18n from 'i18next';
import { forEach } from 'lodash-es';
import { forEach, map } from 'lodash-es';
import { getConnectedEdges } from 'reactflow';
const selector = createSelector(
[stateSelector, activeTabNameSelector],
(
{ controlAdapters, generation, system, nodes, dynamicPrompts },
{ controlNet, generation, system, nodes, dynamicPrompts },
activeTabName
) => {
const { initialImage, model } = generation;
@@ -89,39 +87,30 @@ const selector = createSelector(
reasons.push(i18n.t('parameters.invoke.noModelSelected'));
}
selectControlAdapterAll(controlAdapters).forEach((ca, i) => {
if (!ca.isEnabled) {
return;
}
if (controlNet.isEnabled) {
map(controlNet.controlNets).forEach((controlNet, i) => {
if (!controlNet.isEnabled) {
return;
}
if (!controlNet.model) {
reasons.push(
i18n.t('parameters.invoke.noModelForControlNet', { index: i + 1 })
);
}
if (!ca.model) {
reasons.push(
i18n.t('parameters.invoke.noModelForControlAdapter', {
number: i + 1,
})
);
} else if (ca.model.base_model !== model?.base_model) {
// This should never happen, just a sanity check
reasons.push(
i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', {
number: i + 1,
})
);
}
if (
!ca.controlImage ||
(isControlNetOrT2IAdapter(ca) &&
!ca.processedControlImage &&
ca.processorType !== 'none')
) {
reasons.push(
i18n.t('parameters.invoke.noControlImageForControlAdapter', {
number: i + 1,
})
);
}
});
if (
!controlNet.controlImage ||
(!controlNet.processedControlImage &&
controlNet.processorType !== 'none')
) {
reasons.push(
i18n.t('parameters.invoke.noControlImageForControlNet', {
index: i + 1,
})
);
}
});
}
}
return { isReady: !reasons.length, reasons };

View File

@@ -1,4 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import { ControlNetConfig } from 'features/controlNet/store/controlNetSlice';
import { ImageDTO } from 'services/api/types';
export const canvasSavedToGallery = createAction('canvas/canvasSavedToGallery');
@@ -21,10 +22,10 @@ export const stagingAreaImageSaved = createAction<{ imageDTO: ImageDTO }>(
'canvas/stagingAreaImageSaved'
);
export const canvasMaskToControlAdapter = createAction<{ id: string }>(
'canvas/canvasMaskToControlAdapter'
);
export const canvasMaskToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasMaskToControlNet');
export const canvasImageToControlAdapter = createAction<{ id: string }>(
'canvas/canvasImageToControlAdapter'
);
export const canvasImageToControlNet = createAction<{
controlNet: ControlNetConfig;
}>('canvas/canvasImageToControlNet');

View File

@@ -29,7 +29,6 @@ import {
isCanvasBaseImage,
isCanvasMaskLine,
} from './canvasTypes';
import { appSocketQueueItemStatusChanged } from 'services/events/actions';
export const initialLayerState: CanvasLayerState = {
objects: [],
@@ -787,18 +786,6 @@ export const canvasSlice = createSlice({
},
},
extraReducers: (builder) => {
builder.addCase(appSocketQueueItemStatusChanged, (state, action) => {
const batch_status = action.payload.data.batch_status;
if (!state.batchIds.includes(batch_status.batch_id)) {
return;
}
if (batch_status.in_progress === 0 && batch_status.pending === 0) {
state.batchIds = state.batchIds.filter(
(id) => id !== batch_status.batch_id
);
}
});
builder.addCase(setAspectRatio, (state, action) => {
const ratio = action.payload;
if (ratio) {

View File

@@ -1,39 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAISwitch from 'common/components/IAISwitch';
import { controlAdapterAutoConfigToggled } from 'features/controlAdapters/store/controlAdaptersSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled';
import { useControlAdapterShouldAutoConfig } from '../hooks/useControlAdapterShouldAutoConfig';
import { isNil } from 'lodash-es';
type Props = {
id: string;
};
const ControlAdapterShouldAutoConfig = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const shouldAutoConfig = useControlAdapterShouldAutoConfig(id);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleShouldAutoConfigChanged = useCallback(() => {
dispatch(controlAdapterAutoConfigToggled({ id }));
}, [id, dispatch]);
if (isNil(shouldAutoConfig)) {
return null;
}
return (
<IAISwitch
label={t('controlnet.autoConfigure')}
aria-label={t('controlnet.autoConfigure')}
isChecked={shouldAutoConfig}
onChange={handleShouldAutoConfigChanged}
isDisabled={!isEnabled}
/>
);
};
export default memo(ControlAdapterShouldAutoConfig);

View File

@@ -1,114 +0,0 @@
import { ButtonGroup, Divider, Flex } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAICollapse from 'common/components/IAICollapse';
import ControlAdapterConfig from 'features/controlAdapters/components/ControlAdapterConfig';
import {
selectAllControlNets,
selectAllIPAdapters,
selectAllT2IAdapters,
selectControlAdapterIds,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { Fragment, memo } from 'react';
import { FaPlus } from 'react-icons/fa';
import { useAddControlAdapter } from '../hooks/useAddControlAdapter';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
[stateSelector],
({ controlAdapters }) => {
const activeLabel: string[] = [];
const ipAdapterCount = selectAllIPAdapters(controlAdapters).length;
if (ipAdapterCount > 0) {
activeLabel.push(`${ipAdapterCount} IP`);
}
const controlNetCount = selectAllControlNets(controlAdapters).length;
if (controlNetCount > 0) {
activeLabel.push(`${controlNetCount} ControlNet`);
}
const t2iAdapterCount = selectAllT2IAdapters(controlAdapters).length;
if (t2iAdapterCount > 0) {
activeLabel.push(`${t2iAdapterCount} T2I`);
}
const controlAdapterIds =
selectControlAdapterIds(controlAdapters).map(String);
return {
controlAdapterIds,
activeLabel: activeLabel.join(', '),
};
},
defaultSelectorOptions
);
const ControlAdaptersCollapse = () => {
const { t } = useTranslation();
const { controlAdapterIds, activeLabel } = useAppSelector(selector);
const isControlNetDisabled = useFeatureStatus('controlNet').isFeatureDisabled;
const [addControlNet, isAddControlNetDisabled] =
useAddControlAdapter('controlnet');
const [addIPAdapter, isAddIPAdapterDisabled] =
useAddControlAdapter('ip_adapter');
const [addT2IAdapter, isAddT2IAdapterDisabled] =
useAddControlAdapter('t2i_adapter');
if (isControlNetDisabled) {
return null;
}
return (
<IAICollapse label="Control Adapters" activeLabel={activeLabel}>
<Flex sx={{ flexDir: 'column', gap: 2 }}>
<ButtonGroup size="sm" w="full" justifyContent="space-between">
<IAIButton
tooltip={t('controlnet.addControlNet')}
leftIcon={<FaPlus />}
onClick={addControlNet}
data-testid="add controlnet"
flexGrow={1}
isDisabled={isAddControlNetDisabled}
>
{t('common.controlNet')}
</IAIButton>
<IAIButton
tooltip={t('controlnet.addIPAdapter')}
leftIcon={<FaPlus />}
onClick={addIPAdapter}
data-testid="add ip adapter"
flexGrow={1}
isDisabled={isAddIPAdapterDisabled}
>
{t('common.ipAdapter')}
</IAIButton>
<IAIButton
tooltip={t('controlnet.addT2IAdapter')}
leftIcon={<FaPlus />}
onClick={addT2IAdapter}
data-testid="add t2i adapter"
flexGrow={1}
isDisabled={isAddT2IAdapterDisabled}
>
{t('common.t2iAdapter')}
</IAIButton>
</ButtonGroup>
{controlAdapterIds.map((id, i) => (
<Fragment key={id}>
<Divider />
<ControlAdapterConfig id={id} number={i + 1} />
</Fragment>
))}
</Flex>
</IAICollapse>
);
};
export default memo(ControlAdaptersCollapse);

View File

@@ -1,20 +0,0 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { controlAdapterProcessorParamsChanged } from 'features/controlAdapters/store/controlAdaptersSlice';
import { ControlAdapterProcessorNode } from 'features/controlAdapters/store/types';
import { useCallback } from 'react';
export const useProcessorNodeChanged = () => {
const dispatch = useAppDispatch();
const handleProcessorNodeChanged = useCallback(
(id: string, params: Partial<ControlAdapterProcessorNode>) => {
dispatch(
controlAdapterProcessorParamsChanged({
id,
params,
})
);
},
[dispatch]
);
return handleProcessorNodeChanged;
};

View File

@@ -1,43 +0,0 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { controlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice';
import { useCallback, useMemo } from 'react';
import { ControlAdapterType } from '../store/types';
import { useControlAdapterModels } from './useControlAdapterModels';
export const useAddControlAdapter = (type: ControlAdapterType) => {
const baseModel = useAppSelector(
(state) => state.generation.model?.base_model
);
const dispatch = useAppDispatch();
const models = useControlAdapterModels(type);
const firstModel = useMemo(() => {
// prefer to use a model that matches the base model
const firstCompatibleModel = models.filter((m) =>
baseModel ? m.base_model === baseModel : true
)[0];
if (firstCompatibleModel) {
return firstCompatibleModel;
}
return models[0];
}, [baseModel, models]);
const isDisabled = useMemo(() => !firstModel, [firstModel]);
const addControlAdapter = useCallback(() => {
if (isDisabled) {
return;
}
dispatch(
controlAdapterAdded({
type,
overrides: { model: firstModel },
})
);
}, [dispatch, firstModel, isDisabled, type]);
return [addControlAdapter, isDisabled] as const;
};

View File

@@ -1,22 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapter = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => selectControlAdapterById(controlAdapters, id),
defaultSelectorOptions
),
[id]
);
const controlAdapter = useAppSelector(selector);
return controlAdapter;
};

View File

@@ -1,30 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterBeginEndStepPct = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const cn = selectControlAdapterById(controlAdapters, id);
return cn
? {
beginStepPct: cn.beginStepPct,
endStepPct: cn.endStepPct,
}
: undefined;
},
defaultSelectorOptions
),
[id]
);
const stepPcts = useAppSelector(selector);
return stepPcts;
};

View File

@@ -1,23 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterControlImage = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.controlImage,
defaultSelectorOptions
),
[id]
);
const weight = useAppSelector(selector);
return weight;
};

View File

@@ -1,29 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isControlNet } from '../store/types';
export const useControlAdapterControlMode = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
if (ca && isControlNet(ca)) {
return ca.controlMode;
}
return undefined;
},
defaultSelectorOptions
),
[id]
);
const controlMode = useAppSelector(selector);
return controlMode;
};

View File

@@ -1,23 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterIsEnabled = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.isEnabled ?? false,
defaultSelectorOptions
),
[id]
);
const isEnabled = useAppSelector(selector);
return isEnabled;
};

View File

@@ -1,23 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterModel = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.model,
defaultSelectorOptions
),
[id]
);
const model = useAppSelector(selector);
return model;
};

View File

@@ -1,49 +0,0 @@
import { useMemo } from 'react';
import {
controlNetModelsAdapter,
ipAdapterModelsAdapter,
t2iAdapterModelsAdapter,
useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
} from 'services/api/endpoints/models';
import { ControlAdapterType } from '../store/types';
export const useControlAdapterModels = (type?: ControlAdapterType) => {
const { data: controlNetModelsData } = useGetControlNetModelsQuery();
const controlNetModels = useMemo(
() =>
controlNetModelsData
? controlNetModelsAdapter.getSelectors().selectAll(controlNetModelsData)
: [],
[controlNetModelsData]
);
const { data: t2iAdapterModelsData } = useGetT2IAdapterModelsQuery();
const t2iAdapterModels = useMemo(
() =>
t2iAdapterModelsData
? t2iAdapterModelsAdapter.getSelectors().selectAll(t2iAdapterModelsData)
: [],
[t2iAdapterModelsData]
);
const { data: ipAdapterModelsData } = useGetIPAdapterModelsQuery();
const ipAdapterModels = useMemo(
() =>
ipAdapterModelsData
? ipAdapterModelsAdapter.getSelectors().selectAll(ipAdapterModelsData)
: [],
[ipAdapterModelsData]
);
if (type === 'controlnet') {
return controlNetModels;
}
if (type === 't2i_adapter') {
return t2iAdapterModels;
}
if (type === 'ip_adapter') {
return ipAdapterModels;
}
return [];
};

View File

@@ -1,29 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterProcessedControlImage = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
return ca && isControlNetOrT2IAdapter(ca)
? ca.processedControlImage
: undefined;
},
defaultSelectorOptions
),
[id]
);
const weight = useAppSelector(selector);
return weight;
};

View File

@@ -1,29 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterProcessorNode = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
return ca && isControlNetOrT2IAdapter(ca)
? ca.processorNode
: undefined;
},
defaultSelectorOptions
),
[id]
);
const processorNode = useAppSelector(selector);
return processorNode;
};

View File

@@ -1,29 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterProcessorType = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
return ca && isControlNetOrT2IAdapter(ca)
? ca.processorType
: undefined;
},
defaultSelectorOptions
),
[id]
);
const processorType = useAppSelector(selector);
return processorType;
};

View File

@@ -1,29 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterResizeMode = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
if (ca && isControlNetOrT2IAdapter(ca)) {
return ca.resizeMode;
}
return undefined;
},
defaultSelectorOptions
),
[id]
);
const controlMode = useAppSelector(selector);
return controlMode;
};

View File

@@ -1,29 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from '../store/types';
export const useControlAdapterShouldAutoConfig = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) => {
const ca = selectControlAdapterById(controlAdapters, id);
if (ca && isControlNetOrT2IAdapter(ca)) {
return ca.shouldAutoConfig;
}
return undefined;
},
defaultSelectorOptions
),
[id]
);
const controlMode = useAppSelector(selector);
return controlMode;
};

View File

@@ -1,23 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterType = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.type,
defaultSelectorOptions
),
[id]
);
const type = useAppSelector(selector);
return type;
};

View File

@@ -1,23 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useMemo } from 'react';
import { selectControlAdapterById } from '../store/controlAdaptersSlice';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
export const useControlAdapterWeight = (id: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ controlAdapters }) =>
selectControlAdapterById(controlAdapters, id)?.weight,
defaultSelectorOptions
),
[id]
);
const weight = useAppSelector(selector);
return weight;
};

View File

@@ -1,5 +0,0 @@
import { createAction } from '@reduxjs/toolkit';
export const controlAdapterImageProcessed = createAction<{
id: string;
}>('controlAdapters/imageProcessed');

View File

@@ -1,8 +0,0 @@
import { ControlAdaptersState } from './types';
/**
* ControlNet slice persist denylist
*/
export const controlAdaptersPersistDenylist: (keyof ControlAdaptersState)[] = [
'pendingControlImages',
];

View File

@@ -1,546 +0,0 @@
import {
PayloadAction,
Update,
createEntityAdapter,
createSlice,
} from '@reduxjs/toolkit';
import {
ControlNetModelParam,
IPAdapterModelParam,
T2IAdapterModelParam,
} from 'features/parameters/types/parameterSchemas';
import { cloneDeep, merge, uniq } from 'lodash-es';
import { appSocketInvocationError } from 'services/events/actions';
import { v4 as uuidv4 } from 'uuid';
import { buildControlAdapter } from '../util/buildControlAdapter';
import { controlAdapterImageProcessed } from './actions';
import {
CONTROLNET_MODEL_DEFAULT_PROCESSORS as CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS,
CONTROLNET_PROCESSORS,
} from './constants';
import {
ControlAdapterConfig,
ControlAdapterProcessorType,
ControlAdapterType,
ControlAdaptersState,
ControlMode,
ControlNetConfig,
RequiredControlAdapterProcessorNode,
ResizeMode,
T2IAdapterConfig,
isControlNet,
isControlNetOrT2IAdapter,
isIPAdapter,
isT2IAdapter,
} from './types';
export const caAdapter = createEntityAdapter<ControlAdapterConfig>();
export const {
selectById: selectControlAdapterById,
selectAll: selectControlAdapterAll,
selectEntities: selectControlAdapterEntities,
selectIds: selectControlAdapterIds,
selectTotal: selectControlAdapterTotal,
} = caAdapter.getSelectors();
export const initialControlAdapterState: ControlAdaptersState =
caAdapter.getInitialState<{
pendingControlImages: string[];
}>({
pendingControlImages: [],
});
export const selectAllControlNets = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters).filter(isControlNet);
export const selectValidControlNets = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters)
.filter(isControlNet)
.filter(
(ca) =>
ca.isEnabled &&
ca.model &&
(Boolean(ca.processedControlImage) ||
(ca.processorType === 'none' && Boolean(ca.controlImage)))
);
export const selectAllIPAdapters = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters).filter(isIPAdapter);
export const selectValidIPAdapters = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters)
.filter(isIPAdapter)
.filter((ca) => ca.isEnabled && ca.model && Boolean(ca.controlImage));
export const selectAllT2IAdapters = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters).filter(isT2IAdapter);
export const selectValidT2IAdapters = (controlAdapters: ControlAdaptersState) =>
selectControlAdapterAll(controlAdapters)
.filter(isT2IAdapter)
.filter(
(ca) =>
ca.isEnabled &&
ca.model &&
(Boolean(ca.processedControlImage) ||
(ca.processorType === 'none' && Boolean(ca.controlImage)))
);
// TODO: I think we can safely remove this?
// const disableAllIPAdapters = (
// state: ControlAdaptersState,
// exclude?: string
// ) => {
// const updates: Update<ControlAdapterConfig>[] = selectAllIPAdapters(state)
// .filter((ca) => ca.id !== exclude)
// .map((ca) => ({
// id: ca.id,
// changes: { isEnabled: false },
// }));
// caAdapter.updateMany(state, updates);
// };
const disableAllControlNets = (
state: ControlAdaptersState,
exclude?: string
) => {
const updates: Update<ControlAdapterConfig>[] = selectAllControlNets(state)
.filter((ca) => ca.id !== exclude)
.map((ca) => ({
id: ca.id,
changes: { isEnabled: false },
}));
caAdapter.updateMany(state, updates);
};
const disableAllT2IAdapters = (
state: ControlAdaptersState,
exclude?: string
) => {
const updates: Update<ControlAdapterConfig>[] = selectAllT2IAdapters(state)
.filter((ca) => ca.id !== exclude)
.map((ca) => ({
id: ca.id,
changes: { isEnabled: false },
}));
caAdapter.updateMany(state, updates);
};
const disableIncompatibleControlAdapters = (
state: ControlAdaptersState,
type: ControlAdapterType,
exclude?: string
) => {
if (type === 'controlnet') {
// we cannot do controlnet + t2i adapter, if we are enabled a controlnet, disable all t2is
disableAllT2IAdapters(state, exclude);
}
if (type === 't2i_adapter') {
// we cannot do controlnet + t2i adapter, if we are enabled a t2i, disable controlnets
disableAllControlNets(state, exclude);
}
};
export const controlAdaptersSlice = createSlice({
name: 'controlAdapters',
initialState: initialControlAdapterState,
reducers: {
controlAdapterAdded: {
reducer: (
state,
action: PayloadAction<{
id: string;
type: ControlAdapterType;
overrides?: Partial<ControlAdapterConfig>;
}>
) => {
const { id, type, overrides } = action.payload;
caAdapter.addOne(state, buildControlAdapter(id, type, overrides));
disableIncompatibleControlAdapters(state, type, id);
},
prepare: ({
type,
overrides,
}: {
type: ControlAdapterType;
overrides?: Partial<ControlAdapterConfig>;
}) => {
return { payload: { id: uuidv4(), type, overrides } };
},
},
controlAdapterRecalled: (
state,
action: PayloadAction<ControlAdapterConfig>
) => {
caAdapter.addOne(state, action.payload);
const { type, id } = action.payload;
disableIncompatibleControlAdapters(state, type, id);
},
controlAdapterDuplicated: {
reducer: (
state,
action: PayloadAction<{
id: string;
newId: string;
}>
) => {
const { id, newId } = action.payload;
const controlAdapter = selectControlAdapterById(state, id);
if (!controlAdapter) {
return;
}
const newControlAdapter = merge(cloneDeep(controlAdapter), {
id: newId,
isEnabled: true,
});
caAdapter.addOne(state, newControlAdapter);
const { type } = newControlAdapter;
disableIncompatibleControlAdapters(state, type, newId);
},
prepare: (id: string) => {
return { payload: { id, newId: uuidv4() } };
},
},
controlAdapterAddedFromImage: {
reducer: (
state,
action: PayloadAction<{
id: string;
type: ControlAdapterType;
controlImage: string;
}>
) => {
const { id, type, controlImage } = action.payload;
caAdapter.addOne(
state,
buildControlAdapter(id, type, { controlImage })
);
disableIncompatibleControlAdapters(state, type, id);
},
prepare: (payload: {
type: ControlAdapterType;
controlImage: string;
}) => {
return { payload: { ...payload, id: uuidv4() } };
},
},
controlAdapterRemoved: (state, action: PayloadAction<{ id: string }>) => {
caAdapter.removeOne(state, action.payload.id);
},
controlAdapterIsEnabledChanged: (
state,
action: PayloadAction<{ id: string; isEnabled: boolean }>
) => {
const { id, isEnabled } = action.payload;
caAdapter.updateOne(state, { id, changes: { isEnabled } });
if (isEnabled) {
// we are enabling a control adapter. due to limitations in the current system, we may need to disable other adapters
// TODO: disable when multiple IP adapters are supported
const ca = selectControlAdapterById(state, id);
ca && disableIncompatibleControlAdapters(state, ca.type, id);
}
},
controlAdapterImageChanged: (
state,
action: PayloadAction<{
id: string;
controlImage: string | null;
}>
) => {
const { id, controlImage } = action.payload;
const ca = selectControlAdapterById(state, id);
if (!ca) {
return;
}
caAdapter.updateOne(state, {
id,
changes: { controlImage, processedControlImage: null },
});
if (
controlImage !== null &&
isControlNetOrT2IAdapter(ca) &&
ca.processorType !== 'none'
) {
state.pendingControlImages.push(id);
}
},
controlAdapterProcessedImageChanged: (
state,
action: PayloadAction<{
id: string;
processedControlImage: string | null;
}>
) => {
const { id, processedControlImage } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn) {
return;
}
if (!isControlNetOrT2IAdapter(cn)) {
return;
}
caAdapter.updateOne(state, {
id,
changes: {
processedControlImage,
},
});
state.pendingControlImages = state.pendingControlImages.filter(
(pendingId) => pendingId !== id
);
},
controlAdapterModelCleared: (
state,
action: PayloadAction<{ id: string }>
) => {
caAdapter.updateOne(state, {
id: action.payload.id,
changes: { model: null },
});
},
controlAdapterModelChanged: (
state,
action: PayloadAction<{
id: string;
model:
| ControlNetModelParam
| T2IAdapterModelParam
| IPAdapterModelParam;
}>
) => {
const { id, model } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn) {
return;
}
if (!isControlNetOrT2IAdapter(cn)) {
caAdapter.updateOne(state, { id, changes: { model } });
return;
}
const update: Update<ControlNetConfig | T2IAdapterConfig> = {
id,
changes: { model },
};
update.changes.processedControlImage = null;
if (cn.shouldAutoConfig) {
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
if (model.model_name.includes(modelSubstring)) {
processorType =
CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
update.changes.processorType = processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlAdapterProcessorNode;
}
}
caAdapter.updateOne(state, update);
},
controlAdapterWeightChanged: (
state,
action: PayloadAction<{ id: string; weight: number }>
) => {
const { id, weight } = action.payload;
caAdapter.updateOne(state, { id, changes: { weight } });
},
controlAdapterBeginStepPctChanged: (
state,
action: PayloadAction<{ id: string; beginStepPct: number }>
) => {
const { id, beginStepPct } = action.payload;
caAdapter.updateOne(state, { id, changes: { beginStepPct } });
},
controlAdapterEndStepPctChanged: (
state,
action: PayloadAction<{ id: string; endStepPct: number }>
) => {
const { id, endStepPct } = action.payload;
caAdapter.updateOne(state, { id, changes: { endStepPct } });
},
controlAdapterControlModeChanged: (
state,
action: PayloadAction<{ id: string; controlMode: ControlMode }>
) => {
const { id, controlMode } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNet(cn)) {
return;
}
caAdapter.updateOne(state, { id, changes: { controlMode } });
},
controlAdapterResizeModeChanged: (
state,
action: PayloadAction<{
id: string;
resizeMode: ResizeMode;
}>
) => {
const { id, resizeMode } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn)) {
return;
}
caAdapter.updateOne(state, { id, changes: { resizeMode } });
},
controlAdapterProcessorParamsChanged: (
state,
action: PayloadAction<{
id: string;
params: Partial<RequiredControlAdapterProcessorNode>;
}>
) => {
const { id, params } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn) || !cn.processorNode) {
return;
}
const processorNode = merge(cloneDeep(cn.processorNode), params);
caAdapter.updateOne(state, {
id,
changes: {
shouldAutoConfig: false,
processorNode,
},
});
},
controlAdapterProcessortTypeChanged: (
state,
action: PayloadAction<{
id: string;
processorType: ControlAdapterProcessorType;
}>
) => {
const { id, processorType } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn)) {
return;
}
const processorNode = cloneDeep(
CONTROLNET_PROCESSORS[processorType].default
) as RequiredControlAdapterProcessorNode;
caAdapter.updateOne(state, {
id,
changes: {
processorType,
processedControlImage: null,
processorNode,
shouldAutoConfig: false,
},
});
},
controlAdapterAutoConfigToggled: (
state,
action: PayloadAction<{
id: string;
}>
) => {
const { id } = action.payload;
const cn = selectControlAdapterById(state, id);
if (!cn || !isControlNetOrT2IAdapter(cn)) {
return;
}
const update: Update<ControlNetConfig | T2IAdapterConfig> = {
id,
changes: { shouldAutoConfig: !cn.shouldAutoConfig },
};
if (update.changes.shouldAutoConfig) {
// manage the processor for the user
let processorType: ControlAdapterProcessorType | undefined = undefined;
for (const modelSubstring in CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS) {
if (cn.model?.model_name.includes(modelSubstring)) {
processorType =
CONTROLADAPTER_MODEL_DEFAULT_PROCESSORS[modelSubstring];
break;
}
}
if (processorType) {
update.changes.processorType = processorType;
update.changes.processorNode = CONTROLNET_PROCESSORS[processorType]
.default as RequiredControlAdapterProcessorNode;
} else {
update.changes.processorType = 'none';
update.changes.processorNode = CONTROLNET_PROCESSORS.none
.default as RequiredControlAdapterProcessorNode;
}
}
caAdapter.updateOne(state, update);
},
controlAdaptersReset: () => {
return cloneDeep(initialControlAdapterState);
},
pendingControlImagesCleared: (state) => {
state.pendingControlImages = [];
},
},
extraReducers: (builder) => {
builder.addCase(controlAdapterImageProcessed, (state, action) => {
const cn = selectControlAdapterById(state, action.payload.id);
if (!cn) {
return;
}
if (cn.controlImage !== null) {
state.pendingControlImages = uniq(
state.pendingControlImages.concat(action.payload.id)
);
}
});
builder.addCase(appSocketInvocationError, (state) => {
state.pendingControlImages = [];
});
},
});
export const {
controlAdapterAdded,
controlAdapterRecalled,
controlAdapterDuplicated,
controlAdapterAddedFromImage,
controlAdapterRemoved,
controlAdapterImageChanged,
controlAdapterProcessedImageChanged,
controlAdapterIsEnabledChanged,
controlAdapterModelChanged,
controlAdapterWeightChanged,
controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged,
controlAdapterControlModeChanged,
controlAdapterResizeModeChanged,
controlAdapterProcessorParamsChanged,
controlAdapterProcessortTypeChanged,
controlAdaptersReset,
controlAdapterAutoConfigToggled,
pendingControlImagesCleared,
controlAdapterModelCleared,
} = controlAdaptersSlice.actions;
export default controlAdaptersSlice.reducer;

View File

@@ -1,70 +0,0 @@
import { cloneDeep, merge } from 'lodash-es';
import {
ControlAdapterConfig,
ControlAdapterType,
ControlNetConfig,
IPAdapterConfig,
RequiredCannyImageProcessorInvocation,
T2IAdapterConfig,
} from '../store/types';
import { CONTROLNET_PROCESSORS } from '../store/constants';
export const initialControlNet: Omit<ControlNetConfig, 'id'> = {
type: 'controlnet',
isEnabled: true,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
controlMode: 'balanced',
resizeMode: 'just_resize',
controlImage: null,
processedControlImage: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true,
};
export const initialT2IAdapter: Omit<T2IAdapterConfig, 'id'> = {
type: 't2i_adapter',
isEnabled: true,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
resizeMode: 'just_resize',
controlImage: null,
processedControlImage: null,
processorType: 'canny_image_processor',
processorNode: CONTROLNET_PROCESSORS.canny_image_processor
.default as RequiredCannyImageProcessorInvocation,
shouldAutoConfig: true,
};
export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
type: 'ip_adapter',
isEnabled: true,
controlImage: null,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
};
export const buildControlAdapter = (
id: string,
type: ControlAdapterType,
overrides: Partial<ControlAdapterConfig> = {}
): ControlAdapterConfig => {
switch (type) {
case 'controlnet':
return merge(cloneDeep(initialControlNet), { id, ...overrides });
case 't2i_adapter':
return merge(cloneDeep(initialT2IAdapter), { id, ...overrides });
case 'ip_adapter':
return merge(cloneDeep(initialIPAdapter), { id, ...overrides });
default:
throw new Error(`Unknown control adapter type: ${type}`);
}
};

View File

@@ -3,64 +3,92 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { ChangeEvent, memo, useCallback } from 'react';
import { FaCopy, FaTrash } from 'react-icons/fa';
import {
controlAdapterDuplicated,
controlAdapterIsEnabledChanged,
controlAdapterRemoved,
} from '../store/controlAdaptersSlice';
import ParamControlAdapterModel from './parameters/ParamControlAdapterModel';
import ParamControlAdapterWeight from './parameters/ParamControlAdapterWeight';
ControlNetConfig,
controlNetDuplicated,
controlNetRemoved,
controlNetIsEnabledChanged,
} from '../store/controlNetSlice';
import ParamControlNetModel from './parameters/ParamControlNetModel';
import ParamControlNetWeight from './parameters/ParamControlNetWeight';
import { ChevronUpIcon } from '@chakra-ui/icons';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useTranslation } from 'react-i18next';
import { useToggle } from 'react-use';
import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled';
import { useControlAdapterType } from '../hooks/useControlAdapterType';
import ControlAdapterImagePreview from './ControlAdapterImagePreview';
import ControlAdapterProcessorComponent from './ControlAdapterProcessorComponent';
import ControlAdapterShouldAutoConfig from './ControlAdapterShouldAutoConfig';
import { v4 as uuidv4 } from 'uuid';
import ControlNetImagePreview from './ControlNetImagePreview';
import ControlNetProcessorComponent from './ControlNetProcessorComponent';
import ParamControlNetShouldAutoConfig from './ParamControlNetShouldAutoConfig';
import ControlNetCanvasImageImports from './imports/ControlNetCanvasImageImports';
import ParamControlAdapterBeginEnd from './parameters/ParamControlAdapterBeginEnd';
import ParamControlAdapterControlMode from './parameters/ParamControlAdapterControlMode';
import ParamControlAdapterProcessorSelect from './parameters/ParamControlAdapterProcessorSelect';
import ParamControlAdapterResizeMode from './parameters/ParamControlAdapterResizeMode';
import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
const ControlAdapterConfig = (props: { id: string; number: number }) => {
const { id, number } = props;
const controlAdapterType = useControlAdapterType(id);
type ControlNetProps = {
controlNet: ControlNetConfig;
};
const ControlNet = (props: ControlNetProps) => {
const { controlNet } = props;
const { controlNetId } = controlNet;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const activeTabName = useAppSelector(activeTabNameSelector);
const isEnabled = useControlAdapterIsEnabled(id);
const selector = createSelector(
stateSelector,
({ controlNet }) => {
const cn = controlNet.controlNets[controlNetId];
if (!cn) {
return {
isEnabled: false,
shouldAutoConfig: false,
};
}
const { isEnabled, shouldAutoConfig } = cn;
return { isEnabled, shouldAutoConfig };
},
defaultSelectorOptions
);
const { isEnabled, shouldAutoConfig } = useAppSelector(selector);
const [isExpanded, toggleIsExpanded] = useToggle(false);
const handleDelete = useCallback(() => {
dispatch(controlAdapterRemoved({ id }));
}, [id, dispatch]);
dispatch(controlNetRemoved({ controlNetId }));
}, [controlNetId, dispatch]);
const handleDuplicate = useCallback(() => {
dispatch(controlAdapterDuplicated(id));
}, [id, dispatch]);
dispatch(
controlNetDuplicated({
sourceControlNetId: controlNetId,
newControlNetId: uuidv4(),
})
);
}, [controlNetId, dispatch]);
const handleToggleIsEnabled = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
controlAdapterIsEnabledChanged({
id,
controlNetIsEnabledChanged({
controlNetId,
isEnabled: e.target.checked,
})
);
},
[id, dispatch]
[controlNetId, dispatch]
);
if (!controlAdapterType) {
return null;
}
return (
<Flex
sx={{
@@ -75,31 +103,27 @@ const ControlAdapterConfig = (props: { id: string; number: number }) => {
},
}}
>
<Flex
sx={{ gap: 2, alignItems: 'center', justifyContent: 'space-between' }}
>
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<IAISwitch
label={t(`controlnet.${controlAdapterType}`, { number })}
tooltip={t('controlnet.toggleControlNet')}
aria-label={t('controlnet.toggleControlNet')}
isChecked={isEnabled}
onChange={handleToggleIsEnabled}
formControlProps={{ w: 'full' }}
formLabelProps={{ fontWeight: 600 }}
/>
</Flex>
<Flex sx={{ gap: 2, alignItems: 'center' }}>
<Box
sx={{
w: 'full',
minW: 0,
// opacity: isEnabled ? 1 : 0.5,
// pointerEvents: isEnabled ? 'auto' : 'none',
transitionProperty: 'common',
transitionDuration: '0.1s',
}}
>
<ParamControlAdapterModel id={id} />
<ParamControlNetModel controlNet={controlNet} />
</Box>
{activeTabName === 'unifiedCanvas' && (
<ControlNetCanvasImageImports id={id} />
<ControlNetCanvasImageImports controlNet={controlNet} />
)}
<IAIIconButton
size="sm"
@@ -150,6 +174,23 @@ const ControlAdapterConfig = (props: { id: string; number: number }) => {
/>
}
/>
{!shouldAutoConfig && (
<Box
sx={{
position: 'absolute',
w: 1.5,
h: 1.5,
borderRadius: 'full',
top: 4,
insetInlineEnd: 4,
bg: 'accent.700',
_dark: {
bg: 'accent.400',
},
}}
/>
)}
</Flex>
<Flex sx={{ w: 'full', flexDirection: 'column', gap: 3 }}>
@@ -166,8 +207,8 @@ const ControlAdapterConfig = (props: { id: string; number: number }) => {
justifyContent: 'space-between',
}}
>
<ParamControlAdapterWeight id={id} />
<ParamControlAdapterBeginEnd id={id} />
<ParamControlNetWeight controlNet={controlNet} />
<ParamControlNetBeginEnd controlNet={controlNet} />
</Flex>
{!isExpanded && (
<Flex
@@ -179,26 +220,26 @@ const ControlAdapterConfig = (props: { id: string; number: number }) => {
aspectRatio: '1/1',
}}
>
<ControlAdapterImagePreview id={id} isSmall />
<ControlNetImagePreview controlNet={controlNet} isSmall />
</Flex>
)}
</Flex>
<Flex sx={{ gap: 2 }}>
<ParamControlAdapterControlMode id={id} />
<ParamControlAdapterResizeMode id={id} />
<ParamControlNetControlMode controlNet={controlNet} />
<ParamControlNetResizeMode controlNet={controlNet} />
</Flex>
<ParamControlAdapterProcessorSelect id={id} />
<ParamControlNetProcessorSelect controlNet={controlNet} />
</Flex>
{isExpanded && (
<>
<ControlAdapterImagePreview id={id} />
<ControlAdapterShouldAutoConfig id={id} />
<ControlAdapterProcessorComponent id={id} />
<ControlNetImagePreview controlNet={controlNet} />
<ParamControlNetShouldAutoConfig controlNet={controlNet} />
<ControlNetProcessorComponent controlNet={controlNet} />
</>
)}
</Flex>
);
};
export default memo(ControlAdapterConfig);
export default memo(ControlNet);

View File

@@ -23,20 +23,20 @@ import {
} from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
import IAIDndImageIcon from '../../../common/components/IAIDndImageIcon';
import { controlAdapterImageChanged } from '../store/controlAdaptersSlice';
import { useControlAdapterControlImage } from '../hooks/useControlAdapterControlImage';
import { useControlAdapterProcessedControlImage } from '../hooks/useControlAdapterProcessedControlImage';
import { useControlAdapterProcessorType } from '../hooks/useControlAdapterProcessorType';
import {
ControlNetConfig,
controlNetImageChanged,
} from '../store/controlNetSlice';
type Props = {
id: string;
controlNet: ControlNetConfig;
isSmall?: boolean;
};
const selector = createSelector(
stateSelector,
({ controlAdapters, gallery }) => {
const { pendingControlImages } = controlAdapters;
({ controlNet, gallery }) => {
const { pendingControlImages } = controlNet;
const { autoAddBoardId } = gallery;
return {
@@ -47,10 +47,13 @@ const selector = createSelector(
defaultSelectorOptions
);
const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
const controlImageName = useControlAdapterControlImage(id);
const processedControlImageName = useControlAdapterProcessedControlImage(id);
const processorType = useControlAdapterProcessorType(id);
const ControlNetImagePreview = ({ isSmall, controlNet }: Props) => {
const {
controlImage: controlImageName,
processedControlImage: processedControlImageName,
processorType,
controlNetId,
} = controlNet;
const dispatch = useAppDispatch();
const { t } = useTranslation();
@@ -72,8 +75,8 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
const [addToBoard] = useAddImageToBoardMutation();
const [removeFromBoard] = useRemoveImageFromBoardMutation();
const handleResetControlImage = useCallback(() => {
dispatch(controlAdapterImageChanged({ id, controlImage: null }));
}, [id, dispatch]);
dispatch(controlNetImageChanged({ controlNetId, controlImage: null }));
}, [controlNetId, dispatch]);
const handleSaveControlImage = useCallback(async () => {
if (!processedControlImage) {
@@ -130,32 +133,32 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (controlImage) {
return {
id,
id: controlNetId,
payloadType: 'IMAGE_DTO',
payload: { imageDTO: controlImage },
};
}
}, [controlImage, id]);
}, [controlImage, controlNetId]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id,
actionType: 'SET_CONTROL_ADAPTER_IMAGE',
context: { id },
id: controlNetId,
actionType: 'SET_CONTROLNET_IMAGE',
context: { controlNetId },
}),
[id]
[controlNetId]
);
const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'SET_CONTROL_ADAPTER_IMAGE', id }),
[id]
() => ({ type: 'SET_CONTROLNET_IMAGE', controlNetId }),
[controlNetId]
);
const shouldShowProcessedImage =
controlImage &&
processedControlImage &&
!isMouseOverImage &&
!pendingControlImages.includes(id) &&
!pendingControlImages.includes(controlNetId) &&
processorType !== 'none';
return (
@@ -219,7 +222,7 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
/>
</>
{pendingControlImages.includes(id) && (
{pendingControlImages.includes(controlNetId) && (
<Flex
sx={{
position: 'absolute',
@@ -247,4 +250,4 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
);
};
export default memo(ControlAdapterImagePreview);
export default memo(ControlNetImagePreview);

View File

@@ -1,26 +1,26 @@
import { useAppDispatch } from 'app/store/storeHooks';
import IAIButton from 'common/components/IAIButton';
import { useIsReadyToEnqueue } from 'common/hooks/useIsReadyToEnqueue';
import { memo, useCallback } from 'react';
import { useControlAdapterControlImage } from '../hooks/useControlAdapterControlImage';
import { controlAdapterImageProcessed } from '../store/actions';
import { ControlNetConfig } from '../store/controlNetSlice';
import { useAppDispatch } from 'app/store/storeHooks';
import { controlNetImageProcessed } from '../store/actions';
import { useIsReadyToEnqueue } from 'common/hooks/useIsReadyToEnqueue';
type Props = {
id: string;
controlNet: ControlNetConfig;
};
const ControlAdapterPreprocessButton = ({ id }: Props) => {
const controlImage = useControlAdapterControlImage(id);
const ControlNetPreprocessButton = (props: Props) => {
const { controlNetId, controlImage } = props.controlNet;
const dispatch = useAppDispatch();
const isReady = useIsReadyToEnqueue();
const handleProcess = useCallback(() => {
dispatch(
controlAdapterImageProcessed({
id,
controlNetImageProcessed({
controlNetId,
})
);
}, [id, dispatch]);
}, [controlNetId, dispatch]);
return (
<IAIButton
@@ -33,4 +33,4 @@ const ControlAdapterPreprocessButton = ({ id }: Props) => {
);
};
export default memo(ControlAdapterPreprocessButton);
export default memo(ControlNetPreprocessButton);

View File

@@ -1,4 +1,5 @@
import { memo } from 'react';
import { ControlNetConfig } from '../store/controlNetSlice';
import CannyProcessor from './processors/CannyProcessor';
import ColorMapProcessor from './processors/ColorMapProcessor';
import ContentShuffleProcessor from './processors/ContentShuffleProcessor';
@@ -12,25 +13,18 @@ import NormalBaeProcessor from './processors/NormalBaeProcessor';
import OpenposeProcessor from './processors/OpenposeProcessor';
import PidiProcessor from './processors/PidiProcessor';
import ZoeDepthProcessor from './processors/ZoeDepthProcessor';
import { useControlAdapterIsEnabled } from '../hooks/useControlAdapterIsEnabled';
import { useControlAdapterProcessorNode } from '../hooks/useControlAdapterProcessorNode';
export type Props = {
id: string;
export type ControlNetProcessorProps = {
controlNet: ControlNetConfig;
};
const ControlAdapterProcessorComponent = ({ id }: Props) => {
const isEnabled = useControlAdapterIsEnabled(id);
const processorNode = useControlAdapterProcessorNode(id);
if (!processorNode) {
return null;
}
const ControlNetProcessorComponent = (props: ControlNetProcessorProps) => {
const { controlNetId, isEnabled, processorNode } = props.controlNet;
if (processorNode.type === 'canny_image_processor') {
return (
<CannyProcessor
controlNetId={id}
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -40,7 +34,7 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
if (processorNode.type === 'color_map_image_processor') {
return (
<ColorMapProcessor
controlNetId={id}
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -50,7 +44,7 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
if (processorNode.type === 'hed_image_processor') {
return (
<HedProcessor
controlNetId={id}
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -60,7 +54,7 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
if (processorNode.type === 'lineart_image_processor') {
return (
<LineartProcessor
controlNetId={id}
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -70,7 +64,7 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
if (processorNode.type === 'content_shuffle_image_processor') {
return (
<ContentShuffleProcessor
controlNetId={id}
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -80,7 +74,7 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
if (processorNode.type === 'lineart_anime_image_processor') {
return (
<LineartAnimeProcessor
controlNetId={id}
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -90,7 +84,7 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
if (processorNode.type === 'mediapipe_face_processor') {
return (
<MediapipeFaceProcessor
controlNetId={id}
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -100,7 +94,7 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
if (processorNode.type === 'midas_depth_image_processor') {
return (
<MidasDepthProcessor
controlNetId={id}
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -110,7 +104,7 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
if (processorNode.type === 'mlsd_image_processor') {
return (
<MlsdImageProcessor
controlNetId={id}
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -120,7 +114,7 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
if (processorNode.type === 'normalbae_image_processor') {
return (
<NormalBaeProcessor
controlNetId={id}
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -130,7 +124,7 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
if (processorNode.type === 'openpose_image_processor') {
return (
<OpenposeProcessor
controlNetId={id}
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -140,7 +134,7 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
if (processorNode.type === 'pidi_image_processor') {
return (
<PidiProcessor
controlNetId={id}
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -150,7 +144,7 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
if (processorNode.type === 'zoe_depth_image_processor') {
return (
<ZoeDepthProcessor
controlNetId={id}
controlNetId={controlNetId}
processorNode={processorNode}
isEnabled={isEnabled}
/>
@@ -160,4 +154,4 @@ const ControlAdapterProcessorComponent = ({ id }: Props) => {
return null;
};
export default memo(ControlAdapterProcessorComponent);
export default memo(ControlNetProcessorComponent);

Some files were not shown because too many files have changed in this diff Show More