mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-20 01:28:04 -05:00
Compare commits
27 Commits
ryan/multi
...
maryhipp/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
89c5662848 | ||
|
|
e3e8d689d7 | ||
|
|
9d86c2e2c1 | ||
|
|
c3dd91e3c2 | ||
|
|
aaf83de364 | ||
|
|
959f70da71 | ||
|
|
d551338d62 | ||
|
|
1304fbb36f | ||
|
|
a2a70b6eb0 | ||
|
|
9c328056d5 | ||
|
|
977dbd8051 | ||
|
|
14250a0593 | ||
|
|
62b4614aed | ||
|
|
451c0f00e0 | ||
|
|
05485e1b47 | ||
|
|
01164a404f | ||
|
|
f0b587da27 | ||
|
|
f6b30d2b6b | ||
|
|
6d4fc6e55b | ||
|
|
4e1a0b8a7f | ||
|
|
67abe33c02 | ||
|
|
a3c736c0dc | ||
|
|
e4738b4bee | ||
|
|
fa13ec1f6b | ||
|
|
5ced646210 | ||
|
|
b03073d888 | ||
|
|
a43d602f16 |
@@ -316,6 +316,7 @@ async def list_image_dtos(
|
||||
),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of images per page"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of image DTOs"""
|
||||
|
||||
@@ -326,6 +327,7 @@ async def list_image_dtos(
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
search_term
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
||||
@@ -55,7 +55,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
from invokeai.backend.util.mask import to_standard_float_mask
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
@@ -66,9 +65,6 @@ def get_scheduler(
|
||||
scheduler_name: str,
|
||||
seed: int,
|
||||
) -> Scheduler:
|
||||
"""Load a scheduler and apply some scheduler-specific overrides."""
|
||||
# TODO(ryand): Silently falling back to ddim seems like a bad idea. Look into why this was added and remove if
|
||||
# possible.
|
||||
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP["ddim"])
|
||||
orig_scheduler_info = context.models.load(scheduler_info)
|
||||
with orig_scheduler_info as orig_scheduler:
|
||||
@@ -186,8 +182,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def _get_text_embeddings_and_masks(
|
||||
self,
|
||||
cond_list: list[ConditioningField],
|
||||
context: InvocationContext,
|
||||
device: torch.device,
|
||||
@@ -207,9 +203,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return text_embeddings, text_embeddings_masks
|
||||
|
||||
@staticmethod
|
||||
def _preprocess_regional_prompt_mask(
|
||||
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
"""Preprocess a regional prompt mask to match the target height and width.
|
||||
If mask is None, returns a mask of all ones with the target height and width.
|
||||
@@ -233,8 +228,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
resized_mask = tf(mask)
|
||||
return resized_mask
|
||||
|
||||
@staticmethod
|
||||
def _concat_regional_text_embeddings(
|
||||
self,
|
||||
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||
masks: Optional[list[Optional[torch.Tensor]]],
|
||||
latent_height: int,
|
||||
@@ -284,9 +279,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
)
|
||||
processed_masks.append(
|
||||
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
|
||||
mask, latent_height, latent_width, dtype=dtype
|
||||
)
|
||||
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||
)
|
||||
|
||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||
@@ -308,41 +301,36 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
return BasicConditioningInfo(embeds=text_embedding), regions
|
||||
|
||||
@staticmethod
|
||||
def get_conditioning_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||
unet: UNet2DConditionModel,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
cfg_scale: float | list[float],
|
||||
steps: int,
|
||||
cfg_rescale_multiplier: float,
|
||||
) -> TextConditioningData:
|
||||
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
|
||||
cond_list = positive_conditioning_field
|
||||
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
|
||||
cond_list = self.positive_conditioning
|
||||
if not isinstance(cond_list, list):
|
||||
cond_list = [cond_list]
|
||||
uncond_list = negative_conditioning_field
|
||||
uncond_list = self.negative_conditioning
|
||||
if not isinstance(uncond_list, list):
|
||||
uncond_list = [uncond_list]
|
||||
|
||||
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
cond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||
uncond_list, context, unet.device, unet.dtype
|
||||
)
|
||||
|
||||
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
|
||||
text_conditionings=cond_text_embeddings,
|
||||
masks=cond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
|
||||
text_conditionings=uncond_text_embeddings,
|
||||
masks=uncond_text_embedding_masks,
|
||||
latent_height=latent_height,
|
||||
@@ -350,21 +338,23 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
|
||||
if isinstance(cfg_scale, list):
|
||||
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
|
||||
if isinstance(self.cfg_scale, list):
|
||||
assert (
|
||||
len(self.cfg_scale) == self.steps
|
||||
), "cfg_scale (list) must have the same length as the number of steps"
|
||||
|
||||
conditioning_data = TextConditioningData(
|
||||
uncond_text=uncond_text_embedding,
|
||||
cond_text=cond_text_embedding,
|
||||
uncond_regions=uncond_regions,
|
||||
cond_regions=cond_regions,
|
||||
guidance_scale=cfg_scale,
|
||||
guidance_rescale_multiplier=cfg_rescale_multiplier,
|
||||
guidance_scale=self.cfg_scale,
|
||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
return conditioning_data
|
||||
|
||||
@staticmethod
|
||||
def create_pipeline(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Scheduler,
|
||||
) -> StableDiffusionGeneratorPipeline:
|
||||
@@ -387,38 +377,38 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def prep_control_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
control_input: ControlField | list[ControlField] | None,
|
||||
control_input: Optional[Union[ControlField, List[ControlField]]],
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
) -> list[ControlNetData] | None:
|
||||
# Normalize control_input to a list.
|
||||
control_list: list[ControlField]
|
||||
if isinstance(control_input, ControlField):
|
||||
control_list = [control_input]
|
||||
elif isinstance(control_input, list):
|
||||
control_list = control_input
|
||||
elif control_input is None:
|
||||
control_list = []
|
||||
else:
|
||||
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
|
||||
|
||||
if len(control_list) == 0:
|
||||
return None
|
||||
|
||||
) -> Optional[List[ControlNetData]]:
|
||||
# Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR.
|
||||
_, _, latent_height, latent_width = latents_shape
|
||||
control_height_resize = latent_height * LATENT_SCALE_FACTOR
|
||||
control_width_resize = latent_width * LATENT_SCALE_FACTOR
|
||||
control_height_resize = latents_shape[2] * LATENT_SCALE_FACTOR
|
||||
control_width_resize = latents_shape[3] * LATENT_SCALE_FACTOR
|
||||
if control_input is None:
|
||||
control_list = None
|
||||
elif isinstance(control_input, list) and len(control_input) == 0:
|
||||
control_list = None
|
||||
elif isinstance(control_input, ControlField):
|
||||
control_list = [control_input]
|
||||
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
|
||||
control_list = control_input
|
||||
else:
|
||||
control_list = None
|
||||
if control_list is None:
|
||||
return None
|
||||
# After above handling, any control that is not None should now be of type list[ControlField].
|
||||
|
||||
controlnet_data: list[ControlNetData] = []
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(context.models.load(control_info.control_model))
|
||||
assert isinstance(control_model, ControlNetModel)
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.images.get_pil(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
@@ -439,7 +429,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
control_item = ControlNetData(
|
||||
model=control_model,
|
||||
model=control_model, # model object
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
@@ -593,15 +583,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||
# TODO: research more for second order schedulers timesteps
|
||||
@staticmethod
|
||||
def init_scheduler(
|
||||
self,
|
||||
scheduler: Union[Scheduler, ConfigMixin],
|
||||
device: torch.device,
|
||||
steps: int,
|
||||
denoising_start: float,
|
||||
denoising_end: float,
|
||||
seed: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
||||
) -> Tuple[int, List[int], int, Dict[str, Any]]:
|
||||
assert isinstance(scheduler, ConfigMixin)
|
||||
if scheduler.config.get("cpu_only", False):
|
||||
scheduler.set_timesteps(steps, device="cpu")
|
||||
@@ -627,6 +617,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
init_timestep = timesteps[t_start_idx : t_start_idx + 1]
|
||||
timesteps = timesteps[t_start_idx : t_start_idx + t_end_idx]
|
||||
num_inference_steps = len(timesteps) // scheduler.order
|
||||
|
||||
scheduler_step_kwargs: Dict[str, Any] = {}
|
||||
scheduler_step_signature = inspect.signature(scheduler.step)
|
||||
@@ -648,7 +639,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
if isinstance(scheduler, TCDScheduler):
|
||||
scheduler_step_kwargs.update({"eta": 1.0})
|
||||
|
||||
return timesteps, init_timestep, scheduler_step_kwargs
|
||||
return num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs
|
||||
|
||||
def prep_inpaint_mask(
|
||||
self, context: InvocationContext, latents: torch.Tensor
|
||||
@@ -665,51 +656,30 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return 1 - mask, masked_latents, self.denoise_mask.gradient
|
||||
|
||||
@staticmethod
|
||||
def prepare_noise_and_latents(
|
||||
context: InvocationContext, noise_field: LatentsField | None, latents_field: LatentsField | None
|
||||
) -> Tuple[int, torch.Tensor | None, torch.Tensor]:
|
||||
"""Depending on the workflow, we expect different combinations of noise and latents to be provided. This
|
||||
function handles preparing these values accordingly.
|
||||
|
||||
Expected workflows:
|
||||
- Text-to-Image Denoising: `noise` is provided, `latents` is not. `latents` is initialized to zeros.
|
||||
- Image-to-Image Denoising: `noise` and `latents` are both provided.
|
||||
- Text-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
|
||||
- Image-to-Image SDXL Refiner Denoising: `latents` is provided, `noise` is not.
|
||||
|
||||
NOTE(ryand): I wrote this docstring, but I am not the original author of this code. There may be other workflows
|
||||
I haven't considered.
|
||||
"""
|
||||
noise = None
|
||||
if noise_field is not None:
|
||||
noise = context.tensors.load(noise_field.latents_name)
|
||||
|
||||
if latents_field is not None:
|
||||
latents = context.tensors.load(latents_field.latents_name)
|
||||
elif noise is not None:
|
||||
latents = torch.zeros_like(noise)
|
||||
else:
|
||||
raise ValueError("'latents' or 'noise' must be provided!")
|
||||
|
||||
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||
raise ValueError(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||
|
||||
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
|
||||
seed = 0
|
||||
if noise_field is not None and noise_field.seed is not None:
|
||||
seed = noise_field.seed
|
||||
elif latents_field is not None and latents_field.seed is not None:
|
||||
seed = latents_field.seed
|
||||
else:
|
||||
seed = 0
|
||||
|
||||
return seed, noise, latents
|
||||
|
||||
@torch.no_grad()
|
||||
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
seed = None
|
||||
noise = None
|
||||
if self.noise is not None:
|
||||
noise = context.tensors.load(self.noise.latents_name)
|
||||
seed = self.noise.seed
|
||||
|
||||
if self.latents is not None:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
if seed is None:
|
||||
seed = self.latents.seed
|
||||
|
||||
if noise is not None and noise.shape[1:] != latents.shape[1:]:
|
||||
raise Exception(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
|
||||
|
||||
elif noise is not None:
|
||||
latents = torch.zeros_like(noise)
|
||||
else:
|
||||
raise Exception("'latents' or 'noise' must be provided!")
|
||||
|
||||
if seed is None:
|
||||
seed = 0
|
||||
|
||||
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
|
||||
|
||||
@@ -784,15 +754,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
conditioning_data = self.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||
)
|
||||
|
||||
controlnet_data = self.prep_control_data(
|
||||
@@ -814,7 +776,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
dtype=unet.dtype,
|
||||
)
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
@@ -831,7 +793,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
seed=seed,
|
||||
mask=mask,
|
||||
masked_latents=masked_latents,
|
||||
is_gradient_mask=gradient_mask,
|
||||
gradient_mask=gradient_mask,
|
||||
num_inference_steps=num_inference_steps,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=controlnet_data,
|
||||
|
||||
@@ -8,7 +8,7 @@ from diffusers.models.attention_processor import (
|
||||
)
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
|
||||
from PIL import Image
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||
@@ -23,7 +23,6 @@ from invokeai.app.invocations.fields import (
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.stable_diffusion import set_seamless
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@@ -49,20 +48,16 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
|
||||
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
|
||||
|
||||
@staticmethod
|
||||
def vae_decode(
|
||||
context: InvocationContext,
|
||||
vae_info: LoadedModel,
|
||||
seamless_axes: list[str],
|
||||
latents: torch.Tensor,
|
||||
use_fp32: bool,
|
||||
use_tiling: bool,
|
||||
) -> Image.Image:
|
||||
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
|
||||
with set_seamless(vae_info.model, seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
assert isinstance(vae_info.model, (UNet2DConditionModel, AutoencoderKL, AutoencoderTiny))
|
||||
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
|
||||
assert isinstance(vae, torch.nn.Module)
|
||||
latents = latents.to(vae.device)
|
||||
if use_fp32:
|
||||
if self.fp32:
|
||||
vae.to(dtype=torch.float32)
|
||||
|
||||
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
|
||||
@@ -87,7 +82,7 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
vae.to(dtype=torch.float16)
|
||||
latents = latents.half()
|
||||
|
||||
if use_tiling or context.config.get().force_tiled_decode:
|
||||
if self.tiled or context.config.get().force_tiled_decode:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
@@ -107,21 +102,6 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
return image
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
latents = context.tensors.load(self.latents.latents_name)
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
|
||||
image = self.vae_decode(
|
||||
context=context,
|
||||
vae_info=vae_info,
|
||||
seamless_axes=self.vae.seamless_axes,
|
||||
latents=latents,
|
||||
use_fp32=self.fp32,
|
||||
use_tiling=self.tiled,
|
||||
)
|
||||
image_dto = context.images.save(image=image)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
|
||||
@@ -1,268 +0,0 @@
|
||||
import copy
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.model import UNetField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData
|
||||
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
|
||||
MultiDiffusionPipeline,
|
||||
MultiDiffusionRegionConditioning,
|
||||
)
|
||||
from invokeai.backend.tiles.tiles import (
|
||||
calc_tiles_min_overlap,
|
||||
)
|
||||
from invokeai.backend.tiles.utils import TBLR
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
||||
def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData:
|
||||
"""Crop a ControlNetData object to a region."""
|
||||
# Create a shallow copy of the control_data object.
|
||||
control_data_copy = copy.copy(control_data)
|
||||
# The ControlNet reference image is the only attribute that needs to be cropped.
|
||||
control_data_copy.image_tensor = control_data.image_tensor[
|
||||
:,
|
||||
:,
|
||||
latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR,
|
||||
latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR,
|
||||
]
|
||||
return control_data_copy
|
||||
|
||||
|
||||
@invocation(
|
||||
"tiled_multi_diffusion_denoise_latents",
|
||||
title="Tiled Multi-Diffusion Denoise Latents",
|
||||
tags=["upscale", "denoise"],
|
||||
category="latents",
|
||||
# TODO(ryand): Reset to 1.0.0 right before release.
|
||||
version="1.0.0",
|
||||
)
|
||||
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
"""Tiled Multi-Diffusion denoising.
|
||||
|
||||
This node handles automatically tiling the input image. Future iterations of
|
||||
this node should allow the user to specify custom regions with different parameters for each region to harness the
|
||||
full power of Multi-Diffusion.
|
||||
|
||||
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
|
||||
T2I-Adapter, masking, etc.).
|
||||
"""
|
||||
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
noise: LatentsField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.noise,
|
||||
input=Input.Connection,
|
||||
)
|
||||
latents: LatentsField | None = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.latents,
|
||||
input=Input.Connection,
|
||||
)
|
||||
# TODO(ryand): Add multiple-of validation.
|
||||
# TODO(ryand): Smaller defaults might make more sense.
|
||||
tile_height: int = InputField(default=112, gt=0, description="Height of the tiles in latent space.")
|
||||
tile_width: int = InputField(default=112, gt=0, description="Width of the tiles in latent space.")
|
||||
tile_min_overlap: int = InputField(
|
||||
default=16,
|
||||
gt=0,
|
||||
description="The minimum overlap between adjacent tiles in latent space. The actual overlap may be larger than "
|
||||
"this to evenly cover the entire image.",
|
||||
)
|
||||
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
# TODO(ryand): The default here should probably be 0.0.
|
||||
denoising_start: float = InputField(
|
||||
default=0.65,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
)
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
cfg_rescale_multiplier: float = InputField(
|
||||
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||
)
|
||||
control: ControlField | list[ControlField] | None = InputField(
|
||||
default=None,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v: list[float] | float) -> list[float] | float:
|
||||
"""Validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def create_pipeline(
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: SchedulerMixin,
|
||||
) -> MultiDiffusionPipeline:
|
||||
# TODO(ryand): Get rid of this FakeVae hack.
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self) -> None:
|
||||
self.block_out_channels = [0]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
return MultiDiffusionPipeline(
|
||||
vae=FakeVae(), # TODO: oh...
|
||||
text_encoder=None,
|
||||
tokenizer=None,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
# Calculate the tile locations to cover the latent-space image.
|
||||
# TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint?
|
||||
tiles = calc_tiles_min_overlap(
|
||||
image_height=latent_height,
|
||||
image_width=latent_width,
|
||||
tile_height=self.tile_height,
|
||||
tile_width=self.tile_width,
|
||||
min_overlap=self.tile_min_overlap,
|
||||
)
|
||||
|
||||
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
# Load the UNet model.
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
|
||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
)
|
||||
pipeline = self.create_pipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
|
||||
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
latent_height=self.tile_height,
|
||||
latent_width=self.tile_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
|
||||
controlnet_data = DenoiseLatentsInvocation.prep_control_data(
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=list(latents.shape),
|
||||
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
# Split the controlnet_data into tiles.
|
||||
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
|
||||
controlnet_data_tiles: list[list[ControlNetData]] = []
|
||||
for tile in tiles:
|
||||
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
|
||||
controlnet_data_tiles.append(tile_controlnet_data)
|
||||
|
||||
# Prepare the MultiDiffusionRegionConditioning list.
|
||||
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = []
|
||||
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
|
||||
multi_diffusion_conditioning.append(
|
||||
MultiDiffusionRegionConditioning(
|
||||
region=tile.coords,
|
||||
text_conditioning_data=conditioning_data,
|
||||
control_data=tile_controlnet_data,
|
||||
)
|
||||
)
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
# Run Multi-Diffusion denoising.
|
||||
result_latents = pipeline.multi_diffusion_denoise(
|
||||
multi_diffusion_conditioning=multi_diffusion_conditioning,
|
||||
latents=latents,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
noise=noise,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
# TODO(ryand): Add proper callback.
|
||||
callback=lambda x: None,
|
||||
)
|
||||
|
||||
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
|
||||
result_latents = result_latents.to("cpu")
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=result_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|
||||
@@ -1,380 +0,0 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Iterator, Tuple
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from PIL import Image
|
||||
from pydantic import field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
||||
from invokeai.app.invocations.fields import (
|
||||
ConditioningField,
|
||||
FieldDescriptions,
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
UIType,
|
||||
)
|
||||
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
||||
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
|
||||
from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField
|
||||
from invokeai.app.invocations.noise import get_noise
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
|
||||
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
|
||||
from invokeai.backend.tiles.utils import Tile
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
@invocation(
|
||||
"tiled_stable_diffusion_refine",
|
||||
title="Tiled Stable Diffusion Refine",
|
||||
tags=["upscale", "denoise"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
)
|
||||
class TiledStableDiffusionRefineInvocation(BaseInvocation):
|
||||
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
|
||||
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
|
||||
"""
|
||||
|
||||
image: ImageField = InputField(description="Image to be refined.")
|
||||
|
||||
positive_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.positive_cond, input=Input.Connection
|
||||
)
|
||||
negative_conditioning: ConditioningField = InputField(
|
||||
description=FieldDescriptions.negative_cond, input=Input.Connection
|
||||
)
|
||||
# TODO(ryand): Add multiple-of validation.
|
||||
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.")
|
||||
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.")
|
||||
tile_overlap: int = InputField(
|
||||
default=16,
|
||||
gt=0,
|
||||
description="Target overlap between adjacent tiles (the last row/column may overlap more than this).",
|
||||
)
|
||||
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
|
||||
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
denoising_start: float = InputField(
|
||||
default=0.65,
|
||||
ge=0,
|
||||
le=1,
|
||||
description=FieldDescriptions.denoising_start,
|
||||
)
|
||||
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
||||
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
)
|
||||
unet: UNetField = InputField(
|
||||
description=FieldDescriptions.unet,
|
||||
input=Input.Connection,
|
||||
title="UNet",
|
||||
)
|
||||
cfg_rescale_multiplier: float = InputField(
|
||||
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
||||
)
|
||||
vae: VAEField = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
vae_fp32: bool = InputField(
|
||||
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
|
||||
)
|
||||
# HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly
|
||||
# don't want to use the image field. Figure out how best to handle this.
|
||||
# TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty
|
||||
# much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support
|
||||
# CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models,
|
||||
# etc.
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: float = InputField(default=0.6)
|
||||
|
||||
@field_validator("cfg_scale")
|
||||
def ge_one(cls, v: list[float] | float) -> list[float] | float:
|
||||
"""Validate that all cfg_scale values are >= 1"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
else:
|
||||
if v < 1:
|
||||
raise ValueError("cfg_scale must be greater than 1")
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor:
|
||||
"""Crop the latent-space tensor to the area corresponding to the image-space tile.
|
||||
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR.
|
||||
"""
|
||||
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]:
|
||||
if coord % LATENT_SCALE_FACTOR != 0:
|
||||
raise ValueError(
|
||||
f"The tile coordinates must all be divisible by the latent scale factor"
|
||||
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
|
||||
)
|
||||
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
|
||||
|
||||
top = image_tile.coords.top // LATENT_SCALE_FACTOR
|
||||
left = image_tile.coords.left // LATENT_SCALE_FACTOR
|
||||
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
|
||||
right = image_tile.coords.right // LATENT_SCALE_FACTOR
|
||||
return latents[..., top:bottom, left:right]
|
||||
|
||||
def run_controlnet(
|
||||
self,
|
||||
image: Image.Image,
|
||||
controlnet_model: ControlNetModel,
|
||||
weight: float,
|
||||
do_classifier_free_guidance: bool,
|
||||
width: int,
|
||||
height: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
control_mode: CONTROLNET_MODE_VALUES = "balanced",
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
|
||||
) -> ControlNetData:
|
||||
control_image = prepare_control_image(
|
||||
image=image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=width,
|
||||
height=height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode=control_mode,
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
return ControlNetData(
|
||||
model=controlnet_model,
|
||||
image_tensor=control_image,
|
||||
weight=weight,
|
||||
begin_step_percent=0.0,
|
||||
end_step_percent=1.0,
|
||||
control_mode=control_mode,
|
||||
# Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to
|
||||
# ControlNetData in case needed in the future.
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
# TODO(ryand): Expose the seed parameter.
|
||||
seed = 0
|
||||
|
||||
# Load the input image.
|
||||
input_image = context.images.get_pil(self.image.image_name)
|
||||
|
||||
# Calculate the tile locations to cover the image.
|
||||
# We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This
|
||||
# facilitates conversions between image space and latent space.
|
||||
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
|
||||
tiles = calc_tiles_with_overlap(
|
||||
image_height=input_image.height,
|
||||
image_width=input_image.width,
|
||||
tile_height=self.tile_height,
|
||||
tile_width=self.tile_width,
|
||||
overlap=self.tile_overlap,
|
||||
)
|
||||
|
||||
# Convert the input image to a torch.Tensor.
|
||||
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
|
||||
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
|
||||
# Validate our assumptions about the shape of input_image_torch.
|
||||
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
|
||||
assert input_image_torch.shape[:2] == (1, 3)
|
||||
|
||||
# Split the input image into tiles in torch.Tensor format.
|
||||
image_tiles_torch: list[torch.Tensor] = []
|
||||
for tile in tiles:
|
||||
image_tile = input_image_torch[
|
||||
:,
|
||||
:,
|
||||
tile.coords.top : tile.coords.bottom,
|
||||
tile.coords.left : tile.coords.right,
|
||||
]
|
||||
image_tiles_torch.append(image_tile)
|
||||
|
||||
# Split the input image into tiles in numpy format.
|
||||
# TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work
|
||||
# with torch.Tensor tiles.
|
||||
input_image_np = np.array(input_image)
|
||||
image_tiles_np: list[npt.NDArray[np.uint8]] = []
|
||||
for tile in tiles:
|
||||
image_tile_np = input_image_np[
|
||||
tile.coords.top : tile.coords.bottom,
|
||||
tile.coords.left : tile.coords.right,
|
||||
:,
|
||||
]
|
||||
image_tiles_np.append(image_tile_np)
|
||||
|
||||
# VAE-encode each image tile independently.
|
||||
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
|
||||
# about for decoding?
|
||||
vae_info = context.models.load(self.vae.vae)
|
||||
latent_tiles: list[torch.Tensor] = []
|
||||
for image_tile_torch in image_tiles_torch:
|
||||
latent_tiles.append(
|
||||
ImageToLatentsInvocation.vae_encode(
|
||||
vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile_torch
|
||||
)
|
||||
)
|
||||
|
||||
# Generate noise with dimensions corresponding to the full image in latent space.
|
||||
# It is important that the noise tensor is generated at the full image dimension and then tiled, rather than
|
||||
# generating for each tile independently. This ensures that overlapping regions between tiles use the same
|
||||
# noise.
|
||||
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
|
||||
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
|
||||
global_noise = get_noise(
|
||||
width=input_image_torch.shape[3],
|
||||
height=input_image_torch.shape[2],
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
seed=seed,
|
||||
downsampling_factor=LATENT_SCALE_FACTOR,
|
||||
use_cpu=True,
|
||||
)
|
||||
|
||||
# Crop the global noise into tiles.
|
||||
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles]
|
||||
|
||||
# Prepare an iterator that yields the UNet's LoRA models and their weights.
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.unet.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
assert isinstance(lora_info.model, LoRAModelRaw)
|
||||
yield (lora_info.model, lora.weight)
|
||||
del lora_info
|
||||
|
||||
# Load the UNet model.
|
||||
unet_info = context.models.load(self.unet.unet)
|
||||
|
||||
refined_latent_tiles: list[torch.Tensor] = []
|
||||
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
|
||||
assert isinstance(unet, UNet2DConditionModel)
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.unet.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=seed,
|
||||
)
|
||||
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
|
||||
# Assume that all tiles have the same shape.
|
||||
_, _, latent_height, latent_width = latent_tiles[0].shape
|
||||
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
|
||||
context=context,
|
||||
positive_conditioning_field=self.positive_conditioning,
|
||||
negative_conditioning_field=self.negative_conditioning,
|
||||
unet=unet,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
cfg_scale=self.cfg_scale,
|
||||
steps=self.steps,
|
||||
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||
)
|
||||
|
||||
# Load the ControlNet model.
|
||||
# TODO(ryand): Support multiple ControlNet models.
|
||||
controlnet_model = exit_stack.enter_context(context.models.load(self.control_model))
|
||||
assert isinstance(controlnet_model, ControlNetModel)
|
||||
|
||||
# Denoise (i.e. "refine") each tile independently.
|
||||
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
|
||||
assert latent_tile.shape == noise_tile.shape
|
||||
|
||||
# Prepare a PIL Image for ControlNet processing.
|
||||
# TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of
|
||||
# the tiles. Ideally, the ControlNet code should be able to work with Tensors.
|
||||
image_tile_pil = Image.fromarray(image_tile_np)
|
||||
|
||||
# Run the ControlNet on the image tile.
|
||||
height, width, _ = image_tile_np.shape
|
||||
# The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we
|
||||
# validate this assumption here.
|
||||
assert height % LATENT_SCALE_FACTOR == 0
|
||||
assert width % LATENT_SCALE_FACTOR == 0
|
||||
controlnet_data = self.run_controlnet(
|
||||
image=image_tile_pil,
|
||||
controlnet_model=controlnet_model,
|
||||
weight=self.control_weight,
|
||||
do_classifier_free_guidance=True,
|
||||
width=width,
|
||||
height=height,
|
||||
device=controlnet_model.device,
|
||||
dtype=controlnet_model.dtype,
|
||||
control_mode="balanced",
|
||||
resize_mode="just_resize_simple",
|
||||
)
|
||||
|
||||
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
|
||||
scheduler,
|
||||
device=unet.device,
|
||||
steps=self.steps,
|
||||
denoising_start=self.denoising_start,
|
||||
denoising_end=self.denoising_end,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
|
||||
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
|
||||
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
|
||||
refined_latent_tile = pipeline.latents_from_embeddings(
|
||||
latents=latent_tile,
|
||||
timesteps=timesteps,
|
||||
init_timestep=init_timestep,
|
||||
noise=noise_tile,
|
||||
seed=seed,
|
||||
mask=None,
|
||||
masked_latents=None,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=[controlnet_data],
|
||||
ip_adapter_data=None,
|
||||
t2i_adapter_data=None,
|
||||
callback=lambda x: None,
|
||||
)
|
||||
refined_latent_tiles.append(refined_latent_tile)
|
||||
|
||||
# VAE-decode each refined latent tile independently.
|
||||
refined_image_tiles: list[Image.Image] = []
|
||||
for refined_latent_tile in refined_latent_tiles:
|
||||
refined_image_tile = LatentsToImageInvocation.vae_decode(
|
||||
context=context,
|
||||
vae_info=vae_info,
|
||||
seamless_axes=self.vae.seamless_axes,
|
||||
latents=refined_latent_tile,
|
||||
use_fp32=self.vae_fp32,
|
||||
use_tiling=False,
|
||||
)
|
||||
refined_image_tiles.append(refined_image_tile)
|
||||
|
||||
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
# Merge the refined image tiles back into a single image.
|
||||
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
|
||||
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
|
||||
# TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter?
|
||||
merge_tiles_with_linear_blending(
|
||||
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap
|
||||
)
|
||||
|
||||
# Save the refined image and return its reference.
|
||||
merged_image_pil = Image.fromarray(merged_image_np)
|
||||
image_dto = context.images.save(image=merged_image_pil)
|
||||
|
||||
return ImageOutput.build(image_dto)
|
||||
@@ -113,6 +113,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
||||
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
||||
max_queue_size: Maximum number of items in the session queue.
|
||||
clear_queue_on_startup: Empties session queue on startup.
|
||||
allow_nodes: List of nodes to allow. Omit to allow all.
|
||||
deny_nodes: List of nodes to deny. Omit to deny none.
|
||||
node_cache_size: How many cached nodes to keep in memory.
|
||||
@@ -186,6 +187,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
||||
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
||||
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
||||
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
|
||||
|
||||
# NODES
|
||||
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
|
||||
|
||||
@@ -41,6 +41,7 @@ class ImageRecordStorageBase(ABC):
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
@@ -148,6 +148,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@@ -208,6 +209,13 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND json_extract(images.metadata, '$') LIKE ?
|
||||
"""
|
||||
query_params.append(f'%{search_term}%')
|
||||
|
||||
query_pagination = """--sql
|
||||
ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
@@ -120,6 +120,7 @@ class ImageServiceABC(ABC):
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@@ -206,6 +206,7 @@ class ImageService(ImageServiceABC):
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self.__invoker.services.image_records.get_many(
|
||||
@@ -215,6 +216,7 @@ class ImageService(ImageServiceABC):
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
search_term
|
||||
)
|
||||
|
||||
image_dtos = [
|
||||
|
||||
@@ -37,10 +37,14 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
if self.__invoker.services.configuration.clear_queue_on_startup:
|
||||
clear_result = self.clear(DEFAULT_QUEUE_ID)
|
||||
if clear_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
|
||||
else:
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -289,7 +289,7 @@ def prepare_control_image(
|
||||
width: int,
|
||||
height: int,
|
||||
num_channels: int = 3,
|
||||
device: str | torch.device = "cuda",
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
control_mode: CONTROLNET_MODE_VALUES = "balanced",
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
|
||||
@@ -304,7 +304,7 @@ def prepare_control_image(
|
||||
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
|
||||
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
|
||||
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
|
||||
device (str | torch.Device, optional): The target device for the output image. Defaults to "cuda".
|
||||
device (str, optional): The target device for the output image. Defaults to "cuda".
|
||||
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
|
||||
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
|
||||
Defaults to True.
|
||||
|
||||
@@ -22,8 +22,7 @@ from .generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||
class VAELoader(GenericDiffusersLoader):
|
||||
"""Class to load VAE models."""
|
||||
|
||||
@@ -40,12 +39,8 @@ class VAELoader(GenericDiffusersLoader):
|
||||
return True
|
||||
|
||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
|
||||
# TODO(MM2): check whether sdxl VAE models convert.
|
||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
||||
else:
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = self._app_config.legacy_conf_path / config.config_path
|
||||
assert isinstance(config, CheckpointConfigBase)
|
||||
config_file = self._app_config.legacy_conf_path / config.config_path
|
||||
|
||||
if model_path.suffix == ".safetensors":
|
||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||
|
||||
@@ -451,8 +451,16 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
class VaeCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
# I can't find any standalone 2.X VAEs to test with!
|
||||
return BaseModelType.StableDiffusion1
|
||||
# VAEs of all base types have the same structure, so we wimp out and
|
||||
# guess using the name.
|
||||
for regexp, basetype in [
|
||||
(r"xl", BaseModelType.StableDiffusionXL),
|
||||
(r"sd2", BaseModelType.StableDiffusion2),
|
||||
(r"vae", BaseModelType.StableDiffusion1),
|
||||
]:
|
||||
if re.search(regexp, self.model_path.name, re.IGNORECASE):
|
||||
return basetype
|
||||
raise InvalidModelConfigException("Cannot determine base type")
|
||||
|
||||
|
||||
class LoRACheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
@@ -10,11 +10,12 @@ import PIL.Image
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
@@ -25,7 +26,6 @@ from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion impor
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
from invokeai.backend.util.attention import auto_detect_slice_size
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.hotfixes import ControlNetModel
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -39,17 +39,55 @@ class PipelineIntermediateState:
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskGuidance:
|
||||
class AddsMaskLatents:
|
||||
"""Add the channels required for inpainting model input.
|
||||
|
||||
The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
|
||||
and the latent encoding of the base image.
|
||||
|
||||
This class assumes the same mask and base image should apply to all items in the batch.
|
||||
"""
|
||||
|
||||
forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
mask: torch.Tensor
|
||||
mask_latents: torch.Tensor
|
||||
initial_image_latents: torch.Tensor
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
text_embeddings: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
model_input = self.add_mask_channels(latents)
|
||||
return self.forward(model_input, t, text_embeddings, **kwargs)
|
||||
|
||||
def add_mask_channels(self, latents):
|
||||
batch_size = latents.size(0)
|
||||
# duplicate mask and latents for each batch
|
||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
image_latents = einops.repeat(self.initial_image_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
# add mask and image as additional channels
|
||||
model_input, _ = einops.pack([latents, mask, image_latents], "b * h w")
|
||||
return model_input
|
||||
|
||||
|
||||
def are_like_tensors(a: torch.Tensor, b: object) -> bool:
|
||||
return isinstance(b, torch.Tensor) and (a.size() == b.size())
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddsMaskGuidance:
|
||||
mask: torch.FloatTensor
|
||||
mask_latents: torch.FloatTensor
|
||||
scheduler: SchedulerMixin
|
||||
noise: torch.Tensor
|
||||
is_gradient_mask: bool
|
||||
gradient_mask: bool
|
||||
|
||||
def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
return self.apply_mask(latents, t)
|
||||
|
||||
def apply_mask(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
if t.dim() == 0:
|
||||
@@ -62,7 +100,7 @@ class AddsMaskGuidance:
|
||||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
|
||||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
|
||||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
|
||||
if self.is_gradient_mask:
|
||||
if self.gradient_mask:
|
||||
threshhold = (t.item()) / self.scheduler.config.num_train_timesteps
|
||||
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
|
||||
masked_input = torch.where(mask_bool, latents, mask_latents)
|
||||
@@ -162,6 +200,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
safety_checker: Optional[StableDiffusionSafetyChecker],
|
||||
feature_extractor: Optional[CLIPFeatureExtractor],
|
||||
requires_safety_checker: bool = False,
|
||||
control_model: ControlNetModel = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae=vae,
|
||||
@@ -175,6 +214,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
)
|
||||
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward)
|
||||
self.control_model = control_model
|
||||
self.use_ip_adapter = False
|
||||
|
||||
def _adjust_memory_efficient_attention(self, latents: torch.Tensor):
|
||||
"""
|
||||
@@ -239,131 +280,116 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
raise Exception("Should not be called")
|
||||
|
||||
def add_inpainting_channels_to_latents(
|
||||
self, latents: torch.Tensor, masked_ref_image_latents: torch.Tensor, inpainting_mask: torch.Tensor
|
||||
):
|
||||
"""Given a `latents` tensor, adds the mask and image latents channels required for inpainting.
|
||||
|
||||
Standard (non-inpainting) SD UNet models expect an input with shape (N, 4, H, W). Inpainting models expect an
|
||||
input of shape (N, 9, H, W). The 9 channels are defined as follows:
|
||||
- Channel 0-3: The latents being denoised.
|
||||
- Channel 4: The mask indicating which parts of the image are being inpainted.
|
||||
- Channel 5-8: The latent representation of the masked reference image being inpainted.
|
||||
|
||||
This function assumes that the same mask and base image should apply to all items in the batch.
|
||||
"""
|
||||
# Validate assumptions about input tensor shapes.
|
||||
batch_size, latent_channels, latent_height, latent_width = latents.shape
|
||||
assert latent_channels == 4
|
||||
assert masked_ref_image_latents.shape == [1, 4, latent_height, latent_width]
|
||||
assert inpainting_mask == [1, 1, latent_height, latent_width]
|
||||
|
||||
# Repeat original_image_latents and inpainting_mask to match the latents batch size.
|
||||
original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1)
|
||||
inpainting_mask = inpainting_mask.expand(batch_size, -1, -1, -1)
|
||||
|
||||
# Concatenate along the channel dimension.
|
||||
return torch.cat([latents, inpainting_mask, original_image_latents], dim=1)
|
||||
|
||||
def latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
conditioning_data: TextConditioningData,
|
||||
*,
|
||||
noise: Optional[torch.Tensor],
|
||||
seed: int,
|
||||
timesteps: torch.Tensor,
|
||||
init_timestep: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None],
|
||||
control_data: list[ControlNetData] | None = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
masked_latents: Optional[torch.Tensor] = None,
|
||||
is_gradient_mask: bool = False,
|
||||
gradient_mask: Optional[bool] = False,
|
||||
seed: int,
|
||||
) -> torch.Tensor:
|
||||
"""Denoise the latents.
|
||||
|
||||
Args:
|
||||
latents: The latent-space image to denoise.
|
||||
- If we are inpainting, this is the initial latent image before noise has been added.
|
||||
- If we are generating a new image, this should be initialized to zeros.
|
||||
- In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
|
||||
scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method.
|
||||
conditioning_data: Text conditionging data.
|
||||
noise: Noise used for two purposes:
|
||||
1. Used by the scheduler to noise the initial `latents` before denoising.
|
||||
2. Used to noise the `masked_latents` when inpainting.
|
||||
`noise` should be None if the `latents` tensor has already been noised.
|
||||
seed: The seed used to generate the noise for the denoising process.
|
||||
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
|
||||
same noise used earlier in the pipeline. This should really be handled in a clearer way.
|
||||
timesteps: The timestep schedule for the denoising process.
|
||||
init_timestep: The first timestep in the schedule.
|
||||
TODO(ryand): I'm pretty sure this should always be the same as timesteps[0:1]. Confirm that that is the
|
||||
case, and remove this duplicate param.
|
||||
callback: A callback function that is called to report progress during the denoising process.
|
||||
control_data: ControlNet data.
|
||||
ip_adapter_data: IP-Adapter data.
|
||||
t2i_adapter_data: T2I-Adapter data.
|
||||
mask: A mask indicating which parts of the image are being inpainted. The presence of mask is used to
|
||||
determine whether we are inpainting or not. `mask` should have the same spatial dimensions as the
|
||||
`latents` tensor.
|
||||
TODO(ryand): Check and document the expected dtype, range, and values used to represent
|
||||
foreground/background.
|
||||
masked_latents: A latent-space representation of a masked inpainting reference image. This tensor is only
|
||||
used if an *inpainting* model is being used i.e. this tensor is not used when inpainting with a standard
|
||||
SD UNet model.
|
||||
is_gradient_mask: A flag indicating whether `mask` is a gradient mask or not.
|
||||
"""
|
||||
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
|
||||
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
|
||||
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
|
||||
if init_timestep.shape[0] == 0:
|
||||
return latents
|
||||
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
orig_latents = latents.clone()
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
batched_init_timestep = init_timestep.expand(batch_size)
|
||||
batched_t = init_timestep.expand(batch_size)
|
||||
|
||||
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||
if noise is not None:
|
||||
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
|
||||
# full noise. Investigate the history of why this got commented out.
|
||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_t)
|
||||
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if mask is not None:
|
||||
if is_inpainting_model(self.unet):
|
||||
if masked_latents is None:
|
||||
raise Exception("Source image required for inpaint mask when inpaint model used!")
|
||||
|
||||
# Handle mask guidance (a.k.a. inpainting).
|
||||
mask_guidance: AddsMaskGuidance | None = None
|
||||
if mask is not None and not is_inpainting_model(self.unet):
|
||||
# We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will
|
||||
# apply mask guidance to the latents.
|
||||
self.invokeai_diffuser.model_forward_callback = AddsMaskLatents(
|
||||
self._unet_forward, mask, masked_latents
|
||||
)
|
||||
else:
|
||||
# if no noise provided, noisify unmasked area based on seed
|
||||
if noise is None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
|
||||
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||
# We still need noise for inpainting, so we generate it from the seed here.
|
||||
if noise is None:
|
||||
noise = torch.randn(
|
||||
orig_latents.shape,
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
generator=torch.Generator(device="cpu").manual_seed(seed),
|
||||
).to(device=orig_latents.device, dtype=orig_latents.dtype)
|
||||
additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask))
|
||||
|
||||
mask_guidance = AddsMaskGuidance(
|
||||
mask=mask,
|
||||
mask_latents=orig_latents,
|
||||
scheduler=self.scheduler,
|
||||
noise=noise,
|
||||
is_gradient_mask=is_gradient_mask,
|
||||
try:
|
||||
latents = self.generate_latents_from_embeddings(
|
||||
latents,
|
||||
timesteps,
|
||||
conditioning_data,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
callback=callback,
|
||||
)
|
||||
finally:
|
||||
self.invokeai_diffuser.model_forward_callback = self._unet_forward
|
||||
|
||||
# restore unmasked part after the last step is completed
|
||||
# in-process masking happens before each step
|
||||
if mask is not None:
|
||||
if gradient_mask:
|
||||
latents = torch.where(mask > 0, latents, orig_latents)
|
||||
else:
|
||||
latents = torch.lerp(
|
||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
def generate_latents_from_embeddings(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
timesteps,
|
||||
conditioning_data: TextConditioningData,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
*,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
) -> torch.Tensor:
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents
|
||||
|
||||
use_ip_adapter = ip_adapter_data is not None
|
||||
use_regional_prompting = (
|
||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
||||
)
|
||||
unet_attention_patcher = None
|
||||
self.use_ip_adapter = use_ip_adapter
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
if use_ip_adapter or use_regional_prompting:
|
||||
@@ -376,28 +402,28 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||
|
||||
with attn_ctx:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
if callback is not None:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# print("timesteps:", timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t = t.expand(batch_size)
|
||||
step_output = self.step(
|
||||
t=batched_t,
|
||||
latents=latents,
|
||||
conditioning_data=conditioning_data,
|
||||
batched_t,
|
||||
latents,
|
||||
conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps),
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
mask_guidance=mask_guidance,
|
||||
mask=mask,
|
||||
masked_latents=masked_latents,
|
||||
additional_guidance=additional_guidance,
|
||||
control_data=control_data,
|
||||
ip_adapter_data=ip_adapter_data,
|
||||
t2i_adapter_data=t2i_adapter_data,
|
||||
@@ -405,28 +431,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
latents = step_output.prev_sample
|
||||
predicted_original = getattr(step_output, "pred_original_sample", None)
|
||||
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=i,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
if callback is not None:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=i,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# restore unmasked part after the last step is completed
|
||||
# in-process masking happens before each step
|
||||
if mask is not None:
|
||||
if is_gradient_mask:
|
||||
latents = torch.where(mask > 0, latents, orig_latents)
|
||||
else:
|
||||
latents = torch.lerp(
|
||||
orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype)
|
||||
)
|
||||
|
||||
return latents
|
||||
return latents
|
||||
|
||||
@torch.inference_mode()
|
||||
def step(
|
||||
@@ -437,20 +454,19 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
mask_guidance: AddsMaskGuidance | None,
|
||||
mask: torch.Tensor | None,
|
||||
masked_latents: torch.Tensor | None,
|
||||
control_data: list[ControlNetData] | None = None,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: Optional[list[IPAdapterData]] = None,
|
||||
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
|
||||
):
|
||||
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
|
||||
timestep = t[0]
|
||||
if additional_guidance is None:
|
||||
additional_guidance = []
|
||||
|
||||
# Handle masked image-to-image (a.k.a inpainting).
|
||||
if mask_guidance is not None:
|
||||
# NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...).
|
||||
latents = mask_guidance(latents, timestep)
|
||||
# one day we will expand this extension point, but for now it just does denoise masking
|
||||
for guidance in additional_guidance:
|
||||
latents = guidance(latents, timestep)
|
||||
|
||||
# TODO: should this scaling happen here or inside self._unet_forward?
|
||||
# i.e. before or after passing it to InvokeAIDiffuserComponent
|
||||
@@ -498,31 +514,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
down_intrablock_additional_residuals = accum_adapter_state
|
||||
|
||||
# Handle inpainting models.
|
||||
if is_inpainting_model(self.unet):
|
||||
# NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after*
|
||||
# self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image
|
||||
# latents.
|
||||
if mask is not None:
|
||||
if masked_latents is None:
|
||||
raise ValueError("Source image required for inpaint mask when inpaint model used!")
|
||||
latent_model_input = self.add_inpainting_channels_to_latents(
|
||||
latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask
|
||||
)
|
||||
else:
|
||||
# We are using an inpainting model, but no mask was provided, so we are not really "inpainting".
|
||||
# We generate a global mask and empty original image so that we can still generate in this
|
||||
# configuration.
|
||||
# TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting
|
||||
# to do this.
|
||||
# TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake'
|
||||
# mask and original image once rather than on every denoising step.
|
||||
latent_model_input = self.add_inpainting_channels_to_latents(
|
||||
latents=latent_model_input,
|
||||
masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]),
|
||||
inpainting_mask=torch.ones_like(latent_model_input[:1, :1]),
|
||||
)
|
||||
|
||||
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
|
||||
sample=latent_model_input,
|
||||
timestep=t, # TODO: debug how handled batched and non batched timesteps
|
||||
@@ -551,18 +542,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs)
|
||||
|
||||
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting
|
||||
# again.
|
||||
if mask_guidance is not None:
|
||||
# Apply the mask to any "denoised" or "pred_original_sample" fields.
|
||||
# TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again.
|
||||
for guidance in additional_guidance:
|
||||
# apply the mask to any "denoised" or "pred_original_sample" fields
|
||||
if hasattr(step_output, "denoised"):
|
||||
step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1])
|
||||
step_output.pred_original_sample = guidance(step_output.denoised, self.scheduler.timesteps[-1])
|
||||
elif hasattr(step_output, "pred_original_sample"):
|
||||
step_output.pred_original_sample = mask_guidance(
|
||||
step_output.pred_original_sample = guidance(
|
||||
step_output.pred_original_sample, self.scheduler.timesteps[-1]
|
||||
)
|
||||
else:
|
||||
step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1])
|
||||
step_output.pred_original_sample = guidance(latents, self.scheduler.timesteps[-1])
|
||||
|
||||
return step_output
|
||||
|
||||
@@ -585,6 +575,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
**kwargs,
|
||||
):
|
||||
"""predict the noise residual"""
|
||||
if is_inpainting_model(self.unet) and latents.size(1) == 4:
|
||||
# Pad out normal non-inpainting inputs for an inpainting model.
|
||||
# FIXME: There are too many layers of functions and we have too many different ways of
|
||||
# overriding things! This should get handled in a way more consistent with the other
|
||||
# use of AddsMaskLatents.
|
||||
latents = AddsMaskLatents(
|
||||
self._unet_forward,
|
||||
mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
|
||||
initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype),
|
||||
).add_mask_channels(latents)
|
||||
|
||||
# First three args should be positional, not keywords, so torch hooks can see them.
|
||||
return self.unet(
|
||||
latents,
|
||||
|
||||
@@ -1,242 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
||||
ControlNetData,
|
||||
PipelineIntermediateState,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
||||
from invokeai.backend.tiles.utils import TBLR
|
||||
|
||||
# The maximum number of regions with compatible sizes that will be batched together.
|
||||
# Larger batch sizes improve speed, but require more device memory.
|
||||
MAX_REGION_BATCH_SIZE = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiDiffusionRegionConditioning:
|
||||
# Region coords in latent space.
|
||||
region: TBLR
|
||||
text_conditioning_data: TextConditioningData
|
||||
control_data: list[ControlNetData]
|
||||
|
||||
|
||||
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
|
||||
|
||||
def _split_into_region_batches(
|
||||
self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]
|
||||
) -> list[list[MultiDiffusionRegionConditioning]]:
|
||||
# Group the regions by shape. Only regions with the same shape can be batched together.
|
||||
conditioning_by_shape: dict[tuple[int, int], list[MultiDiffusionRegionConditioning]] = {}
|
||||
for region_conditioning in multi_diffusion_conditioning:
|
||||
shape_hw = (
|
||||
region_conditioning.region.bottom - region_conditioning.region.top,
|
||||
region_conditioning.region.right - region_conditioning.region.left,
|
||||
)
|
||||
# In python, a tuple of hashable objects is hashable, so can be used as a key in a dict.
|
||||
if shape_hw not in conditioning_by_shape:
|
||||
conditioning_by_shape[shape_hw] = []
|
||||
conditioning_by_shape[shape_hw].append(region_conditioning)
|
||||
|
||||
# Split the regions into batches, respecting the MAX_REGION_BATCH_SIZE constraint.
|
||||
region_conditioning_batches = []
|
||||
for region_conditioning_batch in conditioning_by_shape.values():
|
||||
for i in range(0, len(region_conditioning_batch), MAX_REGION_BATCH_SIZE):
|
||||
region_conditioning_batches.append(region_conditioning_batch[i : i + MAX_REGION_BATCH_SIZE])
|
||||
|
||||
return region_conditioning_batches
|
||||
|
||||
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
|
||||
"""Check the input conditioning and confirm that regional prompting is not used."""
|
||||
for region_conditioning in multi_diffusion_conditioning:
|
||||
if (
|
||||
region_conditioning.text_conditioning_data.cond_regions is not None
|
||||
or region_conditioning.text_conditioning_data.uncond_regions is not None
|
||||
):
|
||||
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
|
||||
|
||||
def multi_diffusion_denoise(
|
||||
self,
|
||||
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
|
||||
latents: torch.Tensor,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
noise: Optional[torch.Tensor],
|
||||
timesteps: torch.Tensor,
|
||||
init_timestep: torch.Tensor,
|
||||
callback: Callable[[PipelineIntermediateState], None],
|
||||
) -> torch.Tensor:
|
||||
self._check_regional_prompting(multi_diffusion_conditioning)
|
||||
|
||||
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
|
||||
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
|
||||
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
|
||||
return latents
|
||||
|
||||
batch_size, _, latent_height, latent_width = latents.shape
|
||||
batched_init_timestep = init_timestep.expand(batch_size)
|
||||
|
||||
# noise can be None if the latents have already been noised (e.g. when running the SDXL refiner).
|
||||
if noise is not None:
|
||||
# TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with
|
||||
# full noise. Investigate the history of why this got commented out.
|
||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
||||
|
||||
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
|
||||
# cropping into regions.
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
|
||||
# Populate a weighted mask that will be used to combine the results from each region after every step.
|
||||
# For now, we assume that each region has the same weight (1.0).
|
||||
region_weight_mask = torch.zeros(
|
||||
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
|
||||
)
|
||||
for region_conditioning in multi_diffusion_conditioning:
|
||||
region = region_conditioning.region
|
||||
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
|
||||
|
||||
# Group the region conditioning into batches for faster processing.
|
||||
# region_conditioning_batches[b][r] is the r'th region in the b'th batch.
|
||||
region_conditioning_batches = self._split_into_region_batches(multi_diffusion_conditioning)
|
||||
|
||||
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
|
||||
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
|
||||
# separate scheduler state for each region batch.
|
||||
region_batch_schedulers: list[SchedulerMixin] = [
|
||||
copy.deepcopy(self.scheduler) for _ in region_conditioning_batches
|
||||
]
|
||||
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=-1,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=self.scheduler.config.num_train_timesteps,
|
||||
latents=latents,
|
||||
)
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
batched_t = t.expand(batch_size)
|
||||
|
||||
merged_latents = torch.zeros_like(latents)
|
||||
merged_pred_original: torch.Tensor | None = None
|
||||
for region_batch_idx, region_conditioning_batch in enumerate(region_conditioning_batches):
|
||||
# Switch to the scheduler for the region batch.
|
||||
self.scheduler = region_batch_schedulers[region_batch_idx]
|
||||
|
||||
# TODO(ryand): This logic has not yet been tested with input latents with a batch_size > 1.
|
||||
|
||||
# Prepare the latents for the region batch.
|
||||
batch_latents = torch.cat(
|
||||
[
|
||||
latents[
|
||||
:,
|
||||
:,
|
||||
region_conditioning.region.top : region_conditioning.region.bottom,
|
||||
region_conditioning.region.left : region_conditioning.region.right,
|
||||
]
|
||||
for region_conditioning in region_conditioning_batch
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(ryand): Do we have to repeat the text_conditioning_data to match the batch size? Or does step()
|
||||
# handle broadcasting properly?
|
||||
|
||||
# TODO(ryand): Resume here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
# Run the denoising step on the region.
|
||||
step_output = self.step(
|
||||
t=batched_t,
|
||||
latents=batch_latents,
|
||||
conditioning_data=region_conditioning.text_conditioning_data,
|
||||
step_index=i,
|
||||
total_step_count=total_step_count,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
mask_guidance=None,
|
||||
mask=None,
|
||||
masked_latents=None,
|
||||
control_data=region_conditioning.control_data,
|
||||
)
|
||||
# Run a denoising step on the region.
|
||||
# step_output = self._region_step(
|
||||
# region_conditioning=region_conditioning,
|
||||
# t=batched_t,
|
||||
# latents=latents,
|
||||
# step_index=i,
|
||||
# total_step_count=len(timesteps),
|
||||
# scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
# )
|
||||
|
||||
# Store the results from the region.
|
||||
region = region_conditioning.region
|
||||
merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample
|
||||
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
|
||||
if pred_orig_sample is not None:
|
||||
# If one region has pred_original_sample, then we can assume that all regions will have it, because
|
||||
# they all use the same scheduler.
|
||||
if merged_pred_original is None:
|
||||
merged_pred_original = torch.zeros_like(latents)
|
||||
merged_pred_original[:, :, region.top : region.bottom, region.left : region.right] += (
|
||||
pred_orig_sample
|
||||
)
|
||||
|
||||
# Normalize the merged results.
|
||||
latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
|
||||
predicted_original = None
|
||||
if merged_pred_original is not None:
|
||||
predicted_original = torch.where(
|
||||
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
|
||||
)
|
||||
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
step=i,
|
||||
order=self.scheduler.order,
|
||||
total_steps=len(timesteps),
|
||||
timestep=int(t),
|
||||
latents=latents,
|
||||
predicted_original=predicted_original,
|
||||
)
|
||||
)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.inference_mode()
|
||||
def _region_batch_step(
|
||||
self,
|
||||
region_conditioning: MultiDiffusionRegionConditioning,
|
||||
t: torch.Tensor,
|
||||
latents: torch.Tensor,
|
||||
step_index: int,
|
||||
total_step_count: int,
|
||||
scheduler_step_kwargs: dict[str, Any],
|
||||
):
|
||||
# Crop the inputs to the region.
|
||||
region_latents = latents[
|
||||
:,
|
||||
:,
|
||||
region_conditioning.region.top : region_conditioning.region.bottom,
|
||||
region_conditioning.region.left : region_conditioning.region.right,
|
||||
]
|
||||
|
||||
# Run the denoising step on the region.
|
||||
return self.step(
|
||||
t=t,
|
||||
latents=region_latents,
|
||||
conditioning_data=region_conditioning.text_conditioning_data,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||
mask_guidance=None,
|
||||
mask=None,
|
||||
masked_latents=None,
|
||||
control_data=region_conditioning.control_data,
|
||||
)
|
||||
@@ -37,7 +37,11 @@
|
||||
"selectBoard": "Select a Board",
|
||||
"topMessage": "This board contains images used in the following features:",
|
||||
"uncategorized": "Uncategorized",
|
||||
"downloadBoard": "Download Board"
|
||||
"downloadBoard": "Download Board",
|
||||
"imagesWithCount_one": "{{count}} image",
|
||||
"imagesWithCount_other": "{{count}} images",
|
||||
"assetsWithCount_one": "{{count}} asset",
|
||||
"assetsWithCount_other": "{{count}} assets"
|
||||
},
|
||||
"accordions": {
|
||||
"generation": {
|
||||
@@ -380,7 +384,11 @@
|
||||
"problemDeletingImagesDesc": "One or more images could not be deleted",
|
||||
"viewerImage": "Viewer Image",
|
||||
"compareImage": "Compare Image",
|
||||
"noActiveSearch": "No active search",
|
||||
"openInViewer": "Open in Viewer",
|
||||
"searchingBy": "Searching by",
|
||||
"selectAllOnPage": "Select All On Page",
|
||||
"selectAllOnBoard": "Select All On Board",
|
||||
"selectForCompare": "Select for Compare",
|
||||
"selectAnImageToCompare": "Select an Image to Compare",
|
||||
"slider": "Slider",
|
||||
|
||||
@@ -2,8 +2,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageCache } from 'services/api/types';
|
||||
import { getListImagesUrl, imagesSelectors } from 'services/api/util';
|
||||
import { getListImagesUrl } from 'services/api/util';
|
||||
|
||||
export const addFirstListImagesListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@@ -18,13 +17,10 @@ export const addFirstListImagesListener = (startAppListening: AppStartListening)
|
||||
cancelActiveListeners();
|
||||
unsubscribe();
|
||||
|
||||
// TODO: figure out how to type the predicate
|
||||
const data = action.payload as ImageCache;
|
||||
const data = action.payload;
|
||||
|
||||
if (data.ids.length > 0) {
|
||||
// Select the first image
|
||||
const firstImage = imagesSelectors.selectAll(data)[0];
|
||||
dispatch(imageSelected(firstImage ?? null));
|
||||
if (data.items.length > 0) {
|
||||
dispatch(imageSelected(data.items[0] ?? null));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
boardIdSelected,
|
||||
galleryViewChanged,
|
||||
imageSelected,
|
||||
selectionChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
export const addBoardIdSelectedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@@ -14,14 +18,9 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
|
||||
|
||||
const state = getState();
|
||||
|
||||
const board_id = boardIdSelected.match(action) ? action.payload.boardId : state.gallery.selectedBoardId;
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
|
||||
const galleryView = galleryViewChanged.match(action) ? action.payload : state.gallery.galleryView;
|
||||
|
||||
// when a board is selected, we need to wait until the board has loaded *some* images, then select the first one
|
||||
const categories = galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES;
|
||||
|
||||
const queryArgs = { board_id: board_id ?? 'none', categories };
|
||||
dispatch(selectionChanged([]));
|
||||
|
||||
// wait until the board has some images - maybe it already has some from a previous fetch
|
||||
// must use getState() to ensure we do not have stale state
|
||||
@@ -35,11 +34,12 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
|
||||
const { data: boardImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
|
||||
|
||||
if (boardImagesData && boardIdSelected.match(action) && action.payload.selectedImageName) {
|
||||
const selectedImage = imagesSelectors.selectById(boardImagesData, action.payload.selectedImageName);
|
||||
const selectedImage = boardImagesData.items.find(
|
||||
(item) => item.image_name === action.payload.selectedImageName
|
||||
);
|
||||
dispatch(imageSelected(selectedImage || null));
|
||||
} else if (boardImagesData) {
|
||||
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
|
||||
dispatch(imageSelected(firstImage || null));
|
||||
dispatch(imageSelected(boardImagesData.items[0] || null));
|
||||
} else {
|
||||
// board has no images - deselect
|
||||
dispatch(imageSelected(null));
|
||||
|
||||
@@ -4,7 +4,6 @@ import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelecto
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
export const galleryImageClicked = createAction<{
|
||||
imageDTO: ImageDTO;
|
||||
@@ -32,14 +31,14 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
||||
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
||||
const state = getState();
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
const queryResult = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
|
||||
if (!listImagesData) {
|
||||
if (!queryResult.data) {
|
||||
// Should never happen if we have clicked a gallery image
|
||||
return;
|
||||
}
|
||||
|
||||
const imageDTOs = imagesSelectors.selectAll(listImagesData);
|
||||
const imageDTOs = queryResult.data.items;
|
||||
const selection = state.gallery.selection;
|
||||
|
||||
if (altKey) {
|
||||
|
||||
@@ -22,11 +22,10 @@ import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isImageFieldInputInstance } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { clamp, forEach } from 'lodash-es';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { api } from 'services/api';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
state.nodes.present.nodes.forEach((node) => {
|
||||
@@ -118,32 +117,7 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
|
||||
}
|
||||
|
||||
dispatch(isModalOpenChanged(false));
|
||||
|
||||
const state = getState();
|
||||
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
|
||||
|
||||
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
|
||||
const { image_name } = imageDTO;
|
||||
|
||||
const baseQueryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||
|
||||
const cachedImageDTOs = data ? imagesSelectors.selectAll(data) : [];
|
||||
|
||||
const deletedImageIndex = cachedImageDTOs.findIndex((i) => i.image_name === image_name);
|
||||
|
||||
const filteredImageDTOs = cachedImageDTOs.filter((i) => i.image_name !== image_name);
|
||||
|
||||
const newSelectedImageIndex = clamp(deletedImageIndex, 0, filteredImageDTOs.length - 1);
|
||||
|
||||
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex];
|
||||
|
||||
if (newSelectedImageDTO) {
|
||||
dispatch(imageSelected(newSelectedImageDTO));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
}
|
||||
|
||||
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||
if (imageUsage.isCanvasImage) {
|
||||
@@ -168,6 +142,20 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
|
||||
if (wasImageDeleted) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
|
||||
}
|
||||
|
||||
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
|
||||
|
||||
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
|
||||
const baseQueryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||
|
||||
if (data && data.items) {
|
||||
const newlySelectedImage = data?.items.find((img) => img.image_name !== imageDTO?.image_name);
|
||||
dispatch(imageSelected(newlySelectedImage || null));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -188,10 +176,8 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
|
||||
const newSelectedImageDTO = data ? imagesSelectors.selectAll(data)[0] : undefined;
|
||||
|
||||
if (newSelectedImageDTO) {
|
||||
dispatch(imageSelected(newSelectedImageDTO));
|
||||
if (data && data.items[0]) {
|
||||
dispatch(imageSelected(data.items[0]));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
|
||||
@@ -15,7 +15,12 @@ import {
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import {
|
||||
imageSelected,
|
||||
imageToCompareChanged,
|
||||
isImageViewerOpenChanged,
|
||||
selectionChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@@ -216,6 +221,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
board_id: boardId,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -233,6 +239,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
imageDTO,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -248,6 +255,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
board_id: boardId,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -261,6 +269,7 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
imageDTOs,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
},
|
||||
|
||||
@@ -8,14 +8,14 @@ import {
|
||||
galleryViewChanged,
|
||||
imageSelected,
|
||||
isImageViewerOpenChanged,
|
||||
offsetChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
import { getCategories, getListImagesUrl } from 'services/api/util';
|
||||
import { socketInvocationComplete } from 'services/events/actions';
|
||||
|
||||
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
|
||||
@@ -52,24 +52,6 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
}
|
||||
|
||||
if (!imageDTO.is_intermediate) {
|
||||
/**
|
||||
* Cache updates for when an image result is received
|
||||
* - add it to the no_board/images
|
||||
*/
|
||||
|
||||
dispatch(
|
||||
imagesApi.util.updateQueryData(
|
||||
'listImages',
|
||||
{
|
||||
board_id: imageDTO.board_id ?? 'none',
|
||||
categories: IMAGE_CATEGORIES,
|
||||
},
|
||||
(draft) => {
|
||||
imagesAdapter.addOne(draft, imageDTO);
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
// update the total images for the board
|
||||
dispatch(
|
||||
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
|
||||
@@ -78,7 +60,18 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(imagesApi.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
|
||||
dispatch(
|
||||
imagesApi.util.invalidateTags([
|
||||
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
|
||||
{
|
||||
type: 'ImageList',
|
||||
id: getListImagesUrl({
|
||||
board_id: imageDTO.board_id ?? 'none',
|
||||
categories: getCategories(imageDTO),
|
||||
}),
|
||||
},
|
||||
])
|
||||
);
|
||||
|
||||
const { shouldAutoSwitch } = gallery;
|
||||
|
||||
@@ -98,6 +91,8 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(offsetChanged(0));
|
||||
|
||||
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
|
||||
dispatch(
|
||||
boardIdSelected({
|
||||
|
||||
@@ -1,47 +1,37 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import type { IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import type { MouseEvent, ReactElement } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import type { MouseEvent } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
type Props = {
|
||||
const sx: SystemStyleObject = {
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
|
||||
},
|
||||
};
|
||||
|
||||
type Props = Omit<IconButtonProps, 'aria-label' | 'onClick' | 'tooltip'> & {
|
||||
onClick: (event: MouseEvent<HTMLButtonElement>) => void;
|
||||
tooltip: string;
|
||||
icon?: ReactElement;
|
||||
styleOverrides?: SystemStyleObject;
|
||||
};
|
||||
|
||||
const IAIDndImageIcon = (props: Props) => {
|
||||
const { onClick, tooltip, icon, styleOverrides } = props;
|
||||
|
||||
const sx = useMemo(
|
||||
() => ({
|
||||
position: 'absolute',
|
||||
top: 1,
|
||||
insetInlineEnd: 1,
|
||||
p: 0,
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
|
||||
},
|
||||
...styleOverrides,
|
||||
}),
|
||||
[styleOverrides]
|
||||
);
|
||||
const { onClick, tooltip, icon, ...rest } = props;
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
onClick={onClick}
|
||||
aria-label={tooltip}
|
||||
tooltip={tooltip}
|
||||
icon={icon}
|
||||
size="sm"
|
||||
variant="link"
|
||||
sx={sx}
|
||||
data-testid={tooltip}
|
||||
{...rest}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
/**
|
||||
* Comparator function for sorting dates in ascending order
|
||||
*/
|
||||
export const dateComparator = (a: string, b: string) => {
|
||||
const dateA = new Date(a);
|
||||
const dateB = new Date(b);
|
||||
|
||||
// sort in ascending order
|
||||
if (dateA > dateB) {
|
||||
return 1;
|
||||
}
|
||||
if (dateA < dateB) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Spinner } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
@@ -185,7 +184,7 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
|
||||
/>
|
||||
</Box>
|
||||
|
||||
<>
|
||||
<Flex flexDir="column" top={1} insetInlineEnd={1}>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
|
||||
@@ -195,15 +194,13 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
|
||||
onClick={handleSaveControlImage}
|
||||
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.saveControlImage')}
|
||||
styleOverrides={saveControlImageStyleOverrides}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.setControlImageDimensions')}
|
||||
styleOverrides={setControlImageDimensionsStyleOverrides}
|
||||
/>
|
||||
</>
|
||||
</Flex>
|
||||
|
||||
{pendingControlImages.includes(id) && (
|
||||
<Flex
|
||||
@@ -226,6 +223,3 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
|
||||
};
|
||||
|
||||
export default memo(ControlAdapterImagePreview);
|
||||
|
||||
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Spinner, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -203,13 +202,13 @@ export const ControlAdapterImagePreview = memo(
|
||||
onClick={handleSaveControlImage}
|
||||
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.saveControlImage')}
|
||||
styleOverrides={saveControlImageStyleOverrides}
|
||||
mt={6}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
styleOverrides={setControlImageDimensionsStyleOverrides}
|
||||
mt={12}
|
||||
/>
|
||||
</>
|
||||
|
||||
@@ -235,6 +234,3 @@ export const ControlAdapterImagePreview = memo(
|
||||
);
|
||||
|
||||
ControlAdapterImagePreview.displayName = 'ControlAdapterImagePreview';
|
||||
|
||||
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -100,7 +99,7 @@ export const IPAdapterImagePreview = memo(
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
styleOverrides={setControlImageDimensionsStyleOverrides}
|
||||
mt={6}
|
||||
/>
|
||||
</>
|
||||
</Flex>
|
||||
@@ -109,5 +108,3 @@ export const IPAdapterImagePreview = memo(
|
||||
);
|
||||
|
||||
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';
|
||||
|
||||
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -97,7 +96,7 @@ export const InitialImagePreview = memo(({ image, onChangeImage, droppableData,
|
||||
onClick={onUseSize}
|
||||
icon={imageDTO ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
styleOverrides={useSizeStyleOverrides}
|
||||
mt={6}
|
||||
/>
|
||||
</>
|
||||
</Flex>
|
||||
@@ -105,5 +104,3 @@ export const InitialImagePreview = memo(({ image, onChangeImage, droppableData,
|
||||
});
|
||||
|
||||
InitialImagePreview.displayName = 'InitialImagePreview';
|
||||
|
||||
const useSizeStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
|
||||
type Props = {
|
||||
board_id: string;
|
||||
};
|
||||
|
||||
export const BoardTotalsTooltip = ({ board_id }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { imagesTotal } = useGetBoardImagesTotalQuery(board_id, {
|
||||
selectFromResult: ({ data }) => {
|
||||
return { imagesTotal: data?.total ?? 0 };
|
||||
},
|
||||
});
|
||||
const { assetsTotal } = useGetBoardAssetsTotalQuery(board_id, {
|
||||
selectFromResult: ({ data }) => {
|
||||
return { assetsTotal: data?.total ?? 0 };
|
||||
},
|
||||
});
|
||||
return `${t('boards.imagesWithCount', { count: imagesTotal })}, ${t('boards.assetsWithCount', { count: assetsTotal })}`;
|
||||
};
|
||||
@@ -8,15 +8,12 @@ import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||
import type { AddToBoardDropData } from 'features/dnd/types';
|
||||
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
|
||||
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
|
||||
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
|
||||
import { autoAddBoardIdChanged, boardIdSelected, selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImagesSquare } from 'react-icons/pi';
|
||||
import {
|
||||
useGetBoardAssetsTotalQuery,
|
||||
useGetBoardImagesTotalQuery,
|
||||
useUpdateBoardMutation,
|
||||
} from 'services/api/endpoints/boards';
|
||||
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import type { BoardDTO } from 'services/api/types';
|
||||
|
||||
@@ -51,17 +48,6 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
setIsHovered(false);
|
||||
}, []);
|
||||
|
||||
const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id);
|
||||
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id);
|
||||
const tooltip = useMemo(() => {
|
||||
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
|
||||
assetsTotal.total
|
||||
} asset${assetsTotal.total === 1 ? '' : 's'}`;
|
||||
}, [assetsTotal, imagesTotal]);
|
||||
|
||||
const { currentData: coverImage } = useGetImageDTOQuery(board.cover_image_name ?? skipToken);
|
||||
|
||||
const { board_name, board_id } = board;
|
||||
@@ -132,7 +118,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
>
|
||||
<BoardContextMenu board={board} board_id={board_id} setBoardToDelete={setBoardToDelete}>
|
||||
{(ref) => (
|
||||
<Tooltip label={tooltip} openDelay={1000}>
|
||||
<Tooltip label={<BoardTotalsTooltip board_id={board.board_id} />} openDelay={1000}>
|
||||
<Flex
|
||||
ref={ref}
|
||||
onClick={handleSelectBoard}
|
||||
|
||||
@@ -5,11 +5,11 @@ import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||
import type { RemoveFromBoardDropData } from 'features/dnd/types';
|
||||
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
|
||||
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
|
||||
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
|
||||
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
|
||||
interface Props {
|
||||
@@ -29,17 +29,6 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
}, [dispatch, autoAssignBoardOnClick]);
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const { data: imagesTotal } = useGetBoardImagesTotalQuery('none');
|
||||
const { data: assetsTotal } = useGetBoardAssetsTotalQuery('none');
|
||||
const tooltip = useMemo(() => {
|
||||
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
|
||||
assetsTotal.total
|
||||
} asset${assetsTotal.total === 1 ? '' : 's'}`;
|
||||
}, [assetsTotal, imagesTotal]);
|
||||
|
||||
const handleMouseOver = useCallback(() => {
|
||||
setIsHovered(true);
|
||||
}, []);
|
||||
@@ -71,7 +60,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
>
|
||||
<BoardContextMenu board_id="none">
|
||||
{(ref) => (
|
||||
<Tooltip label={tooltip} openDelay={1000}>
|
||||
<Tooltip label={<BoardTotalsTooltip board_id="none" />} openDelay={1000}>
|
||||
<Flex
|
||||
ref={ref}
|
||||
onClick={handleSelectBoard}
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
import { Flex, IconButton, Spacer, Tag, TagCloseButton, TagLabel, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BiSelectMultiple } from 'react-icons/bi';
|
||||
|
||||
import { GallerySearch } from './GallerySearch';
|
||||
|
||||
export const GalleryBulkSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { selection } = useAppSelector((s) => s.gallery);
|
||||
const { t } = useTranslation();
|
||||
const { imageDTOs } = useGalleryImages();
|
||||
|
||||
const onClickClearSelection = useCallback(() => {
|
||||
dispatch(selectionChanged([]));
|
||||
}, [dispatch]);
|
||||
|
||||
const onClickSelectAllPage = useCallback(() => {
|
||||
dispatch(selectionChanged(selection.concat(imageDTOs)));
|
||||
}, [dispatch, imageDTOs, selection]);
|
||||
|
||||
return (
|
||||
<Flex alignItems="center" justifyContent="space-between">
|
||||
<Flex>
|
||||
{selection.length > 0 ? (
|
||||
<Tag>
|
||||
<TagLabel>
|
||||
{selection.length} {t('common.selected')}
|
||||
</TagLabel>
|
||||
<Tooltip label="Clear selection">
|
||||
<TagCloseButton onClick={onClickClearSelection} />
|
||||
</Tooltip>
|
||||
</Tag>
|
||||
) : (
|
||||
<Spacer />
|
||||
)}
|
||||
|
||||
<Tooltip label={t('gallery.selectAllOnPage')}>
|
||||
<IconButton
|
||||
variant="outline"
|
||||
size="sm"
|
||||
icon={<BiSelectMultiple />}
|
||||
aria-label="Bulk select"
|
||||
onClick={onClickSelectAllPage}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
|
||||
<GallerySearch />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,97 @@
|
||||
import { Flex, IconButton, Input, InputGroup, InputRightElement, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { searchTermChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { motion } from 'framer-motion';
|
||||
import { debounce } from 'lodash-es';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiMagnifyingGlassBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
export const GallerySearch = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { searchTerm } = useAppSelector((s) => s.gallery);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [searchTermInput, setSearchTermInput] = useState('');
|
||||
|
||||
const debouncedSetSearchTerm = useMemo(() => {
|
||||
return debounce((value: string) => {
|
||||
dispatch(searchTermChanged(value));
|
||||
}, 1000);
|
||||
}, [dispatch]);
|
||||
|
||||
const onChangeInput = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
setSearchTermInput(e.target.value);
|
||||
debouncedSetSearchTerm(e.target.value);
|
||||
},
|
||||
[debouncedSetSearchTerm]
|
||||
);
|
||||
|
||||
const onClearInput = useCallback(() => {
|
||||
setSearchTermInput('');
|
||||
debouncedSetSearchTerm('');
|
||||
}, [debouncedSetSearchTerm]);
|
||||
|
||||
const toggleExpanded = useCallback((newState: boolean) => {
|
||||
setExpanded(newState);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex>
|
||||
{!expanded && (
|
||||
<Tooltip
|
||||
label={
|
||||
searchTerm && searchTerm.length ? `${t('gallery.searchingBy')} ${searchTerm}` : t('gallery.noActiveSearch')
|
||||
}
|
||||
>
|
||||
<IconButton
|
||||
aria-label="Close"
|
||||
icon={<PiMagnifyingGlassBold />}
|
||||
onClick={toggleExpanded.bind(null, true)}
|
||||
variant="outline"
|
||||
size="sm"
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
<motion.div
|
||||
initial={false}
|
||||
animate={{ width: expanded ? '200px' : '0px' }}
|
||||
transition={{ duration: 0.3 }}
|
||||
style={{ overflow: 'hidden' }}
|
||||
>
|
||||
<InputGroup size="sm">
|
||||
<IconButton
|
||||
aria-label="Close"
|
||||
icon={<PiMagnifyingGlassBold />}
|
||||
onClick={toggleExpanded.bind(null, false)}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
/>
|
||||
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Search..."
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onChange={onChangeInput}
|
||||
value={searchTermInput}
|
||||
/>
|
||||
{searchTermInput && searchTermInput.length && (
|
||||
<InputRightElement h="full" pe={2}>
|
||||
<IconButton
|
||||
onClick={onClearInput}
|
||||
size="sm"
|
||||
variant="link"
|
||||
aria-label={t('boards.clearSearch')}
|
||||
icon={<PiXBold />}
|
||||
/>
|
||||
</InputRightElement>
|
||||
)}
|
||||
</InputGroup>
|
||||
</motion.div>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -1,22 +1,22 @@
|
||||
import { Box, Button, ButtonGroup, Flex, Tab, TabList, Tabs, useDisclosure, VStack } from '@invoke-ai/ui-library';
|
||||
import { Box, Button, ButtonGroup, Flex, Tab, TabList, Tabs, useDisclosure } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $galleryHeader } from 'app/store/nanostores/galleryHeader';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { galleryViewChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImagesBold } from 'react-icons/pi';
|
||||
import { RiServerLine } from 'react-icons/ri';
|
||||
|
||||
import BoardsList from './Boards/BoardsList/BoardsList';
|
||||
import GalleryBoardName from './GalleryBoardName';
|
||||
import { GalleryBulkSelect } from './GalleryBulkSelect';
|
||||
import GallerySettingsPopover from './GallerySettingsPopover';
|
||||
import GalleryImageGrid from './ImageGrid/GalleryImageGrid';
|
||||
import { GalleryPagination } from './ImageGrid/GalleryPagination';
|
||||
|
||||
const ImageGalleryContent = () => {
|
||||
const { t } = useTranslation();
|
||||
const resizeObserverRef = useRef<HTMLDivElement>(null);
|
||||
const galleryGridRef = useRef<HTMLDivElement>(null);
|
||||
const galleryView = useAppSelector((s) => s.gallery.galleryView);
|
||||
const dispatch = useAppDispatch();
|
||||
const galleryHeader = useStore($galleryHeader);
|
||||
@@ -31,10 +31,10 @@ const ImageGalleryContent = () => {
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<VStack layerStyle="first" flexDirection="column" h="full" w="full" borderRadius="base" p={2}>
|
||||
<Flex layerStyle="first" flexDirection="column" h="full" w="full" borderRadius="base" p={2} gap={2}>
|
||||
{galleryHeader}
|
||||
<Box w="full">
|
||||
<Flex ref={resizeObserverRef} alignItems="center" justifyContent="space-between" gap={2}>
|
||||
<Box>
|
||||
<Flex alignItems="center" justifyContent="space-between" gap={2}>
|
||||
<GalleryBoardName isOpen={isBoardListOpen} onToggle={onToggleBoardList} />
|
||||
<GallerySettingsPopover />
|
||||
</Flex>
|
||||
@@ -42,40 +42,41 @@ const ImageGalleryContent = () => {
|
||||
<BoardsList isOpen={isBoardListOpen} />
|
||||
</Box>
|
||||
</Box>
|
||||
<Flex ref={galleryGridRef} direction="column" gap={2} h="full" w="full">
|
||||
<Flex alignItems="center" justifyContent="space-between" gap={2}>
|
||||
<Tabs index={galleryView === 'images' ? 0 : 1} variant="unstyled" size="sm" w="full">
|
||||
<TabList>
|
||||
<ButtonGroup w="full">
|
||||
<Tab
|
||||
as={Button}
|
||||
size="sm"
|
||||
isChecked={galleryView === 'images'}
|
||||
onClick={handleClickImages}
|
||||
w="full"
|
||||
leftIcon={<PiImagesBold size="16px" />}
|
||||
data-testid="images-tab"
|
||||
>
|
||||
{t('parameters.images')}
|
||||
</Tab>
|
||||
<Tab
|
||||
as={Button}
|
||||
size="sm"
|
||||
isChecked={galleryView === 'assets'}
|
||||
onClick={handleClickAssets}
|
||||
w="full"
|
||||
leftIcon={<RiServerLine size="16px" />}
|
||||
data-testid="assets-tab"
|
||||
>
|
||||
{t('gallery.assets')}
|
||||
</Tab>
|
||||
</ButtonGroup>
|
||||
</TabList>
|
||||
</Tabs>
|
||||
</Flex>
|
||||
<GalleryImageGrid />
|
||||
<Flex alignItems="center" justifyContent="space-between" gap={2}>
|
||||
<Tabs index={galleryView === 'images' ? 0 : 1} variant="unstyled" size="sm" w="full">
|
||||
<TabList>
|
||||
<ButtonGroup w="full">
|
||||
<Tab
|
||||
as={Button}
|
||||
size="sm"
|
||||
isChecked={galleryView === 'images'}
|
||||
onClick={handleClickImages}
|
||||
w="full"
|
||||
leftIcon={<PiImagesBold size="16px" />}
|
||||
data-testid="images-tab"
|
||||
>
|
||||
{t('parameters.images')}
|
||||
</Tab>
|
||||
<Tab
|
||||
as={Button}
|
||||
size="sm"
|
||||
isChecked={galleryView === 'assets'}
|
||||
onClick={handleClickAssets}
|
||||
w="full"
|
||||
leftIcon={<RiServerLine size="16px" />}
|
||||
data-testid="assets-tab"
|
||||
>
|
||||
{t('gallery.assets')}
|
||||
</Tab>
|
||||
</ButtonGroup>
|
||||
</TabList>
|
||||
</Tabs>
|
||||
</Flex>
|
||||
</VStack>
|
||||
<GalleryBulkSelect />
|
||||
|
||||
<GalleryImageGrid />
|
||||
<GalleryPagination />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -16,13 +16,13 @@ import type { MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiStarBold, PiStarFill, PiTrashSimpleFill } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery, useStarImagesMutation, useUnstarImagesMutation } from 'services/api/endpoints/images';
|
||||
import { useStarImagesMutation, useUnstarImagesMutation } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
// This class name is used to calculate the number of images that fit in the gallery
|
||||
export const GALLERY_IMAGE_CLASS_NAME = 'gallery-image';
|
||||
|
||||
const imageSx: SystemStyleObject = { w: 'full', h: 'full' };
|
||||
const imageIconStyleOverrides: SystemStyleObject = {
|
||||
bottom: 2,
|
||||
top: 'auto',
|
||||
};
|
||||
const boxSx: SystemStyleObject = {
|
||||
containerType: 'inline-size',
|
||||
};
|
||||
@@ -34,24 +34,22 @@ const badgeSx: SystemStyleObject = {
|
||||
};
|
||||
|
||||
interface HoverableImageProps {
|
||||
imageName: string;
|
||||
imageDTO: ImageDTO;
|
||||
index: number;
|
||||
}
|
||||
|
||||
const GalleryImage = (props: HoverableImageProps) => {
|
||||
const GalleryImage = ({ index, imageDTO }: HoverableImageProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { imageName } = props;
|
||||
const { currentData: imageDTO } = useGetImageDTOQuery(imageName);
|
||||
const shift = useShiftModifier();
|
||||
const { t } = useTranslation();
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
|
||||
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageName);
|
||||
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageDTO.image_name);
|
||||
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
|
||||
|
||||
const customStarUi = useStore($customStarUI);
|
||||
|
||||
const imageContainerRef = useScrollIntoView(isSelected, props.index, areMultiplesSelected);
|
||||
const imageContainerRef = useScrollIntoView(isSelected, index, areMultiplesSelected);
|
||||
|
||||
const handleDelete = useCallback(
|
||||
(e: MouseEvent<HTMLButtonElement>) => {
|
||||
@@ -114,32 +112,32 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
}, []);
|
||||
|
||||
const starIcon = useMemo(() => {
|
||||
if (imageDTO?.starred) {
|
||||
if (imageDTO.starred) {
|
||||
return customStarUi ? customStarUi.on.icon : <PiStarFill size="20" />;
|
||||
}
|
||||
if (!imageDTO?.starred && isHovered) {
|
||||
if (!imageDTO.starred && isHovered) {
|
||||
return customStarUi ? customStarUi.off.icon : <PiStarBold size="20" />;
|
||||
}
|
||||
}, [imageDTO?.starred, isHovered, customStarUi]);
|
||||
}, [imageDTO.starred, isHovered, customStarUi]);
|
||||
|
||||
const starTooltip = useMemo(() => {
|
||||
if (imageDTO?.starred) {
|
||||
if (imageDTO.starred) {
|
||||
return customStarUi ? customStarUi.off.text : 'Unstar';
|
||||
}
|
||||
if (!imageDTO?.starred) {
|
||||
if (!imageDTO.starred) {
|
||||
return customStarUi ? customStarUi.on.text : 'Star';
|
||||
}
|
||||
return '';
|
||||
}, [imageDTO?.starred, customStarUi]);
|
||||
}, [imageDTO.starred, customStarUi]);
|
||||
|
||||
const dataTestId = useMemo(() => getGalleryImageDataTestId(imageDTO?.image_name), [imageDTO?.image_name]);
|
||||
const dataTestId = useMemo(() => getGalleryImageDataTestId(imageDTO.image_name), [imageDTO.image_name]);
|
||||
|
||||
if (!imageDTO) {
|
||||
return <IAIFillSkeleton />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box w="full" h="full" className="gallerygrid-image" data-testid={dataTestId} sx={boxSx}>
|
||||
<Box w="full" h="full" p={1.5} className={GALLERY_IMAGE_CLASS_NAME} data-testid={dataTestId} sx={boxSx}>
|
||||
<Flex
|
||||
ref={imageContainerRef}
|
||||
userSelect="none"
|
||||
@@ -183,14 +181,23 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
pointerEvents="none"
|
||||
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
|
||||
)}
|
||||
<IAIDndImageIcon onClick={toggleStarredState} icon={starIcon} tooltip={starTooltip} />
|
||||
<IAIDndImageIcon
|
||||
onClick={toggleStarredState}
|
||||
icon={starIcon}
|
||||
tooltip={starTooltip}
|
||||
position="absolute"
|
||||
top={1}
|
||||
insetInlineEnd={1}
|
||||
/>
|
||||
|
||||
{isHovered && shift && (
|
||||
<IAIDndImageIcon
|
||||
onClick={handleDelete}
|
||||
icon={<PiTrashSimpleFill size="16px" />}
|
||||
tooltip={t('gallery.deleteImage', { count: 1 })}
|
||||
styleOverrides={imageIconStyleOverrides}
|
||||
tooltip={t('gallery.deleteImage_one')}
|
||||
position="absolute"
|
||||
bottom={1}
|
||||
insetInlineEnd={1}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
|
||||
@@ -1,120 +1,32 @@
|
||||
import { Box, Button, Flex } from '@invoke-ai/ui-library';
|
||||
import type { EntityId } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { Box, Flex, Grid } from '@invoke-ai/ui-library';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
||||
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
|
||||
import { useGalleryHotkeys } from 'features/gallery/hooks/useGalleryHotkeys';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { limitChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { debounce } from 'lodash-es';
|
||||
import { memo, useEffect, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImageBold, PiWarningCircleBold } from 'react-icons/pi';
|
||||
import type { GridComponents, ItemContent, ListRange, VirtuosoGridHandle } from 'react-virtuoso';
|
||||
import { VirtuosoGrid } from 'react-virtuoso';
|
||||
import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
|
||||
import GalleryImage from './GalleryImage';
|
||||
import ImageGridItemContainer from './ImageGridItemContainer';
|
||||
import ImageGridListContainer from './ImageGridListContainer';
|
||||
|
||||
const components: GridComponents = {
|
||||
Item: ImageGridItemContainer,
|
||||
List: ImageGridListContainer,
|
||||
};
|
||||
|
||||
const virtuosoStyles: CSSProperties = { height: '100%' };
|
||||
import { GALLERY_GRID_CLASS_NAME } from './constants';
|
||||
import GalleryImage, { GALLERY_IMAGE_CLASS_NAME } from './GalleryImage';
|
||||
|
||||
const GalleryImageGrid = () => {
|
||||
const { t } = useTranslation();
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
const [scroller, setScroller] = useState<HTMLElement | null>(null);
|
||||
const [initialize, osInstance] = useOverlayScrollbars(overlayScrollbarsParams);
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const { currentViewTotal } = useBoardTotal(selectedBoardId);
|
||||
const virtuosoRangeRef = useRef<ListRange | null>(null);
|
||||
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
|
||||
const {
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
queryResult: { currentData, isFetching, isSuccess, isError },
|
||||
} = useGalleryImages();
|
||||
useGalleryHotkeys();
|
||||
const itemContentFunc: ItemContent<EntityId, void> = useCallback(
|
||||
(index, imageName) => <GalleryImage key={imageName} index={index} imageName={imageName as string} />,
|
||||
[]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
// Initialize the gallery's custom scrollbar
|
||||
const { current: root } = rootRef;
|
||||
if (scroller && root) {
|
||||
initialize({
|
||||
target: root,
|
||||
elements: {
|
||||
viewport: scroller,
|
||||
},
|
||||
});
|
||||
}
|
||||
return () => osInstance()?.destroy();
|
||||
}, [scroller, initialize, osInstance]);
|
||||
|
||||
const onRangeChanged = useCallback((range: ListRange) => {
|
||||
virtuosoRangeRef.current = range;
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
virtuosoGridRefs.set({ rootRef, virtuosoRangeRef, virtuosoRef });
|
||||
return () => {
|
||||
virtuosoGridRefs.set({});
|
||||
};
|
||||
}, []);
|
||||
|
||||
if (!currentData) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<IAINoContentFallback label={t('gallery.loading')} icon={PiImageBold} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSuccess && currentData?.ids.length === 0) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<IAINoContentFallback label={t('gallery.noImagesInGallery')} icon={PiImageBold} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSuccess && currentData) {
|
||||
return (
|
||||
<>
|
||||
<Box ref={rootRef} data-overlayscrollbars="" h="100%" id="gallery-grid">
|
||||
<VirtuosoGrid
|
||||
style={virtuosoStyles}
|
||||
data={currentData.ids}
|
||||
endReached={handleLoadMoreImages}
|
||||
components={components}
|
||||
scrollerRef={setScroller}
|
||||
itemContent={itemContentFunc}
|
||||
ref={virtuosoRef}
|
||||
rangeChanged={onRangeChanged}
|
||||
overscan={10}
|
||||
/>
|
||||
</Box>
|
||||
<Button
|
||||
onClick={handleLoadMoreImages}
|
||||
isDisabled={!areMoreImagesAvailable}
|
||||
isLoading={isFetching}
|
||||
loadingText={t('gallery.loading')}
|
||||
flexShrink={0}
|
||||
>
|
||||
{`${t('accessibility.loadMore')} (${currentData.ids.length} / ${currentViewTotal})`}
|
||||
</Button>
|
||||
</>
|
||||
);
|
||||
}
|
||||
const { t } = useTranslation();
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const { imageDTOs, isLoading, isError, isFetching } = useListImagesQuery(queryArgs, {
|
||||
selectFromResult: ({ data, isLoading, isSuccess, isError, isFetching }) => ({
|
||||
imageDTOs: data?.items ?? EMPTY_ARRAY,
|
||||
isLoading,
|
||||
isSuccess,
|
||||
isError,
|
||||
isFetching,
|
||||
}),
|
||||
});
|
||||
|
||||
if (isError) {
|
||||
return (
|
||||
@@ -124,7 +36,115 @@ const GalleryImageGrid = () => {
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
if (isLoading || isFetching) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<IAINoContentFallback label={t('gallery.loading')} icon={PiImageBold} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (imageDTOs.length === 0) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<IAINoContentFallback label={t('gallery.noImagesInGallery')} icon={PiImageBold} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return <Content />;
|
||||
};
|
||||
|
||||
export default memo(GalleryImageGrid);
|
||||
|
||||
const Content = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
|
||||
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const { imageDTOs } = useListImagesQuery(queryArgs, {
|
||||
selectFromResult: ({ data }) => ({ imageDTOs: data?.items ?? EMPTY_ARRAY }),
|
||||
});
|
||||
// Use a callback ref to get reactivity on the container element because it is conditionally rendered
|
||||
const [container, containerRef] = useState<HTMLDivElement | null>(null);
|
||||
|
||||
const calculateNewLimit = useMemo(() => {
|
||||
// Debounce this to not thrash the API
|
||||
return debounce(() => {
|
||||
if (!container) {
|
||||
// Container not rendered yet
|
||||
return;
|
||||
}
|
||||
// Managing refs for dynamically rendered components is a bit tedious:
|
||||
// - https://react.dev/learn/manipulating-the-dom-with-refs#how-to-manage-a-list-of-refs-using-a-ref-callback
|
||||
// As a easy workaround, we can just grab the first gallery image element directly.
|
||||
const galleryImageEl = document.querySelector(`.${GALLERY_IMAGE_CLASS_NAME}`);
|
||||
if (!galleryImageEl) {
|
||||
// No images in gallery?
|
||||
return;
|
||||
}
|
||||
|
||||
const galleryImageRect = galleryImageEl.getBoundingClientRect();
|
||||
const containerRect = container.getBoundingClientRect();
|
||||
|
||||
if (!galleryImageRect.width || !galleryImageRect.height || !containerRect.width || !containerRect.height) {
|
||||
// Gallery is too small to fit images or not rendered yet
|
||||
return;
|
||||
}
|
||||
|
||||
// Floating-point precision requires we round to get the correct number of images per row
|
||||
const imagesPerRow = Math.round(containerRect.width / galleryImageRect.width);
|
||||
// However, when calculating the number of images per column, we want to floor the value to not overflow the container
|
||||
const imagesPerColumn = Math.floor(containerRect.height / galleryImageRect.height);
|
||||
// Always load at least 1 row of images
|
||||
const limit = Math.max(imagesPerRow, imagesPerRow * imagesPerColumn);
|
||||
dispatch(limitChanged(limit));
|
||||
}, 300);
|
||||
}, [container, dispatch]);
|
||||
|
||||
useEffect(() => {
|
||||
// We want to recalculate the limit when image size changes
|
||||
calculateNewLimit();
|
||||
}, [calculateNewLimit, galleryImageMinimumWidth]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
|
||||
const resizeObserver = new ResizeObserver(calculateNewLimit);
|
||||
resizeObserver.observe(container);
|
||||
|
||||
// First render
|
||||
calculateNewLimit();
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
}, [calculateNewLimit, container, dispatch]);
|
||||
|
||||
return (
|
||||
<Box position="relative" w="full" h="full">
|
||||
<Box
|
||||
ref={containerRef}
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
w="full"
|
||||
h="full"
|
||||
overflow="hidden"
|
||||
>
|
||||
<Grid
|
||||
className={GALLERY_GRID_CLASS_NAME}
|
||||
gridTemplateColumns={`repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`}
|
||||
>
|
||||
{imageDTOs.map((imageDTO, index) => (
|
||||
<GalleryImage key={imageDTO.image_name} imageDTO={imageDTO} index={index} />
|
||||
))}
|
||||
</Grid>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
import { Button, Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
|
||||
import { PiCaretDoubleLeftBold, PiCaretDoubleRightBold, PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||
|
||||
export const GalleryPagination = () => {
|
||||
const {
|
||||
goPrev,
|
||||
goNext,
|
||||
goToFirst,
|
||||
goToLast,
|
||||
isFirstEnabled,
|
||||
isLastEnabled,
|
||||
isPrevEnabled,
|
||||
isNextEnabled,
|
||||
pageButtons,
|
||||
goToPage,
|
||||
currentPage,
|
||||
rangeDisplay,
|
||||
total,
|
||||
} = useGalleryPagination();
|
||||
|
||||
if (!total) {
|
||||
return <Flex flexDir="column" alignItems="center" gap="2" height="48px"></Flex>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" alignItems="center" gap="2" height="48px">
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="prev"
|
||||
icon={<PiCaretDoubleLeftBold />}
|
||||
onClick={goToFirst}
|
||||
isDisabled={!isFirstEnabled}
|
||||
/>
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="prev"
|
||||
icon={<PiCaretLeftBold />}
|
||||
onClick={goPrev}
|
||||
isDisabled={!isPrevEnabled}
|
||||
/>
|
||||
<Spacer />
|
||||
{pageButtons.map((page) => (
|
||||
<Button
|
||||
size="sm"
|
||||
key={page}
|
||||
onClick={goToPage.bind(null, page)}
|
||||
variant={currentPage === page ? 'solid' : 'outline'}
|
||||
>
|
||||
{page + 1}
|
||||
</Button>
|
||||
))}
|
||||
<Spacer />
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="next"
|
||||
icon={<PiCaretRightBold />}
|
||||
onClick={goNext}
|
||||
isDisabled={!isNextEnabled}
|
||||
/>
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="next"
|
||||
icon={<PiCaretDoubleRightBold />}
|
||||
onClick={goToLast}
|
||||
isDisabled={!isLastEnabled}
|
||||
/>
|
||||
</Flex>
|
||||
<Text>{rangeDisplay}</Text>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -1,15 +0,0 @@
|
||||
import type { FlexProps } from '@invoke-ai/ui-library';
|
||||
import { Box, forwardRef } from '@invoke-ai/ui-library';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const imageItemContainerTestId = 'image-item-container';
|
||||
|
||||
type ItemContainerProps = PropsWithChildren & FlexProps;
|
||||
const ItemContainer = forwardRef((props: ItemContainerProps, ref) => (
|
||||
<Box className="item-container" ref={ref} p={1.5} data-testid={imageItemContainerTestId}>
|
||||
{props.children}
|
||||
</Box>
|
||||
));
|
||||
|
||||
export default memo(ItemContainer);
|
||||
@@ -1,26 +0,0 @@
|
||||
import type { FlexProps } from '@invoke-ai/ui-library';
|
||||
import { forwardRef, Grid } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const imageListContainerTestId = 'image-list-container';
|
||||
|
||||
type ListContainerProps = PropsWithChildren & FlexProps;
|
||||
const ListContainer = forwardRef((props: ListContainerProps, ref) => {
|
||||
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
|
||||
|
||||
return (
|
||||
<Grid
|
||||
{...props}
|
||||
className="list-container"
|
||||
ref={ref}
|
||||
gridTemplateColumns={`repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`}
|
||||
data-testid={imageListContainerTestId}
|
||||
>
|
||||
{props.children}
|
||||
</Grid>
|
||||
);
|
||||
});
|
||||
|
||||
export default memo(ListContainer);
|
||||
@@ -0,0 +1 @@
|
||||
export const GALLERY_GRID_CLASS_NAME = 'gallery-grid';
|
||||
@@ -2,6 +2,7 @@ import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, IconButton, Spinner } from '@invoke-ai/ui-library';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { useGalleryNavigation } from 'features/gallery/hooks/useGalleryNavigation';
|
||||
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDoubleRightBold, PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||
@@ -16,11 +17,8 @@ const NextPrevImageButtons = () => {
|
||||
|
||||
const { prevImage, nextImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
|
||||
|
||||
const {
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
queryResult: { isFetching },
|
||||
} = useGalleryImages();
|
||||
const { isFetching } = useGalleryImages().queryResult;
|
||||
const { isNextEnabled, goNext } = useGalleryPagination();
|
||||
|
||||
return (
|
||||
<Box pos="relative" h="full" w="full">
|
||||
@@ -47,17 +45,17 @@ const NextPrevImageButtons = () => {
|
||||
sx={nextPrevButtonStyles}
|
||||
/>
|
||||
)}
|
||||
{isOnLastImage && areMoreImagesAvailable && !isFetching && (
|
||||
{isOnLastImage && isNextEnabled && !isFetching && (
|
||||
<IconButton
|
||||
aria-label={t('accessibility.loadMore')}
|
||||
icon={<PiCaretDoubleRightBold size={64} />}
|
||||
variant="unstyled"
|
||||
onClick={handleLoadMoreImages}
|
||||
onClick={goNext}
|
||||
boxSize={16}
|
||||
sx={nextPrevButtonStyles}
|
||||
/>
|
||||
)}
|
||||
{isOnLastImage && areMoreImagesAvailable && isFetching && (
|
||||
{isOnLastImage && isNextEnabled && isFetching && (
|
||||
<Flex w={16} h={16} alignItems="center" justifyContent="center">
|
||||
<Spinner opacity={0.5} size="xl" />
|
||||
</Flex>
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { useGalleryNavigation } from 'features/gallery/hooks/useGalleryNavigation';
|
||||
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
|
||||
/**
|
||||
* Registers gallery hotkeys. This hook is a singleton.
|
||||
@@ -17,21 +19,30 @@ export const useGalleryHotkeys = () => {
|
||||
return activeTabName !== 'canvas' || !isStaging;
|
||||
}, [activeTabName, isStaging]);
|
||||
|
||||
const {
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
queryResult: { isFetching },
|
||||
} = useGalleryImages();
|
||||
const { goNext, goPrev, isNextEnabled, isPrevEnabled } = useGalleryPagination();
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const queryResult = useListImagesQuery(queryArgs);
|
||||
|
||||
const { handleLeftImage, handleRightImage, handleUpImage, handleDownImage, isOnLastImage, areImagesBelowCurrent } =
|
||||
useGalleryNavigation();
|
||||
const {
|
||||
handleLeftImage,
|
||||
handleRightImage,
|
||||
handleUpImage,
|
||||
handleDownImage,
|
||||
areImagesBelowCurrent,
|
||||
isOnFirstImageOfView,
|
||||
isOnLastImageOfView,
|
||||
} = useGalleryNavigation();
|
||||
|
||||
useHotkeys(
|
||||
['left', 'alt+left'],
|
||||
(e) => {
|
||||
if (isOnFirstImageOfView && isPrevEnabled && !queryResult.isFetching) {
|
||||
goPrev();
|
||||
return;
|
||||
}
|
||||
canNavigateGallery && handleLeftImage(e.altKey);
|
||||
},
|
||||
[handleLeftImage, canNavigateGallery]
|
||||
[handleLeftImage, canNavigateGallery, isOnFirstImageOfView, goPrev, isPrevEnabled, queryResult.isFetching]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
@@ -40,15 +51,15 @@ export const useGalleryHotkeys = () => {
|
||||
if (!canNavigateGallery) {
|
||||
return;
|
||||
}
|
||||
if (isOnLastImage && areMoreImagesAvailable && !isFetching) {
|
||||
handleLoadMoreImages();
|
||||
if (isOnLastImageOfView && isNextEnabled && !queryResult.isFetching) {
|
||||
goNext();
|
||||
return;
|
||||
}
|
||||
if (!isOnLastImage) {
|
||||
if (!isOnLastImageOfView) {
|
||||
handleRightImage(e.altKey);
|
||||
}
|
||||
},
|
||||
[isOnLastImage, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleRightImage, canNavigateGallery]
|
||||
[isOnLastImageOfView, goNext, isNextEnabled, queryResult.isFetching, handleRightImage, canNavigateGallery]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
@@ -63,13 +74,13 @@ export const useGalleryHotkeys = () => {
|
||||
useHotkeys(
|
||||
['down', 'alt+down'],
|
||||
(e) => {
|
||||
if (!areImagesBelowCurrent && areMoreImagesAvailable && !isFetching) {
|
||||
handleLoadMoreImages();
|
||||
if (!areImagesBelowCurrent && isNextEnabled && !queryResult.isFetching) {
|
||||
goNext();
|
||||
return;
|
||||
}
|
||||
handleDownImage(e.altKey);
|
||||
},
|
||||
{ preventDefault: true },
|
||||
[areImagesBelowCurrent, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleDownImage]
|
||||
[areImagesBelowCurrent, goNext, isNextEnabled, queryResult.isFetching, handleDownImage]
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,38 +1,15 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { moreImagesLoaded } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
import { useMemo } from 'react';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
|
||||
/**
|
||||
* Provides access to the gallery images and a way to imperatively fetch more.
|
||||
*/
|
||||
export const useGalleryImages = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const galleryView = useAppSelector((s) => s.gallery.galleryView);
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const queryResult = useListImagesQuery(queryArgs);
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(selectedBoardId);
|
||||
const { data: imagesTotal } = useGetBoardImagesTotalQuery(selectedBoardId);
|
||||
const currentViewTotal = useMemo(
|
||||
() => (galleryView === 'images' ? imagesTotal?.total : assetsTotal?.total),
|
||||
[assetsTotal?.total, galleryView, imagesTotal?.total]
|
||||
);
|
||||
const areMoreImagesAvailable = useMemo(() => {
|
||||
if (!currentViewTotal || !queryResult.data) {
|
||||
return false;
|
||||
}
|
||||
return queryResult.data.ids.length < currentViewTotal;
|
||||
}, [queryResult.data, currentViewTotal]);
|
||||
const handleLoadMoreImages = useCallback(() => {
|
||||
dispatch(moreImagesLoaded());
|
||||
}, [dispatch]);
|
||||
|
||||
const imageDTOs = useMemo(() => queryResult.data?.items ?? EMPTY_ARRAY, [queryResult.data]);
|
||||
return {
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
imageDTOs,
|
||||
queryResult,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { useAltModifier } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { GALLERY_GRID_CLASS_NAME } from 'features/gallery/components/ImageGrid/constants';
|
||||
import { GALLERY_IMAGE_CLASS_NAME } from 'features/gallery/components/ImageGrid/GalleryImage';
|
||||
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
|
||||
import { imageItemContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridItemContainer';
|
||||
import { imageListContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridListContainer';
|
||||
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { imageSelected, imageToCompareChanged } from 'features/gallery/store/gallerySlice';
|
||||
@@ -11,7 +11,6 @@ import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAli
|
||||
import { clamp } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
/**
|
||||
* This hook is used to navigate the gallery using the arrow keys.
|
||||
@@ -29,10 +28,9 @@ import { imagesSelectors } from 'services/api/util';
|
||||
*/
|
||||
const getImagesPerRow = (): number => {
|
||||
const widthOfGalleryImage =
|
||||
document.querySelector(`[data-testid="${imageItemContainerTestId}"]`)?.getBoundingClientRect().width ?? 1;
|
||||
document.querySelector(`.${GALLERY_IMAGE_CLASS_NAME}`)?.getBoundingClientRect().width ?? 1;
|
||||
|
||||
const widthOfGalleryGrid =
|
||||
document.querySelector(`[data-testid="${imageListContainerTestId}"]`)?.getBoundingClientRect().width ?? 0;
|
||||
const widthOfGalleryGrid = document.querySelector(`.${GALLERY_GRID_CLASS_NAME}`)?.getBoundingClientRect().width ?? 0;
|
||||
|
||||
const imagesPerRow = Math.round(widthOfGalleryGrid / widthOfGalleryImage);
|
||||
|
||||
@@ -115,6 +113,8 @@ type UseGalleryNavigationReturn = {
|
||||
isOnFirstImage: boolean;
|
||||
isOnLastImage: boolean;
|
||||
areImagesBelowCurrent: boolean;
|
||||
isOnFirstImageOfView: boolean;
|
||||
isOnLastImageOfView: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -134,23 +134,19 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
return lastSelected;
|
||||
}
|
||||
});
|
||||
const {
|
||||
queryResult: { data },
|
||||
} = useGalleryImages();
|
||||
const loadedImagesCount = useMemo(() => data?.ids.length ?? 0, [data?.ids.length]);
|
||||
const { imageDTOs } = useGalleryImages();
|
||||
const loadedImagesCount = useMemo(() => imageDTOs.length, [imageDTOs.length]);
|
||||
|
||||
const lastSelectedImageIndex = useMemo(() => {
|
||||
if (!data || !lastSelectedImage) {
|
||||
if (imageDTOs.length === 0 || !lastSelectedImage) {
|
||||
return 0;
|
||||
}
|
||||
return imagesSelectors.selectAll(data).findIndex((i) => i.image_name === lastSelectedImage.image_name);
|
||||
}, [lastSelectedImage, data]);
|
||||
return imageDTOs.findIndex((i) => i.image_name === lastSelectedImage.image_name);
|
||||
}, [imageDTOs, lastSelectedImage]);
|
||||
|
||||
const handleNavigation = useCallback(
|
||||
(direction: 'left' | 'right' | 'up' | 'down', alt?: boolean) => {
|
||||
if (!data) {
|
||||
return;
|
||||
}
|
||||
const { index, image } = getImageFuncs[direction](imagesSelectors.selectAll(data), lastSelectedImageIndex);
|
||||
const { index, image } = getImageFuncs[direction](imageDTOs, lastSelectedImageIndex);
|
||||
if (!image || index === lastSelectedImageIndex) {
|
||||
return;
|
||||
}
|
||||
@@ -161,7 +157,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
}
|
||||
scrollToImage(image.image_name, index);
|
||||
},
|
||||
[data, lastSelectedImageIndex, dispatch]
|
||||
[imageDTOs, lastSelectedImageIndex, dispatch]
|
||||
);
|
||||
|
||||
const isOnFirstImage = useMemo(() => lastSelectedImageIndex === 0, [lastSelectedImageIndex]);
|
||||
@@ -176,6 +172,14 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
return lastSelectedImageIndex + imagesPerRow < loadedImagesCount;
|
||||
}, [lastSelectedImageIndex, loadedImagesCount]);
|
||||
|
||||
const isOnFirstImageOfView = useMemo(() => {
|
||||
return lastSelectedImageIndex === 0;
|
||||
}, [lastSelectedImageIndex]);
|
||||
|
||||
const isOnLastImageOfView = useMemo(() => {
|
||||
return lastSelectedImageIndex === loadedImagesCount - 1;
|
||||
}, [lastSelectedImageIndex, loadedImagesCount]);
|
||||
|
||||
const handleLeftImage = useCallback(
|
||||
(alt?: boolean) => {
|
||||
handleNavigation('left', alt);
|
||||
@@ -222,5 +226,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
areImagesBelowCurrent,
|
||||
nextImage,
|
||||
prevImage,
|
||||
isOnFirstImageOfView,
|
||||
isOnLastImageOfView,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { offsetChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback, useEffect, useMemo } from 'react';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
|
||||
export const useGalleryPagination = (pageButtonsPerSide: number = 2) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { offset, limit } = useAppSelector((s) => s.gallery);
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
|
||||
const { count, total } = useListImagesQuery(queryArgs, {
|
||||
selectFromResult: ({ data }) => ({ count: data?.items.length ?? 0, total: data?.total ?? 0 }),
|
||||
});
|
||||
|
||||
const currentPage = useMemo(() => Math.ceil(offset / (limit || 0)), [offset, limit]);
|
||||
const pages = useMemo(() => Math.ceil(total / (limit || 0)), [total, limit]);
|
||||
|
||||
const isNextEnabled = useMemo(() => {
|
||||
if (!count) {
|
||||
return false;
|
||||
}
|
||||
return currentPage + 1 < pages;
|
||||
}, [count, currentPage, pages]);
|
||||
const isPrevEnabled = useMemo(() => {
|
||||
if (!count) {
|
||||
return false;
|
||||
}
|
||||
return offset > 0;
|
||||
}, [count, offset]);
|
||||
|
||||
const goNext = useCallback(() => {
|
||||
dispatch(offsetChanged(offset + (limit || 0)));
|
||||
}, [dispatch, offset, limit]);
|
||||
|
||||
const goPrev = useCallback(() => {
|
||||
dispatch(offsetChanged(Math.max(offset - (limit || 0), 0)));
|
||||
}, [dispatch, offset, limit]);
|
||||
|
||||
const goToPage = useCallback(
|
||||
(page: number) => {
|
||||
dispatch(offsetChanged(page * (limit || 0)));
|
||||
},
|
||||
[dispatch, limit]
|
||||
);
|
||||
const goToFirst = useCallback(() => {
|
||||
dispatch(offsetChanged(0));
|
||||
}, [dispatch]);
|
||||
const goToLast = useCallback(() => {
|
||||
dispatch(offsetChanged((pages - 1) * (limit || 0)));
|
||||
}, [dispatch, pages, limit]);
|
||||
|
||||
// handle when total/pages decrease and user is on high page number (ie bulk removing or deleting)
|
||||
useEffect(() => {
|
||||
if (pages && currentPage + 1 > pages) {
|
||||
goToLast();
|
||||
}
|
||||
}, [currentPage, pages, goToLast]);
|
||||
|
||||
// calculate the page buttons to display - current page with 3 around it
|
||||
const pageButtons = useMemo(() => {
|
||||
const buttons = [];
|
||||
const maxPageButtons = pageButtonsPerSide * 2 + 1;
|
||||
let startPage = Math.max(currentPage - Math.floor(maxPageButtons / 2), 0);
|
||||
const endPage = Math.min(startPage + maxPageButtons - 1, pages - 1);
|
||||
|
||||
if (endPage - startPage < maxPageButtons - 1) {
|
||||
startPage = Math.max(endPage - maxPageButtons + 1, 0);
|
||||
}
|
||||
|
||||
for (let i = startPage; i <= endPage; i++) {
|
||||
buttons.push(i);
|
||||
}
|
||||
|
||||
return buttons;
|
||||
}, [currentPage, pageButtonsPerSide, pages]);
|
||||
|
||||
const isFirstEnabled = useMemo(() => currentPage > 0, [currentPage]);
|
||||
const isLastEnabled = useMemo(() => currentPage < pages - 1, [currentPage, pages]);
|
||||
|
||||
const rangeDisplay = useMemo(() => {
|
||||
const startItem = currentPage * (limit || 0) + 1;
|
||||
const endItem = Math.min((currentPage + 1) * (limit || 0), total);
|
||||
return `${startItem}-${endItem} of ${total}`;
|
||||
}, [total, currentPage, limit]);
|
||||
|
||||
const numberOnPage = useMemo(() => {
|
||||
return Math.min((currentPage + 1) * (limit || 0), total);
|
||||
}, [currentPage, limit, total]);
|
||||
|
||||
const api = useMemo(
|
||||
() => ({
|
||||
count,
|
||||
total,
|
||||
currentPage,
|
||||
pages,
|
||||
isNextEnabled,
|
||||
isPrevEnabled,
|
||||
goNext,
|
||||
goPrev,
|
||||
goToPage,
|
||||
goToFirst,
|
||||
goToLast,
|
||||
pageButtons,
|
||||
isFirstEnabled,
|
||||
isLastEnabled,
|
||||
rangeDisplay,
|
||||
numberOnPage,
|
||||
}),
|
||||
[
|
||||
count,
|
||||
total,
|
||||
currentPage,
|
||||
pages,
|
||||
isNextEnabled,
|
||||
isPrevEnabled,
|
||||
goNext,
|
||||
goPrev,
|
||||
goToPage,
|
||||
goToFirst,
|
||||
goToLast,
|
||||
pageButtons,
|
||||
isFirstEnabled,
|
||||
isLastEnabled,
|
||||
rangeDisplay,
|
||||
numberOnPage,
|
||||
]
|
||||
);
|
||||
|
||||
return api;
|
||||
};
|
||||
@@ -1,3 +1,5 @@
|
||||
import type { SkipToken } from '@reduxjs/toolkit/query';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
@@ -10,11 +12,15 @@ export const selectLastSelectedImage = createMemoizedSelector(
|
||||
|
||||
export const selectListImagesQueryArgs = createMemoizedSelector(
|
||||
selectGallerySlice,
|
||||
(gallery): ListImagesArgs => ({
|
||||
board_id: gallery.selectedBoardId,
|
||||
categories: gallery.galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES,
|
||||
offset: gallery.offset,
|
||||
limit: gallery.limit,
|
||||
is_intermediate: false,
|
||||
})
|
||||
(gallery): ListImagesArgs | SkipToken =>
|
||||
gallery.limit
|
||||
? {
|
||||
board_id: gallery.selectedBoardId,
|
||||
categories: gallery.galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES,
|
||||
offset: gallery.offset,
|
||||
limit: gallery.limit,
|
||||
is_intermediate: false,
|
||||
search_term: gallery.searchTerm,
|
||||
}
|
||||
: skipToken
|
||||
);
|
||||
|
||||
@@ -7,7 +7,7 @@ import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import type { BoardId, ComparisonMode, GalleryState, GalleryView } from './types';
|
||||
import { IMAGE_LIMIT, INITIAL_IMAGE_LIMIT } from './types';
|
||||
import { IMAGE_LIMIT } from './types';
|
||||
|
||||
const initialGalleryState: GalleryState = {
|
||||
selection: [],
|
||||
@@ -19,7 +19,7 @@ const initialGalleryState: GalleryState = {
|
||||
selectedBoardId: 'none',
|
||||
galleryView: 'images',
|
||||
boardSearchText: '',
|
||||
limit: INITIAL_IMAGE_LIMIT,
|
||||
limit: 20,
|
||||
offset: 0,
|
||||
isImageViewerOpen: true,
|
||||
imageToCompare: null,
|
||||
@@ -72,7 +72,6 @@ export const gallerySlice = createSlice({
|
||||
state.selectedBoardId = action.payload.boardId;
|
||||
state.galleryView = 'images';
|
||||
state.offset = 0;
|
||||
state.limit = INITIAL_IMAGE_LIMIT;
|
||||
},
|
||||
autoAddBoardIdChanged: (state, action: PayloadAction<BoardId>) => {
|
||||
if (!action.payload) {
|
||||
@@ -84,20 +83,11 @@ export const gallerySlice = createSlice({
|
||||
galleryViewChanged: (state, action: PayloadAction<GalleryView>) => {
|
||||
state.galleryView = action.payload;
|
||||
state.offset = 0;
|
||||
state.limit = INITIAL_IMAGE_LIMIT;
|
||||
state.limit = IMAGE_LIMIT;
|
||||
},
|
||||
boardSearchTextChanged: (state, action: PayloadAction<string>) => {
|
||||
state.boardSearchText = action.payload;
|
||||
},
|
||||
moreImagesLoaded: (state) => {
|
||||
if (state.offset === 0 && state.limit === INITIAL_IMAGE_LIMIT) {
|
||||
state.offset = INITIAL_IMAGE_LIMIT;
|
||||
state.limit = IMAGE_LIMIT;
|
||||
} else {
|
||||
state.offset += IMAGE_LIMIT;
|
||||
state.limit += IMAGE_LIMIT;
|
||||
}
|
||||
},
|
||||
alwaysShowImageSizeBadgeChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.alwaysShowImageSizeBadge = action.payload;
|
||||
},
|
||||
@@ -114,6 +104,15 @@ export const gallerySlice = createSlice({
|
||||
comparisonFitChanged: (state, action: PayloadAction<'contain' | 'fill'>) => {
|
||||
state.comparisonFit = action.payload;
|
||||
},
|
||||
offsetChanged: (state, action: PayloadAction<number>) => {
|
||||
state.offset = action.payload;
|
||||
},
|
||||
limitChanged: (state, action: PayloadAction<number>) => {
|
||||
state.limit = action.payload;
|
||||
},
|
||||
searchTermChanged: (state, action: PayloadAction<string | undefined>) => {
|
||||
state.searchTerm = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
|
||||
@@ -149,7 +148,6 @@ export const {
|
||||
galleryViewChanged,
|
||||
selectionChanged,
|
||||
boardSearchTextChanged,
|
||||
moreImagesLoaded,
|
||||
alwaysShowImageSizeBadgeChanged,
|
||||
isImageViewerOpenChanged,
|
||||
imageToCompareChanged,
|
||||
@@ -157,6 +155,9 @@ export const {
|
||||
comparedImagesSwapped,
|
||||
comparisonFitChanged,
|
||||
comparisonModeCycled,
|
||||
offsetChanged,
|
||||
limitChanged,
|
||||
searchTermChanged,
|
||||
} = gallerySlice.actions;
|
||||
|
||||
const isAnyBoardDeleted = isAnyOf(
|
||||
|
||||
@@ -2,8 +2,7 @@ import type { ImageCategory, ImageDTO } from 'services/api/types';
|
||||
|
||||
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
|
||||
export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other'];
|
||||
export const INITIAL_IMAGE_LIMIT = 100;
|
||||
export const IMAGE_LIMIT = 20;
|
||||
export const IMAGE_LIMIT = 15;
|
||||
|
||||
export type GalleryView = 'images' | 'assets';
|
||||
export type BoardId = 'none' | (string & Record<never, never>);
|
||||
@@ -21,6 +20,7 @@ export type GalleryState = {
|
||||
boardSearchText: string;
|
||||
offset: number;
|
||||
limit: number;
|
||||
searchTerm?: string;
|
||||
alwaysShowImageSizeBadge: boolean;
|
||||
imageToCompare: ImageDTO | null;
|
||||
comparisonMode: ComparisonMode;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,18 +0,0 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
|
||||
export const useBoardTotal = (board_id: BoardId) => {
|
||||
const galleryView = useAppSelector((s) => s.gallery.galleryView);
|
||||
|
||||
const { data: totalImages } = useGetBoardImagesTotalQuery(board_id);
|
||||
const { data: totalAssets } = useGetBoardAssetsTotalQuery(board_id);
|
||||
|
||||
const currentViewTotal = useMemo(
|
||||
() => (galleryView === 'images' ? totalImages?.total : totalAssets?.total),
|
||||
[galleryView, totalAssets, totalImages]
|
||||
);
|
||||
|
||||
return { totalImages, totalAssets, currentViewTotal };
|
||||
};
|
||||
@@ -7283,144 +7283,144 @@ export type components = {
|
||||
project_id: string | null;
|
||||
};
|
||||
InvocationOutputMap: {
|
||||
midas_depth_image_processor: components["schemas"]["ImageOutput"];
|
||||
lscale: components["schemas"]["LatentsOutput"];
|
||||
string_split: components["schemas"]["String2Output"];
|
||||
mask_edge: components["schemas"]["ImageOutput"];
|
||||
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
|
||||
color_correct: components["schemas"]["ImageOutput"];
|
||||
save_image: components["schemas"]["ImageOutput"];
|
||||
show_image: components["schemas"]["ImageOutput"];
|
||||
segment_anything_processor: components["schemas"]["ImageOutput"];
|
||||
latents: components["schemas"]["LatentsOutput"];
|
||||
lineart_image_processor: components["schemas"]["ImageOutput"];
|
||||
hed_image_processor: components["schemas"]["ImageOutput"];
|
||||
infill_lama: components["schemas"]["ImageOutput"];
|
||||
infill_patchmatch: components["schemas"]["ImageOutput"];
|
||||
float_collection: components["schemas"]["FloatCollectionOutput"];
|
||||
denoise_latents: components["schemas"]["LatentsOutput"];
|
||||
metadata: components["schemas"]["MetadataOutput"];
|
||||
compel: components["schemas"]["ConditioningOutput"];
|
||||
img_blur: components["schemas"]["ImageOutput"];
|
||||
img_crop: components["schemas"]["ImageOutput"];
|
||||
image_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||
sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
||||
img_ilerp: components["schemas"]["ImageOutput"];
|
||||
img_paste: components["schemas"]["ImageOutput"];
|
||||
core_metadata: components["schemas"]["MetadataOutput"];
|
||||
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
|
||||
lora_selector: components["schemas"]["LoRASelectorOutput"];
|
||||
create_denoise_mask: components["schemas"]["DenoiseMaskOutput"];
|
||||
rectangle_mask: components["schemas"]["MaskOutput"];
|
||||
noise: components["schemas"]["NoiseOutput"];
|
||||
float_to_int: components["schemas"]["IntegerOutput"];
|
||||
esrgan: components["schemas"]["ImageOutput"];
|
||||
merge_tiles_to_image: components["schemas"]["ImageOutput"];
|
||||
prompt_from_file: components["schemas"]["StringCollectionOutput"];
|
||||
infill_rgba: components["schemas"]["ImageOutput"];
|
||||
sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
||||
lora_loader: components["schemas"]["LoRALoaderOutput"];
|
||||
iterate: components["schemas"]["IterateInvocationOutput"];
|
||||
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
|
||||
color_map_image_processor: components["schemas"]["ImageOutput"];
|
||||
blank_image: components["schemas"]["ImageOutput"];
|
||||
normalbae_image_processor: components["schemas"]["ImageOutput"];
|
||||
canvas_paste_back: components["schemas"]["ImageOutput"];
|
||||
string_split_neg: components["schemas"]["StringPosNegOutput"];
|
||||
img_channel_offset: components["schemas"]["ImageOutput"];
|
||||
face_mask_detection: components["schemas"]["FaceMaskOutput"];
|
||||
cv_inpaint: components["schemas"]["ImageOutput"];
|
||||
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
|
||||
latents_collection: components["schemas"]["LatentsCollectionOutput"];
|
||||
metadata: components["schemas"]["MetadataOutput"];
|
||||
invert_tensor_mask: components["schemas"]["MaskOutput"];
|
||||
tomask: components["schemas"]["ImageOutput"];
|
||||
main_model_loader: components["schemas"]["ModelLoaderOutput"];
|
||||
img_watermark: components["schemas"]["ImageOutput"];
|
||||
img_pad_crop: components["schemas"]["ImageOutput"];
|
||||
random_range: components["schemas"]["IntegerCollectionOutput"];
|
||||
mlsd_image_processor: components["schemas"]["ImageOutput"];
|
||||
merge_metadata: components["schemas"]["MetadataOutput"];
|
||||
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
|
||||
string_split: components["schemas"]["String2Output"];
|
||||
integer_collection: components["schemas"]["IntegerCollectionOutput"];
|
||||
boolean_collection: components["schemas"]["BooleanCollectionOutput"];
|
||||
noise: components["schemas"]["NoiseOutput"];
|
||||
float_math: components["schemas"]["FloatOutput"];
|
||||
seamless: components["schemas"]["SeamlessModeOutput"];
|
||||
img_lerp: components["schemas"]["ImageOutput"];
|
||||
img_blur: components["schemas"]["ImageOutput"];
|
||||
string_join: components["schemas"]["StringOutput"];
|
||||
vae_loader: components["schemas"]["VAEOutput"];
|
||||
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
|
||||
calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"];
|
||||
mask_from_id: components["schemas"]["ImageOutput"];
|
||||
zoe_depth_image_processor: components["schemas"]["ImageOutput"];
|
||||
img_resize: components["schemas"]["ImageOutput"];
|
||||
string_replace: components["schemas"]["StringOutput"];
|
||||
face_identifier: components["schemas"]["ImageOutput"];
|
||||
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
|
||||
mul: components["schemas"]["IntegerOutput"];
|
||||
l2i: components["schemas"]["ImageOutput"];
|
||||
img_chan: components["schemas"]["ImageOutput"];
|
||||
conditioning_collection: components["schemas"]["ConditioningCollectionOutput"];
|
||||
blank_image: components["schemas"]["ImageOutput"];
|
||||
ip_adapter: components["schemas"]["IPAdapterOutput"];
|
||||
tile_image_processor: components["schemas"]["ImageOutput"];
|
||||
integer_math: components["schemas"]["IntegerOutput"];
|
||||
infill_tile: components["schemas"]["ImageOutput"];
|
||||
color_correct: components["schemas"]["ImageOutput"];
|
||||
show_image: components["schemas"]["ImageOutput"];
|
||||
float: components["schemas"]["FloatOutput"];
|
||||
prompt_from_file: components["schemas"]["StringCollectionOutput"];
|
||||
merge_metadata: components["schemas"]["MetadataOutput"];
|
||||
img_scale: components["schemas"]["ImageOutput"];
|
||||
string_join_three: components["schemas"]["StringOutput"];
|
||||
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
|
||||
freeu: components["schemas"]["UNetOutput"];
|
||||
img_channel_multiply: components["schemas"]["ImageOutput"];
|
||||
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||
img_conv: components["schemas"]["ImageOutput"];
|
||||
latents: components["schemas"]["LatentsOutput"];
|
||||
face_mask_detection: components["schemas"]["FaceMaskOutput"];
|
||||
canny_image_processor: components["schemas"]["ImageOutput"];
|
||||
collect: components["schemas"]["CollectInvocationOutput"];
|
||||
infill_tile: components["schemas"]["ImageOutput"];
|
||||
integer_collection: components["schemas"]["IntegerCollectionOutput"];
|
||||
img_lerp: components["schemas"]["ImageOutput"];
|
||||
step_param_easing: components["schemas"]["FloatCollectionOutput"];
|
||||
lresize: components["schemas"]["LatentsOutput"];
|
||||
img_mul: components["schemas"]["ImageOutput"];
|
||||
create_gradient_mask: components["schemas"]["GradientMaskOutput"];
|
||||
img_scale: components["schemas"]["ImageOutput"];
|
||||
rand_float: components["schemas"]["FloatOutput"];
|
||||
tile_to_properties: components["schemas"]["TileToPropertiesOutput"];
|
||||
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
|
||||
range_of_size: components["schemas"]["IntegerCollectionOutput"];
|
||||
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
|
||||
heuristic_resize: components["schemas"]["ImageOutput"];
|
||||
controlnet: components["schemas"]["ControlOutput"];
|
||||
string: components["schemas"]["StringOutput"];
|
||||
tile_image_processor: components["schemas"]["ImageOutput"];
|
||||
metadata_item: components["schemas"]["MetadataItemOutput"];
|
||||
freeu: components["schemas"]["UNetOutput"];
|
||||
round_float: components["schemas"]["FloatOutput"];
|
||||
conditioning: components["schemas"]["ConditioningOutput"];
|
||||
ideal_size: components["schemas"]["IdealSizeOutput"];
|
||||
float: components["schemas"]["FloatOutput"];
|
||||
conditioning_collection: components["schemas"]["ConditioningCollectionOutput"];
|
||||
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||
integer_math: components["schemas"]["IntegerOutput"];
|
||||
string_collection: components["schemas"]["StringCollectionOutput"];
|
||||
img_conv: components["schemas"]["ImageOutput"];
|
||||
img_channel_multiply: components["schemas"]["ImageOutput"];
|
||||
lblend: components["schemas"]["LatentsOutput"];
|
||||
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
|
||||
color: components["schemas"]["ColorOutput"];
|
||||
image: components["schemas"]["ImageOutput"];
|
||||
sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"];
|
||||
image_collection: components["schemas"]["ImageCollectionOutput"];
|
||||
model_identifier: components["schemas"]["ModelIdentifierOutput"];
|
||||
l2i: components["schemas"]["ImageOutput"];
|
||||
seamless: components["schemas"]["SeamlessModeOutput"];
|
||||
boolean_collection: components["schemas"]["BooleanCollectionOutput"];
|
||||
string_join_three: components["schemas"]["StringOutput"];
|
||||
ip_adapter: components["schemas"]["IPAdapterOutput"];
|
||||
add: components["schemas"]["IntegerOutput"];
|
||||
crop_latents: components["schemas"]["LatentsOutput"];
|
||||
float_range: components["schemas"]["FloatCollectionOutput"];
|
||||
mul: components["schemas"]["IntegerOutput"];
|
||||
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
|
||||
boolean: components["schemas"]["BooleanOutput"];
|
||||
dynamic_prompt: components["schemas"]["StringCollectionOutput"];
|
||||
mediapipe_face_processor: components["schemas"]["ImageOutput"];
|
||||
i2l: components["schemas"]["LatentsOutput"];
|
||||
latents_collection: components["schemas"]["LatentsCollectionOutput"];
|
||||
integer: components["schemas"]["IntegerOutput"];
|
||||
img_chan: components["schemas"]["ImageOutput"];
|
||||
pair_tile_image: components["schemas"]["PairTileImageOutput"];
|
||||
unsharp_mask: components["schemas"]["ImageOutput"];
|
||||
img_hue_adjust: components["schemas"]["ImageOutput"];
|
||||
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
|
||||
face_off: components["schemas"]["FaceOffOutput"];
|
||||
mask_combine: components["schemas"]["ImageOutput"];
|
||||
leres_image_processor: components["schemas"]["ImageOutput"];
|
||||
image_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||
scheduler: components["schemas"]["SchedulerOutput"];
|
||||
sub: components["schemas"]["IntegerOutput"];
|
||||
pidi_image_processor: components["schemas"]["ImageOutput"];
|
||||
infill_cv2: components["schemas"]["ImageOutput"];
|
||||
div: components["schemas"]["IntegerOutput"];
|
||||
img_nsfw: components["schemas"]["ImageOutput"];
|
||||
depth_anything_image_processor: components["schemas"]["ImageOutput"];
|
||||
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||
range: components["schemas"]["IntegerCollectionOutput"];
|
||||
range_of_size: components["schemas"]["IntegerCollectionOutput"];
|
||||
img_resize: components["schemas"]["ImageOutput"];
|
||||
img_watermark: components["schemas"]["ImageOutput"];
|
||||
esrgan: components["schemas"]["ImageOutput"];
|
||||
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
|
||||
img_paste: components["schemas"]["ImageOutput"];
|
||||
face_identifier: components["schemas"]["ImageOutput"];
|
||||
create_denoise_mask: components["schemas"]["DenoiseMaskOutput"];
|
||||
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
|
||||
round_float: components["schemas"]["FloatOutput"];
|
||||
calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"];
|
||||
lscale: components["schemas"]["LatentsOutput"];
|
||||
rand_int: components["schemas"]["IntegerOutput"];
|
||||
float_math: components["schemas"]["FloatOutput"];
|
||||
infill_cv2: components["schemas"]["ImageOutput"];
|
||||
sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
||||
img_nsfw: components["schemas"]["ImageOutput"];
|
||||
main_model_loader: components["schemas"]["ModelLoaderOutput"];
|
||||
tomask: components["schemas"]["ImageOutput"];
|
||||
string_replace: components["schemas"]["StringOutput"];
|
||||
face_off: components["schemas"]["FaceOffOutput"];
|
||||
string: components["schemas"]["StringOutput"];
|
||||
heuristic_resize: components["schemas"]["ImageOutput"];
|
||||
midas_depth_image_processor: components["schemas"]["ImageOutput"];
|
||||
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||
mask_combine: components["schemas"]["ImageOutput"];
|
||||
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
|
||||
image: components["schemas"]["ImageOutput"];
|
||||
infill_rgba: components["schemas"]["ImageOutput"];
|
||||
img_hue_adjust: components["schemas"]["ImageOutput"];
|
||||
vae_loader: components["schemas"]["VAEOutput"];
|
||||
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||
segment_anything_processor: components["schemas"]["ImageOutput"];
|
||||
sub: components["schemas"]["IntegerOutput"];
|
||||
iterate: components["schemas"]["IterateInvocationOutput"];
|
||||
img_mul: components["schemas"]["ImageOutput"];
|
||||
denoise_latents: components["schemas"]["LatentsOutput"];
|
||||
lineart_image_processor: components["schemas"]["ImageOutput"];
|
||||
rand_float: components["schemas"]["FloatOutput"];
|
||||
rectangle_mask: components["schemas"]["MaskOutput"];
|
||||
lora_selector: components["schemas"]["LoRASelectorOutput"];
|
||||
pair_tile_image: components["schemas"]["PairTileImageOutput"];
|
||||
cv_inpaint: components["schemas"]["ImageOutput"];
|
||||
hed_image_processor: components["schemas"]["ImageOutput"];
|
||||
range: components["schemas"]["IntegerCollectionOutput"];
|
||||
img_pad_crop: components["schemas"]["ImageOutput"];
|
||||
string_split_neg: components["schemas"]["StringPosNegOutput"];
|
||||
string_collection: components["schemas"]["StringCollectionOutput"];
|
||||
zoe_depth_image_processor: components["schemas"]["ImageOutput"];
|
||||
save_image: components["schemas"]["ImageOutput"];
|
||||
img_ilerp: components["schemas"]["ImageOutput"];
|
||||
compel: components["schemas"]["ConditioningOutput"];
|
||||
unsharp_mask: components["schemas"]["ImageOutput"];
|
||||
image_collection: components["schemas"]["ImageCollectionOutput"];
|
||||
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
|
||||
float_to_int: components["schemas"]["IntegerOutput"];
|
||||
random_range: components["schemas"]["IntegerCollectionOutput"];
|
||||
ideal_size: components["schemas"]["IdealSizeOutput"];
|
||||
i2l: components["schemas"]["LatentsOutput"];
|
||||
infill_patchmatch: components["schemas"]["ImageOutput"];
|
||||
depth_anything_image_processor: components["schemas"]["ImageOutput"];
|
||||
infill_lama: components["schemas"]["ImageOutput"];
|
||||
mask_from_id: components["schemas"]["ImageOutput"];
|
||||
conditioning: components["schemas"]["ConditioningOutput"];
|
||||
lresize: components["schemas"]["LatentsOutput"];
|
||||
step_param_easing: components["schemas"]["FloatCollectionOutput"];
|
||||
metadata_item: components["schemas"]["MetadataItemOutput"];
|
||||
controlnet: components["schemas"]["ControlOutput"];
|
||||
merge_tiles_to_image: components["schemas"]["ImageOutput"];
|
||||
boolean: components["schemas"]["BooleanOutput"];
|
||||
core_metadata: components["schemas"]["MetadataOutput"];
|
||||
img_channel_offset: components["schemas"]["ImageOutput"];
|
||||
model_identifier: components["schemas"]["ModelIdentifierOutput"];
|
||||
scheduler: components["schemas"]["SchedulerOutput"];
|
||||
create_gradient_mask: components["schemas"]["GradientMaskOutput"];
|
||||
color_map_image_processor: components["schemas"]["ImageOutput"];
|
||||
canvas_paste_back: components["schemas"]["ImageOutput"];
|
||||
mask_edge: components["schemas"]["ImageOutput"];
|
||||
lora_loader: components["schemas"]["LoRALoaderOutput"];
|
||||
float_collection: components["schemas"]["FloatCollectionOutput"];
|
||||
float_range: components["schemas"]["FloatCollectionOutput"];
|
||||
normalbae_image_processor: components["schemas"]["ImageOutput"];
|
||||
lblend: components["schemas"]["LatentsOutput"];
|
||||
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
|
||||
dynamic_prompt: components["schemas"]["StringCollectionOutput"];
|
||||
leres_image_processor: components["schemas"]["ImageOutput"];
|
||||
add: components["schemas"]["IntegerOutput"];
|
||||
tile_to_properties: components["schemas"]["TileToPropertiesOutput"];
|
||||
img_crop: components["schemas"]["ImageOutput"];
|
||||
integer: components["schemas"]["IntegerOutput"];
|
||||
crop_latents: components["schemas"]["LatentsOutput"];
|
||||
mlsd_image_processor: components["schemas"]["ImageOutput"];
|
||||
};
|
||||
/**
|
||||
* InvocationStartedEvent
|
||||
@@ -14108,7 +14108,7 @@ export type operations = {
|
||||
install_hugging_face_model: {
|
||||
parameters: {
|
||||
query: {
|
||||
/** @description Hugging Face repo_id to install */
|
||||
/** @description HuggingFace repo_id to install */
|
||||
source: string;
|
||||
};
|
||||
};
|
||||
@@ -14698,6 +14698,8 @@ export type operations = {
|
||||
offset?: number;
|
||||
/** @description The number of images per page */
|
||||
limit?: number;
|
||||
/** @description The term to search for */
|
||||
search_term?: string | null;
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import type { components, paths } from 'services/api/schema';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
|
||||
export type S = components['schemas'];
|
||||
|
||||
export type ImageCache = EntityState<ImageDTO, string>;
|
||||
|
||||
export type ListImagesArgs = NonNullable<paths['/api/v1/images/']['get']['parameters']['query']>;
|
||||
export type ListImagesResponse = paths['/api/v1/images/']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
export type DeleteBoardResult =
|
||||
paths['/api/v1/boards/{board_id}']['delete']['responses']['200']['content']['application/json'];
|
||||
|
||||
@@ -1,56 +1,8 @@
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import { dateComparator } from 'common/util/dateComparator';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import queryString from 'query-string';
|
||||
import { buildV1Url } from 'services/api';
|
||||
|
||||
import type { ImageCache, ImageDTO, ListImagesArgs } from './types';
|
||||
|
||||
export const getIsImageInDateRange = (data: ImageCache | undefined, imageDTO: ImageDTO) => {
|
||||
if (!data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const totalCachedImageDtos = imagesSelectors.selectAll(data);
|
||||
|
||||
if (totalCachedImageDtos.length <= 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const cachedStarredImages = [];
|
||||
const cachedUnstarredImages = [];
|
||||
|
||||
for (let index = 0; index < totalCachedImageDtos.length; index++) {
|
||||
const image = totalCachedImageDtos[index];
|
||||
if (image?.starred) {
|
||||
cachedStarredImages.push(image);
|
||||
}
|
||||
if (!image?.starred) {
|
||||
cachedUnstarredImages.push(image);
|
||||
}
|
||||
}
|
||||
|
||||
if (imageDTO.starred) {
|
||||
const lastStarredImage = cachedStarredImages[cachedStarredImages.length - 1];
|
||||
// if starring or already starred, want to look in list of starred images
|
||||
if (!lastStarredImage) {
|
||||
return true;
|
||||
} // no starred images showing, so always show this one
|
||||
const createdDate = new Date(imageDTO.created_at);
|
||||
const oldestDate = new Date(lastStarredImage.created_at);
|
||||
return createdDate >= oldestDate;
|
||||
} else {
|
||||
const lastUnstarredImage = cachedUnstarredImages[cachedUnstarredImages.length - 1];
|
||||
// if unstarring or already unstarred, want to look in list of unstarred images
|
||||
if (!lastUnstarredImage) {
|
||||
return false;
|
||||
} // no unstarred images showing, so don't show this one
|
||||
const createdDate = new Date(imageDTO.created_at);
|
||||
const oldestDate = new Date(lastUnstarredImage.created_at);
|
||||
return createdDate >= oldestDate;
|
||||
}
|
||||
};
|
||||
import type { ImageDTO, ListImagesArgs } from './types';
|
||||
|
||||
export const getCategories = (imageDTO: ImageDTO) => {
|
||||
if (IMAGE_CATEGORIES.includes(imageDTO.image_category)) {
|
||||
@@ -59,25 +11,6 @@ export const getCategories = (imageDTO: ImageDTO) => {
|
||||
return ASSETS_CATEGORIES;
|
||||
};
|
||||
|
||||
// The adapter is not actually the data store - it just provides helper functions to interact
|
||||
// with some other store of data. We will use the RTK Query cache as that store.
|
||||
export const imagesAdapter = createEntityAdapter<ImageDTO, string>({
|
||||
selectId: (image) => image.image_name,
|
||||
sortComparer: (a, b) => {
|
||||
// Compare starred images first
|
||||
if (a.starred && !b.starred) {
|
||||
return -1;
|
||||
}
|
||||
if (!a.starred && b.starred) {
|
||||
return 1;
|
||||
}
|
||||
return dateComparator(b.created_at, a.created_at);
|
||||
},
|
||||
});
|
||||
|
||||
// Create selectors for the adapter.
|
||||
export const imagesSelectors = imagesAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
|
||||
// Helper to create the url for the listImages endpoint. Also we use it to create the cache key.
|
||||
export const getListImagesUrl = (queryArgs: ListImagesArgs) =>
|
||||
buildV1Url(`images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`);
|
||||
|
||||
Reference in New Issue
Block a user