Compare commits

...

62 Commits

Author SHA1 Message Date
psychedelicious
ceae1dc04f chore: bump version to v5.4.1 2024-11-15 11:21:24 +11:00
psychedelicious
4b390906bc fix(ui): multiple selection dnd sometimes doesn't get full selection
Turns out a gallery image's `imageDTO` object can actually be a different object by reference. I thought this was not possible thanks to how we have a quasi-normalized cache.

Need to check against image name instead of reference equality when deciding whether or not to use the single image or the gallery selection for the dnd payload.
2024-11-15 11:21:03 +11:00
psychedelicious
c5b8efe03b fix(ui): unable to use text inputs within draggable 2024-11-15 10:25:30 +11:00
psychedelicious
4d08d00ad8 chore(ui): knip 2024-11-14 13:38:40 -08:00
psychedelicious
9b0130262b fix(ui): use silent upload for single-image upload buttons 2024-11-14 13:38:40 -08:00
psychedelicious
878093f64e fix(ui): image uploading handling
Rework uploadImage and uploadImages helpers and the RTK listener, ensuring gallery view isn't changed unexpectedly and preventing extraneous toasts.

Fix staging area save to gallery button to essentially make a copy of the image, instead of changing its intermediate status.
2024-11-14 13:38:40 -08:00
psychedelicious
d5ff7ef250 feat(ui): update output only masked regions
- New name: "Output only Generated Regions"
- New default: true (this was the intention, but at some point the behaviour of the setting was inverted without the default being changed)
2024-11-14 13:35:55 -08:00
psychedelicious
f36583f866 feat(ui): tweak image selection/hover styling
The styling in gallery for selected vs hovered was very similar, leading users to think that the hovered image was also selected.

Reducing the borders for hovered images to a single pixel makes it easier to distinguish between selected and hovered.
2024-11-14 16:28:53 -05:00
psychedelicious
829bc1bc7d feat(ui): progress alert config setting
- Add `invocationProgressAlert` as a disable-able feature. Hide the alert and the setting in system settings when disabled.
- Fix merge conflict
2024-11-15 05:49:05 +11:00
Mary Hipp
17c7b57145 (ui): make detailed progress view a setting that can be hidden 2024-11-15 05:49:05 +11:00
psychedelicious
6a12189542 feat(ui): updated progress event display
- Tweak layout/styling of alerts for consistent spacing
- Add percentage to message if it has percentage
- Only show events if the destination is canvas (so workflows events are hidden for example)
2024-11-15 05:49:05 +11:00
psychedelicious
96a31a5563 feat(app): add more events when loading/running models 2024-11-15 05:49:05 +11:00
psychedelicious
067747eca9 feat(app): tweak model load events
- Pass in the `UtilInterface` to the `ModelsInterface` so we can call the simple `signal_progress` method instead of the complicated `emit_invocation_progress` method.
- Only emit load events when starting to load - not after.
- Add more detail to the messages, like submodel type
2024-11-15 05:49:05 +11:00
Mary Hipp
c7878fddc6 (pytest) mock emit_invocation_progress on events service 2024-11-15 05:49:05 +11:00
maryhipp
54c51e0a06 (worker) add progress images for downloading remote models 2024-11-15 05:49:05 +11:00
Mary Hipp
1640ea0298 (pytest) add missing arg for mocked context 2024-11-15 05:49:05 +11:00
Mary Hipp
0c32ae9775 (pytest) fix import 2024-11-15 05:49:05 +11:00
maryhipp
fdb8ca5165 (worker) use source if name is not available 2024-11-15 05:49:05 +11:00
Mary Hipp
571faf6d7c (pytest) add queue_item and invocation to data in context for test 2024-11-15 05:49:05 +11:00
Mary Hipp
bdbdb22b74 (ui) add Canvas Alert for invocation progress messages 2024-11-15 05:49:05 +11:00
maryhipp
9bbb5644af (worker) add invocation_progress events to model loading 2024-11-15 05:49:05 +11:00
Mary Hipp
e90ad19f22 (ui): update en string for full IP adapter 2024-11-14 10:07:42 -08:00
Ryan Dick
0ba11e8f73 SD3 Image-to-Image and Inpainting (#7295)
## Summary

Add support for SD3 image-to-image and inpainting. Similar to FLUX, the
implementation supports fractional denoise_start/denoise_end for more
fine-grained denoise strength control, and a gradient mask adjustment
schedule for smoother inpainting seams.

## Example
Workflow
<img width="1016" alt="image"
src="https://github.com/user-attachments/assets/ee598d77-be80-4ca7-9355-c3cbefa2ef43">

Result

![image](https://github.com/user-attachments/assets/43953fa7-0e4e-42b5-84e8-85cfeeeee00b)

## QA Instructions

- [x] Regression test of text-to-image
- [x] Test image-to-image without mask
- [x] Test that adjusting denoising_start allows fine-grained control of
amount of change in image-to-image
- [x] Test inpainting with mask
- [x] Smoke test SD1, SDXL, FLUX image-to-image to make sure there was
no regression with the frontend changes.

## Merge Plan

<!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like
DB schemas, may need some care when merging. For example, a careful
rebase by the change author, timing to not interfere with a pending
release, or a message to contributors on discord after merging.-->

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2024-11-14 09:33:51 -08:00
Ryan Dick
1cf7600f5b Merge branch 'main' into ryan/sd3-image-to-image 2024-11-14 09:25:23 -08:00
Ryan Dick
4f9d12b872 Fix FLUX diffusers LoRA models with no .proj_mlp layers (#7313)
## Summary

Add support for FLUX diffusers LoRA models without `.proj_mlp` layers.

## Related Issues / Discussions

Closes #7129 

## QA Instructions

- [x] FLUX diffusers LoRA **without .proj_mlp** layers
- [x] FLUX diffusers LoRA **with .proj_mlp** layers
- [x] FLUX diffusers LoRA **without .proj_mlp** layers, quantized base
model
- [x] FLUX diffusers LoRA **with .proj_mlp** layers, quantized base
model

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
2024-11-14 09:09:10 -08:00
Ryan Dick
68c3b0649b Add unit tests for FLUX diffusers LoRA without .proj_mlp layers. 2024-11-14 16:53:49 +00:00
Ryan Dick
8ef8bd4261 Add state dict tensor shapes for existing LoRA unit tests. 2024-11-14 16:53:49 +00:00
Ryan Dick
50897ba066 Add flag to optionally allow missing layer keys in FLUX lora loader. 2024-11-14 16:53:49 +00:00
Ryan Dick
3510643870 Support FLUX LoRAs without .proj_mlp layers. 2024-11-14 16:53:49 +00:00
Ryan Dick
ca9cb1c9ef Flux Vae broke for float16, force bfloat16 or float32 were compatible (#7213)
## Summary

The Flux VAE, like many VAEs, is broken if run using float16 inputs
returning black images due to NaNs
This will fix the issue by forcing the VAE to run in bfloat16 or float32
were compatible

## Related Issues / Discussions

Fix for issue https://github.com/invoke-ai/InvokeAI/issues/7208

## QA Instructions

Tested on MacOS, VAE works with float16 in the invoke.yaml and left to
default.
I also briefly forced it down the float32 route to check that to.
Needs testing on CUDA / ROCm

## Merge Plan

It should be a straight forward merge,
2024-11-13 15:51:40 -08:00
Ryan Dick
b89caa02bd Merge branch 'main' into flux_vae_fp16_broke 2024-11-13 15:33:43 -08:00
Ryan Dick
eaf4e08c44 Use vae.parameters() for more efficient access of the first model parameter. 2024-11-13 23:32:40 +00:00
Darrell
fb19621361 Updated link to flux ip adapter model 2024-11-12 08:11:40 -05:00
Mary Hipp
9179619077 actually use optimized denoising 2024-11-08 20:46:08 -05:00
Mary Hipp
13cb5f0ba2 Merge remote-tracking branch 'origin/main' into ryan/sd3-image-to-image 2024-11-08 20:29:56 -05:00
Mary Hipp
7e52fc1c17 Merge branch 'ryan/sd3-image-to-image' of https://github.com/invoke-ai/InvokeAI into ryan/sd3-image-to-image 2024-11-08 20:14:24 -05:00
Mary Hipp
7f60a4a282 (ui): update more generation settings for SD3 linear UI 2024-11-08 20:14:13 -05:00
psychedelicious
3f880496f7 feat(ui): clarify denoising strength badge text 2024-11-09 08:38:41 +11:00
Ryan Dick
f05efd3270 Fix import for getInfill. 2024-11-08 20:42:44 +00:00
psychedelicious
79eb8172b6 feat(ui): update warnings on upscaling tab based on model arch
When an unsupported model architecture is selected, show that warning only, without the extra warnings (i.e. no "missing tile controlnet" warning)

Update Invoke tooltip warnings accordingly

Closes #7239
Closes #7177
2024-11-09 07:34:03 +11:00
Ryan Dick
7732b5d478 Fix bug related to i2l nodes during graph construction of image-to-image workflows. 2024-11-08 20:15:34 +00:00
Mary Hipp
a2a1934b66 Merge branch 'ryan/sd3-image-to-image' of https://github.com/invoke-ai/InvokeAI into ryan/sd3-image-to-image 2024-11-08 13:43:19 -05:00
Mary Hipp
dff6570078 (ui) SD3 support in linear UI 2024-11-08 13:42:57 -05:00
maryhipp
04e4fb63af add SD3 generation modes for metadata validation 2024-11-08 13:13:58 -05:00
Vargol
83609d5008 Merge branch 'invoke-ai:main' into flux_vae_fp16_broke 2024-11-08 10:37:31 +00:00
David Burnett
2618ed0ae7 ruff complained 2024-11-08 10:31:53 +00:00
David Burnett
bb3cedddd5 Rework change based on comments 2024-11-08 10:27:47 +00:00
psychedelicious
5b3e1593ca fix(ui): restore missing image paste handler
Missed migrating this logic over during dnd migration.
2024-11-08 16:42:39 +11:00
psychedelicious
2d08078a7d fix(ui): fit bbox to layers math 2024-11-08 16:40:24 +11:00
Ryan Dick
0e6cb91863 Update SD3 InpaintExtension with gradient adjustment to match FLUX. 2024-11-07 22:55:30 +00:00
Ryan Dick
a0fefcd43f Switch to using a custom scheduler implementation for SD3 rather than the diffusers FlowMatchEulerDiscreteScheduler. It is easier to work with and enables us to re-use the clip_timestep_schedule_fractional() utility from FLUX. 2024-11-07 22:46:52 +00:00
Ryan Dick
a5f8c23dee Add inpainting support for SD3. 2024-11-07 20:21:43 +00:00
Ryan Dick
7bb4ea57c6 Add SD3ImageToLatentsInvocation. 2024-11-07 16:07:57 +00:00
Ryan Dick
75dc961bcb Add image-to-image support for SD3 - WIP. 2024-11-07 15:48:35 +00:00
Vargol
a9a1f6ef21 Merge branch 'invoke-ai:main' into flux_vae_fp16_broke 2024-11-07 14:02:51 +00:00
Jonathan
aa40161f26 Update flux_denoise.py
Added a bool to allow the node user to add noise in to initial latents (default) or to leave them alone.
2024-11-07 14:02:20 +00:00
psychedelicious
6efa812874 chore(ui): bump version to v5.4.1rc1 2024-11-07 14:02:20 +00:00
psychedelicious
8a683f5a3c feat(ui): updated whats new handling and v5.4.1 items 2024-11-07 14:02:20 +00:00
Brandon Rising
f4b0b6a93d fix: Look in known subfolders for configs for clip variants 2024-11-07 14:02:20 +00:00
Brandon Rising
1337c33ad3 fix: Avoid downloading unsafe .bin files if a safetensors file is available 2024-11-07 14:02:20 +00:00
David Burnett
496b02a3bc Same issue affects image2image, so do the same again 2024-11-06 17:47:22 -05:00
David Burnett
7b5efc2203 Flux Vae broke for float16, force bfloat16 or float32 were compatible 2024-11-06 17:47:22 -05:00
91 changed files with 5203 additions and 3418 deletions

View File

@@ -95,6 +95,7 @@ class CompelInvocation(BaseInvocation):
ti_manager,
),
):
context.util.signal_progress("Building conditioning")
assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(tokenizer, CLIPTokenizer)
compel = Compel(
@@ -191,6 +192,7 @@ class SDXLPromptInvocationBase:
ti_manager,
),
):
context.util.signal_progress("Building conditioning")
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(tokenizer, CLIPTokenizer)

View File

@@ -65,6 +65,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
context.util.signal_progress("Running VAE encoder")
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
masked_latents_name = context.tensors.save(tensor=masked_latents)

View File

@@ -131,6 +131,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
image_tensor = image_tensor.unsqueeze(0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
context.util.signal_progress("Running VAE encoder")
masked_latents = ImageToLatentsInvocation.vae_encode(
vae_info, self.fp32, self.tiled, masked_image.clone()
)

View File

@@ -71,6 +71,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)
assert isinstance(prompt_embeds, torch.Tensor)
@@ -111,6 +112,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)
assert isinstance(pooled_prompt_embeds, torch.Tensor)

View File

@@ -41,7 +41,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
vae_dtype = next(iter(vae.parameters())).dtype
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
img = vae.decode(latents)
img = img.clamp(-1, 1)
@@ -53,6 +54,7 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
context.util.signal_progress("Running VAE")
image = self._vae_decode(vae_info=vae_info, latents=latents)
TorchDevice.empty_cache()

View File

@@ -44,9 +44,8 @@ class FluxVaeEncodeInvocation(BaseInvocation):
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
with vae_info as vae:
assert isinstance(vae, AutoEncoder)
image_tensor = image_tensor.to(
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
)
vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
latents = vae.encode(image_tensor, sample=True, generator=generator)
return latents
@@ -60,6 +59,7 @@ class FluxVaeEncodeInvocation(BaseInvocation):
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
context.util.signal_progress("Running VAE")
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")

View File

@@ -117,6 +117,7 @@ class ImageToLatentsInvocation(BaseInvocation):
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
context.util.signal_progress("Running VAE encoder")
latents = self.vae_encode(
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
)

View File

@@ -60,6 +60,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE decoder")
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:

View File

@@ -147,6 +147,10 @@ GENERATION_MODES = Literal[
"flux_img2img",
"flux_inpaint",
"flux_outpaint",
"sd3_txt2img",
"sd3_img2img",
"sd3_inpaint",
"sd3_outpaint",
]

View File

