Compare commits

...

82 Commits

Author SHA1 Message Date
Mary Hipp
9bd1f4a4f4 ruff 2024-11-06 16:30:37 -05:00
Mary Hipp
28864f6d7f use queue room/subscription instead of separate model loading room 2024-11-06 16:30:02 -05:00
Mary Hipp
c63fe5e9bb add queue_id to all model load invocations 2024-11-06 16:19:34 -05:00
Mary Hipp
674f530501 break out model load events from other model events, add queue_id as required arg everytime model loads so that event can be emitted to proper queue 2024-11-06 13:48:59 -05:00
psychedelicious
a01d44f813 chore(ui): lint 2024-11-06 10:25:46 -05:00
psychedelicious
63fb3a15e9 feat(ui): default to no control model selected for control layers 2024-11-06 10:25:46 -05:00
psychedelicious
4d0837541b feat(ui): add simple mode filtering 2024-11-06 10:25:46 -05:00
psychedelicious
999809b4c7 fix(ui): minor viewer close button styling 2024-11-06 10:25:46 -05:00
psychedelicious
c452edfb9f feat(ui): add control layer empty state 2024-11-06 10:25:46 -05:00
psychedelicious
ad2cdbd8a2 feat(ui): tooltip for canvas preview image 2024-11-06 10:25:46 -05:00
psychedelicious
f15c24bfa7 feat(ui): add " (recommended)" to balanced control mode label 2024-11-06 10:25:46 -05:00
psychedelicious
d1f653f28c feat(ui): make default control end step 0.75 2024-11-06 10:25:46 -05:00
psychedelicious
244465d3a6 feat(ui): make default control weight 0.75 2024-11-06 10:25:46 -05:00
psychedelicious
c6236ab70c feat(ui): add menubar-ish header on comparison 2024-11-06 10:25:46 -05:00
psychedelicious
644d5cb411 feat(ui): add menubar-ish header on viewer 2024-11-06 10:25:46 -05:00
Riku
bb0a630416 fix(ui): adjust knip config to ignore parameter schema exports 2024-11-06 22:51:17 +11:00
Riku
2148ae9287 feat(ui): simplify parameter schema declaration and type inference 2024-11-06 22:51:17 +11:00
psychedelicious
42d242609c chore(gh): update pr template w/ reminder for what's new copy 2024-11-06 19:03:31 +11:00
psychedelicious
fd0a52392b feat(ui): added line about when denoising str is disabled 2024-11-06 19:01:33 +11:00
psychedelicious
e64415d59a feat(ui): revised logic to disable denoising str 2024-11-06 19:01:33 +11:00
psychedelicious
1871e0bdbf feat(ui): tweaked denoise str styling 2024-11-06 19:01:33 +11:00
Mary Hipp
3ae9a965c2 lint 2024-11-06 19:01:33 +11:00
Mary Hipp
85932e35a7 update copy again 2024-11-06 19:01:33 +11:00
Mary Hipp
41b07a56cc update popover copy and add image 2024-11-06 19:01:33 +11:00
Mary Hipp
54064c0cb8 fix(ui): match badge height to slider height so layout does not shift 2024-11-06 19:01:33 +11:00
Mary Hipp
68284b37fa remove opacity logic from WavyLine, add badge explaining disabled state, add translations 2024-11-06 19:01:33 +11:00
Mary Hipp
ae5bc6f5d6 feat(ui): move denoising strength to layers panel w/ visualization of how much change will be applied, only enable if 1+ enabled raster layer 2024-11-06 19:01:33 +11:00
Mary Hipp
6dc16c9f54 wip 2024-11-06 19:01:33 +11:00
Brandon Rising
faa9ac4e15 fix: get_clip_variant_type should never return None 2024-11-06 09:59:50 +11:00
Mary Hipp Rogers
d0460849b0 fix bad merge conflict (#7273)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2024-11-05 16:02:03 -05:00
Mary Hipp Rogers
bed3c2dd77 update Whats New for 5.3.1 (#7272)
Co-authored-by: Mary Hipp <maryhipp@Marys-MacBook-Air.local>
2024-11-05 15:43:16 -05:00
Mary Hipp
916ddd17d7 fix(ui): fix link for infill method popover 2024-11-05 15:39:03 -05:00
Mary Hipp
accfa7407f fix undefined 2024-11-05 15:30:17 -05:00
Mary Hipp
908db31e48 feat(api,ui): allow Whats New module to get content from back-end 2024-11-05 15:30:17 -05:00
Mary Hipp
b70f632b26 fix(ui): add some feedback while layers are merging 2024-11-05 12:38:50 -05:00
Brandon Rising
d07a6385ab Always default to ClipVariantType.L instead of None 2024-11-05 12:03:40 -05:00
Brandon Rising
68df612fa1 fix: Never throw an exception when finding the clip variant type 2024-11-05 12:03:40 -05:00
psychedelicious
3b96c79461 chore: bump version to v5.4.0 2024-11-05 10:09:21 +11:00
psychedelicious
89bda5b983 Ryan/sd3 diffusers (#7222)
## Summary

Nodes to support SD3.5 txt2img generations
* adds SD3.5 to starter models
* adds default workflow for SD3.5 txt2img

## Related Issues / Discussions

<!--WHEN APPLICABLE: List any related issues or discussions on github or
discord. If this PR closes an issue, please use the "Closes #1234"
format, so that the issue will be automatically closed when the PR
merges.-->

## QA Instructions

<!--WHEN APPLICABLE: Describe how you have tested the changes in this
PR. Provide enough detail that a reviewer can reproduce your tests.-->

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [ ] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
2024-11-05 08:21:28 +11:00
Brandon Rising
22bff1fb22 Fix conditional within filter_by_variant to not read all candidates as default 2024-11-04 12:42:09 -05:00
Mary Hipp
55ba6488d1 fix up types file 2024-11-04 12:42:09 -05:00
brandonrising
2d78859171 Create bespoke latents to image node for sd3 2024-11-04 12:42:09 -05:00
Mary Hipp
3a661bac34 fix(ui): exclude submodels from model manager 2024-11-04 12:42:09 -05:00
Mary Hipp
bb8a02de18 update schema 2024-11-04 12:42:09 -05:00
maryhipp
78155344f6 update node fields for SD3 to match other SD nodes 2024-11-04 12:42:09 -05:00
Brandon Rising
391a24b0f6 Re-add erroniously removed hash code 2024-11-04 12:42:09 -05:00
Brandon Rising
e75903389f Run ruff, fix bug in hf downloading code which failed to download parts of a model 2024-11-04 12:42:09 -05:00
Brandon Rising
27567052f2 Create new latent factors for sd35 2024-11-04 12:42:09 -05:00
Brandon Rising
6f447f7169 Rather than .fp16., some repos start the suffix with .fp16... for weights spread across multiple files 2024-11-04 12:42:09 -05:00
Mary Hipp
8b370cc182 (ui): dont show SD3 in main model dropdown yet 2024-11-04 12:42:09 -05:00
maryhipp
af583d2971 ruff format 2024-11-04 12:42:09 -05:00
Mary Hipp
0ebe8fb1bd (ui): add required/optional logic to other submodel fields 2024-11-04 12:42:09 -05:00
maryhipp
befb629f46 add default workflow 2024-11-04 12:42:09 -05:00
maryhipp
874d67cb37 add SD3.5 to starter models 2024-11-04 12:42:09 -05:00
Mary Hipp
19f7a1295a (ui): add fields for CLIP-L and CLIP-G, remove MainModelConfig type changes 2024-11-04 12:42:09 -05:00
maryhipp
78bd605617 (nodes,api): expose the submodels on SD3 model loader as optional, add types needed for CLIP-L and CLIP-G fields 2024-11-04 12:42:09 -05:00
Brandon Rising
b87f4e59a5 Create clip variant type, create new fucntions for discerning clipL and clipG in the frontend 2024-11-04 12:42:09 -05:00
Ryan Dick
1eca4f12c8 Make T5 encoder optonal in SD3 workflows. 2024-11-04 12:42:09 -05:00
Ryan Dick
f1de11d6bf Make the default CFG for SD3 3.5. 2024-11-04 12:42:09 -05:00
Ryan Dick
9361ed9d70 Add progress images to SD3 and make denoising cancellable. 2024-11-04 12:42:09 -05:00
Brandon Rising
ebabf4f7a8 Setup Model and T5 Encoder selection fields for sd3 nodes 2024-11-04 12:42:09 -05:00
Brandon Rising
606f3321f5 Initial wave of frontend updates for sd-3 node inputs 2024-11-04 12:42:09 -05:00
Brandon Rising
3970aa30fb define submodels on sd3 models during probe 2024-11-04 12:42:09 -05:00
Ryan Dick
678436e07c Add tqdm progress bar for SD3. 2024-11-04 12:42:09 -05:00
Ryan Dick
c620581699 Bug fixes to get SD3 text-to-image workflow running. 2024-11-04 12:42:09 -05:00
Ryan Dick
c331d42ce4 Temporary hack for testing SD3 model loader. 2024-11-04 12:42:09 -05:00
Ryan Dick
1ac9b502f1 Fix Sd3TextEncoderInvocation output type. 2024-11-04 12:42:09 -05:00
Ryan Dick
3fa478a12f Initial draft of SD3DenoiseInvocation. 2024-11-04 12:42:09 -05:00
Ryan Dick
2d86298b7f Add first draft of Sd3TextEncoderInvocation. 2024-11-04 12:42:09 -05:00
Ryan Dick
009cdb714c Add Sd3ModelLoaderInvocation. 2024-11-04 12:42:09 -05:00
Ryan Dick
9d3f5427b4 Move FluxModelLoaderInvocation to its own file. model.py was getting bloated. 2024-11-04 12:42:09 -05:00
Ryan Dick
e4b17f019a Get diffusers SD3 model probing working. 2024-11-04 12:42:09 -05:00
Ryan Dick
586c00bc02 (minor) Remove unused dict. 2024-11-04 12:42:09 -05:00
Eugene Brodsky
0f11fda65a fix(deps): pin mediapipe strictly to a known working version 2024-11-04 10:16:19 -05:00
psychedelicious
3e75331ef7 fix(ui): load workflow from file
In a8de6406c5 a change was made to many menus in an effort to improve performance. The menus were made to be lazy, so that they are mounted only while open.

This causes unexpected behaviour when there is some logic in the menu that may need to execute after the user selects a menu item.

In this case, when you click to load a workflow from file, the file picker opens but then the menuitem unmounts, taking the input element and all uploading logic with it. When you select a file, nothing happens because we've nuked the handlers by unmounting everything.

Easy fix - un-lazy-fy the menu.

Closes #7240
2024-11-04 08:02:55 -05:00
psychedelicious
be133408ac fix(nodes): relaxed validation for segment anything
The validation on this node causes graph validation to valid. It must be validated _after_ instantiation.

Also, it was a bit too strict. The only case we explicitly do not handle is when both bboxes and points are provided. It's acceptable if neither are provided.

Closes #7248
2024-11-04 08:00:52 -05:00
psychedelicious
7e1e0d6928 fix(ui): non-default filters can erase layer
When filtering, we use a listener to trigger processing the image whenever a filter setting changes. For example, if the user changes from canny to depth, and auto-process is enabled, we re-process the layer with new filter settings.

The filterer has a method to reset its ephemeral state. This includes the filter settings, so resetting the ephemeral state is expected to trigger processing of the filter.

When we exit filtering, we reset the ephemeral state before resetting everything else, like the listeners.

This can cause problem when we exit filtering. The sequence:
- Start filtering a layer.
- Auto-process the filter in response to starting the filter process.
- Change the filter settings.
- Auto-process the filter in response to the changed settings.
- Apply the filter.
- Exit filtering, first by resetting the ephemeral state.
- Auto-process the filter in response to the reset settings.*
- Finish exiting, including unsubscribing from listeners.

*Whoops! That last auto-process has now borked the layer's rendering by processing a filter when we shouldn't be processing a filter.

We need to first unsubscribe from listeners, so we don't react to that change to the filter settings and erroneously process the layer.

Also, add a check to the `processImmediate` method to prevent processing if that method is accidentally called without first starting the filterer.

The same issue could affect the segmenyanything module - same fixes are implemented there.
2024-11-04 07:11:20 -05:00
psychedelicious
cd3d8df5a8 fix(ui): save canvas to gallery does nothing
The root issue is the compositing cache. When we save the canvas to gallery, we need to first composite raster layers together and then upload the image.

The compositor makes extensive use of caching to reduce the number of images created and improve performance. There are two "layers" of caching:
1. Caching the composite canvas element, which is used both for uploading the canvas and for generation mode analysis.
2. Caching the uploaded composite canvas element as an image.

The combination of these caches allows for the various processes that require composite canvases to do minimal work.

But this causes a problem in this situation, because the user expects a new image to be uploaded when they click save to gallery.

For example, suppose we have already composited and uploaded the raster layer state for use in a generation. Then, we ask the compositor to save the canvas to gallery.

The compositor sees that we are requesting an image for the current canvas state, and instead of recompositing and uploading the image again, it just returns the cached image.

In this case, no image is uploaded and it the button does nothing.

We need to be able to opt out of the caching at some level, for certain actions. A `forceUpload` arg is added to the compositor's high-level `getCompositeImageDTO` method to do this.

When true, we ignore the uppermost caching layer (the uploaded image layer), but still use the lower caching layer (the canvas element layer). So we don't recompute the canvas element, but we do upload it as a new image to the server.
2024-11-04 07:11:20 -05:00
psychedelicious
24d3c22017 fix(ui): temp fix for stuck tooltips 2024-11-04 07:11:20 -05:00
psychedelicious
b0d37f4e51 fix(ui): progress image does not reset when canceling generation
Previously, we cleared the canvas progress image when the canvas had no active generations. This allowed for a brief flash of canvas state between the last progress image for a given generation, and when the output image for that generation rendered. Here's the sequence:
- Progress images are received and rendered
- Generation completes - no active canvas generations
- Clear the progress image -> canvas layers visible unexpectedly, creating an awkward jarring change
- Generation output image is rendered -> output image overlaid on canvas layers

In 83538c4b2b I attempted to fix this by only clearing the progress image while we were not staging.

This isn't quite right, though. We are often staging with no active generations - for example, you have a few images completed and are waiting to choose one.

In this situation, if you cancel a pending generation, the logic to clear the progress image doesn't fire because it sees staging is in progress.

What we really need is:
- Staging area module clears the progress image once it has rendered an output image.
- Progress image module clears the progress image when a generation is canceled or failed, in which case there will be no output image.

To do this, we can add an event listener to the progress image module to listen for queue item status changes, and when we get a cancelation or failure, clear the progress image.
2024-11-04 07:11:20 -05:00
psychedelicious
3559124674 feat(ui): use nanostores in CanvasProgressImageModule for internal state 2024-11-04 07:11:20 -05:00
Eugene Brodsky
6c33e02141 fix(pkg): pin torch to <2.5.0 to prevent unnecessary downloads
pip's dependency resolution doesn't take into account transitive
dependencies when choosing package versions for download.
Even though `torch=~2.4.1` is required by `diffusers`, pip will
download 2.5.0 and higher, but only install 2.4.1.
Pinning torch to <2.5.0 prevents this behaviour.
2024-11-01 12:27:28 -04:00
127 changed files with 3518 additions and 630 deletions

View File

@@ -19,3 +19,4 @@
- [ ] _The PR has a short but descriptive title, suitable for a changelog_
- [ ] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_

View File

@@ -40,6 +40,8 @@ class AppVersion(BaseModel):
version: str = Field(description="App version")
highlights: Optional[list[str]] = Field(default=None, description="Highlights of release")
class AppDependencyVersions(BaseModel):
"""App depencency Versions Response"""

View File

@@ -751,7 +751,7 @@ async def convert_model(
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem
converted_model = loader.load_model(model_config)
converted_model = loader.load_model(model_config, queue_id="default")
# write the converted file to the convert path
raw_model = converted_model.model
assert hasattr(raw_model, "save_pretrained")

View File

@@ -31,6 +31,7 @@ from invokeai.app.services.events.events_common import (
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadEventBase,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEventBase,
@@ -53,6 +54,13 @@ class BulkDownloadSubscriptionEvent(BaseModel):
bulk_download_id: str
class ModelLoadSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io model loading room.
This is a pydantic model to ensure the data is in the correct format."""
queue_id: str
QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationProgressEvent,
@@ -69,8 +77,6 @@ MODEL_EVENTS = {
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
@@ -79,6 +85,11 @@ MODEL_EVENTS = {
ModelInstallErrorEvent,
}
MODEL_LOAD_EVENTS = {
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
}
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
@@ -101,6 +112,7 @@ class SocketIO:
register_events(QUEUE_EVENTS, self._handle_queue_event)
register_events(MODEL_EVENTS, self._handle_model_event)
register_events(MODEL_LOAD_EVENTS, self._handle_model_load_event)
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
@@ -115,9 +127,18 @@ class SocketIO:
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_sub_model_load(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, ModelLoadSubscriptionEvent(**data).queue_id)
async def _handle_unsub_model_load(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, ModelLoadSubscriptionEvent(**data).queue_id)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
async def _handle_model_load_event(self, event: FastAPIEvent[ModelLoadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))

View File

@@ -63,12 +63,12 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
tokenizer_info = context.models.load(self.clip.tokenizer, queue_id=context.util.get_queue_id())
text_encoder_info = context.models.load(self.clip.text_encoder, queue_id=context.util.get_queue_id())
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, queue_id=context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
@@ -137,8 +137,8 @@ class SDXLPromptInvocationBase:
lora_prefix: str,
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
tokenizer_info = context.models.load(clip_field.tokenizer, queue_id=context.util.get_queue_id())
text_encoder_info = context.models.load(clip_field.text_encoder, queue_id=context.util.get_queue_id())
# return zero on empty
if prompt == "" and zero_on_empty:
@@ -163,7 +163,7 @@ class SDXLPromptInvocationBase:
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_field.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
lora_model = lora_info.model
assert isinstance(lora_model, LoRAModelRaw)
yield (lora_model, lora.weight)

View File

@@ -649,7 +649,9 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
return DepthAnythingPipeline(depth_anything_pipeline)
with self._context.models.load_remote_model(
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
source=DEPTH_ANYTHING_MODELS[self.model_size],
queue_id=self._context.util.get_queue_id(),
loader=load_depth_anything,
) as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
depth_map = depth_anything_detector.generate_depth(image)

View File

@@ -60,7 +60,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
)
if image_tensor is not None:
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)

View File

@@ -124,7 +124,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
assert isinstance(main_model_config, MainConfigBase)
if main_model_config.variant is ModelVariantType.Inpaint:
mask = blur_tensor
vae_info: LoadedModel = context.models.load(self.vae.vae)
vae_info: LoadedModel = context.models.load(self.vae.vae, context.util.get_queue_id())
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:

View File

@@ -88,7 +88,7 @@ def get_scheduler(
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
# possible.
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
orig_scheduler_info = context.models.load(scheduler_info)
orig_scheduler_info = context.models.load(scheduler_info, context.util.get_queue_id())
with orig_scheduler_info as orig_scheduler:
scheduler_config = orig_scheduler.config
@@ -435,7 +435,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
controlnet_data: list[ControlNetData] = []
for control_info in control_list:
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
control_model = exit_stack.enter_context(
context.models.load(control_info.control_model, context.util.get_queue_id())
)
assert isinstance(control_model, ControlNetModel)
control_image_field = control_info.image
@@ -492,7 +494,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
for control_info in control_list:
model = exit_stack.enter_context(context.models.load(control_info.control_model))
model = exit_stack.enter_context(
context.models.load(control_info.control_model, context.util.get_queue_id())
)
ext_manager.add_extension(
ControlNetExt(
model=model,
@@ -545,9 +549,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
image_prompts = []
for single_ip_adapter in ip_adapters:
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
with context.models.load(
single_ip_adapter.ip_adapter_model, context.util.get_queue_id()
) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
image_encoder_model_info = context.models.load(
single_ip_adapter.image_encoder_model, context.util.get_queue_id()
)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
@@ -581,7 +589,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
ip_adapters, image_prompts, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
ip_adapter_model = exit_stack.enter_context(
context.models.load(single_ip_adapter.ip_adapter_model, context.util.get_queue_id())
)
mask_field = single_ip_adapter.mask
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
@@ -621,7 +631,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key)
t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model)
t2i_adapter_loaded_model = context.models.load(
t2i_adapter_field.t2i_adapter_model, context.util.get_queue_id()
)
image = context.images.get_pil(t2i_adapter_field.image.image_name, mode="RGB")
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
@@ -926,7 +938,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
unet_info = context.models.load(self.unet.unet)
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (cached_weights, unet),
@@ -989,13 +1001,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
return
unet_info = context.models.load(self.unet.unet)
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,

View File

@@ -35,7 +35,9 @@ class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithB
model_url = DEPTH_ANYTHING_MODELS[self.model_size]
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(model_url, DepthAnythingPipeline.load_model)
loaded_model = context.models.load_remote_model(
model_url, context.util.get_queue_id(), DepthAnythingPipeline.load_model
)
with loaded_model as depth_anything_detector:
assert isinstance(depth_anything_detector, DepthAnythingPipeline)

View File

@@ -29,10 +29,10 @@ class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
onnx_pose_path = context.models.download_and_cache_model(DWOpenposeDetector2.get_model_url_pose())
loaded_session_det = context.models.load_local_model(
onnx_det_path, DWOpenposeDetector2.create_onnx_inference_session
onnx_det_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
)
loaded_session_pose = context.models.load_local_model(
onnx_pose_path, DWOpenposeDetector2.create_onnx_inference_session
onnx_pose_path, context.util.get_queue_id(), DWOpenposeDetector2.create_onnx_inference_session
)
with loaded_session_det as session_det, loaded_session_pose as session_pose:

View File

@@ -41,6 +41,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
# region Model Field Types
MainModel = "MainModelField"
FluxMainModel = "FluxMainModelField"
SD3MainModel = "SD3MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
ONNXModel = "ONNXModelField"
@@ -52,6 +53,8 @@ class UIType(str, Enum, metaclass=MetaEnum):
T2IAdapterModel = "T2IAdapterModelField"
T5EncoderModel = "T5EncoderModelField"
CLIPEmbedModel = "CLIPEmbedModelField"
CLIPLEmbedModel = "CLIPLEmbedModelField"
CLIPGEmbedModel = "CLIPGEmbedModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion
@@ -131,8 +134,10 @@ class FieldDescriptions:
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
t5_encoder = "T5 tokenizer and text encoder"
clip_embed_model = "CLIP Embed loader"
clip_g_model = "CLIP-G Embed loader"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
mmditx = "MMDiTX"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
@@ -140,6 +145,7 @@ class FieldDescriptions:
lora_model = "LoRA model to load"
main_model = "Main model (UNet, VAE, CLIP) to load"
flux_model = "Flux model (Transformer) to load"
sd3_model = "SD3 model (MMDiTX) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@@ -246,6 +252,12 @@ class FluxConditioningField(BaseModel):
conditioning_name: str = Field(description="The name of conditioning tensor")
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""

View File

@@ -183,7 +183,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
seed=self.seed,
)
transformer_info = context.models.load(self.transformer.transformer)
transformer_info = context.models.load(self.transformer.transformer, context.util.get_queue_id())
is_schnell = "schnell" in transformer_info.config.config_path
# Calculate the timestep schedule.
@@ -468,7 +468,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# minimize peak memory.
# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
controlnet_models = [
context.models.load(controlnet.control_model, context.util.get_queue_id()) for controlnet in controlnets
]
# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
@@ -479,7 +481,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
if isinstance(controlnet_model.model, InstantXControlNetFlux):
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
vae_info = context.models.load(self.controlnet_vae.vae)
vae_info = context.models.load(self.controlnet_vae.vae, context.util.get_queue_id())
controlnet_conds.append(
InstantXControlNetExtension.prepare_controlnet_cond(
controlnet_image=image,
@@ -590,7 +592,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
pos_images.append(pos_image)
neg_images.append(neg_image)
with context.models.load(ip_adapter_field.image_encoder_model) as image_encoder_model:
with context.models.load(
ip_adapter_field.image_encoder_model, context.util.get_queue_id()
) as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
@@ -620,7 +624,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
for ip_adapter_field, pos_image_prompt_clip_embed, neg_image_prompt_clip_embed in zip(
ip_adapter_fields, pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(ip_adapter_field.ip_adapter_model))
ip_adapter_model = exit_stack.enter_context(
context.models.load(ip_adapter_field.ip_adapter_model, context.util.get_queue_id())
)
assert isinstance(ip_adapter_model, XlabsIpAdapterFlux)
ip_adapter_model = ip_adapter_model.to(dtype=dtype)
if ip_adapter_field.mask is not None:
@@ -649,7 +655,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -0,0 +1,89 @@
from typing import Literal
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
SubModelType,
)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)

View File

@@ -57,8 +57,8 @@ class FluxTextEncoderInvocation(BaseInvocation):
return FluxConditioningOutput.build(conditioning_name)
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder)
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer, context.util.get_queue_id())
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder, context.util.get_queue_id())
prompt = [self.prompt]
@@ -77,8 +77,8 @@ class FluxTextEncoderInvocation(BaseInvocation):
return prompt_embeds
def _clip_encode(self, context: InvocationContext) -> torch.Tensor:
clip_tokenizer_info = context.models.load(self.clip.tokenizer)
clip_text_encoder_info = context.models.load(self.clip.text_encoder)
clip_tokenizer_info = context.models.load(self.clip.tokenizer, context.util.get_queue_id())
clip_text_encoder_info = context.models.load(self.clip.text_encoder, context.util.get_queue_id())
prompt = [self.prompt]
@@ -118,7 +118,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -52,7 +52,7 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
image = self._vae_decode(vae_info=vae_info, latents=latents)
TorchDevice.empty_cache()

View File

@@ -54,7 +54,7 @@ class FluxVaeEncodeInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:

View File

@@ -94,7 +94,9 @@ class GroundingDinoInvocation(BaseInvocation):
labels = [label if label.endswith(".") else label + "." for label in labels]
with context.models.load_remote_model(
source=GROUNDING_DINO_MODEL_IDS[self.model], loader=GroundingDinoInvocation._load_grounding_dino
source=GROUNDING_DINO_MODEL_IDS[self.model],
queue_id=context.util.get_queue_id(),
loader=GroundingDinoInvocation._load_grounding_dino,
) as detector:
assert isinstance(detector, GroundingDinoPipeline)
return detector.detect(image=image, candidate_labels=labels, threshold=threshold)

View File

@@ -22,7 +22,9 @@ class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(HEDEdgeDetector.get_model_url(), HEDEdgeDetector.load_model)
loaded_model = context.models.load_remote_model(
HEDEdgeDetector.get_model_url(), context.util.get_queue_id(), HEDEdgeDetector.load_model
)
with loaded_model as model:
assert isinstance(model, ControlNetHED_Apache2)

View File

@@ -111,7 +111,7 @@ class ImageToLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:

View File

@@ -36,7 +36,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to process")
@abstractmethod
def infill(self, image: Image.Image) -> Image.Image:
def infill(self, image: Image.Image, queue_id: str) -> Image.Image:
"""Infill the image with the specified method"""
pass
@@ -56,7 +56,7 @@ class InfillImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
return ImageOutput.build(context.images.get_dto(self.image.image_name))
# Perform Infill action
infilled_image = self.infill(input_image)
infilled_image = self.infill(input_image, context.util.get_queue_id())
# Create ImageDTO for Infilled Image
infilled_image_dto = context.images.save(image=infilled_image)
@@ -74,7 +74,7 @@ class InfillColorInvocation(InfillImageProcessorInvocation):
description="The color to use to infill",
)
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, queue_id: str):
solid_bg = Image.new("RGBA", image.size, self.color.tuple())
infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))
infilled.paste(image, (0, 0), image.split()[-1])
@@ -93,7 +93,7 @@ class InfillTileInvocation(InfillImageProcessorInvocation):
description="The seed to use for tile generation (omit for random)",
)
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, queue_id: str):
output = infill_tile(image, seed=self.seed, tile_size=self.tile_size)
return output.infilled
@@ -107,7 +107,7 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, queue_id: str):
resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]
width = int(image.width / self.downscale)
@@ -131,9 +131,10 @@ class InfillPatchMatchInvocation(InfillImageProcessorInvocation):
class LaMaInfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using the LaMa model"""
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, queue_id: str):
with self._context.models.load_remote_model(
source="https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
queue_id=queue_id,
loader=LaMA.load_jit_model,
) as model:
lama = LaMA(model)
@@ -144,7 +145,7 @@ class LaMaInfillInvocation(InfillImageProcessorInvocation):
class CV2InfillInvocation(InfillImageProcessorInvocation):
"""Infills transparent areas of an image using OpenCV Inpainting"""
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, queue_id: str):
return cv2_inpaint(image)
@@ -166,5 +167,5 @@ class MosaicInfillInvocation(InfillImageProcessorInvocation):
description="The max threshold for color",
)
def infill(self, image: Image.Image):
def infill(self, image: Image.Image, queue_id: str):
return infill_mosaic(image, (self.tile_width, self.tile_height), self.min_color.tuple(), self.max_color.tuple())

View File

@@ -57,7 +57,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))

