mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 15:18:05 -05:00
Compare commits
62 Commits
v5.4.1rc2
...
psychedeli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ceae1dc04f | ||
|
|
4b390906bc | ||
|
|
c5b8efe03b | ||
|
|
4d08d00ad8 | ||
|
|
9b0130262b | ||
|
|
878093f64e | ||
|
|
d5ff7ef250 | ||
|
|
f36583f866 | ||
|
|
829bc1bc7d | ||
|
|
17c7b57145 | ||
|
|
6a12189542 | ||
|
|
96a31a5563 | ||
|
|
067747eca9 | ||
|
|
c7878fddc6 | ||
|
|
54c51e0a06 | ||
|
|
1640ea0298 | ||
|
|
0c32ae9775 | ||
|
|
fdb8ca5165 | ||
|
|
571faf6d7c | ||
|
|
bdbdb22b74 | ||
|
|
9bbb5644af | ||
|
|
e90ad19f22 | ||
|
|
0ba11e8f73 | ||
|
|
1cf7600f5b | ||
|
|
4f9d12b872 | ||
|
|
68c3b0649b | ||
|
|
8ef8bd4261 | ||
|
|
50897ba066 | ||
|
|
3510643870 | ||
|
|
ca9cb1c9ef | ||
|
|
b89caa02bd | ||
|
|
eaf4e08c44 | ||
|
|
fb19621361 | ||
|
|
9179619077 | ||
|
|
13cb5f0ba2 | ||
|
|
7e52fc1c17 | ||
|
|
7f60a4a282 | ||
|
|
3f880496f7 | ||
|
|
f05efd3270 | ||
|
|
79eb8172b6 | ||
|
|
7732b5d478 | ||
|
|
a2a1934b66 | ||
|
|
dff6570078 | ||
|
|
04e4fb63af | ||
|
|
83609d5008 | ||
|
|
2618ed0ae7 | ||
|
|
bb3cedddd5 | ||
|
|
5b3e1593ca | ||
|
|
2d08078a7d | ||
|
|
0e6cb91863 | ||
|
|
a0fefcd43f | ||
|
|
a5f8c23dee | ||
|
|
7bb4ea57c6 | ||
|
|
75dc961bcb | ||
|
|
a9a1f6ef21 | ||
|
|
aa40161f26 | ||
|
|
6efa812874 | ||
|
|
8a683f5a3c | ||
|
|
f4b0b6a93d | ||
|
|
1337c33ad3 | ||
|
|
496b02a3bc | ||
|
|
7b5efc2203 |
@@ -95,6 +95,7 @@ class CompelInvocation(BaseInvocation):
|
||||
ti_manager,
|
||||
),
|
||||
):
|
||||
context.util.signal_progress("Building conditioning")
|
||||
assert isinstance(text_encoder, CLIPTextModel)
|
||||
assert isinstance(tokenizer, CLIPTokenizer)
|
||||
compel = Compel(
|
||||
@@ -191,6 +192,7 @@ class SDXLPromptInvocationBase:
|
||||
ti_manager,
|
||||
),
|
||||
):
|
||||
context.util.signal_progress("Building conditioning")
|
||||
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
assert isinstance(tokenizer, CLIPTokenizer)
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ class CreateDenoiseMaskInvocation(BaseInvocation):
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
# TODO:
|
||||
context.util.signal_progress("Running VAE encoder")
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())
|
||||
|
||||
masked_latents_name = context.tensors.save(tensor=masked_latents)
|
||||
|
||||
@@ -131,6 +131,7 @@ class CreateGradientMaskInvocation(BaseInvocation):
|
||||
image_tensor = image_tensor.unsqueeze(0)
|
||||
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
|
||||
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
|
||||
context.util.signal_progress("Running VAE encoder")
|
||||
masked_latents = ImageToLatentsInvocation.vae_encode(
|
||||
vae_info, self.fp32, self.tiled, masked_image.clone()
|
||||
)
|
||||
|
||||
@@ -71,6 +71,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
|
||||
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
prompt_embeds = t5_encoder(prompt)
|
||||
|
||||
assert isinstance(prompt_embeds, torch.Tensor)
|
||||
@@ -111,6 +112,7 @@ class FluxTextEncoderInvocation(BaseInvocation):
|
||||
|
||||
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
|
||||
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
pooled_prompt_embeds = clip_encoder(prompt)
|
||||
|
||||
assert isinstance(pooled_prompt_embeds, torch.Tensor)
|
||||
|
||||
@@ -41,7 +41,8 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Image:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype())
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
img = vae.decode(latents)
|
||||
|
||||
img = img.clamp(-1, 1)
|
||||
@@ -53,6 +54,7 @@ class FluxVaeDecodeInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
context.util.signal_progress("Running VAE")
|
||||
image = self._vae_decode(vae_info=vae_info, latents=latents)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
@@ -44,9 +44,8 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoEncoder)
|
||||
image_tensor = image_tensor.to(
|
||||
device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
|
||||
)
|
||||
vae_dtype = next(iter(vae.parameters())).dtype
|
||||
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
|
||||
latents = vae.encode(image_tensor, sample=True, generator=generator)
|
||||
return latents
|
||||
|
||||
@@ -60,6 +59,7 @@ class FluxVaeEncodeInvocation(BaseInvocation):
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
context.util.signal_progress("Running VAE")
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
|
||||
@@ -117,6 +117,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
context.util.signal_progress("Running VAE encoder")
|
||||
latents = self.vae_encode(
|
||||
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
|
||||
)
|
||||
|
||||
@@ -60,6 +60,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
context.util.signal_progress("Running VAE decoder")
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
latents = latents.to(vae.device)
|
||||
if self.fp32:
|
||||
|
||||
@@ -147,6 +147,10 @@ GENERATION_MODES = Literal[
|
||||
"flux_img2img",
|
||||
"flux_inpaint",
|
||||
"flux_outpaint",
|
||||
"sd3_txt2img",
|
||||
"sd3_img2img",
|
||||
"sd3_inpaint",
|
||||
"sd3_outpaint",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
from typing import Callable, Tuple
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision.transforms as tv_transforms
|
||||
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.fields import (
|
||||
DenoiseMaskField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
SD3ConditioningField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
@@ -19,7 +22,9 @@ from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.invocations.sd3_text_encoder import SD3_T5_MAX_SEQ_LEN
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional
|
||||
from invokeai.backend.model_manager.config import BaseModelType
|
||||
from invokeai.backend.sd3.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import SD3ConditioningInfo
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@@ -30,16 +35,24 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="SD3 Denoise",
|
||||
tags=["image", "sd3"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Run denoising process with a SD3 model."""
|
||||
|
||||
# If latents is provided, this means we are doing image-to-image.
|
||||
latents: Optional[LatentsField] = InputField(
|
||||
default=None, description=FieldDescriptions.latents, input=Input.Connection
|
||||
)
|
||||
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
|
||||
)
|
||||
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.sd3_model,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
description=FieldDescriptions.sd3_model, input=Input.Connection, title="Transformer"
|
||||
)
|
||||
positive_conditioning: SD3ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
@@ -61,6 +74,41 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
|
||||
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
|
||||
"""Prepare the inpaint mask.
|
||||
- Loads the mask
|
||||
- Resizes if necessary
|
||||
- Casts to same device/dtype as latents
|
||||
|
||||
Args:
|
||||
context (InvocationContext): The invocation context, for loading the inpaint mask.
|
||||
latents (torch.Tensor): A latent image tensor. Used to determine the target shape, device, and dtype for the
|
||||
inpaint mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor | None: Inpaint mask. Values of 0.0 represent the regions to be fully denoised, and 1.0
|
||||
represent the regions to be preserved.
|
||||
"""
|
||||
if self.denoise_mask is None:
|
||||
return None
|
||||
mask = context.tensors.load(self.denoise_mask.mask_name)
|
||||
|
||||
# The input denoise_mask contains values in [0, 1], where 0.0 represents the regions to be fully denoised, and
|
||||
# 1.0 represents the regions to be preserved.
|
||||
# We invert the mask so that the regions to be preserved are 0.0 and the regions to be denoised are 1.0.
|
||||
mask = 1.0 - mask
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
mask = tv_resize(
|
||||
img=mask,
|
||||
size=[latent_height, latent_width],
|
||||
interpolation=tv_transforms.InterpolationMode.BILINEAR,
|
||||
antialias=False,
|
||||
)
|
||||
|
||||
mask = mask.to(device=latents.device, dtype=latents.dtype)
|
||||
return mask
|
||||
|
||||
def _load_text_conditioning(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
@@ -170,14 +218,20 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
prompt_embeds = torch.cat([neg_prompt_embeds, pos_prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([neg_pooled_prompt_embeds, pos_pooled_prompt_embeds], dim=0)
|
||||
|
||||
# Prepare the scheduler.
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
scheduler.set_timesteps(num_inference_steps=self.steps, device=device)
|
||||
timesteps = scheduler.timesteps
|
||||
assert isinstance(timesteps, torch.Tensor)
|
||||
# Prepare the timestep schedule.
|
||||
# We add an extra step to the end to account for the final timestep of 0.0.
|
||||
timesteps: list[float] = torch.linspace(1, 0, self.steps + 1).tolist()
|
||||
# Clip the timesteps schedule based on denoising_start and denoising_end.
|
||||
timesteps = clip_timestep_schedule_fractional(timesteps, self.denoising_start, self.denoising_end)
|
||||
total_steps = len(timesteps) - 1
|
||||
|
||||
# Prepare the CFG scale list.
|
||||
cfg_scale = self._prepare_cfg_scale(len(timesteps))
|
||||
cfg_scale = self._prepare_cfg_scale(total_steps)
|
||||
|
||||
# Load the input latents, if provided.
|
||||
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
|
||||
if init_latents is not None:
|
||||
init_latents = init_latents.to(device=device, dtype=inference_dtype)
|
||||
|
||||
# Generate initial latent noise.
|
||||
num_channels_latents = transformer_info.model.config.in_channels
|
||||
@@ -191,9 +245,34 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
device=device,
|
||||
seed=self.seed,
|
||||
)
|
||||
latents: torch.Tensor = noise
|
||||
|
||||
total_steps = len(timesteps)
|
||||
# Prepare input latent image.
|
||||
if init_latents is not None:
|
||||
# Noise the init_latents by the appropriate amount for the first timestep.
|
||||
t_0 = timesteps[0]
|
||||
latents = t_0 * noise + (1.0 - t_0) * init_latents
|
||||
else:
|
||||
# init_latents are not provided, so we are not doing image-to-image (i.e. we are starting from pure noise).
|
||||
if self.denoising_start > 1e-5:
|
||||
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
|
||||
latents = noise
|
||||
|
||||
# If len(timesteps) == 1, then short-circuit. We are just noising the input latents, but not taking any
|
||||
# denoising steps.
|
||||
if len(timesteps) <= 1:
|
||||
return latents
|
||||
|
||||
# Prepare inpaint extension.
|
||||
inpaint_mask = self._prep_inpaint_mask(context, latents)
|
||||
inpaint_extension: InpaintExtension | None = None
|
||||
if inpaint_mask is not None:
|
||||
assert init_latents is not None
|
||||
inpaint_extension = InpaintExtension(
|
||||
init_latents=init_latents,
|
||||
inpaint_mask=inpaint_mask,
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
step_callback = self._build_step_callback(context)
|
||||
|
||||
step_callback(
|
||||
@@ -210,11 +289,12 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
assert isinstance(transformer, SD3Transformer2DModel)
|
||||
|
||||
# 6. Denoising loop
|
||||
for step_idx, t in tqdm(list(enumerate(timesteps))):
|
||||
for step_idx, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
|
||||
# Expand the latents if we are doing CFG.
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
# Expand the timestep to match the latent model input.
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
# Multiply by 1000 to match the default FlowMatchEulerDiscreteScheduler num_train_timesteps.
|
||||
timestep = torch.tensor([t_curr * 1000], device=device).expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = transformer(
|
||||
hidden_states=latent_model_input,
|
||||
@@ -232,21 +312,19 @@ class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1.
|
||||
latents_dtype = latents.dtype
|
||||
latents = scheduler.step(model_output=noise_pred, timestep=t, sample=latents, return_dict=False)[0]
|
||||
latents = latents.to(dtype=torch.float32)
|
||||
latents = latents + (t_prev - t_curr) * noise_pred
|
||||
latents = latents.to(dtype=latents_dtype)
|
||||
|
||||
# TODO(ryand): This MPS dtype handling was copied from diffusers, I haven't tested to see if it's
|
||||
# needed.
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
if inpaint_extension is not None:
|
||||
latents = inpaint_extension.merge_intermediate_latents_with_init_latents(latents, t_prev)
|
||||
|
||||
step_callback(
|
||||
PipelineIntermediateState(
|
||||
step=step_idx + 1,
|
||||
order=1,
|
||||
total_steps=total_steps,
|
||||
timestep=int(t),
|
||||
timestep=int(t_curr),
|
||||
latents=latents,
|
||||
),
|
||||
)
|
||||
|
||||
65
invokeai/app/invocations/sd3_image_to_latents.py
Normal file
65
invokeai/app/invocations/sd3_image_to_latents.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import einops
|
||||
import torch
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_i2l",
|
||||
title="SD3 Image to Latents",
|
||||
tags=["image", "latents", "vae", "i2l", "sd3"],
|
||||
category="image",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
"""Generates latents from an image."""
|
||||
|
||||
image: ImageField = InputField(description="The image to encode")
|
||||
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
|
||||
|
||||
@staticmethod
|
||||
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
with vae_info as vae:
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
|
||||
vae.disable_tiling()
|
||||
|
||||
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
|
||||
with torch.inference_mode():
|
||||
image_tensor_dist = vae.encode(image_tensor).latent_dist
|
||||
# TODO: Use seed to make sampling reproducible.
|
||||
latents: torch.Tensor = image_tensor_dist.sample().to(dtype=vae.dtype)
|
||||
|
||||
latents = vae.config.scaling_factor * latents
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
|
||||
if image_tensor.dim() == 3:
|
||||
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
|
||||
|
||||
latents = latents.to("cpu")
|
||||
name = context.tensors.save(tensor=latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
|
||||
@@ -47,6 +47,7 @@ class SD3LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (AutoencoderKL))
|
||||
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
context.util.signal_progress("Running VAE")
|
||||
assert isinstance(vae, (AutoencoderKL))
|
||||
latents = latents.to(vae.device)
|
||||
|
||||
|
||||
@@ -95,6 +95,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
t5_text_encoder_info as t5_text_encoder,
|
||||
t5_tokenizer_info as t5_tokenizer,
|
||||
):
|
||||
context.util.signal_progress("Running T5 encoder")
|
||||
assert isinstance(t5_text_encoder, T5EncoderModel)
|
||||
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))
|
||||
|
||||
@@ -137,6 +138,7 @@ class Sd3TextEncoderInvocation(BaseInvocation):
|
||||
clip_tokenizer_info as clip_tokenizer,
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
context.util.signal_progress("Running CLIP encoder")
|
||||
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
assert isinstance(clip_tokenizer, CLIPTokenizer)
|
||||
|
||||
|
||||
@@ -160,6 +160,10 @@ class LoggerInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class ImagesInterface(InvocationContextInterface):
|
||||
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
|
||||
super().__init__(services, data)
|
||||
self._util = util
|
||||
|
||||
def save(
|
||||
self,
|
||||
image: Image,
|
||||
@@ -186,6 +190,8 @@ class ImagesInterface(InvocationContextInterface):
|
||||
The saved image DTO.
|
||||
"""
|
||||
|
||||
self._util.signal_progress("Saving image")
|
||||
|
||||
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
||||
metadata_ = None
|
||||
if metadata:
|
||||
@@ -336,6 +342,10 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
class ModelsInterface(InvocationContextInterface):
|
||||
"""Common API for loading, downloading and managing models."""
|
||||
|
||||
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
|
||||
super().__init__(services, data)
|
||||
self._util = util
|
||||
|
||||
def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
|
||||
"""Check if a model exists.
|
||||
|
||||
@@ -368,11 +378,15 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type)
|
||||
|
||||
message = f"Loading model {model.name}"
|
||||
if submodel_type:
|
||||
message += f" ({submodel_type.value})"
|
||||
self._util.signal_progress(message)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
@@ -397,6 +411,10 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
message = f"Loading model {name}"
|
||||
if submodel_type:
|
||||
message += f" ({submodel_type.value})"
|
||||
self._util.signal_progress(message)
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
@@ -467,6 +485,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
Path to the downloaded model
|
||||
"""
|
||||
self._util.signal_progress(f"Downloading model {source}")
|
||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||
|
||||
def load_local_model(
|
||||
@@ -489,6 +508,8 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
|
||||
self._util.signal_progress(f"Loading model {model_path.name}")
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
def load_remote_model(
|
||||
@@ -514,6 +535,8 @@ class ModelsInterface(InvocationContextInterface):
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))
|
||||
|
||||
self._util.signal_progress(f"Loading model {source}")
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)
|
||||
|
||||
|
||||
@@ -707,12 +730,12 @@ def build_invocation_context(
|
||||
"""
|
||||
|
||||
logger = LoggerInterface(services=services, data=data)
|
||||
images = ImagesInterface(services=services, data=data)
|
||||
tensors = TensorsInterface(services=services, data=data)
|
||||
models = ModelsInterface(services=services, data=data)
|
||||
config = ConfigInterface(services=services, data=data)
|
||||
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
|
||||
conditioning = ConditioningInterface(services=services, data=data)
|
||||
models = ModelsInterface(services=services, data=data, util=util)
|
||||
images = ImagesInterface(services=services, data=data, util=util)
|
||||
boards = BoardsInterface(services=services, data=data)
|
||||
|
||||
ctx = InvocationContext(
|
||||
|
||||
@@ -45,8 +45,9 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
# Constants for FLUX.1
|
||||
num_double_layers = 19
|
||||
num_single_layers = 38
|
||||
# inner_dim = 3072
|
||||
# mlp_ratio = 4.0
|
||||
hidden_size = 3072
|
||||
mlp_ratio = 4.0
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
layers: dict[str, AnyLoRALayer] = {}
|
||||
|
||||
@@ -62,30 +63,43 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
|
||||
assert len(src_layer_dict) == 0
|
||||
|
||||
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
|
||||
def add_qkv_lora_layer_if_present(
|
||||
src_keys: list[str],
|
||||
src_weight_shapes: list[tuple[int, int]],
|
||||
dst_qkv_key: str,
|
||||
allow_missing_keys: bool = False,
|
||||
) -> None:
|
||||
"""Handle the Q, K, V matrices for a transformer block. We need special handling because the diffusers format
|
||||
stores them in separate matrices, whereas the BFL format used internally by InvokeAI concatenates them.
|
||||
"""
|
||||
# We expect that either all src keys are present or none of them are. Verify this.
|
||||
keys_present = [key in grouped_state_dict for key in src_keys]
|
||||
assert all(keys_present) or not any(keys_present)
|
||||
|
||||
# If none of the keys are present, return early.
|
||||
keys_present = [key in grouped_state_dict for key in src_keys]
|
||||
if not any(keys_present):
|
||||
return
|
||||
|
||||
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
|
||||
sub_layers: list[LoRALayer] = []
|
||||
for src_layer_dict in src_layer_dicts:
|
||||
values = {
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
}
|
||||
if alpha is not None:
|
||||
values["alpha"] = torch.tensor(alpha)
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
assert len(src_layer_dict) == 0
|
||||
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)
|
||||
for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True):
|
||||
src_layer_dict = grouped_state_dict.pop(src_key, None)
|
||||
if src_layer_dict is not None:
|
||||
values = {
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
}
|
||||
if alpha is not None:
|
||||
values["alpha"] = torch.tensor(alpha)
|
||||
assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
|
||||
assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
assert len(src_layer_dict) == 0
|
||||
else:
|
||||
if not allow_missing_keys:
|
||||
raise ValueError(f"Missing LoRA layer: '{src_key}'.")
|
||||
values = {
|
||||
"lora_up.weight": torch.zeros((src_weight_shape[0], 1)),
|
||||
"lora_down.weight": torch.zeros((1, src_weight_shape[1])),
|
||||
}
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers)
|
||||
|
||||
# time_text_embed.timestep_embedder -> time_in.
|
||||
add_lora_layer_if_present("time_text_embed.timestep_embedder.linear_1", "time_in.in_layer")
|
||||
@@ -118,6 +132,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
f"transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
[(hidden_size, hidden_size), (hidden_size, hidden_size), (hidden_size, hidden_size)],
|
||||
f"double_blocks.{i}.img_attn.qkv",
|
||||
)
|
||||
add_qkv_lora_layer_if_present(
|
||||
@@ -126,6 +141,7 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
f"transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
[(hidden_size, hidden_size), (hidden_size, hidden_size), (hidden_size, hidden_size)],
|
||||
f"double_blocks.{i}.txt_attn.qkv",
|
||||
)
|
||||
|
||||
@@ -175,7 +191,14 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
|
||||
f"single_transformer_blocks.{i}.attn.to_v",
|
||||
f"single_transformer_blocks.{i}.proj_mlp",
|
||||
],
|
||||
[
|
||||
(hidden_size, hidden_size),
|
||||
(hidden_size, hidden_size),
|
||||
(hidden_size, hidden_size),
|
||||
(mlp_hidden_dim, hidden_size),
|
||||
],
|
||||
f"single_blocks.{i}.linear1",
|
||||
allow_missing_keys=True,
|
||||
)
|
||||
|
||||
# Output projections.
|
||||
|
||||
@@ -35,6 +35,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._logger = logger
|
||||
self._ram_cache = ram_cache
|
||||
self._torch_dtype = TorchDevice.choose_torch_dtype()
|
||||
self._torch_device = TorchDevice.choose_torch_device()
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
|
||||
@@ -84,7 +84,15 @@ class FluxVAELoader(ModelLoader):
|
||||
model = AutoEncoder(ae_params[config.config_path])
|
||||
sd = load_file(model_path)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
model.to(dtype=self._torch_dtype)
|
||||
# VAE is broken in float16, which mps defaults to
|
||||
if self._torch_dtype == torch.float16:
|
||||
try:
|
||||
vae_dtype = torch.tensor([1.0], dtype=torch.bfloat16, device=self._torch_device).dtype
|
||||
except TypeError:
|
||||
vae_dtype = torch.float32
|
||||
else:
|
||||
vae_dtype = self._torch_dtype
|
||||
model.to(vae_dtype)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -300,7 +300,7 @@ ip_adapter_sdxl = StarterModel(
|
||||
ip_adapter_flux = StarterModel(
|
||||
name="Standard Reference (XLabs FLUX IP-Adapter)",
|
||||
base=BaseModelType.Flux,
|
||||
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/flux-ip-adapter.safetensors",
|
||||
source="https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors",
|
||||
description="References images with a more generalized/looser degree of precision.",
|
||||
type=ModelType.IPAdapter,
|
||||
dependencies=[clip_vit_l_image_encoder],
|
||||
|
||||
0
invokeai/backend/sd3/__init__.py
Normal file
0
invokeai/backend/sd3/__init__.py
Normal file
0
invokeai/backend/sd3/extensions/__init__.py
Normal file
0
invokeai/backend/sd3/extensions/__init__.py
Normal file
58
invokeai/backend/sd3/extensions/inpaint_extension.py
Normal file
58
invokeai/backend/sd3/extensions/inpaint_extension.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
|
||||
|
||||
class InpaintExtension:
|
||||
"""A class for managing inpainting with SD3."""
|
||||
|
||||
def __init__(self, init_latents: torch.Tensor, inpaint_mask: torch.Tensor, noise: torch.Tensor):
|
||||
"""Initialize InpaintExtension.
|
||||
|
||||
Args:
|
||||
init_latents (torch.Tensor): The initial latents (i.e. un-noised at timestep 0).
|
||||
inpaint_mask (torch.Tensor): A mask specifying which elements to inpaint. Range [0, 1]. Values of 1 will be
|
||||
re-generated. Values of 0 will remain unchanged. Values between 0 and 1 can be used to blend the
|
||||
inpainted region with the background.
|
||||
noise (torch.Tensor): The noise tensor used to noise the init_latents.
|
||||
"""
|
||||
assert init_latents.dim() == inpaint_mask.dim() == noise.dim() == 4
|
||||
assert init_latents.shape[-2:] == inpaint_mask.shape[-2:] == noise.shape[-2:]
|
||||
|
||||
self._init_latents = init_latents
|
||||
self._inpaint_mask = inpaint_mask
|
||||
self._noise = noise
|
||||
|
||||
def _apply_mask_gradient_adjustment(self, t_prev: float) -> torch.Tensor:
|
||||
"""Applies inpaint mask gradient adjustment and returns the inpaint mask to be used at the current timestep."""
|
||||
# As we progress through the denoising process, we promote gradient regions of the mask to have a full weight of
|
||||
# 1.0. This helps to produce more coherent seams around the inpainted region. We experimented with a (small)
|
||||
# number of promotion strategies (e.g. gradual promotion based on timestep), but found that a simple cutoff
|
||||
# threshold worked well.
|
||||
# We use a small epsilon to avoid any potential issues with floating point precision.
|
||||
eps = 1e-4
|
||||
mask_gradient_t_cutoff = 0.5
|
||||
if t_prev > mask_gradient_t_cutoff:
|
||||
# Early in the denoising process, use the inpaint mask as-is.
|
||||
return self._inpaint_mask
|
||||
else:
|
||||
# After the cut-off, promote all non-zero mask values to 1.0.
|
||||
mask = self._inpaint_mask.where(self._inpaint_mask <= (0.0 + eps), 1.0)
|
||||
|
||||
return mask
|
||||
|
||||
def merge_intermediate_latents_with_init_latents(
|
||||
self, intermediate_latents: torch.Tensor, t_prev: float
|
||||
) -> torch.Tensor:
|
||||
"""Merge the intermediate latents with the initial latents for the current timestep using the inpaint mask. I.e.
|
||||
update the intermediate latents to keep the regions that are not being inpainted on the correct noise
|
||||
trajectory.
|
||||
|
||||
This function should be called after each denoising step.
|
||||
"""
|
||||
|
||||
mask = self._apply_mask_gradient_adjustment(t_prev)
|
||||
|
||||
# Noise the init latents for the current timestep.
|
||||
noised_init_latents = self._noise * t_prev + (1.0 - t_prev) * self._init_latents
|
||||
|
||||
# Merge the intermediate latents with the noised_init_latents using the inpaint_mask.
|
||||
return intermediate_latents * mask + noised_init_latents * (1.0 - mask)
|
||||
@@ -174,7 +174,8 @@
|
||||
"placeholderSelectAModel": "Select a model",
|
||||
"reset": "Reset",
|
||||
"none": "None",
|
||||
"new": "New"
|
||||
"new": "New",
|
||||
"generating": "Generating"
|
||||
},
|
||||
"hrf": {
|
||||
"hrf": "High Resolution Fix",
|
||||
@@ -704,6 +705,8 @@
|
||||
"baseModel": "Base Model",
|
||||
"cancel": "Cancel",
|
||||
"clipEmbed": "CLIP Embed",
|
||||
"clipLEmbed": "CLIP-L Embed",
|
||||
"clipGEmbed": "CLIP-G Embed",
|
||||
"config": "Config",
|
||||
"convert": "Convert",
|
||||
"convertingModelBegin": "Converting Model. Please wait.",
|
||||
@@ -997,7 +1000,7 @@
|
||||
"controlNetControlMode": "Control Mode",
|
||||
"copyImage": "Copy Image",
|
||||
"denoisingStrength": "Denoising Strength",
|
||||
"noRasterLayers": "No Raster Layers",
|
||||
"disabledNoRasterContent": "Disabled (No Raster Content)",
|
||||
"downloadImage": "Download Image",
|
||||
"general": "General",
|
||||
"guidance": "Guidance",
|
||||
@@ -1137,6 +1140,7 @@
|
||||
"resetWebUI": "Reset Web UI",
|
||||
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
|
||||
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
|
||||
"showDetailedInvocationProgress": "Show Progress Details",
|
||||
"showProgressInViewer": "Show Progress Images in Viewer",
|
||||
"ui": "User Interface",
|
||||
"clearIntermediatesDisabled": "Queue must be empty to clear intermediates",
|
||||
@@ -1671,7 +1675,7 @@
|
||||
"clearCaches": "Clear Caches",
|
||||
"recalculateRects": "Recalculate Rects",
|
||||
"clipToBbox": "Clip Strokes to Bbox",
|
||||
"outputOnlyMaskedRegions": "Output Only Masked Regions",
|
||||
"outputOnlyMaskedRegions": "Output Only Generated Regions",
|
||||
"addLayer": "Add Layer",
|
||||
"duplicate": "Duplicate",
|
||||
"moveToFront": "Move to Front",
|
||||
@@ -1787,7 +1791,7 @@
|
||||
},
|
||||
"ipAdapterMethod": {
|
||||
"ipAdapterMethod": "IP Adapter Method",
|
||||
"full": "Full",
|
||||
"full": "Style and Composition",
|
||||
"style": "Style Only",
|
||||
"composition": "Composition Only"
|
||||
},
|
||||
@@ -1999,7 +2003,9 @@
|
||||
"upscaleModelDesc": "Upscale (image to image) model",
|
||||
"missingUpscaleInitialImage": "Missing initial image for upscaling",
|
||||
"missingUpscaleModel": "Missing upscale model",
|
||||
"missingTileControlNetModel": "No valid tile ControlNet models installed"
|
||||
"missingTileControlNetModel": "No valid tile ControlNet models installed",
|
||||
"incompatibleBaseModel": "Unsupported main model architecture for upscaling",
|
||||
"incompatibleBaseModelDesc": "Upscaling is supported for SD1.5 and SDXL architecture models only. Change the main model to enable upscaling."
|
||||
},
|
||||
"stylePresets": {
|
||||
"active": "Active",
|
||||
|
||||
@@ -8,6 +8,7 @@ import { $canvasManager } from 'features/controlLayers/store/ephemeral';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { toast } from 'features/toast/toast';
|
||||
@@ -34,8 +35,8 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
let buildGraphResult: Result<
|
||||
{
|
||||
g: Graph;
|
||||
noise: Invocation<'noise' | 'flux_denoise'>;
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
|
||||
noise: Invocation<'noise' | 'flux_denoise' | 'sd3_denoise'>;
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'sd3_text_encoder'>;
|
||||
},
|
||||
Error
|
||||
>;
|
||||
@@ -51,6 +52,9 @@ export const addEnqueueRequestedLinear = (startAppListening: AppStartListening)
|
||||
case `sd-2`:
|
||||
buildGraphResult = await withResultAsync(() => buildSD1Graph(state, manager));
|
||||
break;
|
||||
case `sd-3`:
|
||||
buildGraphResult = await withResultAsync(() => buildSD3Graph(state, manager));
|
||||
break;
|
||||
case `flux`:
|
||||
buildGraphResult = await withResultAsync(() => buildFLUXGraph(state, manager));
|
||||
break;
|
||||
|
||||
@@ -41,29 +41,33 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
|
||||
|
||||
log.debug({ imageDTO }, 'Image uploaded');
|
||||
|
||||
if (action.meta.arg.originalArgs.silent || imageDTO.is_intermediate) {
|
||||
// When a "silent" upload is requested, or the image is intermediate, we can skip all post-upload actions,
|
||||
// like toasts and switching the gallery view
|
||||
return;
|
||||
}
|
||||
|
||||
const boardId = imageDTO.board_id ?? 'none';
|
||||
|
||||
if (action.meta.arg.originalArgs.withToast) {
|
||||
const DEFAULT_UPLOADED_TOAST = {
|
||||
id: 'IMAGE_UPLOADED',
|
||||
title: t('toast.imageUploaded'),
|
||||
status: 'success',
|
||||
} as const;
|
||||
const DEFAULT_UPLOADED_TOAST = {
|
||||
id: 'IMAGE_UPLOADED',
|
||||
title: t('toast.imageUploaded'),
|
||||
status: 'success',
|
||||
} as const;
|
||||
|
||||
// default action - just upload and alert user
|
||||
if (lastUploadedToastTimeout !== null) {
|
||||
window.clearTimeout(lastUploadedToastTimeout);
|
||||
}
|
||||
const toastApi = toast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
title: DEFAULT_UPLOADED_TOAST.title,
|
||||
description: getUploadedToastDescription(boardId, state),
|
||||
duration: null, // we will close the toast manually
|
||||
});
|
||||
lastUploadedToastTimeout = window.setTimeout(() => {
|
||||
toastApi.close();
|
||||
}, 3000);
|
||||
// default action - just upload and alert user
|
||||
if (lastUploadedToastTimeout !== null) {
|
||||
window.clearTimeout(lastUploadedToastTimeout);
|
||||
}
|
||||
const toastApi = toast({
|
||||
...DEFAULT_UPLOADED_TOAST,
|
||||
title: DEFAULT_UPLOADED_TOAST.title,
|
||||
description: getUploadedToastDescription(boardId, state),
|
||||
duration: null, // we will close the toast manually
|
||||
});
|
||||
lastUploadedToastTimeout = window.setTimeout(() => {
|
||||
toastApi.close();
|
||||
}, 3000);
|
||||
|
||||
/**
|
||||
* We only want to change the board and view if this is the first upload of a batch, else we end up hijacking
|
||||
|
||||
@@ -25,7 +25,8 @@ export type AppFeature =
|
||||
| 'invocationCache'
|
||||
| 'bulkDownload'
|
||||
| 'starterModels'
|
||||
| 'hfToken';
|
||||
| 'hfToken'
|
||||
| 'invocationProgressAlert';
|
||||
|
||||
/**
|
||||
* A disable-able Stable Diffusion feature
|
||||
|
||||
@@ -61,6 +61,11 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
|
||||
log.warn('Multiple files dropped but only one allowed');
|
||||
return;
|
||||
}
|
||||
if (files.length === 0) {
|
||||
// Should never happen
|
||||
log.warn('No files dropped');
|
||||
return;
|
||||
}
|
||||
const file = files[0];
|
||||
assert(file !== undefined); // should never happen
|
||||
const imageDTO = await uploadImage({
|
||||
@@ -68,18 +73,20 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
silent: true,
|
||||
}).unwrap();
|
||||
if (onUpload) {
|
||||
onUpload(imageDTO);
|
||||
}
|
||||
} else {
|
||||
//
|
||||
const imageDTOs = await uploadImages(
|
||||
files.map((file) => ({
|
||||
files.map((file, i) => ({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
silent: false,
|
||||
isFirstUploadOfBatch: i === 0,
|
||||
}))
|
||||
);
|
||||
if (onUpload) {
|
||||
|
||||
@@ -119,11 +119,20 @@ const createSelector = (
|
||||
reasons.push({ content: i18n.t('upscaling.exceedsMaxSize') });
|
||||
}
|
||||
}
|
||||
if (!upscale.upscaleModel) {
|
||||
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
|
||||
}
|
||||
if (!upscale.tileControlnetModel) {
|
||||
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
|
||||
if (model && !['sd-1', 'sdxl'].includes(model.base)) {
|
||||
// When we are using an upsupported model, do not add the other warnings
|
||||
reasons.push({ content: i18n.t('upscaling.incompatibleBaseModel') });
|
||||
} else {
|
||||
// Using a compatible model, add all warnings
|
||||
if (!model) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
|
||||
}
|
||||
if (!upscale.upscaleModel) {
|
||||
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
|
||||
}
|
||||
if (!upscale.tileControlnetModel) {
|
||||
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (canvasIsFiltering) {
|
||||
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
useAddRegionalGuidance,
|
||||
useAddRegionalReferenceImage,
|
||||
} from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
@@ -23,6 +23,7 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
const addGlobalReferenceImage = useAddGlobalReferenceImage();
|
||||
const addRegionalReferenceImage = useAddRegionalReferenceImage();
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const isSD3 = useAppSelector(selectIsSD3);
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" justifyContent="center" gap={4}>
|
||||
@@ -36,6 +37,7 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addGlobalReferenceImage}
|
||||
isDisabled={isSD3}
|
||||
>
|
||||
{t('controlLayers.globalReferenceImage')}
|
||||
</Button>
|
||||
@@ -61,7 +63,7 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalGuidance}
|
||||
isDisabled={isFLUX}
|
||||
isDisabled={isFLUX || isSD3}
|
||||
>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</Button>
|
||||
@@ -73,7 +75,7 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addRegionalReferenceImage}
|
||||
isDisabled={isFLUX}
|
||||
isDisabled={isFLUX || isSD3}
|
||||
>
|
||||
{t('controlLayers.regionalReferenceImage')}
|
||||
</Button>
|
||||
@@ -88,6 +90,7 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addControlLayer}
|
||||
isDisabled={isSD3}
|
||||
>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</Button>
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import { Alert, AlertDescription, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import { selectSystemShouldShowInvocationProgressDetail } from 'features/system/store/systemSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { $invocationProgressMessage } from 'services/events/stores';
|
||||
|
||||
const CanvasAlertsInvocationProgressContent = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const invocationProgressMessage = useStore($invocationProgressMessage);
|
||||
|
||||
if (!invocationProgressMessage) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Alert status="loading" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
|
||||
<AlertIcon />
|
||||
<AlertTitle>{t('common.generating')}</AlertTitle>
|
||||
<AlertDescription>{invocationProgressMessage}</AlertDescription>
|
||||
</Alert>
|
||||
);
|
||||
});
|
||||
CanvasAlertsInvocationProgressContent.displayName = 'CanvasAlertsInvocationProgressContent';
|
||||
|
||||
export const CanvasAlertsInvocationProgress = memo(() => {
|
||||
const isProgressMessageAlertEnabled = useFeatureStatus('invocationProgressAlert');
|
||||
const shouldShowInvocationProgressDetail = useAppSelector(selectSystemShouldShowInvocationProgressDetail);
|
||||
|
||||
// The alert is disabled at the system level
|
||||
if (!isProgressMessageAlertEnabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// The alert is disabled at the user level
|
||||
if (!shouldShowInvocationProgressDetail) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <CanvasAlertsInvocationProgressContent />;
|
||||
});
|
||||
|
||||
CanvasAlertsInvocationProgress.displayName = 'CanvasAlertsInvocationProgress';
|
||||
@@ -9,7 +9,7 @@ import {
|
||||
useAddRegionalReferenceImage,
|
||||
} from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold } from 'react-icons/pi';
|
||||
@@ -24,6 +24,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
const addRasterLayer = useAddRasterLayer();
|
||||
const addControlLayer = useAddControlLayer();
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const isSD3 = useAppSelector(selectIsSD3);
|
||||
|
||||
return (
|
||||
<Menu>
|
||||
@@ -40,7 +41,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
/>
|
||||
<MenuList>
|
||||
<MenuGroup title={t('controlLayers.global')}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addGlobalReferenceImage} isDisabled={isSD3}>
|
||||
{t('controlLayers.globalReferenceImage')}
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
@@ -48,15 +49,15 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addInpaintMask}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isFLUX || isSD3}>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>
|
||||
{t('controlLayers.regionalReferenceImage')}
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
<MenuGroup title={t('controlLayers.layer_other')}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer} isDisabled={isSD3}>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>
|
||||
|
||||
@@ -4,6 +4,7 @@ import { attachClosestEdge, extractClosestEdge } from '@atlaskit/pragmatic-drag-
|
||||
import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { singleCanvasEntityDndSource } from 'features/dnd/dnd';
|
||||
import { type DndListTargetState, idle } from 'features/dnd/types';
|
||||
import { firefoxDndFix } from 'features/dnd/util';
|
||||
import type { RefObject } from 'react';
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
@@ -17,6 +18,7 @@ export const useCanvasEntityListDnd = (ref: RefObject<HTMLElement>, entityIdenti
|
||||
return;
|
||||
}
|
||||
return combine(
|
||||
firefoxDndFix(element),
|
||||
draggable({
|
||||
element,
|
||||
getInitialData() {
|
||||
|
||||
@@ -21,6 +21,8 @@ import { GatedImageViewer } from 'features/gallery/components/ImageViewer/ImageV
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { PiDotsThreeOutlineVerticalFill } from 'react-icons/pi';
|
||||
|
||||
import { CanvasAlertsInvocationProgress } from './CanvasAlerts/CanvasAlertsInvocationProgress';
|
||||
|
||||
const MenuContent = () => {
|
||||
return (
|
||||
<CanvasManagerProviderGate>
|
||||
@@ -84,6 +86,7 @@ export const CanvasMainPanelContent = memo(() => {
|
||||
<CanvasAlertsSelectedEntityStatus />
|
||||
<CanvasAlertsPreserveMask />
|
||||
<CanvasAlertsSendingToGallery />
|
||||
<CanvasAlertsInvocationProgress />
|
||||
</Flex>
|
||||
<Flex position="absolute" top={1} insetInlineEnd={1}>
|
||||
<Menu>
|
||||
|
||||
@@ -17,12 +17,15 @@ import { selectImg2imgStrengthConfig } from 'features/system/store/configSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selectIsEnabled = createSelector(selectActiveRasterLayerEntities, (entities) => entities.length > 0);
|
||||
const selectHasRasterLayersWithContent = createSelector(
|
||||
selectActiveRasterLayerEntities,
|
||||
(entities) => entities.length > 0
|
||||
);
|
||||
|
||||
export const ParamDenoisingStrength = memo(() => {
|
||||
const img2imgStrength = useAppSelector(selectImg2imgStrength);
|
||||
const dispatch = useAppDispatch();
|
||||
const isEnabled = useAppSelector(selectIsEnabled);
|
||||
const hasRasterLayersWithContent = useAppSelector(selectHasRasterLayersWithContent);
|
||||
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
@@ -37,16 +40,16 @@ export const ParamDenoisingStrength = memo(() => {
|
||||
const [invokeBlue300] = useToken('colors', ['invokeBlue.300']);
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!isEnabled} p={1} justifyContent="space-between" h={8}>
|
||||
<FormControl isDisabled={!hasRasterLayersWithContent} p={1} justifyContent="space-between" h={8}>
|
||||
<Flex gap={3} alignItems="center">
|
||||
<InformationalPopover feature="paramDenoisingStrength">
|
||||
<FormLabel mr={0}>{`${t('parameters.denoisingStrength')}`}</FormLabel>
|
||||
</InformationalPopover>
|
||||
{isEnabled && (
|
||||
{hasRasterLayersWithContent && (
|
||||
<WavyLine amplitude={img2imgStrength * 10} stroke={invokeBlue300} strokeWidth={1} width={40} height={14} />
|
||||
)}
|
||||
</Flex>
|
||||
{isEnabled ? (
|
||||
{hasRasterLayersWithContent ? (
|
||||
<>
|
||||
<CompositeSlider
|
||||
step={config.coarseStep}
|
||||
@@ -70,9 +73,7 @@ export const ParamDenoisingStrength = memo(() => {
|
||||
</>
|
||||
) : (
|
||||
<Flex alignItems="center">
|
||||
<Badge opacity="0.6">
|
||||
{t('common.disabled')} - {t('parameters.noRasterLayers')}
|
||||
</Badge>
|
||||
<Badge opacity="0.6">{t('parameters.disabledNoRasterContent')}</Badge>
|
||||
</Flex>
|
||||
)}
|
||||
</FormControl>
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { selectSelectedImage } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFloppyDiskBold } from 'react-icons/pi';
|
||||
import { useAddImagesToBoardMutation, useChangeImageIsIntermediateMutation } from 'services/api/endpoints/images';
|
||||
import { uploadImage } from 'services/api/endpoints/images';
|
||||
|
||||
const TOAST_ID = 'SAVE_STAGING_AREA_IMAGE_TO_GALLERY';
|
||||
|
||||
export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
|
||||
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
|
||||
const selectedImage = useAppSelector(selectSelectedImage);
|
||||
const [addImageToBoard] = useAddImagesToBoardMutation();
|
||||
const [changeIsImageIntermediate] = useChangeImageIsIntermediateMutation();
|
||||
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -19,21 +21,42 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
|
||||
if (!selectedImage) {
|
||||
return;
|
||||
}
|
||||
if (autoAddBoardId !== 'none') {
|
||||
await addImageToBoard({ imageDTOs: [selectedImage.imageDTO], board_id: autoAddBoardId }).unwrap();
|
||||
// The changeIsImageIntermediate request will use the board_id on this specific imageDTO object, so we need to
|
||||
// update it before making the request - else the optimistic board updates will get out of whack.
|
||||
changeIsImageIntermediate({
|
||||
imageDTO: { ...selectedImage.imageDTO, board_id: autoAddBoardId },
|
||||
|
||||
// To save the image to gallery, we will download it and re-upload it. This allows the user to delete the image
|
||||
// the gallery without borking the canvas, which may need this image to exist.
|
||||
const result = await withResultAsync(async () => {
|
||||
// Download the image
|
||||
const res = await fetch(selectedImage.imageDTO.image_url);
|
||||
const blob = await res.blob();
|
||||
// Create a new file with the same name, which we will upload
|
||||
const file = new File([blob], `copy_of_${selectedImage.imageDTO.image_name}`, { type: 'image/png' });
|
||||
|
||||
await uploadImage({
|
||||
file,
|
||||
// Image should show up in the Images tab
|
||||
image_category: 'general',
|
||||
is_intermediate: false,
|
||||
// TODO(psyche): Maybe this should just save to the currently-selected board?
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
// We will do our own toast - opt out of the default handling
|
||||
silent: true,
|
||||
});
|
||||
});
|
||||
|
||||
if (result.isOk()) {
|
||||
toast({
|
||||
id: TOAST_ID,
|
||||
title: t('controlLayers.savedToGalleryOk'),
|
||||
status: 'success',
|
||||
});
|
||||
} else {
|
||||
changeIsImageIntermediate({
|
||||
imageDTO: selectedImage.imageDTO,
|
||||
is_intermediate: false,
|
||||
toast({
|
||||
id: TOAST_ID,
|
||||
title: t('controlLayers.savedToGalleryError'),
|
||||
status: 'error',
|
||||
});
|
||||
}
|
||||
}, [addImageToBoard, autoAddBoardId, changeIsImageIntermediate, selectedImage]);
|
||||
}, [autoAddBoardId, selectedImage, t]);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
@@ -42,7 +65,7 @@ export const StagingAreaToolbarSaveSelectedToGalleryButton = memo(() => {
|
||||
icon={<PiFloppyDiskBold />}
|
||||
onClick={saveSelectedImageToGallery}
|
||||
colorScheme="invokeBlue"
|
||||
isDisabled={!selectedImage || !selectedImage.imageDTO.is_intermediate}
|
||||
isDisabled={!selectedImage}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -25,6 +25,8 @@ import type {
|
||||
RegionalGuidanceReferenceImageState,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } from 'features/controlLayers/store/util';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -70,6 +72,11 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit
|
||||
metadata = selectCanvasMetadata(store.getState());
|
||||
}
|
||||
|
||||
let boardId: BoardId | undefined = undefined;
|
||||
if (saveToGallery) {
|
||||
boardId = selectAutoAddBoardId(store.getState());
|
||||
}
|
||||
|
||||
const result = await withResultAsync(() => {
|
||||
const rasterAdapters = canvasManager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
return canvasManager.compositor.getCompositeImageDTO(
|
||||
@@ -78,6 +85,8 @@ const useSaveCanvas = ({ region, saveToGallery, toastOk, toastError, onSave, wit
|
||||
{
|
||||
is_intermediate: !saveToGallery,
|
||||
metadata,
|
||||
board_id: boardId,
|
||||
silent: true,
|
||||
},
|
||||
undefined,
|
||||
true // force upload the image to ensure it gets added to the gallery
|
||||
@@ -222,8 +231,8 @@ export const useNewRasterLayerFromBbox = () => {
|
||||
toastError: t('controlLayers.newRasterLayerError'),
|
||||
};
|
||||
}, [dispatch, t]);
|
||||
const newRasterLayerFromBbox = useSaveCanvas(arg);
|
||||
return newRasterLayerFromBbox;
|
||||
const func = useSaveCanvas(arg);
|
||||
return func;
|
||||
};
|
||||
|
||||
export const useNewControlLayerFromBbox = () => {
|
||||
|
||||
@@ -28,19 +28,17 @@ import type {
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { getEntityIdentifier } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { Logger } from 'roarr';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import type { UploadImageArg } from 'services/api/endpoints/images';
|
||||
import { getImageDTOSafe, uploadImage } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { ImageDTO, UploadImageArg } from 'services/api/types';
|
||||
import stableHash from 'stable-hash';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
import type { JsonObject, SetOptional } from 'type-fest';
|
||||
|
||||
type CompositingOptions = {
|
||||
/**
|
||||
@@ -259,7 +257,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
|
||||
getCompositeImageDTO = async (
|
||||
adapters: CanvasEntityAdapter[],
|
||||
rect: Rect,
|
||||
uploadOptions: Pick<UploadImageArg, 'is_intermediate' | 'metadata'>,
|
||||
uploadOptions: SetOptional<Omit<UploadImageArg, 'file'>, 'image_category'>,
|
||||
compositingOptions?: CompositingOptions,
|
||||
forceUpload?: boolean
|
||||
): Promise<ImageDTO> => {
|
||||
@@ -299,10 +297,7 @@ export class CanvasCompositorModule extends CanvasModuleBase {
|
||||
uploadImage({
|
||||
file: new File([blob], 'canvas-composite.png', { type: 'image/png' }),
|
||||
image_category: 'general',
|
||||
is_intermediate: uploadOptions.is_intermediate,
|
||||
board_id: uploadOptions.is_intermediate ? undefined : selectAutoAddBoardId(this.manager.store.getState()),
|
||||
metadata: uploadOptions.metadata,
|
||||
withToast: false,
|
||||
...uploadOptions,
|
||||
})
|
||||
);
|
||||
this.$isUploading.set(false);
|
||||
|
||||
@@ -493,7 +493,7 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase {
|
||||
file: new File([blob], `${this.id}_rasterized.png`, { type: 'image/png' }),
|
||||
image_category: 'other',
|
||||
is_intermediate: true,
|
||||
withToast: false,
|
||||
silent: true,
|
||||
});
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
if (replaceObjects) {
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
import {
|
||||
roundDownToMultiple,
|
||||
roundToMultiple,
|
||||
roundToMultipleMin,
|
||||
roundUpToMultiple,
|
||||
} from 'common/util/roundDownToMultiple';
|
||||
import { roundToMultiple, roundToMultipleMin } from 'common/util/roundDownToMultiple';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule';
|
||||
import { getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { fitRectToGrid, getKonvaNodeDebugAttrs, getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectBboxOverlay } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectBbox } from 'features/controlLayers/store/selectors';
|
||||
import type { Coordinate, Rect } from 'features/controlLayers/store/types';
|
||||
@@ -398,18 +393,12 @@ export class CanvasBboxToolModule extends CanvasModuleBase {
|
||||
}
|
||||
|
||||
// Determine the bbox size that fits within the visible rect. The bbox must be at least 64px in width and height,
|
||||
// and its width and height must be multiples of 8px.
|
||||
// and its width and height must be multiples of the bbox grid size.
|
||||
const gridSize = this.manager.stateApi.getBboxGridSize();
|
||||
|
||||
// To be conservative, we will round up the x and y to the nearest grid size, and round down the width and height.
|
||||
// This ensures the bbox is never _larger_ than the visible rect. If the bbox is larger than the visible, we
|
||||
// will always trigger the outpainting workflow, which is not what the user wants.
|
||||
const x = roundUpToMultiple(visibleRect.x, gridSize);
|
||||
const y = roundUpToMultiple(visibleRect.y, gridSize);
|
||||
const width = roundDownToMultiple(visibleRect.width, gridSize);
|
||||
const height = roundDownToMultiple(visibleRect.height, gridSize);
|
||||
const rect = fitRectToGrid(visibleRect, gridSize);
|
||||
|
||||
this.manager.stateApi.setGenerationBbox({ x, y, width, height });
|
||||
this.manager.stateApi.setGenerationBbox(rect);
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import { getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util';
|
||||
import { roundUpToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import { fitRectToGrid, getPrefixedId, getRectUnion } from 'features/controlLayers/konva/util';
|
||||
import type { Rect } from 'features/controlLayers/store/types';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
describe('util', () => {
|
||||
@@ -44,4 +46,74 @@ describe('util', () => {
|
||||
expect(union).toEqual({ x: 0, y: 0, width: 0, height: 0 });
|
||||
});
|
||||
});
|
||||
|
||||
describe('fitRectToGrid', () => {
|
||||
it('should fit rect within grid without exceeding bounds', () => {
|
||||
const rect: Rect = { x: 0, y: 0, width: 1047, height: 1758 };
|
||||
const gridSize = 50;
|
||||
const result = fitRectToGrid(rect, gridSize);
|
||||
|
||||
expect(result.x).toBe(roundUpToMultiple(rect.x, gridSize));
|
||||
expect(result.y).toBe(roundUpToMultiple(rect.y, gridSize));
|
||||
expect(result.width).toBeLessThanOrEqual(rect.width);
|
||||
expect(result.height).toBeLessThanOrEqual(rect.height);
|
||||
expect(result.width % gridSize).toBe(0);
|
||||
expect(result.height % gridSize).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle small rect within grid bounds', () => {
|
||||
const rect: Rect = { x: 20, y: 30, width: 80, height: 90 };
|
||||
const gridSize = 25;
|
||||
const result = fitRectToGrid(rect, gridSize);
|
||||
|
||||
expect(result.x).toBe(25);
|
||||
expect(result.y).toBe(50);
|
||||
expect(result.width % gridSize).toBe(0);
|
||||
expect(result.height % gridSize).toBe(0);
|
||||
expect(result.width).toBeLessThanOrEqual(rect.width);
|
||||
expect(result.height).toBeLessThanOrEqual(rect.height);
|
||||
});
|
||||
|
||||
it('should handle rect starting outside of grid alignment', () => {
|
||||
const rect: Rect = { x: 13, y: 27, width: 94, height: 112 };
|
||||
const gridSize = 20;
|
||||
const result = fitRectToGrid(rect, gridSize);
|
||||
|
||||
expect(result.x).toBe(20);
|
||||
expect(result.y).toBe(40);
|
||||
expect(result.width % gridSize).toBe(0);
|
||||
expect(result.height % gridSize).toBe(0);
|
||||
expect(result.width).toBeLessThanOrEqual(rect.width);
|
||||
expect(result.height).toBeLessThanOrEqual(rect.height);
|
||||
});
|
||||
|
||||
it('should return the same rect if already aligned to grid', () => {
|
||||
const rect: Rect = { x: 100, y: 100, width: 200, height: 300 };
|
||||
const gridSize = 50;
|
||||
const result = fitRectToGrid(rect, gridSize);
|
||||
|
||||
expect(result).toEqual(rect);
|
||||
});
|
||||
|
||||
it('should handle large grid sizes relative to rect dimensions', () => {
|
||||
const rect: Rect = { x: 250, y: 300, width: 400, height: 500 };
|
||||
const gridSize = 100;
|
||||
const result = fitRectToGrid(rect, gridSize);
|
||||
|
||||
expect(result.x).toBe(300);
|
||||
expect(result.y).toBe(300);
|
||||
expect(result.width % gridSize).toBe(0);
|
||||
expect(result.height % gridSize).toBe(0);
|
||||
expect(result.width).toBeLessThanOrEqual(rect.width);
|
||||
expect(result.height).toBeLessThanOrEqual(rect.height);
|
||||
});
|
||||
|
||||
it('should handle rect with zero width and height', () => {
|
||||
const rect: Rect = { x: 40, y: 60, width: 100, height: 200 };
|
||||
const gridSize = 20;
|
||||
const result = fitRectToGrid(rect, gridSize);
|
||||
|
||||
expect(result).toEqual({ x: 40, y: 60, width: 100, height: 200 });
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { Selector, Store } from '@reduxjs/toolkit';
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { roundDownToMultiple, roundUpToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import type {
|
||||
CanvasEntityIdentifier,
|
||||
CanvasObjectState,
|
||||
@@ -560,6 +561,33 @@ export const getRectIntersection = (...rects: Rect[]): Rect => {
|
||||
return rect || getEmptyRect();
|
||||
};
|
||||
|
||||
/**
|
||||
* Fits a rect to the nearest multiple of the grid size, rounding down. The returned rect will be smaller than or equal
|
||||
* to the input rect, and will be aligned to the grid.
|
||||
*
|
||||
* In other words, shrink the rect inwards on each size until it fits within the visible rect and aligns to the grid.
|
||||
*
|
||||
* @param rect The rect to fit
|
||||
* @param gridSize The size of the grid
|
||||
* @returns The fitted rect
|
||||
*/
|
||||
export const fitRectToGrid = (rect: Rect, gridSize: number): Rect => {
|
||||
// Rounding x and y up effectively shrinks the left and top edges of the rect, and rounding width and height down
|
||||
// effectively shrinks the right and bottom edges.
|
||||
const x = roundUpToMultiple(rect.x, gridSize);
|
||||
const y = roundUpToMultiple(rect.y, gridSize);
|
||||
|
||||
// Because we've just shifted the rect's x and y, we need to adjust the width and height by the same amount before
|
||||
// we round those values down.
|
||||
const offsetX = x - rect.x;
|
||||
const offsetY = y - rect.y;
|
||||
|
||||
const width = roundDownToMultiple(rect.width - offsetX, gridSize);
|
||||
const height = roundDownToMultiple(rect.height - offsetY, gridSize);
|
||||
|
||||
return { x, y, width, height };
|
||||
};
|
||||
|
||||
/**
|
||||
* Asserts that the value is never reached. Used for exhaustive checks in switch statements or conditional logic to ensure
|
||||
* that all possible values are handled.
|
||||
|
||||
@@ -90,7 +90,7 @@ const initialState: CanvasSettingsState = {
|
||||
invertScrollForToolWidth: false,
|
||||
color: { r: 31, g: 160, b: 224, a: 1 }, // invokeBlue.500
|
||||
sendToCanvas: false,
|
||||
outputOnlyMaskedRegions: false,
|
||||
outputOnlyMaskedRegions: true,
|
||||
autoProcess: true,
|
||||
snapToGrid: true,
|
||||
showProgressOnCanvas: true,
|
||||
|
||||
@@ -9,6 +9,8 @@ import type {
|
||||
ParameterCFGRescaleMultiplier,
|
||||
ParameterCFGScale,
|
||||
ParameterCLIPEmbedModel,
|
||||
ParameterCLIPGEmbedModel,
|
||||
ParameterCLIPLEmbedModel,
|
||||
ParameterGuidance,
|
||||
ParameterMaskBlurMethod,
|
||||
ParameterModel,
|
||||
@@ -71,6 +73,8 @@ export type ParamsState = {
|
||||
refinerStart: number;
|
||||
t5EncoderModel: ParameterT5EncoderModel | null;
|
||||
clipEmbedModel: ParameterCLIPEmbedModel | null;
|
||||
clipLEmbedModel: ParameterCLIPLEmbedModel | null;
|
||||
clipGEmbedModel: ParameterCLIPGEmbedModel | null;
|
||||
};
|
||||
|
||||
const initialState: ParamsState = {
|
||||
@@ -115,6 +119,8 @@ const initialState: ParamsState = {
|
||||
refinerStart: 0.8,
|
||||
t5EncoderModel: null,
|
||||
clipEmbedModel: null,
|
||||
clipLEmbedModel: null,
|
||||
clipGEmbedModel: null,
|
||||
};
|
||||
|
||||
export const paramsSlice = createSlice({
|
||||
@@ -192,6 +198,12 @@ export const paramsSlice = createSlice({
|
||||
clipEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPEmbedModel | null>) => {
|
||||
state.clipEmbedModel = action.payload;
|
||||
},
|
||||
clipLEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPLEmbedModel | null>) => {
|
||||
state.clipLEmbedModel = action.payload;
|
||||
},
|
||||
clipGEmbedModelSelected: (state, action: PayloadAction<ParameterCLIPGEmbedModel | null>) => {
|
||||
state.clipGEmbedModel = action.payload;
|
||||
},
|
||||
vaePrecisionChanged: (state, action: PayloadAction<ParameterPrecision>) => {
|
||||
state.vaePrecision = action.payload;
|
||||
},
|
||||
@@ -305,6 +317,8 @@ export const {
|
||||
vaePrecisionChanged,
|
||||
t5EncoderModelSelected,
|
||||
clipEmbedModelSelected,
|
||||
clipLEmbedModelSelected,
|
||||
clipGEmbedModelSelected,
|
||||
setClipSkip,
|
||||
shouldUseCpuNoiseChanged,
|
||||
positivePromptChanged,
|
||||
@@ -341,6 +355,7 @@ export const createParamsSelector = <T>(selector: Selector<ParamsState, T>) =>
|
||||
export const selectBase = createParamsSelector((params) => params.model?.base);
|
||||
export const selectIsSDXL = createParamsSelector((params) => params.model?.base === 'sdxl');
|
||||
export const selectIsFLUX = createParamsSelector((params) => params.model?.base === 'flux');
|
||||
export const selectIsSD3 = createParamsSelector((params) => params.model?.base === 'sd-3');
|
||||
|
||||
export const selectModel = createParamsSelector((params) => params.model);
|
||||
export const selectModelKey = createParamsSelector((params) => params.model?.key);
|
||||
@@ -349,6 +364,10 @@ export const selectFLUXVAE = createParamsSelector((params) => params.fluxVAE);
|
||||
export const selectVAEKey = createParamsSelector((params) => params.vae?.key);
|
||||
export const selectT5EncoderModel = createParamsSelector((params) => params.t5EncoderModel);
|
||||
export const selectCLIPEmbedModel = createParamsSelector((params) => params.clipEmbedModel);
|
||||
export const selectCLIPLEmbedModel = createParamsSelector((params) => params.clipLEmbedModel);
|
||||
|
||||
export const selectCLIPGEmbedModel = createParamsSelector((params) => params.clipGEmbedModel);
|
||||
|
||||
export const selectCFGScale = createParamsSelector((params) => params.cfgScale);
|
||||
export const selectGuidance = createParamsSelector((params) => params.guidance);
|
||||
export const selectSteps = createParamsSelector((params) => params.steps);
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
|
||||
import { draggable } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
|
||||
import type { ImageProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Image } from '@invoke-ai/ui-library';
|
||||
@@ -5,6 +6,7 @@ import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { singleImageDndSource } from 'features/dnd/dnd';
|
||||
import type { DndDragPreviewSingleImageState } from 'features/dnd/DndDragPreviewSingleImage';
|
||||
import { createSingleImageDragPreview, setSingleImageDragPreview } from 'features/dnd/DndDragPreviewSingleImage';
|
||||
import { firefoxDndFix } from 'features/dnd/util';
|
||||
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
import { memo, useEffect, useState } from 'react';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -35,25 +37,28 @@ export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: Props) => {
|
||||
if (!element) {
|
||||
return;
|
||||
}
|
||||
return draggable({
|
||||
element,
|
||||
getInitialData: () => singleImageDndSource.getData({ imageDTO }, imageDTO.image_name),
|
||||
onDragStart: () => {
|
||||
setIsDragging(true);
|
||||
},
|
||||
onDrop: () => {
|
||||
setIsDragging(false);
|
||||
},
|
||||
onGenerateDragPreview: (args) => {
|
||||
if (singleImageDndSource.typeGuard(args.source.data)) {
|
||||
setSingleImageDragPreview({
|
||||
singleImageDndData: args.source.data,
|
||||
onGenerateDragPreviewArgs: args,
|
||||
setDragPreviewState,
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
return combine(
|
||||
firefoxDndFix(element),
|
||||
draggable({
|
||||
element,
|
||||
getInitialData: () => singleImageDndSource.getData({ imageDTO }, imageDTO.image_name),
|
||||
onDragStart: () => {
|
||||
setIsDragging(true);
|
||||
},
|
||||
onDrop: () => {
|
||||
setIsDragging(false);
|
||||
},
|
||||
onGenerateDragPreview: (args) => {
|
||||
if (singleImageDndSource.typeGuard(args.source.data)) {
|
||||
setSingleImageDragPreview({
|
||||
singleImageDndData: args.source.data,
|
||||
onGenerateDragPreviewArgs: args,
|
||||
setDragPreviewState,
|
||||
});
|
||||
}
|
||||
},
|
||||
})
|
||||
);
|
||||
}, [imageDTO, element, store]);
|
||||
|
||||
useImageContextMenu(imageDTO, element);
|
||||
|
||||
@@ -11,10 +11,11 @@ import type { DndTargetState } from 'features/dnd/types';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { memo, useEffect, useRef, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { type UploadImageArg, uploadImages } from 'services/api/endpoints/images';
|
||||
import { uploadImages } from 'services/api/endpoints/images';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
import type { UploadImageArg } from 'services/api/types';
|
||||
import { z } from 'zod';
|
||||
|
||||
const ACCEPTED_IMAGE_TYPES = ['image/png', 'image/jpg', 'image/jpeg'];
|
||||
@@ -71,13 +72,47 @@ export const FullscreenDropzone = memo(() => {
|
||||
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
|
||||
const [dndState, setDndState] = useState<DndTargetState>('idle');
|
||||
|
||||
const uploadFilesSchema = useMemo(() => getFilesSchema(maxImageUploadCount), [maxImageUploadCount]);
|
||||
|
||||
const validateAndUploadFiles = useCallback(
|
||||
(files: File[]) => {
|
||||
const { getState } = getStore();
|
||||
const parseResult = uploadFilesSchema.safeParse(files);
|
||||
|
||||
if (!parseResult.success) {
|
||||
const description =
|
||||
maxImageUploadCount === undefined
|
||||
? t('toast.uploadFailedInvalidUploadDesc')
|
||||
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
|
||||
|
||||
toast({
|
||||
id: 'UPLOAD_FAILED',
|
||||
title: t('toast.uploadFailed'),
|
||||
description,
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
const autoAddBoardId = selectAutoAddBoardId(getState());
|
||||
|
||||
const uploadArgs: UploadImageArg[] = files.map((file, i) => ({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
isFirstUploadOfBatch: i === 0,
|
||||
}));
|
||||
|
||||
uploadImages(uploadArgs);
|
||||
},
|
||||
[maxImageUploadCount, t, uploadFilesSchema]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const element = ref.current;
|
||||
if (!element) {
|
||||
return;
|
||||
}
|
||||
const { getState } = getStore();
|
||||
const uploadFilesSchema = getFilesSchema(maxImageUploadCount);
|
||||
|
||||
return combine(
|
||||
dropTargetForExternal({
|
||||
@@ -85,32 +120,7 @@ export const FullscreenDropzone = memo(() => {
|
||||
canDrop: containsFiles,
|
||||
onDrop: ({ source }) => {
|
||||
const files = getFiles({ source });
|
||||
const parseResult = uploadFilesSchema.safeParse(files);
|
||||
|
||||
if (!parseResult.success) {
|
||||
const description =
|
||||
maxImageUploadCount === undefined
|
||||
? t('toast.uploadFailedInvalidUploadDesc')
|
||||
: t('toast.uploadFailedInvalidUploadDesc_withCount', { count: maxImageUploadCount });
|
||||
|
||||
toast({
|
||||
id: 'UPLOAD_FAILED',
|
||||
title: t('toast.uploadFailed'),
|
||||
description,
|
||||
status: 'error',
|
||||
});
|
||||
return;
|
||||
}
|
||||
const autoAddBoardId = selectAutoAddBoardId(getState());
|
||||
|
||||
const uploadArgs: UploadImageArg[] = files.map((file) => ({
|
||||
file,
|
||||
image_category: 'user',
|
||||
is_intermediate: false,
|
||||
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
|
||||
}));
|
||||
|
||||
uploadImages(uploadArgs);
|
||||
validateAndUploadFiles(files);
|
||||
},
|
||||
onDragEnter: () => {
|
||||
setDndState('over');
|
||||
@@ -131,7 +141,27 @@ export const FullscreenDropzone = memo(() => {
|
||||
},
|
||||
})
|
||||
);
|
||||
}, [maxImageUploadCount, t]);
|
||||
}, [validateAndUploadFiles]);
|
||||
|
||||
useEffect(() => {
|
||||
const controller = new AbortController();
|
||||
|
||||
document.addEventListener(
|
||||
'paste',
|
||||
(e) => {
|
||||
if (!e.clipboardData?.files) {
|
||||
return;
|
||||
}
|
||||
const files = Array.from(e.clipboardData.files);
|
||||
validateAndUploadFiles(files);
|
||||
},
|
||||
{ signal: controller.signal }
|
||||
);
|
||||
|
||||
return () => {
|
||||
controller.abort();
|
||||
};
|
||||
}, [validateAndUploadFiles]);
|
||||
|
||||
return (
|
||||
<Box ref={ref} data-dnd-state={dndState} sx={sx}>
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { GetOffsetFn } from '@atlaskit/pragmatic-drag-and-drop/dist/types/public-utils/element/custom-native-drag-preview/types';
|
||||
import type { Input } from '@atlaskit/pragmatic-drag-and-drop/types';
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { noop } from 'lodash-es';
|
||||
import type { CSSProperties } from 'react';
|
||||
|
||||
/**
|
||||
@@ -44,3 +45,67 @@ export function triggerPostMoveFlash(element: HTMLElement, backgroundColor: CSSP
|
||||
iterations: 1,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Firefox has a bug where input or textarea elements with draggable parents do not allow selection of their text.
|
||||
*
|
||||
* This helper function implements a workaround by setting the draggable attribute to false when the mouse is over a
|
||||
* input or textarea child of the draggable. It reverts the attribute on mouse out.
|
||||
*
|
||||
* The fix is only applied for Firefox, and should be used in every `pragmatic-drag-and-drop` `draggable`.
|
||||
*
|
||||
* See:
|
||||
* - https://github.com/atlassian/pragmatic-drag-and-drop/issues/111
|
||||
* - https://bugzilla.mozilla.org/show_bug.cgi?id=1853069
|
||||
*
|
||||
* @example
|
||||
* ```tsx
|
||||
* useEffect(() => {
|
||||
* const element = ref.current;
|
||||
* if (!element) {
|
||||
* return;
|
||||
* }
|
||||
* return combine(
|
||||
* firefoxDndFix(element),
|
||||
* // The rest of the draggable setup is the same
|
||||
* draggable({
|
||||
* element,
|
||||
* // ...
|
||||
* }),
|
||||
* );
|
||||
*```
|
||||
* @param element The draggable element
|
||||
* @returns A cleanup function that removes the event listeners
|
||||
*/
|
||||
export const firefoxDndFix = (element: HTMLElement): (() => void) => {
|
||||
if (!navigator.userAgent.includes('Firefox')) {
|
||||
return noop;
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
|
||||
element.addEventListener(
|
||||
'mouseover',
|
||||
(event) => {
|
||||
if (event.target instanceof HTMLTextAreaElement || event.target instanceof HTMLInputElement) {
|
||||
element.setAttribute('draggable', 'false');
|
||||
}
|
||||
},
|
||||
{ signal: abortController.signal }
|
||||
);
|
||||
|
||||
element.addEventListener(
|
||||
'mouseout',
|
||||
(event) => {
|
||||
if (event.target instanceof HTMLTextAreaElement || event.target instanceof HTMLInputElement) {
|
||||
element.setAttribute('draggable', 'true');
|
||||
}
|
||||
},
|
||||
{ signal: abortController.signal }
|
||||
);
|
||||
|
||||
return () => {
|
||||
element.setAttribute('draggable', 'true');
|
||||
abortController.abort();
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import { Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu';
|
||||
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
|
||||
import { sentImageToCanvas } from 'features/gallery/store/actions';
|
||||
@@ -20,6 +22,8 @@ export const ImageMenuItemNewFromImageSubMenu = memo(() => {
|
||||
const imageDTO = useImageDTOContext();
|
||||
const imageViewer = useImageViewer();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const isSD3 = useAppSelector(selectIsSD3);
|
||||
|
||||
const onClickNewCanvasWithRasterLayerFromImage = useCallback(() => {
|
||||
const { dispatch, getState } = store;
|
||||
@@ -110,17 +114,25 @@ export const ImageMenuItemNewFromImageSubMenu = memo(() => {
|
||||
<MenuItem
|
||||
icon={<PiFileBold />}
|
||||
onClickCapture={onClickNewCanvasWithControlLayerFromImage}
|
||||
isDisabled={isBusy}
|
||||
isDisabled={isBusy || isSD3}
|
||||
>
|
||||
{t('controlLayers.canvasAsControlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewInpaintMaskFromImage} isDisabled={isBusy}>
|
||||
{t('controlLayers.inpaintMask')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewRegionalGuidanceFromImage} isDisabled={isBusy}>
|
||||
<MenuItem
|
||||
icon={<NewLayerIcon />}
|
||||
onClickCapture={onClickNewRegionalGuidanceFromImage}
|
||||
isDisabled={isBusy || isFLUX || isSD3}
|
||||
>
|
||||
{t('controlLayers.regionalGuidance')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewControlLayerFromImage} isDisabled={isBusy}>
|
||||
<MenuItem
|
||||
icon={<NewLayerIcon />}
|
||||
onClickCapture={onClickNewControlLayerFromImage}
|
||||
isDisabled={isBusy || isSD3}
|
||||
>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClickNewRasterLayerFromImage} isDisabled={isBusy}>
|
||||
|
||||
@@ -12,6 +12,7 @@ import type { DndDragPreviewMultipleImageState } from 'features/dnd/DndDragPrevi
|
||||
import { createMultipleImageDragPreview, setMultipleImageDragPreview } from 'features/dnd/DndDragPreviewMultipleImage';
|
||||
import type { DndDragPreviewSingleImageState } from 'features/dnd/DndDragPreviewSingleImage';
|
||||
import { createSingleImageDragPreview, setSingleImageDragPreview } from 'features/dnd/DndDragPreviewSingleImage';
|
||||
import { firefoxDndFix } from 'features/dnd/util';
|
||||
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
import { GalleryImageHoverIcons } from 'features/gallery/components/ImageGrid/GalleryImageHoverIcons';
|
||||
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
|
||||
@@ -66,7 +67,7 @@ const galleryImageContainerSX = {
|
||||
},
|
||||
'&:hover::before': {
|
||||
boxShadow:
|
||||
'inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-800)',
|
||||
'inset 0px 0px 0px 1px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-800)',
|
||||
},
|
||||
'&:hover[data-selected=true]::before': {
|
||||
boxShadow:
|
||||
@@ -115,13 +116,17 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
|
||||
return;
|
||||
}
|
||||
return combine(
|
||||
firefoxDndFix(element),
|
||||
draggable({
|
||||
element,
|
||||
getInitialData: () => {
|
||||
const { gallery } = store.getState();
|
||||
// When we have multiple images selected, and the dragged image is part of the selection, initiate a
|
||||
// multi-image drag.
|
||||
if (gallery.selection.length > 1 && gallery.selection.includes(imageDTO)) {
|
||||
if (
|
||||
gallery.selection.length > 1 &&
|
||||
gallery.selection.find(({ image_name }) => image_name === imageDTO.image_name) !== undefined
|
||||
) {
|
||||
return multipleImageDndSource.getData({
|
||||
imageDTOs: gallery.selection,
|
||||
boardId: gallery.selectedBoardId,
|
||||
|
||||
@@ -2,6 +2,7 @@ import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { CanvasAlertsInvocationProgress } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsInvocationProgress';
|
||||
import { CanvasAlertsSendingToCanvas } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSendingTo';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
|
||||
@@ -48,9 +49,18 @@ const CurrentImagePreview = () => {
|
||||
position="relative"
|
||||
>
|
||||
<ImageContent imageDTO={imageDTO} />
|
||||
<Box position="absolute" top={0} insetInlineStart={0}>
|
||||
<Flex
|
||||
flexDir="column"
|
||||
gap={2}
|
||||
position="absolute"
|
||||
top={0}
|
||||
insetInlineStart={0}
|
||||
pointerEvents="none"
|
||||
alignItems="flex-start"
|
||||
>
|
||||
<CanvasAlertsSendingToCanvas />
|
||||
</Box>
|
||||
<CanvasAlertsInvocationProgress />
|
||||
</Flex>
|
||||
{shouldShowImageDetails && imageDTO && (
|
||||
<Box position="absolute" opacity={0.8} top={0} width="full" height="full" borderRadius="base">
|
||||
<ImageMetadataViewer image={imageDTO} />
|
||||
|
||||
@@ -44,7 +44,6 @@ export const ImageViewer = memo(({ closeButton }: Props) => {
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
rowGap={2}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
|
||||
@@ -10,6 +10,8 @@ import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
||||
{ value: 'sd-3', label: MODEL_TYPE_MAP['sd-3'] },
|
||||
{ value: 'flux', label: MODEL_TYPE_MAP['flux'] },
|
||||
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
|
||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||
];
|
||||
|
||||
@@ -4,6 +4,7 @@ import { attachClosestEdge, extractClosestEdge } from '@atlaskit/pragmatic-drag-
|
||||
import { singleWorkflowFieldDndSource } from 'features/dnd/dnd';
|
||||
import type { DndListTargetState } from 'features/dnd/types';
|
||||
import { idle } from 'features/dnd/types';
|
||||
import { firefoxDndFix } from 'features/dnd/util';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import type { RefObject } from 'react';
|
||||
import { useEffect, useState } from 'react';
|
||||
@@ -18,6 +19,7 @@ export const useLinearViewFieldDnd = (ref: RefObject<HTMLElement>, fieldIdentifi
|
||||
return;
|
||||
}
|
||||
return combine(
|
||||
firefoxDndFix(element),
|
||||
draggable({
|
||||
element,
|
||||
getInitialData() {
|
||||
|
||||
@@ -9,8 +9,8 @@ export const prepareLinearUIBatch = (
|
||||
state: RootState,
|
||||
g: Graph,
|
||||
prepend: boolean,
|
||||
noise: Invocation<'noise' | 'flux_denoise'>,
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>,
|
||||
noise: Invocation<'noise' | 'flux_denoise' | 'sd3_denoise'>,
|
||||
posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder' | 'sd3_text_encoder'>,
|
||||
origin: 'canvas' | 'workflows' | 'upscaling',
|
||||
destination: 'canvas' | 'gallery'
|
||||
): BatchConfig => {
|
||||
|
||||
@@ -2,16 +2,18 @@ import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type { CanvasState, Dimensions } from 'features/controlLayers/store/types';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { addImageToLatents } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
|
||||
type AddImageToImageArg = {
|
||||
g: Graph;
|
||||
manager: CanvasManager;
|
||||
l2i: Invocation<'l2i' | 'flux_vae_decode'>;
|
||||
denoise: Invocation<'denoise_latents' | 'flux_denoise'>;
|
||||
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader'>;
|
||||
l2i: Invocation<'l2i' | 'flux_vae_decode' | 'sd3_l2i'>;
|
||||
i2lNodeType: 'i2l' | 'flux_vae_encode' | 'sd3_i2l';
|
||||
denoise: Invocation<'denoise_latents' | 'flux_denoise' | 'sd3_denoise'>;
|
||||
vaeSource: Invocation<
|
||||
'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader' | 'sd3_model_loader'
|
||||
>;
|
||||
originalSize: Dimensions;
|
||||
scaledSize: Dimensions;
|
||||
bbox: CanvasState['bbox'];
|
||||
@@ -23,6 +25,7 @@ export const addImageToImage = async ({
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType,
|
||||
denoise,
|
||||
vaeSource,
|
||||
originalSize,
|
||||
@@ -30,10 +33,13 @@ export const addImageToImage = async ({
|
||||
bbox,
|
||||
denoising_start,
|
||||
fp32,
|
||||
}: AddImageToImageArg): Promise<Invocation<'img_resize' | 'l2i' | 'flux_vae_decode'>> => {
|
||||
}: AddImageToImageArg): Promise<Invocation<'img_resize' | 'l2i' | 'flux_vae_decode' | 'sd3_l2i'>> => {
|
||||
denoise.denoising_start = denoising_start;
|
||||
const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, { is_intermediate: true });
|
||||
const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, {
|
||||
is_intermediate: true,
|
||||
silent: true,
|
||||
});
|
||||
|
||||
if (!isEqual(scaledSize, originalSize)) {
|
||||
// Resize the initial image to the scaled size, denoise, then resize back to the original size
|
||||
@@ -44,7 +50,12 @@ export const addImageToImage = async ({
|
||||
...scaledSize,
|
||||
});
|
||||
|
||||
const i2l = addImageToLatents(g, l2i.type === 'flux_vae_decode', fp32);
|
||||
const i2l = g.addNode({
|
||||
id: i2lNodeType,
|
||||
type: i2lNodeType,
|
||||
image: image_name ? { image_name } : undefined,
|
||||
...(i2lNodeType === 'i2l' ? { fp32 } : {}),
|
||||
});
|
||||
|
||||
const resizeImageToOriginalSize = g.addNode({
|
||||
type: 'img_resize',
|
||||
@@ -61,7 +72,12 @@ export const addImageToImage = async ({
|
||||
return resizeImageToOriginalSize;
|
||||
} else {
|
||||
// No need to resize, just decode
|
||||
const i2l = addImageToLatents(g, l2i.type === 'flux_vae_decode', fp32, image_name);
|
||||
const i2l = g.addNode({
|
||||
id: i2lNodeType,
|
||||
type: i2lNodeType,
|
||||
image: image_name ? { image_name } : undefined,
|
||||
...(i2lNodeType === 'i2l' ? { fp32 } : {}),
|
||||
});
|
||||
g.addEdge(vaeSource, 'vae', i2l, 'vae');
|
||||
g.addEdge(i2l, 'latents', denoise, 'latents');
|
||||
return l2i;
|
||||
|
||||
@@ -6,7 +6,6 @@ import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import type { Dimensions } from 'features/controlLayers/store/types';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { addImageToLatents } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
|
||||
@@ -14,10 +13,13 @@ type AddInpaintArg = {
|
||||
state: RootState;
|
||||
g: Graph;
|
||||
manager: CanvasManager;
|
||||
l2i: Invocation<'l2i' | 'flux_vae_decode'>;
|
||||
denoise: Invocation<'denoise_latents' | 'flux_denoise'>;
|
||||
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader'>;
|
||||
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader'>;
|
||||
l2i: Invocation<'l2i' | 'flux_vae_decode' | 'sd3_l2i'>;
|
||||
i2lNodeType: 'i2l' | 'flux_vae_encode' | 'sd3_i2l';
|
||||
denoise: Invocation<'denoise_latents' | 'flux_denoise' | 'sd3_denoise'>;
|
||||
vaeSource: Invocation<
|
||||
'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader' | 'sd3_model_loader'
|
||||
>;
|
||||
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'sd3_model_loader'>;
|
||||
originalSize: Dimensions;
|
||||
scaledSize: Dimensions;
|
||||
denoising_start: number;
|
||||
@@ -29,6 +31,7 @@ export const addInpaint = async ({
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType,
|
||||
denoise,
|
||||
vaeSource,
|
||||
modelLoader,
|
||||
@@ -48,16 +51,23 @@ export const addInpaint = async ({
|
||||
const rasterAdapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
const initialImage = await manager.compositor.getCompositeImageDTO(rasterAdapters, bbox.rect, {
|
||||
is_intermediate: true,
|
||||
silent: true,
|
||||
});
|
||||
|
||||
const inpaintMaskAdapters = manager.compositor.getVisibleAdaptersOfType('inpaint_mask');
|
||||
const maskImage = await manager.compositor.getCompositeImageDTO(inpaintMaskAdapters, bbox.rect, {
|
||||
is_intermediate: true,
|
||||
silent: true,
|
||||
});
|
||||
|
||||
if (!isEqual(scaledSize, originalSize)) {
|
||||
// Scale before processing requires some resizing
|
||||
const i2l = addImageToLatents(g, modelLoader.type === 'flux_model_loader', fp32, initialImage.image_name);
|
||||
const i2l = g.addNode({
|
||||
id: i2lNodeType,
|
||||
type: i2lNodeType,
|
||||
image: initialImage.image_name ? { image_name: initialImage.image_name } : undefined,
|
||||
...(i2lNodeType === 'i2l' ? { fp32 } : {}),
|
||||
});
|
||||
|
||||
const resizeImageToScaledSize = g.addNode({
|
||||
type: 'img_resize',
|
||||
@@ -102,7 +112,7 @@ export const addInpaint = async ({
|
||||
g.addEdge(vaeSource, 'vae', i2l, 'vae');
|
||||
|
||||
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
|
||||
if (modelLoader.type !== 'flux_model_loader') {
|
||||
if (modelLoader.type !== 'flux_model_loader' && modelLoader.type !== 'sd3_model_loader') {
|
||||
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
|
||||
}
|
||||
g.addEdge(resizeImageToScaledSize, 'image', createGradientMask, 'image');
|
||||
@@ -126,7 +136,12 @@ export const addInpaint = async ({
|
||||
return resizeOutput;
|
||||
} else {
|
||||
// No scale before processing, much simpler
|
||||
const i2l = addImageToLatents(g, modelLoader.type === 'flux_model_loader', fp32, initialImage.image_name);
|
||||
const i2l = g.addNode({
|
||||
id: i2lNodeType,
|
||||
type: i2lNodeType,
|
||||
image: initialImage.image_name ? { image_name: initialImage.image_name } : undefined,
|
||||
...(i2lNodeType === 'i2l' ? { fp32 } : {}),
|
||||
});
|
||||
|
||||
const alphaToMask = g.addNode({
|
||||
id: getPrefixedId('alpha_to_mask'),
|
||||
@@ -153,7 +168,7 @@ export const addInpaint = async ({
|
||||
g.addEdge(i2l, 'latents', denoise, 'latents');
|
||||
g.addEdge(vaeSource, 'vae', i2l, 'vae');
|
||||
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
|
||||
if (modelLoader.type !== 'flux_model_loader') {
|
||||
if (modelLoader.type !== 'flux_model_loader' && modelLoader.type !== 'sd3_model_loader') {
|
||||
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
|
||||
}
|
||||
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
|
||||
|
||||
@@ -11,7 +11,7 @@ import type { Invocation } from 'services/api/types';
|
||||
export const addNSFWChecker = (
|
||||
g: Graph,
|
||||
imageOutput: Invocation<
|
||||
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'
|
||||
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode' | 'sd3_l2i'
|
||||
>
|
||||
): Invocation<'img_nsfw'> => {
|
||||
const nsfw = g.addNode({
|
||||
|
||||
@@ -6,7 +6,7 @@ import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import type { Dimensions } from 'features/controlLayers/store/types';
|
||||
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import { addImageToLatents, getInfill } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { getInfill } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
|
||||
@@ -14,10 +14,13 @@ type AddOutpaintArg = {
|
||||
state: RootState;
|
||||
g: Graph;
|
||||
manager: CanvasManager;
|
||||
l2i: Invocation<'l2i' | 'flux_vae_decode'>;
|
||||
denoise: Invocation<'denoise_latents' | 'flux_denoise'>;
|
||||
vaeSource: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader'>;
|
||||
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader'>;
|
||||
l2i: Invocation<'l2i' | 'flux_vae_decode' | 'sd3_l2i'>;
|
||||
i2lNodeType: 'i2l' | 'flux_vae_encode' | 'sd3_i2l';
|
||||
denoise: Invocation<'denoise_latents' | 'flux_denoise' | 'sd3_denoise'>;
|
||||
vaeSource: Invocation<
|
||||
'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'seamless' | 'vae_loader' | 'sd3_model_loader'
|
||||
>;
|
||||
modelLoader: Invocation<'main_model_loader' | 'sdxl_model_loader' | 'flux_model_loader' | 'sd3_model_loader'>;
|
||||
originalSize: Dimensions;
|
||||
scaledSize: Dimensions;
|
||||
denoising_start: number;
|
||||
@@ -29,6 +32,7 @@ export const addOutpaint = async ({
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType,
|
||||
denoise,
|
||||
vaeSource,
|
||||
modelLoader,
|
||||
@@ -48,11 +52,13 @@ export const addOutpaint = async ({
|
||||
const rasterAdapters = manager.compositor.getVisibleAdaptersOfType('raster_layer');
|
||||
const initialImage = await manager.compositor.getCompositeImageDTO(rasterAdapters, bbox.rect, {
|
||||
is_intermediate: true,
|
||||
silent: true,
|
||||
});
|
||||
|
||||
const inpaintMaskAdapters = manager.compositor.getVisibleAdaptersOfType('inpaint_mask');
|
||||
const maskImage = await manager.compositor.getCompositeImageDTO(inpaintMaskAdapters, bbox.rect, {
|
||||
is_intermediate: true,
|
||||
silent: true,
|
||||
});
|
||||
|
||||
const infill = getInfill(g, params);
|
||||
@@ -108,14 +114,19 @@ export const addOutpaint = async ({
|
||||
g.addEdge(infill, 'image', createGradientMask, 'image');
|
||||
g.addEdge(resizeInputMaskToScaledSize, 'image', createGradientMask, 'mask');
|
||||
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
|
||||
if (modelLoader.type !== 'flux_model_loader') {
|
||||
if (modelLoader.type !== 'flux_model_loader' && modelLoader.type !== 'sd3_model_loader') {
|
||||
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
|
||||
}
|
||||
|
||||
g.addEdge(createGradientMask, 'denoise_mask', denoise, 'denoise_mask');
|
||||
|
||||
// Decode infilled image and connect to denoise
|
||||
const i2l = addImageToLatents(g, modelLoader.type === 'flux_model_loader', fp32);
|
||||
const i2l = g.addNode({
|
||||
id: i2lNodeType,
|
||||
type: i2lNodeType,
|
||||
...(i2lNodeType === 'i2l' ? { fp32 } : {}),
|
||||
});
|
||||
|
||||
g.addEdge(infill, 'image', i2l, 'image');
|
||||
g.addEdge(vaeSource, 'vae', i2l, 'vae');
|
||||
g.addEdge(i2l, 'latents', denoise, 'latents');
|
||||
@@ -150,7 +161,11 @@ export const addOutpaint = async ({
|
||||
} else {
|
||||
infill.image = { image_name: initialImage.image_name };
|
||||
// No scale before processing, much simpler
|
||||
const i2l = addImageToLatents(g, modelLoader.type === 'flux_model_loader', fp32);
|
||||
const i2l = g.addNode({
|
||||
id: i2lNodeType,
|
||||
type: i2lNodeType,
|
||||
...(i2lNodeType === 'i2l' ? { fp32 } : {}),
|
||||
});
|
||||
const maskAlphaToMask = g.addNode({
|
||||
id: getPrefixedId('mask_alpha_to_mask'),
|
||||
type: 'tomask',
|
||||
@@ -187,7 +202,7 @@ export const addOutpaint = async ({
|
||||
g.addEdge(i2l, 'latents', denoise, 'latents');
|
||||
g.addEdge(vaeSource, 'vae', i2l, 'vae');
|
||||
g.addEdge(vaeSource, 'vae', createGradientMask, 'vae');
|
||||
if (modelLoader.type !== 'flux_model_loader') {
|
||||
if (modelLoader.type !== 'flux_model_loader' && modelLoader.type !== 'sd3_model_loader') {
|
||||
g.addEdge(modelLoader, 'unet', createGradientMask, 'unet');
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import type { Invocation } from 'services/api/types';
|
||||
|
||||
type AddTextToImageArg = {
|
||||
g: Graph;
|
||||
l2i: Invocation<'l2i' | 'flux_vae_decode'>;
|
||||
l2i: Invocation<'l2i' | 'flux_vae_decode' | 'sd3_l2i'>;
|
||||
originalSize: Dimensions;
|
||||
scaledSize: Dimensions;
|
||||
};
|
||||
@@ -16,7 +16,7 @@ export const addTextToImage = ({
|
||||
l2i,
|
||||
originalSize,
|
||||
scaledSize,
|
||||
}: AddTextToImageArg): Invocation<'img_resize' | 'l2i' | 'flux_vae_decode'> => {
|
||||
}: AddTextToImageArg): Invocation<'img_resize' | 'l2i' | 'flux_vae_decode' | 'sd3_l2i'> => {
|
||||
if (!isEqual(scaledSize, originalSize)) {
|
||||
// We need to resize the output image back to the original size
|
||||
const resizeImageToOriginalSize = g.addNode({
|
||||
|
||||
@@ -11,7 +11,7 @@ import type { Invocation } from 'services/api/types';
|
||||
export const addWatermarker = (
|
||||
g: Graph,
|
||||
imageOutput: Invocation<
|
||||
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'
|
||||
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode' | 'sd3_l2i'
|
||||
>
|
||||
): Invocation<'img_watermark'> => {
|
||||
const watermark = g.addNode({
|
||||
|
||||
@@ -139,7 +139,7 @@ export const buildFLUXGraph = async (
|
||||
}
|
||||
|
||||
let canvasOutput: Invocation<
|
||||
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'
|
||||
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode' | 'sd3_l2i'
|
||||
> = l2i;
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
@@ -149,6 +149,7 @@ export const buildFLUXGraph = async (
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType: 'flux_vae_encode',
|
||||
denoise,
|
||||
vaeSource: modelLoader,
|
||||
originalSize,
|
||||
@@ -163,6 +164,7 @@ export const buildFLUXGraph = async (
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType: 'flux_vae_encode',
|
||||
denoise,
|
||||
vaeSource: modelLoader,
|
||||
modelLoader,
|
||||
@@ -177,6 +179,7 @@ export const buildFLUXGraph = async (
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType: 'flux_vae_encode',
|
||||
denoise,
|
||||
vaeSource: modelLoader,
|
||||
modelLoader,
|
||||
|
||||
@@ -170,7 +170,7 @@ export const buildSD1Graph = async (
|
||||
const denoising_start = 1 - params.img2imgStrength;
|
||||
|
||||
let canvasOutput: Invocation<
|
||||
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'
|
||||
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode' | 'sd3_l2i'
|
||||
> = l2i;
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
@@ -180,6 +180,7 @@ export const buildSD1Graph = async (
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType: 'i2l',
|
||||
denoise,
|
||||
vaeSource,
|
||||
originalSize,
|
||||
@@ -194,6 +195,7 @@ export const buildSD1Graph = async (
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType: 'i2l',
|
||||
denoise,
|
||||
vaeSource,
|
||||
modelLoader,
|
||||
@@ -208,6 +210,7 @@ export const buildSD1Graph = async (
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType: 'i2l',
|
||||
denoise,
|
||||
vaeSource,
|
||||
modelLoader,
|
||||
|
||||
@@ -0,0 +1,216 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers';
|
||||
import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage';
|
||||
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
|
||||
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
|
||||
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
|
||||
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
|
||||
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
|
||||
import { Graph } from 'features/nodes/util/graph/generation/Graph';
|
||||
import {
|
||||
CANVAS_OUTPUT_PREFIX,
|
||||
getBoardField,
|
||||
getPresetModifiedPrompts,
|
||||
getSizes,
|
||||
} from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildSD3Graph = async (
|
||||
state: RootState,
|
||||
manager: CanvasManager
|
||||
): Promise<{ g: Graph; noise: Invocation<'sd3_denoise'>; posCond: Invocation<'sd3_text_encoder'> }> => {
|
||||
const generationMode = await manager.compositor.getGenerationMode();
|
||||
log.debug({ generationMode }, 'Building SD3 graph');
|
||||
|
||||
const params = selectParamsSlice(state);
|
||||
const canvasSettings = selectCanvasSettingsSlice(state);
|
||||
const canvas = selectCanvasSlice(state);
|
||||
|
||||
const { bbox } = canvas;
|
||||
|
||||
const {
|
||||
model,
|
||||
cfgScale: cfg_scale,
|
||||
seed,
|
||||
steps,
|
||||
vae,
|
||||
t5EncoderModel,
|
||||
clipLEmbedModel,
|
||||
clipGEmbedModel,
|
||||
optimizedDenoisingEnabled,
|
||||
img2imgStrength,
|
||||
} = params;
|
||||
|
||||
assert(model, 'No model found in state');
|
||||
|
||||
const { originalSize, scaledSize } = getSizes(bbox);
|
||||
const { positivePrompt, negativePrompt } = getPresetModifiedPrompts(state);
|
||||
|
||||
const g = new Graph(getPrefixedId('sd3_graph'));
|
||||
const modelLoader = g.addNode({
|
||||
type: 'sd3_model_loader',
|
||||
id: getPrefixedId('sd3_model_loader'),
|
||||
model,
|
||||
t5_encoder_model: t5EncoderModel,
|
||||
clip_l_model: clipLEmbedModel,
|
||||
clip_g_model: clipGEmbedModel,
|
||||
vae_model: vae,
|
||||
});
|
||||
const posCond = g.addNode({
|
||||
type: 'sd3_text_encoder',
|
||||
id: getPrefixedId('pos_cond'),
|
||||
prompt: positivePrompt,
|
||||
});
|
||||
|
||||
const negCond = g.addNode({
|
||||
type: 'sd3_text_encoder',
|
||||
id: getPrefixedId('neg_cond'),
|
||||
prompt: negativePrompt,
|
||||
});
|
||||
|
||||
const denoise = g.addNode({
|
||||
type: 'sd3_denoise',
|
||||
id: getPrefixedId('sd3_denoise'),
|
||||
cfg_scale,
|
||||
steps,
|
||||
denoising_start: 0,
|
||||
denoising_end: 1,
|
||||
width: scaledSize.width,
|
||||
height: scaledSize.height,
|
||||
});
|
||||
const l2i = g.addNode({
|
||||
type: 'sd3_l2i',
|
||||
id: getPrefixedId('l2i'),
|
||||
});
|
||||
|
||||
g.addEdge(modelLoader, 'transformer', denoise, 'transformer');
|
||||
g.addEdge(modelLoader, 'clip_l', posCond, 'clip_l');
|
||||
g.addEdge(modelLoader, 'clip_l', negCond, 'clip_l');
|
||||
g.addEdge(modelLoader, 'clip_g', posCond, 'clip_g');
|
||||
g.addEdge(modelLoader, 'clip_g', negCond, 'clip_g');
|
||||
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
|
||||
g.addEdge(modelLoader, 't5_encoder', negCond, 't5_encoder');
|
||||
|
||||
g.addEdge(posCond, 'conditioning', denoise, 'positive_conditioning');
|
||||
g.addEdge(negCond, 'conditioning', denoise, 'negative_conditioning');
|
||||
|
||||
g.addEdge(denoise, 'latents', l2i, 'latents');
|
||||
|
||||
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
|
||||
assert(modelConfig.base === 'sd-3');
|
||||
|
||||
g.upsertMetadata({
|
||||
generation_mode: 'sd3_txt2img',
|
||||
cfg_scale,
|
||||
width: originalSize.width,
|
||||
height: originalSize.height,
|
||||
positive_prompt: positivePrompt,
|
||||
negative_prompt: negativePrompt,
|
||||
model: Graph.getModelMetadataField(modelConfig),
|
||||
seed,
|
||||
steps,
|
||||
vae: vae ?? undefined,
|
||||
});
|
||||
g.addEdge(modelLoader, 'vae', l2i, 'vae');
|
||||
|
||||
let denoising_start: number;
|
||||
if (optimizedDenoisingEnabled) {
|
||||
// We rescale the img2imgStrength (with exponent 0.2) to effectively use the entire range [0, 1] and make the scale
|
||||
// more user-friendly for SD3.5. Without this, most of the 'change' is concentrated in the high denoise strength
|
||||
// range (>0.9).
|
||||
denoising_start = 1 - img2imgStrength ** 0.2;
|
||||
} else {
|
||||
denoising_start = 1 - img2imgStrength;
|
||||
}
|
||||
|
||||
let canvasOutput: Invocation<
|
||||
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode' | 'sd3_l2i'
|
||||
> = l2i;
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize });
|
||||
} else if (generationMode === 'img2img') {
|
||||
canvasOutput = await addImageToImage({
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType: 'sd3_i2l',
|
||||
denoise,
|
||||
vaeSource: modelLoader,
|
||||
originalSize,
|
||||
scaledSize,
|
||||
bbox,
|
||||
denoising_start,
|
||||
fp32: false,
|
||||
});
|
||||
} else if (generationMode === 'inpaint') {
|
||||
canvasOutput = await addInpaint({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType: 'sd3_i2l',
|
||||
denoise,
|
||||
vaeSource: modelLoader,
|
||||
modelLoader,
|
||||
originalSize,
|
||||
scaledSize,
|
||||
denoising_start,
|
||||
fp32: false,
|
||||
});
|
||||
} else if (generationMode === 'outpaint') {
|
||||
canvasOutput = await addOutpaint({
|
||||
state,
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType: 'sd3_i2l',
|
||||
denoise,
|
||||
vaeSource: modelLoader,
|
||||
modelLoader,
|
||||
originalSize,
|
||||
scaledSize,
|
||||
denoising_start,
|
||||
fp32: false,
|
||||
});
|
||||
} else {
|
||||
assert<Equals<typeof generationMode, never>>(false);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
canvasOutput = addNSFWChecker(g, canvasOutput);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseWatermarker) {
|
||||
canvasOutput = addWatermarker(g, canvasOutput);
|
||||
}
|
||||
|
||||
// This image will be staged, should not be saved to the gallery or added to a board.
|
||||
const is_intermediate = canvasSettings.sendToCanvas;
|
||||
const board = canvasSettings.sendToCanvas ? undefined : getBoardField(state);
|
||||
|
||||
if (!canvasSettings.sendToCanvas) {
|
||||
g.upsertMetadata(selectCanvasMetadata(state));
|
||||
}
|
||||
|
||||
g.updateNode(canvasOutput, {
|
||||
id: getPrefixedId(CANVAS_OUTPUT_PREFIX),
|
||||
is_intermediate,
|
||||
use_cache: false,
|
||||
board,
|
||||
});
|
||||
|
||||
g.setMetadataReceivingNode(canvasOutput);
|
||||
return { g, noise: denoise, posCond };
|
||||
};
|
||||
@@ -175,7 +175,7 @@ export const buildSDXLGraph = async (
|
||||
: 1 - params.img2imgStrength;
|
||||
|
||||
let canvasOutput: Invocation<
|
||||
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode'
|
||||
'l2i' | 'img_nsfw' | 'img_watermark' | 'img_resize' | 'canvas_v2_mask_and_crop' | 'flux_vae_decode' | 'sd3_l2i'
|
||||
> = l2i;
|
||||
|
||||
if (generationMode === 'txt2img') {
|
||||
@@ -185,6 +185,7 @@ export const buildSDXLGraph = async (
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType: 'i2l',
|
||||
denoise,
|
||||
vaeSource,
|
||||
originalSize,
|
||||
@@ -199,6 +200,7 @@ export const buildSDXLGraph = async (
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType: 'i2l',
|
||||
denoise,
|
||||
vaeSource,
|
||||
modelLoader,
|
||||
@@ -213,6 +215,7 @@ export const buildSDXLGraph = async (
|
||||
g,
|
||||
manager,
|
||||
l2i,
|
||||
i2lNodeType: 'i2l',
|
||||
denoise,
|
||||
vaeSource,
|
||||
modelLoader,
|
||||
|
||||
@@ -118,16 +118,4 @@ export const getInfill = (
|
||||
assert(false, 'Unknown infill method');
|
||||
};
|
||||
|
||||
export const addImageToLatents = (g: Graph, isFlux: boolean, fp32: boolean, image_name?: string) => {
|
||||
if (isFlux) {
|
||||
return g.addNode({
|
||||
id: 'flux_vae_encode',
|
||||
type: 'flux_vae_encode',
|
||||
image: image_name ? { image_name } : undefined,
|
||||
});
|
||||
} else {
|
||||
return g.addNode({ id: 'i2l', type: 'i2l', fp32, image: image_name ? { image_name } : undefined });
|
||||
}
|
||||
};
|
||||
|
||||
export const CANVAS_OUTPUT_PREFIX = 'canvas_output';
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useModelCombobox } from 'common/hooks/useModelCombobox';
|
||||
import { clipGEmbedModelSelected, selectCLIPGEmbedModel } from 'features/controlLayers/store/paramsSlice';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { CLIPGEmbedModelConfig } from 'services/api/types';
|
||||
import { isCLIPGEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
const ParamCLIPEmbedModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const clipEmbedModel = useAppSelector(selectCLIPGEmbedModel);
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(clipEmbedModel: CLIPGEmbedModelConfig | null) => {
|
||||
if (clipEmbedModel) {
|
||||
dispatch(clipGEmbedModelSelected(zModelIdentifierField.parse(clipEmbedModel)));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const { options, value, onChange, noOptionsMessage } = useModelCombobox({
|
||||
modelConfigs: modelConfigs.filter((config) => isCLIPGEmbedModelConfig(config)),
|
||||
onChange: _onChange,
|
||||
selectedModel: clipEmbedModel,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!options.length} isInvalid={!options.length} minW={0} flexGrow={1}>
|
||||
<FormLabel m={0}>{t('modelManager.clipGEmbed')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={onChange} noOptionsMessage={noOptionsMessage} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamCLIPEmbedModelSelect);
|
||||
@@ -0,0 +1,42 @@
|
||||
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useModelCombobox } from 'common/hooks/useModelCombobox';
|
||||
import { clipLEmbedModelSelected, selectCLIPLEmbedModel } from 'features/controlLayers/store/paramsSlice';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useCLIPEmbedModels } from 'services/api/hooks/modelsByType';
|
||||
import type { CLIPLEmbedModelConfig } from 'services/api/types';
|
||||
import { isCLIPLEmbedModelConfig } from 'services/api/types';
|
||||
|
||||
const ParamCLIPEmbedModelSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const clipEmbedModel = useAppSelector(selectCLIPLEmbedModel);
|
||||
const [modelConfigs, { isLoading }] = useCLIPEmbedModels();
|
||||
|
||||
const _onChange = useCallback(
|
||||
(clipEmbedModel: CLIPLEmbedModelConfig | null) => {
|
||||
if (clipEmbedModel) {
|
||||
dispatch(clipLEmbedModelSelected(zModelIdentifierField.parse(clipEmbedModel)));
|
||||
}
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const { options, value, onChange, noOptionsMessage } = useModelCombobox({
|
||||
modelConfigs: modelConfigs.filter((config) => isCLIPLEmbedModelConfig(config)),
|
||||
onChange: _onChange,
|
||||
selectedModel: clipEmbedModel,
|
||||
isLoading,
|
||||
});
|
||||
|
||||
return (
|
||||
<FormControl isDisabled={!options.length} isInvalid={!options.length} minW={0} flexGrow={1}>
|
||||
<FormLabel m={0}>{t('modelManager.clipLEmbed')}</FormLabel>
|
||||
<Combobox value={value} options={options} onChange={onChange} noOptionsMessage={noOptionsMessage} />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamCLIPEmbedModelSelect);
|
||||
@@ -9,7 +9,7 @@ import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MdMoneyOff } from 'react-icons/md';
|
||||
import { useNonSD3MainModels } from 'services/api/hooks/modelsByType';
|
||||
import { useMainModels } from 'services/api/hooks/modelsByType';
|
||||
import { type AnyModelConfig, isCheckpointMainModelConfig, type MainModelConfig } from 'services/api/types';
|
||||
|
||||
const ParamMainModelSelect = () => {
|
||||
@@ -18,7 +18,7 @@ const ParamMainModelSelect = () => {
|
||||
const activeTabName = useAppSelector(selectActiveTab);
|
||||
const selectedModelKey = useAppSelector(selectModelKey);
|
||||
// const selectedModel = useAppSelector(selectModel);
|
||||
const [modelConfigs, { isLoading }] = useNonSD3MainModels();
|
||||
const [modelConfigs, { isLoading }] = useMainModels();
|
||||
|
||||
const selectedModel = useMemo(() => {
|
||||
if (!modelConfigs) {
|
||||
|
||||
@@ -7,9 +7,10 @@ export const MODEL_TYPE_MAP = {
|
||||
any: 'Any',
|
||||
'sd-1': 'Stable Diffusion 1.x',
|
||||
'sd-2': 'Stable Diffusion 2.x',
|
||||
'sd-3': 'Stable Diffusion 3.x',
|
||||
sdxl: 'Stable Diffusion XL',
|
||||
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
||||
flux: 'Flux',
|
||||
flux: 'FLUX',
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -123,6 +123,16 @@ export const zParameterCLIPEmbedModel = zModelIdentifierField;
|
||||
export type ParameterCLIPEmbedModel = z.infer<typeof zParameterCLIPEmbedModel>;
|
||||
// #endregion
|
||||
|
||||
// #region CLIP embed Model
|
||||
export const zParameterCLIPLEmbedModel = zModelIdentifierField;
|
||||
export type ParameterCLIPLEmbedModel = z.infer<typeof zParameterCLIPLEmbedModel>;
|
||||
// #endregion
|
||||
|
||||
// #region CLIP embed Model
|
||||
export const zParameterCLIPGEmbedModel = zModelIdentifierField;
|
||||
export type ParameterCLIPGEmbedModel = z.infer<typeof zParameterCLIPGEmbedModel>;
|
||||
// #endregion
|
||||
|
||||
// #region LoRA Model
|
||||
const zParameterLoRAModel = zModelIdentifierField;
|
||||
export type ParameterLoRAModel = z.infer<typeof zParameterLoRAModel>;
|
||||
|
||||
@@ -3,9 +3,11 @@ import { Flex, FormControlGroup, StandaloneAccordion } from '@invoke-ai/ui-libra
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectIsFLUX, selectParamsSlice, selectVAEKey } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectIsFLUX, selectIsSD3, selectParamsSlice, selectVAEKey } from 'features/controlLayers/store/paramsSlice';
|
||||
import ParamCFGRescaleMultiplier from 'features/parameters/components/Advanced/ParamCFGRescaleMultiplier';
|
||||
import ParamCLIPEmbedModelSelect from 'features/parameters/components/Advanced/ParamCLIPEmbedModelSelect';
|
||||
import ParamCLIPGEmbedModelSelect from 'features/parameters/components/Advanced/ParamCLIPGEmbedModelSelect';
|
||||
import ParamCLIPLEmbedModelSelect from 'features/parameters/components/Advanced/ParamCLIPLEmbedModelSelect';
|
||||
import ParamClipSkip from 'features/parameters/components/Advanced/ParamClipSkip';
|
||||
import ParamT5EncoderModelSelect from 'features/parameters/components/Advanced/ParamT5EncoderModelSelect';
|
||||
import ParamSeamlessXAxis from 'features/parameters/components/Seamless/ParamSeamlessXAxis';
|
||||
@@ -35,6 +37,7 @@ export const AdvancedSettingsAccordion = memo(() => {
|
||||
const { currentData: vaeConfig } = useGetModelConfigQuery(vaeKey ?? skipToken);
|
||||
const activeTabName = useAppSelector(selectActiveTab);
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const isSD3 = useAppSelector(selectIsSD3);
|
||||
|
||||
const selectBadges = useMemo(
|
||||
() =>
|
||||
@@ -88,7 +91,7 @@ export const AdvancedSettingsAccordion = memo(() => {
|
||||
<Flex gap={4} alignItems="center" p={4} flexDir="column" data-testid="advanced-settings-accordion">
|
||||
<Flex gap={4} w="full">
|
||||
{isFLUX ? <ParamFLUXVAEModelSelect /> : <ParamVAEModelSelect />}
|
||||
{!isFLUX && <ParamVAEPrecision />}
|
||||
{!isFLUX && !isSD3 && <ParamVAEPrecision />}
|
||||
</Flex>
|
||||
{activeTabName === 'upscaling' ? (
|
||||
<Flex gap={4} alignItems="center">
|
||||
@@ -98,7 +101,7 @@ export const AdvancedSettingsAccordion = memo(() => {
|
||||
</Flex>
|
||||
) : (
|
||||
<>
|
||||
{!isFLUX && (
|
||||
{!isFLUX && !isSD3 && (
|
||||
<>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
<ParamClipSkip />
|
||||
@@ -118,6 +121,13 @@ export const AdvancedSettingsAccordion = memo(() => {
|
||||
<ParamCLIPEmbedModelSelect />
|
||||
</FormControlGroup>
|
||||
)}
|
||||
{isSD3 && (
|
||||
<FormControlGroup>
|
||||
<ParamT5EncoderModelSelect />
|
||||
<ParamCLIPLEmbedModelSelect />
|
||||
<ParamCLIPGEmbedModelSelect />
|
||||
</FormControlGroup>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
|
||||
@@ -4,7 +4,7 @@ import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
|
||||
import { LoRAList } from 'features/lora/components/LoRAList';
|
||||
import LoRASelect from 'features/lora/components/LoRASelect';
|
||||
import ParamCFGScale from 'features/parameters/components/Core/ParamCFGScale';
|
||||
@@ -30,6 +30,7 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
const modelConfig = useSelectedModelConfig();
|
||||
const activeTabName = useAppSelector(selectActiveTab);
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const isSD3 = useAppSelector(selectIsSD3);
|
||||
const selectBadges = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectLoRAsSlice, (loras) => {
|
||||
@@ -74,7 +75,7 @@ export const GenerationSettingsAccordion = memo(() => {
|
||||
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||
<Flex gap={4} flexDir="column" pb={4}>
|
||||
<FormControlGroup formLabelProps={formLabelProps}>
|
||||
{!isFLUX && <ParamScheduler />}
|
||||
{!isFLUX && !isSD3 && <ParamScheduler />}
|
||||
<ParamSteps />
|
||||
{isFLUX ? <ParamGuidance /> : <ParamCFGScale />}
|
||||
</FormControlGroup>
|
||||
|
||||
@@ -3,7 +3,7 @@ import { Expander, Flex, FormControlGroup, StandaloneAccordion } from '@invoke-a
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectIsFLUX, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectIsFLUX, selectIsSD3, selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectScaleMethod } from 'features/controlLayers/store/selectors';
|
||||
import { ParamOptimizedDenoisingToggle } from 'features/parameters/components/Advanced/ParamOptimizedDenoisingToggle';
|
||||
import BboxScaledHeight from 'features/parameters/components/Bbox/BboxScaledHeight';
|
||||
@@ -60,6 +60,7 @@ export const ImageSettingsAccordion = memo(() => {
|
||||
defaultIsOpen: false,
|
||||
});
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
const isSD3 = useAppSelector(selectIsSD3);
|
||||
|
||||
return (
|
||||
<StandaloneAccordion
|
||||
@@ -77,7 +78,7 @@ export const ImageSettingsAccordion = memo(() => {
|
||||
</Flex>
|
||||
<Expander label={t('accordions.advanced.options')} isOpen={isOpenExpander} onToggle={onToggleExpander}>
|
||||
<Flex gap={4} pb={4} flexDir="column">
|
||||
{isFLUX && <ParamOptimizedDenoisingToggle />}
|
||||
{(isFLUX || isSD3) && <ParamOptimizedDenoisingToggle />}
|
||||
<BboxScaleMethod />
|
||||
{scaleMethod !== 'none' && (
|
||||
<FormControlGroup formLabelProps={scalingLabelProps}>
|
||||
|
||||
@@ -34,8 +34,15 @@ export const UpscaleWarning = () => {
|
||||
dispatch(tileControlnetModelChanged(validModel || null));
|
||||
}, [model?.base, modelConfigs, dispatch]);
|
||||
|
||||
const isBaseModelCompatible = useMemo(() => {
|
||||
return model && ['sd-1', 'sdxl'].includes(model.base);
|
||||
}, [model]);
|
||||
|
||||
const modelWarnings = useMemo(() => {
|
||||
const _warnings: string[] = [];
|
||||
if (!isBaseModelCompatible) {
|
||||
return _warnings;
|
||||
}
|
||||
if (!model) {
|
||||
_warnings.push(t('upscaling.mainModelDesc'));
|
||||
}
|
||||
@@ -46,7 +53,7 @@ export const UpscaleWarning = () => {
|
||||
_warnings.push(t('upscaling.upscaleModelDesc'));
|
||||
}
|
||||
return _warnings;
|
||||
}, [model, tileControlnetModel, upscaleModel, t]);
|
||||
}, [isBaseModelCompatible, model, tileControlnetModel, upscaleModel, t]);
|
||||
|
||||
const otherWarnings = useMemo(() => {
|
||||
const _warnings: string[] = [];
|
||||
@@ -58,22 +65,25 @@ export const UpscaleWarning = () => {
|
||||
return _warnings;
|
||||
}, [isTooLargeToUpscale, t, maxUpscaleDimension]);
|
||||
|
||||
const allWarnings = useMemo(() => [...modelWarnings, ...otherWarnings], [modelWarnings, otherWarnings]);
|
||||
|
||||
const handleGoToModelManager = useCallback(() => {
|
||||
dispatch(setActiveTab('models'));
|
||||
$installModelsTab.set(3);
|
||||
}, [dispatch]);
|
||||
|
||||
if (modelWarnings.length && isModelsTabDisabled) {
|
||||
if (isBaseModelCompatible && modelWarnings.length > 0 && isModelsTabDisabled) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if ((!modelWarnings.length && !otherWarnings.length) || isLoading) {
|
||||
if ((isBaseModelCompatible && allWarnings.length === 0) || isLoading) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex bg="error.500" borderRadius="base" padding={4} direction="column" fontSize="sm" gap={2}>
|
||||
{!!modelWarnings.length && (
|
||||
{!isBaseModelCompatible && <Text>{t('upscaling.incompatibleBaseModelDesc')}</Text>}
|
||||
{modelWarnings.length > 0 && (
|
||||
<Text>
|
||||
<Trans
|
||||
i18nKey="upscaling.missingModelsWarning"
|
||||
@@ -85,11 +95,13 @@ export const UpscaleWarning = () => {
|
||||
/>
|
||||
</Text>
|
||||
)}
|
||||
<UnorderedList>
|
||||
{[...modelWarnings, ...otherWarnings].map((warning) => (
|
||||
<ListItem key={warning}>{warning}</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
{allWarnings.length > 0 && (
|
||||
<UnorderedList>
|
||||
{allWarnings.map((warning) => (
|
||||
<ListItem key={warning}>{warning}</ListItem>
|
||||
))}
|
||||
</UnorderedList>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -26,17 +26,20 @@ import { SettingsDeveloperLogLevel } from 'features/system/components/SettingsMo
|
||||
import { SettingsDeveloperLogNamespaces } from 'features/system/components/SettingsModal/SettingsDeveloperLogNamespaces';
|
||||
import { useClearIntermediates } from 'features/system/components/SettingsModal/useClearIntermediates';
|
||||
import { StickyScrollable } from 'features/system/components/StickyScrollable';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
import {
|
||||
selectSystemShouldAntialiasProgressImage,
|
||||
selectSystemShouldConfirmOnDelete,
|
||||
selectSystemShouldConfirmOnNewSession,
|
||||
selectSystemShouldEnableInformationalPopovers,
|
||||
selectSystemShouldEnableModelDescriptions,
|
||||
selectSystemShouldShowInvocationProgressDetail,
|
||||
selectSystemShouldUseNSFWChecker,
|
||||
selectSystemShouldUseWatermarker,
|
||||
setShouldConfirmOnDelete,
|
||||
setShouldEnableInformationalPopovers,
|
||||
setShouldEnableModelDescriptions,
|
||||
setShouldShowInvocationProgressDetail,
|
||||
shouldAntialiasProgressImageChanged,
|
||||
shouldConfirmOnNewSessionToggled,
|
||||
shouldUseNSFWCheckerChanged,
|
||||
@@ -103,6 +106,8 @@ const SettingsModal = ({ config = defaultConfig, children }: SettingsModalProps)
|
||||
const shouldEnableInformationalPopovers = useAppSelector(selectSystemShouldEnableInformationalPopovers);
|
||||
const shouldEnableModelDescriptions = useAppSelector(selectSystemShouldEnableModelDescriptions);
|
||||
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
|
||||
const shouldShowInvocationProgressDetail = useAppSelector(selectSystemShouldShowInvocationProgressDetail);
|
||||
const isInvocationProgressAlertEnabled = useFeatureStatus('invocationProgressAlert');
|
||||
const onToggleConfirmOnNewSession = useCallback(() => {
|
||||
dispatch(shouldConfirmOnNewSessionToggled());
|
||||
}, [dispatch]);
|
||||
@@ -170,6 +175,13 @@ const SettingsModal = ({ config = defaultConfig, children }: SettingsModalProps)
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleChangeShouldShowInvocationProgressDetail = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
dispatch(setShouldShowInvocationProgressDetail(e.target.checked));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
{cloneElement(children, {
|
||||
@@ -221,6 +233,15 @@ const SettingsModal = ({ config = defaultConfig, children }: SettingsModalProps)
|
||||
onChange={handleChangeShouldAntialiasProgressImage}
|
||||
/>
|
||||
</FormControl>
|
||||
{isInvocationProgressAlertEnabled && (
|
||||
<FormControl>
|
||||
<FormLabel>{t('settings.showDetailedInvocationProgress')}</FormLabel>
|
||||
<Switch
|
||||
isChecked={shouldShowInvocationProgressDetail}
|
||||
onChange={handleChangeShouldShowInvocationProgressDetail}
|
||||
/>
|
||||
</FormControl>
|
||||
)}
|
||||
<FormControl>
|
||||
<InformationalPopover feature="noiseUseCPU" inPortal={false}>
|
||||
<FormLabel>{t('parameters.useCpuNoise')}</FormLabel>
|
||||
|
||||
@@ -21,6 +21,7 @@ const initialSystemState: SystemState = {
|
||||
logIsEnabled: true,
|
||||
logLevel: 'debug',
|
||||
logNamespaces: [...zLogNamespace.options],
|
||||
shouldShowInvocationProgressDetail: false,
|
||||
};
|
||||
|
||||
export const systemSlice = createSlice({
|
||||
@@ -64,6 +65,9 @@ export const systemSlice = createSlice({
|
||||
shouldConfirmOnNewSessionToggled(state) {
|
||||
state.shouldConfirmOnNewSession = !state.shouldConfirmOnNewSession;
|
||||
},
|
||||
setShouldShowInvocationProgressDetail(state, action: PayloadAction<boolean>) {
|
||||
state.shouldShowInvocationProgressDetail = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -79,6 +83,7 @@ export const {
|
||||
setShouldEnableInformationalPopovers,
|
||||
setShouldEnableModelDescriptions,
|
||||
shouldConfirmOnNewSessionToggled,
|
||||
setShouldShowInvocationProgressDetail,
|
||||
} = systemSlice.actions;
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
@@ -117,3 +122,6 @@ export const selectSystemShouldEnableModelDescriptions = createSystemSelector(
|
||||
(system) => system.shouldEnableModelDescriptions
|
||||
);
|
||||
export const selectSystemShouldConfirmOnNewSession = createSystemSelector((system) => system.shouldConfirmOnNewSession);
|
||||
export const selectSystemShouldShowInvocationProgressDetail = createSystemSelector(
|
||||
(system) => system.shouldShowInvocationProgressDetail
|
||||
);
|
||||
|
||||
@@ -41,4 +41,5 @@ export interface SystemState {
|
||||
logIsEnabled: boolean;
|
||||
logLevel: LogLevel;
|
||||
logNamespaces: LogNamespace[];
|
||||
shouldShowInvocationProgressDetail: boolean;
|
||||
}
|
||||
|
||||
@@ -6,10 +6,10 @@ import type { components, paths } from 'services/api/schema';
|
||||
import type {
|
||||
DeleteBoardResult,
|
||||
GraphAndWorkflowResponse,
|
||||
ImageCategory,
|
||||
ImageDTO,
|
||||
ListImagesArgs,
|
||||
ListImagesResponse,
|
||||
UploadImageArg,
|
||||
} from 'services/api/types';
|
||||
import { getCategories, getListImagesUrl } from 'services/api/util';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
@@ -260,20 +260,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
return [];
|
||||
},
|
||||
}),
|
||||
uploadImage: build.mutation<
|
||||
ImageDTO,
|
||||
{
|
||||
file: File;
|
||||
image_category: ImageCategory;
|
||||
is_intermediate: boolean;
|
||||
session_id?: string;
|
||||
board_id?: string;
|
||||
crop_visible?: boolean;
|
||||
metadata?: JsonObject;
|
||||
isFirstUploadOfBatch?: boolean;
|
||||
withToast?: boolean;
|
||||
}
|
||||
>({
|
||||
uploadImage: build.mutation<ImageDTO, UploadImageArg>({
|
||||
query: ({ file, image_category, is_intermediate, session_id, board_id, crop_visible, metadata }) => {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
@@ -558,7 +545,6 @@ export const {
|
||||
useClearIntermediatesMutation,
|
||||
useAddImagesToBoardMutation,
|
||||
useRemoveImagesFromBoardMutation,
|
||||
useChangeImageIsIntermediateMutation,
|
||||
useDeleteBoardAndImagesMutation,
|
||||
useDeleteBoardMutation,
|
||||
useStarImagesMutation,
|
||||
@@ -622,79 +608,17 @@ export const getImageMetadata = (
|
||||
return req.unwrap();
|
||||
};
|
||||
|
||||
export type UploadImageArg = {
|
||||
file: File;
|
||||
image_category: ImageCategory;
|
||||
is_intermediate: boolean;
|
||||
session_id?: string;
|
||||
board_id?: string;
|
||||
crop_visible?: boolean;
|
||||
metadata?: JsonObject;
|
||||
withToast?: boolean;
|
||||
};
|
||||
|
||||
export const uploadImage = (arg: UploadImageArg): Promise<ImageDTO> => {
|
||||
const {
|
||||
file,
|
||||
image_category,
|
||||
is_intermediate,
|
||||
crop_visible = false,
|
||||
board_id,
|
||||
metadata,
|
||||
session_id,
|
||||
withToast = true,
|
||||
} = arg;
|
||||
|
||||
const { dispatch } = getStore();
|
||||
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate(
|
||||
{
|
||||
file,
|
||||
image_category,
|
||||
is_intermediate,
|
||||
crop_visible,
|
||||
board_id,
|
||||
metadata,
|
||||
session_id,
|
||||
withToast,
|
||||
},
|
||||
{ track: false }
|
||||
)
|
||||
);
|
||||
const req = dispatch(imagesApi.endpoints.uploadImage.initiate(arg, { track: false }));
|
||||
return req.unwrap();
|
||||
};
|
||||
|
||||
export const uploadImages = async (args: UploadImageArg[]): Promise<ImageDTO[]> => {
|
||||
const { dispatch } = getStore();
|
||||
const results = await Promise.allSettled(
|
||||
args.map((arg, i) => {
|
||||
const {
|
||||
file,
|
||||
image_category,
|
||||
is_intermediate,
|
||||
crop_visible = false,
|
||||
board_id,
|
||||
metadata,
|
||||
session_id,
|
||||
withToast = true,
|
||||
} = arg;
|
||||
const req = dispatch(
|
||||
imagesApi.endpoints.uploadImage.initiate(
|
||||
{
|
||||
file,
|
||||
image_category,
|
||||
is_intermediate,
|
||||
crop_visible,
|
||||
board_id,
|
||||
metadata,
|
||||
session_id,
|
||||
isFirstUploadOfBatch: i === 0,
|
||||
withToast,
|
||||
},
|
||||
{ track: false }
|
||||
)
|
||||
);
|
||||
args.map((arg) => {
|
||||
const req = dispatch(imagesApi.endpoints.uploadImage.initiate(arg, { track: false }));
|
||||
return req.unwrap();
|
||||
})
|
||||
);
|
||||
|
||||
@@ -18,7 +18,6 @@ import {
|
||||
isIPAdapterModelConfig,
|
||||
isLoRAModelConfig,
|
||||
isNonRefinerMainModelConfig,
|
||||
isNonSD3MainModelModelConfig,
|
||||
isNonSDXLMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isSD3MainModelModelConfig,
|
||||
@@ -53,7 +52,6 @@ const buildModelsHook =
|
||||
};
|
||||
|
||||
export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||
export const useNonSD3MainModels = buildModelsHook(isNonSD3MainModelModelConfig);
|
||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||
export const useFluxModels = buildModelsHook(isFluxMainModelModelConfig);
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1,5 +1,5 @@
|
||||
import type { components, paths } from 'services/api/schema';
|
||||
import type { SetRequired } from 'type-fest';
|
||||
import type { JsonObject, SetRequired } from 'type-fest';
|
||||
|
||||
export type S = components['schemas'];
|
||||
|
||||
@@ -223,10 +223,6 @@ export const isSD3MainModelModelConfig = (config: AnyModelConfig): config is Mai
|
||||
return config.type === 'main' && config.base === 'sd-3';
|
||||
};
|
||||
|
||||
export const isNonSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base !== 'sd-3' && config.base !== 'sdxl-refiner';
|
||||
};
|
||||
|
||||
export const isFluxMainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base === 'flux';
|
||||
};
|
||||
@@ -291,3 +287,42 @@ export type SetHFTokenResponse = NonNullable<
|
||||
export type SetHFTokenArg = NonNullable<
|
||||
paths['/api/v2/models/hf_login']['post']['requestBody']['content']['application/json']
|
||||
>;
|
||||
|
||||
export type UploadImageArg = {
|
||||
/**
|
||||
* The file object to upload
|
||||
*/
|
||||
file: File;
|
||||
/**
|
||||
* THe category of image to upload
|
||||
*/
|
||||
image_category: ImageCategory;
|
||||
/**
|
||||
* Whether the uploaded image is an intermediate image (intermediate images are not shown int he gallery)
|
||||
*/
|
||||
is_intermediate: boolean;
|
||||
/**
|
||||
* The session with which to associate the uploaded image
|
||||
*/
|
||||
session_id?: string;
|
||||
/**
|
||||
* The board id to add the image to
|
||||
*/
|
||||
board_id?: string;
|
||||
/**
|
||||
* Whether or not to crop the image to its bounding box before saving
|
||||
*/
|
||||
crop_visible?: boolean;
|
||||
/**
|
||||
* Metadata to embed in the image when saving it
|
||||
*/
|
||||
metadata?: JsonObject;
|
||||
/**
|
||||
* Whether this upload should be "silent" (no toast on upload, no changing of gallery view)
|
||||
*/
|
||||
silent?: boolean;
|
||||
/**
|
||||
* Whether this is the first upload of a batch (used when displaying user feedback with toasts - ignored if the upload is silent)
|
||||
*/
|
||||
isFirstUploadOfBatch?: boolean;
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { round } from 'lodash-es';
|
||||
import { atom, computed, map } from 'nanostores';
|
||||
import type { S } from 'services/api/types';
|
||||
import type { AppSocket } from 'services/events/types';
|
||||
@@ -10,3 +11,16 @@ export const $lastProgressEvent = atom<S['InvocationProgressEvent'] | null>(null
|
||||
export const $progressImage = computed($lastProgressEvent, (val) => val?.image ?? null);
|
||||
export const $hasProgressImage = computed($lastProgressEvent, (val) => Boolean(val?.image));
|
||||
export const $isProgressFromCanvas = computed($lastProgressEvent, (val) => val?.destination === 'canvas');
|
||||
export const $invocationProgressMessage = computed($lastProgressEvent, (val) => {
|
||||
if (!val) {
|
||||
return null;
|
||||
}
|
||||
if (val.destination !== 'canvas') {
|
||||
return null;
|
||||
}
|
||||
let message = val.message;
|
||||
if (val.percentage) {
|
||||
message += ` (${round(val.percentage * 100)}%)`;
|
||||
}
|
||||
return message;
|
||||
});
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "5.4.1rc2"
|
||||
__version__ = "5.4.1"
|
||||
|
||||
@@ -6,7 +6,11 @@ from diffusers import AutoencoderTiny
|
||||
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
|
||||
from invokeai.app.services.shared.invocation_context import (
|
||||
InvocationContext,
|
||||
InvocationContextData,
|
||||
build_invocation_context,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
@@ -19,7 +23,7 @@ def mock_context(
|
||||
mock_services.model_manager = mm2_model_manager
|
||||
return build_invocation_context(
|
||||
services=mock_services,
|
||||
data=None, # type: ignore
|
||||
data=InvocationContextData(queue_item=None, invocation=None, source_invocation_id=None), # type: ignore
|
||||
is_canceled=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# State dict keys and shapes for an XLabs FLUX IP-Adapter model. Intended to be used for unit tests.
|
||||
# These keys were extracted from:
|
||||
# https://huggingface.co/XLabs-AI/flux-ip-adapter/blob/ad16be50d78a07ea83d8c4bde44ff9753235182e/flux-ip-adapter.safetensors
|
||||
# https://huggingface.co/XLabs-AI/flux-ip-adapter/resolve/main/ip_adapter.safetensors
|
||||
xlabs_sd_shapes = {
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.bias": [3072],
|
||||
"double_blocks.0.processor.ip_adapter_double_stream_k_proj.weight": [3072, 4096],
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,688 @@
|
||||
# A sample state dict in the Diffusers FLUX LoRA format without .proj_mlp layers.
|
||||
# This format was added in response to https://github.com/invoke-ai/InvokeAI/issues/7129
|
||||
state_dict_keys = {
|
||||
"transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.0.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.0.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.0.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.0.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.1.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.1.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.1.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.1.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.1.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.1.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.10.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.10.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.10.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.10.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.10.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.10.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.11.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.11.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.11.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.11.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.11.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.11.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.12.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.12.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.12.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.12.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.12.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.12.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.13.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.13.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.13.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.13.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.13.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.13.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.14.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.14.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.14.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.14.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.14.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.14.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.15.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.15.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.15.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.15.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.15.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.15.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.16.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.16.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.16.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.16.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.16.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.16.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.17.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.17.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.17.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.17.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.17.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.17.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.18.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.18.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.18.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.18.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.18.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.18.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.19.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.19.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.19.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.19.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.19.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.19.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.2.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.2.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.2.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.2.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.2.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.2.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.20.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.20.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.20.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.20.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.20.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.20.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.21.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.21.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.21.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.21.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.21.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.21.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.22.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.22.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.22.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.22.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.22.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.22.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.23.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.23.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.23.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.23.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.23.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.23.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.24.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.24.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.24.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.24.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.24.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.24.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.25.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.25.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.25.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.25.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.25.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.25.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.26.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.26.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.26.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.26.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.26.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.26.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.27.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.27.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.27.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.27.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.27.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.27.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.28.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.28.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.28.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.28.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.28.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.28.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.29.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.29.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.29.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.29.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.29.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.29.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.3.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.3.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.3.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.3.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.3.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.3.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.30.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.30.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.30.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.30.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.30.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.30.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.31.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.31.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.31.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.31.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.31.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.31.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.32.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.32.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.32.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.32.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.32.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.32.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.33.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.33.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.33.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.33.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.33.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.33.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.34.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.34.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.34.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.34.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.34.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.34.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.35.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.35.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.35.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.35.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.35.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.35.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.36.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.36.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.36.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.36.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.36.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.36.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.37.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.37.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.37.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.37.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.37.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.37.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.4.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.4.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.4.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.4.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.4.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.4.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.5.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.5.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.5.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.5.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.5.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.5.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.6.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.6.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.6.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.6.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.6.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.6.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.7.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.7.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.7.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.7.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.7.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.7.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.8.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.8.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.8.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.8.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.8.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.8.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.9.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.9.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.9.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.9.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.single_transformer_blocks.9.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.single_transformer_blocks.9.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.0.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.0.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.0.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.0.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.0.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.0.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.0.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.0.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.0.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.0.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.0.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.0.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.0.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.0.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.0.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.0.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.0.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.0.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.0.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.0.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.0.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.0.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.0.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.1.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.1.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.1.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.1.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.1.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.1.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.1.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.1.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.1.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.1.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.1.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.1.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.1.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.1.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.1.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.1.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.1.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.1.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.1.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.1.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.1.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.1.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.1.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.1.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.10.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.10.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.10.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.10.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.10.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.10.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.10.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.10.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.10.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.10.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.10.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.10.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.10.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.10.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.10.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.10.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.10.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.10.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.10.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.10.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.10.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.10.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.10.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.10.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.11.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.11.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.11.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.11.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.11.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.11.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.11.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.11.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.11.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.11.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.11.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.11.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.11.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.11.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.11.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.11.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.11.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.11.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.11.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.11.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.11.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.11.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.11.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.11.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.12.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.12.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.12.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.12.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.12.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.12.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.12.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.12.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.12.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.12.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.12.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.12.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.12.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.12.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.12.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.12.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.12.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.12.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.12.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.12.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.12.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.12.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.12.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.12.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.13.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.13.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.13.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.13.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.13.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.13.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.13.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.13.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.13.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.13.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.13.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.13.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.13.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.13.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.13.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.13.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.13.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.13.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.13.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.13.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.13.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.13.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.13.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.13.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.14.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.14.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.14.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.14.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.14.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.14.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.14.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.14.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.14.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.14.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.14.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.14.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.14.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.14.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.14.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.14.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.14.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.14.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.14.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.14.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.14.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.14.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.14.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.14.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.15.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.15.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.15.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.15.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.15.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.15.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.15.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.15.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.15.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.15.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.15.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.15.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.15.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.15.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.15.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.15.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.15.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.15.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.15.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.15.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.15.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.15.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.15.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.15.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.16.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.16.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.16.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.16.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.16.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.16.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.16.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.16.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.16.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.16.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.16.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.16.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.16.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.16.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.16.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.16.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.16.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.16.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.16.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.16.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.16.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.16.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.16.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.16.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.17.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.17.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.17.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.17.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.17.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.17.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.17.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.17.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.17.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.17.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.17.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.17.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.17.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.17.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.17.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.17.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.17.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.17.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.17.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.17.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.17.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.17.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.17.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.17.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.18.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.18.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.18.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.18.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.18.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.18.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.18.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.18.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.18.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.18.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.18.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.18.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.18.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.18.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.18.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.18.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.18.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.18.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.18.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.18.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.18.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.18.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.18.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.18.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.2.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.2.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.2.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.2.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.2.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.2.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.2.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.2.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.2.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.2.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.2.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.2.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.2.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.2.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.2.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.2.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.2.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.2.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.2.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.2.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.2.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.2.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.2.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.2.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.3.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.3.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.3.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.3.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.3.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.3.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.3.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.3.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.3.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.3.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.3.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.3.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.3.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.3.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.3.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.3.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.3.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.3.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.3.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.3.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.3.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.3.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.3.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.3.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.4.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.4.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.4.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.4.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.4.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.4.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.4.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.4.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.4.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.4.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.4.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.4.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.4.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.4.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.4.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.4.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.4.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.4.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.4.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.4.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.4.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.4.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.4.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.4.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.5.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.5.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.5.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.5.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.5.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.5.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.5.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.5.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.5.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.5.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.5.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.5.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.5.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.5.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.5.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.5.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.5.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.5.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.5.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.5.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.5.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.5.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.5.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.5.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.6.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.6.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.6.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.6.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.6.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.6.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.6.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.6.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.6.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.6.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.6.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.6.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.6.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.6.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.6.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.6.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.6.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.6.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.6.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.6.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.6.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.6.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.6.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.6.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.7.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.7.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.7.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.7.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.7.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.7.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.7.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.7.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.7.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.7.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.7.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.7.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.7.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.7.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.7.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.7.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.7.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.7.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.7.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.7.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.7.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.7.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.7.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.7.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.8.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.8.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.8.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.8.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.8.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.8.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.8.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.8.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.8.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.8.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.8.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.8.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.8.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.8.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.8.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.8.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.8.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.8.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.8.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.8.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.8.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.8.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.8.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.8.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.9.attn.add_k_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.9.attn.add_k_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.9.attn.add_q_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.9.attn.add_q_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.9.attn.add_v_proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.9.attn.add_v_proj.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.9.attn.to_add_out.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.9.attn.to_add_out.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.9.attn.to_k.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.9.attn.to_k.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.9.attn.to_out.0.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.9.attn.to_out.0.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.9.attn.to_q.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.9.attn.to_q.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.9.attn.to_v.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.9.attn.to_v.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.9.ff.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.9.ff.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.9.ff.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.9.ff.net.2.lora_B.weight": [3072, 16],
|
||||
"transformer.transformer_blocks.9.ff_context.net.0.proj.lora_A.weight": [16, 3072],
|
||||
"transformer.transformer_blocks.9.ff_context.net.0.proj.lora_B.weight": [12288, 16],
|
||||
"transformer.transformer_blocks.9.ff_context.net.2.lora_A.weight": [16, 12288],
|
||||
"transformer.transformer_blocks.9.ff_context.net.2.lora_B.weight": [3072, 16],
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,8 +1,8 @@
|
||||
import torch
|
||||
|
||||
|
||||
def keys_to_mock_state_dict(keys: list[str]) -> dict[str, torch.Tensor]:
|
||||
def keys_to_mock_state_dict(keys: dict[str, list[int]]) -> dict[str, torch.Tensor]:
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
for k, shape in keys.items():
|
||||
state_dict[k] = torch.empty(shape)
|
||||
return state_dict
|
||||
|
||||
@@ -9,22 +9,26 @@ from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRAN
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format import (
|
||||
state_dict_keys as flux_diffusers_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_no_proj_mlp_format import (
|
||||
state_dict_keys as flux_diffusers_no_proj_mlp_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import (
|
||||
state_dict_keys as flux_kohya_state_dict_keys,
|
||||
)
|
||||
from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_diffusers_format_true():
|
||||
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_diffusers_no_proj_mlp_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_diffusers_format_true(sd_keys: dict[str, list[int]]):
|
||||
"""Test that is_state_dict_likely_in_flux_diffusers_format() can identify a state dict in the Diffusers FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Diffusers FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
|
||||
assert is_state_dict_likely_in_flux_diffusers_format(state_dict)
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_diffusers_format_false():
|
||||
"""Test that is_state_dict_likely_in_flux_diffusers_format() returns False for a state dict that is not in the Kohya
|
||||
"""Test that is_state_dict_likely_in_flux_diffusers_format() returns False for a state dict that is in the Kohya
|
||||
FLUX LoRA format.
|
||||
"""
|
||||
# Construct a state dict that is not in the Kohya FLUX LoRA format.
|
||||
@@ -33,16 +37,17 @@ def test_is_state_dict_likely_in_flux_diffusers_format_false():
|
||||
assert not is_state_dict_likely_in_flux_diffusers_format(state_dict)
|
||||
|
||||
|
||||
def test_lora_model_from_flux_diffusers_state_dict():
|
||||
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_diffusers_no_proj_mlp_state_dict_keys])
|
||||
def test_lora_model_from_flux_diffusers_state_dict(sd_keys: dict[str, list[int]]):
|
||||
"""Test that lora_model_from_flux_diffusers_state_dict() can load a state dict in the Diffusers FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Diffusers FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(flux_diffusers_state_dict_keys)
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
# Load the state dict into a LoRAModelRaw object.
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict, alpha=8.0)
|
||||
|
||||
# Check that the model has the correct number of LoRA layers.
|
||||
expected_lora_layers: set[str] = set()
|
||||
for k in flux_diffusers_state_dict_keys:
|
||||
for k in sd_keys:
|
||||
k = k.replace("lora_A.weight", "")
|
||||
k = k.replace("lora_B.weight", "")
|
||||
expected_lora_layers.add(k)
|
||||
|
||||
@@ -23,7 +23,7 @@ from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_s
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: list[str]):
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: dict[str, list[int]]):
|
||||
"""Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
@@ -83,7 +83,7 @@ def test_convert_flux_transformer_kohya_state_dict_to_invoke_format_error():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys])
|
||||
def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]):
|
||||
def test_lora_model_from_flux_kohya_state_dict(sd_keys: dict[str, list[int]]):
|
||||
"""Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format."""
|
||||
# Construct a state dict that is in the Kohya FLUX LoRA format.
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
|
||||
@@ -9,6 +9,8 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
from invokeai.app.invocations.fields import InputField, OutputField
|
||||
from invokeai.app.invocations.image import ImageField
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@@ -133,6 +135,16 @@ class TestEventService(EventServiceBase):
|
||||
self.events.append(event)
|
||||
pass
|
||||
|
||||
def emit_invocation_progress(
|
||||
self,
|
||||
queue_item: "SessionQueueItem",
|
||||
invocation: "BaseInvocation",
|
||||
message: str,
|
||||
percentage: float | None = None,
|
||||
image: "ProgressImage | None" = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def wait_until(condition: Callable[[], bool], timeout: int = 10, interval: float = 0.1) -> None:
|
||||
import time
|
||||
|
||||
Reference in New Issue
Block a user