@@ -1,16 +1,19 @@
from typing import Callable, Tuple
from typing import Callable, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from torchvision.transforms.functional import resize as tv_resize
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
Input,
InputField,
LatentsField,
SD3ConditioningField,
WithBoard,
WithMetadata,
@@ -19,7 +22,9 @@ from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
from invokeai.backend.model_manager.config import BaseModelType
from invokeai.backend.sd3.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@@ -30,16 +35,24 @@ from invokeai.backend.util.devices import TorchDevice
title="SD3 Denoise",
tags=["image", "sd3"],
category="image",
version="1.0.0",
version="1.1.0",
classification=Classification.Prototype,
)
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a SD3 model."""
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
default=None, description=FieldDescriptions.latents, input=Input.Connection
)
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
)
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
transformer: TransformerField = InputField(
description=FieldDescriptions.sd3_model,
input=Input.Connection,
title="Transformer",
description=FieldDescriptions.sd3_model, input=Input.Connection, title="Transformer"
)
positive_conditioning: SD3ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
@@ -61,6 +74,41 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
"""Prepare the inpaint mask.
- Loads the mask
- Resizes if necessary
- Casts to same device/dtype as latents
Args:
context (InvocationContext): The invocation context, for loading the inpaint mask.
latents (torch.Tensor): A latent image tensor. Used to determine the target shape, device, and dtype for the
inpaint mask.
Returns:
torch.Tensor | None: Inpaint mask. Values of 0.0 represent the regions to be fully denoised, and 1.0
represent the regions to be preserved.
"""
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
# 1.0 represents the regions to be preserved.
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
mask = 1.0 - mask
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
return mask
def _load_text_conditioning(
self,
context: InvocationContext,
@@ -170,14 +218,20 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
# Prepare the scheduler.
scheduler = FlowMatchEulerDiscreteScheduler()
scheduler.set_timesteps(num_inference_steps=self.steps, device=device)
timesteps = scheduler.timesteps
assert isinstance(timesteps, torch.Tensor)
# Prepare the timestep schedule.
# We add an extra step to the end to account for the final timestep of 0.0.
timesteps: list[float] = torch.linspace(1, 0, self.steps + 1).tolist()
# Clip the timesteps schedule based on denoising_start and denoising_end.
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
total_steps = len(timesteps) - 1
# Prepare the CFG scale list.
cfg_scale = self._prepare_cfg_scale(len(timesteps))
cfg_scale = self._prepare_cfg_scale(total_steps)
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=device, dtype=inference_dtype)
# Generate initial latent noise.
num_channels_latents = transformer_info.model.config.in_channels
@@ -191,9 +245,34 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
device=device,
seed=self.seed,
)
latents: torch.Tensor = noise
total_steps = len(timesteps)
# Prepare input latent image.
if init_latents is not None:
# Noise the init_latents by the appropriate amount for the first timestep.
t_0 = timesteps[0]
latents = t_0 * noise + (1.0 - t_0) * init_latents
else:
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
if self.denoising_start > 1e-5:
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
latents = noise
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
# denoising steps.
if len(timesteps) <= 1:
return latents
# Prepare inpaint extension.
inpaint_mask = self._prep_inpaint_mask(context, latents)
inpaint_extension: InpaintExtension | None = None
if inpaint_mask is not None:
assert init_latents is not None
inpaint_extension = InpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,
)
step_callback = self._build_step_callback(context)
step_callback(
@@ -210,11 +289,12 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
assert isinstance(transformer, SD3Transformer2DModel)
# 6. Denoising loop
for step_idx, t in tqdm(list(enumerate(timesteps))):
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
# Expand the latents if we are doing CFG.
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# Expand the timestep to match the latent model input.
timestep = t.expand(latent_model_input.shape[0])
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
timestep = torch.tensor([t_curr * 1000], device=device).expand(latent_model_input.shape[0])
noise_pred = transformer(
hidden_states=latent_model_input,
@@ -232,21 +312,19 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# Compute the previous noisy sample x_t -> x_t-1.
latents_dtype = latents.dtype
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
latents = latents.to(dtype=torch.float32)
latents = latents + (t_prev - t_curr) * noise_pred
latents = latents.to(dtype=latents_dtype)
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
# needed.
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if inpaint_extension is not None:
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, t_prev)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(t),
timestep=int(t_curr),
latents=latents,
),
)

View File

@@ -0,0 +1,65 @@
import einops
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
@invocation(
"sd3_i2l",
title="SD3 Image to Latents",
tags=["image", "latents", "vae", "i2l", "sd3"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates latents from an image."""
image: ImageField = InputField(description="The image to encode")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, AutoencoderKL)
vae.disable_tiling()
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
# TODO: Use seed to make sampling reproducible.
latents: torch.Tensor = image_tensor_dist.sample().to(dtype=vae.dtype)
latents = vae.config.scaling_factor * latents
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

View File

@@ -47,6 +47,7 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE")
assert isinstance(vae, (AutoencoderKL))
latents = latents.to(vae.device)

View File

@@ -95,6 +95,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
context.util.signal_progress("Running T5 encoder")
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
@@ -137,6 +138,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
context.util.signal_progress("Running CLIP encoder")
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(clip_tokenizer, CLIPTokenizer)

View File

@@ -160,6 +160,10 @@ class LoggerInterface(InvocationContextInterface):
class ImagesInterface(InvocationContextInterface):
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
super().__init__(services, data)
self._util = util
def save(
self,
image: Image,
@@ -186,6 +190,8 @@ class ImagesInterface(InvocationContextInterface):
The saved image DTO.
"""
self._util.signal_progress("Saving image")
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None
if metadata:
@@ -336,6 +342,10 @@ class ConditioningInterface(InvocationContextInterface):
class ModelsInterface(InvocationContextInterface):
"""Common API for loading, downloading and managing models."""
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
super().__init__(services, data)
self._util = util
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Check if a model exists.
@@ -368,11 +378,15 @@ class ModelsInterface(InvocationContextInterface):
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type)
else:
_submodel_type = submodel_type or identifier.submodel_type
submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type)
message = f"Loading model {model.name}"
if submodel_type:
message += f" ({submodel_type.value})"
self._util.signal_progress(message)
return self._services.model_manager.load.load_model(model, submodel_type)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
@@ -397,6 +411,10 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
message = f"Loading model {name}"
if submodel_type:
message += f" ({submodel_type.value})"
self._util.signal_progress(message)
return self._services.model_manager.load.load_model(configs[0], submodel_type)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
@@ -467,6 +485,7 @@ class ModelsInterface(InvocationContextInterface):
Returns:
Path to the downloaded model
"""
self._util.signal_progress(f"Downloading model {source}")
return self._services.model_manager.install.download_and_cache_model(source=source)
def load_local_model(
@@ -489,6 +508,8 @@ class ModelsInterface(InvocationContextInterface):
Returns:
A LoadedModelWithoutConfig object.
"""
self._util.signal_progress(f"Loading model {model_path.name}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
def load_remote_model(
@@ -514,6 +535,8 @@ class ModelsInterface(InvocationContextInterface):
A LoadedModelWithoutConfig object.
"""
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
self._util.signal_progress(f"Loading model {source}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
@@ -707,12 +730,12 @@ def build_invocation_context(
"""
logger = LoggerInterface(services=services, data=data)
images = ImagesInterface(services=services, data=data)
tensors = TensorsInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data)
config = ConfigInterface(services=services, data=data)
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
conditioning = ConditioningInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data, util=util)
images = ImagesInterface(services=services, data=data, util=util)
boards = BoardsInterface(services=services, data=data)
ctx = InvocationContext(

View File

@@ -45,8 +45,9 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
# Constants for FLUX.1
num_double_layers = 19
num_single_layers = 38
# inner_dim = 3072
# mlp_ratio = 4.0
hidden_size = 3072
mlp_ratio = 4.0
mlp_hidden_dim = int(hidden_size * mlp_ratio)
layers: dict[str, AnyLoRALayer] = {}
@@ -62,30 +63,43 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
assert len(src_layer_dict) == 0
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
def add_qkv_lora_layer_if_present(
src_keys: list[str],
src_weight_shapes: list[tuple[int, int]],
dst_qkv_key: str,
allow_missing_keys: bool = False,
) -> None:
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
"""
# We expect that either all src keys are present or none of them are. Verify this.
keys_present = [key in grouped_state_dict for key in src_keys]
assert all(keys_present) or not any(keys_present)
# If none of the keys are present, return early.
keys_present = [key in grouped_state_dict for key in src_keys]
if not any(keys_present):
return
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
sub_layers: list[LoRALayer] = []
for src_layer_dict in src_layer_dicts:
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
assert len(src_layer_dict) == 0
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True):
src_layer_dict = grouped_state_dict.pop(src_key, None)
if src_layer_dict is not None:
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
assert len(src_layer_dict) == 0
else:
if not allow_missing_keys:
raise ValueError(f"Missing LoRA layer: '{src_key}'.")
values = {
"lora_up.weight": torch.zeros((src_weight_shape[0], 1)),
"lora_down.weight": torch.zeros((1, src_weight_shape[1])),
}
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers)
# time_text_embed.timestep_embedder -> time_in.
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
@@ -118,6 +132,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
f"transformer_blocks.{i}.attn.to_k",
f"transformer_blocks.{i}.attn.to_v",
],
[(hidden_size, hidden_size), (hidden_size, hidden_size), (hidden_size, hidden_size)],
f"double_blocks.{i}.img_attn.qkv",
)
add_qkv_lora_layer_if_present(
@@ -126,6 +141,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
f"transformer_blocks.{i}.attn.add_k_proj",
f"transformer_blocks.{i}.attn.add_v_proj",
],
[(hidden_size, hidden_size), (hidden_size, hidden_size), (hidden_size, hidden_size)],
f"double_blocks.{i}.txt_attn.qkv",
)
@@ -175,7 +191,14 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
f"single_transformer_blocks.{i}.attn.to_v",
f"single_transformer_blocks.{i}.proj_mlp",
],
[
(hidden_size, hidden_size),
(hidden_size, hidden_size),
(hidden_size, hidden_size),
(mlp_hidden_dim, hidden_size),
],
f"single_blocks.{i}.linear1",
allow_missing_keys=True,
)
# Output projections.

View File

@@ -35,6 +35,7 @@ class ModelLoader(ModelLoaderBase):
self._logger = logger
self._ram_cache = ram_cache
self._torch_dtype = TorchDevice.choose_torch_dtype()
self._torch_device = TorchDevice.choose_torch_device()
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""

View File

@@ -84,7 +84,15 @@ class FluxVAELoader(ModelLoader):
model = AutoEncoder(ae_params[config.config_path])
sd = load_file(model_path)
model.load_state_dict(sd, assign=True)
model.to(dtype=self._torch_dtype)
# VAE is broken in float16, which mps defaults to
if self._torch_dtype == torch.float16:
try:
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
except TypeError:
vae_dtype = torch.float32
else:
vae_dtype = self._torch_dtype
model.to(vae_dtype)
return model

View File