View File

@@ -23,7 +23,9 @@ class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
model_url = LineartEdgeDetector.get_model_url(self.coarse)
loaded_model = context.models.load_remote_model(model_url, LineartEdgeDetector.load_model)
loaded_model = context.models.load_remote_model(
model_url, context.util.get_queue_id(), LineartEdgeDetector.load_model
)
with loaded_model as model:
assert isinstance(model, Generator)

View File

@@ -20,7 +20,9 @@ class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoar
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
model_url = LineartAnimeEdgeDetector.get_model_url()
loaded_model = context.models.load_remote_model(model_url, LineartAnimeEdgeDetector.load_model)
loaded_model = context.models.load_remote_model(
model_url, context.util.get_queue_id(), LineartAnimeEdgeDetector.load_model
)
with loaded_model as model:
assert isinstance(model, UnetGenerator)

View File

@@ -28,7 +28,9 @@ class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(MLSDDetector.get_model_url(), MLSDDetector.load_model)
loaded_model = context.models.load_remote_model(
MLSDDetector.get_model_url(), context.util.get_queue_id(), MLSDDetector.load_model
)
with loaded_model as model:
assert isinstance(model, MobileV2_MLSD_Large)

View File

@@ -1,5 +1,5 @@
import copy
from typing import List, Literal, Optional
from typing import List, Optional
from pydantic import BaseModel, Field
@@ -13,11 +13,9 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.flux.util import max_seq_lengths
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
CheckpointConfigBase,
ModelType,
SubModelType,
)
@@ -139,78 +137,6 @@ class ModelIdentifierInvocation(BaseInvocation):
return ModelIdentifierOutput(model=self.model)
@invocation_output("flux_model_loader_output")
class FluxModelLoaderOutput(BaseInvocationOutput):
"""Flux base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
max_seq_len: Literal[256, 512] = OutputField(
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)",
title="Max Seq Length",
)
@invocation(
"flux_model_loader",
title="Flux Main Model",
tags=["model", "flux"],
category="model",
version="1.0.4",
classification=Classification.Prototype,
)
class FluxModelLoaderInvocation(BaseInvocation):
"""Loads a flux base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.flux_model,
ui_type=UIType.FluxMainModel,
input=Input.Direct,
)
t5_encoder_model: ModelIdentifierField = InputField(
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder"
)
clip_embed_model: ModelIdentifierField = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPEmbedModel,
input=Input.Direct,
title="CLIP Embed",
)
vae_model: ModelIdentifierField = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE"
)
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput:
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]:
if not context.models.exists(key):
raise ValueError(f"Unknown model: {key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
transformer_config = context.models.get_config(transformer)
assert isinstance(transformer_config, CheckpointConfigBase)
return FluxModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
max_seq_len=max_seq_lengths[transformer_config.config_path],
)
@invocation(
"main_model_loader",
title="Main Model",

View File

@@ -20,7 +20,9 @@ class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(NormalMapDetector.get_model_url(), NormalMapDetector.load_model)
loaded_model = context.models.load_remote_model(
NormalMapDetector.get_model_url(), context.util.get_queue_id(), NormalMapDetector.load_model
)
with loaded_model as model:
assert isinstance(model, NNET)

View File

@@ -22,7 +22,9 @@ class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
loaded_model = context.models.load_remote_model(PIDINetDetector.get_model_url(), PIDINetDetector.load_model)
loaded_model = context.models.load_remote_model(
PIDINetDetector.get_model_url(), context.util.get_queue_id(), PIDINetDetector.load_model
)
with loaded_model as model:
assert isinstance(model, PiDiNet)

View File

@@ -18,6 +18,7 @@ from invokeai.app.invocations.fields import (
InputField,
LatentsField,
OutputField,
SD3ConditioningField,
TensorField,
UIComponent,
)
@@ -426,6 +427,17 @@ class FluxConditioningOutput(BaseInvocationOutput):
return cls(conditioning=FluxConditioningField(conditioning_name=conditioning_name))
@invocation_output("sd3_conditioning_output")
class SD3ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single SD3 conditioning tensor"""
conditioning: SD3ConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "SD3ConditioningOutput":
return cls(conditioning=SD3ConditioningField(conditioning_name=conditioning_name))
@invocation_output("conditioning_output")
class ConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a single conditioning tensor"""

View File

@@ -0,0 +1,260 @@
from typing import Callable, Tuple
import torch
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
SD3ConditioningField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"sd3_denoise",
title="SD3 Denoise",
tags=["image", "sd3"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a SD3 model."""
transformer: TransformerField = InputField(
description=FieldDescriptions.sd3_model,
input=Input.Connection,
title="Transformer",
)
positive_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
cfg_scale: float | list[float] = InputField(default=3.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _load_text_conditioning(
self,
context: InvocationContext,
conditioning_name: str,
joint_attention_dim: int,
dtype: torch.dtype,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Load the conditioning data.
cond_data = context.conditioning.load(conditioning_name)
assert len(cond_data.conditionings) == 1
sd3_conditioning = cond_data.conditionings[0]
assert isinstance(sd3_conditioning, SD3ConditioningInfo)
sd3_conditioning = sd3_conditioning.to(dtype=dtype, device=device)
t5_embeds = sd3_conditioning.t5_embeds
if t5_embeds is None:
t5_embeds = torch.zeros(
(1, SD3_T5_MAX_SEQ_LEN, joint_attention_dim),
device=device,
dtype=dtype,
)
clip_prompt_embeds = torch.cat([sd3_conditioning.clip_l_embeds, sd3_conditioning.clip_g_embeds], dim=-1)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds, (0, t5_embeds.shape[-1] - clip_prompt_embeds.shape[-1])
)
prompt_embeds = torch.cat([clip_prompt_embeds, t5_embeds], dim=-2)
pooled_prompt_embeds = torch.cat(
[sd3_conditioning.clip_l_pooled_embeds, sd3_conditioning.clip_g_pooled_embeds], dim=-1
)
return prompt_embeds, pooled_prompt_embeds
def _get_noise(
self,
num_samples: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
seed: int,
) -> torch.Tensor:
# We always generate noise on the same device and dtype then cast to ensure consistency across devices/dtypes.
rand_device = "cpu"
rand_dtype = torch.float16
return torch.randn(
num_samples,
num_channels_latents,
int(height) // LATENT_SCALE_FACTOR,
int(width) // LATENT_SCALE_FACTOR,
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
"""Prepare the CFG scale list.
Args:
num_timesteps (int): The number of timesteps in the scheduler. Could be different from num_steps depending
on the scheduler used (e.g. higher order schedulers).
Returns:
list[float]: _description_
"""
if isinstance(self.cfg_scale, float):
cfg_scale = [self.cfg_scale] * num_timesteps
elif isinstance(self.cfg_scale, list):
assert len(self.cfg_scale) == num_timesteps
cfg_scale = self.cfg_scale
else:
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
return cfg_scale
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = TorchDevice.choose_torch_dtype()
device = TorchDevice.choose_torch_device()
transformer_info = context.models.load(self.transformer.transformer, context.util.get_queue_id())
# Load/process the conditioning data.
# TODO(ryand): Make CFG optional.
do_classifier_free_guidance = True
pos_prompt_embeds, pos_pooled_prompt_embeds = self._load_text_conditioning(
context=context,
conditioning_name=self.positive_conditioning.conditioning_name,
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
dtype=inference_dtype,
device=device,
)
neg_prompt_embeds, neg_pooled_prompt_embeds = self._load_text_conditioning(
context=context,
conditioning_name=self.negative_conditioning.conditioning_name,
joint_attention_dim=transformer_info.model.config.joint_attention_dim,
dtype=inference_dtype,
device=device,
)
# TODO(ryand): Support both sequential and batched CFG inference.
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
# Prepare the scheduler.
scheduler = FlowMatchEulerDiscreteScheduler()
scheduler.set_timesteps(num_inference_steps=self.steps, device=device)
timesteps = scheduler.timesteps
assert isinstance(timesteps, torch.Tensor)
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(len(timesteps))
# Generate initial latent noise.
num_channels_latents = transformer_info.model.config.in_channels
assert isinstance(num_channels_latents, int)
noise = self._get_noise(
num_samples=1,
num_channels_latents=num_channels_latents,
height=self.height,
width=self.width,
dtype=inference_dtype,
device=device,
seed=self.seed,
)
latents: torch.Tensor = noise
total_steps = len(timesteps)
step_callback = self._build_step_callback(context)
step_callback(
PipelineIntermediateState(
step=0,
order=1,
total_steps=total_steps,
timestep=int(timesteps[0]),
latents=latents,
),
)
with transformer_info.model_on_device() as (cached_weights, transformer):
assert isinstance(transformer, SD3Transformer2DModel)
# 6. Denoising loop
for step_idx, t in tqdm(list(enumerate(timesteps))):
# Expand the latents if we are doing CFG.
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# Expand the timestep to match the latent model input.
timestep = t.expand(latent_model_input.shape[0])
noise_pred = transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=None,
return_dict=False,
)[0]
# Apply CFG.
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
# Compute the previous noisy sample x_t -> x_t-1.
latents_dtype = latents.dtype
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
# needed.
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(t),
latents=latents,
),
)
return latents
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, BaseModelType.StableDiffusion3)
return step_callback

View File

@@ -0,0 +1,73 @@
from contextlib import nullcontext
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice
@invocation(
"sd3_l2i",
title="SD3 Latents to Image",
tags=["latents", "image", "vae", "l2i", "sd3"],
category="latents",
version="1.3.0",
)
class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
latents: LatentsField = InputField(
description=FieldDescriptions.latents,
input=Input.Connection,
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae, context.util.get_queue_id())
assert isinstance(vae_info.model, (AutoencoderKL))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL))
latents = latents.to(vae.device)
vae.disable_tiling()
tiling_context = nullcontext()
# clear memory as vae decode can request a lot
TorchDevice.empty_cache()
with torch.inference_mode(), tiling_context:
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
img = vae.decode(latents, return_dict=False)[0]
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c") # noqa: F821
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
TorchDevice.empty_cache()
image_dto = context.images.save(image=img_pil)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,108 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import SubModelType
@invocation_output("sd3_model_loader_output")
class Sd3ModelLoaderOutput(BaseInvocationOutput):
"""SD3 base model loader output."""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip_l: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP L")
clip_g: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP G")
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"sd3_model_loader",
title="SD3 Main Model",
tags=["model", "sd3"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3ModelLoaderInvocation(BaseInvocation):
"""Loads a SD3 base model, outputting its submodels."""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.sd3_model,
ui_type=UIType.SD3MainModel,
input=Input.Direct,
)
t5_encoder_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.t5_encoder,
ui_type=UIType.T5EncoderModel,
input=Input.Direct,
title="T5 Encoder",
default=None,
)
clip_l_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.clip_embed_model,
ui_type=UIType.CLIPLEmbedModel,
input=Input.Direct,
title="CLIP L Encoder",
default=None,
)
clip_g_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.clip_g_model,
ui_type=UIType.CLIPGEmbedModel,
input=Input.Direct,
title="CLIP G Encoder",
default=None,
)
vae_model: Optional[ModelIdentifierField] = InputField(
description=FieldDescriptions.vae_model, ui_type=UIType.VAEModel, title="VAE", default=None
)
def invoke(self, context: InvocationContext) -> Sd3ModelLoaderOutput:
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
vae = (
self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE})
if self.vae_model
else self.model.model_copy(update={"submodel_type": SubModelType.VAE})
)
tokenizer_l = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
clip_encoder_l = (
self.clip_l_model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
if self.clip_l_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
)
tokenizer_g = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
clip_encoder_g = (
self.clip_g_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
if self.clip_g_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
)
tokenizer_t5 = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
)
t5_encoder = (
self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
if self.t5_encoder_model
else self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
)
return Sd3ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
clip_l=CLIPField(tokenizer=tokenizer_l, text_encoder=clip_encoder_l, loras=[], skipped_layers=0),
clip_g=CLIPField(tokenizer=tokenizer_g, text_encoder=clip_encoder_g, loras=[], skipped_layers=0),
t5_encoder=T5EncoderField(tokenizer=tokenizer_t5, text_encoder=t5_encoder),
vae=VAEField(vae=vae),
)

View File

@@ -0,0 +1,199 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
import torch
from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
T5EncoderModel,
T5Tokenizer,
T5TokenizerFast,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import SD3ConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
# The SD3 T5 Max Sequence Length set based on the default in diffusers.
SD3_T5_MAX_SEQ_LEN = 256
@invocation(
"sd3_text_encoder",
title="SD3 Text Encoding",
tags=["prompt", "conditioning", "sd3"],
category="conditioning",
version="1.0.0",
classification=Classification.Prototype,
)
class Sd3TextEncoderInvocation(BaseInvocation):
"""Encodes and preps a prompt for a SD3 image."""
clip_l: CLIPField = InputField(
title="CLIP L",
description=FieldDescriptions.clip,
input=Input.Connection,
)
clip_g: CLIPField = InputField(
title="CLIP G",
description=FieldDescriptions.clip,
input=Input.Connection,
)
# The SD3 models were trained with text encoder dropout, so the T5 encoder can be omitted to save time/memory.
t5_encoder: T5EncoderField | None = InputField(
title="T5Encoder",
default=None,
description=FieldDescriptions.t5_encoder,
input=Input.Connection,
)
prompt: str = InputField(description="Text prompt to encode.")
@torch.no_grad()
def invoke(self, context: InvocationContext) -> SD3ConditioningOutput:
# Note: The text encoding model are run in separate functions to ensure that all model references are locally
# scoped. This ensures that earlier models can be freed and gc'd before loading later models (if necessary).
clip_l_embeddings, clip_l_pooled_embeddings = self._clip_encode(context, self.clip_l)
clip_g_embeddings, clip_g_pooled_embeddings = self._clip_encode(context, self.clip_g)
t5_embeddings: torch.Tensor | None = None
if self.t5_encoder is not None:
t5_embeddings = self._t5_encode(context, SD3_T5_MAX_SEQ_LEN)
conditioning_data = ConditioningFieldData(
conditionings=[
SD3ConditioningInfo(
clip_l_embeds=clip_l_embeddings,
clip_l_pooled_embeds=clip_l_pooled_embeddings,
clip_g_embeds=clip_g_embeddings,
clip_g_pooled_embeds=clip_g_pooled_embeddings,
t5_embeds=t5_embeddings,
)
]
)
conditioning_name = context.conditioning.save(conditioning_data)
return SD3ConditioningOutput.build(conditioning_name)
def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tensor:
assert self.t5_encoder is not None
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer, context.util.get_queue_id())
t5_text_encoder_info = context.models.load(self.t5_encoder.text_encoder, context.util.get_queue_id())
prompt = [self.prompt]
with (
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
text_inputs = t5_tokenizer(
prompt,
padding="max_length",
max_length=max_seq_len,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = t5_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = t5_tokenizer.batch_decode(untruncated_ids[:, max_seq_len - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_seq_len} tokens: {removed_text}"
)
prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
def _clip_encode(
self, context: InvocationContext, clip_model: CLIPField, tokenizer_max_length: int = 77
) -> Tuple[torch.Tensor, torch.Tensor]:
clip_tokenizer_info = context.models.load(clip_model.tokenizer, context.util.get_queue_id())
clip_text_encoder_info = context.models.load(clip_model.text_encoder, context.util.get_queue_id())
prompt = [self.prompt]
with (
clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder),
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(clip_tokenizer, CLIPTokenizer)
clip_text_encoder_config = clip_text_encoder_info.config
assert clip_text_encoder_config is not None
# Apply LoRA models to the CLIP encoder.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
if clip_text_encoder_config.format in [ModelFormat.Diffusers]:
# The model is non-quantized, so we can apply the LoRA weights directly into the model.
exit_stack.enter_context(
LoRAPatcher.apply_lora_patches(
model=clip_text_encoder,
patches=self._clip_lora_iterator(context, clip_model),
prefix=FLUX_LORA_CLIP_PREFIX,
cached_weights=cached_weights,
)
)
else:
# There are currently no supported CLIP quantized models. Add support here if needed.
raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}")
clip_text_encoder = clip_text_encoder.eval().requires_grad_(False)
text_inputs = clip_tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
assert isinstance(text_input_ids, torch.Tensor)
assert isinstance(untruncated_ids, torch.Tensor)
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = clip_tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
context.logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = clip_text_encoder(
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
return prompt_embeds, pooled_prompt_embeds
def _clip_lora_iterator(
self, context: InvocationContext, clip_model: CLIPField
) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in clip_model.loras:
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -5,7 +5,7 @@ from typing import Literal
import numpy as np
import torch
from PIL import Image
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from transformers import AutoModelForMaskGeneration, AutoProcessor
from transformers.models.sam import SamModel
from transformers.models.sam.processing_sam import SamProcessor
@@ -77,19 +77,14 @@ class SegmentAnythingInvocation(BaseInvocation):
default="all",
)
@model_validator(mode="after")
def check_point_lists_or_bounding_box(self):
if self.point_lists is None and self.bounding_boxes is None:
raise ValueError("Either point_lists or bounding_box must be provided.")
elif self.point_lists is not None and self.bounding_boxes is not None:
raise ValueError("Only one of point_lists or bounding_box can be provided.")
return self
@torch.no_grad()
def invoke(self, context: InvocationContext) -> MaskOutput:
# The models expect a 3-channel RGB image.
image_pil = context.images.get_pil(self.image.image_name, mode="RGB")
if self.point_lists is not None and self.bounding_boxes is not None:
raise ValueError("Only one of point_lists or bounding_box can be provided.")
if (not self.bounding_boxes or len(self.bounding_boxes) == 0) and (
not self.point_lists or len(self.point_lists) == 0
):
@@ -130,7 +125,9 @@ class SegmentAnythingInvocation(BaseInvocation):
with (
context.models.load_remote_model(
source=SEGMENT_ANYTHING_MODEL_IDS[self.model], loader=SegmentAnythingInvocation._load_sam_model
source=SEGMENT_ANYTHING_MODEL_IDS[self.model],
queue_id=context.util.get_queue_id(),
loader=SegmentAnythingInvocation._load_sam_model,
) as sam_pipeline,
):
assert isinstance(sam_pipeline, SegmentAnythingPipeline)

View File

@@ -158,7 +158,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
spandrel_model_info = context.models.load(self.image_to_image_model, context.util.get_queue_id())
def step_callback(step: int, total_steps: int) -> None:
context.util.signal_progress(
@@ -207,7 +207,7 @@ class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
spandrel_model_info = context.models.load(self.image_to_image_model, context.util.get_queue_id())
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.

View File

@@ -196,13 +196,13 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
lora_info = context.models.load(lora.lora, context.util.get_queue_id())
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
unet_info = context.models.load(self.unet.unet, context.util.get_queue_id())
with (
ExitStack() as exit_stack,

View File

@@ -90,7 +90,7 @@ class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard):
raise ValueError(msg)
loadnet = context.models.load_remote_model(
source=ESRGAN_MODEL_URLS[self.model_name],
source=ESRGAN_MODEL_URLS[self.model_name], queue_id=context.util.get_queue_id()
)
with loadnet as loadnet_model:

View File

@@ -131,15 +131,17 @@ class EventServiceBase:
# region Model loading
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
def emit_model_load_started(
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
) -> None:
"""Emitted when a model load is started."""
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
self.dispatch(ModelLoadStartedEvent.build(config, queue_id, submodel_type))
def emit_model_load_complete(
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
self, config: "AnyModelConfig", queue_id: str, submodel_type: Optional["SubModelType"] = None
) -> None:
"""Emitted when a model load is complete."""
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
self.dispatch(ModelLoadCompleteEvent.build(config, queue_id, submodel_type))
# endregion

View File

@@ -383,12 +383,14 @@ class DownloadErrorEvent(DownloadEventBase):
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
class ModelLoadEventBase(EventBase):
"""Base class for queue events"""
queue_id: str = Field(description="The ID of the queue")
@payload_schema.register
class ModelLoadStartedEvent(ModelEventBase):
class ModelLoadStartedEvent(ModelLoadEventBase):
"""Event model for model_load_started"""
__event_name__ = "model_load_started"
@@ -397,12 +399,14 @@ class ModelLoadStartedEvent(ModelEventBase):
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
return cls(config=config, submodel_type=submodel_type)
def build(
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> "ModelLoadStartedEvent":
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
@payload_schema.register
class ModelLoadCompleteEvent(ModelEventBase):
class ModelLoadCompleteEvent(ModelLoadEventBase):
"""Event model for model_load_complete"""
__event_name__ = "model_load_complete"
@@ -411,8 +415,14 @@ class ModelLoadCompleteEvent(ModelEventBase):
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
return cls(config=config, submodel_type=submodel_type)
def build(
cls, config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> "ModelLoadCompleteEvent":
return cls(config=config, queue_id=queue_id, submodel_type=submodel_type)
class ModelEventBase(EventBase):
"""Base class for model events"""
@payload_schema.register

View File

@@ -14,7 +14,9 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader."""
@abstractmethod
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
@@ -29,7 +31,7 @@ class ModelLoadServiceBase(ABC):
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
self, model_path: Path, queue_id: str, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
"""
Load the model file or directory located at the indicated Path.

View File

@@ -49,7 +49,9 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader."""
return self._ram_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(
self, model_config: AnyModelConfig, queue_id: str, submodel_type: Optional[SubModelType] = None
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
@@ -60,7 +62,7 @@ class ModelLoadService(ModelLoadServiceBase):
# We don't have an invoker during testing
# TODO(psyche): Mock this method on the invoker in the tests
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
self._invoker.services.events.emit_model_load_started(model_config, queue_id, submodel_type)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation(
@@ -70,12 +72,12 @@ class ModelLoadService(ModelLoadServiceBase):
).load_model(model_config, submodel_type)
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
self._invoker.services.events.emit_model_load_complete(model_config, queue_id, submodel_type)
return loaded_model
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
self, model_path: Path, queue_id: str, loader: Optional[Callable[[Path], AnyModel]] = None
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
ram_cache = self.ram_cache

View File

@@ -15,6 +15,7 @@ from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ClipVariantType,
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
ModelFormat,
@@ -85,7 +86,7 @@ class ModelRecordChanges(BaseModelExcludeNull):
# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[ModelVariantType] = Field(description="The variant of the model.", default=None)
variant: Optional[ModelVariantType | ClipVariantType] = Field(description="The variant of the model.", default=None)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None
)

View File

@@ -351,7 +351,10 @@ class ModelsInterface(InvocationContextInterface):
return self._services.model_manager.store.exists(identifier.key)
def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
self,
identifier: Union[str, "ModelIdentifierField"],
queue_id: str,
submodel_type: Optional[SubModelType] = None,
) -> LoadedModel:
"""Load a model.
@@ -368,14 +371,19 @@ class ModelsInterface(InvocationContextInterface):
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type)
return self._services.model_manager.load.load_model(model, queue_id, submodel_type)
else:
_submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type)
return self._services.model_manager.load.load_model(model, queue_id, _submodel_type)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
self,
name: str,
base: BaseModelType,
type: ModelType,
queue_id: str,
submodel_type: Optional[SubModelType] = None,
) -> LoadedModel:
"""Load a model by its attributes.
@@ -397,7 +405,7 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type)
return self._services.model_manager.load.load_model(configs[0], queue_id, submodel_type)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Get a model's config.
@@ -472,6 +480,7 @@ class ModelsInterface(InvocationContextInterface):
def load_local_model(
self,
model_path: Path,
queue_id: str,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
@@ -489,11 +498,14 @@ class ModelsInterface(InvocationContextInterface):
Returns:
A LoadedModelWithoutConfig object.
"""
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
return self._services.model_manager.load.load_model_from_path(
model_path=model_path, queue_id=queue_id, loader=loader
)
def load_remote_model(
self,
source: str | AnyHttpUrl,
queue_id: str,
loader: Optional[Callable[[Path], AnyModel]] = None,
) -> LoadedModelWithoutConfig:
"""
@@ -514,7 +526,9 @@ class ModelsInterface(InvocationContextInterface):
A LoadedModelWithoutConfig object.
"""
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
return self._services.model_manager.load.load_model_from_path(
model_path=model_path, queue_id=queue_id, loader=loader
)
class ConfigInterface(InvocationContextInterface):
@@ -535,6 +549,14 @@ class UtilInterface(InvocationContextInterface):
super().__init__(services, data)
self._is_canceled = is_canceled
def get_queue_id(self) -> str:
"""Checks if the current session has been canceled.
Returns:
True if the current session has been canceled, False if not.
"""
return self._data.queue_item.queue_id
def is_canceled(self) -> bool:
"""Checks if the current session has been canceled.

View File

@@ -0,0 +1,382 @@
{
"name": "SD3.5 Text to Image",
"author": "InvokeAI",
"description": "Sample text to image workflow for Stable Diffusion 3.5",
"version": "1.0.0",
"contact": "invoke@invoke.ai",
"tags": "text2image, SD3.5, default",
"notes": "",
"exposedFields": [
{
"nodeId": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"fieldName": "model"
},
{
"nodeId": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"fieldName": "prompt"
}
],
"meta": {
"version": "3.0.0",
"category": "default"
},
"id": "e3a51d6b-8208-4d6d-b187-fcfe8b32934c",
"nodes": [
{
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"type": "invocation",
"data": {
"id": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"type": "sd3_model_loader",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"model": {
"name": "model",
"label": "",
"value": {
"key": "f7b20be9-92a8-4cfb-bca4-6c3b5535c10b",
"hash": "placeholder",
"name": "stable-diffusion-3.5-medium",
"base": "sd-3",
"type": "main"
}
},
"t5_encoder_model": {
"name": "t5_encoder_model",
"label": ""
},
"clip_l_model": {
"name": "clip_l_model",
"label": ""
},
"clip_g_model": {
"name": "clip_g_model",
"label": ""
},
"vae_model": {
"name": "vae_model",
"label": ""
}
}
},
"position": {
"x": -55.58689609637031,
"y": -111.53602444662268
}
},
{
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
"type": "invocation",
"data": {
"id": "f7e394ac-6394-4096-abcb-de0d346506b3",
"type": "rand_int",
"version": "1.0.1",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": false,
"nodePack": "invokeai",
"inputs": {
"low": {
"name": "low",
"label": "",
"value": 0
},
"high": {
"name": "high",
"label": "",
"value": 2147483647
}
}
},
"position": {
"x": 470.45870147220353,
"y": 350.3141781644303
}
},
{
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"type": "invocation",
"data": {
"id": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"type": "sd3_l2i",
"version": "1.3.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": false,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"latents": {
"name": "latents",
"label": ""
},
"vae": {
"name": "vae",
"label": ""
}
}
},
"position": {
"x": 1192.3097009334897,
"y": -366.0994675072209
}
},
{
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"type": "invocation",
"data": {
"id": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"type": "sd3_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"clip_l": {
"name": "clip_l",
"label": ""
},
"clip_g": {
"name": "clip_g",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"prompt": {
"name": "prompt",
"label": "",
"value": ""
}
}
},
"position": {
"x": 408.16054647924784,
"y": 65.06415352118786
}
},
{
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"type": "invocation",
"data": {
"id": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"type": "sd3_text_encoder",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"clip_l": {
"name": "clip_l",
"label": ""
},
"clip_g": {
"name": "clip_g",
"label": ""
},
"t5_encoder": {
"name": "t5_encoder",
"label": ""
},
"prompt": {
"name": "prompt",
"label": "",
"value": ""
}
}
},
"position": {
"x": 378.9283412440941,
"y": -302.65777497352553
}
},
{
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"type": "invocation",
"data": {
"id": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"type": "sd3_denoise",
"version": "1.0.0",
"label": "",
"notes": "",
"isOpen": true,
"isIntermediate": true,
"useCache": true,
"nodePack": "invokeai",
"inputs": {
"board": {
"name": "board",
"label": ""
},
"metadata": {
"name": "metadata",
"label": ""
},
"transformer": {
"name": "transformer",
"label": ""
},
"positive_conditioning": {
"name": "positive_conditioning",
"label": ""
},
"negative_conditioning": {
"name": "negative_conditioning",
"label": ""
},
"cfg_scale": {
"name": "cfg_scale",
"label": "",
"value": 3.5
},
"width": {
"name": "width",
"label": "",
"value": 1024
},
"height": {
"name": "height",
"label": "",
"value": 1024
},
"steps": {
"name": "steps",
"label": "",
"value": 30
},
"seed": {
"name": "seed",
"label": "",
"value": 0
}
}
},
"position": {
"x": 813.7814762740603,
"y": -142.20529727605867
}
}
],
"edges": [
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cvae-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48bvae",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"sourceHandle": "vae",
"targetHandle": "vae"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-3b4f7f27-cfc0-4373-a009-99c5290d0cd6t5_encoder",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ct5_encoder-e17d34e7-6ed1-493c-9a85-4fcd291cb084t5_encoder",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "t5_encoder",
"targetHandle": "t5_encoder"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_g",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "clip_g",
"targetHandle": "clip_g"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_g-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_g",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "clip_g",
"targetHandle": "clip_g"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-3b4f7f27-cfc0-4373-a009-99c5290d0cd6clip_l",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"sourceHandle": "clip_l",
"targetHandle": "clip_l"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4cclip_l-e17d34e7-6ed1-493c-9a85-4fcd291cb084clip_l",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"sourceHandle": "clip_l",
"targetHandle": "clip_l"
},
{
"id": "reactflow__edge-3f22f668-0e02-4fde-a2bb-c339586ceb4ctransformer-c7539f7b-7ac5-49b9-93eb-87ede611409ftransformer",
"type": "default",
"source": "3f22f668-0e02-4fde-a2bb-c339586ceb4c",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "transformer",
"targetHandle": "transformer"
},
{
"id": "reactflow__edge-f7e394ac-6394-4096-abcb-de0d346506b3value-c7539f7b-7ac5-49b9-93eb-87ede611409fseed",
"type": "default",
"source": "f7e394ac-6394-4096-abcb-de0d346506b3",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "value",
"targetHandle": "seed"
},
{
"id": "reactflow__edge-c7539f7b-7ac5-49b9-93eb-87ede611409flatents-9eb72af0-dd9e-4ec5-ad87-d65e3c01f48blatents",
"type": "default",
"source": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"target": "9eb72af0-dd9e-4ec5-ad87-d65e3c01f48b",
"sourceHandle": "latents",
"targetHandle": "latents"
},
{
"id": "reactflow__edge-e17d34e7-6ed1-493c-9a85-4fcd291cb084conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fpositive_conditioning",
"type": "default",
"source": "e17d34e7-6ed1-493c-9a85-4fcd291cb084",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "conditioning",
"targetHandle": "positive_conditioning"
},
{
"id": "reactflow__edge-3b4f7f27-cfc0-4373-a009-99c5290d0cd6conditioning-c7539f7b-7ac5-49b9-93eb-87ede611409fnegative_conditioning",
"type": "default",
"source": "3b4f7f27-cfc0-4373-a009-99c5290d0cd6",
"target": "c7539f7b-7ac5-49b9-93eb-87ede611409f",
"sourceHandle": "conditioning",
"targetHandle": "negative_conditioning"
}
]
}

