Compare commits

..

106 Commits

Author SHA1 Message Date
Sergey Borisov
18956a6186 Expose seamless variables to node 2023-09-20 01:10:37 +03:00
Kent Keirsey
864f2270c3 feat: Add IP Adapter to InvokeAI (Node & Linear) (#4429)
## What type of PR is this? (check all applicable)

- [ ] Refactor
- [x] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

      
## Have you updated all relevant documentation?
- [ ] Yes
- [ ] No


## Description (edit by @blessedcoolant , @RyanJDick )

This PR adds support for IP-Adapters (a technique for image-based
prompts) in Invoke AI. Currently only available in the Node UI.

IP-Adapter Paper: [IP-Adapter: Text Compatible Image Prompt Adapter for
Text-to-Image Diffusion Models](https://arxiv.org/abs/2308.06721)
IP-Adapter reference code: https://github.com/tencent-ailab/IP-Adapter

On order to test, install the following models via the InvokeAI UI:

Image Encoders:

[InvokeAI/ip_adapter_sd_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder)

[InvokeAI/ip_adapter_sdxl_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sdxl_image_encoder)

IP-Adapters:

[InvokeAI/ip_adapter_sd15](https://huggingface.co/InvokeAI/ip_adapter_sd15)

[InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)

[InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)

[InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)

Old instructions (for reference only):

> In order to test, you need to download and place the following models
in your InvokeAI models directory.
> 
> - SD 1.5 - https://huggingface.co/h94/IP-Adapter/tree/main/models -->
Download the models and the `image_encoder` folder to
`models/core/ip_adapters/sd-1`
> - SDXL - https://huggingface.co/h94/IP-Adapter/tree/main/sdxl_models
-Download the models and the `image_encoder` folder to
`models/core/ip_adapaters/sdxl`
> 
> This is only temporary. This needs to be handled differently. I
outlined them here.
https://github.com/invoke-ai/InvokeAI/pull/4429#issuecomment-1705776570

## Examples using this PR

### Image variations, no text prompt
Leftmost image in each row is original image used for input to
IP-Adapter. The other rows are example outputs with different seeds,
other parameters identical.

![ipadapter_invokai_example1](https://github.com/invoke-ai/InvokeAI/assets/303100/cae18b97-14a9-4499-8d87-f07faa8ad13a)







## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Related Issue #
- Closes #

## QA Instructions, Screenshots, Recordings

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Added/updated tests?

- [ ] Yes
- [ ] No : _please replace this line with details on why tests
      have not been included_

## [optional] Are there any post deployment tasks we need to perform?
2023-09-19 14:31:08 -04:00
Ryan Dick
8b44d83859 yarn build 2023-09-19 14:03:22 -04:00
Kent Keirsey
0b6315de71 Merge branch 'main' into feat/ip-adapter 2023-09-19 13:49:20 -04:00
Ryan Dick
92b49e45bb Address flake8 error. 2023-09-18 16:33:16 -04:00
Ryan Dick
b05b8ef677 Switch to using torch 2.0 attention for IP-Adapter (more memory-efficient). 2023-09-18 16:30:53 -04:00
Ryan Dick
382e2139bd Clear incompatible IP-Adapter when base model changes in the Linear UI. 2023-09-18 12:57:23 -04:00
blessedcoolant
2a3909da94 isort: fix issues 2023-09-17 12:14:58 +12:00
blessedcoolant
e0dddbd38e chore: fix isort issues 2023-09-17 12:13:03 +12:00
blessedcoolant
231b7a5000 fix: Upload not working correctly on the ip Adapter image upload 2023-09-17 12:08:35 +12:00
blessedcoolant
b7773c9962 chore: black & lint fixes 2023-09-17 12:00:21 +12:00
blessedcoolant
11c501fc80 fix: Upload issue with the ip adapter image uploader 2023-09-17 11:58:15 +12:00
blessedcoolant
7be5743011 feat: Add IP Adapter Begin & End Percent to Linear UI 2023-09-17 11:53:05 +12:00
user1
c48e648cbb Added per-step setting of IP-Adapter weights (for param easing, etc.) 2023-09-16 12:36:16 -07:00
user1
29b4ddcc7f Merge branch 'feat/ip-adapter' of github.com:invoke-ai/InvokeAI into feat/ip-adapter 2023-09-16 09:32:41 -07:00
user1
7ee13879e3 Added check in IP-Adapter to avoid begin/end step percent handling if use of IP-Adapter is already turned off due to potential clash with other cross attention control. 2023-09-16 09:29:50 -07:00
user1
ced297ed21 Initial implementation of IP-Adapter "begin_step_percent" and "end_step_percent" for controlling on which steps IP-Adapter is applied in the denoising loop. 2023-09-16 08:24:12 -07:00
blessedcoolant
3e813ead1f chore: extract the adapter info initial state 2023-09-16 10:59:19 -04:00
blessedcoolant
820ec08e9a feat: Update Control Adapter Collapse active status to reflect IP Adapter 2023-09-16 10:59:19 -04:00
blessedcoolant
4dd289b337 feat: Handle IP Adapter Image being reset on being deleted. 2023-09-16 10:59:19 -04:00
blessedcoolant
b60b1e359e fix: Decrease the size of the IP Adapter Image Reset Button 2023-09-16 10:59:19 -04:00
blessedcoolant
208286e97a wip: Improve the IP Adapter UI 2023-09-16 10:59:19 -04:00
blessedcoolant
f7b64304ae wip: Add IP Adapter To Linear UI 2023-09-16 10:59:19 -04:00
blessedcoolant
834751e877 Merge branch 'main' into feat/ip-adapter 2023-09-16 07:06:46 +12:00
Ryan Dick
343df03a92 isort 2023-09-15 13:18:00 -04:00
Ryan Dick
b57acb7353 Merge branch 'main' into feat/ip-adapter 2023-09-15 13:15:25 -04:00
Ryan Dick
56340c24c8 IP-Adapter Model Management (#4540)
Note: The target branch is `feat/ip-adapter`, not `main`. After a
cursory review here, I'll merge for an in-depth review as part of
https://github.com/invoke-ai/InvokeAI/pull/4429.

## Description

This branch adds model management support for IP-Adapter models. There
are a few notable/unusual aspects to how it is implemented:
- We have defined a model format that works better with our model
manager than the 'official' IP-Adapter repo, and will be hosting the
IP-Adapter models ourselves (See `invokeai/backend/ip_adapter/README.md`
for a description of the expected model formats.)
- The CLIP Vision models and IP-Adapter models are handled independently
in the model manager. The IP-Adapter model info has a reference to the
CLIP model that it is intended to be run with.
- The `BaseModelType.Any` field was added for CLIP Vision models, as
they don't have a clear 1-to-1 association with a particular base model.

## QA Instructions, Screenshots, Recordings

Install the following models via the InvokeAI UI:

Image Encoders:
-
[InvokeAI/ip_adapter_sd_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder)
-
[InvokeAI/ip_adapter_sdxl_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sdxl_image_encoder)

IP-Adapters:
-
[InvokeAI/ip_adapter_sd15](https://huggingface.co/InvokeAI/ip_adapter_sd15)
-
[InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
-
[InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
-
[InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
2023-09-15 12:42:02 -04:00
Ryan Dick
16664da5b6 black 2023-09-14 23:49:02 -04:00
Ryan Dick
c104807201 Update list of supported IP-Adapters. 2023-09-14 23:43:19 -04:00
Ryan Dick
990ce9a1da Lookup IP-Adapter linked image encoder from disk instead of storing in model config metadata. 2023-09-14 23:06:57 -04:00
Ryan Dick
18095ecc44 yarn build 2023-09-14 16:56:51 -04:00
Ryan Dick
fe19f11abf Bump DenoiseLatentsInvocation minor version. 2023-09-14 16:54:07 -04:00
Ryan Dick
c2f074dc2f Fix python static checks. 2023-09-14 16:48:47 -04:00
Ryan Dick
e02a557454 Fix frontend typescript errors. 2023-09-14 16:43:43 -04:00
Ryan Dick
fca60862e2 Add README.md describing IP-Adapter model formats. 2023-09-14 16:02:07 -04:00
Ryan Dick
94c186bb4c Fix bug in IPAdapter.to(...). 2023-09-14 15:45:25 -04:00
Ryan Dick
a22c8cb3a1 Improve robustness of check for IPAdapter vs IPAdapterPlus. 2023-09-14 15:25:41 -04:00
Ryan Dick
781e8521d5 Eliminate the need for IPAdapter.initialize(). 2023-09-14 15:02:59 -04:00
Ryan Dick
d114d0ba95 Remove need for the image_encoder param in IPAdapter.initialize(). 2023-09-14 14:14:35 -04:00
Ryan Dick
cc8b7a74da (minor) Delete minor TODO. 2023-09-14 13:04:34 -04:00
Ryan Dick
388554448a Add CLIP Vision model to IP-Adapter info and use this to infer which model to use. 2023-09-14 11:57:53 -04:00
Ryan Dick
cadc0839a6 typegen 2023-09-14 11:19:52 -04:00
Ryan Dick
d5160648d0 Add support for downloading IP-Adapter models from HF. 2023-09-14 11:18:43 -04:00
Ryan Dick
6d0ea42a94 Get CLIPVision model download from HF working. 2023-09-14 09:54:10 -04:00
Ryan Dick
2c1100509f Add BaseModelType.Any to be used by CLIPVisionModel. 2023-09-14 08:19:55 -04:00
Ryan Dick
c34b359c36 (minor) Remove duplicate TODO. 2023-09-13 21:25:20 -04:00
Ryan Dick
77d135967f Update IPAdapterModel to respect requested torch_dtype. 2023-09-13 21:06:42 -04:00
Ryan Dick
ebf26687cb (minor) Remove unnecessary TODO. 2023-09-13 21:03:42 -04:00
Ryan Dick
1c8991a3df Use CLIPVisionModel under model management for IP-Adapter. 2023-09-13 19:10:02 -04:00
Ryan Dick
3d52656176 Add CLIPVisionModel to model management. 2023-09-13 17:14:20 -04:00
Ryan Dick
a2777decd4 Add a IPAdapterModelField for passing passing IP-Adapter models between nodes. 2023-09-13 13:40:59 -04:00
Ryan Dick
468253aa14 typegen 2023-09-13 08:27:24 -04:00
Ryan Dick
3ee9a21647 Initial (barely) working version of IP-Adapter model management. 2023-09-13 08:27:24 -04:00
Ryan Dick
0d823901ef Add IPAdapter to model_management __init__.py 2023-09-13 08:27:24 -04:00
Ryan Dick
7ee55489bb Improve model search warning messages. 2023-09-13 08:27:24 -04:00
Ryan Dick
163ece9aee Initial skeleton for IPAdapter model management. 2023-09-13 08:27:24 -04:00
Ryan Dick
aa7d945b23 IP-Adapter Re-Factor (#4496)
## What type of PR is this? (check all applicable)

- [x] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

## Description

**NOTE!!!** This PR is against `feat/ip-adapter`, not `main`. I created
a PR because I made some pretty significant changes that I thought might
spark discussion.

I don't think it makes sense to do a full in-depth review here. If
possible, let's try to agree on the high-level approach and then merge
this and do an in-depth review on the original PR.

High-level changes:
- Split `IPAdapterField` from the `ControlField` and make them separate
inputs on the `DenoiseLatentsInvocation`
- Create context manager that handles patching/un-patching the UNet with
IP-Adapter attention blocks (`IPAdapter.apply_ip_adapter_attention()`)
- Pass IP-Adapter conditioning via `cross_attention_kwargs` rather than
concatenating it to the text embedding. This helps avoid breaking other
features (like long prompts).
- Remove unused blocks of the IP-Adapter implementation and do some
general tidying.

Out of scope:
- I haven't looked at model management yet. I'd like to get this merged
into `feat/ip-adapter` and then look at model management separately.
2023-09-11 18:51:10 -04:00
Ryan Dick
50a0691514 flake8 2023-09-08 18:05:31 -04:00
Ryan Dick
a255624984 black 2023-09-08 17:55:23 -04:00
Ryan Dick
2630fe3608 Remove unused ip_adapter/utils.py file. 2023-09-08 16:25:34 -04:00
Ryan Dick
dee6f86d5e Set 'title' for IP-Adapter fields with non-default names. 2023-09-08 16:14:17 -04:00
Ryan Dick
6ca6cf713c Tidy IPAdapter. Add types, improve field/method naming. 2023-09-08 16:00:58 -04:00
Ryan Dick
3f7d5b4e0f Remove redundant IPAdapterXL class. 2023-09-08 15:46:10 -04:00
Ryan Dick
91596d9527 Re-factor IPAdapter to patch UNet in a context manager. 2023-09-08 15:39:22 -04:00
Ryan Dick
d669f0855d Comment unused IPAdapter generate(...) methods. 2023-09-08 13:12:42 -04:00
Ryan Dick
b2d5b53b5f Pass IP-Adapter conditioning via cross_attention_kwargs instead of concatenating to the text embedding. This avoids interference with other features that manipulate the text embedding (e.g. long prompts). 2023-09-08 11:47:36 -04:00
Ryan Dick
ddc148b70b Move ConditioningData and its field classes to their own file. This will allow new conditioning types to be added more cleanly without introducing circular dependencies. 2023-09-08 11:00:11 -04:00
Ryan Dick
c2d43f007b Specify the image_embedding_len in the IPAttnProcessor rather than the text embedding length. This enables the IPAttnProcessor to handle text embeddings of varying lengths. 2023-09-07 18:20:21 -04:00
Ryan Dick
7703bf2ca1 Delete IP-Adapter copies of AttnProcessor and AttnProcessor2_0, which were unmodified from diffusers. 2023-09-07 15:00:13 -04:00
Ryan Dick
23fdf0156f Clean up IP-Adapter in diffusers_pipeline.py - WIP 2023-09-06 20:42:20 -04:00
Ryan Dick
cdbf40c9b2 Revert ControlNetInvocation changes. 2023-09-06 19:30:30 -04:00
Ryan Dick
46c9dcb113 Run yarn build. 2023-09-06 17:16:01 -04:00
Ryan Dick
6df79045fa Run typegen. 2023-09-06 17:03:37 -04:00
Ryan Dick
d776e0a0a9 Split ControlField and IpAdapterField. 2023-09-06 17:03:37 -04:00
blessedcoolant
94ec3da7b5 chore: regen scheme merge 2023-09-05 15:23:16 +12:00
blessedcoolant
f44496a579 Merge branch 'main' into feat/ip-adapter 2023-09-05 15:22:15 +12:00
blessedcoolant
99fe95ab03 fix: Add validation for image_encoder model too 2023-09-05 14:49:41 +12:00
psychedelicious
95ecb1a0c1 fix(ip_adapter): add None to types 2023-09-05 12:30:00 +10:00
psychedelicious
bd15874cf6 feat(nodes): add control_type validation & fix types 2023-09-05 12:24:54 +10:00
blessedcoolant
30ab81b6bb fix: Update paths so they are serializable in the nodes 2023-09-05 13:50:21 +12:00
blessedcoolant
78195491bc fix: Make the adapter models use new local paths 2023-09-05 13:39:54 +12:00
blessedcoolant
c63390f6e1 fix: Temporarily update the ControlField zod model
While we decide how to go ahead with this .
2023-09-05 12:29:05 +12:00
blessedcoolant
cbd451c610 chore: Regen Schema 2023-09-05 12:13:08 +12:00
blessedcoolant
b0f91f2e75 fix: Remove types on adapter nodes. Superseded by the decorator 2023-09-05 12:12:19 +12:00
blessedcoolant
3ac68cde66 chore: flake8 cleanup 2023-09-05 12:07:12 +12:00
blessedcoolant
a69b1cd598 chore: Add Versioning data to new adapters + update model paths 2023-09-05 11:54:50 +12:00
blessedcoolant
65a76a086b cleanup: Some basic cleanup 2023-09-05 11:54:28 +12:00
blessedcoolant
07381e5a26 cleanup: merge conflicts 2023-09-05 11:37:12 +12:00
blessedcoolant
6bb378a101 Merge branch 'main' into feat/ip-adapter 2023-09-05 11:35:19 +12:00
psychedelicious
b761807219 Merge branch 'main' into feat/ip-adapter 2023-09-02 11:31:08 +10:00
user1
fb1b03960e Added IP-Adapter SDXL support. Added IP-Adapter "Plus" (more detail) model support. 2023-09-01 04:40:30 -07:00
user1
74bfb5e1f9 First commit of separate node for IP-Adapter.
And it own dataclasses for passing info.
2023-08-31 23:07:15 -07:00
user1
942ecbbde4 Merge branch 'feat/ip-adapter' of github.com:invoke-ai/InvokeAI into feat/ip-adapter 2023-08-30 18:35:53 -07:00
user1
79db0e9e93 More cleanup after rebasing to main. 2023-08-30 18:29:06 -07:00
user1
0c17f8604f Resolving rebase conflict, redirecting control imports to invocations/control_adapter 2023-08-30 17:35:31 -07:00
user1
054edc4077 Oops, forgot to add control_adapter.py for control nodes in last refactor commit 2023-08-30 17:31:46 -07:00
user1
5a9993772d Added ip_adapter_strength parameter to adjust weighting of IP-Adapter's added cross-attention layers 2023-08-30 17:28:30 -07:00
user1
f2cd9e9ae2 Working POC for IP-Adapters. Not fully nodified yet, lots of caveats, hardwired model paths, etc. 2023-08-30 17:28:30 -07:00
user1
9f86cfa471 Working POC of IP-Adapters. Not fully nodified yet. 2023-08-30 17:28:30 -07:00
user1
8c1390166f Modifying code from https://github.com/tencent-ailab/IP-Adapter. Also adding license notice at top. 2023-08-30 17:28:30 -07:00
user1
1ad98ce999 Core ip_adapter files from https://github.com/tencent-ailab/IP-Adapter
Copied into InvokeAI since IP-Adapter repo is not a package. Is there a better way to do this for non-packaged Python code while still keeping InvokeAI install easy?
2023-08-30 17:28:30 -07:00
user1
5f4a62810e Added ip_adapter_strength parameter to adjust weighting of IP-Adapter's added cross-attention layers 2023-08-29 10:47:37 -07:00
user1
35b7ae90ae Working POC for IP-Adapters. Not fully nodified yet, lots of caveats, hardwired model paths, etc. 2023-08-29 10:47:37 -07:00
user1
9ed4d487d2 Working POC of IP-Adapters. Not fully nodified yet. 2023-08-29 10:47:37 -07:00
user1
69d37217b8 Modifying code from https://github.com/tencent-ailab/IP-Adapter. Also adding license notice at top. 2023-08-29 10:47:37 -07:00
user1
7afdefb0e5 Core ip_adapter files from https://github.com/tencent-ailab/IP-Adapter
Copied into InvokeAI since IP-Adapter repo is not a package. Is there a better way to do this for non-packaged Python code while still keeping InvokeAI install easy?
2023-08-29 10:47:37 -07:00
111 changed files with 4220 additions and 2562 deletions

View File

@@ -9,7 +9,6 @@ from invokeai.app.services.boards import BoardService, BoardServiceDependencies
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
@@ -127,7 +126,6 @@ class ApiDependencies:
configuration=config,
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
)
create_system_graphs(services.graph_library)

View File

@@ -1,7 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from .services.config import InvokeAIAppConfig
# parse_args() must be called before any other imports. if it is not called first, consumers of the config
@@ -311,7 +309,6 @@ def invoke_cli():
performance_statistics=InvocationStatsService(graph_execution_manager),
logger=logger,
configuration=config,
invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size),
)
system_graphs = create_system_graphs(services.graph_library)

View File

@@ -67,6 +67,7 @@ class FieldDescriptions:
width = "Width of output (px)"
height = "Height of output (px)"
control = "ControlNet(s) to apply"
ip_adapter = "IP-Adapter to apply"
denoised_latents = "Denoised latents tensor"
latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)"
@@ -155,6 +156,7 @@ class UIType(str, Enum):
VaeModel = "VaeModelField"
LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
UNet = "UNetField"
Vae = "VaeField"
CLIP = "ClipField"
@@ -568,24 +570,7 @@ class BaseInvocation(ABC, BaseModel):
raise RequiredConnectionException(self.__fields__["type"].default, field_name)
elif _input == Input.Any:
raise MissingInputException(self.__fields__["type"].default, field_name)
output: BaseInvocationOutput
if self.use_cache:
key = context.services.invocation_cache.create_key(self)
cached_value = context.services.invocation_cache.get(key)
if cached_value is None:
context.services.logger.debug(f'Invocation cache miss for type "{self.get_type()}": {self.id}')
output = self.invoke(context)
context.services.invocation_cache.save(key, output)
return output
else:
context.services.logger.debug(f'Invocation cache hit for type "{self.get_type()}": {self.id}')
return cached_value
else:
context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}')
return self.invoke(context)
def get_type(self) -> str:
return self.__fields__["type"].default
return self.invoke(context)
id: str = Field(
description="The id of this instance of an invocation. Must be unique among all instances of invocations."
@@ -598,7 +583,6 @@ class BaseInvocation(ABC, BaseModel):
description="The workflow to save with the image",
ui_type=UIType.WorkflowField,
)
use_cache: bool = InputField(default=True, description="Whether or not to use the cache")
@validator("workflow", pre=True)
def validate_workflow_is_json(cls, v):
@@ -622,7 +606,6 @@ def invocation(
tags: Optional[list[str]] = None,
category: Optional[str] = None,
version: Optional[str] = None,
use_cache: Optional[bool] = True,
) -> Callable[[Type[GenericBaseInvocation]], Type[GenericBaseInvocation]]:
"""
Adds metadata to an invocation.
@@ -655,8 +638,6 @@ def invocation(
except ValueError as e:
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
cls.UIConfig.version = version
if use_cache is not None:
cls.__fields__["use_cache"].default = use_cache
# Add the invocation type to the pydantic model of the invocation
invocation_type_annotation = Literal[invocation_type] # type: ignore

View File

@@ -56,7 +56,6 @@ class RangeOfSizeInvocation(BaseInvocation):
tags=["range", "integer", "random", "collection"],
category="collections",
version="1.0.0",
use_cache=False,
)
class RandomRangeInvocation(BaseInvocation):
"""Creates a collection of random numbers"""

View File

@@ -7,14 +7,14 @@ from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ExtraConditioningInfo,
SDXLConditioningInfo,
)
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.util.devices import torch_dtype
from .baseinvocation import (
BaseInvocation,
@@ -99,14 +99,15 @@ class CompelInvocation(BaseInvocation):
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, self.clip.skipped_layers
), text_encoder_info as text_encoder:
with (
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
text_encoder_info as text_encoder,
):
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@@ -122,7 +123,7 @@ class CompelInvocation(BaseInvocation):
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
ec = ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
@@ -213,14 +214,15 @@ class SDXLPromptInvocationBase:
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora(
text_encoder_info.context.model, _lora_loader(), lora_prefix
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, clip_field.skipped_layers
), text_encoder_info as text_encoder:
with (
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
text_encoder_info as text_encoder,
):
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@@ -244,7 +246,7 @@ class SDXLPromptInvocationBase:
else:
c_pooled = None
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
ec = ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
@@ -436,9 +438,11 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
text_fragments = [
x.text
if type(x) is Fragment
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
(
x.text
if type(x) is Fragment
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
)
for x in parsed_prompt.children
]
text = " ".join(text_fragments)

View File

@@ -965,42 +965,3 @@ class ImageChannelMultiplyInvocation(BaseInvocation):
width=image_dto.width,
height=image_dto.height,
)
@invocation(
"save_image",
title="Save Image",
tags=["primitives", "image"],
category="primitives",
version="1.0.0",
use_cache=False,
)
class SaveImageInvocation(BaseInvocation):
"""Saves an image. Unlike an image primitive, this invocation stores a copy of the image."""
image: ImageField = InputField(description="The image to load")
metadata: CoreMetadata = InputField(
default=None,
description=FieldDescriptions.core_metadata,
ui_hidden=True,
)
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
workflow=self.workflow,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@@ -0,0 +1,105 @@
import os
from builtins import float
from typing import List, Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.model_management.models.ip_adapter import get_ip_adapter_image_encoder_model_id
class IPAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the IP-Adapter model")
base_model: BaseModelType = Field(description="Base model")
class CLIPVisionModelField(BaseModel):
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
class IPAdapterField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
# weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the IP-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
@invocation_output("ip_adapter_output")
class IPAdapterOutput(BaseInvocationOutput):
# Outputs
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.0.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = InputField(
description="The IP-Adapter model.",
title="IP-Adapter Model",
input=Input.Direct,
)
# weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
weight: Union[float, List[float]] = InputField(
default=1, ge=0, description="The weight given to the IP-Adapter", ui_type=UIType.Float, title="Weight"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the IP-Adapter is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the IP-Adapter is last applied (% of total steps)"
)
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.services.model_manager.model_info(
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
)
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
# directly, and 2) we are reading from disk every time this invocation is called without caching the result.
# A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this
# is currently messy due to differences between how the model info is generated when installing a model from
# disk vs. downloading the model.
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"])
)
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_model = CLIPVisionModelField(
model_name=image_encoder_model_name,
base_model=BaseModelType.Any,
)
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=self.ip_adapter_model,
image_encoder_model=image_encoder_model,
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
),
)

View File

@@ -8,6 +8,7 @@ import numpy as np
import torch
import torchvision.transforms as T
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@@ -19,6 +20,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import validator
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import (
DenoiseMaskField,
@@ -31,15 +33,17 @@ from invokeai.app.invocations.primitives import (
)
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import BaseModelType
from ...backend.model_management.seamless import set_seamless
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,
ControlNetData,
IPAdapterData,
StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor,
)
@@ -68,7 +72,6 @@ if choose_torch_device() == torch.device("mps"):
DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@@ -191,7 +194,7 @@ def get_scheduler(
title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents",
version="1.0.0",
version="1.1.0",
)
class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images"""
@@ -219,9 +222,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
input=Input.Connection,
ui_order=5,
)
ip_adapter: Optional[IPAdapterField] = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=7
)
@validator("cfg_scale")
@@ -323,8 +329,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_control_data(
self,
context: InvocationContext,
# really only need model for dtype and device
model: StableDiffusionGeneratorPipeline,
control_input: Union[ControlField, List[ControlField]],
latents_shape: List[int],
exit_stack: ExitStack,
@@ -344,57 +348,107 @@ class DenoiseLatentsInvocation(BaseInvocation):
else:
control_list = None
if control_list is None:
control_data = None
# from above handling, any control that is not None should now be of type list[ControlField]
else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
control_data = []
control_models = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
context=context,
)
)
return None
# After above handling, any control that is not None should now be of type list[ControlField].
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
context=context,
)
control_item = ControlNetData(
model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data
)
# control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, # model object
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
controlnet_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return controlnet_data
def prep_ip_adapter_data(
self,
context: InvocationContext,
ip_adapter: Optional[IPAdapterField],
conditioning_data: ConditioningData,
unet: UNet2DConditionModel,
exit_stack: ExitStack,
) -> Optional[IPAdapterData]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
to the `conditioning_data` (in-place).
"""
if ip_adapter is None:
return None
image_encoder_model_info = context.services.model_manager.get_model(
model_name=ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=ip_adapter.image_encoder_model.base_model,
context=context,
)
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=ip_adapter.ip_adapter_model.base_model,
context=context,
)
)
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
input_image, image_encoder_model
)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds
)
return IPAdapterData(
ip_adapter_model=ip_adapter_model,
weight=ip_adapter.weight,
begin_step_percent=ip_adapter.begin_step_percent,
end_step_percent=ip_adapter.end_step_percent,
)
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
@@ -488,9 +542,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
**self.unet.unet.dict(),
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet:
with (
ExitStack() as exit_stack,
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
set_seamless(unet_info.context.model, **self.unet.seamless.dict()),
unet_info as unet,
):
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
@@ -509,8 +566,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
control_data = self.prep_control_data(
model=pipeline,
controlnet_data = self.prep_control_data(
context=context,
control_input=self.control,
latents_shape=latents.shape,
@@ -519,6 +575,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
exit_stack=exit_stack,
)
ip_adapter_data = self.prep_ip_adapter_data(
context=context,
ip_adapter=self.ip_adapter,
conditioning_data=conditioning_data,
unet=unet,
exit_stack=exit_stack,
)
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
scheduler,
device=unet.device,
@@ -537,7 +601,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
masked_latents=masked_latents,
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
control_data=controlnet_data, # list[ControlNetData],
ip_adapter_data=ip_adapter_data, # IPAdapterData,
callback=step_callback,
)
@@ -583,7 +648,7 @@ class LatentsToImageInvocation(BaseInvocation):
context=context,
)
with set_seamless(vae_info.context.model, self.vae.seamless_axes), vae_info as vae:
with set_seamless(vae_info.context.model, **self.vae.seamless.dict()), vae_info as vae:
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)

View File

@@ -54,14 +54,7 @@ class DivideInvocation(BaseInvocation):
return IntegerOutput(value=int(self.a / self.b))
@invocation(
"rand_int",
title="Random Integer",
tags=["math", "random"],
category="math",
version="1.0.0",
use_cache=False,
)
@invocation("rand_int", title="Random Integer", tags=["math", "random"], category="math", version="1.0.0")
class RandomIntInvocation(BaseInvocation):
"""Outputs a single random integer."""

View File

@@ -18,6 +18,13 @@ from .baseinvocation import (
)
class SeamlessSettings(BaseModel):
axes: List[str] = Field(description="Axes('x' and 'y') to which apply seamless")
skipped_layers: int = Field(description="How much down layers skip when applying seamless")
skip_second_resnet: bool = Field(description="Skip or not second resnet in down blocks when applying seamless")
skip_conv2: bool = Field(description="Skip or not conv2 in down blocks when applying seamless")
class ModelInfo(BaseModel):
model_name: str = Field(description="Info to load submodel")
base_model: BaseModelType = Field(description="Base model")
@@ -33,7 +40,7 @@ class UNetField(BaseModel):
unet: ModelInfo = Field(description="Info to load unet submodel")
scheduler: ModelInfo = Field(description="Info to load scheduler submodel")
loras: List[LoraInfo] = Field(description="Loras to apply on model loading")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
class ClipField(BaseModel):
@@ -46,7 +53,7 @@ class ClipField(BaseModel):
class VaeField(BaseModel):
# TODO: better naming?
vae: ModelInfo = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
seamless: Optional[SeamlessSettings] = Field(default=None, description="Seamless settings applied to model")
@invocation_output("model_loader_output")
@@ -388,6 +395,11 @@ class SeamlessModeInvocation(BaseInvocation):
)
seamless_y: bool = InputField(default=True, input=Input.Any, description="Specify whether Y axis is seamless")
seamless_x: bool = InputField(default=True, input=Input.Any, description="Specify whether X axis is seamless")
skipped_layers: int = InputField(default=0, input=Input.Any, description="How much model's down layers to skip")
skip_second_resnet: bool = InputField(
default=True, input=Input.Any, description="Skip or not second resnet in down layers"
)
skip_conv2: bool = InputField(default=True, input=Input.Any, description="Skip or not conv2 in down layers")
def invoke(self, context: InvocationContext) -> SeamlessModeOutput:
# Conditionally append 'x' and 'y' based on seamless_x and seamless_y
@@ -402,8 +414,18 @@ class SeamlessModeInvocation(BaseInvocation):
seamless_axes_list.append("y")
if unet is not None:
unet.seamless_axes = seamless_axes_list
unet.seamless = SeamlessSettings(
axes=seamless_axes_list,
skipped_layers=self.skipped_layers,
skip_second_resnet=self.skip_second_resnet,
skip_conv2=self.skip_conv2,
)
if vae is not None:
vae.seamless_axes = seamless_axes_list
vae.seamless = SeamlessSettings(
axes=seamless_axes_list,
skipped_layers=self.skipped_layers,
skip_second_resnet=self.skip_second_resnet,
skip_conv2=self.skip_conv2,
)
return SeamlessModeOutput(unet=unet, vae=vae)

View File

@@ -95,9 +95,10 @@ class ONNXPromptInvocation(BaseInvocation):
print(f'Warn: trigger: "{trigger}" not found')
if loras or ti_list:
text_encoder.release_session()
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti(
orig_tokenizer, text_encoder, ti_list
) as (tokenizer, ti_manager):
with (
ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager),
):
text_encoder.create_session()
# copy from

View File

@@ -10,14 +10,7 @@ from invokeai.app.invocations.primitives import StringCollectionOutput
from .baseinvocation import BaseInvocation, InputField, InvocationContext, UIComponent, invocation
@invocation(
"dynamic_prompt",
title="Dynamic Prompt",
tags=["prompt", "collection"],
category="prompt",
version="1.0.0",
use_cache=False,
)
@invocation("dynamic_prompt", title="Dynamic Prompt", tags=["prompt", "collection"], category="prompt", version="1.0.0")
class DynamicPromptInvocation(BaseInvocation):
"""Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator"""

View File

@@ -253,7 +253,6 @@ class InvokeAIAppConfig(InvokeAISettings):
attention_type : Literal[tuple(["auto", "normal", "xformers", "sliced", "torch-sdp"])] = Field(default="auto", description="Attention type", category="Generation", )
attention_slice_size: Literal[tuple(["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8])] = Field(default="auto", description='Slice size, valid when attention_type=="sliced"', category="Generation", )
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty)", category="Generation",)
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", category="Generation", )
# NODES
allow_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.", category="Nodes")

View File

@@ -1,29 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
class InvocationCacheBase(ABC):
"""Base class for invocation caches."""
@abstractmethod
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
"""Retrieves and invocation output from the cache"""
pass
@abstractmethod
def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None:
"""Stores an invocation output in the cache"""
pass
@abstractmethod
def delete(self, key: Union[int, str]) -> None:
"""Deleted an invocation output from the cache"""
pass
@classmethod
@abstractmethod
def create_key(cls, value: BaseInvocation) -> Union[int, str]:
"""Creates the cache key for an invocation"""
pass

View File

@@ -1,34 +0,0 @@
from queue import Queue
from typing import Optional, Union
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
class MemoryInvocationCache(InvocationCacheBase):
__cache: dict[Union[int, str], BaseInvocationOutput]
__max_cache_size: int
__cache_ids: Queue
def __init__(self, max_cache_size: int = 512) -> None:
self.__cache = dict()
self.__max_cache_size = max_cache_size
self.__cache_ids = Queue()
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
return self.__cache.get(key, None)
def save(self, key: Union[int, str], value: BaseInvocationOutput) -> None:
if key not in self.__cache:
self.__cache[key] = value
self.__cache_ids.put(key)
if self.__cache_ids.qsize() > self.__max_cache_size:
self.__cache.pop(self.__cache_ids.get())
def delete(self, key: Union[int, str]) -> None:
if key in self.__cache:
del self.__cache[key]
@classmethod
def create_key(cls, value: BaseInvocation) -> Union[int, str]:
return hash(value.json(exclude={"id"}))

View File

@@ -12,7 +12,6 @@ if TYPE_CHECKING:
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import GraphExecutionState, LibraryGraph
from invokeai.app.services.images import ImageServiceABC
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
from invokeai.app.services.invocation_queue import InvocationQueueABC
from invokeai.app.services.invocation_stats import InvocationStatsServiceBase
from invokeai.app.services.invoker import InvocationProcessorABC
@@ -38,7 +37,6 @@ class InvocationServices:
processor: "InvocationProcessorABC"
performance_statistics: "InvocationStatsServiceBase"
queue: "InvocationQueueABC"
invocation_cache: "InvocationCacheBase"
def __init__(
self,
@@ -55,7 +53,6 @@ class InvocationServices:
processor: "InvocationProcessorABC",
performance_statistics: "InvocationStatsServiceBase",
queue: "InvocationQueueABC",
invocation_cache: "InvocationCacheBase",
):
self.board_images = board_images
self.boards = boards
@@ -71,4 +68,3 @@ class InvocationServices:
self.processor = processor
self.performance_statistics = performance_statistics
self.queue = queue
self.invocation_cache = invocation_cache

View File

@@ -326,6 +326,16 @@ class ModelInstall(object):
elif f"learned_embeds.{suffix}" in files:
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
break
elif "image_encoder.txt" in files and f"ip_adapter.{suffix}" in files: # IP-Adapter
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
location = self._download_hf_model(repo_id, files, staging)
break
elif f"model.{suffix}" in files and "config.json" in files:
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
# by InvokeAI for use with IP-Adapters.
files = ["config.json", f"model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging)
break
if not location:
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
return {}
@@ -534,14 +544,17 @@ def hf_download_with_resume(
logger.info(f"{model_name}: Downloading...")
try:
with open(model_dest, open_mode) as file, tqdm(
desc=model_name,
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar:
with (
open(model_dest, open_mode) as file,
tqdm(
desc=model_name,
initial=exist_size,
total=total + exist_size,
unit="iB",
unit_scale=True,
unit_divisor=1000,
) as bar,
):
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)

View File

@@ -0,0 +1,45 @@
# IP-Adapter Model Formats
The official IP-Adapter models are released here: [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter)
This official model repo does not integrate well with InvokeAI's current approach to model management, so we have defined a new file structure for IP-Adapter models. The InvokeAI format is described below.
## CLIP Vision Models
CLIP Vision models are organized in `diffusers`` format. The expected directory structure is:
```bash
ip_adapter_sd_image_encoder/
├── config.json
└── model.safetensors
```
## IP-Adapter Models
IP-Adapter models are stored in a directory containing two files
- `image_encoder.txt`: A text file containing the model identifier for the CLIP Vision encoder that is intended to be used with this IP-Adapter model.
- `ip_adapter.bin`: The IP-Adapter weights.
Sample directory structure:
```bash
ip_adapter_sd15/
├── image_encoder.txt
└── ip_adapter.bin
```
### Why save the weights in a .safetensors file?
The weights in `ip_adapter.bin` are stored in a nested dict, which is not supported by `safetensors`. This could be solved by splitting `ip_adapter.bin` into multiple files, but for now we have decided to maintain consistency with the checkpoint structure used in the official [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repo.
## InvokeAI Hosted IP-Adapters
Image Encoders:
- [InvokeAI/ip_adapter_sd_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder)
- [InvokeAI/ip_adapter_sdxl_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sdxl_image_encoder)
IP-Adapters:
- [InvokeAI/ip_adapter_sd15](https://huggingface.co/InvokeAI/ip_adapter_sd15)
- [InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
- [InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
- [InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
- Not yet supported: [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)

View File

@@ -0,0 +1,162 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
# tencent-ailab comment:
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor2_0 as DiffusersAttnProcessor2_0
# Create a version of AttnProcessor2_0 that is a sub-class of nn.Module. This is required for IP-Adapter state_dict
# loading.
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
def __init__(self):
DiffusersAttnProcessor2_0.__init__(self)
nn.Module.__init__(self)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
ip_adapter_image_prompt_embeds parameter.
"""
return DiffusersAttnProcessor2_0.__call__(
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
)
class IPAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
# The batch dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
ip_hidden_states = ip_adapter_image_prompt_embeds
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states

View File

@@ -0,0 +1,217 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
from contextlib import contextmanager
from typing import Optional, Union
import torch
from diffusers.models import UNet2DConditionModel
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from .attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
from .resampler import Resampler
class ImageProjModel(torch.nn.Module):
"""Image Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim)
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
"""Initialize an ImageProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
Args:
state_dict (dict[torch.Tensor]): The state_dict of model weights.
clip_extra_context_tokens (int, optional): Defaults to 4.
Returns:
ImageProjModel
"""
cross_attention_dim = state_dict["norm.weight"].shape[0]
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
model = cls(cross_attention_dim, clip_embeddings_dim, clip_extra_context_tokens)
model.load_state_dict(state_dict)
return model
def forward(self, image_embeds):
embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim
)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
class IPAdapter:
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
def __init__(
self,
state_dict: dict[torch.Tensor],
device: torch.device,
dtype: torch.dtype = torch.float16,
num_tokens: int = 4,
):
self.device = device
self.dtype = dtype
self._num_tokens = num_tokens
self._clip_image_processor = CLIPImageProcessor()
self._state_dict = state_dict
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
# The _attn_processors will be initialized later when we have access to the UNet.
self._attn_processors = None
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device
if dtype is not None:
self.dtype = dtype
self._image_proj_model.to(device=self.device, dtype=self.dtype)
if self._attn_processors is not None:
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
def _init_image_proj_model(self, state_dict):
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can later be injected into a unet, and load the IP-Adapter
attention weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
TODO(ryand): As a future improvement, this could all be inferred from the state_dict when the IPAdapter is
intialized.
"""
attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor2_0()
else:
attn_procs[name] = IPAttnProcessor2_0(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self.device, dtype=self.dtype)
ip_layers = torch.nn.ModuleList(attn_procs.values())
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
self._attn_processors = attn_procs
self._state_dict = None
# @genomancer: pushed scaling back out into its own method (like original Tencent implementation)
# which makes implementing begin_step_percent and end_step_percent easier
# but based on self._attn_processors (ala @Ryan) instead of original Tencent unet.attn_processors,
# which should make it easier to implement multiple IPAdapters
def set_scale(self, scale):
if self._attn_processors is not None:
for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor2_0):
attn_processor.scale = scale
@contextmanager
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel, scale: float):
"""A context manager that patches `unet` with this IP-Adapter's attention processors while it is active.
Yields:
None
"""
if self._attn_processors is None:
# We only have to call _prepare_attention_processors(...) once, and then the result is cached and can be
# used on any UNet model (with the same dimensions).
self._prepare_attention_processors(unet)
# Set scale
self.set_scale(scale)
# for attn_processor in self._attn_processors.values():
# if isinstance(attn_processor, IPAttnProcessor2_0):
# attn_processor.scale = scale
orig_attn_processors = unet.attn_processors
# Make a (moderately-) shallow copy of the self._attn_processors dict, because unet.set_attn_processor(...)
# actually pops elements from the passed dict.
ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()}
try:
unet.set_attn_processor(ip_adapter_attn_processors)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)
@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features"""
def _init_image_proj_model(self, state_dict):
return Resampler.from_state_dict(
state_dict=state_dict,
depth=4,
dim_head=64,
heads=12,
num_queries=self._num_tokens,
ff_mult=4,
).to(self.device, dtype=self.dtype)
@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=self.dtype)
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
-2
]
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds
def build_ip_adapter(
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus]:
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it
# contains.
is_plus = "proj.weight" not in state_dict["image_proj"]
if is_plus:
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
else:
return IPAdapter(state_dict, device=device, dtype=dtype)

View File

@@ -0,0 +1,158 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# tencent ailab comment: modified from
# https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
import math
import torch
import torch.nn as nn
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
"""A convenience function that initializes a Resampler from a state_dict.
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
writing, we did not have a need for inferring ALL of the shape parameters from the state_dict, but this would be
possible if needed in the future.
Args:
state_dict (dict[torch.Tensor]): The state_dict to load.
depth (int, optional):
dim_head (int, optional):
heads (int, optional):
ff_mult (int, optional):
Returns:
Resampler
"""
dim = state_dict["latents"].shape[2]
num_queries = state_dict["latents"].shape[1]
embedding_dim = state_dict["proj_in.weight"].shape[-1]
output_dim = state_dict["norm_out.weight"].shape[0]
model = cls(
dim=dim,
depth=depth,
dim_head=dim_head,
heads=heads,
num_queries=num_queries,
embedding_dim=embedding_dim,
output_dim=output_dim,
ff_mult=ff_mult,
)
model.load_state_dict(state_dict)
return model
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
return self.norm_out(latents)

View File

@@ -25,6 +25,7 @@ Models are described using four attributes:
ModelType.Lora -- a LoRA or LyCORIS fine-tune
ModelType.TextualInversion -- a textual inversion embedding
ModelType.ControlNet -- a ControlNet model
ModelType.IPAdapter -- an IPAdapter model
3) BaseModelType -- an enum indicating the stable diffusion base model, one of:
BaseModelType.StableDiffusion1
@@ -1000,8 +1001,8 @@ class ModelManager(object):
new_models_found = True
except DuplicateModelException as e:
self.logger.warning(e)
except InvalidModelException:
self.logger.warning(f"Not a valid model: {model_path}")
except InvalidModelException as e:
self.logger.warning(f"Not a valid model: {model_path}. {e}")
except NotImplementedError as e:
self.logger.warning(e)

View File

@@ -8,6 +8,8 @@ import torch
from diffusers import ConfigMixin, ModelMixin
from picklescan.scanner import scan_file_path
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from .models import (
BaseModelType,
InvalidModelException,
@@ -52,6 +54,7 @@ class ModelProbe(object):
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
}
@classmethod
@@ -118,14 +121,18 @@ class ModelProbe(object):
and prediction_type == SchedulerPredictionType.VPrediction
),
format=format,
image_size=1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else 768
if (
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
)
else 512,
image_size=(
1024
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
else (
768
if (
base_type == BaseModelType.StableDiffusion2
and prediction_type == SchedulerPredictionType.VPrediction
)
else 512
)
),
)
except Exception:
raise
@@ -177,9 +184,10 @@ class ModelProbe(object):
return ModelType.ONNX
if (folder_path / "learned_embeds.bin").exists():
return ModelType.TextualInversion
if (folder_path / "pytorch_lora_weights.bin").exists():
return ModelType.Lora
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
i = folder_path / "model_index.json"
c = folder_path / "config.json"
@@ -188,7 +196,12 @@ class ModelProbe(object):
if config_path:
with open(config_path, "r") as file:
conf = json.load(file)
class_name = conf["_class_name"]
if "_class_name" in conf:
class_name = conf["_class_name"]
elif "architectures" in conf:
class_name = conf["architectures"][0]
else:
class_name = None
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type
@@ -366,6 +379,16 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
########################################################
# classes for probing folders
#######################################################
@@ -485,11 +508,13 @@ class ControlNetFolderProbe(FolderProbeBase):
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
else (
BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
)
if not base_model:
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
@@ -509,15 +534,47 @@ class LoRAFolderProbe(FolderProbeBase):
return LoRACheckpointProbe(model_file, None).get_base_type()
class IPAdapterFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return IPAdapterModelFormat.InvokeAI.value
def get_base_type(self) -> BaseModelType:
model_file = self.folder_path / "ip_adapter.bin"
if not model_file.exists():
raise InvalidModelException("Unknown IP-Adapter model format.")
state_dict = torch.load(model_file, map_location="cpu")
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
class CLIPVisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@@ -79,7 +79,7 @@ class ModelSearch(ABC):
self._models_found += 1
self._scanned_dirs.add(path)
except Exception as e:
self.logger.warning(str(e))
self.logger.warning(f"Failed to process '{path}': {e}")
for f in files:
path = Path(root) / f
@@ -90,7 +90,7 @@ class ModelSearch(ABC):
self.on_model_found(path)
self._models_found += 1
except Exception as e:
self.logger.warning(str(e))
self.logger.warning(f"Failed to process '{path}': {e}")
class FindModels(ModelSearch):

View File

@@ -18,7 +18,9 @@ from .base import ( # noqa: F401
SilenceWarnings,
SubModelType,
)
from .clip_vision import CLIPVisionModel
from .controlnet import ControlNetModel # TODO:
from .ip_adapter import IPAdapterModel
from .lora import LoRAModel
from .sdxl import StableDiffusionXLModel
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
@@ -34,6 +36,8 @@ MODEL_CLASSES = {
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
},
BaseModelType.StableDiffusion2: {
ModelType.ONNX: ONNXStableDiffusion2Model,
@@ -42,6 +46,8 @@ MODEL_CLASSES = {
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
},
BaseModelType.StableDiffusionXL: {
ModelType.Main: StableDiffusionXLModel,
@@ -51,6 +57,8 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelType.Main: StableDiffusionXLModel,
@@ -60,6 +68,19 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
},
BaseModelType.Any: {
ModelType.CLIPVision: CLIPVisionModel,
# The following model types are not expected to be used with BaseModelType.Any.
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.Main: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
},
# BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model,

View File

@@ -36,6 +36,7 @@ class ModelNotFoundException(Exception):
class BaseModelType(str, Enum):
Any = "any" # For models that are not associated with any particular base model.
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
@@ -50,6 +51,8 @@ class ModelType(str, Enum):
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
class SubModelType(str, Enum):

View File

@@ -0,0 +1,82 @@
import os
from enum import Enum
from typing import Literal, Optional
import torch
from transformers import CLIPVisionModelWithProjection
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelType,
SubModelType,
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
)
class CLIPVisionModelFormat(str, Enum):
Diffusers = "diffusers"
class CLIPVisionModel(ModelBase):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[CLIPVisionModelFormat.Diffusers]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.CLIPVision
super().__init__(model_path, base_model, model_type)
self.model_size = calc_model_size_by_fs(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
if not os.path.exists(path):
raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.")
if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")):
return CLIPVisionModelFormat.Diffusers
raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}")
@classproperty
def save_to_config(cls) -> bool:
return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in a CLIP Vision model.")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> CLIPVisionModelWithProjection:
if child_type is not None:
raise ValueError("There are no child models in a CLIP Vision model.")
model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype)
# Calculate a more accurate model size.
self.model_size = calc_model_size_by_data(model)
return model
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == CLIPVisionModelFormat.Diffusers:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")

View File

@@ -0,0 +1,92 @@
import os
import typing
from enum import Enum
from typing import Literal, Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, build_ip_adapter
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelType,
SubModelType,
classproperty,
)
class IPAdapterModelFormat(str, Enum):
# The custom IP-Adapter model format defined by InvokeAI.
InvokeAI = "invokeai"
class IPAdapterModel(ModelBase):
class InvokeAIConfig(ModelConfigBase):
model_format: Literal[IPAdapterModelFormat.InvokeAI]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.IPAdapter
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
if not os.path.exists(path):
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
if os.path.isdir(path):
model_file = os.path.join(path, "ip_adapter.bin")
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
return IPAdapterModelFormat.InvokeAI
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
@classproperty
def save_to_config(cls) -> bool:
return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> typing.Union[IPAdapter, IPAdapterPlus]:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
return build_ip_adapter(
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
)
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == IPAdapterModelFormat.InvokeAI:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")
def get_ip_adapter_image_encoder_model_id(model_path: str):
"""Read the ID of the image encoder associated with the IP-Adapter at `model_path`."""
image_encoder_config_file = os.path.join(model_path, "image_encoder.txt")
with open(image_encoder_config_file, "r") as f:
image_encoder_model = f.readline().strip()
return image_encoder_model

View File

@@ -25,71 +25,55 @@ def _conv_forward_asymmetric(self, input, weight, bias):
@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
def set_seamless(
model: Union[UNet2DConditionModel, AutoencoderKL],
axes: List[str],
skipped_layers: int,
skip_second_resnet: bool,
skip_conv2: bool,
):
try:
to_restore = []
for m_name, m in model.named_modules():
if isinstance(model, UNet2DConditionModel):
if ".attentions." in m_name:
if not isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
continue
if isinstance(model, UNet2DConditionModel) and m_name.startswith("down_blocks.") and ".resnets." in m_name:
# down_blocks.1.resnets.1.conv1
_, block_num, _, resnet_num, submodule_name = m_name.split(".")
block_num = int(block_num)
resnet_num = int(resnet_num)
# if block_num >= seamless_down_blocks:
if block_num >= len(model.down_blocks) - skipped_layers:
continue
if ".resnets." in m_name:
if ".conv2" in m_name:
continue
if ".conv_shortcut" in m_name:
continue
"""
if isinstance(model, UNet2DConditionModel):
if False and ".upsamplers." in m_name:
if resnet_num > 0 and skip_second_resnet:
continue
if False and ".downsamplers." in m_name:
if submodule_name == "conv2" and skip_conv2:
continue
if True and ".resnets." in m_name:
if True and ".conv1" in m_name:
if False and "down_blocks" in m_name:
continue
if False and "mid_block" in m_name:
continue
if False and "up_blocks" in m_name:
continue
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
if True and ".conv2" in m_name:
continue
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield

View File

@@ -1,15 +1,6 @@
"""
Initialization file for the invokeai.backend.stable_diffusion package
"""
from .diffusers_pipeline import ( # noqa: F401
ConditioningData,
PipelineIntermediateState,
StableDiffusionGeneratorPipeline,
)
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
BasicConditioningInfo,
PostprocessingSettings,
SDXLConditioningInfo,
)

View File

@@ -1,8 +1,8 @@
from __future__ import annotations
import dataclasses
import inspect
from dataclasses import dataclass, field
import math
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Union
import einops
@@ -23,9 +23,11 @@ from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData
from ..util import auto_detect_slice_size, normalize_device
from .diffusion import AttentionMapSaver, BasicConditioningInfo, InvokeAIDiffuserComponent, PostprocessingSettings
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
@dataclass
@@ -95,7 +97,7 @@ class AddsMaskGuidance:
# Mask anything that has the same shape as prev_sample, return others as-is.
return output_class(
{
k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v)
k: self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v
for k, v in step_output.items()
}
)
@@ -162,39 +164,13 @@ class ControlNetData:
@dataclass
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
@property
def dtype(self):
return self.text_embeddings.dtype
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step)
for name, value in kwargs.items():
try:
step_method.bind_partial(**{name: value})
except TypeError:
# FIXME: don't silently discard arguments
pass # debug("%s does not accept argument named %r", scheduler, name)
else:
scheduler_args[name] = value
return dataclasses.replace(self, scheduler_args=scheduler_args)
class IPAdapterData:
ip_adapter_model: IPAdapter = Field(default=None)
# TODO: change to polymorphic so can do different weights per step (once implemented...)
weight: Union[float, List[float]] = Field(default=1.0)
# weight: float = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@dataclass
@@ -277,6 +253,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
self.control_model = control_model
self.use_ip_adapter = False
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
"""
@@ -349,6 +326,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
seed: Optional[int] = None,
@@ -400,6 +378,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data,
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
callback=callback,
)
finally:
@@ -419,6 +398,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
*,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
):
self._adjust_memory_efficient_attention(latents)
@@ -431,12 +411,26 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0:
return latents, attention_map_saver
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
):
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=conditioning_data.extra,
step_count=len(self.scheduler.timesteps),
)
self.use_ip_adapter = False
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
weight = ip_adapter_data.weight[0] if isinstance(ip_adapter_data.weight, List) else ip_adapter_data.weight
attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention(
unet=self.invokeai_diffuser.model,
scale=weight,
)
self.use_ip_adapter = True
else:
attn_ctx = nullcontext()
with attn_ctx:
if callback is not None:
callback(
PipelineIntermediateState(
@@ -459,6 +453,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=len(timesteps),
additional_guidance=additional_guidance,
control_data=control_data,
ip_adapter_data=ip_adapter_data,
)
latents = step_output.prev_sample
@@ -504,6 +499,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count: int,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0]
@@ -514,6 +510,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# i.e. before or after passing it to InvokeAIDiffuserComponent
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# handle IP-Adapter
if self.use_ip_adapter and ip_adapter_data is not None: # somewhat redundant but logic is clearer
first_adapter_step = math.floor(ip_adapter_data.begin_step_percent * total_step_count)
last_adapter_step = math.ceil(ip_adapter_data.end_step_percent * total_step_count)
weight = (
ip_adapter_data.weight[step_index]
if isinstance(ip_adapter_data.weight, List)
else ip_adapter_data.weight
)
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# only apply IP-Adapter if current step is within the IP-Adapter's begin/end step range
# ip_adapter_data.ip_adapter_model.set_scale(ip_adapter_data.weight)
ip_adapter_data.ip_adapter_model.set_scale(weight)
else:
# otherwise, set IP-Adapter scale to 0, so it has no effect
ip_adapter_data.ip_adapter_model.set_scale(0.0)
# handle ControlNet(s)
# default is no controlnet, so set controlnet processing output to None
controlnet_down_block_samples, controlnet_mid_block_sample = None, None
if control_data is not None:

View File

@@ -3,9 +3,4 @@ Initialization file for invokeai.models.diffusion
"""
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .shared_invokeai_diffusion import ( # noqa: F401
BasicConditioningInfo,
InvokeAIDiffuserComponent,
PostprocessingSettings,
SDXLConditioningInfo,
)
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401