@@ -300,7 +300,7 @@ ip_adapter_sdxl = StarterModel(
ip_adapter_flux = StarterModel(
name="Standard Reference (XLabs FLUX IP-Adapter)",
base=BaseModelType.Flux,
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/flux-ip-adapter.safetensors",
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors",
description="References images with a more generalized/looser degree of precision.",
type=ModelType.IPAdapter,
dependencies=[clip_vit_l_image_encoder],

View File

View File

@@ -0,0 +1,58 @@
import torch
class InpaintExtension:
"""A class for managing inpainting with SD3."""
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
"""Initialize InpaintExtension.
Args:
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0).
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
inpainted region with the background.
noise (torch.Tensor): The noise tensor used to noise the init_latents.
"""
assert init_latents.dim() == inpaint_mask.dim() == noise.dim() == 4
assert init_latents.shape[-2:] == inpaint_mask.shape[-2:] == noise.shape[-2:]
self._init_latents = init_latents
self._inpaint_mask = inpaint_mask
self._noise = noise
def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor:
"""Applies inpaint mask gradient adjustment and returns the inpaint mask to be used at the current timestep."""
# As we progress through the denoising process, we promote gradient regions of the mask to have a full weight of
# 1.0. This helps to produce more coherent seams around the inpainted region. We experimented with a (small)
# number of promotion strategies (e.g. gradual promotion based on timestep), but found that a simple cutoff
# threshold worked well.
# We use a small epsilon to avoid any potential issues with floating point precision.
eps = 1e-4
mask_gradient_t_cutoff = 0.5
if t_prev > mask_gradient_t_cutoff:
# Early in the denoising process, use the inpaint mask as-is.
return self._inpaint_mask
else:
# After the cut-off, promote all non-zero mask values to 1.0.
mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + eps), 1.0)
return mask
def merge_intermediate_latents_with_init_latents(
self, intermediate_latents: torch.Tensor, t_prev: float
) -> torch.Tensor:
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
trajectory.
This function should be called after each denoising step.
"""
mask = self._apply_mask_gradient_adjustment(t_prev)
# Noise the init latents for the current timestep.
noised_init_latents = self._noise * t_prev + (1.0 - t_prev) * self._init_latents
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
return intermediate_latents * mask + noised_init_latents * (1.0 - mask)

View File

@@ -174,7 +174,8 @@
"placeholderSelectAModel": "Select a model",
"reset": "Reset",
"none": "None",
"new": "New"
"new": "New",
"generating": "Generating"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -704,6 +705,8 @@
"baseModel": "Base Model",
"cancel": "Cancel",
"clipEmbed": "CLIP Embed",
"clipLEmbed": "CLIP-L Embed",
"clipGEmbed": "CLIP-G Embed",
"config": "Config",
"convert": "Convert",
"convertingModelBegin": "Converting Model. Please wait.",
@@ -997,7 +1000,7 @@
"controlNetControlMode": "Control Mode",
"copyImage": "Copy Image",
"denoisingStrength": "Denoising Strength",
"noRasterLayers": "No Raster Layers",
"disabledNoRasterContent": "Disabled (No Raster Content)",
"downloadImage": "Download Image",
"general": "General",
"guidance": "Guidance",
@@ -1137,6 +1140,7 @@
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
"showDetailedInvocationProgress": "Show Progress Details",
"showProgressInViewer": "Show Progress Images in Viewer",
"ui": "User Interface",
"clearIntermediatesDisabled": "Queue must be empty to clear intermediates",
@@ -1671,7 +1675,7 @@
"clearCaches": "Clear Caches",
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"outputOnlyMaskedRegions": "Output Only Masked Regions",
"outputOnlyMaskedRegions": "Output Only Generated Regions",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
@@ -1787,7 +1791,7 @@
},
"ipAdapterMethod": {
"ipAdapterMethod": "IP Adapter Method",
"full": "Full",
"full": "Style and Composition",
"style": "Style Only",
"composition": "Composition Only"
},
@@ -1999,7 +2003,9 @@
"upscaleModelDesc": "Upscale (image to image) model",
"missingUpscaleInitialImage": "Missing initial image for upscaling",
"missingUpscaleModel": "Missing upscale model",
"missingTileControlNetModel": "No valid tile ControlNet models installed"
"missingTileControlNetModel": "No valid tile ControlNet models installed",
"incompatibleBaseModel": "Unsupported main model architecture for upscaling",
"incompatibleBaseModelDesc": "Upscaling is supported for SD1.5 and SDXL architecture models only. Change the main model to enable upscaling."
},
"stylePresets": {
"active": "Active",

View File

@@ -8,6 +8,7 @@ import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { toast } from 'features/toast/toast';
@@ -34,8 +35,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
let buildGraphResult: Result<
{
g: Graph;
noise: Invocation<'noise' | 'flux_denoise'>;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
noise: Invocation<'noise' | 'flux_denoise' | 'sd3_denoise'>;
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'sd3_text_encoder'>;
},
Error
>;
@@ -51,6 +52,9 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
case `sd-2`:
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
break;
case `sd-3`:
buildGraphResult = await withResultAsync(() => buildSD3Graph(state, manager));
break;
case `flux`:
buildGraphResult = await withResultAsync(() => buildFLUXGraph(state, manager));
break;

View File

@@ -41,29 +41,33 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
log.debug({ imageDTO }, 'Image uploaded');
if (action.meta.arg.originalArgs.silent || imageDTO.is_intermediate) {
// When a "silent" upload is requested, or the image is intermediate, we can skip all post-upload actions,
// like toasts and switching the gallery view
return;
}
const boardId = imageDTO.board_id ?? 'none';
if (action.meta.arg.originalArgs.withToast) {
const DEFAULT_UPLOADED_TOAST = {
id: 'IMAGE_UPLOADED',
title: t('toast.imageUploaded'),
status: 'success',
} as const;
const DEFAULT_UPLOADED_TOAST = {
id: 'IMAGE_UPLOADED',
title: t('toast.imageUploaded'),
status: 'success',
} as const;
// default action - just upload and alert user
if (lastUploadedToastTimeout !== null) {
window.clearTimeout(lastUploadedToastTimeout);
}
const toastApi = toast({
...DEFAULT_UPLOADED_TOAST,
title: DEFAULT_UPLOADED_TOAST.title,
description: getUploadedToastDescription(boardId, state),
duration: null, // we will close the toast manually
});
lastUploadedToastTimeout = window.setTimeout(() => {
toastApi.close();
}, 3000);
// default action - just upload and alert user
if (lastUploadedToastTimeout !== null) {
window.clearTimeout(lastUploadedToastTimeout);
}
const toastApi = toast({
...DEFAULT_UPLOADED_TOAST,
title: DEFAULT_UPLOADED_TOAST.title,
description: getUploadedToastDescription(boardId, state),
duration: null, // we will close the toast manually
});
lastUploadedToastTimeout = window.setTimeout(() => {
toastApi.close();
}, 3000);
/**
* We only want to change the board and view if this is the first upload of a batch, else we end up hijacking

View File

@@ -25,7 +25,8 @@ export type AppFeature =
| 'invocationCache'
| 'bulkDownload'
| 'starterModels'
| 'hfToken';
| 'hfToken'
| 'invocationProgressAlert';
/**
* A disable-able Stable Diffusion feature

View File

@@ -61,6 +61,11 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
log.warn('Multiple files dropped but only one allowed');
return;
}
if (files.length === 0) {
// Should never happen
log.warn('No files dropped');
return;
}
const file = files[0];
assert(file !== undefined); // should never happen
const imageDTO = await uploadImage({
@@ -68,18 +73,20 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
silent: true,
}).unwrap();
if (onUpload) {
onUpload(imageDTO);
}
} else {
//
const imageDTOs = await uploadImages(
files.map((file) => ({
files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
silent: false,
isFirstUploadOfBatch: i === 0,
}))
);
if (onUpload) {

View File

@@ -119,11 +119,20 @@ const createSelector = (
reasons.push({ content: i18n.t('upscaling.exceedsMaxSize') });
}
}
if (!upscale.upscaleModel) {
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
}
if (!upscale.tileControlnetModel) {
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
if (model && !['sd-1', 'sdxl'].includes(model.base)) {
// When we are using an upsupported model, do not add the other warnings
reasons.push({ content: i18n.t('upscaling.incompatibleBaseModel') });
} else {
// Using a compatible model, add all warnings
if (!model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
if (!upscale.upscaleModel) {
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
}
if (!upscale.tileControlnetModel) {
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
}
}
} else {
if (canvasIsFiltering) {

View File

@@ -9,7 +9,7 @@ import {
useAddRegionalGuidance,
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
@@ -23,6 +23,7 @@ export const CanvasAddEntityButtons = memo(() => {
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
return (
<Flex w="full" h="full" justifyContent="center" gap={4}>
@@ -36,6 +37,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addGlobalReferenceImage}
isDisabled={isSD3}
>
{t('controlLayers.globalReferenceImage')}
</Button>
@@ -61,7 +63,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalGuidance}
isDisabled={isFLUX}
isDisabled={isFLUX || isSD3}
>
{t('controlLayers.regionalGuidance')}
</Button>
@@ -73,7 +75,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalReferenceImage}
isDisabled={isFLUX}
isDisabled={isFLUX || isSD3}
>
{t('controlLayers.regionalReferenceImage')}
</Button>
@@ -88,6 +90,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addControlLayer}
isDisabled={isSD3}
>
{t('controlLayers.controlLayer')}
</Button>

View File

@@ -0,0 +1,45 @@
import { Alert, AlertDescription, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectSystemShouldShowInvocationProgressDetail } from 'features/system/store/systemSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { $invocationProgressMessage } from 'services/events/stores';
const CanvasAlertsInvocationProgressContent = memo(() => {
const { t } = useTranslation();
const invocationProgressMessage = useStore($invocationProgressMessage);
if (!invocationProgressMessage) {
return null;
}
return (
<Alert status="loading" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<AlertIcon />
<AlertTitle>{t('common.generating')}</AlertTitle>
<AlertDescription>{invocationProgressMessage}</AlertDescription>
</Alert>
);
});
CanvasAlertsInvocationProgressContent.displayName = 'CanvasAlertsInvocationProgressContent';
export const CanvasAlertsInvocationProgress = memo(() => {
const isProgressMessageAlertEnabled = useFeatureStatus('invocationProgressAlert');
const shouldShowInvocationProgressDetail = useAppSelector(selectSystemShouldShowInvocationProgressDetail);
// The alert is disabled at the system level
if (!isProgressMessageAlertEnabled) {
return null;
}
// The alert is disabled at the user level
if (!shouldShowInvocationProgressDetail) {
return null;
}
return <CanvasAlertsInvocationProgressContent />;
});
CanvasAlertsInvocationProgress.displayName = 'CanvasAlertsInvocationProgress';

View File

@@ -9,7 +9,7 @@ import {
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
@@ -24,6 +24,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
return (
<Menu>
@@ -40,7 +41,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
/>
<MenuList>
<MenuGroup title={t('controlLayers.global')}>
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage}>
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage} isDisabled={isSD3}>
{t('controlLayers.globalReferenceImage')}
</MenuItem>
</MenuGroup>
@@ -48,15 +49,15 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX}>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX || isSD3}>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX}>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>
{t('controlLayers.regionalReferenceImage')}
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.layer_other')}>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer} isDisabled={isSD3}>
{t('controlLayers.controlLayer')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>

View File

@@ -4,6 +4,7 @@ import { attachClosestEdge, extractClosestEdge } from '@atlaskit/pragmatic-drag-
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
import { singleCanvasEntityDndSource } from 'features/dnd/dnd';
import { type DndListTargetState, idle } from 'features/dnd/types';
import { firefoxDndFix } from 'features/dnd/util';
import type { RefObject } from 'react';
import { useEffect, useState } from 'react';
@@ -17,6 +18,7 @@ export const useCanvasEntityListDnd = (ref: RefObject<HTMLElement>, entityIdenti
return;
}
return combine(
firefoxDndFix(element),
draggable({
element,
getInitialData() {

View File

@@ -21,6 +21,8 @@ import { GatedImageViewer } from 'features/gallery/components/ImageViewer/ImageV
import { memo, useCallback, useRef } from 'react';
import { PiDotsThreeOutlineVerticalFill } from 'react-icons/pi';
import { CanvasAlertsInvocationProgress } from './CanvasAlerts/CanvasAlertsInvocationProgress';
const MenuContent = () => {
return (
<CanvasManagerProviderGate>
@@ -84,6 +86,7 @@ export const CanvasMainPanelContent = memo(() => {
<CanvasAlertsSelectedEntityStatus />
<CanvasAlertsPreserveMask />
<CanvasAlertsSendingToGallery />
<CanvasAlertsInvocationProgress />
</Flex>
<Flex position="absolute" top={1} insetInlineEnd={1}>
<Menu>

View File

@@ -17,12 +17,15 @@ import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selectIsEnabled = createSelector(selectActiveRasterLayerEntities, (entities) => entities.length > 0);
const selectHasRasterLayersWithContent = createSelector(
selectActiveRasterLayerEntities,
(entities) => entities.length > 0
);
export const ParamDenoisingStrength = memo(() => {
const img2imgStrength = useAppSelector(selectImg2imgStrength);
const dispatch = useAppDispatch();
const isEnabled = useAppSelector(selectIsEnabled);
const hasRasterLayersWithContent = useAppSelector(selectHasRasterLayersWithContent);
const onChange = useCallback(
(v: number) => {
@@ -37,16 +40,16 @@ export const ParamDenoisingStrength = memo(() => {
const [invokeBlue300] = useToken('colors', ['invokeBlue.300']);
return (
<FormControl isDisabled={!isEnabled} p={1} justifyContent="space-between" h={8}>
<FormControl isDisabled={!hasRasterLayersWithContent} p={1} justifyContent="space-between" h={8}>
<Flex gap={3} alignItems="center">
<InformationalPopover feature="paramDenoisingStrength">
<FormLabel mr={0}>{`${t('parameters.denoisingStrength')}`}</FormLabel>
</InformationalPopover>
{isEnabled && (
{hasRasterLayersWithContent && (
<WavyLine amplitude={img2imgStrength * 10} stroke={invokeBlue300} strokeWidth={1} width={40} height={14} />
)}
</Flex>
{isEnabled ? (
{hasRasterLayersWithContent ? (
<>
<CompositeSlider
step={config.coarseStep}
@@ -70,9 +73,7 @@ export const ParamDenoisingStrength = memo(() => {
</>
) : (
<Flex alignItems="center">
<Badge opacity="0.6">
{t('common.disabled')} - {t('parameters.noRasterLayers')}
</Badge>
<Badge opacity="0.6">{t('parameters.disabledNoRasterContent')}</Badge>
</Flex>
)}
</FormControl>

View File

@@ -1,17 +1,19 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { withResultAsync } from 'common/util/result';
import { selectSelectedImage } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFloppyDiskBold } from 'react-icons/pi';
import { useAddImagesToBoardMutation, useChangeImageIsIntermediateMutation } from 'services/api/endpoints/images';
import { uploadImage } from 'services/api/endpoints/images';
const TOAST_ID = 'SAVE_STAGING_AREA_IMAGE_TO_GALLERY';
export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const selectedImage = useAppSelector(selectSelectedImage);
const [addImageToBoard] = useAddImagesToBoardMutation();
const [changeIsImageIntermediate] = useChangeImageIsIntermediateMutation();
const { t } = useTranslation();
@@ -19,21 +21,42 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
if (!selectedImage) {
return;
}
if (autoAddBoardId !== 'none') {
await addImageToBoard({ imageDTOs: [selectedImage.imageDTO], board_id: autoAddBoardId }).unwrap();
// The changeIsImageIntermediate request will use the board_id on this specific imageDTO object, so we need to
// update it before making the request - else the optimistic board updates will get out of whack.
changeIsImageIntermediate({
imageDTO: { ...selectedImage.imageDTO, board_id: autoAddBoardId },
// To save the image to gallery, we will download it and re-upload it. This allows the user to delete the image
// the gallery without borking the canvas, which may need this image to exist.
const result = await withResultAsync(async () => {
// Download the image
const res = await fetch(selectedImage.imageDTO.image_url);
const blob = await res.blob();
// Create a new file with the same name, which we will upload
const file = new File([blob], `copy_of_${selectedImage.imageDTO.image_name}`, { type: 'image/png' });
await uploadImage({
file,
// Image should show up in the Images tab
image_category: 'general',
is_intermediate: false,
// TODO(psyche): Maybe this should just save to the currently-selected board?
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
// We will do our own toast - opt out of the default handling
silent: true,
});
});
if (result.isOk()) {
toast({
id: TOAST_ID,
title: t('controlLayers.savedToGalleryOk'),
status: 'success',
});
} else {
changeIsImageIntermediate({
imageDTO: selectedImage.imageDTO,
is_intermediate: false,
toast({
id: TOAST_ID,
title: t('controlLayers.savedToGalleryError'),
status: 'error',
});
}
}, [addImageToBoard, autoAddBoardId, changeIsImageIntermediate, selectedImage]);
}, [autoAddBoardId, selectedImage, t]);
return (
<IconButton
@@ -42,7 +65,7 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
icon={<PiFloppyDiskBold />}
onClick={saveSelectedImageToGallery}
colorScheme="invokeBlue"
isDisabled={!selectedImage || !selectedImage.imageDTO.is_intermediate}
isDisabled={!selectedImage}
/>
);
});

View File

@@ -25,6 +25,8 @@ import type {
RegionalGuidanceReferenceImageState,
} from 'features/controlLayers/store/types';
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import type { BoardId } from 'features/gallery/store/types';
import { toast } from 'features/toast/toast';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -70,6 +72,11 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit
metadata = selectCanvasMetadata(store.getState());
}
let boardId: BoardId | undefined = undefined;
if (saveToGallery) {
boardId = selectAutoAddBoardId(store.getState());
}
const result = await withResultAsync(() => {
const rasterAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
return canvasManager.compositor.getCompositeImageDTO(
@@ -78,6 +85,8 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit
{
is_intermediate: !saveToGallery,
metadata,
board_id: boardId,
silent: true,
},
undefined,
true // force upload the image to ensure it gets added to the gallery
@@ -222,8 +231,8 @@ export const useNewRasterLayerFromBbox = () => {
toastError: t('controlLayers.newRasterLayerError'),
};
}, [dispatch, t]);
const newRasterLayerFromBbox = useSaveCanvas(arg);
return newRasterLayerFromBbox;
const func = useSaveCanvas(arg);
return func;
};
export const useNewControlLayerFromBbox = () => {

View File

@@ -28,19 +28,17 @@ import type {
} from 'features/controlLayers/store/types';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { atom, computed } from 'nanostores';
import type { Logger } from 'roarr';
import { serializeError } from 'serialize-error';
import type { UploadImageArg } from 'services/api/endpoints/images';
import { getImageDTOSafe, uploadImage } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import type { ImageDTO, UploadImageArg } from 'services/api/types';
import stableHash from 'stable-hash';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
import type { JsonObject } from 'type-fest';
import type { JsonObject, SetOptional } from 'type-fest';
type CompositingOptions = {
/**
@@ -259,7 +257,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
getCompositeImageDTO = async (
adapters: CanvasEntityAdapter[],
rect: Rect,
uploadOptions: Pick<UploadImageArg, 'is_intermediate' | 'metadata'>,
uploadOptions: SetOptional<Omit<UploadImageArg, 'file'>, 'image_category'>,
compositingOptions?: CompositingOptions,
forceUpload?: boolean
): Promise<ImageDTO> => {
@@ -299,10 +297,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
uploadImage({
file: new File([blob], 'canvas-composite.png', { type: 'image/png' }),
image_category: 'general',
is_intermediate: uploadOptions.is_intermediate,
board_id: uploadOptions.is_intermediate ? undefined : selectAutoAddBoardId(this.manager.store.getState()),
metadata: uploadOptions.metadata,
withToast: false,
...uploadOptions,
})
);
this.$isUploading.set(false);

View File

@@ -493,7 +493,7 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
file: new File([blob], `${this.id}_rasterized.png`, { type: 'image/png' }),
image_category: 'other',
is_intermediate: true,
withToast: false,
silent: true,
});
const imageObject = imageDTOToImageObject(imageDTO);
if (replaceObjects) {

View File

@@ -1,13 +1,8 @@
import {
roundDownToMultiple,
roundToMultiple,
roundToMultipleMin,
roundUpToMultiple,
} from 'common/util/roundDownToMultiple';
import { roundToMultiple, roundToMultipleMin } from 'common/util/roundDownToMultiple';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
import { fitRectToGrid, getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectBbox } from 'features/controlLayers/store/selectors';
import type { Coordinate, Rect } from 'features/controlLayers/store/types';
@@ -398,18 +393,12 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
}
// Determine the bbox size that fits within the visible rect. The bbox must be at least 64px in width and height,
// and its width and height must be multiples of 8px.
// and its width and height must be multiples of the bbox grid size.
const gridSize = this.manager.stateApi.getBboxGridSize();
// To be conservative, we will round up the x and y to the nearest grid size, and round down the width and height.
// This ensures the bbox is never _larger_ than the visible rect. If the bbox is larger than the visible, we
// will always trigger the outpainting workflow, which is not what the user wants.
const x = roundUpToMultiple(visibleRect.x, gridSize);
const y = roundUpToMultiple(visibleRect.y, gridSize);
const width = roundDownToMultiple(visibleRect.width, gridSize);
const height = roundDownToMultiple(visibleRect.height, gridSize);
const rect = fitRectToGrid(visibleRect, gridSize);
this.manager.stateApi.setGenerationBbox({ x, y, width, height });
this.manager.stateApi.setGenerationBbox(rect);
};
/**

View File

@@ -1,4 +1,6 @@
import { getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util';
import { roundUpToMultiple } from 'common/util/roundDownToMultiple';
import { fitRectToGrid, getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util';
import type { Rect } from 'features/controlLayers/store/types';
import { describe, expect, it } from 'vitest';
describe('util', () => {
@@ -44,4 +46,74 @@ describe('util', () => {
expect(union).toEqual({ x: 0, y: 0, width: 0, height: 0 });
});
});
describe('fitRectToGrid', () => {
it('should fit rect within grid without exceeding bounds', () => {
const rect: Rect = { x: 0, y: 0, width: 1047, height: 1758 };
const gridSize = 50;
const result = fitRectToGrid(rect, gridSize);
expect(result.x).toBe(roundUpToMultiple(rect.x, gridSize));
expect(result.y).toBe(roundUpToMultiple(rect.y, gridSize));
expect(result.width).toBeLessThanOrEqual(rect.width);
expect(result.height).toBeLessThanOrEqual(rect.height);
expect(result.width % gridSize).toBe(0);
expect(result.height % gridSize).toBe(0);
});
it('should handle small rect within grid bounds', () => {
const rect: Rect = { x: 20, y: 30, width: 80, height: 90 };
const gridSize = 25;
const result = fitRectToGrid(rect, gridSize);
expect(result.x).toBe(25);
expect(result.y).toBe(50);
expect(result.width % gridSize).toBe(0);
expect(result.height % gridSize).toBe(0);
expect(result.width).toBeLessThanOrEqual(rect.width);
expect(result.height).toBeLessThanOrEqual(rect.height);
});
it('should handle rect starting outside of grid alignment', () => {
const rect: Rect = { x: 13, y: 27, width: 94, height: 112 };
const gridSize = 20;
const result = fitRectToGrid(rect, gridSize);
expect(result.x).toBe(20);
expect(result.y).toBe(40);
expect(result.width % gridSize).toBe(0);
expect(result.height % gridSize).toBe(0);
expect(result.width).toBeLessThanOrEqual(rect.width);
expect(result.height).toBeLessThanOrEqual(rect.height);
});
it('should return the same rect if already aligned to grid', () => {
const rect: Rect = { x: 100, y: 100, width: 200, height: 300 };
const gridSize = 50;
const result = fitRectToGrid(rect, gridSize);
expect(result).toEqual(rect);
});
it('should handle large grid sizes relative to rect dimensions', () => {
const rect: Rect = { x: 250, y: 300, width: 400, height: 500 };
const gridSize = 100;
const result = fitRectToGrid(rect, gridSize);
expect(result.x).toBe(300);
expect(result.y).toBe(300);
expect(result.width % gridSize).toBe(0);
expect(result.height % gridSize).toBe(0);
expect(result.width).toBeLessThanOrEqual(rect.width);
expect(result.height).toBeLessThanOrEqual(rect.height);
});
it('should handle rect with zero width and height', () => {
const rect: Rect = { x: 40, y: 60, width: 100, height: 200 };
const gridSize = 20;
const result = fitRectToGrid(rect, gridSize);
expect(result).toEqual({ x: 40, y: 60, width: 100, height: 200 });
});
});
});

View File

@@ -1,5 +1,6 @@
import type { Selector, Store } from '@reduxjs/toolkit';
import { $authToken } from 'app/store/nanostores/authToken';
import { roundDownToMultiple, roundUpToMultiple } from 'common/util/roundDownToMultiple';
import type {
CanvasEntityIdentifier,
CanvasObjectState,
@@ -560,6 +561,33 @@ export const getRectIntersection = (...rects: Rect[]): Rect => {
return rect || getEmptyRect();
};
/**
* Fits a rect to the nearest multiple of the grid size, rounding down. The returned rect will be smaller than or equal
* to the input rect, and will be aligned to the grid.
*
* In other words, shrink the rect inwards on each size until it fits within the visible rect and aligns to the grid.
*
* @param rect The rect to fit
* @param gridSize The size of the grid
* @returns The fitted rect
*/
export const fitRectToGrid = (rect: Rect, gridSize: number): Rect => {
// Rounding x and y up effectively shrinks the left and top edges of the rect, and rounding width and height down
// effectively shrinks the right and bottom edges.
const x = roundUpToMultiple(rect.x, gridSize);
const y = roundUpToMultiple(rect.y, gridSize);
// Because we've just shifted the rect's x and y, we need to adjust the width and height by the same amount before
// we round those values down.
const offsetX = x - rect.x;
const offsetY = y - rect.y;
const width = roundDownToMultiple(rect.width - offsetX, gridSize);
const height = roundDownToMultiple(rect.height - offsetY, gridSize);
return { x, y, width, height };
};
/**
* Asserts that the value is never reached. Used for exhaustive checks in switch statements or conditional logic to ensure
* that all possible values are handled.

View File

@@ -90,7 +90,7 @@ const initialState: CanvasSettingsState = {
invertScrollForToolWidth: false,
color: { r: 31, g: 160, b: 224, a: 1 }, // invokeBlue.500
sendToCanvas: false,
outputOnlyMaskedRegions: false,
outputOnlyMaskedRegions: true,
autoProcess: true,
snapToGrid: true,
showProgressOnCanvas: true,

View File

@@ -9,6 +9,8 @@ import type {
ParameterCFGRescaleMultiplier,
ParameterCFGScale,
ParameterCLIPEmbedModel,
ParameterCLIPGEmbedModel,
ParameterCLIPLEmbedModel,
ParameterGuidance,
ParameterMaskBlurMethod,
ParameterModel,
@@ -71,6 +73,8 @@ export type ParamsState = {
refinerStart: number;
t5EncoderModel: ParameterT5EncoderModel | null;
clipEmbedModel: ParameterCLIPEmbedModel | null;
clipLEmbedModel: ParameterCLIPLEmbedModel | null;
clipGEmbedModel: ParameterCLIPGEmbedModel | null;
};
const initialState: ParamsState = {
@@ -115,6 +119,8 @@ const initialState: ParamsState = {
refinerStart: 0.8,
t5EncoderModel: null,
clipEmbedModel: null,
clipLEmbedModel: null,
clipGEmbedModel: null,
};
export const paramsSlice = createSlice({
@@ -192,6 +198,12 @@ export const paramsSlice = createSlice({
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
state.clipEmbedModel = action.payload;
},
clipLEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPLEmbedModel | null>) => {
state.clipLEmbedModel = action.payload;
},
clipGEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPGEmbedModel | null>) => {
state.clipGEmbedModel = action.payload;
},
vaePrecisionChanged: (state, action: PayloadAction<ParameterPrecision>) => {
state.vaePrecision = action.payload;
},
@@ -305,6 +317,8 @@ export const {
vaePrecisionChanged,
t5EncoderModelSelected,
clipEmbedModelSelected,
clipLEmbedModelSelected,
clipGEmbedModelSelected,
setClipSkip,
shouldUseCpuNoiseChanged,
positivePromptChanged,
@@ -341,6 +355,7 @@ export const createParamsSelector = <T>(selector: Selector<ParamsState, T>) =>
export const selectBase = createParamsSelector((params) => params.model?.base);
export const selectIsSDXL = createParamsSelector((params) => params.model?.base === 'sdxl');
export const selectIsFLUX = createParamsSelector((params) => params.model?.base === 'flux');
export const selectIsSD3 = createParamsSelector((params) => params.model?.base === 'sd-3');
export const selectModel = createParamsSelector((params) => params.model);
export const selectModelKey = createParamsSelector((params) => params.model?.key);
@@ -349,6 +364,10 @@ export const selectFLUXVAE = createParamsSelector((params) => params.fluxVAE);
export const selectVAEKey = createParamsSelector((params) => params.vae?.key);
export const selectT5EncoderModel = createParamsSelector((params) => params.t5EncoderModel);
export const selectCLIPEmbedModel = createParamsSelector((params) => params.clipEmbedModel);
export const selectCLIPLEmbedModel = createParamsSelector((params) => params.clipLEmbedModel);
export const selectCLIPGEmbedModel = createParamsSelector((params) => params.clipGEmbedModel);
export const selectCFGScale = createParamsSelector((params) => params.cfgScale);
export const selectGuidance = createParamsSelector((params) => params.guidance);
export const selectSteps = createParamsSelector((params) => params.steps);

View File

@@ -1,3 +1,4 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { draggable } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import type { ImageProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Image } from '@invoke-ai/ui-library';
@@ -5,6 +6,7 @@ import { useAppStore } from 'app/store/nanostores/store';
import { singleImageDndSource } from 'features/dnd/dnd';
import type { DndDragPreviewSingleImageState } from 'features/dnd/DndDragPreviewSingleImage';
import { createSingleImageDragPreview, setSingleImageDragPreview } from 'features/dnd/DndDragPreviewSingleImage';
import { firefoxDndFix } from 'features/dnd/util';
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { memo, useEffect, useState } from 'react';
import type { ImageDTO } from 'services/api/types';
@@ -35,25 +37,28 @@ export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: Props) => {
if (!element) {
return;
}
return draggable({
element,
getInitialData: () => singleImageDndSource.getData({ imageDTO }, imageDTO.image_name),
onDragStart: () => {
setIsDragging(true);
},
onDrop: () => {
setIsDragging(false);
},
onGenerateDragPreview: (args) => {
if (singleImageDndSource.typeGuard(args.source.data)) {
setSingleImageDragPreview({
singleImageDndData: args.source.data,
onGenerateDragPreviewArgs: args,
setDragPreviewState,
});
}
},
});
return combine(
firefoxDndFix(element),
draggable({
element,
getInitialData: () => singleImageDndSource.getData({ imageDTO }, imageDTO.image_name),
onDragStart: () => {
setIsDragging(true);
},
onDrop: () => {
setIsDragging(false);
},
onGenerateDragPreview: (args) => {
if (singleImageDndSource.typeGuard(args.source.data)) {
setSingleImageDragPreview({
singleImageDndData: args.source.data,
onGenerateDragPreviewArgs: args,
setDragPreviewState,
});
}
},
})
);
}, [imageDTO, element, store]);
useImageContextMenu(imageDTO, element);

View File

@@ -11,10 +11,11 @@ import type { DndTargetState } from 'features/dnd/types';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
import { toast } from 'features/toast/toast';
import { memo, useEffect, useRef, useState } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { type UploadImageArg, uploadImages } from 'services/api/endpoints/images';
import { uploadImages } from 'services/api/endpoints/images';
import { useBoardName } from 'services/api/hooks/useBoardName';
import type { UploadImageArg } from 'services/api/types';
import { z } from 'zod';
const ACCEPTED_IMAGE_TYPES = ['image/png', 'image/jpg', 'image/jpeg'];
@@ -71,13 +72,47 @@ export const FullscreenDropzone = memo(() => {
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const [dndState, setDndState] = useState<DndTargetState>('idle');
const uploadFilesSchema = useMemo(() => getFilesSchema(maxImageUploadCount), [maxImageUploadCount]);
const validateAndUploadFiles = useCallback(
(files: File[]) => {
const { getState } = getStore();
const parseResult = uploadFilesSchema.safeParse(files);
if (!parseResult.success) {
const description =
maxImageUploadCount === undefined
? t('toast.uploadFailedInvalidUploadDesc')
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
toast({
id: 'UPLOAD_FAILED',
title: t('toast.uploadFailed'),
description,
status: 'error',
});
return;
}
const autoAddBoardId = selectAutoAddBoardId(getState());
const uploadArgs: UploadImageArg[] = files.map((file, i) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
isFirstUploadOfBatch: i === 0,
}));
uploadImages(uploadArgs);
},
[maxImageUploadCount, t, uploadFilesSchema]
);
useEffect(() => {
const element = ref.current;
if (!element) {
return;
}
const { getState } = getStore();
const uploadFilesSchema = getFilesSchema(maxImageUploadCount);
return combine(
dropTargetForExternal({
@@ -85,32 +120,7 @@ export const FullscreenDropzone = memo(() => {
canDrop: containsFiles,
onDrop: ({ source }) => {
const files = getFiles({ source });
const parseResult = uploadFilesSchema.safeParse(files);
if (!parseResult.success) {
const description =
maxImageUploadCount === undefined
? t('toast.uploadFailedInvalidUploadDesc')
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
toast({
id: 'UPLOAD_FAILED',
title: t('toast.uploadFailed'),
description,
status: 'error',
});
return;
}
const autoAddBoardId = selectAutoAddBoardId(getState());
const uploadArgs: UploadImageArg[] = files.map((file) => ({
file,
image_category: 'user',
is_intermediate: false,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
}));
uploadImages(uploadArgs);
validateAndUploadFiles(files);
},
onDragEnter: () => {
setDndState('over');
@@ -131,7 +141,27 @@ export const FullscreenDropzone = memo(() => {
},
})
);
}, [maxImageUploadCount, t]);
}, [validateAndUploadFiles]);
useEffect(() => {
const controller = new AbortController();
document.addEventListener(
'paste',
(e) => {
if (!e.clipboardData?.files) {
return;
}
const files = Array.from(e.clipboardData.files);
validateAndUploadFiles(files);
},
{ signal: controller.signal }
);
return () => {
controller.abort();
};
}, [validateAndUploadFiles]);
return (
<Box ref={ref} data-dnd-state={dndState} sx={sx}>

View File

@@ -1,6 +1,7 @@
import type { GetOffsetFn } from '@atlaskit/pragmatic-drag-and-drop/dist/types/public-utils/element/custom-native-drag-preview/types';
import type { Input } from '@atlaskit/pragmatic-drag-and-drop/types';
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { noop } from 'lodash-es';
import type { CSSProperties } from 'react';
/**
@@ -44,3 +45,67 @@ export function triggerPostMoveFlash(element: HTMLElement, backgroundColor: CSSP
iterations: 1,
});
}
/**
* Firefox has a bug where input or textarea elements with draggable parents do not allow selection of their text.
*
* This helper function implements a workaround by setting the draggable attribute to false when the mouse is over a
* input or textarea child of the draggable. It reverts the attribute on mouse out.
*
* The fix is only applied for Firefox, and should be used in every `pragmatic-drag-and-drop` `draggable`.
*
* See:
* - https://github.com/atlassian/pragmatic-drag-and-drop/issues/111
* - https://bugzilla.mozilla.org/show_bug.cgi?id=1853069
*
* @example
* ```tsx
* useEffect(() => {
* const element = ref.current;
* if (!element) {
* return;
* }
* return combine(
* firefoxDndFix(element),
* // The rest of the draggable setup is the same
* draggable({
* element,
* // ...
* }),
* );
*```
* @param element The draggable element
* @returns A cleanup function that removes the event listeners
*/
export const firefoxDndFix = (element: HTMLElement): (() => void) => {
if (!navigator.userAgent.includes('Firefox')) {
return noop;
}
const abortController = new AbortController();
element.addEventListener(
'mouseover',
(event) => {
if (event.target instanceof HTMLTextAreaElement || event.target instanceof HTMLInputElement) {
element.setAttribute('draggable', 'false');
}
},
{ signal: abortController.signal }
);
element.addEventListener(
'mouseout',
(event) => {
if (event.target instanceof HTMLTextAreaElement || event.target instanceof HTMLInputElement) {
element.setAttribute('draggable', 'true');
}
},
{ signal: abortController.signal }
);
return () => {
element.setAttribute('draggable', 'true');
abortController.abort();
};
};

View File

@@ -1,8 +1,10 @@
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { useAppSelector } from 'app/store/storeHooks';
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { sentImageToCanvas } from 'features/gallery/store/actions';
@@ -20,6 +22,8 @@ export const ImageMenuItemNewFromImageSubMenu = memo(() => {
const imageDTO = useImageDTOContext();
const imageViewer = useImageViewer();
const isBusy = useCanvasIsBusy();
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
const onClickNewCanvasWithRasterLayerFromImage = useCallback(() => {
const { dispatch, getState } = store;
@@ -110,17 +114,25 @@ export const ImageMenuItemNewFromImageSubMenu = memo(() => {
<MenuItem
icon={<PiFileBold />}
onClickCapture={onClickNewCanvasWithControlLayerFromImage}
isDisabled={isBusy}
isDisabled={isBusy || isSD3}
>
{t('controlLayers.canvasAsControlLayer')}
</MenuItem>
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewInpaintMaskFromImage} isDisabled={isBusy}>
{t('controlLayers.inpaintMask')}
</MenuItem>
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewRegionalGuidanceFromImage} isDisabled={isBusy}>
<MenuItem
icon={<NewLayerIcon />}
onClickCapture={onClickNewRegionalGuidanceFromImage}
isDisabled={isBusy || isFLUX || isSD3}
>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewControlLayerFromImage} isDisabled={isBusy}>
<MenuItem
icon={<NewLayerIcon />}
onClickCapture={onClickNewControlLayerFromImage}
isDisabled={isBusy || isSD3}
>
{t('controlLayers.controlLayer')}
</MenuItem>
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewRasterLayerFromImage} isDisabled={isBusy}>

View File

@@ -12,6 +12,7 @@ import type { DndDragPreviewMultipleImageState } from 'features/dnd/DndDragPrevi
import { createMultipleImageDragPreview, setMultipleImageDragPreview } from 'features/dnd/DndDragPreviewMultipleImage';
import type { DndDragPreviewSingleImageState } from 'features/dnd/DndDragPreviewSingleImage';
import { createSingleImageDragPreview, setSingleImageDragPreview } from 'features/dnd/DndDragPreviewSingleImage';
import { firefoxDndFix } from 'features/dnd/util';
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { GalleryImageHoverIcons } from 'features/gallery/components/ImageGrid/GalleryImageHoverIcons';
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
@@ -66,7 +67,7 @@ const galleryImageContainerSX = {
},
'&:hover::before': {
boxShadow:
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
'inset 0px 0px 0px 1px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected=true]::before': {
boxShadow:
@@ -115,13 +116,17 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
return;
}
return combine(
firefoxDndFix(element),
draggable({
element,
getInitialData: () => {
const { gallery } = store.getState();
// When we have multiple images selected, and the dragged image is part of the selection, initiate a
// multi-image drag.
if (gallery.selection.length > 1 && gallery.selection.includes(imageDTO)) {
if (
gallery.selection.length > 1 &&
gallery.selection.find(({ image_name }) => image_name === imageDTO.image_name) !== undefined
) {
return multipleImageDndSource.getData({
imageDTOs: gallery.selection,
boardId: gallery.selectedBoardId,

View File

@@ -2,6 +2,7 @@ import { Box, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasAlertsInvocationProgress } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsInvocationProgress';
import { CanvasAlertsSendingToCanvas } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSendingTo';
import { DndImage } from 'features/dnd/DndImage';
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
@@ -48,9 +49,18 @@ const CurrentImagePreview = () => {
position="relative"
>
<ImageContent imageDTO={imageDTO} />
<Box position="absolute" top={0} insetInlineStart={0}>
<Flex
flexDir="column"
gap={2}
position="absolute"
top={0}
insetInlineStart={0}
pointerEvents="none"
alignItems="flex-start"
>
<CanvasAlertsSendingToCanvas />
</Box>
<CanvasAlertsInvocationProgress />
</Flex>
{shouldShowImageDetails && imageDTO && (
<Box position="absolute" opacity={0.8} top={0} width="full" height="full" borderRadius="base">
<ImageMetadataViewer image={imageDTO} />

View File

@@ -44,7 +44,6 @@ export const ImageViewer = memo(({ closeButton }: Props) => {
right={0}
bottom={0}
left={0}
rowGap={2}
alignItems="center"
justifyContent="center"
>

View File

@@ -10,6 +10,8 @@ import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
{ value: 'sd-3', label: MODEL_TYPE_MAP['sd-3'] },
{ value: 'flux', label: MODEL_TYPE_MAP['flux'] },
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
];

View File

@@ -4,6 +4,7 @@ import { attachClosestEdge, extractClosestEdge } from '@atlaskit/pragmatic-drag-
import { singleWorkflowFieldDndSource } from 'features/dnd/dnd';
import type { DndListTargetState } from 'features/dnd/types';
import { idle } from 'features/dnd/types';
import { firefoxDndFix } from 'features/dnd/util';
import type { FieldIdentifier } from 'features/nodes/types/field';
import type { RefObject } from 'react';
import { useEffect, useState } from 'react';
@@ -18,6 +19,7 @@ export const useLinearViewFieldDnd = (ref: RefObject<HTMLElement>, fieldIdentifi
return;
}
return combine(
firefoxDndFix(element),
draggable({
element,
getInitialData() {

View File

@@ -9,8 +9,8 @@ export const prepareLinearUIBatch = (
state: RootState,
g: Graph,
prepend: boolean,
noise: Invocation<'noise' | 'flux_denoise'>,
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>,
noise: Invocation<'noise' | 'flux_denoise' | 'sd3_denoise'>,
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'sd3_text_encoder'>,
origin: 'canvas' | 'workflows' | 'upscaling',
destination: 'canvas' | 'gallery'
): BatchConfig => {

View File

@@ -2,16 +2,18 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type { CanvasState, Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { addImageToLatents } from 'features/nodes/util/graph/graphBuilderUtils';
import { isEqual } from 'lodash-es';
import type { Invocation } from 'services/api/types';
type AddImageToImageArg = {
g: Graph;
manager: CanvasManager;
l2i: Invocation<'l2i' | 'flux_vae_decode'>;
denoise: Invocation<'denoise_latents' | 'flux_denoise'>;
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader'>;
l2i: Invocation<'l2i' | 'flux_vae_decode' | 'sd3_l2i'>;
i2lNodeType: 'i2l' | 'flux_vae_encode' | 'sd3_i2l';
denoise: Invocation<'denoise_latents' | 'flux_denoise' | 'sd3_denoise'>;
vaeSource: Invocation<
'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader' | 'sd3_model_loader'
>;
originalSize: Dimensions;
scaledSize: Dimensions;
bbox: CanvasState['bbox'];
@@ -23,6 +25,7 @@ export const addImageToImage = async ({
g,
manager,
l2i,
i2lNodeType,
denoise,
vaeSource,
originalSize,
@@ -30,10 +33,13 @@ export const addImageToImage = async ({
bbox,
denoising_start,
fp32,
}: AddImageToImageArg): Promise<Invocation<'img_resize' | 'l2i' | 'flux_vae_decode'>> => {
}: AddImageToImageArg): Promise<Invocation<'img_resize' | 'l2i' | 'flux_vae_decode' | 'sd3_l2i'>> => {
denoise.denoising_start = denoising_start;
const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, { is_intermediate: true });
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
is_intermediate: true,
silent: true,
});
if (!isEqual(scaledSize, originalSize)) {
// Resize the initial image to the scaled size, denoise, then resize back to the original size
@@ -44,7 +50,12 @@ export const addImageToImage = async ({
...scaledSize,
});
const i2l = addImageToLatents(g, l2i.type === 'flux_vae_decode', fp32);
const i2l = g.addNode({
id: i2lNodeType,
type: i2lNodeType,
image: image_name ? { image_name } : undefined,
...(i2lNodeType === 'i2l' ? { fp32 } : {}),
});
const resizeImageToOriginalSize = g.addNode({
type: 'img_resize',
@@ -61,7 +72,12 @@ export const addImageToImage = async ({
return resizeImageToOriginalSize;
} else {
// No need to resize, just decode
const i2l = addImageToLatents(g, l2i.type === 'flux_vae_decode', fp32, image_name);
const i2l = g.addNode({
id: i2lNodeType,
type: i2lNodeType,
image: image_name ? { image_name } : undefined,
...(i2lNodeType === 'i2l' ? { fp32 } : {}),
});
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(i2l, 'latents', denoise, 'latents');
return l2i;

View File

@@ -6,7 +6,6 @@ import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { addImageToLatents } from 'features/nodes/util/graph/graphBuilderUtils';
import { isEqual } from 'lodash-es';
import type { Invocation } from 'services/api/types';
@@ -14,10 +13,13 @@ type AddInpaintArg = {
state: RootState;
g: Graph;
manager: CanvasManager;
l2i: Invocation<'l2i' | 'flux_vae_decode'>;
denoise: Invocation<'denoise_latents' | 'flux_denoise'>;
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader'>;
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader'>;
l2i: Invocation<'l2i' | 'flux_vae_decode' | 'sd3_l2i'>;
i2lNodeType: 'i2l' | 'flux_vae_encode' | 'sd3_i2l';
denoise: Invocation<'denoise_latents' | 'flux_denoise' | 'sd3_denoise'>;
vaeSource: Invocation<
'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader' | 'sd3_model_loader'
>;
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'sd3_model_loader'>;
originalSize: Dimensions;
scaledSize: Dimensions;
denoising_start: number;
@@ -29,6 +31,7 @@ export const addInpaint = async ({
g,
manager,
l2i,
i2lNodeType,
denoise,
vaeSource,
modelLoader,
@@ -48,16 +51,23 @@ export const addInpaint = async ({
const rasterAdapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
const initialImage = await manager.compositor.getCompositeImageDTO(rasterAdapters, bbox.rect, {
is_intermediate: true,
silent: true,
});
const inpaintMaskAdapters = manager.compositor.getVisibleAdaptersOfType('inpaint_mask');
const maskImage = await manager.compositor.getCompositeImageDTO(inpaintMaskAdapters, bbox.rect, {
is_intermediate: true,
silent: true,
});
if (!isEqual(scaledSize, originalSize)) {
// Scale before processing requires some resizing
const i2l = addImageToLatents(g, modelLoader.type === 'flux_model_loader', fp32, initialImage.image_name);
const i2l = g.addNode({
id: i2lNodeType,
type: i2lNodeType,
image: initialImage.image_name ? { image_name: initialImage.image_name } : undefined,
...(i2lNodeType === 'i2l' ? { fp32 } : {}),
});
const resizeImageToScaledSize = g.addNode({
type: 'img_resize',
@@ -102,7 +112,7 @@ export const addInpaint = async ({
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
if (modelLoader.type !== 'flux_model_loader') {
if (modelLoader.type !== 'flux_model_loader' && modelLoader.type !== 'sd3_model_loader') {
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
}
g.addEdge(resizeImageToScaledSize, 'image', createGradientMask, 'image');
@@ -126,7 +136,12 @@ export const addInpaint = async ({
return resizeOutput;
} else {
// No scale before processing, much simpler
const i2l = addImageToLatents(g, modelLoader.type === 'flux_model_loader', fp32, initialImage.image_name);
const i2l = g.addNode({
id: i2lNodeType,
type: i2lNodeType,
image: initialImage.image_name ? { image_name: initialImage.image_name } : undefined,
...(i2lNodeType === 'i2l' ? { fp32 } : {}),
});
const alphaToMask = g.addNode({
id: getPrefixedId('alpha_to_mask'),
@@ -153,7 +168,7 @@ export const addInpaint = async ({
g.addEdge(i2l, 'latents', denoise, 'latents');
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
if (modelLoader.type !== 'flux_model_loader') {
if (modelLoader.type !== 'flux_model_loader' && modelLoader.type !== 'sd3_model_loader') {
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
}
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');

View File

@@ -11,7 +11,7 @@ import type { Invocation } from 'services/api/types';
export const addNSFWChecker = (
g: Graph,
imageOutput: Invocation<
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode' | 'sd3_l2i'
>
): Invocation<'img_nsfw'> => {
const nsfw = g.addNode({

View File

@@ -6,7 +6,7 @@ import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { Dimensions } from 'features/controlLayers/store/types';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import { addImageToLatents, getInfill } from 'features/nodes/util/graph/graphBuilderUtils';
import { getInfill } from 'features/nodes/util/graph/graphBuilderUtils';
import { isEqual } from 'lodash-es';
import type { Invocation } from 'services/api/types';
@@ -14,10 +14,13 @@ type AddOutpaintArg = {
state: RootState;
g: Graph;
manager: CanvasManager;
l2i: Invocation<'l2i' | 'flux_vae_decode'>;
denoise: Invocation<'denoise_latents' | 'flux_denoise'>;
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader'>;
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader'>;
l2i: Invocation<'l2i' | 'flux_vae_decode' | 'sd3_l2i'>;
i2lNodeType: 'i2l' | 'flux_vae_encode' | 'sd3_i2l';
denoise: Invocation<'denoise_latents' | 'flux_denoise' | 'sd3_denoise'>;
vaeSource: Invocation<
'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader' | 'sd3_model_loader'
>;
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'sd3_model_loader'>;
originalSize: Dimensions;
scaledSize: Dimensions;
denoising_start: number;
@@ -29,6 +32,7 @@ export const addOutpaint = async ({
g,
manager,
l2i,
i2lNodeType,
denoise,
vaeSource,
modelLoader,
@@ -48,11 +52,13 @@ export const addOutpaint = async ({
const rasterAdapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
const initialImage = await manager.compositor.getCompositeImageDTO(rasterAdapters, bbox.rect, {
is_intermediate: true,
silent: true,
});
const inpaintMaskAdapters = manager.compositor.getVisibleAdaptersOfType('inpaint_mask');
const maskImage = await manager.compositor.getCompositeImageDTO(inpaintMaskAdapters, bbox.rect, {
is_intermediate: true,
silent: true,
});
const infill = getInfill(g, params);
@@ -108,14 +114,19 @@ export const addOutpaint = async ({
g.addEdge(infill, 'image', createGradientMask, 'image');
g.addEdge(resizeInputMaskToScaledSize, 'image', createGradientMask, 'mask');
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
if (modelLoader.type !== 'flux_model_loader') {
if (modelLoader.type !== 'flux_model_loader' && modelLoader.type !== 'sd3_model_loader') {
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
}
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
// Decode infilled image and connect to denoise
const i2l = addImageToLatents(g, modelLoader.type === 'flux_model_loader', fp32);
const i2l = g.addNode({
id: i2lNodeType,
type: i2lNodeType,
...(i2lNodeType === 'i2l' ? { fp32 } : {}),
});
g.addEdge(infill, 'image', i2l, 'image');
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(i2l, 'latents', denoise, 'latents');
@@ -150,7 +161,11 @@ export const addOutpaint = async ({
} else {
infill.image = { image_name: initialImage.image_name };
// No scale before processing, much simpler
const i2l = addImageToLatents(g, modelLoader.type === 'flux_model_loader', fp32);
const i2l = g.addNode({
id: i2lNodeType,
type: i2lNodeType,
...(i2lNodeType === 'i2l' ? { fp32 } : {}),
});
const maskAlphaToMask = g.addNode({
id: getPrefixedId('mask_alpha_to_mask'),
type: 'tomask',
@@ -187,7 +202,7 @@ export const addOutpaint = async ({
g.addEdge(i2l, 'latents', denoise, 'latents');
g.addEdge(vaeSource, 'vae', i2l, 'vae');
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
if (modelLoader.type !== 'flux_model_loader') {
if (modelLoader.type !== 'flux_model_loader' && modelLoader.type !== 'sd3_model_loader') {
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
}

View File

@@ -6,7 +6,7 @@ import type { Invocation } from 'services/api/types';
type AddTextToImageArg = {
g: Graph;
l2i: Invocation<'l2i' | 'flux_vae_decode'>;
l2i: Invocation<'l2i' | 'flux_vae_decode' | 'sd3_l2i'>;
originalSize: Dimensions;
scaledSize: Dimensions;
};
@@ -16,7 +16,7 @@ export const addTextToImage = ({
l2i,
originalSize,
scaledSize,
}: AddTextToImageArg): Invocation<'img_resize' | 'l2i' | 'flux_vae_decode'> => {
}: AddTextToImageArg): Invocation<'img_resize' | 'l2i' | 'flux_vae_decode' | 'sd3_l2i'> => {
if (!isEqual(scaledSize, originalSize)) {
// We need to resize the output image back to the original size
const resizeImageToOriginalSize = g.addNode({

View File

@@ -11,7 +11,7 @@ import type { Invocation } from 'services/api/types';
export const addWatermarker = (
g: Graph,
imageOutput: Invocation<
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode' | 'sd3_l2i'
>
): Invocation<'img_watermark'> => {
const watermark = g.addNode({

View File

@@ -139,7 +139,7 @@ export const buildFLUXGraph = async (
}
let canvasOutput: Invocation<
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode' | 'sd3_l2i'
> = l2i;
if (generationMode === 'txt2img') {
@@ -149,6 +149,7 @@ export const buildFLUXGraph = async (
g,
manager,
l2i,
i2lNodeType: 'flux_vae_encode',
denoise,
vaeSource: modelLoader,
originalSize,
@@ -163,6 +164,7 @@ export const buildFLUXGraph = async (
g,
manager,
l2i,
i2lNodeType: 'flux_vae_encode',
denoise,
vaeSource: modelLoader,
modelLoader,
@@ -177,6 +179,7 @@ export const buildFLUXGraph = async (
g,
manager,
l2i,
i2lNodeType: 'flux_vae_encode',
denoise,
vaeSource: modelLoader,
modelLoader,

View File

@@ -170,7 +170,7 @@ export const buildSD1Graph = async (
const denoising_start = 1 - params.img2imgStrength;
let canvasOutput: Invocation<
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode' | 'sd3_l2i'
> = l2i;
if (generationMode === 'txt2img') {
@@ -180,6 +180,7 @@ export const buildSD1Graph = async (
g,
manager,
l2i,
i2lNodeType: 'i2l',
denoise,
vaeSource,
originalSize,
@@ -194,6 +195,7 @@ export const buildSD1Graph = async (
g,
manager,
l2i,
i2lNodeType: 'i2l',
denoise,
vaeSource,
modelLoader,
@@ -208,6 +210,7 @@ export const buildSD1Graph = async (
g,
manager,
l2i,
i2lNodeType: 'i2l',
denoise,
vaeSource,
modelLoader,

View File

@@ -0,0 +1,216 @@
import { logger } from 'app/logging/logger';
import type { RootState } from 'app/store/store';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
import {
CANVAS_OUTPUT_PREFIX,
getBoardField,
getPresetModifiedPrompts,
getSizes,
} from 'features/nodes/util/graph/graphBuilderUtils';
import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const buildSD3Graph = async (
state: RootState,
manager: CanvasManager
): Promise<{ g: Graph; noise: Invocation<'sd3_denoise'>; posCond: Invocation<'sd3_text_encoder'> }> => {
const generationMode = await manager.compositor.getGenerationMode();
log.debug({ generationMode }, 'Building SD3 graph');
const params = selectParamsSlice(state);
const canvasSettings = selectCanvasSettingsSlice(state);
const canvas = selectCanvasSlice(state);
const { bbox } = canvas;
const {
model,
cfgScale: cfg_scale,
seed,
steps,
vae,
t5EncoderModel,
clipLEmbedModel,
clipGEmbedModel,
optimizedDenoisingEnabled,
img2imgStrength,
} = params;
assert(model, 'No model found in state');
const { originalSize, scaledSize } = getSizes(bbox);
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
const g = new Graph(getPrefixedId('sd3_graph'));
const modelLoader = g.addNode({
type: 'sd3_model_loader',
id: getPrefixedId('sd3_model_loader'),
model,
t5_encoder_model: t5EncoderModel,
clip_l_model: clipLEmbedModel,
clip_g_model: clipGEmbedModel,
vae_model: vae,
});
const posCond = g.addNode({
type: 'sd3_text_encoder',
id: getPrefixedId('pos_cond'),
prompt: positivePrompt,
});
const negCond = g.addNode({
type: 'sd3_text_encoder',
id: getPrefixedId('neg_cond'),
prompt: negativePrompt,
});
const denoise = g.addNode({
type: 'sd3_denoise',
id: getPrefixedId('sd3_denoise'),
cfg_scale,
steps,
denoising_start: 0,
denoising_end: 1,
width: scaledSize.width,
height: scaledSize.height,
});
const l2i = g.addNode({
type: 'sd3_l2i',
id: getPrefixedId('l2i'),
});
g.addEdge(modelLoader, 'transformer', denoise, 'transformer');
g.addEdge(modelLoader, 'clip_l', posCond, 'clip_l');
g.addEdge(modelLoader, 'clip_l', negCond, 'clip_l');
g.addEdge(modelLoader, 'clip_g', posCond, 'clip_g');
g.addEdge(modelLoader, 'clip_g', negCond, 'clip_g');
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
g.addEdge(modelLoader, 't5_encoder', negCond, 't5_encoder');
g.addEdge(posCond, 'conditioning', denoise, 'positive_conditioning');
g.addEdge(negCond, 'conditioning', denoise, 'negative_conditioning');
g.addEdge(denoise, 'latents', l2i, 'latents');
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
assert(modelConfig.base === 'sd-3');
g.upsertMetadata({
generation_mode: 'sd3_txt2img',
cfg_scale,
width: originalSize.width,
height: originalSize.height,
positive_prompt: positivePrompt,
negative_prompt: negativePrompt,
model: Graph.getModelMetadataField(modelConfig),
seed,
steps,
vae: vae ?? undefined,
});
g.addEdge(modelLoader, 'vae', l2i, 'vae');
let denoising_start: number;
if (optimizedDenoisingEnabled) {
// We rescale the img2imgStrength (with exponent 0.2) to effectively use the entire range [0, 1] and make the scale
// more user-friendly for SD3.5. Without this, most of the 'change' is concentrated in the high denoise strength
// range (>0.9).
denoising_start = 1 - img2imgStrength ** 0.2;
} else {
denoising_start = 1 - img2imgStrength;
}
let canvasOutput: Invocation<
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode' | 'sd3_l2i'
> = l2i;
if (generationMode === 'txt2img') {
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
} else if (generationMode === 'img2img') {
canvasOutput = await addImageToImage({
g,
manager,
l2i,
i2lNodeType: 'sd3_i2l',
denoise,
vaeSource: modelLoader,
originalSize,
scaledSize,
bbox,
denoising_start,
fp32: false,
});
} else if (generationMode === 'inpaint') {
canvasOutput = await addInpaint({
state,
g,
manager,
l2i,
i2lNodeType: 'sd3_i2l',
denoise,
vaeSource: modelLoader,
modelLoader,
originalSize,
scaledSize,
denoising_start,
fp32: false,
});
} else if (generationMode === 'outpaint') {
canvasOutput = await addOutpaint({
state,
g,
manager,
l2i,
i2lNodeType: 'sd3_i2l',
denoise,
vaeSource: modelLoader,
modelLoader,
originalSize,
scaledSize,
denoising_start,
fp32: false,
});
} else {
assert<Equals<typeof generationMode, never>>(false);
}
if (state.system.shouldUseNSFWChecker) {
canvasOutput = addNSFWChecker(g, canvasOutput);
}
if (state.system.shouldUseWatermarker) {
canvasOutput = addWatermarker(g, canvasOutput);
}
// This image will be staged, should not be saved to the gallery or added to a board.
const is_intermediate = canvasSettings.sendToCanvas;
const board = canvasSettings.sendToCanvas ? undefined : getBoardField(state);
if (!canvasSettings.sendToCanvas) {
g.upsertMetadata(selectCanvasMetadata(state));
}
g.updateNode(canvasOutput, {
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
is_intermediate,
use_cache: false,
board,
});
g.setMetadataReceivingNode(canvasOutput);
return { g, noise: denoise, posCond };
};

View File

@@ -175,7 +175,7 @@ export const buildSDXLGraph = async (
: 1 - params.img2imgStrength;
let canvasOutput: Invocation<
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode' | 'sd3_l2i'
> = l2i;
if (generationMode === 'txt2img') {
@@ -185,6 +185,7 @@ export const buildSDXLGraph = async (
g,
manager,
l2i,
i2lNodeType: 'i2l',
denoise,
vaeSource,
originalSize,
@@ -199,6 +200,7 @@ export const buildSDXLGraph = async (
g,
manager,
l2i,
i2lNodeType: 'i2l',
denoise,
vaeSource,
modelLoader,
@@ -213,6 +215,7 @@ export const buildSDXLGraph = async (
g,
manager,
l2i,
i2lNodeType: 'i2l',
denoise,
vaeSource,
modelLoader,

View File

@@ -118,16 +118,4 @@ export const getInfill = (
assert(false, 'Unknown infill method');
};
export const addImageToLatents = (g: Graph, isFlux: boolean, fp32: boolean, image_name?: string) => {
if (isFlux) {
return g.addNode({
id: 'flux_vae_encode',
type: 'flux_vae_encode',
image: image_name ? { image_name } : undefined,
});
} else {
return g.addNode({ id: 'i2l', type: 'i2l', fp32, image: image_name ? { image_name } : undefined });
}
};
export const CANVAS_OUTPUT_PREFIX = 'canvas_output';

View File

@@ -0,0 +1,42 @@
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useModelCombobox } from 'common/hooks/useModelCombobox';
import { clipGEmbedModelSelected, selectCLIPGEmbedModel } from 'features/controlLayers/store/paramsSlice';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPGEmbedModelConfig } from 'services/api/types';
import { isCLIPGEmbedModelConfig } from 'services/api/types';
const ParamCLIPEmbedModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const clipEmbedModel = useAppSelector(selectCLIPGEmbedModel);
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(clipEmbedModel: CLIPGEmbedModelConfig | null) => {
if (clipEmbedModel) {
dispatch(clipGEmbedModelSelected(zModelIdentifierField.parse(clipEmbedModel)));
}
},
[dispatch]
);
const { options, value, onChange, noOptionsMessage } = useModelCombobox({
modelConfigs: modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config)),
onChange: _onChange,
selectedModel: clipEmbedModel,
isLoading,
});
return (
<FormControl isDisabled={!options.length} isInvalid={!options.length} minW={0} flexGrow={1}>
<FormLabel m={0}>{t('modelManager.clipGEmbed')}</FormLabel>
<Combobox value={value} options={options} onChange={onChange} noOptionsMessage={noOptionsMessage} />
</FormControl>
);
};
export default memo(ParamCLIPEmbedModelSelect);

View File

@@ -0,0 +1,42 @@
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useModelCombobox } from 'common/hooks/useModelCombobox';
import { clipLEmbedModelSelected, selectCLIPLEmbedModel } from 'features/controlLayers/store/paramsSlice';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
import type { CLIPLEmbedModelConfig } from 'services/api/types';
import { isCLIPLEmbedModelConfig } from 'services/api/types';
const ParamCLIPEmbedModelSelect = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const clipEmbedModel = useAppSelector(selectCLIPLEmbedModel);
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
const _onChange = useCallback(
(clipEmbedModel: CLIPLEmbedModelConfig | null) => {
if (clipEmbedModel) {
dispatch(clipLEmbedModelSelected(zModelIdentifierField.parse(clipEmbedModel)));
}
},
[dispatch]
);
const { options, value, onChange, noOptionsMessage } = useModelCombobox({
modelConfigs: modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config)),
onChange: _onChange,
selectedModel: clipEmbedModel,
isLoading,
});
return (
<FormControl isDisabled={!options.length} isInvalid={!options.length} minW={0} flexGrow={1}>
<FormLabel m={0}>{t('modelManager.clipLEmbed')}</FormLabel>
<Combobox value={value} options={options} onChange={onChange} noOptionsMessage={noOptionsMessage} />
</FormControl>
);
};
export default memo(ParamCLIPEmbedModelSelect);

View File

@@ -9,7 +9,7 @@ import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { MdMoneyOff } from 'react-icons/md';
import { useNonSD3MainModels } from 'services/api/hooks/modelsByType';
import { useMainModels } from 'services/api/hooks/modelsByType';
import { type AnyModelConfig, isCheckpointMainModelConfig, type MainModelConfig } from 'services/api/types';
const ParamMainModelSelect = () => {
@@ -18,7 +18,7 @@ const ParamMainModelSelect = () => {
const activeTabName = useAppSelector(selectActiveTab);
const selectedModelKey = useAppSelector(selectModelKey);
// const selectedModel = useAppSelector(selectModel);
const [modelConfigs, { isLoading }] = useNonSD3MainModels();
const [modelConfigs, { isLoading }] = useMainModels();
const selectedModel = useMemo(() => {
if (!modelConfigs) {

View File

@@ -7,9 +7,10 @@ export const MODEL_TYPE_MAP = {
any: 'Any',
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
'sd-3': 'Stable Diffusion 3.x',
sdxl: 'Stable Diffusion XL',
'sdxl-refiner': 'Stable Diffusion XL Refiner',
flux: 'Flux',
flux: 'FLUX',
};
/**

View File

@@ -123,6 +123,16 @@ export const zParameterCLIPEmbedModel = zModelIdentifierField;
export type ParameterCLIPEmbedModel = z.infer<typeof zParameterCLIPEmbedModel>;
// #endregion
// #region CLIP embed Model
export const zParameterCLIPLEmbedModel = zModelIdentifierField;
export type ParameterCLIPLEmbedModel = z.infer<typeof zParameterCLIPLEmbedModel>;
// #endregion
// #region CLIP embed Model
export const zParameterCLIPGEmbedModel = zModelIdentifierField;
export type ParameterCLIPGEmbedModel = z.infer<typeof zParameterCLIPGEmbedModel>;
// #endregion
// #region LoRA Model
const zParameterLoRAModel = zModelIdentifierField;
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;

View File

@@ -3,9 +3,11 @@ import { Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-libra
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsFLUX, selectParamsSlice, selectVAEKey } from 'features/controlLayers/store/paramsSlice';
import { selectIsFLUX, selectIsSD3, selectParamsSlice, selectVAEKey } from 'features/controlLayers/store/paramsSlice';
import ParamCFGRescaleMultiplier from 'features/parameters/components/Advanced/ParamCFGRescaleMultiplier';
import ParamCLIPEmbedModelSelect from 'features/parameters/components/Advanced/ParamCLIPEmbedModelSelect';
import ParamCLIPGEmbedModelSelect from 'features/parameters/components/Advanced/ParamCLIPGEmbedModelSelect';
import ParamCLIPLEmbedModelSelect from 'features/parameters/components/Advanced/ParamCLIPLEmbedModelSelect';
import ParamClipSkip from 'features/parameters/components/Advanced/ParamClipSkip';
import ParamT5EncoderModelSelect from 'features/parameters/components/Advanced/ParamT5EncoderModelSelect';
import ParamSeamlessXAxis from 'features/parameters/components/Seamless/ParamSeamlessXAxis';
@@ -35,6 +37,7 @@ export const AdvancedSettingsAccordion = memo(() => {
const { currentData: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
const activeTabName = useAppSelector(selectActiveTab);
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
const selectBadges = useMemo(
() =>
@@ -88,7 +91,7 @@ export const AdvancedSettingsAccordion = memo(() => {
<Flex gap={4} alignItems="center" p={4} flexDir="column" data-testid="advanced-settings-accordion">
<Flex gap={4} w="full">
{isFLUX ? <ParamFLUXVAEModelSelect /> : <ParamVAEModelSelect />}
{!isFLUX && <ParamVAEPrecision />}
{!isFLUX && !isSD3 && <ParamVAEPrecision />}
</Flex>
{activeTabName === 'upscaling' ? (
<Flex gap={4} alignItems="center">
@@ -98,7 +101,7 @@ export const AdvancedSettingsAccordion = memo(() => {
</Flex>
) : (
<>
{!isFLUX && (
{!isFLUX && !isSD3 && (
<>
<FormControlGroup formLabelProps={formLabelProps}>
<ParamClipSkip />
@@ -118,6 +121,13 @@ export const AdvancedSettingsAccordion = memo(() => {
<ParamCLIPEmbedModelSelect />
</FormControlGroup>
)}
{isSD3 && (
<FormControlGroup>
<ParamT5EncoderModelSelect />
<ParamCLIPLEmbedModelSelect />
<ParamCLIPGEmbedModelSelect />
</FormControlGroup>
)}
</>
)}
</Flex>

View File

@@ -4,7 +4,7 @@ import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { LoRAList } from 'features/lora/components/LoRAList';
import LoRASelect from 'features/lora/components/LoRASelect';
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
@@ -30,6 +30,7 @@ export const GenerationSettingsAccordion = memo(() => {
const modelConfig = useSelectedModelConfig();
const activeTabName = useAppSelector(selectActiveTab);
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
const selectBadges = useMemo(
() =>
createMemoizedSelector(selectLoRAsSlice, (loras) => {
@@ -74,7 +75,7 @@ export const GenerationSettingsAccordion = memo(() => {
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
<Flex gap={4} flexDir="column" pb={4}>
<FormControlGroup formLabelProps={formLabelProps}>
{!isFLUX && <ParamScheduler />}
{!isFLUX && !isSD3 && <ParamScheduler />}
<ParamSteps />
{isFLUX ? <ParamGuidance /> : <ParamCFGScale />}
</FormControlGroup>

View File

@@ -3,7 +3,7 @@ import { Expander, Flex, FormControlGroup, StandaloneAccordion } from '@invoke-a
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsFLUX, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectIsFLUX, selectIsSD3, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectScaleMethod } from 'features/controlLayers/store/selectors';
import { ParamOptimizedDenoisingToggle } from 'features/parameters/components/Advanced/ParamOptimizedDenoisingToggle';
import BboxScaledHeight from 'features/parameters/components/Bbox/BboxScaledHeight';
@@ -60,6 +60,7 @@ export const ImageSettingsAccordion = memo(() => {
defaultIsOpen: false,
});
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);
return (
<StandaloneAccordion
@@ -77,7 +78,7 @@ export const ImageSettingsAccordion = memo(() => {
</Flex>
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
<Flex gap={4} pb={4} flexDir="column">
{isFLUX && <ParamOptimizedDenoisingToggle />}
{(isFLUX || isSD3) && <ParamOptimizedDenoisingToggle />}
<BboxScaleMethod />
{scaleMethod !== 'none' && (
<FormControlGroup formLabelProps={scalingLabelProps}>

View File

@@ -34,8 +34,15 @@ export const UpscaleWarning = () => {
dispatch(tileControlnetModelChanged(validModel || null));
}, [model?.base, modelConfigs, dispatch]);
const isBaseModelCompatible = useMemo(() => {
return model && ['sd-1', 'sdxl'].includes(model.base);
}, [model]);
const modelWarnings = useMemo(() => {
const _warnings: string[] = [];
if (!isBaseModelCompatible) {
return _warnings;
}
if (!model) {
_warnings.push(t('upscaling.mainModelDesc'));
}
@@ -46,7 +53,7 @@ export const UpscaleWarning = () => {
_warnings.push(t('upscaling.upscaleModelDesc'));
}
return _warnings;
}, [model, tileControlnetModel, upscaleModel, t]);
}, [isBaseModelCompatible, model, tileControlnetModel, upscaleModel, t]);
const otherWarnings = useMemo(() => {
const _warnings: string[] = [];
@@ -58,22 +65,25 @@ export const UpscaleWarning = () => {
return _warnings;
}, [isTooLargeToUpscale, t, maxUpscaleDimension]);
const allWarnings = useMemo(() => [...modelWarnings, ...otherWarnings], [modelWarnings, otherWarnings]);
const handleGoToModelManager = useCallback(() => {
dispatch(setActiveTab('models'));
$installModelsTab.set(3);
}, [dispatch]);
if (modelWarnings.length && isModelsTabDisabled) {
if (isBaseModelCompatible && modelWarnings.length > 0 && isModelsTabDisabled) {
return null;
}
if ((!modelWarnings.length && !otherWarnings.length) || isLoading) {
if ((isBaseModelCompatible && allWarnings.length === 0) || isLoading) {
return null;
}
return (
<Flex bg="error.500" borderRadius="base" padding={4} direction="column" fontSize="sm" gap={2}>
{!!modelWarnings.length && (
{!isBaseModelCompatible && <Text>{t('upscaling.incompatibleBaseModelDesc')}</Text>}
{modelWarnings.length > 0 && (
<Text>
<Trans
i18nKey="upscaling.missingModelsWarning"
@@ -85,11 +95,13 @@ export const UpscaleWarning = () => {
/>
</Text>
)}
<UnorderedList>
{[...modelWarnings, ...otherWarnings].map((warning) => (
<ListItem key={warning}>{warning}</ListItem>
))}
</UnorderedList>
{allWarnings.length > 0 && (
<UnorderedList>
{allWarnings.map((warning) => (
<ListItem key={warning}>{warning}</ListItem>
))}
</UnorderedList>
)}
</Flex>
);
};

View File

@@ -26,17 +26,20 @@ import { SettingsDeveloperLogLevel } from 'features/system/components/SettingsMo
import { SettingsDeveloperLogNamespaces } from 'features/system/components/SettingsModal/SettingsDeveloperLogNamespaces';
import { useClearIntermediates } from 'features/system/components/SettingsModal/useClearIntermediates';
import { StickyScrollable } from 'features/system/components/StickyScrollable';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import {
selectSystemShouldAntialiasProgressImage,
selectSystemShouldConfirmOnDelete,
selectSystemShouldConfirmOnNewSession,
selectSystemShouldEnableInformationalPopovers,
selectSystemShouldEnableModelDescriptions,
selectSystemShouldShowInvocationProgressDetail,
selectSystemShouldUseNSFWChecker,
selectSystemShouldUseWatermarker,
setShouldConfirmOnDelete,
setShouldEnableInformationalPopovers,
setShouldEnableModelDescriptions,
setShouldShowInvocationProgressDetail,
shouldAntialiasProgressImageChanged,
shouldConfirmOnNewSessionToggled,
shouldUseNSFWCheckerChanged,
@@ -103,6 +106,8 @@ const SettingsModal = ({ config = defaultConfig, children }: SettingsModalProps)
const shouldEnableInformationalPopovers = useAppSelector(selectSystemShouldEnableInformationalPopovers);
const shouldEnableModelDescriptions = useAppSelector(selectSystemShouldEnableModelDescriptions);
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const shouldShowInvocationProgressDetail = useAppSelector(selectSystemShouldShowInvocationProgressDetail);
const isInvocationProgressAlertEnabled = useFeatureStatus('invocationProgressAlert');
const onToggleConfirmOnNewSession = useCallback(() => {
dispatch(shouldConfirmOnNewSessionToggled());
}, [dispatch]);
@@ -170,6 +175,13 @@ const SettingsModal = ({ config = defaultConfig, children }: SettingsModalProps)
[dispatch]
);
const handleChangeShouldShowInvocationProgressDetail = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(setShouldShowInvocationProgressDetail(e.target.checked));
},
[dispatch]
);
return (
<>
{cloneElement(children, {
@@ -221,6 +233,15 @@ const SettingsModal = ({ config = defaultConfig, children }: SettingsModalProps)
onChange={handleChangeShouldAntialiasProgressImage}
/>
</FormControl>
{isInvocationProgressAlertEnabled && (
<FormControl>
<FormLabel>{t('settings.showDetailedInvocationProgress')}</FormLabel>
<Switch
isChecked={shouldShowInvocationProgressDetail}
onChange={handleChangeShouldShowInvocationProgressDetail}
/>
</FormControl>
)}
<FormControl>
<InformationalPopover feature="noiseUseCPU" inPortal={false}>
<FormLabel>{t('parameters.useCpuNoise')}</FormLabel>

View File

@@ -21,6 +21,7 @@ const initialSystemState: SystemState = {
logIsEnabled: true,
logLevel: 'debug',
logNamespaces: [...zLogNamespace.options],
shouldShowInvocationProgressDetail: false,
};
export const systemSlice = createSlice({
@@ -64,6 +65,9 @@ export const systemSlice = createSlice({
shouldConfirmOnNewSessionToggled(state) {
state.shouldConfirmOnNewSession = !state.shouldConfirmOnNewSession;
},
setShouldShowInvocationProgressDetail(state, action: PayloadAction<boolean>) {
state.shouldShowInvocationProgressDetail = action.payload;
},
},
});
@@ -79,6 +83,7 @@ export const {
setShouldEnableInformationalPopovers,
setShouldEnableModelDescriptions,
shouldConfirmOnNewSessionToggled,
setShouldShowInvocationProgressDetail,
} = systemSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
@@ -117,3 +122,6 @@ export const selectSystemShouldEnableModelDescriptions = createSystemSelector(
(system) => system.shouldEnableModelDescriptions
);
export const selectSystemShouldConfirmOnNewSession = createSystemSelector((system) => system.shouldConfirmOnNewSession);
export const selectSystemShouldShowInvocationProgressDetail = createSystemSelector(
(system) => system.shouldShowInvocationProgressDetail
);

View File

@@ -41,4 +41,5 @@ export interface SystemState {
logIsEnabled: boolean;
logLevel: LogLevel;
logNamespaces: LogNamespace[];
shouldShowInvocationProgressDetail: boolean;
}

View File

@@ -6,10 +6,10 @@ import type { components, paths } from 'services/api/schema';
import type {
DeleteBoardResult,
GraphAndWorkflowResponse,
ImageCategory,
ImageDTO,
ListImagesArgs,
ListImagesResponse,
UploadImageArg,
} from 'services/api/types';
import { getCategories, getListImagesUrl } from 'services/api/util';
import type { JsonObject } from 'type-fest';
@@ -260,20 +260,7 @@ export const imagesApi = api.injectEndpoints({
return [];
},
}),
uploadImage: build.mutation<
ImageDTO,
{
file: File;
image_category: ImageCategory;
is_intermediate: boolean;
session_id?: string;
board_id?: string;
crop_visible?: boolean;
metadata?: JsonObject;
isFirstUploadOfBatch?: boolean;
withToast?: boolean;
}
>({
uploadImage: build.mutation<ImageDTO, UploadImageArg>({
query: ({ file, image_category, is_intermediate, session_id, board_id, crop_visible, metadata }) => {
const formData = new FormData();
formData.append('file', file);
@@ -558,7 +545,6 @@ export const {
useClearIntermediatesMutation,
useAddImagesToBoardMutation,
useRemoveImagesFromBoardMutation,
useChangeImageIsIntermediateMutation,
useDeleteBoardAndImagesMutation,
useDeleteBoardMutation,
useStarImagesMutation,
@@ -622,79 +608,17 @@ export const getImageMetadata = (
return req.unwrap();
};
export type UploadImageArg = {
file: File;
image_category: ImageCategory;
is_intermediate: boolean;
session_id?: string;
board_id?: string;
crop_visible?: boolean;
metadata?: JsonObject;
withToast?: boolean;
};
export const uploadImage = (arg: UploadImageArg): Promise<ImageDTO> => {
const {
file,
image_category,
is_intermediate,
crop_visible = false,
board_id,
metadata,
session_id,
withToast = true,
} = arg;
const { dispatch } = getStore();
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate(
{
file,
image_category,
is_intermediate,
crop_visible,
board_id,
metadata,
session_id,
withToast,
},
{ track: false }
)
);
const req = dispatch(imagesApi.endpoints.uploadImage.initiate(arg, { track: false }));
return req.unwrap();
};
export const uploadImages = async (args: UploadImageArg[]): Promise<ImageDTO[]> => {
const { dispatch } = getStore();
const results = await Promise.allSettled(
args.map((arg, i) => {
const {
file,
image_category,
is_intermediate,
crop_visible = false,
board_id,
metadata,
session_id,
withToast = true,
} = arg;
const req = dispatch(
imagesApi.endpoints.uploadImage.initiate(
{
file,
image_category,
is_intermediate,
crop_visible,
board_id,
metadata,
session_id,
isFirstUploadOfBatch: i === 0,
withToast,
},
{ track: false }
)
);
args.map((arg) => {
const req = dispatch(imagesApi.endpoints.uploadImage.initiate(arg, { track: false }));
return req.unwrap();
})
);

View File

@@ -18,7 +18,6 @@ import {
isIPAdapterModelConfig,
isLoRAModelConfig,
isNonRefinerMainModelConfig,
isNonSD3MainModelModelConfig,
isNonSDXLMainModelConfig,
isRefinerMainModelModelConfig,
isSD3MainModelModelConfig,
@@ -53,7 +52,6 @@ const buildModelsHook =
};
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSD3MainModels = buildModelsHook(isNonSD3MainModelModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);

File diff suppressed because one or more lines are too long

View File

@@ -1,5 +1,5 @@
import type { components, paths } from 'services/api/schema';
import type { SetRequired } from 'type-fest';
import type { JsonObject, SetRequired } from 'type-fest';
export type S = components['schemas'];
@@ -223,10 +223,6 @@ export const isSD3MainModelModelConfig = (config: AnyModelConfig): config is Mai
return config.type === 'main' && config.base === 'sd-3';
};
export const isNonSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base !== 'sd-3' && config.base !== 'sdxl-refiner';
};
export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'flux';
};
@@ -291,3 +287,42 @@ export type SetHFTokenResponse = NonNullable<
export type SetHFTokenArg = NonNullable<
paths['/api/v2/models/hf_login']['post']['requestBody']['content']['application/json']
>;
export type UploadImageArg = {
/**
* The file object to upload
*/
file: File;
/**
* THe category of image to upload
*/
image_category: ImageCategory;
/**
* Whether the uploaded image is an intermediate image (intermediate images are not shown int he gallery)
*/
is_intermediate: boolean;
/**
* The session with which to associate the uploaded image
*/
session_id?: string;
/**
* The board id to add the image to
*/
board_id?: string;
/**
* Whether or not to crop the image to its bounding box before saving
*/
crop_visible?: boolean;
/**
* Metadata to embed in the image when saving it
*/
metadata?: JsonObject;
/**
* Whether this upload should be "silent" (no toast on upload, no changing of gallery view)
*/
silent?: boolean;
/**
* Whether this is the first upload of a batch (used when displaying user feedback with toasts - ignored if the upload is silent)
*/
isFirstUploadOfBatch?: boolean;
};

View File

@@ -1,3 +1,4 @@
import { round } from 'lodash-es';
import { atom, computed, map } from 'nanostores';
import type { S } from 'services/api/types';
import type { AppSocket } from 'services/events/types';
@@ -10,3 +11,16 @@ export const $lastProgressEvent = atom<S['InvocationProgressEvent'] | null>(null
export const $progressImage = computed($lastProgressEvent, (val) => val?.image ?? null);
export const $hasProgressImage = computed($lastProgressEvent, (val) => Boolean(val?.image));
export const $isProgressFromCanvas = computed($lastProgressEvent, (val) => val?.destination === 'canvas');
export const $invocationProgressMessage = computed($lastProgressEvent, (val) => {
if (!val) {
return null;
}
if (val.destination !== 'canvas') {
return null;
}
let message = val.message;
if (val.percentage) {
message += ` (${round(val.percentage * 100)}%)`;
}
return message;
});

View File

@@ -1 +1 @@
__version__ = "5.4.1rc2"
__version__ = "5.4.1"

View File

@@ -6,7 +6,11 @@ from diffusers import AutoencoderTiny
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
from invokeai.app.services.shared.invocation_context import (
InvocationContext,
InvocationContextData,
build_invocation_context,
)
from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
@@ -19,7 +23,7 @@ def mock_context(
mock_services.model_manager = mm2_model_manager
return build_invocation_context(
services=mock_services,
data=None, # type: ignore
data=InvocationContextData(queue_item=None, invocation=None, source_invocation_id=None), # type: ignore
is_canceled=None, # type: ignore
)

View File

@@ -1,6 +1,6 @@
# State dict keys and shapes for an XLabs FLUX IP-Adapter model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/XLabs-AI/flux-ip-adapter/blob/ad16be50d78a07ea83d8c4bde44ff9753235182e/flux-ip-adapter.safetensors
# https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors
xlabs_sd_shapes = {
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias": [3072],
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],

View File

@@ -0,0 +1,688 @@
# A sample state dict in the Diffusers FLUX LoRA format without .proj_mlp layers.
# This format was added in response to https://github.com/invoke-ai/InvokeAI/issues/7129
state_dict_keys = {
"transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.1.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.1.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.1.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.1.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.1.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.1.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.10.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.10.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.10.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.10.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.10.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.10.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.11.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.11.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.11.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.11.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.11.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.11.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.12.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.12.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.12.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.12.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.12.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.12.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.13.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.13.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.13.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.13.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.13.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.13.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.14.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.14.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.14.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.14.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.14.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.14.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.15.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.15.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.15.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.15.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.15.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.15.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.16.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.16.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.16.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.16.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.16.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.16.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.17.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.17.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.17.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.17.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.17.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.17.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.18.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.18.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.18.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.18.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.18.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.18.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.19.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.19.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.19.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.19.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.19.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.19.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.2.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.2.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.2.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.2.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.2.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.2.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.20.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.20.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.20.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.20.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.20.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.20.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.21.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.21.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.21.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.21.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.21.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.21.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.22.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.22.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.22.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.22.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.22.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.22.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.23.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.23.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.23.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.23.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.23.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.23.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.24.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.24.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.24.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.24.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.24.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.24.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.25.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.25.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.25.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.25.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.25.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.25.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.26.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.26.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.26.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.26.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.26.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.26.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.27.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.27.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.27.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.27.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.27.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.27.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.28.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.28.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.28.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.28.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.28.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.28.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.29.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.29.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.29.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.29.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.29.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.29.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.3.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.3.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.3.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.3.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.3.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.3.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.30.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.30.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.30.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.30.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.30.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.30.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.31.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.31.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.31.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.31.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.31.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.31.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.32.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.32.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.32.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.32.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.32.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.32.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.33.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.33.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.33.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.33.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.33.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.33.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.34.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.34.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.34.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.34.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.34.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.34.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.35.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.35.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.35.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.35.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.35.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.35.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.36.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.36.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.36.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.36.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.36.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.36.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.37.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.37.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.37.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.37.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.37.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.37.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.4.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.4.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.4.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.4.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.4.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.4.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.5.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.5.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.5.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.5.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.5.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.5.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.6.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.6.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.6.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.6.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.6.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.6.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.7.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.7.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.7.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.7.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.7.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.7.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.8.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.8.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.8.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.8.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.8.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.8.attn.to_v.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.9.attn.to_k.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.9.attn.to_k.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.9.attn.to_q.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.9.attn.to_q.lora_B.weight": [3072, 16],
"transformer.single_transformer_blocks.9.attn.to_v.lora_A.weight": [16, 3072],
"transformer.single_transformer_blocks.9.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.0.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.0.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.0.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.0.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.0.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.0.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.0.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.0.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.1.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.1.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.1.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.1.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.1.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.1.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.1.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.1.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.1.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.1.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.1.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.1.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.1.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.1.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.1.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.1.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.1.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.1.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.1.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.1.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.1.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.1.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.1.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.1.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.10.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.10.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.10.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.10.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.10.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.10.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.10.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.10.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.10.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.10.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.10.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.10.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.10.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.10.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.10.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.10.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.10.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.10.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.10.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.10.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.10.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.10.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.10.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.10.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.11.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.11.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.11.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.11.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.11.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.11.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.11.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.11.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.11.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.11.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.11.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.11.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.11.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.11.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.11.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.11.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.11.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.11.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.11.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.11.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.11.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.11.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.11.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.11.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.12.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.12.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.12.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.12.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.12.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.12.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.12.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.12.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.12.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.12.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.12.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.12.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.12.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.12.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.12.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.12.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.12.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.12.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.12.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.12.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.12.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.12.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.12.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.12.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.13.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.13.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.13.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.13.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.13.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.13.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.13.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.13.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.13.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.13.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.13.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.13.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.13.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.13.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.13.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.13.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.13.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.13.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.13.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.13.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.13.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.13.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.13.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.13.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.14.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.14.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.14.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.14.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.14.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.14.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.14.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.14.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.14.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.14.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.14.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.14.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.14.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.14.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.14.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.14.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.14.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.14.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.14.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.14.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.14.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.14.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.14.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.14.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.15.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.15.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.15.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.15.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.15.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.15.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.15.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.15.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.15.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.15.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.15.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.15.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.15.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.15.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.15.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.15.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.15.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.15.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.15.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.15.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.15.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.15.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.15.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.15.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.16.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.16.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.16.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.16.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.16.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.16.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.16.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.16.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.16.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.16.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.16.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.16.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.16.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.16.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.16.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.16.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.16.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.16.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.16.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.16.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.16.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.16.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.16.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.16.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.17.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.17.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.17.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.17.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.17.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.17.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.17.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.17.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.17.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.17.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.17.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.17.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.17.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.17.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.17.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.17.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.17.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.17.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.17.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.17.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.17.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.17.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.17.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.17.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.18.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.18.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.18.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.18.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.18.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.18.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.18.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.18.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.18.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.18.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.18.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.18.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.18.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.18.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.18.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.18.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.18.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.18.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.18.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.18.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.18.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.18.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.18.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.18.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.2.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.2.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.2.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.2.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.2.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.2.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.2.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.2.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.2.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.2.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.2.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.2.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.2.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.2.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.2.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.2.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.2.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.2.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.2.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.2.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.2.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.2.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.2.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.2.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.3.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.3.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.3.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.3.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.3.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.3.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.3.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.3.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.3.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.3.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.3.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.3.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.3.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.3.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.3.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.3.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.3.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.3.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.3.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.3.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.3.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.3.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.3.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.3.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.4.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.4.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.4.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.4.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.4.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.4.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.4.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.4.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.4.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.4.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.4.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.4.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.4.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.4.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.4.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.4.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.4.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.4.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.4.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.4.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.4.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.4.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.4.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.4.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.5.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.5.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.5.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.5.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.5.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.5.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.5.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.5.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.5.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.5.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.5.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.5.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.5.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.5.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.5.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.5.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.5.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.5.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.5.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.5.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.5.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.5.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.5.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.5.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.6.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.6.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.6.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.6.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.6.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.6.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.6.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.6.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.6.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.6.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.6.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.6.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.6.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.6.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.6.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.6.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.6.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.6.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.6.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.6.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.6.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.6.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.6.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.6.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.7.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.7.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.7.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.7.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.7.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.7.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.7.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.7.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.7.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.7.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.7.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.7.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.7.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.7.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.7.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.7.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.7.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.7.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.7.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.7.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.7.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.7.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.7.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.7.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.8.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.8.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.8.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.8.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.8.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.8.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.8.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.8.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.8.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.8.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.8.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.8.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.8.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.8.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.8.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.8.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.8.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.8.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.8.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.8.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.8.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.8.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.8.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.8.ff_context.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.9.attn.add_k_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.9.attn.add_k_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.9.attn.add_q_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.9.attn.add_q_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.9.attn.add_v_proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.9.attn.add_v_proj.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.9.attn.to_add_out.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.9.attn.to_add_out.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.9.attn.to_k.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.9.attn.to_k.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.9.attn.to_out.0.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.9.attn.to_out.0.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.9.attn.to_q.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.9.attn.to_q.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.9.attn.to_v.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.9.attn.to_v.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.9.ff.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.9.ff.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.9.ff.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.9.ff.net.2.lora_B.weight": [3072, 16],
"transformer.transformer_blocks.9.ff_context.net.0.proj.lora_A.weight": [16, 3072],
"transformer.transformer_blocks.9.ff_context.net.0.proj.lora_B.weight": [12288, 16],
"transformer.transformer_blocks.9.ff_context.net.2.lora_A.weight": [16, 12288],
"transformer.transformer_blocks.9.ff_context.net.2.lora_B.weight": [3072, 16],
}

View File

@@ -1,8 +1,8 @@
import torch
def keys_to_mock_state_dict(keys: list[str]) -> dict[str, torch.Tensor]:
def keys_to_mock_state_dict(keys: dict[str, list[int]]) -> dict[str, torch.Tensor]:
state_dict: dict[str, torch.Tensor] = {}
for k in keys:
state_dict[k] = torch.empty(1)
for k, shape in keys.items():
state_dict[k] = torch.empty(shape)
return state_dict

View File

@@ -9,22 +9,26 @@ from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRAN
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_no_proj_mlp_format import (
state_dict_keys as flux_diffusers_no_proj_mlp_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
state_dict_keys as flux_kohya_state_dict_keys,
)
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_diffusers_format_true():
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_diffusers_no_proj_mlp_state_dict_keys])
def test_is_state_dict_likely_in_flux_diffusers_format_true(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_in_flux_diffusers_format() can identify a state dict in the Diffusers FLUX LoRA format."""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
state_dict = keys_to_mock_state_dict(sd_keys)
assert is_state_dict_likely_in_flux_diffusers_format(state_dict)
def test_is_state_dict_likely_in_flux_diffusers_format_false():
"""Test that is_state_dict_likely_in_flux_diffusers_format() returns False for a state dict that is not in the Kohya
"""Test that is_state_dict_likely_in_flux_diffusers_format() returns False for a state dict that is in the Kohya
FLUX LoRA format.
"""
# Construct a state dict that is not in the Kohya FLUX LoRA format.
@@ -33,16 +37,17 @@ def test_is_state_dict_likely_in_flux_diffusers_format_false():
assert not is_state_dict_likely_in_flux_diffusers_format(state_dict)
def test_lora_model_from_flux_diffusers_state_dict():
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_diffusers_no_proj_mlp_state_dict_keys])
def test_lora_model_from_flux_diffusers_state_dict(sd_keys: dict[str, list[int]]):
"""Test that lora_model_from_flux_diffusers_state_dict() can load a state dict in the Diffusers FLUX LoRA format."""
# Construct a state dict that is in the Diffusers FLUX LoRA format.
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
state_dict = keys_to_mock_state_dict(sd_keys)
# Load the state dict into a LoRAModelRaw object.
model = lora_model_from_flux_diffusers_state_dict(state_dict, alpha=8.0)
# Check that the model has the correct number of LoRA layers.
expected_lora_layers: set[str] = set()
for k in flux_diffusers_state_dict_keys:
for k in sd_keys:
k = k.replace("lora_A.weight", "")
k = k.replace("lora_B.weight", "")
expected_lora_layers.add(k)

View File

@@ -23,7 +23,7 @@ from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_s
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: list[str]):
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: dict[str, list[int]]):
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(sd_keys)
@@ -83,7 +83,7 @@ def test_convert_flux_transformer_kohya_state_dict_to_invoke_format_error():
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
def test_lora_model_from_flux_kohya_state_dict(sd_keys: dict[str, list[int]]):
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
# Construct a state dict that is in the Kohya FLUX LoRA format.
state_dict = keys_to_mock_state_dict(sd_keys)

View File

@@ -9,6 +9,8 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.fields import InputField, OutputField
from invokeai.app.invocations.image import ImageField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -133,6 +135,16 @@ class TestEventService(EventServiceBase):
self.events.append(event)
pass
def emit_invocation_progress(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
message: str,
percentage: float | None = None,
image: "ProgressImage | None" = None,
) -> None:
pass
def wait_until(condition: Callable[[], bool], timeout: int = 10, interval: float = 0.1) -> None:
import time