View File

@@ -34,6 +34,25 @@ SD1_5_LATENT_RGB_FACTORS = [
[-0.1307, -0.1874, -0.7445], # L4
]
SD3_5_LATENT_RGB_FACTORS = [
[-0.05240681, 0.03251581, 0.0749016],
[-0.0580572, 0.00759826, 0.05729818],
[0.16144888, 0.01270368, -0.03768577],
[0.14418615, 0.08460266, 0.15941818],
[0.04894035, 0.0056485, -0.06686988],
[0.05187166, 0.19222395, 0.06261094],
[0.1539433, 0.04818359, 0.07103094],
[-0.08601796, 0.09013458, 0.10893912],
[-0.12398469, -0.06766567, 0.0033688],
[-0.0439737, 0.07825329, 0.02258823],
[0.03101129, 0.06382551, 0.07753657],
[-0.01315361, 0.08554491, -0.08772475],
[0.06464487, 0.05914605, 0.13262741],
[-0.07863674, -0.02261737, -0.12761454],
[-0.09923835, -0.08010759, -0.06264447],
[-0.03392309, -0.0804029, -0.06078822],
]
FLUX_LATENT_RGB_FACTORS = [
[-0.0412, 0.0149, 0.0521],
[0.0056, 0.0291, 0.0768],
@@ -110,6 +129,9 @@ def stable_diffusion_step_callback(
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
elif base_model == BaseModelType.StableDiffusion3:
sd3_latent_rgb_factors = torch.tensor(SD3_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, sd3_latent_rgb_factors)
else:
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)

View File

@@ -22,7 +22,7 @@ def generate_ti_list(
for trigger in extract_ti_triggers_from_prompt(prompt):
name_or_key = trigger[1:-1]
try:
loaded_model = context.models.load(name_or_key)
loaded_model = context.models.load(name_or_key, queue_id=context.util.get_queue_id())
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)
assert loaded_model.config.base == base
@@ -30,7 +30,7 @@ def generate_ti_list(
except UnknownModelException:
try:
loaded_model = context.models.load_by_attrs(
name=name_or_key, base=base, type=ModelType.TextualInversion
name=name_or_key, base=base, type=ModelType.TextualInversion, queue_id=context.util.get_queue_id()
)
model = loaded_model.model
assert isinstance(model, TextualInversionModelRaw)

View File

@@ -53,6 +53,7 @@ class BaseModelType(str, Enum):
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusion3 = "sd-3"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
Flux = "flux"
@@ -83,8 +84,10 @@ class SubModelType(str, Enum):
Transformer = "transformer"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"
@@ -92,6 +95,13 @@ class SubModelType(str, Enum):
SafetyChecker = "safety_checker"
class ClipVariantType(str, Enum):
"""Variant type."""
L = "large"
G = "gigantic"
class ModelVariantType(str, Enum):
"""Variant type."""
@@ -147,6 +157,15 @@ class ModelSourceType(str, Enum):
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
class SubmodelDefinition(BaseModel):
path_or_prefix: str
model_type: ModelType
variant: AnyVariant = None
class MainModelDefaultSettings(BaseModel):
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
@@ -193,6 +212,9 @@ class ModelConfigBase(BaseModel):
schema["required"].extend(["key", "type", "format"])
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
submodels: Optional[Dict[SubModelType, SubmodelDefinition]] = Field(
description="Loadable submodels in this model", default=None
)
class CheckpointConfigBase(ModelConfigBase):
@@ -335,7 +357,7 @@ class MainConfigBase(ModelConfigBase):
default_settings: Optional[MainModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
variant: ModelVariantType = ModelVariantType.Normal
variant: AnyVariant = ModelVariantType.Normal
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
@@ -419,12 +441,33 @@ class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
variant: ClipVariantType = ClipVariantType.L
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
class CLIPGEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
"""Model config for CLIP-G Embeddings."""
variant: ClipVariantType = ClipVariantType.G
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.G}")
class CLIPLEmbedDiffusersConfig(CLIPEmbedDiffusersConfig):
"""Model config for CLIP-L Embeddings."""
variant: ClipVariantType = ClipVariantType.L
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}.{ClipVariantType.L}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision."""
@@ -501,6 +544,8 @@ AnyModelConfig = Annotated[
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
Annotated[CLIPLEmbedDiffusersConfig, CLIPLEmbedDiffusersConfig.get_tag()],
Annotated[CLIPGEmbedDiffusersConfig, CLIPGEmbedDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]

View File

@@ -128,9 +128,9 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
"The bnb modules are not available. Please install bitsandbytes if available on your platform."
)
match submodel_type:
case SubModelType.Tokenizer2:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
with accelerate.init_empty_weights():
@@ -172,9 +172,9 @@ class T5EncoderCheckpointModel(ModelLoader):
raise ValueError("Only T5EncoderConfig models are currently supported here.")
match submodel_type:
case SubModelType.Tokenizer2:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2:
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(Path(config.path) / "text_encoder_2", torch_dtype="auto")
raise ValueError(

View File

@@ -42,6 +42,7 @@ VARIANT_TO_IN_CHANNEL_MAP = {
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXLRefiner, type=ModelType.Main, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.Main, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Main, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.Main, format=ModelFormat.Checkpoint)
@@ -51,13 +52,6 @@ VARIANT_TO_IN_CHANNEL_MAP = {
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
def _load_model(
self,
config: AnyModelConfig,

View File

@@ -1,7 +1,7 @@
import json
import re
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
from typing import Any, Callable, Dict, Literal, Optional, Union
import safetensors.torch
import spandrel
@@ -22,6 +22,7 @@ from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import i
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS, ModelHash
from invokeai.backend.model_manager.config import (
AnyModelConfig,
AnyVariant,
BaseModelType,
ControlAdapterDefaultSettings,
InvalidModelConfigException,
@@ -33,8 +34,15 @@ from invokeai.backend.model_manager.config import (
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubmodelDefinition,
SubModelType,
)
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import ConfigLoader
from invokeai.backend.model_manager.util.model_util import (
get_clip_variant_type,
lora_token_vector_length,
read_checkpoint_meta,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
@@ -112,6 +120,7 @@ class ModelProbe(object):
"StableDiffusionXLPipeline": ModelType.Main,
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"StableDiffusion3Pipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"AutoencoderKL": ModelType.VAE,
"AutoencoderTiny": ModelType.VAE,
@@ -122,8 +131,12 @@ class ModelProbe(object):
"CLIPTextModel": ModelType.CLIPEmbed,
"T5EncoderModel": ModelType.T5Encoder,
"FluxControlNetModel": ModelType.ControlNet,
"SD3Transformer2DModel": ModelType.Main,
"CLIPTextModelWithProjection": ModelType.CLIPEmbed,
}
TYPE2VARIANT: Dict[ModelType, Callable[[str], Optional[AnyVariant]]] = {ModelType.CLIPEmbed: get_clip_variant_type}
@classmethod
def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: type[ProbeBase]
@@ -170,7 +183,10 @@ class ModelProbe(object):
fields["path"] = model_path.as_posix()
fields["type"] = fields.get("type") or model_type
fields["base"] = fields.get("base") or probe.get_base_type()
fields["variant"] = fields.get("variant") or probe.get_variant_type()
variant_func = cls.TYPE2VARIANT.get(fields["type"], None)
fields["variant"] = (
fields.get("variant") or (variant_func and variant_func(model_path.as_posix())) or probe.get_variant_type()
)
fields["prediction_type"] = fields.get("prediction_type") or probe.get_scheduler_prediction_type()
fields["image_encoder_model_id"] = fields.get("image_encoder_model_id") or probe.get_image_encoder_model_id()
fields["name"] = fields.get("name") or cls.get_model_name(model_path)
@@ -217,6 +233,10 @@ class ModelProbe(object):
and fields["prediction_type"] == SchedulerPredictionType.VPrediction
)
get_submodels = getattr(probe, "get_submodels", None)
if fields["base"] == BaseModelType.StableDiffusion3 and callable(get_submodels):
fields["submodels"] = get_submodels()
model_info = ModelConfigFactory.make_config(fields) # , key=fields.get("key", None))
return model_info
@@ -747,18 +767,33 @@ class FolderProbeBase(ProbeBase):
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
# Handle pipelines with a UNet (i.e SD 1.x, SD2, SDXL).
config_path = self.model_path / "unet" / "config.json"
if config_path.exists():
with open(config_path) as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
return BaseModelType.StableDiffusion1
elif unet_conf["cross_attention_dim"] == 1024:
return BaseModelType.StableDiffusion2
elif unet_conf["cross_attention_dim"] == 1280:
return BaseModelType.StableDiffusionXLRefiner
elif unet_conf["cross_attention_dim"] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
# Handle pipelines with a transformer (i.e. SD3).
config_path = self.model_path / "transformer" / "config.json"
if config_path.exists():
with open(config_path) as file:
transformer_conf = json.load(file)
if transformer_conf["_class_name"] == "SD3Transformer2DModel":
return BaseModelType.StableDiffusion3
else:
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
raise InvalidModelConfigException(f"Unknown base model for {self.model_path}")
def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
with open(self.model_path / "scheduler" / "scheduler_config.json", "r") as file:
@@ -770,6 +805,23 @@ class PipelineFolderProbe(FolderProbeBase):
else:
raise InvalidModelConfigException("Unknown scheduler prediction type: {scheduler_conf['prediction_type']}")
def get_submodels(self) -> Dict[SubModelType, SubmodelDefinition]:
config = ConfigLoader.load_config(self.model_path, config_name="model_index.json")
submodels: Dict[SubModelType, SubmodelDefinition] = {}
for key, value in config.items():
if key.startswith("_") or not (isinstance(value, list) and len(value) == 2):
continue
model_loader = str(value[1])
if model_type := ModelProbe.CLASS2TYPE.get(model_loader):
variant_func = ModelProbe.TYPE2VARIANT.get(model_type, None)
submodels[SubModelType(key)] = SubmodelDefinition(
path_or_prefix=(self.model_path / key).resolve().as_posix(),
model_type=model_type,
variant=variant_func and variant_func((self.model_path / key).as_posix()),
)
return submodels
def get_variant_type(self) -> ModelVariantType:
# This only works for pipelines! Any kind of
# exception results in our returning the

View File

@@ -140,6 +140,22 @@ flux_dev = StarterModel(
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
sd35_medium = StarterModel(
name="SD3.5 Medium",
base=BaseModelType.StableDiffusion3,
source="stabilityai/stable-diffusion-3.5-medium",
description="Medium SD3.5 Model: ~15GB",
type=ModelType.Main,
dependencies=[],
)
sd35_large = StarterModel(
name="SD3.5 Large",
base=BaseModelType.StableDiffusion3,
source="stabilityai/stable-diffusion-3.5-large",
description="Large SD3.5 Model: ~19G",
type=ModelType.Main,
dependencies=[],
)
cyberrealistic_sd1 = StarterModel(
name="CyberRealistic v4.1",
base=BaseModelType.StableDiffusion1,
@@ -570,6 +586,8 @@ STARTER_MODELS: list[StarterModel] = [
flux_dev_quantized,
flux_schnell,
flux_dev,
sd35_medium,
sd35_large,
cyberrealistic_sd1,
rev_animated_sd1,
dreamshaper_8_sd1,

View File

@@ -8,6 +8,7 @@ import safetensors
import torch
from picklescan.scanner import scan_file_path
from invokeai.backend.model_manager.config import ClipVariantType
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
@@ -165,3 +166,23 @@ def convert_bundle_to_flux_transformer_checkpoint(
del transformer_state_dict[k]
return original_state_dict
def get_clip_variant_type(location: str) -> Optional[ClipVariantType]:
try:
path = Path(location)
config_path = path / "config.json"
if not config_path.exists():
return ClipVariantType.L
with open(config_path) as file:
clip_conf = json.load(file)
hidden_size = clip_conf.get("hidden_size", -1)
match hidden_size:
case 1280:
return ClipVariantType.G
case 768:
return ClipVariantType.L
case _:
return ClipVariantType.L
except Exception:
return ClipVariantType.L

View File

@@ -129,9 +129,11 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
if candidate_variant_label == f".{variant}" or (
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
):
if (
variant is not ModelRepoVariant.Default
and candidate_variant_label
and candidate_variant_label.startswith(f".{variant.value}")
) or (not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]):
score += 1
if parent not in subfolder_weights:
@@ -146,7 +148,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# Check if at least one of the files has the explicit fp16 variant.
at_least_one_fp16 = False
for candidate in candidate_list:
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0] == ".fp16":
if len(candidate.path.suffixes) == 2 and candidate.path.suffixes[0].startswith(".fp16"):
at_least_one_fp16 = True
break
@@ -162,7 +164,16 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
# candidate.
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:
result.add(highest_score_candidate.path)
pattern = r"^(.*?)-\d+-of-\d+(\.\w+)$"
match = re.match(pattern, highest_score_candidate.path.as_posix())
if match:
for candidate in candidate_list:
if candidate.path.as_posix().startswith(match.group(1)) and candidate.path.as_posix().endswith(
match.group(2)
):
result.add(candidate.path)
else:
result.add(highest_score_candidate.path)
# If one of the architecture-related variants was specified and no files matched other than
# config and text files then we return an empty list

View File

@@ -49,9 +49,32 @@ class FLUXConditioningInfo:
return self
@dataclass
class SD3ConditioningInfo:
clip_l_pooled_embeds: torch.Tensor
clip_l_embeds: torch.Tensor
clip_g_pooled_embeds: torch.Tensor
clip_g_embeds: torch.Tensor
t5_embeds: torch.Tensor | None
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.clip_l_pooled_embeds = self.clip_l_pooled_embeds.to(device=device, dtype=dtype)
self.clip_l_embeds = self.clip_l_embeds.to(device=device, dtype=dtype)
self.clip_g_pooled_embeds = self.clip_g_pooled_embeds.to(device=device, dtype=dtype)
self.clip_g_embeds = self.clip_g_embeds.to(device=device, dtype=dtype)
if self.t5_embeds is not None:
self.t5_embeds = self.t5_embeds.to(device=device, dtype=dtype)
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo] | List[SDXLConditioningInfo] | List[FLUXConditioningInfo]
conditionings: (
List[BasicConditioningInfo]
| List[SDXLConditioningInfo]
| List[FLUXConditioningInfo]
| List[SD3ConditioningInfo]
)
@dataclass

View File

@@ -29,7 +29,7 @@ class LoRAExt(ExtensionBase):
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
lora_model = self._node_context.models.load(self._model_id).model
lora_model = self._node_context.models.load(self._model_id, self._node_context.util.get_queue_id()).model
assert isinstance(lora_model, LoRAModelRaw)
LoRAPatcher.apply_lora_patch(
model=unet,

View File

@@ -54,7 +54,7 @@ class T2IAdapterExt(ExtensionBase):
@callback(ExtensionCallbackType.SETUP)
def setup(self, ctx: DenoiseContext):
t2i_model: T2IAdapter
with self._node_context.models.load(self._model_id) as t2i_model:
with self._node_context.models.load(self._model_id, self._node_context.util.get_queue_id()) as t2i_model:
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
self._adapter_state = self._run_model(

View File

@@ -9,6 +9,7 @@ const config: KnipConfig = {
'src/services/api/schema.ts',
'src/features/nodes/types/v1/**',
'src/features/nodes/types/v2/**',
'src/features/parameters/types/parameterSchemas.ts',
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
'src/features/controlLayers/konva/util.ts',
// TODO(psyche): restore HRF functionality?

Binary file not shown.

After

Width:  |  Height:  |  Size: 895 KiB

View File

@@ -997,6 +997,7 @@
"controlNetControlMode": "Control Mode",
"copyImage": "Copy Image",
"denoisingStrength": "Denoising Strength",
"noRasterLayers": "No Raster Layers",
"downloadImage": "Download Image",
"general": "General",
"guidance": "Guidance",
@@ -1412,8 +1413,9 @@
"paramDenoisingStrength": {
"heading": "Denoising Strength",
"paragraphs": [
"How much noise is added to the input image.",
"0 will result in an identical image, while 1 will result in a completely new image."
"Controls how much the generated image varies from the raster layer(s).",
"Lower strength stays closer to the combined visible raster layers. Higher strength relies more on the global prompt.",
"When there are no raster layers with visible content, this setting is ignored."
]
},
"paramHeight": {
@@ -1662,6 +1664,7 @@
"mergeDown": "Merge Down",
"mergeVisibleOk": "Merged layers",
"mergeVisibleError": "Error merging layers",
"mergingLayers": "Merging layers",
"clearHistory": "Clear History",
"bboxOverlay": "Show Bbox Overlay",
"resetCanvas": "Reset Canvas",
@@ -1774,9 +1777,10 @@
"newCanvasSession": "New Canvas Session",
"newCanvasSessionDesc": "This will clear the canvas and all settings except for your model selection. Generations will be staged on the canvas.",
"replaceCurrent": "Replace Current",
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, or draw on the canvas to get started.",
"controlMode": {
"controlMode": "Control Mode",
"balanced": "Balanced",
"balanced": "Balanced (recommended)",
"prompt": "Prompt",
"control": "Control",
"megaControl": "Mega Control"
@@ -1815,6 +1819,9 @@
"process": "Process",
"apply": "Apply",
"cancel": "Cancel",
"advanced": "Advanced",
"processingLayerWith": "Processing layer with the {{type}} filter.",
"forMoreControl": "For more control, click Advanced below.",
"spandrel_filter": {
"label": "Image-to-Image Model",
"description": "Run an image-to-image model on the selected layer.",
@@ -2095,9 +2102,8 @@
},
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"line1": "<ItalicComponent>Select Object</ItalicComponent> tool for precise object selection and editing",
"line2": "Expanded Flux support, now with Global Reference Images",
"line3": "Improved tooltips and context menus",
"line1": "<StrongComponent>Layer Merging</StrongComponent>: New <StrongComponent>Merge Down</StrongComponent> and improved <StrongComponent>Merge Visible</StrongComponent> for all layers, with special handling for Regional Guidance and Control Layers.",
"line2": "<StrongComponent>HF Token Support</StrongComponent>: Upload models that require Hugging Face authentication.",
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
"watchUiUpdatesOverview": "Watch UI Updates Overview"

View File

@@ -2,7 +2,7 @@ import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { selectDefaultControlAdapter, selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import {
controlLayerAdded,
@@ -23,7 +23,7 @@ import type {
CanvasReferenceImageState,
CanvasRegionalGuidanceState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
@@ -163,11 +163,10 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
const state = getState();
const imageObject = imageDTOToImageObject(activeData.payload.imageDTO);
const { x, y } = selectCanvasSlice(state).bbox.rect;
const defaultControlAdapter = selectDefaultControlAdapter(state);
const overrides: Partial<CanvasControlLayerState> = {
objects: [imageObject],
position: { x, y },
controlAdapter: defaultControlAdapter,
controlAdapter: deepClone(initialControlNet),
};
dispatch(controlLayerAdded({ overrides, isSelected: true }));
return;

View File

@@ -164,7 +164,7 @@ const handleVAEModels: ModelHandler = (models, state, dispatch, log) => {
// We have a VAE selected, need to check if it is available
// Grab just the VAE models
const vaeModels = models.filter(isNonFluxVAEModelConfig);
const vaeModels = models.filter((m) => isNonFluxVAEModelConfig(m));
// If the current VAE model is available, we don't need to do anything
if (vaeModels.some((m) => m.key === selectedVAEModel.key)) {
@@ -297,7 +297,7 @@ const handleUpscaleModel: ModelHandler = (models, state, dispatch, log) => {
const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
const selectedT5EncoderModel = state.params.t5EncoderModel;
const t5EncoderModels = models.filter(isT5EncoderModelConfig);
const t5EncoderModels = models.filter((m) => isT5EncoderModelConfig(m));
// If the currently selected model is available, we don't need to do anything
if (selectedT5EncoderModel && t5EncoderModels.some((m) => m.key === selectedT5EncoderModel.key)) {
@@ -325,7 +325,7 @@ const handleT5EncoderModels: ModelHandler = (models, state, dispatch, log) => {
const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
const selectedCLIPEmbedModel = state.params.clipEmbedModel;
const CLIPEmbedModels = models.filter(isCLIPEmbedModelConfig);
const CLIPEmbedModels = models.filter((m) => isCLIPEmbedModelConfig(m));
// If the currently selected model is available, we don't need to do anything
if (selectedCLIPEmbedModel && CLIPEmbedModels.some((m) => m.key === selectedCLIPEmbedModel.key)) {
@@ -353,7 +353,7 @@ const handleCLIPEmbedModels: ModelHandler = (models, state, dispatch, log) => {
const handleFLUXVAEModels: ModelHandler = (models, state, dispatch, log) => {
const selectedFLUXVAEModel = state.params.fluxVAE;
const fluxVAEModels = models.filter(isFluxVAEModelConfig);
const fluxVAEModels = models.filter((m) => isFluxVAEModelConfig(m));
// If the currently selected model is available, we don't need to do anything
if (selectedFLUXVAEModel && fluxVAEModels.some((m) => m.key === selectedFLUXVAEModel.key)) {

View File

@@ -4,8 +4,10 @@ import { atom } from 'nanostores';
/**
* A fallback non-writable atom that always returns `false`, used when a nanostores atom is only conditionally available
* in a hook or component.
*
* @knipignore
*/
// export const $false: ReadableAtom<boolean> = atom(false);
export const $false: ReadableAtom<boolean> = atom(false);
/**
* A fallback non-writable atom that always returns `true`, used when a nanostores atom is only conditionally available
* in a hook or component.

View File

@@ -1,5 +1,6 @@
import type { PopoverProps } from '@invoke-ai/ui-library';
import commercialLicenseBg from 'public/assets/images/commercial-license-bg.png';
import denoisingStrength from 'public/assets/images/denoising-strength.png';
export type Feature =
| 'clipSkip'
@@ -125,7 +126,7 @@ export const POPOVER_DATA: { [key in Feature]?: PopoverData } = {
href: 'https://support.invoke.ai/support/solutions/articles/151000158838-compositing-settings',
},
infillMethod: {
href: 'https://support.invoke.ai/support/solutions/articles/151000158841-infill-and-scaling',
href: 'https://support.invoke.ai/support/solutions/articles/151000158838-compositing-settings',
},
scaleBeforeProcessing: {
href: 'https://support.invoke.ai/support/solutions/articles/151000158841',
@@ -138,6 +139,7 @@ export const POPOVER_DATA: { [key in Feature]?: PopoverData } = {
},
paramDenoisingStrength: {
href: 'https://support.invoke.ai/support/solutions/articles/151000094998-image-to-image',
image: denoisingStrength,
},
paramHrf: {
href: 'https://support.invoke.ai/support/solutions/articles/151000096700-how-can-i-get-larger-images-what-does-upscaling-do-',

View File

@@ -0,0 +1,57 @@
type Props = {
/**
* The amplitude of the wave. 0 is a straight line, higher values create more pronounced waves.
*/
amplitude: number;
/**
* The number of segments in the line. More segments create a smoother wave.
*/
segments?: number;
/**
* The color of the wave.
*/
stroke: string;
/**
* The width of the wave.
*/
strokeWidth: number;
/**
* The width of the SVG.
*/
width: number;
/**
* The height of the SVG.
*/
height: number;
};
const WavyLine = ({ amplitude, stroke, strokeWidth, width, height, segments = 5 }: Props) => {
// Calculate the path dynamically based on waviness
const generatePath = () => {
if (amplitude === 0) {
// If waviness is 0, return a straight line
return `M0,${height / 2} L${width},${height / 2}`;
}
const clampedAmplitude = Math.min(height / 2, amplitude); // Cap amplitude to half the height
const segmentWidth = width / segments;
let path = `M0,${height / 2}`; // Start in the middle of the left edge
// Loop through each segment and alternate the y position to create waves
for (let i = 1; i <= segments; i++) {
const x = i * segmentWidth;
const y = height / 2 + (i % 2 === 0 ? clampedAmplitude : -clampedAmplitude);
path += ` Q${x - segmentWidth / 2},${y} ${x},${height / 2}`;
}
return path;
};
return (
<svg width={width} height={height} viewBox={`0 0 ${width} ${height}`} xmlns="http://www.w3.org/2000/svg">
<path d={generatePath()} fill="none" stroke={stroke} strokeWidth={strokeWidth} />
</svg>
);
};
export default WavyLine;

View File

@@ -0,0 +1,15 @@
import type { CSSProperties } from 'react';
/**
* Chakra's Tooltip's method of finding the nearest scroll parent has a problem - it assumes the first parent with
* `overflow: hidden` is the scroll parent. In this case, the Collapse component has that style, but isn't scrollable
* itself. The result is that the tooltip does not close on scroll, because the scrolling happens higher up in the DOM.
*
* As a hacky workaround, we can set the overflow to `visible`, which allows the scroll parent search to continue up to
* the actual scroll parent (in this case, the OverlayScrollbarsComponent in BoardsListWrapper).
*
* See: https://github.com/chakra-ui/chakra-ui/issues/7871#issuecomment-2453780958
*/
export const fixTooltipCloseOnScrollStyles: CSSProperties = {
overflow: 'visible',
};

View File

@@ -7,6 +7,8 @@ import { EntityListSelectedEntityActionBar } from 'features/controlLayers/compon
import { selectHasEntities } from 'features/controlLayers/store/selectors';
import { memo, useRef } from 'react';
import { ParamDenoisingStrength } from './ParamDenoisingStrength';
export const CanvasLayersPanelContent = memo(() => {
const hasEntities = useAppSelector(selectHasEntities);
const layersPanelFocusRef = useRef<HTMLDivElement>(null);
@@ -16,6 +18,8 @@ export const CanvasLayersPanelContent = memo(() => {
<Flex ref={layersPanelFocusRef} flexDir="column" gap={2} w="full" h="full">
<EntityListSelectedEntityActionBar />
<Divider py={0} />
<ParamDenoisingStrength />
<Divider py={0} />
{!hasEntities && <CanvasAddEntityButtons />}
{hasEntities && <CanvasEntityList />}
</Flex>

View File

@@ -7,7 +7,7 @@ import { CanvasEntityPreviewImage } from 'features/controlLayers/components/comm
import { CanvasEntitySettingsWrapper } from 'features/controlLayers/components/common/CanvasEntitySettingsWrapper';
import { CanvasEntityEditableTitle } from 'features/controlLayers/components/common/CanvasEntityTitleEdit';
import { ControlLayerBadges } from 'features/controlLayers/components/ControlLayer/ControlLayerBadges';
import { ControlLayerControlAdapter } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapter';
import { ControlLayerSettings } from 'features/controlLayers/components/ControlLayer/ControlLayerSettings';
import { ControlLayerAdapterGate } from 'features/controlLayers/contexts/EntityAdapterContext';
import { EntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
@@ -41,7 +41,7 @@ export const ControlLayer = memo(({ id }: Props) => {
<CanvasEntityHeaderCommonActions />
</CanvasEntityHeader>
<CanvasEntitySettingsWrapper>
<ControlLayerControlAdapter />
<ControlLayerSettings />
</CanvasEntitySettingsWrapper>
<IAIDroppable data={dropData} dropLabel={t('controlLayers.replaceLayer')} />
</CanvasEntityContainer>

View File

@@ -6,6 +6,7 @@ import { BeginEndStepPct } from 'features/controlLayers/components/common/BeginE
import { Weight } from 'features/controlLayers/components/common/Weight';
import { ControlLayerControlAdapterControlMode } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterControlMode';
import { ControlLayerControlAdapterModel } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapterModel';
import { useEntityAdapterContext } from 'features/controlLayers/contexts/EntityAdapterContext';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { usePullBboxIntoLayer } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
@@ -16,6 +17,7 @@ import {
controlLayerModelChanged,
controlLayerWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { getFilterForModel } from 'features/controlLayers/store/filters';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
@@ -44,6 +46,7 @@ export const ControlLayerControlAdapter = memo(() => {
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
const filter = useEntityFilter(entityIdentifier);
const isFLUX = useAppSelector(selectIsFLUX);
const adapter = useEntityAdapterContext('control_layer');
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
@@ -69,8 +72,43 @@ export const ControlLayerControlAdapter = memo(() => {
const onChangeModel = useCallback(
(modelConfig: ControlNetModelConfig | T2IAdapterModelConfig) => {
dispatch(controlLayerModelChanged({ entityIdentifier, modelConfig }));
// When we change the model, we need may need to start filtering w/ the simplified filter mode, and/or change the
// filter config.
const isFiltering = adapter.filterer.$isFiltering.get();
const isSimple = adapter.filterer.$simple.get();
// If we are filtering and _not_ in simple mode, that means the user has clicked Advanced. They want to be in control
// of the settings. Bail early without doing anything else.
if (isFiltering && !isSimple) {
return;
}
// Else, we are in simple mode and will take care of some things for the user.
// First, check if the newly-selected model has a default filter. It may not - for example, Tile controlnet models
// don't have a default filter.
const defaultFilterForNewModel = getFilterForModel(modelConfig);
if (!defaultFilterForNewModel) {
// The user has chosen a model that doesn't have a default filter - cancel any in-progress filtering and bail.
if (isFiltering) {
adapter.filterer.cancel();
}
return;
}
// At this point, we know the user has selected a model that has a default filter. We need to either start filtering
// with that default filter, or update the existing filter config to match the new model's default filter.
const filterConfig = defaultFilterForNewModel.buildDefaults();
if (isFiltering) {
adapter.filterer.$filterConfig.set(filterConfig);
} else {
adapter.filterer.start(filterConfig);
}
// The user may have disabled auto-processing, so we should process the filter manually. This is essentially a
// no-op if auto-processing is already enabled, because the process method is debounced.
adapter.filterer.process();
},
[dispatch, entityIdentifier]
[adapter.filterer, dispatch, entityIdentifier]
);
const pullBboxIntoLayer = usePullBboxIntoLayer(entityIdentifier);

View File

@@ -0,0 +1,18 @@
import { ControlLayerControlAdapter } from 'features/controlLayers/components/ControlLayer/ControlLayerControlAdapter';
import { ControlLayerSettingsEmptyState } from 'features/controlLayers/components/ControlLayer/ControlLayerSettingsEmptyState';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useEntityIsEmpty } from 'features/controlLayers/hooks/useEntityIsEmpty';
import { memo } from 'react';
export const ControlLayerSettings = memo(() => {
const entityIdentifier = useEntityIdentifierContext();
const isEmpty = useEntityIsEmpty(entityIdentifier);
if (isEmpty) {
return <ControlLayerSettingsEmptyState />;
}
return <ControlLayerControlAdapter />;
});
ControlLayerSettings.displayName = 'ControlLayerSettings';

View File

@@ -0,0 +1,50 @@
import { Button, Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { memo, useCallback, useMemo } from 'react';
import { Trans } from 'react-i18next';
import type { PostUploadAction } from 'services/api/types';
export const ControlLayerSettingsEmptyState = memo(() => {
const entityIdentifier = useEntityIdentifierContext('control_layer');
const dispatch = useAppDispatch();
const isBusy = useCanvasIsBusy();
const postUploadAction = useMemo<PostUploadAction>(
() => ({ type: 'REPLACE_LAYER_WITH_IMAGE', entityIdentifier }),
[entityIdentifier]
);
const uploadApi = useImageUploadButton({ postUploadAction });
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
return (
<Flex flexDir="column" gap={3} position="relative" w="full" p={4}>
<Text textAlign="center" color="base.300">
<Trans
i18nKey="controlLayers.controlLayerEmptyState"
components={{
UploadButton: (
<Button
isDisabled={isBusy}
size="sm"
variant="link"
color="base.300"
{...uploadApi.getUploadButtonProps()}
/>
),
GalleryButton: (
<Button onClick={onClickGalleryButton} isDisabled={isBusy} size="sm" variant="link" color="base.300" />
),
}}
/>
</Text>
<input {...uploadApi.getUploadInputProps()} />
</Flex>
);
});
ControlLayerSettingsEmptyState.displayName = 'ControlLayerSettingsEmptyState';

View File

@@ -9,6 +9,7 @@ import {
MenuList,
Spacer,
Spinner,
Text,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
@@ -28,13 +29,10 @@ import { memo, useCallback, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
const FilterContent = memo(
const FilterContentAdvanced = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const ref = useRef<HTMLDivElement>(null);
useFocusRegion('canvas', ref, { focusOnMount: true });
const config = useStore(adapter.filterer.$filterConfig);
const isCanvasFocused = useIsRegionFocused('canvas');
const isProcessing = useStore(adapter.filterer.$isProcessing);
const hasImageState = useStore(adapter.filterer.$hasImageState);
const autoProcess = useAppSelector(selectAutoProcess);
@@ -73,36 +71,8 @@ const FilterContent = memo(
adapter.filterer.saveAs('control_layer');
}, [adapter.filterer]);
useRegisteredHotkeys({
id: 'applyFilter',
category: 'canvas',
callback: adapter.filterer.apply,
options: { enabled: !isProcessing && isCanvasFocused },
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
});
useRegisteredHotkeys({
id: 'cancelFilter',
category: 'canvas',
callback: adapter.filterer.cancel,
options: { enabled: !isProcessing && isCanvasFocused },
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
});
return (
<Flex
ref={ref}
bg="base.800"
borderRadius="base"
p={4}
flexDir="column"
gap={4}
w={420}
h="auto"
shadow="dark-lg"
transitionProperty="height"
transitionDuration="normal"
>
<>
<Flex w="full" gap={4}>
<Heading size="md" color="base.300" userSelect="none">
{t('controlLayers.filter.filter')}
@@ -169,12 +139,67 @@ const FilterContent = memo(
{t('controlLayers.filter.cancel')}
</Button>
</ButtonGroup>
</Flex>
</>
);
}
);
FilterContent.displayName = 'FilterContent';
FilterContentAdvanced.displayName = 'FilterContentAdvanced';
const FilterContentSimple = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const { t } = useTranslation();
const config = useStore(adapter.filterer.$filterConfig);
const isProcessing = useStore(adapter.filterer.$isProcessing);
const hasImageState = useStore(adapter.filterer.$hasImageState);
const isValid = useMemo(() => {
return IMAGE_FILTERS[config.type].validateConfig?.(config as never) ?? true;
}, [config]);
const onClickAdvanced = useCallback(() => {
adapter.filterer.$simple.set(false);
}, [adapter.filterer.$simple]);
return (
<>
<Flex w="full" gap={4}>
<Heading size="md" color="base.300" userSelect="none">
{t('controlLayers.filter.filter')}
</Heading>
<Spacer />
</Flex>
<Flex flexDir="column" w="full" gap={2} pb={2}>
<Text color="base.500" textAlign="center">
{t('controlLayers.filter.processingLayerWith', { type: t(`controlLayers.filter.${config.type}.label`) })}
</Text>
<Text color="base.500" textAlign="center">
{t('controlLayers.filter.forMoreControl')}
</Text>
</Flex>
<ButtonGroup isAttached={false} size="sm" w="full">
<Button variant="ghost" onClick={onClickAdvanced}>
{t('controlLayers.filter.advanced')}
</Button>
<Spacer />
<Button
onClick={adapter.filterer.apply}
loadingText={t('controlLayers.filter.apply')}
variant="ghost"
isDisabled={isProcessing || !isValid || !hasImageState}
>
{t('controlLayers.filter.apply')}
</Button>
<Button variant="ghost" onClick={adapter.filterer.cancel} loadingText={t('controlLayers.filter.cancel')}>
{t('controlLayers.filter.cancel')}
</Button>
</ButtonGroup>
</>
);
}
);
FilterContentSimple.displayName = 'FilterContentSimple';
export const Filter = () => {
const canvasManager = useCanvasManager();
@@ -182,8 +207,54 @@ export const Filter = () => {
if (!adapter) {
return null;
}
return <FilterContent adapter={adapter} />;
};
Filter.displayName = 'Filter';
const FilterContent = memo(
({ adapter }: { adapter: CanvasEntityAdapterRasterLayer | CanvasEntityAdapterControlLayer }) => {
const simplified = useStore(adapter.filterer.$simple);
const isCanvasFocused = useIsRegionFocused('canvas');
const isProcessing = useStore(adapter.filterer.$isProcessing);
const ref = useRef<HTMLDivElement>(null);
useFocusRegion('canvas', ref, { focusOnMount: true });
useRegisteredHotkeys({
id: 'applyFilter',
category: 'canvas',
callback: adapter.filterer.apply,
options: { enabled: !isProcessing && isCanvasFocused, enableOnFormTags: true },
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
});
useRegisteredHotkeys({
id: 'cancelFilter',
category: 'canvas',
callback: adapter.filterer.cancel,
options: { enabled: !isProcessing && isCanvasFocused, enableOnFormTags: true },
dependencies: [adapter.filterer, isProcessing, isCanvasFocused],
});
return (
<Flex
ref={ref}
bg="base.800"
borderRadius="base"
p={4}
flexDir="column"
gap={4}
w={420}
h="auto"
shadow="dark-lg"
transitionProperty="height"
transitionDuration="normal"
>
{simplified && <FilterContentSimple adapter={adapter} />}
{!simplified && <FilterContentAdvanced adapter={adapter} />}
</Flex>
);
}
);
FilterContent.displayName = 'FilterContent';

View File

@@ -0,0 +1,82 @@
import {
Badge,
CompositeNumberInput,
CompositeSlider,
Flex,
FormControl,
FormLabel,
useToken,
} from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import WavyLine from 'common/components/WavyLine';
import { selectImg2imgStrength, setImg2imgStrength } from 'features/controlLayers/store/paramsSlice';
import { selectActiveRasterLayerEntities } from 'features/controlLayers/store/selectors';
import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selectIsEnabled = createSelector(selectActiveRasterLayerEntities, (entities) => entities.length > 0);
export const ParamDenoisingStrength = memo(() => {
const img2imgStrength = useAppSelector(selectImg2imgStrength);
const dispatch = useAppDispatch();
const isEnabled = useAppSelector(selectIsEnabled);
const onChange = useCallback(
(v: number) => {
dispatch(setImg2imgStrength(v));
},
[dispatch]
);
const config = useAppSelector(selectImg2imgStrengthConfig);
const { t } = useTranslation();
const [invokeBlue300] = useToken('colors', ['invokeBlue.300']);
return (
<FormControl isDisabled={!isEnabled} p={1} justifyContent="space-between" h={8}>
<Flex gap={3} alignItems="center">
<InformationalPopover feature="paramDenoisingStrength">
<FormLabel mr={0}>{`${t('parameters.denoisingStrength')}`}</FormLabel>
</InformationalPopover>
{isEnabled && (
<WavyLine amplitude={img2imgStrength * 10} stroke={invokeBlue300} strokeWidth={1} width={40} height={14} />
)}
</Flex>
{isEnabled ? (
<>
<CompositeSlider
step={config.coarseStep}
fineStep={config.fineStep}
min={config.sliderMin}
max={config.sliderMax}
defaultValue={config.initial}
onChange={onChange}
value={img2imgStrength}
/>
<CompositeNumberInput
step={config.coarseStep}
fineStep={config.fineStep}
min={config.numberInputMin}
max={config.numberInputMax}
defaultValue={config.initial}
onChange={onChange}
value={img2imgStrength}
variant="outline"
/>
</>
) : (
<Flex alignItems="center">
<Badge opacity="0.6">
{t('common.disabled')} - {t('parameters.noRasterLayers')}
</Badge>
</Flex>
)}
</FormControl>
);
});
ParamDenoisingStrength.displayName = 'ParamDenoisingStrength';

View File

@@ -1,8 +1,8 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { deepClone } from 'common/util/deepClone';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { useEntityIsLocked } from 'features/controlLayers/hooks/useEntityIsLocked';
import {
@@ -10,6 +10,7 @@ import {
rasterLayerConvertedToInpaintMask,
rasterLayerConvertedToRegionalGuidance,
} from 'features/controlLayers/store/canvasSlice';
import { initialControlNet } from 'features/controlLayers/store/util';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiSwapBold } from 'react-icons/pi';
@@ -20,7 +21,6 @@ export const RasterLayerMenuItemsConvertToSubMenu = memo(() => {
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('raster_layer');
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
const isBusy = useCanvasIsBusy();
const isLocked = useEntityIsLocked(entityIdentifier);
@@ -37,10 +37,10 @@ export const RasterLayerMenuItemsConvertToSubMenu = memo(() => {
rasterLayerConvertedToControlLayer({
entityIdentifier,
replace: true,
overrides: { controlAdapter: defaultControlAdapter },
overrides: { controlAdapter: deepClone(initialControlNet) },
})
);
}, [defaultControlAdapter, dispatch, entityIdentifier]);
}, [dispatch, entityIdentifier]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiSwapBold />} isDisabled={isBusy || isLocked}>

View File

@@ -1,15 +1,16 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { deepClone } from 'common/util/deepClone';
import { CanvasEntityMenuItemsCopyToClipboard } from 'features/controlLayers/components/common/CanvasEntityMenuItemsCopyToClipboard';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { selectDefaultControlAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import {
rasterLayerConvertedToControlLayer,
rasterLayerConvertedToInpaintMask,
rasterLayerConvertedToRegionalGuidance,
} from 'features/controlLayers/store/canvasSlice';
import { initialControlNet } from 'features/controlLayers/store/util';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCopyBold } from 'react-icons/pi';
@@ -20,7 +21,6 @@ export const RasterLayerMenuItemsCopyToSubMenu = memo(() => {
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext('raster_layer');
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
const isBusy = useCanvasIsBusy();
const copyToInpaintMask = useCallback(() => {
@@ -35,10 +35,10 @@ export const RasterLayerMenuItemsCopyToSubMenu = memo(() => {
dispatch(
rasterLayerConvertedToControlLayer({
entityIdentifier,
overrides: { controlAdapter: defaultControlAdapter },
overrides: { controlAdapter: deepClone(initialControlNet) },
})
);
}, [defaultControlAdapter, dispatch, entityIdentifier]);
}, [dispatch, entityIdentifier]);
return (
<MenuItem {...subMenu.parentMenuItemProps} icon={<PiCopyBold />} isDisabled={isBusy}>

View File

@@ -2,6 +2,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Button, Collapse, Flex, Icon, Spacer, Text } from '@invoke-ai/ui-library';
import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover';
import { useBoolean } from 'common/hooks/useBoolean';
import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScrollStyles';
import { CanvasEntityAddOfTypeButton } from 'features/controlLayers/components/common/CanvasEntityAddOfTypeButton';
import { CanvasEntityMergeVisibleButton } from 'features/controlLayers/components/common/CanvasEntityMergeVisibleButton';
import { CanvasEntityTypeIsHiddenToggle } from 'features/controlLayers/components/common/CanvasEntityTypeIsHiddenToggle';
@@ -78,7 +79,7 @@ export const CanvasEntityGroupList = memo(({ isSelected, type, children }: Props
{isRenderableEntityType(type) && <CanvasEntityTypeIsHiddenToggle type={type} />}
<CanvasEntityAddOfTypeButton type={type} />
</Flex>
<Collapse in={collapse.isTrue}>
<Collapse in={collapse.isTrue} style={fixTooltipCloseOnScrollStyles}>
<Flex flexDir="column" gap={2} pt={2}>
{children}
</Flex>

View File

@@ -1,4 +1,4 @@
import { Box, chakra, Flex } from '@invoke-ai/ui-library';
import { Box, chakra, Flex, Tooltip } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { rgbColorToString } from 'common/util/colorCodeTransformers';
@@ -86,13 +86,63 @@ export const CanvasEntityPreviewImage = memo(() => {
useEffect(updatePreview, [updatePreview, canvasCache, nodeRect, pixelRect]);
return (
<Tooltip label={<TooltipContent canvasRef={canvasRef} />} p={2} closeOnScroll>
<Flex
position="relative"
alignItems="center"
justifyContent="center"
w={CONTAINER_WIDTH_PX}
h={CONTAINER_WIDTH_PX}
borderRadius="sm"
borderWidth={1}
bg="base.900"
flexShrink={0}
>
<Box
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
bgImage={TRANSPARENCY_CHECKERBOARD_PATTERN_DARK_DATAURL}
bgSize="5px"
/>
<ChakraCanvas position="relative" ref={canvasRef} objectFit="contain" maxW="full" maxH="full" />
</Flex>
</Tooltip>
);
});
CanvasEntityPreviewImage.displayName = 'CanvasEntityPreviewImage';
const TooltipContent = ({ canvasRef }: { canvasRef: React.RefObject<HTMLCanvasElement> }) => {
const canvasRef2 = useRef<HTMLCanvasElement>(null);
useEffect(() => {
if (!canvasRef2.current || !canvasRef.current) {
return;
}
const ctx = canvasRef2.current.getContext('2d');
if (!ctx) {
return;
}
canvasRef2.current.width = canvasRef.current.width;
canvasRef2.current.height = canvasRef.current.height;
ctx.clearRect(0, 0, canvasRef2.current.width, canvasRef2.current.height);
ctx.drawImage(canvasRef.current, 0, 0);
}, [canvasRef]);
return (
<Flex
position="relative"
alignItems="center"
justifyContent="center"
w={CONTAINER_WIDTH_PX}
h={CONTAINER_WIDTH_PX}
w={150}
h={150}
borderRadius="sm"
borderWidth={1}
bg="base.900"
@@ -105,11 +155,9 @@ export const CanvasEntityPreviewImage = memo(() => {
bottom={0}
left={0}
bgImage={TRANSPARENCY_CHECKERBOARD_PATTERN_DARK_DATAURL}
bgSize="5px"
bgSize="8px"
/>
<ChakraCanvas position="relative" ref={canvasRef} objectFit="contain" maxW="full" maxH="full" />
<ChakraCanvas position="relative" ref={canvasRef2} objectFit="contain" maxW="full" maxH="full" />
</Flex>
);
});
CanvasEntityPreviewImage.displayName = 'CanvasEntityPreviewImage';
};

View File

@@ -4,9 +4,10 @@ import type { CanvasEntityAdapterControlLayer } from 'features/controlLayers/kon
import type { CanvasEntityAdapterInpaintMask } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterInpaintMask';
import type { CanvasEntityAdapterRasterLayer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRasterLayer';
import type { CanvasEntityAdapterRegionalGuidance } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterRegionalGuidance';
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import type { CanvasEntityAdapterFromType } from 'features/controlLayers/konva/CanvasEntity/types';
import type { CanvasEntityIdentifier, CanvasRenderableEntityType } from 'features/controlLayers/store/types';
import type { PropsWithChildren } from 'react';
import { createContext, memo, useMemo, useSyncExternalStore } from 'react';
import { createContext, memo, useContext, useMemo, useSyncExternalStore } from 'react';
import { assert } from 'tsafe';
const EntityAdapterContext = createContext<
@@ -95,6 +96,17 @@ export const RegionalGuidanceAdapterGate = memo(({ children }: PropsWithChildren
return <EntityAdapterContext.Provider value={adapter}>{children}</EntityAdapterContext.Provider>;
});
export const useEntityAdapterContext = <T extends CanvasRenderableEntityType | undefined = CanvasRenderableEntityType>(
type?: T
): CanvasEntityAdapterFromType<T extends undefined ? CanvasRenderableEntityType : T> => {
const adapter = useContext(EntityAdapterContext);
assert(adapter, 'useEntityIdentifier must be used within a EntityIdentifierProvider');
if (type) {
assert(adapter.entityIdentifier.type === type, 'useEntityIdentifier must be used with the correct type');
}
return adapter as CanvasEntityAdapterFromType<T extends undefined ? CanvasRenderableEntityType : T>;
};
RegionalGuidanceAdapterGate.displayName = 'RegionalGuidanceAdapterGate';
export const useEntityAdapterSafe = (

View File

@@ -49,6 +49,7 @@ import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'ser
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
/** @knipignore */
export const selectDefaultControlAdapter = createSelector(
selectModelConfigsQuery,
selectBase,
@@ -92,11 +93,10 @@ export const selectDefaultIPAdapter = createSelector(
export const useAddControlLayer = () => {
const dispatch = useAppDispatch();
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
const func = useCallback(() => {
const overrides = { controlAdapter: defaultControlAdapter };
const overrides = { controlAdapter: deepClone(initialControlNet) };
dispatch(controlLayerAdded({ isSelected: true, overrides }));
}, [defaultControlAdapter, dispatch]);
}, [dispatch]);
return func;
};

View File

@@ -4,7 +4,7 @@ import type { SerializableObject } from 'common/types';
import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { selectDefaultControlAdapter, selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import {
controlLayerAdded,
@@ -25,7 +25,7 @@ import type {
Rect,
RegionalGuidanceReferenceImageState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -72,10 +72,16 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit
const result = await withResultAsync(() => {
const rasterAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
return canvasManager.compositor.getCompositeImageDTO(rasterAdapters, rect, {
is_intermediate: !saveToGallery,
metadata,
});
return canvasManager.compositor.getCompositeImageDTO(
rasterAdapters,
rect,
{
is_intermediate: !saveToGallery,
metadata,
},
undefined,
true // force upload the image to ensure it gets added to the gallery
);
});
if (result.isOk()) {
@@ -223,13 +229,12 @@ export const useNewRasterLayerFromBbox = () => {
export const useNewControlLayerFromBbox = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const defaultControlAdapter = useAppSelector(selectDefaultControlAdapter);
const arg = useMemo<UseSaveCanvasArg>(() => {
const onSave = (imageDTO: ImageDTO, rect: Rect) => {
const overrides: Partial<CanvasControlLayerState> = {
objects: [imageDTOToImageObject(imageDTO)],
controlAdapter: deepClone(defaultControlAdapter),
controlAdapter: deepClone(initialControlNet),
position: { x: rect.x, y: rect.y },
};
dispatch(controlLayerAdded({ overrides, isSelected: true }));
@@ -242,7 +247,7 @@ export const useNewControlLayerFromBbox = () => {
toastOk: t('controlLayers.newControlLayerOk'),
toastError: t('controlLayers.newControlLayerError'),
};
}, [defaultControlAdapter, dispatch, t]);
}, [dispatch, t]);
const func = useSaveCanvas(arg);
return func;
};

View File

@@ -253,18 +253,20 @@ export class CanvasCompositorModule extends CanvasModuleBase {
* @param rect The region to include in the rasterized image
* @param uploadOptions Options for uploading the image
* @param compositingOptions Options for compositing the entities
* @param forceUpload If true, the image is always re-uploaded, returning a new image DTO
* @returns A promise that resolves to the image DTO
*/
getCompositeImageDTO = async (
adapters: CanvasEntityAdapter[],
rect: Rect,
uploadOptions: Pick<UploadOptions, 'is_intermediate' | 'metadata'>,
compositingOptions?: CompositingOptions
compositingOptions?: CompositingOptions,
forceUpload?: boolean
): Promise<ImageDTO> => {
assert(rect.width > 0 && rect.height > 0, 'Unable to rasterize empty rect');
const hash = this.getCompositeHash(adapters, { rect });
const cachedImageName = this.manager.cache.imageNameCache.get(hash);
const cachedImageName = forceUpload ? undefined : this.manager.cache.imageNameCache.get(hash);
let imageDTO: ImageDTO | null = null;
@@ -327,6 +329,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
entityIdentifiers: T[],
deleteMergedEntities: boolean
): Promise<ImageDTO | null> => {
toast({ id: 'MERGE_LAYERS_TOAST', title: t('controlLayers.mergingLayers'), withCount: false });
if (entityIdentifiers.length <= 1) {
this.log.warn({ entityIdentifiers }, 'Cannot merge less than 2 entities');
return null;
@@ -349,7 +352,12 @@ export class CanvasCompositorModule extends CanvasModuleBase {
if (result.isErr()) {
this.log.error({ error: serializeError(result.error) }, 'Failed to merge selected entities');
toast({ title: t('controlLayers.mergeVisibleError'), status: 'error' });
toast({
id: 'MERGE_LAYERS_TOAST',
title: t('controlLayers.mergeVisibleError'),
status: 'error',
withCount: false,
});
return null;
}
@@ -381,7 +389,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
assert<Equals<typeof type, never>>(false, 'Unsupported type for merge');
}
toast({ title: t('controlLayers.mergeVisibleOk') });
toast({ id: 'MERGE_LAYERS_TOAST', title: t('controlLayers.mergeVisibleOk'), status: 'success', withCount: false });
return result.value;
};

View File

@@ -83,6 +83,13 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
* Whether the module has an image state. This is a computed value based on $imageState.
*/
$hasImageState = computed(this.$imageState, (imageState) => imageState !== null);
/**
* Whether the filter is in simple mode. In simple mode, the filter is started with a default filter config and the
* user is not presented with filter settings.
*/
$simple = atom<boolean>(false);
/**
* The filtered image object module, if it exists.
*/
@@ -147,7 +154,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
/**
* Starts the filter module.
* @param config The filter config to start with. If omitted, the default filter config is used.
* @param config The filter config to use. If omitted, the default filter config is used.
*/
start = (config?: FilterConfig) => {
const filteringAdapter = this.manager.stateApi.$filteringAdapter.get();
@@ -174,12 +181,14 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
// If a config is provided, use it
this.$filterConfig.set(config);
this.$initialFilterConfig.set(config);
this.$simple.set(true);
} else {
this.$filterConfig.set(this.createInitialFilterConfig());
const initialConfig = this.createInitialFilterConfig();
this.$filterConfig.set(initialConfig);
this.$initialFilterConfig.set(initialConfig);
this.$simple.set(false);
}
this.$initialFilterConfig.set(this.$filterConfig.get());
this.subscribe();
this.manager.stateApi.$filteringAdapter.set(this.parent);
@@ -198,7 +207,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
);
const modelConfig = this.manager.stateApi.runSelector(selectModelConfig);
// This always returns a filter
const filter = getFilterForModel(modelConfig);
const filter = getFilterForModel(modelConfig) ?? IMAGE_FILTERS.canny_edge_detection;
return filter.buildDefaults();
} else {
// Otherwise, used the default filter
@@ -210,6 +219,10 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
* Processes the filter, updating the module's state and rendering the filtered image.
*/
processImmediate = async () => {
if (!this.$isFiltering.get()) {
this.log.warn('Cannot process filter when not initialized');
return;
}
const config = this.$filterConfig.get();
const filterData = IMAGE_FILTERS[config.type];
@@ -342,7 +355,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
});
// Final cleanup and teardown, returning user to main canvas UI
this.resetEphemeralState();
this.teardown();
};
@@ -401,7 +413,7 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
this.imageModule.destroy();
this.imageModule = null;
}
const initialFilterConfig = this.$initialFilterConfig.get() ?? this.createInitialFilterConfig();
const initialFilterConfig = deepClone(this.$initialFilterConfig.get() ?? this.createInitialFilterConfig());
this.$filterConfig.set(initialFilterConfig);
this.$imageState.set(null);
this.$lastProcessedHash.set('');
@@ -409,9 +421,11 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
};
teardown = () => {
this.$initialFilterConfig.set(null);
this.konva.group.remove();
this.unsubscribe();
this.konva.group.remove();
// The reset must be done _after_ unsubscribing from listeners, in case the listeners would otherwise react to
// the reset. For example, if auto-processing is enabled and we reset the state, it may trigger processing.
this.resetEphemeralState();
this.$isFiltering.set(false);
this.manager.stateApi.$filteringAdapter.set(null);
};
@@ -428,7 +442,6 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
cancel = () => {
this.log.trace('Canceling');
this.resetEphemeralState();
this.teardown();
};

View File

@@ -1,4 +1,5 @@
import { Mutex } from 'async-mutex';
import { parseify } from 'common/util/serialize';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { getPrefixedId, loadImage } from 'features/controlLayers/konva/util';
@@ -26,13 +27,13 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
group: Konva.Group;
image: Konva.Image | null; // The image is loaded asynchronously, so it may not be available immediately
};
isLoading: boolean = false;
isError: boolean = false;
$isLoading = atom<boolean>(false);
$isError = atom<boolean>(false);
imageElement: HTMLImageElement | null = null;
subscriptions = new Set<() => void>();
$lastProgressEvent = atom<ProgressEventWithImage | null>(null);
hasActiveGeneration: boolean = false;
$hasActiveGeneration = atom<boolean>(false);
mutex: Mutex = new Mutex();
constructor(manager: CanvasManager) {
@@ -56,12 +57,9 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
this.subscriptions.add(
this.manager.stateApi.createStoreSubscription(selectCanvasQueueCounts, ({ data }) => {
if (data && (data.in_progress > 0 || data.pending > 0)) {
this.hasActiveGeneration = true;
this.$hasActiveGeneration.set(true);
} else {
this.hasActiveGeneration = false;
if (!this.manager.stagingArea.$isStaging.get()) {
this.$lastProgressEvent.set(null);
}
this.$hasActiveGeneration.set(false);
}
})
);
@@ -76,23 +74,36 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
if (!isProgressEventWithImage(data)) {
return;
}
if (!this.hasActiveGeneration) {
if (!this.$hasActiveGeneration.get()) {
return;
}
this.$lastProgressEvent.set(data);
};
// Handle a canceled or failed canvas generation. We should clear the progress image in this case.
const queueItemStatusChangedListener = (data: S['QueueItemStatusChangedEvent']) => {
if (data.destination !== 'canvas') {
return;
}
if (data.status === 'failed' || data.status === 'canceled') {
this.$lastProgressEvent.set(null);
this.$hasActiveGeneration.set(false);
}
};
const clearProgress = () => {
this.$lastProgressEvent.set(null);
};
this.manager.socket.on('invocation_progress', progressListener);
this.manager.socket.on('queue_item_status_changed', queueItemStatusChangedListener);
this.manager.socket.on('connect', clearProgress);
this.manager.socket.on('connect_error', clearProgress);
this.manager.socket.on('disconnect', clearProgress);
return () => {
this.manager.socket.off('invocation_progress', progressListener);
this.manager.socket.off('queue_item_status_changed', queueItemStatusChangedListener);
this.manager.socket.off('connect', clearProgress);
this.manager.socket.off('connect_error', clearProgress);
this.manager.socket.off('disconnect', clearProgress);
@@ -114,13 +125,13 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
this.konva.image?.destroy();
this.konva.image = null;
this.imageElement = null;
this.isLoading = false;
this.isError = false;
this.$isLoading.set(false);
this.$isError.set(false);
release();
return;
}
this.isLoading = true;
this.$isLoading.set(true);
const { x, y, width, height } = this.manager.stateApi.getBbox().rect;
try {
@@ -149,9 +160,9 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
// Should not be visible if the user has disabled showing staging images
this.konva.group.visible(this.manager.stagingArea.$shouldShowStagedImage.get());
} catch {
this.isError = true;
this.$isError.set(true);
} finally {
this.isLoading = false;
this.$isLoading.set(false);
release();
}
};
@@ -162,4 +173,16 @@ export class CanvasProgressImageModule extends CanvasModuleBase {
this.subscriptions.clear();
this.konva.group.destroy();
};
repr = () => {
return {
id: this.id,
type: this.type,
path: this.path,
$lastProgressEvent: parseify(this.$lastProgressEvent.get()),
$hasActiveGeneration: this.$hasActiveGeneration.get(),
$isError: this.$isError.get(),
$isLoading: this.$isLoading.get(),
};
};
}

View File

@@ -535,6 +535,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
* Processes the SAM points to segment the entity, updating the module's state and rendering the mask.
*/
processImmediate = async () => {
if (!this.$isSegmenting.get()) {
this.log.warn('Cannot process segmentation when not initialized');
return;
}
if (this.$isProcessing.get()) {
this.log.warn('Already processing');
return;
@@ -689,7 +694,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
});
// Final cleanup and teardown, returning user to main canvas UI
this.resetEphemeralState();
this.teardown();
};
@@ -758,7 +762,6 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
cancel = () => {
this.log.trace('Canceling');
// Reset the module's state and tear down, returning user to main canvas UI
this.resetEphemeralState();
this.teardown();
};
@@ -773,8 +776,11 @@ export class CanvasSegmentAnythingModule extends CanvasModuleBase {
* - Resets the global segmenting adapter
*/
teardown = () => {
this.konva.group.remove();
this.unsubscribe();
this.konva.group.remove();
// The reset must be done _after_ unsubscribing from listeners, in case the listeners would otherwise react to
// the reset. For example, if auto-processing is enabled and we reset the state, it may trigger processing.
this.resetEphemeralState();
this.$isSegmenting.set(false);
this.manager.stateApi.$segmentingAdapter.set(null);
};

View File

@@ -456,14 +456,14 @@ const PROCESSOR_TO_FILTER_MAP: Record<string, FilterType> = {
*/
export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapterModelConfig | null) => {
if (!modelConfig) {
// No model, use the default filter
return IMAGE_FILTERS.canny_edge_detection;
// No model
return null;
}
const preprocessor = modelConfig?.default_settings?.preprocessor;
if (!preprocessor) {
// No preprocessor, use the default filter
return IMAGE_FILTERS.canny_edge_detection;
// No preprocessor
return null;
}
if (isFilterType(preprocessor)) {
@@ -473,8 +473,8 @@ export const getFilterForModel = (modelConfig: ControlNetModelConfig | T2IAdapte
const filterName = PROCESSOR_TO_FILTER_MAP[preprocessor];
if (!filterName) {
// No filter found, use the default filter
return IMAGE_FILTERS.canny_edge_detection;
// No filter found
return null;
}
// Found a filter, use it

View File

@@ -78,8 +78,8 @@ export const initialT2IAdapter: T2IAdapterConfig = {
export const initialControlNet: ControlNetConfig = {
type: 'controlnet',
model: null,
weight: 1,
beginEndStepPct: [0, 1],
weight: 0.75,
beginEndStepPct: [0, 0.75],
controlMode: 'balanced',
};

View File

@@ -27,6 +27,8 @@ export const DeleteImageButton = memo((props: DeleteImageButtonProps) => {
aria-label={labelMessage}
isDisabled={isDisabled || !isConnected}
colorScheme="error"
variant="link"
alignSelf="stretch"
/>
);
});

View File

@@ -1,6 +1,7 @@
import { Button, Collapse, Flex, Icon, Text, useDisclosure } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppSelector } from 'app/store/storeHooks';
import { fixTooltipCloseOnScrollStyles } from 'common/util/fixTooltipCloseOnScrollStyles';
import {
selectBoardSearchText,
selectListBoardsQueryArgs,
@@ -104,7 +105,7 @@ export const BoardsList = ({ isPrivate }: Props) => {
)}
<AddBoardButton isPrivateBoard={isPrivate} />
</Flex>
<Collapse in={isOpen}>
<Collapse in={isOpen} style={fixTooltipCloseOnScrollStyles}>
<Flex direction="column" gap={1}>
{boardElements.length ? (
boardElements

View File

@@ -22,7 +22,7 @@ import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/us
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { Trans, useTranslation } from 'react-i18next';
import { PiArrowsOutBold, PiQuestion, PiSwapBold } from 'react-icons/pi';
import { PiArrowsLeftRightBold, PiArrowsOutBold, PiQuestion } from 'react-icons/pi';
export const CompareToolbar = memo(() => {
const { t } = useTranslation();
@@ -60,14 +60,16 @@ export const CompareToolbar = memo(() => {
useRegisteredHotkeys({ id: 'nextComparisonMode', category: 'viewer', callback: nextMode, dependencies: [nextMode] });
return (
<Flex w="full" gap={2}>
<Flex w="full" px={2} gap={2} bg="base.750" borderTopRadius="base" h={12}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<Flex marginInlineEnd="auto" alignItems="center">
<IconButton
icon={<PiSwapBold />}
icon={<PiArrowsLeftRightBold />}
aria-label={`${t('gallery.swapImages')} (C)`}
tooltip={`${t('gallery.swapImages')} (C)`}
onClick={swapImages}
variant="link"
alignSelf="stretch"
/>
{comparisonMode !== 'side-by-side' && (
<IconButton
@@ -75,14 +77,15 @@ export const CompareToolbar = memo(() => {
tooltip={t('gallery.stretchToFit')}
onClick={toggleComparisonFit}
colorScheme={comparisonFit === 'fill' ? 'invokeBlue' : 'base'}
variant="outline"
variant="link"
alignSelf="stretch"
icon={<PiArrowsOutBold />}
/>
)}
</Flex>
</Flex>
<Flex flex={1} gap={4} justifyContent="center">
<ButtonGroup variant="outline">
<Flex flex={1} justifyContent="center">
<ButtonGroup variant="outline" alignItems="center">
<Button
flexShrink={0}
onClick={setComparisonModeSlider}
@@ -110,11 +113,13 @@ export const CompareToolbar = memo(() => {
<Flex gap={2} marginInlineStart="auto" alignItems="center">
<Tooltip label={<CompareHelp />}>
<Flex alignItems="center">
<Icon boxSize={6} color="base.500" as={PiQuestion} lineHeight={0} />
<Icon boxSize={6} color="base.300" as={PiQuestion} lineHeight={0} />
</Flex>
</Tooltip>
<Button
variant="ghost"
variant="link"
alignSelf="stretch"
px={2}
aria-label={`${t('gallery.exitCompare')} (Esc)`}
tooltip={`${t('gallery.exitCompare')} (Esc)`}
onClick={exitCompare}

View File

@@ -1,4 +1,4 @@
import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
import { Divider, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
@@ -46,73 +46,81 @@ const CurrentImageButtonsContent = memo(({ imageDTO }: { imageDTO: ImageDTO }) =
return (
<>
<ButtonGroup>
<Menu isLazy>
<MenuButton
as={IconButton}
aria-label={t('parameters.imageActions')}
tooltip={t('parameters.imageActions')}
isDisabled={!imageDTO}
icon={<PiDotsThreeOutlineFill />}
/>
<MenuList>{imageDTO && <SingleSelectionMenuItems imageDTO={imageDTO} />}</MenuList>
</Menu>
</ButtonGroup>
<Menu isLazy>
<MenuButton
as={IconButton}
aria-label={t('parameters.imageActions')}
tooltip={t('parameters.imageActions')}
isDisabled={!imageDTO}
variant="link"
alignSelf="stretch"
icon={<PiDotsThreeOutlineFill />}
/>
<MenuList>{imageDTO && <SingleSelectionMenuItems imageDTO={imageDTO} />}</MenuList>
</Menu>
<ButtonGroup>
<IconButton
icon={<PiFlowArrowBold />}
tooltip={`${t('nodes.loadWorkflow')} (W)`}
aria-label={`${t('nodes.loadWorkflow')} (W)`}
isDisabled={!imageActions.hasWorkflow || !hasTemplates}
onClick={imageActions.loadWorkflow}
/>
<IconButton
icon={<PiArrowsCounterClockwiseBold />}
tooltip={`${t('parameters.remixImage')} (R)`}
aria-label={`${t('parameters.remixImage')} (R)`}
isDisabled={!imageActions.hasMetadata}
onClick={imageActions.remix}
/>
<IconButton
icon={<PiQuotesBold />}
tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!imageActions.hasPrompts}
onClick={imageActions.recallPrompts}
/>
<IconButton
icon={<PiPlantBold />}
tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!imageActions.hasSeed}
onClick={imageActions.recallSeed}
/>
<IconButton
icon={<PiRulerBold />}
tooltip={`${t('parameters.useSize')} (D)`}
aria-label={`${t('parameters.useSize')} (D)`}
onClick={imageActions.recallSize}
isDisabled={isStaging}
/>
<IconButton
icon={<PiAsteriskBold />}
tooltip={`${t('parameters.useAll')} (A)`}
aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={!imageActions.hasMetadata}
onClick={imageActions.recallAll}
/>
</ButtonGroup>
<Divider orientation="vertical" h={8} mx={2} />
{isUpscalingEnabled && (
<ButtonGroup>
<PostProcessingPopover imageDTO={imageDTO} />
</ButtonGroup>
)}
<IconButton
icon={<PiFlowArrowBold />}
tooltip={`${t('nodes.loadWorkflow')} (W)`}
aria-label={`${t('nodes.loadWorkflow')} (W)`}
isDisabled={!imageActions.hasWorkflow || !hasTemplates}
variant="link"
alignSelf="stretch"
onClick={imageActions.loadWorkflow}
/>
<IconButton
icon={<PiArrowsCounterClockwiseBold />}
tooltip={`${t('parameters.remixImage')} (R)`}
aria-label={`${t('parameters.remixImage')} (R)`}
isDisabled={!imageActions.hasMetadata}
variant="link"
alignSelf="stretch"
onClick={imageActions.remix}
/>
<IconButton
icon={<PiQuotesBold />}
tooltip={`${t('parameters.usePrompt')} (P)`}
aria-label={`${t('parameters.usePrompt')} (P)`}
isDisabled={!imageActions.hasPrompts}
variant="link"
alignSelf="stretch"
onClick={imageActions.recallPrompts}
/>
<IconButton
icon={<PiPlantBold />}
tooltip={`${t('parameters.useSeed')} (S)`}
aria-label={`${t('parameters.useSeed')} (S)`}
isDisabled={!imageActions.hasSeed}
variant="link"
alignSelf="stretch"
onClick={imageActions.recallSeed}
/>
<IconButton
icon={<PiRulerBold />}
tooltip={`${t('parameters.useSize')} (D)`}
aria-label={`${t('parameters.useSize')} (D)`}
variant="link"
alignSelf="stretch"
onClick={imageActions.recallSize}
isDisabled={isStaging}
/>
<IconButton
icon={<PiAsteriskBold />}
tooltip={`${t('parameters.useAll')} (A)`}
aria-label={`${t('parameters.useAll')} (A)`}
isDisabled={!imageActions.hasMetadata}
variant="link"
alignSelf="stretch"
onClick={imageActions.recallAll}
/>
<ButtonGroup>
<DeleteImageButton onClick={imageActions.delete} />
</ButtonGroup>
{isUpscalingEnabled && <PostProcessingPopover imageDTO={imageDTO} />}
<Divider orientation="vertical" h={8} mx={2} />
<DeleteImageButton onClick={imageActions.delete} />
</>
);
});

View File

@@ -37,7 +37,6 @@ export const ImageViewer = memo(({ closeButton }: Props) => {
ref={ref}
tabIndex={-1}
layerStyle="first"
p={2}
borderRadius="base"
position="absolute"
flexDirection="column"
@@ -51,7 +50,7 @@ export const ImageViewer = memo(({ closeButton }: Props) => {
>
{hasImageToCompare && <CompareToolbar />}
{!hasImageToCompare && <ViewerToolbar closeButton={closeButton} />}
<Box ref={containerRef} w="full" h="full">
<Box ref={containerRef} w="full" h="full" p={2}>
{!hasImageToCompare && <CurrentImagePreview />}
{hasImageToCompare && <ImageComparison containerDims={containerDims} />}
</Box>
@@ -84,7 +83,8 @@ const ImageViewerCloseButton = memo(() => {
tooltip={t('gallery.closeViewer')}
aria-label={t('gallery.closeViewer')}
icon={<PiXBold />}
variant="ghost"
variant="link"
alignSelf="stretch"
onClick={imageViewer.close}
/>
);

View File

@@ -38,7 +38,8 @@ export const ToggleMetadataViewerButton = memo(() => {
aria-label={`${t('parameters.info')} (I)`}
onClick={toggleMetadataViewer}
isDisabled={!imageDTO}
variant="outline"
variant="link"
alignSelf="stretch"
colorScheme={shouldShowImageDetails ? 'invokeBlue' : 'base'}
data-testid="toggle-show-metadata-button"
/>

View File

@@ -21,7 +21,8 @@ export const ToggleProgressButton = memo(() => {
tooltip={t('settings.displayInProgress')}
icon={<PiHourglassHighBold />}
onClick={onClick}
variant="outline"
variant="link"
alignSelf="stretch"
colorScheme={shouldShowProgressInViewer ? 'invokeBlue' : 'base'}
data-testid="toggle-show-progress-button"
/>

View File

@@ -12,18 +12,18 @@ type Props = {
export const ViewerToolbar = memo(({ closeButton }: Props) => {
return (
<Flex w="full" gap={2}>
<Flex w="full" px={2} gap={2} bg="base.750" borderTopRadius="base" h={12}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<Flex marginInlineEnd="auto" alignItems="center">
<ToggleProgressButton />
<ToggleMetadataViewerButton />
</Flex>
</Flex>
<Flex flex={1} gap={2} justifyContent="center">
<Flex flex={1} justifyContent="center" alignItems="center">
<CurrentImageButtons />
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto">
<Flex marginInlineStart="auto" alignItems="center">
{closeButton}
</Flex>
</Flex>

View File

@@ -11,6 +11,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
any: 'base',
'sd-1': 'green',
'sd-2': 'teal',
'sd-3': 'purple',
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
flux: 'gold',

View File

@@ -80,19 +80,19 @@ const ModelList = () => {
[clipVisionModels, searchTerm, filteredModelType]
);
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels();
const [vaeModels, { isLoading: isLoadingVAEModels }] = useVAEModels({ excludeSubmodels: true });
const filteredVAEModels = useMemo(
() => modelsFilter(vaeModels, searchTerm, filteredModelType),
[vaeModels, searchTerm, filteredModelType]
);
const [t5EncoderModels, { isLoading: isLoadingT5EncoderModels }] = useT5EncoderModels();
const [t5EncoderModels, { isLoading: isLoadingT5EncoderModels }] = useT5EncoderModels({ excludeSubmodels: true });
const filteredT5EncoderModels = useMemo(
() => modelsFilter(t5EncoderModels, searchTerm, filteredModelType),
[t5EncoderModels, searchTerm, filteredModelType]
);
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels();
const [clipEmbedModels, { isLoading: isLoadingClipEmbedModels }] = useCLIPEmbedModels({ excludeSubmodels: true });
const filteredClipEmbedModels = useMemo(
() => modelsFilter(clipEmbedModels, searchTerm, filteredModelType),
[clipEmbedModels, searchTerm, filteredModelType]

View File

@@ -8,6 +8,10 @@ import {
isBooleanFieldInputTemplate,
isCLIPEmbedModelFieldInputInstance,
isCLIPEmbedModelFieldInputTemplate,
isCLIPGEmbedModelFieldInputInstance,
isCLIPGEmbedModelFieldInputTemplate,
isCLIPLEmbedModelFieldInputInstance,
isCLIPLEmbedModelFieldInputTemplate,
isColorFieldInputInstance,
isColorFieldInputTemplate,
isControlNetModelFieldInputInstance,
@@ -34,6 +38,8 @@ import {
isModelIdentifierFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isSD3MainModelFieldInputInstance,
isSD3MainModelFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
@@ -54,6 +60,8 @@ import { memo } from 'react';
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
import CLIPGEmbedModelFieldInputComponent from './inputs/CLIPGEmbedModelFieldInputComponent';
import CLIPLEmbedModelFieldInputComponent from './inputs/CLIPLEmbedModelFieldInputComponent';
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
@@ -66,6 +74,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent'
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import SpandrelImageToImageModelFieldInputComponent from './inputs/SpandrelImageToImageModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
@@ -132,6 +141,14 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isCLIPLEmbedModelFieldInputInstance(fieldInstance) && isCLIPLEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isCLIPGEmbedModelFieldInputInstance(fieldInstance) && isCLIPGEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
@@ -168,10 +185,15 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@@ -39,14 +39,15 @@ const CLIPEmbedModelFieldInputComponent = (props: Props) => {
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
<Combobox
value={value}
placeholder={placeholder}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}

View File

@@ -0,0 +1,62 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldCLIPGEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPGEmbedModelConfig, isCLIPGEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CLIPGEmbedModelFieldInputInstance, CLIPGEmbedModelFieldInputTemplate>;
const CLIPGEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(value: CLIPGEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPGEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(CLIPGEmbedModelFieldInputComponent);

View File

@@ -0,0 +1,62 @@
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldCLIPLEmbedValueChanged } from 'features/nodes/store/nodesSlice';
import type { CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import { type CLIPLEmbedModelConfig, isCLIPLEmbedModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<CLIPLEmbedModelFieldInputInstance, CLIPLEmbedModelFieldInputTemplate>;
const CLIPLEmbedModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(value: CLIPLEmbedModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldCLIPLEmbedValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs: modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config)),
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
const required = props.fieldTemplate.required;
return (
<Flex w="full" alignItems="center" gap={2}>
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value && required}>
<Combobox
value={value}
placeholder={required ? placeholder : `(Optional) ${placeholder}`}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Tooltip>
</Flex>
);
};
export default memo(CLIPLEmbedModelFieldInputComponent);

View File

@@ -0,0 +1,59 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSD3Models } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
const SD3MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSD3Models();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl
className="nowheel nodrag"
isDisabled={!options.length}
isInvalid={!value && props.fieldTemplate.required}
>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(SD3MainModelFieldInputComponent);

Some files were not shown because too many files have changed in this diff Show More