View File

@@ -0,0 +1,101 @@
import dataclasses
import inspect
from dataclasses import dataclass, field
from typing import Any, List, Optional, Union
import torch
from .cross_attention_control import Arguments
@dataclass
class ExtraConditioningInfo:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
@dataclass
class BasicConditioningInfo:
embeds: torch.Tensor
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
# should only be stored in one place.
extra_conditioning: Optional[ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
@dataclass
class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoder conditioning embeddings.
Shape: (batch_size, num_tokens, encoding_dim).
"""
uncond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoding embeddings to use for unconditional generation.
Shape: (batch_size, num_tokens, encoding_dim).
"""
@dataclass
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
extra: Optional[ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
ip_adapter_conditioning: Optional[IPAdapterConditioningInfo] = None
@property
def dtype(self):
return self.text_embeddings.dtype
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step)
for name, value in kwargs.items():
try:
step_method.bind_partial(**{name: value})
except TypeError:
# FIXME: don't silently discard arguments
pass # debug("%s does not accept argument named %r", scheduler, name)
else:
scheduler_args[name] = value
return dataclasses.replace(self, scheduler_args=scheduler_args)

View File

@@ -376,11 +376,11 @@ def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[
# non-fatal error but .swap() won't work.
logger.error(
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
+ "work properly until it is fixed."
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
"failed or some assumption has changed about the structure of the model itself. Please fix the "
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
"attention map display will not work properly until it is fixed."
)
return attention_module_tuples
@@ -577,6 +577,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
attention_mask=None,
# kwargs
swap_cross_attn_context: SwapCrossAttnContext = None,
**kwargs,
):
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import math
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Optional, Union
import torch
@@ -10,9 +9,14 @@ from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
ExtraConditioningInfo,
PostprocessingSettings,
SDXLConditioningInfo,
)
from .cross_attention_control import (
Arguments,
Context,
CrossAttentionType,
SwapCrossAttnContext,
@@ -31,37 +35,6 @@ ModelForwardCallback: TypeAlias = Union[
]
@dataclass
class BasicConditioningInfo:
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
class InvokeAIDiffuserComponent:
"""
The aim of this component is to provide a single place for code that can be applied identically to
@@ -75,15 +48,6 @@ class InvokeAIDiffuserComponent:
debug_thresholding = False
sequential_guidance = False
@dataclass
class ExtraConditioningInfo:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
def __init__(
self,
model,
@@ -103,30 +67,26 @@ class InvokeAIDiffuserComponent:
@contextmanager
def custom_attention_context(
self,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
unet: UNet2DConditionModel,
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
):
old_attn_processors = None
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
self.cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
old_attn_processors = unet.attn_processors
try:
self.cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
yield None
finally:
self.cross_attention_control_context = None
if old_attn_processors is not None:
unet.set_attn_processor(old_attn_processors)
unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()
@@ -376,11 +336,24 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs):
# fast batched path
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs):
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
the cost of higher memory usage.
"""
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": torch.cat(
[
conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds,
conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds,
]
)
}
added_cond_kwargs = None
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
@@ -408,6 +381,7 @@ class InvokeAIDiffuserComponent:
x_twice,
sigma_twice,
both_conditionings,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
@@ -419,9 +393,12 @@ class InvokeAIDiffuserComponent:
self,
x: torch.Tensor,
sigma,
conditioning_data,
conditioning_data: ConditioningData,
**kwargs,
):
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
slower execution speed.
"""
# low-memory sequential path
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
@@ -437,6 +414,13 @@ class InvokeAIDiffuserComponent:
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
# Run unconditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds
}
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl:
@@ -449,12 +433,21 @@ class InvokeAIDiffuserComponent:
x,
sigma,
conditioning_data.unconditioned_embeddings.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
# Run conditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds
}
added_cond_kwargs = None
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
@@ -465,6 +458,7 @@ class InvokeAIDiffuserComponent:
x,
sigma,
conditioning_data.text_embeddings.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
added_cond_kwargs=added_cond_kwargs,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,4 +1,4 @@
import{v as m,h5 as Je,u as y,Y as Xa,h6 as Ja,a7 as ua,ab as d,h7 as b,h8 as o,h9 as Qa,ha as h,hb as fa,hc as Za,hd as eo,aE as ro,he as ao,a4 as oo,hf as to}from"./index-f83c2c5c.js";import{s as ha,n as t,t as io,o as ma,p as no,q as ga,v as ya,w as pa,x as lo,y as Sa,z as xa,A as xr,B as so,D as co,E as bo,F as $a,G as ka,H as _a,J as vo,K as wa,L as uo,M as fo,N as ho,O as mo,Q as za,R as go,S as yo,T as po,U as So,V as xo,W as $o,e as ko,X as _o}from"./menu-31376327.js";var Ca=String.raw,Aa=Ca`
import{v as m,hj as Je,u as y,Y as Xa,hk as Ja,a7 as ua,ab as d,hl as b,hm as o,hn as Qa,ho as h,hp as fa,hq as Za,hr as eo,aE as ro,hs as ao,a4 as oo,ht as to}from"./index-f6c3f475.js";import{s as ha,n as t,t as io,o as ma,p as no,q as ga,v as ya,w as pa,x as lo,y as Sa,z as xa,A as xr,B as so,D as co,E as bo,F as $a,G as ka,H as _a,J as vo,K as wa,L as uo,M as fo,N as ho,O as mo,Q as za,R as go,S as yo,T as po,U as So,V as xo,W as $o,e as ko,X as _o}from"./menu-c9cc8c3d.js";var Ca=String.raw,Aa=Ca`
:root,
:host {
--chakra-vh: 100vh;

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-f83c2c5c.js"></script>
<script type="module" crossorigin src="./assets/index-f6c3f475.js"></script>
</head>
<body dir="ltr">

File diff suppressed because it is too large Load Diff

View File

@@ -49,6 +49,7 @@
"close": "Close",
"communityLabel": "Community",
"controlNet": "Controlnet",
"ipAdapter": "IP Adapter",
"darkMode": "Dark Mode",
"discordLabel": "Discord",
"dontAskMeAgain": "Don't ask me again",
@@ -191,7 +192,11 @@
"showAdvanced": "Show Advanced",
"toggleControlNet": "Toggle this ControlNet",
"w": "W",
"weight": "Weight"
"weight": "Weight",
"enableIPAdapter": "Enable IP Adapter",
"ipAdapterModel": "Adapter Model",
"resetIPAdapterImage": "Reset IP Adapter Image",
"ipAdapterImageFallback": "No IP Adapter Image Selected"
},
"embedding": {
"addEmbedding": "Add Embedding",
@@ -1036,6 +1041,7 @@
"serverError": "Server Error",
"setCanvasInitialImage": "Set as canvas initial image",
"setControlImage": "Set as control image",
"setIPAdapterImage": "Set as IP Adapter Image",
"setInitialImage": "Set as initial image",
"setNodeField": "Set as node field",
"tempFoldersEmptied": "Temp Folder Emptied",

View File

@@ -1,5 +1,8 @@
import { resetCanvas } from 'features/canvas/store/canvasSlice';
import { controlNetReset } from 'features/controlNet/store/controlNetSlice';
import {
controlNetReset,
ipAdapterStateReset,
} from 'features/controlNet/store/controlNetSlice';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { clearInitialImage } from 'features/parameters/store/generationSlice';
@@ -18,6 +21,7 @@ export const addDeleteBoardAndImagesFulfilledListener = () => {
let wasCanvasReset = false;
let wasNodeEditorReset = false;
let wasControlNetReset = false;
let wasIPAdapterReset = false;
const state = getState();
deleted_images.forEach((image_name) => {
@@ -42,6 +46,11 @@ export const addDeleteBoardAndImagesFulfilledListener = () => {
dispatch(controlNetReset());
wasControlNetReset = true;
}
if (imageUsage.isIPAdapterImage && !wasIPAdapterReset) {
dispatch(ipAdapterStateReset());
wasIPAdapterReset = true;
}
});
},
});

View File

@@ -3,6 +3,7 @@ import { resetCanvas } from 'features/canvas/store/canvasSlice';
import {
controlNetImageChanged,
controlNetProcessedImageChanged,
ipAdapterImageChanged,
} from 'features/controlNet/store/controlNetSlice';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
@@ -110,6 +111,14 @@ export const addRequestedSingleImageDeletionListener = () => {
}
});
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage?.image_name ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));
}
// reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
@@ -227,6 +236,14 @@ export const addRequestedMultipleImageDeletionListener = () => {
}
});
// Remove IP Adapter Set Image if image is deleted.
if (
getState().controlNet.ipAdapterInfo.adapterImage?.image_name ===
imageDTO.image_name
) {
dispatch(ipAdapterImageChanged(null));
}
// reset nodes that use the deleted images
getState().nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {

View File

@@ -1,7 +1,11 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import {
controlNetImageChanged,
ipAdapterImageChanged,
} from 'features/controlNet/store/controlNetSlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
@@ -14,7 +18,6 @@ import {
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
import { startAppListening } from '../';
import { parseify } from 'common/util/serialize';
export const dndDropped = createAction<{
overData: TypesafeDroppableData;
@@ -99,6 +102,18 @@ export const addImageDroppedListener = () => {
return;
}
/**
* Image dropped on IP Adapter image
*/
if (
overData.actionType === 'SET_IP_ADAPTER_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
dispatch(ipAdapterImageChanged(activeData.payload.imageDTO));
return;
}
/**
* Image dropped on Canvas
*/

View File

@@ -19,6 +19,7 @@ export const addImageToDeleteSelectedListener = () => {
imagesUsage.some((i) => i.isCanvasImage) ||
imagesUsage.some((i) => i.isInitialImage) ||
imagesUsage.some((i) => i.isControlNetImage) ||
imagesUsage.some((i) => i.isIPAdapterImage) ||
imagesUsage.some((i) => i.isNodesImage);
if (shouldConfirmOnDelete || isImageInUse) {

View File

@@ -1,15 +1,18 @@
import { UseToastOptions } from '@chakra-ui/react';
import { logger } from 'app/logging/logger';
import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice';
import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice';
import {
controlNetImageChanged,
ipAdapterImageChanged,
} from 'features/controlNet/store/controlNetSlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { initialImageChanged } from 'features/parameters/store/generationSlice';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import { startAppListening } from '..';
import { imagesApi } from '../../../../../services/api/endpoints/images';
import { t } from 'i18next';
const DEFAULT_UPLOADED_TOAST: UseToastOptions = {
title: t('toast.imageUploaded'),
@@ -99,6 +102,17 @@ export const addImageUploadedFulfilledListener = () => {
return;
}
if (postUploadAction?.type === 'SET_IP_ADAPTER_IMAGE') {
dispatch(ipAdapterImageChanged(imageDTO));
dispatch(
addToast({
...DEFAULT_UPLOADED_TOAST,
description: t('toast.setIPAdapterImage'),
})
);
return;
}
if (postUploadAction?.type === 'SET_INITIAL_IMAGE') {
dispatch(initialImageChanged(imageDTO));
dispatch(

View File

@@ -1,6 +1,9 @@
import { logger } from 'app/logging/logger';
import { setBoundingBoxDimensions } from 'features/canvas/store/canvasSlice';
import { controlNetRemoved } from 'features/controlNet/store/controlNetSlice';
import {
controlNetRemoved,
ipAdapterStateReset,
} from 'features/controlNet/store/controlNetSlice';
import { loraRemoved } from 'features/lora/store/loraSlice';
import { modelSelected } from 'features/parameters/store/actions';
import {
@@ -56,6 +59,7 @@ export const addModelSelectedListener = () => {
modelsCleared += 1;
}
// handle incompatible controlnets
const { controlNets } = state.controlNet;
forEach(controlNets, (controlNet, controlNetId) => {
if (controlNet.model?.base_model !== base_model) {
@@ -64,6 +68,16 @@ export const addModelSelectedListener = () => {
}
});
// handle incompatible IP-Adapter
const { ipAdapterInfo } = state.controlNet;
if (
ipAdapterInfo.model &&
ipAdapterInfo.model.base_model !== base_model
) {
dispatch(ipAdapterStateReset());
modelsCleared += 1;
}
if (modelsCleared > 0) {
dispatch(
addToast(

View File

@@ -1,5 +0,0 @@
import { Store } from '@reduxjs/toolkit';
import { atom } from 'nanostores';
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export const $store = atom<Store<any> | undefined>();

View File

@@ -31,7 +31,6 @@ import { actionSanitizer } from './middleware/devtools/actionSanitizer';
import { actionsDenylist } from './middleware/devtools/actionsDenylist';
import { stateSanitizer } from './middleware/devtools/stateSanitizer';
import { listenerMiddleware } from './middleware/listenerMiddleware';
import { $store } from './nanostores/store';
const allReducers = {
canvas: canvasReducer,
@@ -87,7 +86,10 @@ export const store = configureStore({
.concat(autoBatchEnhancer());
},
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({ immutableCheck: false })
getDefaultMiddleware({
serializableCheck: false,
immutableCheck: false,
})
.concat(api.middleware)
.concat(dynamicMiddlewares)
.prepend(listenerMiddleware.middleware),
@@ -122,4 +124,3 @@ export type RootState = ReturnType<typeof store.getState>;
export type AppThunkDispatch = ThunkDispatch<RootState, any, AnyAction>;
export type AppDispatch = typeof store.dispatch;
export const stateSelector = (state: RootState) => state;
$store.set(store);

View File

@@ -18,6 +18,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import IAISwitch from 'common/components/IAISwitch';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useTranslation } from 'react-i18next';
import { useToggle } from 'react-use';
import { v4 as uuidv4 } from 'uuid';
import ControlNetImagePreview from './ControlNetImagePreview';
@@ -28,7 +29,6 @@ import ParamControlNetBeginEnd from './parameters/ParamControlNetBeginEnd';
import ParamControlNetControlMode from './parameters/ParamControlNetControlMode';
import ParamControlNetProcessorSelect from './parameters/ParamControlNetProcessorSelect';
import ParamControlNetResizeMode from './parameters/ParamControlNetResizeMode';
import { useTranslation } from 'react-i18next';
type ControlNetProps = {
controlNet: ControlNetConfig;

View File

@@ -0,0 +1,35 @@
import { Flex } from '@chakra-ui/react';
import { memo } from 'react';
import ParamIPAdapterBeginEnd from './ParamIPAdapterBeginEnd';
import ParamIPAdapterFeatureToggle from './ParamIPAdapterFeatureToggle';
import ParamIPAdapterImage from './ParamIPAdapterImage';
import ParamIPAdapterModelSelect from './ParamIPAdapterModelSelect';
import ParamIPAdapterWeight from './ParamIPAdapterWeight';
const IPAdapterPanel = () => {
return (
<Flex
sx={{
flexDir: 'column',
gap: 3,
paddingInline: 3,
paddingBlock: 2,
paddingBottom: 5,
borderRadius: 'base',
position: 'relative',
bg: 'base.250',
_dark: {
bg: 'base.750',
},
}}
>
<ParamIPAdapterFeatureToggle />
<ParamIPAdapterImage />
<ParamIPAdapterModelSelect />
<ParamIPAdapterWeight />
<ParamIPAdapterBeginEnd />
</Flex>
);
};
export default memo(IPAdapterPanel);

View File

@@ -0,0 +1,100 @@
import {
FormControl,
FormLabel,
HStack,
RangeSlider,
RangeSliderFilledTrack,
RangeSliderMark,
RangeSliderThumb,
RangeSliderTrack,
Tooltip,
} from '@chakra-ui/react';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
ipAdapterBeginStepPctChanged,
ipAdapterEndStepPctChanged,
} from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const formatPct = (v: number) => `${Math.round(v * 100)}%`;
const ParamIPAdapterBeginEnd = () => {
const isEnabled = useAppSelector(
(state: RootState) => state.controlNet.isIPAdapterEnabled
);
const beginStepPct = useAppSelector(
(state: RootState) => state.controlNet.ipAdapterInfo.beginStepPct
);
const endStepPct = useAppSelector(
(state: RootState) => state.controlNet.ipAdapterInfo.endStepPct
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleStepPctChanged = useCallback(
(v: number[]) => {
dispatch(ipAdapterBeginStepPctChanged(v[0] as number));
dispatch(ipAdapterEndStepPctChanged(v[1] as number));
},
[dispatch]
);
return (
<FormControl isDisabled={!isEnabled}>
<FormLabel>{t('controlnet.beginEndStepPercent')}</FormLabel>
<HStack w="100%" gap={2} alignItems="center">
<RangeSlider
aria-label={['Begin Step %', 'End Step %!']}
value={[beginStepPct, endStepPct]}
onChange={handleStepPctChanged}
min={0}
max={1}
step={0.01}
minStepsBetweenThumbs={5}
isDisabled={!isEnabled}
>
<RangeSliderTrack>
<RangeSliderFilledTrack />
</RangeSliderTrack>
<Tooltip label={formatPct(beginStepPct)} placement="top" hasArrow>
<RangeSliderThumb index={0} />
</Tooltip>
<Tooltip label={formatPct(endStepPct)} placement="top" hasArrow>
<RangeSliderThumb index={1} />
</Tooltip>
<RangeSliderMark
value={0}
sx={{
insetInlineStart: '0 !important',
insetInlineEnd: 'unset !important',
}}
>
0%
</RangeSliderMark>
<RangeSliderMark
value={0.5}
sx={{
insetInlineStart: '50% !important',
transform: 'translateX(-50%)',
}}
>
50%
</RangeSliderMark>
<RangeSliderMark
value={1}
sx={{
insetInlineStart: 'unset !important',
insetInlineEnd: '0 !important',
}}
>
100%
</RangeSliderMark>
</RangeSlider>
</HStack>
</FormControl>
);
};
export default memo(ParamIPAdapterBeginEnd);

View File

@@ -0,0 +1,41 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAISwitch from 'common/components/IAISwitch';
import { isIPAdapterEnableToggled } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
stateSelector,
(state) => {
const { isIPAdapterEnabled } = state.controlNet;
return { isIPAdapterEnabled };
},
defaultSelectorOptions
);
const ParamIPAdapterFeatureToggle = () => {
const { isIPAdapterEnabled } = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(() => {
dispatch(isIPAdapterEnableToggled());
}, [dispatch]);
return (
<IAISwitch
label={t('controlnet.enableIPAdapter')}
isChecked={isIPAdapterEnabled}
onChange={handleChange}
formControlProps={{
width: '100%',
}}
/>
);
};
export default memo(ParamIPAdapterFeatureToggle);

View File

@@ -0,0 +1,93 @@
import { Flex } from '@chakra-ui/react';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import IAIDndImageIcon from 'common/components/IAIDndImageIcon';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { ipAdapterImageChanged } from 'features/controlNet/store/controlNetSlice';
import {
TypesafeDraggableData,
TypesafeDroppableData,
} from 'features/dnd/types';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaUndo } from 'react-icons/fa';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { PostUploadAction } from 'services/api/types';
const ParamIPAdapterImage = () => {
const ipAdapterInfo = useAppSelector(
(state: RootState) => state.controlNet.ipAdapterInfo
);
const isIPAdapterEnabled = useAppSelector(
(state: RootState) => state.controlNet.isIPAdapterEnabled
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { currentData: imageDTO } = useGetImageDTOQuery(
ipAdapterInfo.adapterImage?.image_name ?? skipToken
);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
if (imageDTO) {
return {
id: 'ip-adapter-image',
payloadType: 'IMAGE_DTO',
payload: { imageDTO },
};
}
}, [imageDTO]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id: 'ip-adapter-image',
actionType: 'SET_IP_ADAPTER_IMAGE',
}),
[]
);
const postUploadAction = useMemo<PostUploadAction>(
() => ({
type: 'SET_IP_ADAPTER_IMAGE',
}),
[]
);
return (
<Flex
sx={{
position: 'relative',
w: 'full',
alignItems: 'center',
justifyContent: 'center',
}}
>
<IAIDndImage
imageDTO={imageDTO}
droppableData={droppableData}
draggableData={draggableData}
postUploadAction={postUploadAction}
isUploadDisabled={!isIPAdapterEnabled}
isDropDisabled={!isIPAdapterEnabled}
dropLabel={t('toast.setIPAdapterImage')}
noContentFallback={
<IAINoContentFallback
label={t('controlnet.ipAdapterImageFallback')}
/>
}
/>
<IAIDndImageIcon
onClick={() => dispatch(ipAdapterImageChanged(null))}
icon={ipAdapterInfo.adapterImage ? <FaUndo /> : undefined}
tooltip={t('controlnet.resetIPAdapterImage')}
/>
</Flex>
);
};
export default memo(ParamIPAdapterImage);

View File

@@ -0,0 +1,97 @@
import { SelectItem } from '@mantine/core';
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { ipAdapterModelChanged } from 'features/controlNet/store/controlNetSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
const ParamIPAdapterModelSelect = () => {
const ipAdapterModel = useAppSelector(
(state: RootState) => state.controlNet.ipAdapterInfo.model
);
const model = useAppSelector((state: RootState) => state.generation.model);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
ipAdapterModels?.entities[
`${ipAdapterModel?.base_model}/ip_adapter/${ipAdapterModel?.model_name}`
] ?? null,
[
ipAdapterModel?.base_model,
ipAdapterModel?.model_name,
ipAdapterModels?.entities,
]
);
const data = useMemo(() => {
if (!ipAdapterModels) {
return [];
}
const data: SelectItem[] = [];
forEach(ipAdapterModels.entities, (ipAdapterModel, id) => {
if (!ipAdapterModel) {
return;
}
const disabled = model?.base_model !== ipAdapterModel.base_model;
data.push({
value: id,
label: ipAdapterModel.model_name,
group: MODEL_TYPE_MAP[ipAdapterModel.base_model],
disabled,
tooltip: disabled
? `Incompatible base model: ${ipAdapterModel.base_model}`
: undefined,
});
});
return data.sort((a, b) => (a.disabled && !b.disabled ? 1 : -1));
}, [ipAdapterModels, model?.base_model]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newIPAdapterModel = modelIdToIPAdapterModelParam(v);
if (!newIPAdapterModel) {
return;
}
dispatch(ipAdapterModelChanged(newIPAdapterModel));
},
[dispatch]
);
return (
<IAIMantineSelect
label={t('controlnet.ipAdapterModel')}
className="nowheel nodrag"
tooltip={selectedModel?.description}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
data={data}
onChange={handleValueChanged}
sx={{ width: '100%' }}
/>
);
};
export default memo(ParamIPAdapterModelSelect);

View File

@@ -0,0 +1,46 @@
import { RootState } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAISlider from 'common/components/IAISlider';
import { ipAdapterWeightChanged } from 'features/controlNet/store/controlNetSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const ParamIPAdapterWeight = () => {
const isIpAdapterEnabled = useAppSelector(
(state: RootState) => state.controlNet.isIPAdapterEnabled
);
const ipAdapterWeight = useAppSelector(
(state: RootState) => state.controlNet.ipAdapterInfo.weight
);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleWeightChanged = useCallback(
(weight: number) => {
dispatch(ipAdapterWeightChanged(weight));
},
[dispatch]
);
const handleWeightReset = useCallback(() => {
dispatch(ipAdapterWeightChanged(1));
}, [dispatch]);
return (
<IAISlider
isDisabled={!isIpAdapterEnabled}
label={t('controlnet.weight')}
value={ipAdapterWeight}
onChange={handleWeightChanged}
min={0}
max={2}
step={0.01}
withSliderMarks
sliderMarks={[0, 1, 2]}
withReset
handleReset={handleWeightReset}
/>
);
};
export default memo(ParamIPAdapterWeight);

View File

@@ -1,9 +1,13 @@
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
import { ControlNetModelParam } from 'features/parameters/types/parameterSchemas';
import {
ControlNetModelParam,
IPAdapterModelParam,
} from 'features/parameters/types/parameterSchemas';
import { cloneDeep, forEach } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import { components } from 'services/api/schema';
import { isAnySessionRejected } from 'services/api/thunks/session';
import { ImageDTO } from 'services/api/types';
import { appSocketInvocationError } from 'services/events/actions';
import { controlNetImageProcessed } from './actions';
import {
@@ -56,16 +60,36 @@ export type ControlNetConfig = {
shouldAutoConfig: boolean;
};
export type IPAdapterConfig = {
adapterImage: ImageDTO | null;
model: IPAdapterModelParam | null;
weight: number;
beginStepPct: number;
endStepPct: number;
};
export type ControlNetState = {
controlNets: Record<string, ControlNetConfig>;
isEnabled: boolean;
pendingControlImages: string[];
isIPAdapterEnabled: boolean;
ipAdapterInfo: IPAdapterConfig;
};
export const initialIPAdapterState: IPAdapterConfig = {
adapterImage: null,
model: null,
weight: 1,
beginStepPct: 0,
endStepPct: 1,
};
export const initialControlNetState: ControlNetState = {
controlNets: {},
isEnabled: false,
pendingControlImages: [],
isIPAdapterEnabled: false,
ipAdapterInfo: { ...initialIPAdapterState },
};
export const controlNetSlice = createSlice({
@@ -353,6 +377,31 @@ export const controlNetSlice = createSlice({
controlNetReset: () => {
return { ...initialControlNetState };
},
isIPAdapterEnableToggled: (state) => {
state.isIPAdapterEnabled = !state.isIPAdapterEnabled;
},
ipAdapterImageChanged: (state, action: PayloadAction<ImageDTO | null>) => {
state.ipAdapterInfo.adapterImage = action.payload;
},
ipAdapterWeightChanged: (state, action: PayloadAction<number>) => {
state.ipAdapterInfo.weight = action.payload;
},
ipAdapterModelChanged: (
state,
action: PayloadAction<IPAdapterModelParam | null>
) => {
state.ipAdapterInfo.model = action.payload;
},
ipAdapterBeginStepPctChanged: (state, action: PayloadAction<number>) => {
state.ipAdapterInfo.beginStepPct = action.payload;
},
ipAdapterEndStepPctChanged: (state, action: PayloadAction<number>) => {
state.ipAdapterInfo.endStepPct = action.payload;
},
ipAdapterStateReset: (state) => {
state.isIPAdapterEnabled = false;
state.ipAdapterInfo = { ...initialIPAdapterState };
},
},
extraReducers: (builder) => {
builder.addCase(controlNetImageProcessed, (state, action) => {
@@ -412,6 +461,13 @@ export const {
controlNetProcessorTypeChanged,
controlNetReset,
controlNetAutoConfigToggled,
isIPAdapterEnableToggled,
ipAdapterImageChanged,
ipAdapterWeightChanged,
ipAdapterModelChanged,
ipAdapterBeginStepPctChanged,
ipAdapterEndStepPctChanged,
ipAdapterStateReset,
} = controlNetSlice.actions;
export default controlNetSlice.reducer;

View File

@@ -10,20 +10,20 @@ import {
Text,
} from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIButton from 'common/components/IAIButton';
import IAISwitch from 'common/components/IAISwitch';
import { setShouldConfirmOnDelete } from 'features/system/store/systemSlice';
import { stateSelector } from 'app/store/store';
import { some } from 'lodash-es';
import { ChangeEvent, memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { imageDeletionConfirmed } from '../store/actions';
import { getImageUsage, selectImageUsage } from '../store/selectors';
import { imageDeletionCanceled, isModalOpenChanged } from '../store/slice';
import ImageUsageMessage from './ImageUsageMessage';
import { ImageUsage } from '../store/types';
import ImageUsageMessage from './ImageUsageMessage';
const selector = createSelector(
[stateSelector, selectImageUsage],
@@ -42,6 +42,7 @@ const selector = createSelector(
isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage),
isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage),
isIPAdapterImage: some(allImageUsage, (i) => i.isIPAdapterImage),
};
return {

View File

@@ -1,8 +1,8 @@
import { ListItem, Text, UnorderedList } from '@chakra-ui/react';
import { some } from 'lodash-es';
import { memo } from 'react';
import { ImageUsage } from '../store/types';
import { useTranslation } from 'react-i18next';
import { ImageUsage } from '../store/types';
type Props = {
imageUsage?: ImageUsage;
@@ -38,6 +38,9 @@ const ImageUsageMessage = (props: Props) => {
{imageUsage.isControlNetImage && (
<ListItem>{t('common.controlNet')}</ListItem>
)}
{imageUsage.isIPAdapterImage && (
<ListItem>{t('common.ipAdapter')}</ListItem>
)}
{imageUsage.isNodesImage && (
<ListItem>{t('common.nodeEditor')}</ListItem>
)}

View File

@@ -1,9 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { RootState } from 'app/store/store';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { isInvocationNode } from 'features/nodes/types/types';
import { some } from 'lodash-es';
import { ImageUsage } from './types';
import { isInvocationNode } from 'features/nodes/types/types';
export const getImageUsage = (state: RootState, image_name: string) => {
const { generation, canvas, nodes, controlNet } = state;
@@ -27,11 +27,15 @@ export const getImageUsage = (state: RootState, image_name: string) => {
c.controlImage === image_name || c.processedControlImage === image_name
);
const isIPAdapterImage =
controlNet.ipAdapterInfo.adapterImage?.image_name === image_name;
const imageUsage: ImageUsage = {
isInitialImage,
isCanvasImage,
isNodesImage,
isControlNetImage,
isIPAdapterImage,
};
return imageUsage;

View File

@@ -10,4 +10,5 @@ export type ImageUsage = {
isCanvasImage: boolean;
isNodesImage: boolean;
isControlNetImage: boolean;
isIPAdapterImage: boolean;
};

View File

@@ -35,6 +35,10 @@ export type ControlNetDropData = BaseDropData & {
};
};
export type IPAdapterImageDropData = BaseDropData & {
actionType: 'SET_IP_ADAPTER_IMAGE';
};
export type CanvasInitialImageDropData = BaseDropData & {
actionType: 'SET_CANVAS_INITIAL_IMAGE';
};
@@ -73,6 +77,7 @@ export type TypesafeDroppableData =
| CurrentImageDropData
| InitialImageDropData
| ControlNetDropData
| IPAdapterImageDropData
| CanvasInitialImageDropData
| NodesImageDropData
| AddToBatchDropData

View File

@@ -24,6 +24,8 @@ export const isValidDrop = (
return payloadType === 'IMAGE_DTO';
case 'SET_CONTROLNET_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_IP_ADAPTER_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_CANVAS_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':

View File

@@ -53,6 +53,7 @@ const DeleteBoardModal = (props: Props) => {
isCanvasImage: some(allImageUsage, (i) => i.isCanvasImage),
isNodesImage: some(allImageUsage, (i) => i.isNodesImage),
isControlNetImage: some(allImageUsage, (i) => i.isControlNetImage),
isIPAdapterImage: some(allImageUsage, (i) => i.isIPAdapterImage),
};
return { imageUsageSummary };
}),

View File

@@ -27,7 +27,7 @@ const EmbedWorkflowCheckbox = ({ nodeId }: { nodeId: string }) => {
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Workflow</FormLabel>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Embed Workflow</FormLabel>
<Checkbox
className="nopan"
size="sm"

View File

@@ -1,13 +1,14 @@
import { Flex, Grid, GridItem } from '@chakra-ui/react';
import { useAnyOrDirectInputFieldNames } from 'features/nodes/hooks/useAnyOrDirectInputFieldNames';
import { useConnectionInputFieldNames } from 'features/nodes/hooks/useConnectionInputFieldNames';
import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames';
import { memo } from 'react';
import NodeWrapper from '../common/NodeWrapper';
import InvocationNodeFooter from './InvocationNodeFooter';
import InvocationNodeHeader from './InvocationNodeHeader';
import InputField from './fields/InputField';
import NodeWrapper from '../common/NodeWrapper';
import OutputField from './fields/OutputField';
import InputField from './fields/InputField';
import { useOutputFieldNames } from 'features/nodes/hooks/useOutputFieldNames';
import { useWithFooter } from 'features/nodes/hooks/useWithFooter';
import { useConnectionInputFieldNames } from 'features/nodes/hooks/useConnectionInputFieldNames';
import { useAnyOrDirectInputFieldNames } from 'features/nodes/hooks/useAnyOrDirectInputFieldNames';
type Props = {
nodeId: string;
@@ -21,6 +22,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
const inputConnectionFieldNames = useConnectionInputFieldNames(nodeId);
const inputAnyOrDirectFieldNames = useAnyOrDirectInputFieldNames(nodeId);
const outputFieldNames = useOutputFieldNames(nodeId);
const withFooter = useWithFooter(nodeId);
return (
<NodeWrapper nodeId={nodeId} selected={selected}>
@@ -41,7 +43,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
h: 'full',
py: 2,
gap: 1,
borderBottomRadius: 0,
borderBottomRadius: withFooter ? 0 : 'base',
}}
>
<Flex sx={{ flexDir: 'column', px: 2, w: 'full', h: 'full' }}>
@@ -74,7 +76,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
))}
</Flex>
</Flex>
<InvocationNodeFooter nodeId={nodeId} />
{withFooter && <InvocationNodeFooter nodeId={nodeId} />}
</>
)}
</NodeWrapper>

View File

@@ -3,15 +3,12 @@ import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import { memo } from 'react';
import EmbedWorkflowCheckbox from './EmbedWorkflowCheckbox';
import SaveToGalleryCheckbox from './SaveToGalleryCheckbox';
import UseCacheCheckbox from './UseCacheCheckbox';
import { useHasImageOutput } from 'features/nodes/hooks/useHasImageOutput';
type Props = {
nodeId: string;
};
const InvocationNodeFooter = ({ nodeId }: Props) => {
const hasImageOutput = useHasImageOutput(nodeId);
return (
<Flex
className={DRAG_HANDLE_CLASSNAME}
@@ -25,9 +22,8 @@ const InvocationNodeFooter = ({ nodeId }: Props) => {
justifyContent: 'space-between',
}}
>
{hasImageOutput && <EmbedWorkflowCheckbox nodeId={nodeId} />}
<UseCacheCheckbox nodeId={nodeId} />
{hasImageOutput && <SaveToGalleryCheckbox nodeId={nodeId} />}
<EmbedWorkflowCheckbox nodeId={nodeId} />
<SaveToGalleryCheckbox nodeId={nodeId} />
</Flex>
);
};

View File

@@ -1,35 +0,0 @@
import { Checkbox, Flex, FormControl, FormLabel } from '@chakra-ui/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useUseCache } from 'features/nodes/hooks/useUseCache';
import { nodeUseCacheChanged } from 'features/nodes/store/nodesSlice';
import { ChangeEvent, memo, useCallback } from 'react';
const UseCacheCheckbox = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const useCache = useUseCache(nodeId);
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(
nodeUseCacheChanged({
nodeId,
useCache: e.target.checked,
})
);
},
[dispatch, nodeId]
);
return (
<FormControl as={Flex} sx={{ alignItems: 'center', gap: 2, w: 'auto' }}>
<FormLabel sx={{ fontSize: 'xs', mb: '1px' }}>Use Cache</FormLabel>
<Checkbox
className="nopan"
size="sm"
onChange={handleChange}
isChecked={useCache}
/>
</FormControl>
);
};
export default memo(UseCacheCheckbox);

View File

@@ -15,6 +15,7 @@ import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
import SchedulerInputField from './inputs/SchedulerInputField';
import StringInputField from './inputs/StringInputField';
import VaeModelInputField from './inputs/VaeModelInputField';
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
type InputFieldProps = {
nodeId: string;
@@ -147,6 +148,19 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
);
}
if (
field?.type === 'IPAdapterModelField' &&
fieldTemplate?.type === 'IPAdapterModelField'
) {
return (
<IPAdapterModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
return (
<ColorInputField

View File

@@ -0,0 +1,17 @@
import {
IPAdapterInputFieldTemplate,
IPAdapterInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const IPAdapterInputFieldComponent = (
_props: FieldComponentProps<
IPAdapterInputFieldValue,
IPAdapterInputFieldTemplate
>
) => {
return null;
};
export default memo(IPAdapterInputFieldComponent);

View File

@@ -0,0 +1,100 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
IPAdapterModelInputFieldTemplate,
IPAdapterModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
const IPAdapterModelInputFieldComponent = (
props: FieldComponentProps<
IPAdapterModelInputFieldValue,
IPAdapterModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const ipAdapterModel = field.value;
const dispatch = useAppDispatch();
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
ipAdapterModels?.entities[
`${ipAdapterModel?.base_model}/ip_adapter/${ipAdapterModel?.model_name}`
] ?? null,
[
ipAdapterModel?.base_model,
ipAdapterModel?.model_name,
ipAdapterModels?.entities,
]
);
const data = useMemo(() => {
if (!ipAdapterModels) {
return [];
}
const data: SelectItem[] = [];
forEach(ipAdapterModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [ipAdapterModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newIPAdapterModel = modelIdToIPAdapterModelParam(v);
if (!newIPAdapterModel) {
return;
}
dispatch(
fieldIPAdapterModelValueChanged({
nodeId,
fieldName: field.name,
value: newIPAdapterModel,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<IAIMantineSelect
className="nowheel nodrag"
tooltip={selectedModel?.description}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
data={data}
onChange={handleValueChanged}
sx={{ width: '100%' }}
/>
);
};
export default memo(IPAdapterModelInputFieldComponent);

View File

@@ -146,7 +146,6 @@ export const useBuildNodeData = () => {
isIntermediate: true,
inputs,
outputs,
useCache: template.useCache,
},
};

View File

@@ -1,29 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import { useMemo } from 'react';
import { isInvocationNode } from '../types/types';
export const useUseCache = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(
stateSelector,
({ nodes }) => {
const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) {
return false;
}
// cast to boolean to support older workflows that didn't have useCache
// TODO: handle this better somehow
return node.data.useCache;
},
defaultSelectorOptions
),
[nodeId]
);
const useCache = useAppSelector(selector);
return useCache;
};

View File

@@ -7,7 +7,7 @@ import { useMemo } from 'react';
import { FOOTER_FIELDS } from '../types/constants';
import { isInvocationNode } from '../types/types';
export const useHasImageOutputs = (nodeId: string) => {
export const useWithFooter = (nodeId: string) => {
const selector = useMemo(
() =>
createSelector(

View File

@@ -41,6 +41,7 @@ import {
IntegerInputFieldValue,
InvocationNodeData,
InvocationTemplate,
IPAdapterModelInputFieldValue,
isInvocationNode,
isNotesNode,
LoRAModelInputFieldValue,
@@ -260,20 +261,6 @@ const nodesSlice = createSlice({
}
node.data.embedWorkflow = embedWorkflow;
},
nodeUseCacheChanged: (
state,
action: PayloadAction<{ nodeId: string; useCache: boolean }>
) => {
const { nodeId, useCache } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
node.data.useCache = useCache;
},
nodeIsIntermediateChanged: (
state,
action: PayloadAction<{ nodeId: string; isIntermediate: boolean }>
@@ -534,6 +521,12 @@ const nodesSlice = createSlice({
) => {
fieldValueReducer(state, action);
},
fieldIPAdapterModelValueChanged: (
state,
action: FieldValueAction<IPAdapterModelInputFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldEnumModelValueChanged: (
state,
action: FieldValueAction<EnumInputFieldValue>
@@ -880,6 +873,7 @@ export const {
fieldLoRAModelValueChanged,
fieldEnumModelValueChanged,
fieldControlNetModelValueChanged,
fieldIPAdapterModelValueChanged,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
nodeIsOpenChanged,
@@ -918,7 +912,6 @@ export const {
nodeIsIntermediateChanged,
mouseOverNodeChanged,
nodeExclusivelySelected,
nodeUseCacheChanged,
} = nodesSlice.actions;
export default nodesSlice.reducer;

View File

@@ -41,6 +41,7 @@ export const POLYMORPHIC_TYPES = [
];
export const MODEL_TYPES = [
'IPAdapterModelField',
'ControlNetModelField',
'LoRAModelField',
'MainModelField',
@@ -236,6 +237,16 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
description: t('nodes.integerPolymorphicDescription'),
title: t('nodes.integerPolymorphic'),
},
IPAdapterField: {
color: 'green.300',
description: 'IP-Adapter info passed between nodes.',
title: 'IP-Adapter',
},
IPAdapterModelField: {
color: 'teal.500',
description: 'IP-Adapter model',
title: 'IP-Adapter Model',
},
LatentsCollection: {
color: 'pink.500',
description: t('nodes.latentsCollectionDescription'),

View File

@@ -1,4 +1,3 @@
import { $store } from 'app/store/nanostores/store';
import {
SchedulerParam,
zBaseModel,
@@ -8,8 +7,7 @@ import {
zSDXLRefinerModel,
zScheduler,
} from 'features/parameters/types/parameterSchemas';
import i18n from 'i18next';
import { has, keyBy } from 'lodash-es';
import { keyBy } from 'lodash-es';
import { OpenAPIV3 } from 'openapi-types';
import { RgbaColor } from 'react-colorful';
import { Node } from 'reactflow';
@@ -22,6 +20,7 @@ import {
import { O } from 'ts-toolbelt';
import { JsonObject } from 'type-fest';
import { z } from 'zod';
import i18n from 'i18next';
export type NonNullableGraph = O.Required<Graph, 'nodes' | 'edges'>;
@@ -58,10 +57,6 @@ export type InvocationTemplate = {
* The invocation's version.
*/
version?: string;
/**
* Whether or not this node should use the cache
*/
useCache: boolean;
};
export type FieldUIConfig = {
@@ -99,6 +94,8 @@ export const zFieldType = z.enum([
'integer',
'IntegerCollection',
'IntegerPolymorphic',
'IPAdapterField',
'IPAdapterModelField',
'LatentsCollection',
'LatentsField',
'LatentsPolymorphic',
@@ -394,6 +391,25 @@ export type ControlCollectionInputFieldValue = z.infer<
typeof zControlCollectionInputFieldValue
>;
export const zIPAdapterModel = zModelIdentifier;
export type IPAdapterModel = z.infer<typeof zIPAdapterModel>;
export const zIPAdapterField = z.object({
image: zImageField,
ip_adapter_model: zIPAdapterModel,
image_encoder_model: z.string().trim().min(1),
weight: z.number(),
});
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
export const zIPAdapterInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IPAdapterField'),
value: zIPAdapterField.optional(),
});
export type IPAdapterInputFieldValue = z.infer<
typeof zIPAdapterInputFieldValue
>;
export const zModelType = z.enum([
'onnx',
'main',
@@ -543,6 +559,17 @@ export type ControlNetModelInputFieldValue = z.infer<
typeof zControlNetModelInputFieldValue
>;
export const zIPAdapterModelField = zModelIdentifier;
export type IPAdapterModelField = z.infer<typeof zIPAdapterModelField>;
export const zIPAdapterModelInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IPAdapterModelField'),
value: zIPAdapterModelField.optional(),
});
export type IPAdapterModelInputFieldValue = z.infer<
typeof zIPAdapterModelInputFieldValue
>;
export const zCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('Collection'),
value: z.array(z.any()).optional(), // TODO: should this field ever have a value?
@@ -625,6 +652,8 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
zIntegerCollectionInputFieldValue,
zIntegerPolymorphicInputFieldValue,
zIntegerInputFieldValue,
zIPAdapterInputFieldValue,
zIPAdapterModelInputFieldValue,
zLatentsInputFieldValue,
zLatentsCollectionInputFieldValue,
zLatentsPolymorphicInputFieldValue,
@@ -827,6 +856,11 @@ export type ControlPolymorphicInputFieldTemplate = Omit<
type: 'ControlPolymorphic';
};
export type IPAdapterInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'IPAdapterField';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'enum';
@@ -864,6 +898,11 @@ export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'ControlNetModelField';
};
export type IPAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'IPAdapterModelField';
};
export type CollectionInputFieldTemplate = InputFieldTemplateBase & {
default: [];
type: 'Collection';
@@ -935,6 +974,8 @@ export type InputFieldTemplate =
| IntegerCollectionInputFieldTemplate
| IntegerPolymorphicInputFieldTemplate
| IntegerInputFieldTemplate
| IPAdapterInputFieldTemplate
| IPAdapterModelInputFieldTemplate
| LatentsInputFieldTemplate
| LatentsCollectionInputFieldTemplate
| LatentsPolymorphicInputFieldTemplate
@@ -981,9 +1022,6 @@ export type InvocationSchemaExtra = {
type: Omit<OpenAPIV3.SchemaObject, 'default'> & {
default: AnyInvocationType;
};
use_cache: Omit<OpenAPIV3.SchemaObject, 'default'> & {
default: boolean;
};
};
};
@@ -1147,37 +1185,9 @@ export const zInvocationNodeData = z.object({
version: zSemVer.optional(),
});
export const zInvocationNodeDataV2 = z.preprocess(
(arg) => {
try {
const data = zInvocationNodeData.parse(arg);
if (!has(data, 'useCache')) {
const nodeTemplates = $store.get()?.getState().nodes.nodeTemplates as
| Record<string, InvocationTemplate>
| undefined;
const template = nodeTemplates?.[data.type];
let useCache = true;
if (template) {
useCache = template.useCache;
}
Object.assign(data, { useCache });
}
return data;
} catch {
return arg;
}
},
zInvocationNodeData.extend({
useCache: z.boolean(),
})
);
// Massage this to get better type safety while developing
export type InvocationNodeData = Omit<
z.infer<typeof zInvocationNodeDataV2>,
z.infer<typeof zInvocationNodeData>,
'type'
> & {
type: AnyInvocationType;
@@ -1205,7 +1215,7 @@ const zDimension = z.number().gt(0).nullish();
export const zWorkflowInvocationNode = z.object({
id: z.string().trim().min(1),
type: z.literal('invocation'),
data: zInvocationNodeDataV2,
data: zInvocationNodeData,
width: zDimension,
height: zDimension,
position: zPosition,
@@ -1267,8 +1277,6 @@ export type WorkflowWarning = {
data: JsonObject;
};
const CURRENT_WORKFLOW_VERSION = '1.0.0';
export const zWorkflow = z.object({
name: z.string().default(''),
author: z.string().default(''),
@@ -1284,7 +1292,7 @@ export const zWorkflow = z.object({
.object({
version: zSemVer,
})
.default({ version: CURRENT_WORKFLOW_VERSION }),
.default({ version: '1.0.0' }),
});
export const zValidatedWorkflow = zWorkflow.transform((workflow) => {

View File

@@ -60,6 +60,8 @@ import {
ImageField,
LatentsField,
ConditioningField,
IPAdapterInputFieldTemplate,
IPAdapterModelInputFieldTemplate,
} from '../types/types';
import { ControlField } from 'services/api/types';
@@ -435,6 +437,19 @@ const buildControlNetModelInputFieldTemplate = ({
return template;
};
const buildIPAdapterModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IPAdapterModelInputFieldTemplate => {
const template: IPAdapterModelInputFieldTemplate = {
...baseField,
type: 'IPAdapterModelField',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({
schemaObject,
baseField,
@@ -648,6 +663,19 @@ const buildControlCollectionInputFieldTemplate = ({
return template;
};
const buildIPAdapterInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IPAdapterInputFieldTemplate => {
const template: IPAdapterInputFieldTemplate = {
...baseField,
type: 'IPAdapterField',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({
schemaObject,
baseField,
@@ -851,6 +879,8 @@ const TEMPLATE_BUILDER_MAP = {
integer: buildIntegerInputFieldTemplate,
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
IPAdapterField: buildIPAdapterInputFieldTemplate,
IPAdapterModelField: buildIPAdapterModelInputFieldTemplate,
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
LatentsField: buildLatentsInputFieldTemplate,
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,

View File

@@ -28,6 +28,8 @@ const FIELD_VALUE_FALLBACK_MAP = {
integer: 0,
IntegerCollection: [],
IntegerPolymorphic: 0,
IPAdapterField: undefined,
IPAdapterModelField: undefined,
LatentsCollection: [],
LatentsField: undefined,
LatentsPolymorphic: undefined,

View File

@@ -0,0 +1,59 @@
import { RootState } from 'app/store/store';
import { IPAdapterInvocation } from 'services/api/types';
import { NonNullableGraph } from '../../types/types';
import { IP_ADAPTER } from './constants';
export const addIPAdapterToLinearGraph = (
state: RootState,
graph: NonNullableGraph,
baseNodeId: string
): void => {
const { isIPAdapterEnabled, ipAdapterInfo } = state.controlNet;
// const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
// | MetadataAccumulatorInvocation
// | undefined;
if (isIPAdapterEnabled && ipAdapterInfo.model) {
const ipAdapterNode: IPAdapterInvocation = {
id: IP_ADAPTER,
type: 'ip_adapter',
is_intermediate: true,
weight: ipAdapterInfo.weight,
ip_adapter_model: {
base_model: ipAdapterInfo.model?.base_model,
model_name: ipAdapterInfo.model?.model_name,
},
begin_step_percent: ipAdapterInfo.beginStepPct,
end_step_percent: ipAdapterInfo.endStepPct,
};
if (ipAdapterInfo.adapterImage) {
ipAdapterNode.image = {
image_name: ipAdapterInfo.adapterImage.image_name,
};
} else {
return;
}
graph.nodes[ipAdapterNode.id] = ipAdapterNode as IPAdapterInvocation;
// if (metadataAccumulator?.ip_adapters) {
// // metadata accumulator only needs the ip_adapter field - not the whole node
// // extract what we need and add to the accumulator
// const ipAdapterField = omit(ipAdapterNode, [
// 'id',
// 'type',
// ]) as IPAdapterField;
// metadataAccumulator.ip_adapters.push(ipAdapterField);
// }
graph.edges.push({
source: { node_id: ipAdapterNode.id, field: 'ip_adapter' },
destination: {
node_id: baseNodeId,
field: 'ip_adapter',
},
});
}
};

View File

@@ -1,32 +1,46 @@
import { RootState } from 'app/store/store';
import { NonNullableGraph } from 'features/nodes/types/types';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import {
ImageNSFWBlurInvocation,
LatentsToImageInvocation,
MetadataAccumulatorInvocation,
} from 'services/api/types';
import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants';
import {
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NSFW_CHECKER,
} from './constants';
export const addNSFWCheckerToGraph = (
state: RootState,
graph: NonNullableGraph,
nodeIdToAddTo = LATENTS_TO_IMAGE
): void => {
const activeTabName = activeTabNameSelector(state);
const is_intermediate =
activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false;
const nodeToAddTo = graph.nodes[nodeIdToAddTo] as
| LatentsToImageInvocation
| undefined;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (!nodeToAddTo) {
// something has gone terribly awry
return;
}
nodeToAddTo.is_intermediate = true;
nodeToAddTo.use_cache = true;
const nsfwCheckerNode: ImageNSFWBlurInvocation = {
id: NSFW_CHECKER,
type: 'img_nsfw',
is_intermediate: true,
is_intermediate,
};
graph.nodes[NSFW_CHECKER] = nsfwCheckerNode as ImageNSFWBlurInvocation;
@@ -40,4 +54,17 @@ export const addNSFWCheckerToGraph = (
field: 'image',
},
});
if (metadataAccumulator) {
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: NSFW_CHECKER,
field: 'metadata',
},
});
}
};

View File

@@ -1,92 +0,0 @@
import { NonNullableGraph } from 'features/nodes/types/types';
import {
CANVAS_OUTPUT,
LATENTS_TO_IMAGE,
METADATA_ACCUMULATOR,
NSFW_CHECKER,
SAVE_IMAGE,
WATERMARKER,
} from './constants';
import {
MetadataAccumulatorInvocation,
SaveImageInvocation,
} from 'services/api/types';
import { RootState } from 'app/store/store';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
/**
* Set the `use_cache` field on the linear/canvas graph's final image output node to False.
*/
export const addSaveImageNode = (
state: RootState,
graph: NonNullableGraph
): void => {
const activeTabName = activeTabNameSelector(state);
const is_intermediate =
activeTabName === 'unifiedCanvas' ? !state.canvas.shouldAutoSave : false;
const saveImageNode: SaveImageInvocation = {
id: SAVE_IMAGE,
type: 'save_image',
is_intermediate,
use_cache: false,
};
graph.nodes[SAVE_IMAGE] = saveImageNode;
const metadataAccumulator = graph.nodes[METADATA_ACCUMULATOR] as
| MetadataAccumulatorInvocation
| undefined;
if (metadataAccumulator) {
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: SAVE_IMAGE,
field: 'metadata',
},
});
}
const destination = {
node_id: SAVE_IMAGE,
field: 'image',
};
if (WATERMARKER in graph.nodes) {
graph.edges.push({
source: {
node_id: WATERMARKER,
field: 'image',
},
destination,
});
} else if (NSFW_CHECKER in graph.nodes) {
graph.edges.push({
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination,
});
} else if (CANVAS_OUTPUT in graph.nodes) {
graph.edges.push({
source: {
node_id: CANVAS_OUTPUT,
field: 'image',
},
destination,
});
} else if (LATENTS_TO_IMAGE in graph.nodes) {
graph.edges.push({
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination,
});
}
};

View File

@@ -51,7 +51,6 @@ export const addWatermarkerToGraph = (
// no matter the situation, we want the l2i node to be intermediate
nodeToAddTo.is_intermediate = true;
nodeToAddTo.use_cache = true;
if (nsfwCheckerNode) {
// if we are using NSFW checker, we need to "disable" it output by marking it intermediate,

View File

@@ -5,6 +5,7 @@ import { initialGenerationState } from 'features/parameters/store/generationSlic
import { ImageDTO, ImageToLatentsInvocation } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@@ -25,7 +26,6 @@ import {
POSITIVE_CONDITIONING,
SEAMLESS,
} from './constants';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Image to Image graph.
@@ -54,10 +54,14 @@ export const buildCanvasImageToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -89,31 +93,31 @@ export const buildCanvasImageToImageGraph = (
[modelLoaderNodeId]: {
type: 'main_model_loader',
id: modelLoaderNodeId,
is_intermediate,
is_intermediate: true,
model,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
is_intermediate,
is_intermediate: true,
skipped_layers: clipSkip,
},
[POSITIVE_CONDITIONING]: {
type: 'compel',
id: POSITIVE_CONDITIONING,
is_intermediate,
is_intermediate: true,
prompt: positivePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
id: NEGATIVE_CONDITIONING,
is_intermediate,
is_intermediate: true,
prompt: negativePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
is_intermediate,
is_intermediate: true,
use_cpu,
width: !isUsingScaledDimensions
? width
@@ -125,12 +129,12 @@ export const buildCanvasImageToImageGraph = (
[IMAGE_TO_LATENTS]: {
type: 'i2l',
id: IMAGE_TO_LATENTS,
is_intermediate,
is_intermediate: true,
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
cfg_scale,
scheduler,
steps,
@@ -140,7 +144,7 @@ export const buildCanvasImageToImageGraph = (
[CANVAS_OUTPUT]: {
type: 'l2i',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: !shouldAutoSave,
},
},
edges: [
@@ -235,7 +239,7 @@ export const buildCanvasImageToImageGraph = (
graph.nodes[IMG2IMG_RESIZE] = {
id: IMG2IMG_RESIZE,
type: 'img_resize',
is_intermediate,
is_intermediate: true,
image: initialImage,
width: scaledBoundingBoxDimensions.width,
height: scaledBoundingBoxDimensions.height,
@@ -243,13 +247,13 @@ export const buildCanvasImageToImageGraph = (
graph.nodes[LATENTS_TO_IMAGE] = {
id: LATENTS_TO_IMAGE,
type: 'l2i',
is_intermediate,
is_intermediate: true,
fp32,
};
graph.nodes[CANVAS_OUTPUT] = {
id: CANVAS_OUTPUT,
type: 'img_resize',
is_intermediate,
is_intermediate: !shouldAutoSave,
width: width,
height: height,
};
@@ -290,7 +294,7 @@ export const buildCanvasImageToImageGraph = (
graph.nodes[CANVAS_OUTPUT] = {
type: 'l2i',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: !shouldAutoSave,
fp32,
};
@@ -334,6 +338,17 @@ export const buildCanvasImageToImageGraph = (
init_image: initialImage.image_name,
};
graph.edges.push({
source: {
node_id: METADATA_ACCUMULATOR,
field: 'metadata',
},
destination: {
node_id: CANVAS_OUTPUT,
field: 'metadata',
},
});
// Add Seamless To Graph
if (seamlessXAxis || seamlessYAxis) {
addSeamlessToLinearGraph(state, graph, modelLoaderNodeId);
@@ -352,6 +367,9 @@ export const buildCanvasImageToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -363,7 +381,5 @@ export const buildCanvasImageToImageGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -12,6 +12,7 @@ import {
RangeOfSizeInvocation,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@@ -44,7 +45,6 @@ import {
RANGE_OF_SIZE,
SEAMLESS,
} from './constants';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Inpaint graph.
@@ -88,8 +88,12 @@ export const buildCanvasInpaintGraph = (
const { width, height } = state.canvas.boundingBoxDimensions;
// We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const is_intermediate = true;
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const fp32 = vaePrecision === 'fp32';
const isUsingScaledDimensions = ['auto', 'manual'].includes(
@@ -108,56 +112,56 @@ export const buildCanvasInpaintGraph = (
[modelLoaderNodeId]: {
type: 'main_model_loader',
id: modelLoaderNodeId,
is_intermediate,
is_intermediate: true,
model,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
is_intermediate,
is_intermediate: true,
skipped_layers: clipSkip,
},
[POSITIVE_CONDITIONING]: {
type: 'compel',
id: POSITIVE_CONDITIONING,
is_intermediate,
is_intermediate: true,
prompt: positivePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
id: NEGATIVE_CONDITIONING,
is_intermediate,
is_intermediate: true,
prompt: negativePrompt,
},
[MASK_BLUR]: {
type: 'img_blur',
id: MASK_BLUR,
is_intermediate,
is_intermediate: true,
radius: maskBlur,
blur_type: maskBlurMethod,
},
[INPAINT_IMAGE]: {
type: 'i2l',
id: INPAINT_IMAGE,
is_intermediate,
is_intermediate: true,
fp32,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate,
is_intermediate: true,
},
[INPAINT_CREATE_MASK]: {
type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK,
is_intermediate,
is_intermediate: true,
fp32,
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
steps: steps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -168,18 +172,18 @@ export const buildCanvasInpaintGraph = (
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate,
is_intermediate: true,
},
[CANVAS_COHERENCE_NOISE_INCREMENT]: {
type: 'add',
id: CANVAS_COHERENCE_NOISE_INCREMENT,
b: 1,
is_intermediate,
is_intermediate: true,
},
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: CANVAS_COHERENCE_DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
steps: canvasCoherenceSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -189,19 +193,19 @@ export const buildCanvasInpaintGraph = (
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate,
is_intermediate: true,
fp32,
},
[CANVAS_OUTPUT]: {
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: !shouldAutoSave,
reference: canvasInitImage,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
is_intermediate,
is_intermediate: true,
// seed - must be connected manually
// start: 0,
size: iterations,
@@ -210,7 +214,7 @@ export const buildCanvasInpaintGraph = (
[ITERATE]: {
type: 'iterate',
id: ITERATE,
is_intermediate,
is_intermediate: true,
},
},
edges: [
@@ -433,7 +437,7 @@ export const buildCanvasInpaintGraph = (
graph.nodes[INPAINT_IMAGE_RESIZE_UP] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_UP,
is_intermediate,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
image: canvasInitImage,
@@ -441,7 +445,7 @@ export const buildCanvasInpaintGraph = (
graph.nodes[MASK_RESIZE_UP] = {
type: 'img_resize',
id: MASK_RESIZE_UP,
is_intermediate,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
image: canvasMaskImage,
@@ -449,14 +453,14 @@ export const buildCanvasInpaintGraph = (
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN,
is_intermediate,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: MASK_RESIZE_DOWN,
is_intermediate,
is_intermediate: true,
width: width,
height: height,
};
@@ -594,7 +598,7 @@ export const buildCanvasInpaintGraph = (
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
type: 'create_denoise_mask',
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
is_intermediate,
is_intermediate: true,
fp32,
};
@@ -647,7 +651,7 @@ export const buildCanvasInpaintGraph = (
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
type: 'mask_edge',
id: CANVAS_COHERENCE_MASK_EDGE,
is_intermediate,
is_intermediate: true,
edge_blur: maskBlur,
edge_size: maskBlur * 2,
low_threshold: 100,
@@ -733,6 +737,9 @@ export const buildCanvasInpaintGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -744,7 +751,5 @@ export const buildCanvasInpaintGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -11,6 +11,7 @@ import {
RangeOfSizeInvocation,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@@ -46,7 +47,6 @@ import {
RANGE_OF_SIZE,
SEAMLESS,
} from './constants';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Outpaint graph.
@@ -92,10 +92,14 @@ export const buildCanvasOutpaintGraph = (
const { width, height } = state.canvas.boundingBoxDimensions;
// We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -112,61 +116,61 @@ export const buildCanvasOutpaintGraph = (
[modelLoaderNodeId]: {
type: 'main_model_loader',
id: modelLoaderNodeId,
is_intermediate,
is_intermediate: true,
model,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
is_intermediate,
is_intermediate: true,
skipped_layers: clipSkip,
},
[POSITIVE_CONDITIONING]: {
type: 'compel',
id: POSITIVE_CONDITIONING,
is_intermediate,
is_intermediate: true,
prompt: positivePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
id: NEGATIVE_CONDITIONING,
is_intermediate,
is_intermediate: true,
prompt: negativePrompt,
},
[MASK_FROM_ALPHA]: {
type: 'tomask',
id: MASK_FROM_ALPHA,
is_intermediate,
is_intermediate: true,
image: canvasInitImage,
},
[MASK_COMBINE]: {
type: 'mask_combine',
id: MASK_COMBINE,
is_intermediate,
is_intermediate: true,
mask2: canvasMaskImage,
},
[INPAINT_IMAGE]: {
type: 'i2l',
id: INPAINT_IMAGE,
is_intermediate,
is_intermediate: true,
fp32,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate,
is_intermediate: true,
},
[INPAINT_CREATE_MASK]: {
type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK,
is_intermediate,
is_intermediate: true,
fp32,
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
steps: steps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -177,18 +181,18 @@ export const buildCanvasOutpaintGraph = (
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate,
is_intermediate: true,
},
[CANVAS_COHERENCE_NOISE_INCREMENT]: {
type: 'add',
id: CANVAS_COHERENCE_NOISE_INCREMENT,
b: 1,
is_intermediate,
is_intermediate: true,
},
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: CANVAS_COHERENCE_DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
steps: canvasCoherenceSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -198,18 +202,18 @@ export const buildCanvasOutpaintGraph = (
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate,
is_intermediate: true,
fp32,
},
[CANVAS_OUTPUT]: {
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: !shouldAutoSave,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
is_intermediate,
is_intermediate: true,
// seed - must be connected manually
// start: 0,
size: iterations,
@@ -218,7 +222,7 @@ export const buildCanvasOutpaintGraph = (
[ITERATE]: {
type: 'iterate',
id: ITERATE,
is_intermediate,
is_intermediate: true,
},
},
edges: [
@@ -469,7 +473,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_patchmatch',
id: INPAINT_INFILL,
is_intermediate,
is_intermediate: true,
downscale: infillPatchmatchDownscaleSize,
};
}
@@ -478,7 +482,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_lama',
id: INPAINT_INFILL,
is_intermediate,
is_intermediate: true,
};
}
@@ -486,7 +490,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_cv2',
id: INPAINT_INFILL,
is_intermediate,
is_intermediate: true,
};
}
@@ -494,7 +498,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_tile',
id: INPAINT_INFILL,
is_intermediate,
is_intermediate: true,
tile_size: infillTileSize,
};
}
@@ -508,7 +512,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[INPAINT_IMAGE_RESIZE_UP] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_UP,
is_intermediate,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
image: canvasInitImage,
@@ -516,28 +520,28 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[MASK_RESIZE_UP] = {
type: 'img_resize',
id: MASK_RESIZE_UP,
is_intermediate,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN,
is_intermediate,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[INPAINT_INFILL_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_INFILL_RESIZE_DOWN,
is_intermediate,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: MASK_RESIZE_DOWN,
is_intermediate,
is_intermediate: true,
width: width,
height: height,
};
@@ -696,7 +700,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
type: 'create_denoise_mask',
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
is_intermediate,
is_intermediate: true,
fp32,
};
@@ -743,7 +747,7 @@ export const buildCanvasOutpaintGraph = (
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
type: 'mask_edge',
id: CANVAS_COHERENCE_MASK_EDGE,
is_intermediate,
is_intermediate: true,
edge_blur: maskBlur,
edge_size: maskBlur * 2,
low_threshold: 100,
@@ -835,6 +839,9 @@ export const buildCanvasOutpaintGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -846,7 +853,5 @@ export const buildCanvasOutpaintGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -5,6 +5,7 @@ import { initialGenerationState } from 'features/parameters/store/generationSlic
import { ImageDTO, ImageToLatentsInvocation } from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@@ -27,7 +28,6 @@ import {
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Image to Image graph.
@@ -62,10 +62,14 @@ export const buildCanvasSDXLImageToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -119,7 +123,7 @@ export const buildCanvasSDXLImageToImageGraph = (
[NOISE]: {
type: 'noise',
id: NOISE,
is_intermediate,
is_intermediate: true,
use_cpu,
width: !isUsingScaledDimensions
? width
@@ -131,13 +135,13 @@ export const buildCanvasSDXLImageToImageGraph = (
[IMAGE_TO_LATENTS]: {
type: 'i2l',
id: IMAGE_TO_LATENTS,
is_intermediate,
is_intermediate: true,
fp32,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
cfg_scale,
scheduler,
steps,
@@ -248,7 +252,7 @@ export const buildCanvasSDXLImageToImageGraph = (
graph.nodes[IMG2IMG_RESIZE] = {
id: IMG2IMG_RESIZE,
type: 'img_resize',
is_intermediate,
is_intermediate: true,
image: initialImage,
width: scaledBoundingBoxDimensions.width,
height: scaledBoundingBoxDimensions.height,
@@ -256,13 +260,13 @@ export const buildCanvasSDXLImageToImageGraph = (
graph.nodes[LATENTS_TO_IMAGE] = {
id: LATENTS_TO_IMAGE,
type: 'l2i',
is_intermediate,
is_intermediate: true,
fp32,
};
graph.nodes[CANVAS_OUTPUT] = {
id: CANVAS_OUTPUT,
type: 'img_resize',
is_intermediate,
is_intermediate: !shouldAutoSave,
width: width,
height: height,
};
@@ -303,7 +307,7 @@ export const buildCanvasSDXLImageToImageGraph = (
graph.nodes[CANVAS_OUTPUT] = {
type: 'l2i',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: !shouldAutoSave,
fp32,
};
@@ -389,6 +393,9 @@ export const buildCanvasSDXLImageToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -400,7 +407,5 @@ export const buildCanvasSDXLImageToImageGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -46,7 +46,7 @@ import {
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
import { addSaveImageNode } from './addSaveImageNode';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
/**
* Builds the Canvas tab's Inpaint graph.
@@ -95,10 +95,14 @@ export const buildCanvasSDXLInpaintGraph = (
const { width, height } = state.canvas.boundingBoxDimensions;
// We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -136,32 +140,32 @@ export const buildCanvasSDXLInpaintGraph = (
[MASK_BLUR]: {
type: 'img_blur',
id: MASK_BLUR,
is_intermediate,
is_intermediate: true,
radius: maskBlur,
blur_type: maskBlurMethod,
},
[INPAINT_IMAGE]: {
type: 'i2l',
id: INPAINT_IMAGE,
is_intermediate,
is_intermediate: true,
fp32,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate,
is_intermediate: true,
},
[INPAINT_CREATE_MASK]: {
type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK,
is_intermediate,
is_intermediate: true,
fp32,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
steps: steps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -174,18 +178,18 @@ export const buildCanvasSDXLInpaintGraph = (
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate,
is_intermediate: true,
},
[CANVAS_COHERENCE_NOISE_INCREMENT]: {
type: 'add',
id: CANVAS_COHERENCE_NOISE_INCREMENT,
b: 1,
is_intermediate,
is_intermediate: true,
},
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: CANVAS_COHERENCE_DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
steps: canvasCoherenceSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -195,19 +199,19 @@ export const buildCanvasSDXLInpaintGraph = (
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate,
is_intermediate: true,
fp32,
},
[CANVAS_OUTPUT]: {
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: !shouldAutoSave,
reference: canvasInitImage,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
is_intermediate,
is_intermediate: true,
// seed - must be connected manually
// start: 0,
size: iterations,
@@ -216,7 +220,7 @@ export const buildCanvasSDXLInpaintGraph = (
[ITERATE]: {
type: 'iterate',
id: ITERATE,
is_intermediate,
is_intermediate: true,
},
},
edges: [
@@ -448,7 +452,7 @@ export const buildCanvasSDXLInpaintGraph = (
graph.nodes[INPAINT_IMAGE_RESIZE_UP] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_UP,
is_intermediate,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
image: canvasInitImage,
@@ -456,7 +460,7 @@ export const buildCanvasSDXLInpaintGraph = (
graph.nodes[MASK_RESIZE_UP] = {
type: 'img_resize',
id: MASK_RESIZE_UP,
is_intermediate,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
image: canvasMaskImage,
@@ -464,14 +468,14 @@ export const buildCanvasSDXLInpaintGraph = (
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN,
is_intermediate,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: MASK_RESIZE_DOWN,
is_intermediate,
is_intermediate: true,
width: width,
height: height,
};
@@ -609,7 +613,7 @@ export const buildCanvasSDXLInpaintGraph = (
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
type: 'create_denoise_mask',
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
is_intermediate,
is_intermediate: true,
fp32,
};
@@ -662,7 +666,7 @@ export const buildCanvasSDXLInpaintGraph = (
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
type: 'mask_edge',
id: CANVAS_COHERENCE_MASK_EDGE,
is_intermediate,
is_intermediate: true,
edge_blur: maskBlur,
edge_size: maskBlur * 2,
low_threshold: 100,
@@ -762,6 +766,9 @@ export const buildCanvasSDXLInpaintGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -773,7 +780,5 @@ export const buildCanvasSDXLInpaintGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -11,6 +11,7 @@ import {
RangeOfSizeInvocation,
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@@ -48,7 +49,6 @@ import {
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Outpaint graph.
@@ -99,10 +99,14 @@ export const buildCanvasSDXLOutpaintGraph = (
const { width, height } = state.canvas.boundingBoxDimensions;
// We may need to set the inpaint width and height to scale the image
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -140,37 +144,37 @@ export const buildCanvasSDXLOutpaintGraph = (
[MASK_FROM_ALPHA]: {
type: 'tomask',
id: MASK_FROM_ALPHA,
is_intermediate,
is_intermediate: true,
image: canvasInitImage,
},
[MASK_COMBINE]: {
type: 'mask_combine',
id: MASK_COMBINE,
is_intermediate,
is_intermediate: true,
mask2: canvasMaskImage,
},
[INPAINT_IMAGE]: {
type: 'i2l',
id: INPAINT_IMAGE,
is_intermediate,
is_intermediate: true,
fp32,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate,
is_intermediate: true,
},
[INPAINT_CREATE_MASK]: {
type: 'create_denoise_mask',
id: INPAINT_CREATE_MASK,
is_intermediate,
is_intermediate: true,
fp32,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
steps: steps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -183,18 +187,18 @@ export const buildCanvasSDXLOutpaintGraph = (
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate,
is_intermediate: true,
},
[CANVAS_COHERENCE_NOISE_INCREMENT]: {
type: 'add',
id: CANVAS_COHERENCE_NOISE_INCREMENT,
b: 1,
is_intermediate,
is_intermediate: true,
},
[CANVAS_COHERENCE_DENOISE_LATENTS]: {
type: 'denoise_latents',
id: CANVAS_COHERENCE_DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
steps: canvasCoherenceSteps,
cfg_scale: cfg_scale,
scheduler: scheduler,
@@ -204,18 +208,18 @@ export const buildCanvasSDXLOutpaintGraph = (
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate,
is_intermediate: true,
fp32,
},
[CANVAS_OUTPUT]: {
type: 'color_correct',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: !shouldAutoSave,
},
[RANGE_OF_SIZE]: {
type: 'range_of_size',
id: RANGE_OF_SIZE,
is_intermediate,
is_intermediate: true,
// seed - must be connected manually
// start: 0,
size: iterations,
@@ -224,7 +228,7 @@ export const buildCanvasSDXLOutpaintGraph = (
[ITERATE]: {
type: 'iterate',
id: ITERATE,
is_intermediate,
is_intermediate: true,
},
},
edges: [
@@ -484,7 +488,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_patchmatch',
id: INPAINT_INFILL,
is_intermediate,
is_intermediate: true,
downscale: infillPatchmatchDownscaleSize,
};
}
@@ -493,7 +497,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_lama',
id: INPAINT_INFILL,
is_intermediate,
is_intermediate: true,
};
}
@@ -501,7 +505,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_cv2',
id: INPAINT_INFILL,
is_intermediate,
is_intermediate: true,
};
}
@@ -509,7 +513,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[INPAINT_INFILL] = {
type: 'infill_tile',
id: INPAINT_INFILL,
is_intermediate,
is_intermediate: true,
tile_size: infillTileSize,
};
}
@@ -523,7 +527,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[INPAINT_IMAGE_RESIZE_UP] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_UP,
is_intermediate,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
image: canvasInitImage,
@@ -531,28 +535,28 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[MASK_RESIZE_UP] = {
type: 'img_resize',
id: MASK_RESIZE_UP,
is_intermediate,
is_intermediate: true,
width: scaledWidth,
height: scaledHeight,
};
graph.nodes[INPAINT_IMAGE_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_IMAGE_RESIZE_DOWN,
is_intermediate,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[INPAINT_INFILL_RESIZE_DOWN] = {
type: 'img_resize',
id: INPAINT_INFILL_RESIZE_DOWN,
is_intermediate,
is_intermediate: true,
width: width,
height: height,
};
graph.nodes[MASK_RESIZE_DOWN] = {
type: 'img_resize',
id: MASK_RESIZE_DOWN,
is_intermediate,
is_intermediate: true,
width: width,
height: height,
};
@@ -712,7 +716,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[CANVAS_COHERENCE_INPAINT_CREATE_MASK] = {
type: 'create_denoise_mask',
id: CANVAS_COHERENCE_INPAINT_CREATE_MASK,
is_intermediate,
is_intermediate: true,
fp32,
};
@@ -759,7 +763,7 @@ export const buildCanvasSDXLOutpaintGraph = (
graph.nodes[CANVAS_COHERENCE_MASK_EDGE] = {
type: 'mask_edge',
id: CANVAS_COHERENCE_MASK_EDGE,
is_intermediate,
is_intermediate: true,
edge_blur: maskBlur,
edge_size: maskBlur * 2,
low_threshold: 100,
@@ -865,6 +869,9 @@ export const buildCanvasSDXLOutpaintGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -876,7 +883,5 @@ export const buildCanvasSDXLOutpaintGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -8,6 +8,7 @@ import {
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@@ -29,7 +30,6 @@ import {
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Canvas tab's Text to Image graph.
@@ -56,10 +56,14 @@ export const buildCanvasSDXLTextToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -91,7 +95,7 @@ export const buildCanvasSDXLTextToImageGraph = (
? {
type: 't2l_onnx',
id: SDXL_DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
cfg_scale,
scheduler,
steps,
@@ -99,7 +103,7 @@ export const buildCanvasSDXLTextToImageGraph = (
: {
type: 'denoise_latents',
id: SDXL_DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
cfg_scale,
scheduler,
steps,
@@ -128,27 +132,27 @@ export const buildCanvasSDXLTextToImageGraph = (
[modelLoaderNodeId]: {
type: modelLoaderNodeType,
id: modelLoaderNodeId,
is_intermediate,
is_intermediate: true,
model,
},
[POSITIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
is_intermediate,
is_intermediate: true,
prompt: positivePrompt,
style: craftedPositiveStylePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
is_intermediate,
is_intermediate: true,
prompt: negativePrompt,
style: craftedNegativeStylePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
is_intermediate,
is_intermediate: true,
width: !isUsingScaledDimensions
? width
: scaledBoundingBoxDimensions.width,
@@ -250,14 +254,14 @@ export const buildCanvasSDXLTextToImageGraph = (
graph.nodes[LATENTS_TO_IMAGE] = {
id: LATENTS_TO_IMAGE,
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
is_intermediate,
is_intermediate: true,
fp32,
};
graph.nodes[CANVAS_OUTPUT] = {
id: CANVAS_OUTPUT,
type: 'img_resize',
is_intermediate,
is_intermediate: !shouldAutoSave,
width: width,
height: height,
};
@@ -288,7 +292,7 @@ export const buildCanvasSDXLTextToImageGraph = (
graph.nodes[CANVAS_OUTPUT] = {
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: !shouldAutoSave,
fp32,
};
@@ -369,6 +373,9 @@ export const buildCanvasSDXLTextToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -380,7 +387,5 @@ export const buildCanvasSDXLTextToImageGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -8,9 +8,9 @@ import {
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSaveImageNode } from './addSaveImageNode';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
import { addVAEToGraph } from './addVAEToGraph';
import { addWatermarkerToGraph } from './addWatermarkerToGraph';
@@ -54,10 +54,14 @@ export const buildCanvasTextToImageGraph = (
// The bounding box determines width and height, not the width and height params
const { width, height } = state.canvas.boundingBoxDimensions;
const { scaledBoundingBoxDimensions, boundingBoxScaleMethod } = state.canvas;
const {
scaledBoundingBoxDimensions,
boundingBoxScaleMethod,
shouldAutoSave,
} = state.canvas;
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingScaledDimensions = ['auto', 'manual'].includes(
boundingBoxScaleMethod
);
@@ -86,7 +90,7 @@ export const buildCanvasTextToImageGraph = (
? {
type: 't2l_onnx',
id: DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
cfg_scale,
scheduler,
steps,
@@ -94,7 +98,7 @@ export const buildCanvasTextToImageGraph = (
: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
cfg_scale,
scheduler,
steps,
@@ -119,31 +123,31 @@ export const buildCanvasTextToImageGraph = (
[modelLoaderNodeId]: {
type: modelLoaderNodeType,
id: modelLoaderNodeId,
is_intermediate,
is_intermediate: true,
model,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
is_intermediate,
is_intermediate: true,
skipped_layers: clipSkip,
},
[POSITIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'compel',
id: POSITIVE_CONDITIONING,
is_intermediate,
is_intermediate: true,
prompt: positivePrompt,
},
[NEGATIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'compel',
id: NEGATIVE_CONDITIONING,
is_intermediate,
is_intermediate: true,
prompt: negativePrompt,
},
[NOISE]: {
type: 'noise',
id: NOISE,
is_intermediate,
is_intermediate: true,
width: !isUsingScaledDimensions
? width
: scaledBoundingBoxDimensions.width,
@@ -236,14 +240,14 @@ export const buildCanvasTextToImageGraph = (
graph.nodes[LATENTS_TO_IMAGE] = {
id: LATENTS_TO_IMAGE,
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
is_intermediate,
is_intermediate: true,
fp32,
};
graph.nodes[CANVAS_OUTPUT] = {
id: CANVAS_OUTPUT,
type: 'img_resize',
is_intermediate,
is_intermediate: !shouldAutoSave,
width: width,
height: height,
};
@@ -274,7 +278,7 @@ export const buildCanvasTextToImageGraph = (
graph.nodes[CANVAS_OUTPUT] = {
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
id: CANVAS_OUTPUT,
is_intermediate,
is_intermediate: !shouldAutoSave,
fp32,
};
@@ -342,6 +346,9 @@ export const buildCanvasTextToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -353,7 +360,5 @@ export const buildCanvasTextToImageGraph = (
addWatermarkerToGraph(state, graph, CANVAS_OUTPUT);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -8,6 +8,7 @@ import {
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@@ -27,7 +28,6 @@ import {
RESIZE,
SEAMLESS,
} from './constants';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Image to Image tab graph.
@@ -86,7 +86,6 @@ export const buildLinearImageToImageGraph = (
}
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
let modelLoaderNodeId = MAIN_MODEL_LOADER;
@@ -102,37 +101,31 @@ export const buildLinearImageToImageGraph = (
type: 'main_model_loader',
id: modelLoaderNodeId,
model,
is_intermediate,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
skipped_layers: clipSkip,
is_intermediate,
},
[POSITIVE_CONDITIONING]: {
type: 'compel',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
is_intermediate,
},
[NEGATIVE_CONDITIONING]: {
type: 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
is_intermediate,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
},
[DENOISE_LATENTS]: {
type: 'denoise_latents',
@@ -142,7 +135,6 @@ export const buildLinearImageToImageGraph = (
steps,
denoising_start: 1 - strength,
denoising_end: 1,
is_intermediate,
},
[IMAGE_TO_LATENTS]: {
type: 'i2l',
@@ -152,7 +144,6 @@ export const buildLinearImageToImageGraph = (
// image_name: initialImage.image_name,
// },
fp32,
is_intermediate,
},
},
edges: [
@@ -374,6 +365,9 @@ export const buildLinearImageToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -385,7 +379,5 @@ export const buildLinearImageToImageGraph = (
addWatermarkerToGraph(state, graph);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -8,6 +8,7 @@ import {
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@@ -29,7 +30,6 @@ import {
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
import { addSaveImageNode } from './addSaveImageNode';
/**
* Builds the Image to Image tab graph.
@@ -86,7 +86,6 @@ export const buildLinearSDXLImageToImageGraph = (
}
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
// Model Loader ID
let modelLoaderNodeId = SDXL_MODEL_LOADER;
@@ -107,33 +106,28 @@ export const buildLinearSDXLImageToImageGraph = (
type: 'sdxl_model_loader',
id: modelLoaderNodeId,
model,
is_intermediate,
},
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: craftedPositiveStylePrompt,
is_intermediate,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: craftedNegativeStylePrompt,
is_intermediate,
},
[NOISE]: {
type: 'noise',
id: NOISE,
use_cpu,
is_intermediate,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
@@ -145,7 +139,6 @@ export const buildLinearSDXLImageToImageGraph = (
? Math.min(refinerStart, 1 - strength)
: 1 - strength,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
is_intermediate,
},
[IMAGE_TO_LATENTS]: {
type: 'i2l',
@@ -155,7 +148,6 @@ export const buildLinearSDXLImageToImageGraph = (
// image_name: initialImage.image_name,
// },
fp32,
is_intermediate,
},
},
edges: [
@@ -393,6 +385,9 @@ export const buildLinearSDXLImageToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// Add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
@@ -407,7 +402,5 @@ export const buildLinearSDXLImageToImageGraph = (
addWatermarkerToGraph(state, graph);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -4,6 +4,7 @@ import { NonNullableGraph } from 'features/nodes/types/types';
import { initialGenerationState } from 'features/parameters/store/generationSlice';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph';
import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph';
@@ -23,7 +24,6 @@ import {
SEAMLESS,
} from './constants';
import { craftSDXLStylePrompt } from './helpers/craftSDXLStylePrompt';
import { addSaveImageNode } from './addSaveImageNode';
export const buildLinearSDXLTextToImageGraph = (
state: RootState
@@ -57,13 +57,13 @@ export const buildLinearSDXLTextToImageGraph = (
const use_cpu = shouldUseNoiseSettings
? shouldUseCpuNoise
: initialGenerationState.shouldUseCpuNoise;
if (!model) {
log.error('No model found in state');
throw new Error('No model found in state');
}
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
// Construct Style Prompt
const { craftedPositiveStylePrompt, craftedNegativeStylePrompt } =
@@ -89,21 +89,18 @@ export const buildLinearSDXLTextToImageGraph = (
type: 'sdxl_model_loader',
id: modelLoaderNodeId,
model,
is_intermediate,
},
[POSITIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
style: craftedPositiveStylePrompt,
is_intermediate,
},
[NEGATIVE_CONDITIONING]: {
type: 'sdxl_compel_prompt',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
style: craftedNegativeStylePrompt,
is_intermediate,
},
[NOISE]: {
type: 'noise',
@@ -111,7 +108,6 @@ export const buildLinearSDXLTextToImageGraph = (
width,
height,
use_cpu,
is_intermediate,
},
[SDXL_DENOISE_LATENTS]: {
type: 'denoise_latents',
@@ -121,13 +117,11 @@ export const buildLinearSDXLTextToImageGraph = (
steps,
denoising_start: 0,
denoising_end: shouldUseSDXLRefiner ? refinerStart : 1,
is_intermediate,
},
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
},
},
edges: [
@@ -284,6 +278,9 @@ export const buildLinearSDXLTextToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// add IP Adapter
addIPAdapterToLinearGraph(state, graph, SDXL_DENOISE_LATENTS);
// add dynamic prompts - also sets up core iteration and seed
addDynamicPromptsToGraph(state, graph);
@@ -298,7 +295,5 @@ export const buildLinearSDXLTextToImageGraph = (
addWatermarkerToGraph(state, graph);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -8,6 +8,7 @@ import {
} from 'services/api/types';
import { addControlNetToLinearGraph } from './addControlNetToLinearGraph';
import { addDynamicPromptsToGraph } from './addDynamicPromptsToGraph';
import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph';
import { addLoRAsToGraph } from './addLoRAsToGraph';
import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph';
import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph';
@@ -26,7 +27,6 @@ import {
SEAMLESS,
TEXT_TO_IMAGE_GRAPH,
} from './constants';
import { addSaveImageNode } from './addSaveImageNode';
export const buildLinearTextToImageGraph = (
state: RootState
@@ -59,7 +59,7 @@ export const buildLinearTextToImageGraph = (
}
const fp32 = vaePrecision === 'fp32';
const is_intermediate = true;
const isUsingOnnxModel = model.model_type === 'onnx';
let modelLoaderNodeId = isUsingOnnxModel
@@ -75,7 +75,7 @@ export const buildLinearTextToImageGraph = (
? {
type: 't2l_onnx',
id: DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
cfg_scale,
scheduler,
steps,
@@ -83,7 +83,7 @@ export const buildLinearTextToImageGraph = (
: {
type: 'denoise_latents',
id: DENOISE_LATENTS,
is_intermediate,
is_intermediate: true,
cfg_scale,
scheduler,
steps,
@@ -109,26 +109,26 @@ export const buildLinearTextToImageGraph = (
[modelLoaderNodeId]: {
type: modelLoaderNodeType,
id: modelLoaderNodeId,
is_intermediate,
is_intermediate: true,
model,
},
[CLIP_SKIP]: {
type: 'clip_skip',
id: CLIP_SKIP,
skipped_layers: clipSkip,
is_intermediate,
is_intermediate: true,
},
[POSITIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'compel',
id: POSITIVE_CONDITIONING,
prompt: positivePrompt,
is_intermediate,
is_intermediate: true,
},
[NEGATIVE_CONDITIONING]: {
type: isUsingOnnxModel ? 'prompt_onnx' : 'compel',
id: NEGATIVE_CONDITIONING,
prompt: negativePrompt,
is_intermediate,
is_intermediate: true,
},
[NOISE]: {
type: 'noise',
@@ -136,14 +136,13 @@ export const buildLinearTextToImageGraph = (
width,
height,
use_cpu,
is_intermediate,
is_intermediate: true,
},
[t2lNode.id]: t2lNode,
[LATENTS_TO_IMAGE]: {
type: isUsingOnnxModel ? 'l2i_onnx' : 'l2i',
id: LATENTS_TO_IMAGE,
fp32,
is_intermediate,
},
},
edges: [
@@ -284,6 +283,9 @@ export const buildLinearTextToImageGraph = (
// add controlnet, mutating `graph`
addControlNetToLinearGraph(state, graph, DENOISE_LATENTS);
// add IP Adapter
addIPAdapterToLinearGraph(state, graph, DENOISE_LATENTS);
// NSFW & watermark - must be last thing added to graph
if (state.system.shouldUseNSFWChecker) {
// must add before watermarker!
@@ -295,7 +297,5 @@ export const buildLinearTextToImageGraph = (
addWatermarkerToGraph(state, graph);
}
addSaveImageNode(state, graph);
return graph;
};

View File

@@ -55,9 +55,6 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
{} as Record<Exclude<string, 'id' | 'type'>, unknown>
);
// add reserved use_cache
transformedInputs['use_cache'] = node.data.useCache;
// Build this specific node
const graphNode = {
type,

Some files were not shown because too many files have changed in this diff Show More