Compare commits

..

12 Commits

Author SHA1 Message Date
psychedelicious
3cd836efde feat(app): cleanup of board queries 2024-09-24 15:35:20 +10:00
psychedelicious
dca5a2ce26 fix(app): fix swapped image counts for uncategorized 2024-07-20 19:57:12 +10:00
psychedelicious
9d3a72fff3 docs(app): add comments to boards queries 2024-07-15 17:25:47 +10:00
psychedelicious
c586d65a54 tidy(app): remove extraneous condition from query 2024-07-15 17:25:47 +10:00
psychedelicious
25107e427c tidy(app): move sqlite-specific objects to sqlite file 2024-07-15 17:25:47 +10:00
psychedelicious
a30d143c5a fix(ui): remove unused queries and fix invocation complete listener 2024-07-15 17:25:47 +10:00
psychedelicious
a6f1148676 feat(ui): use new uncategorized image counts & updated list boards queries 2024-07-15 17:25:47 +10:00
psychedelicious
2843a6a227 chore(ui): typegen
chore(ui): typegen
2024-07-15 17:25:47 +10:00
psychedelicious
0484f458b6 tests(app): fix bulk downloads test 2024-07-15 17:25:47 +10:00
psychedelicious
c05f97d8ca feat(app): refactor board record to include image & asset counts and cover image
This _substantially_ reduces the number of queries required to list all boards. A single query now gets one, all, or a page of boards, including counts and cover image name.

- Add helpers to build the queries, which share a common base with some joins.
- Update `BoardRecord` to include the counts.
- Update `BoardDTO`, which is now identical to `BoardRecord`. I opted to not remove `BoardDTO` because it is used in many places.
- Update boards high-level service and board records services accordingly.
2024-07-15 17:25:47 +10:00
psychedelicious
a95aa6cc16 feat(api): add get_uncategorized_image_counts endpoint 2024-07-15 17:25:47 +10:00
psychedelicious
c74b9a40af feat(app): add get_uncategorized_image_counts method on board_records service 2024-07-15 17:25:47 +10:00
207 changed files with 3119 additions and 7338 deletions

View File

@@ -55,7 +55,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
FROM node:20-slim AS web-builder
ENV PNPM_HOME="/pnpm"
ENV PATH="$PNPM_HOME:$PATH"
RUN corepack use pnpm@8.x
RUN corepack enable
WORKDIR /build

View File

@@ -5,7 +5,7 @@ from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.board_records.board_records_common import BoardChanges, UncategorizedImageCounts
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
@@ -146,3 +146,14 @@ async def list_all_board_image_names(
board_id,
)
return image_names
@boards_router.get(
"/uncategorized/counts",
operation_id="get_uncategorized_image_counts",
response_model=UncategorizedImageCounts,
)
async def get_uncategorized_image_counts() -> UncategorizedImageCounts:
"""Gets count of images and assets for uncategorized images (images with no board assocation)"""
return ApiDependencies.invoker.services.board_records.get_uncategorized_image_counts()

View File

@@ -6,7 +6,7 @@ import pathlib
import traceback
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import List, Optional, Type
from typing import Any, Dict, List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse, HTMLResponse
@@ -430,11 +430,13 @@ async def delete_model_image(
async def install_model(
source: str = Query(description="Model source to install, can be a local path, repo_id, or remote URL"),
inplace: Optional[bool] = Query(description="Whether or not to install a local model in place", default=False),
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
config: ModelRecordChanges = Body(
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
# TODO(MM2): Can we type this?
config: Optional[Dict[str, Any]] = Body(
description="Dict of fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
default=None,
example={"name": "string", "description": "string"},
),
access_token: Optional[str] = None,
) -> ModelInstallJob:
"""Install a model using a string identifier.
@@ -449,9 +451,8 @@ async def install_model(
- model/name:fp16:path/to/model.safetensors
- model/name::path/to/model.safetensors
`config` is a ModelRecordChanges object. Fields in this object will override
the ones that are probed automatically. Pass an empty object to accept
all the defaults.
`config` is an optional dict containing model configuration values that will override
the ones that are probed automatically.
`access_token` is an optional access token for use with Urls that require
authentication.
@@ -736,7 +737,7 @@ async def convert_model(
# write the converted file to the convert path
raw_model = converted_model.model
assert hasattr(raw_model, "save_pretrained")
raw_model.save_pretrained(convert_path) # type: ignore
raw_model.save_pretrained(convert_path)
assert convert_path.exists()
# temporarily rename the original safetensors file so that there is no naming conflict
@@ -749,12 +750,12 @@ async def convert_model(
try:
new_key = installer.install_path(
convert_path,
config=ModelRecordChanges(
name=original_name,
description=model_config.description,
hash=model_config.hash,
source=model_config.source,
),
config={
"name": original_name,
"description": model_config.description,
"hash": model_config.hash,
"source": model_config.source,
},
)
except Exception as e:
logger.error(str(e))

View File

@@ -1,6 +1,5 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
import inspect
import os
from contextlib import ExitStack
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
@@ -40,7 +39,6 @@ from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
ControlNetData,
StableDiffusionGeneratorPipeline,
@@ -55,14 +53,6 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.util.devices import TorchDevice
@@ -324,10 +314,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
device: torch.device,
dtype: torch.dtype,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
@@ -341,10 +330,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
uncond_list = [uncond_list]
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, device, dtype
cond_list, context, unet.device, unet.dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, device, dtype
uncond_list, context, unet.device, unet.dtype
)
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
@@ -352,14 +341,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
dtype=unet.dtype,
)
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
dtype=unet.dtype,
)
if isinstance(cfg_scale, list):
@@ -466,38 +455,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
return controlnet_data
@staticmethod
def parse_controlnet_field(
exit_stack: ExitStack,
context: InvocationContext,
control_input: ControlField | list[ControlField] | None,
ext_manager: ExtensionsManager,
) -> None:
# Normalize control_input to a list.
control_list: list[ControlField]
if isinstance(control_input, ControlField):
control_list = [control_input]
elif isinstance(control_input, list):
control_list = control_input
elif control_input is None:
control_list = []
else:
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
for control_info in control_list:
model = exit_stack.enter_context(context.models.load(control_info.control_model))
ext_manager.add_extension(
ControlNetExt(
model=model,
image=context.images.get_pil(control_info.image.image_name),
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
)
def prep_ip_adapter_image_prompts(
self,
context: InvocationContext,
@@ -750,124 +707,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
return seed, noise, latents
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def invoke(self, context: InvocationContext) -> LatentsOutput:
if os.environ.get("USE_MODULAR_DENOISE", False):
return self._new_invoke(context)
else:
return self._old_invoke(context)
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled)
device = TorchDevice.choose_torch_device()
dtype = TorchDevice.choose_torch_dtype()
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
cfg_scale=self.cfg_scale,
steps=self.steps,
latent_height=latent_height,
latent_width=latent_width,
device=device,
dtype=dtype,
# TODO: old backend, remove
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
seed=seed,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
)
denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise,
seed=seed,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
attention_processor_cls=CustomAttnProcessor2_0,
),
unet=None,
scheduler=scheduler,
)
# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)
### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
ext_manager.add_extension(PreviewExt(step_callback))
### cfg rescale
if self.cfg_rescale_multiplier > 0:
ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier))
### freeu
if self.unet.freeu_config:
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
# context for loading additional models
with ExitStack() as exit_stack:
# later should be smth like:
# for extension_field in self.extensions:
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
# ext_manager.add_extension(ext)
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (cached_weights, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(denoise_ctx),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(unet, cached_weights),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.detach().to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
@@ -946,8 +788,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
device=unet.device,
dtype=unet.dtype,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,

View File

@@ -48,7 +48,6 @@ class UIType(str, Enum, metaclass=MetaEnum):
ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
T2IAdapterModel = "T2IAdapterModelField"
SpandrelImageToImageModel = "SpandrelImageToImageModelField"
# endregion
# region Misc Field Types
@@ -135,7 +134,6 @@ class FieldDescriptions:
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
spandrel_image_to_image_model = "Image-to-Image model"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
raw_prompt = "Raw prompt text (no parsing)"

View File

@@ -1,253 +0,0 @@
from typing import Callable
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
UIType,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
from invokeai.backend.tiles.utils import TBLR, Tile
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
image: ImageField = InputField(description="The input image")
image_to_image_model: ModelIdentifierField = InputField(
title="Image-to-Image Model",
description=FieldDescriptions.spandrel_image_to_image_model,
ui_type=UIType.SpandrelImageToImageModel,
)
tile_size: int = InputField(
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
)
@classmethod
def scale_tile(cls, tile: Tile, scale: int) -> Tile:
return Tile(
coords=TBLR(
top=tile.coords.top * scale,
bottom=tile.coords.bottom * scale,
left=tile.coords.left * scale,
right=tile.coords.right * scale,
),
overlap=TBLR(
top=tile.overlap.top * scale,
bottom=tile.overlap.bottom * scale,
left=tile.overlap.left * scale,
right=tile.overlap.right * scale,
),
)
@classmethod
def upscale_image(
cls,
image: Image.Image,
tile_size: int,
spandrel_model: SpandrelImageToImageModel,
is_canceled: Callable[[], bool],
) -> Image.Image:
# Compute the image tiles.
if tile_size > 0:
min_overlap = 20
tiles = calc_tiles_min_overlap(
image_height=image.height,
image_width=image.width,
tile_height=tile_size,
tile_width=tile_size,
min_overlap=min_overlap,
)
else:
# No tiling. Generate a single tile that covers the entire image.
min_overlap = 0
tiles = [
Tile(
coords=TBLR(top=0, bottom=image.height, left=0, right=image.width),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
]
# Sort tiles first by left x coordinate, then by top y coordinate. During tile processing, we want to iterate
# over tiles left-to-right, top-to-bottom.
tiles = sorted(tiles, key=lambda x: x.coords.left)
tiles = sorted(tiles, key=lambda x: x.coords.top)
# Prepare input image for inference.
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
# Scale the tiles for re-assembling the final image.
scale = spandrel_model.scale
scaled_tiles = [cls.scale_tile(tile, scale=scale) for tile in tiles]
# Prepare the output tensor.
_, channels, height, width = image_tensor.shape
output_tensor = torch.zeros(
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
)
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on each tile.
for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
# Exit early if the invocation has been canceled.
if is_canceled():
raise CanceledException
# Extract the current tile from the input tensor.
input_tile = image_tensor[
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on the tile.
output_tile = spandrel_model.run(input_tile)
# Convert the output tile into the output tensor's format.
# (N, C, H, W) -> (C, H, W)
output_tile = output_tile.squeeze(0)
# (C, H, W) -> (H, W, C)
output_tile = output_tile.permute(1, 2, 0)
output_tile = output_tile.clamp(0, 1)
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
# Merge the output tile into the output tensor.
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
# it seems unnecessary, but we may find a need in the future.
top_overlap = scaled_tile.overlap.top // 2
left_overlap = scaled_tile.overlap.left // 2
output_tensor[
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
:,
] = output_tile[top_overlap:, left_overlap:, :]
# Convert the output tensor to a PIL image.
np_image = output_tensor.detach().numpy().astype(np.uint8)
pil_image = Image.fromarray(np_image)
return pil_image
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
# Do the upscaling.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Upscale the image
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)
@invocation(
"spandrel_image_to_image_autoscale",
title="Image-to-Image (Autoscale)",
tags=["upscale"],
category="upscale",
version="1.0.0",
)
class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation):
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel) until the target scale is reached."""
scale: float = InputField(
default=4.0,
gt=0.0,
le=16.0,
description="The final scale of the output image. If the model does not upscale the image, this will be ignored.",
)
fit_to_multiple_of_8: bool = InputField(
default=False,
description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.",
)
@torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput:
# Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model)
# The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size.
# Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8.
target_width = int(image.width * self.scale)
target_height = int(image.height * self.scale)
# Do the upscaling.
with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel)
# First pass of upscaling. Note: `pil_image` will be mutated.
pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled)
# Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model
# upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions
# to be considered an upscale model.
is_upscale_model = pil_image.width > image.width and pil_image.height > image.height
if is_upscale_model:
# This is an upscale model, so we should keep upscaling until we reach the target size.
iterations = 1
while pil_image.width < target_width or pil_image.height < target_height:
pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled)
iterations += 1
# Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x.
# Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations.
# We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice,
# we should never reach this limit.
if iterations >= 5:
context.logger.warning(
"Upscale loop reached maximum iteration count of 5, stopping upscaling early."
)
break
else:
# This model doesn't upscale the image. We should ignore the scale parameter, modifying the output size
# to be the same as the processed image size.
# The output size is now the size of the processed image.
target_width = pil_image.width
target_height = pil_image.height
# Warn the user if they requested a scale greater than 1.
if self.scale > 1:
context.logger.warning(
"Model does not increase the size of the image, but a greater scale than 1 was requested. Image will not be scaled."
)
# We may need to resize the image to a multiple of 8. Use floor division to ensure we don't scale the image up
# in the final resize
if self.fit_to_multiple_of_8:
target_width = int(target_width // 8 * 8)
target_height = int(target_height // 8 * 8)
# Final resize. Per PIL documentation, Lanczos provides the best quality for both upscale and downscale.
# See: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table
pil_image = pil_image.resize((target_width, target_height), resample=Image.Resampling.LANCZOS)
image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto)

View File

@@ -175,10 +175,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
_, _, latent_height, latent_width = latents.shape
# Calculate the tile locations to cover the latent-space image.
# TODO(ryand): In the future, we may want to revisit the tile overlap strategy. Things to consider:
# - How much overlap 'context' to provide for each denoising step.
# - How much overlap to use during merging/blending.
# - Should we 'jitter' the tile locations in each step so that the seams are in different places?
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
@@ -222,8 +218,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
device=unet.device,
dtype=unet.dtype,
unet=unet,
latent_height=latent_tile_height,
latent_width=latent_tile_width,
cfg_scale=self.cfg_scale,

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecord
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecord, UncategorizedImageCounts
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
@@ -48,3 +48,8 @@ class BoardRecordStorageBase(ABC):
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
"""Gets all board records."""
pass
@abstractmethod
def get_uncategorized_image_counts(self) -> UncategorizedImageCounts:
"""Gets count of images and assets for uncategorized images (images with no board assocation)."""
pass

View File

@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Optional, Union
from typing import Any, Optional, Union
from pydantic import BaseModel, Field
@@ -26,21 +26,25 @@ class BoardRecord(BaseModelExcludeNull):
"""Whether or not the board is archived."""
is_private: Optional[bool] = Field(default=None, description="Whether the board is private.")
"""Whether the board is private."""
image_count: int = Field(description="The number of images in the board.")
asset_count: int = Field(description="The number of assets in the board.")
def deserialize_board_record(board_dict: dict) -> BoardRecord:
def deserialize_board_record(board_dict: dict[str, Any]) -> BoardRecord:
"""Deserializes a board record."""
# Retrieve all the values, setting "reasonable" defaults if they are not present.
board_id = board_dict.get("board_id", "unknown")
board_name = board_dict.get("board_name", "unknown")
cover_image_name = board_dict.get("cover_image_name", "unknown")
cover_image_name = board_dict.get("cover_image_name", None)
created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
archived = board_dict.get("archived", False)
is_private = board_dict.get("is_private", False)
image_count = board_dict.get("image_count", 0)
asset_count = board_dict.get("asset_count", 0)
return BoardRecord(
board_id=board_id,
@@ -51,6 +55,8 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
deleted_at=deleted_at,
archived=archived,
is_private=is_private,
image_count=image_count,
asset_count=asset_count,
)
@@ -63,19 +69,24 @@ class BoardChanges(BaseModel, extra="forbid"):
class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found."""
def __init__(self, message="Board record not found"):
def __init__(self, message: str = "Board record not found"):
super().__init__(message)
class BoardRecordSaveException(Exception):
"""Raised when an board record cannot be saved."""
def __init__(self, message="Board record not saved"):
def __init__(self, message: str = "Board record not saved"):
super().__init__(message)
class BoardRecordDeleteException(Exception):
"""Raised when an board record cannot be deleted."""
def __init__(self, message="Board record not deleted"):
def __init__(self, message: str = "Board record not deleted"):
super().__init__(message)
class UncategorizedImageCounts(BaseModel):
image_count: int = Field(description="The number of uncategorized images.")
asset_count: int = Field(description="The number of uncategorized assets.")

View File

@@ -1,5 +1,6 @@
import sqlite3
import threading
from dataclasses import dataclass
from typing import Union, cast
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
@@ -9,12 +10,108 @@ from invokeai.app.services.board_records.board_records_common import (
BoardRecordDeleteException,
BoardRecordNotFoundException,
BoardRecordSaveException,
UncategorizedImageCounts,
deserialize_board_record,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string
BASE_BOARD_RECORD_QUERY = """
-- This query retrieves board records, joining with the board_images and images tables to get image counts and cover image names.
-- It is not a complete query, as it is missing a GROUP BY or WHERE clause (and is unterminated).
SELECT b.board_id,
b.board_name,
b.created_at,
b.updated_at,
b.archived,
-- Count the number of images in the board, alias image_count
COUNT(
CASE
WHEN i.image_category in ('general') -- "Images" are images in the 'general' category
AND i.is_intermediate = 0 THEN 1 -- Intermediates are not counted
END
) AS image_count,
-- Count the number of assets in the board, alias asset_count
COUNT(
CASE
WHEN i.image_category in ('control', 'mask', 'user', 'other') -- "Assets" are images in any of the other categories ('control', 'mask', 'user', 'other')
AND i.is_intermediate = 0 THEN 1 -- Intermediates are not counted
END
) AS asset_count,
-- Get the name of the the most recent image in the board, alias cover_image_name
(
SELECT bi.image_name
FROM board_images bi
JOIN images i ON bi.image_name = i.image_name
WHERE bi.board_id = b.board_id
AND i.is_intermediate = 0 -- Intermediates cannot be cover images
ORDER BY i.created_at DESC -- Sort by created_at to get the most recent image
LIMIT 1
) AS cover_image_name
FROM boards b
LEFT JOIN board_images bi ON b.board_id = bi.board_id
LEFT JOIN images i ON bi.image_name = i.image_name
"""
@dataclass
class PaginatedBoardRecordsQueries:
main_query: str
total_count_query: str
def get_paginated_list_board_records_queries(include_archived: bool) -> PaginatedBoardRecordsQueries:
"""Gets a query to retrieve a paginated list of board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
main_query = f"""
{BASE_BOARD_RECORD_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC
LIMIT ? OFFSET ?;
"""
total_count_query = f"""
SELECT COUNT(*)
FROM boards b
{archived_condition};
"""
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
def get_list_all_board_records_query(include_archived: bool) -> str:
"""Gets a query to retrieve all board records."""
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
# The GROUP BY must be added _after_ the WHERE clause!
return f"""
{BASE_BOARD_RECORD_QUERY}
{archived_condition}
GROUP BY b.board_id,
b.board_name,
b.created_at,
b.updated_at
ORDER BY b.created_at DESC;
"""
def get_board_record_query() -> str:
"""Gets a query to retrieve a board record."""
return f"""
{BASE_BOARD_RECORD_QUERY}
WHERE b.board_id = ?;
"""
class SqliteBoardRecordStorage(BoardRecordStorageBase):
_conn: sqlite3.Connection
@@ -76,11 +173,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
try:
self._lock.acquire()
self._cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
get_board_record_query(),
(board_id,),
)
@@ -92,7 +185,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
self._lock.release()
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
return deserialize_board_record(dict(result))
def update(
self,
@@ -149,45 +242,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
try:
self._lock.acquire()
# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
"""
queries = get_paginated_list_board_records_queries(include_archived=include_archived)
# Determine archived filter condition
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
self._cursor.execute(
queries.main_query,
(limit, offset),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
self._cursor.execute(count_query)
self._cursor.execute(queries.total_count_query)
count = cast(int, self._cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
@@ -201,26 +266,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
try:
self._lock.acquire()
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
"""
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
self._cursor.execute(final_query)
self._cursor.execute(get_list_all_board_records_query(include_archived=include_archived))
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
return boards
except sqlite3.Error as e:
@@ -228,3 +276,28 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
raise e
finally:
self._lock.release()
def get_uncategorized_image_counts(self) -> UncategorizedImageCounts:
try:
self._lock.acquire()
query = """
-- Get the count of uncategorized images and assets.
SELECT
CASE
WHEN i.image_category = 'general' THEN 'image_count' -- "Images" are images in the 'general' category
ELSE 'asset_count' -- "Assets" are images in any of the other categories ('control', 'mask', 'user', 'other')
END AS category_type,
COUNT(*) AS unassigned_count
FROM images i
LEFT JOIN board_images bi ON i.image_name = bi.image_name
WHERE bi.board_id IS NULL -- Uncategorized images have no board association
AND i.is_intermediate = 0 -- Omit intermediates from the counts
GROUP BY category_type; -- Group by category_type alias, as derived from the image_category column earlier
"""
self._cursor.execute(query)
results = self._cursor.fetchall()
image_count = dict(results)["image_count"]
asset_count = dict(results)["asset_count"]
return UncategorizedImageCounts(image_count=image_count, asset_count=asset_count)
finally:
self._lock.release()

View File

@@ -1,23 +1,8 @@
from typing import Optional
from pydantic import Field
from invokeai.app.services.board_records.board_records_common import BoardRecord
# TODO(psyche): BoardDTO is now identical to BoardRecord. We should consider removing it.
class BoardDTO(BoardRecord):
"""Deserialized board record with cover image URL and image count."""
"""Deserialized board record."""
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
"""The URL of the thumbnail of the most recent image in the board."""
image_count: int = Field(description="The number of images in the board.")
"""The number of images in the board."""
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
"""Converts a board record to a board DTO."""
return BoardDTO(
**board_record.model_dump(exclude={"cover_image_name"}),
cover_image_name=cover_image_name,
image_count=image_count,
)
pass

View File

@@ -1,6 +1,6 @@
from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.boards.boards_base import BoardServiceABC
from invokeai.app.services.boards.boards_common import BoardDTO, board_record_to_dto
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
@@ -16,17 +16,11 @@ class BoardService(BoardServiceABC):
board_name: str,
) -> BoardDTO:
board_record = self.__invoker.services.board_records.save(board_name)
return board_record_to_dto(board_record, None, 0)
return BoardDTO.model_validate(board_record.model_dump())
def get_dto(self, board_id: str) -> BoardDTO:
board_record = self.__invoker.services.board_records.get(board_id)
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
return board_record_to_dto(board_record, cover_image_name, image_count)
return BoardDTO.model_validate(board_record.model_dump())
def update(
self,
@@ -34,14 +28,7 @@ class BoardService(BoardServiceABC):
changes: BoardChanges,
) -> BoardDTO:
board_record = self.__invoker.services.board_records.update(board_id, changes)
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
return board_record_to_dto(board_record, cover_image_name, image_count)
return BoardDTO.model_validate(board_record.model_dump())
def delete(self, board_id: str) -> None:
self.__invoker.services.board_records.delete(board_id)
@@ -50,30 +37,10 @@ class BoardService(BoardServiceABC):
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
board_dtos = []
for r in board_records.items:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records.items]
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all(include_archived)
board_dtos = []
for r in board_records:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
if cover_image:
cover_image_name = cover_image.image_name
else:
cover_image_name = None
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records]
return board_dtos

View File

@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from pydantic.networks import AnyHttpUrl
@@ -12,7 +12,7 @@ from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordChanges, ModelRecordServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig
@@ -64,7 +64,7 @@ class ModelInstallServiceBase(ABC):
def register_path(
self,
model_path: Union[Path, str],
config: Optional[ModelRecordChanges] = None,
config: Optional[Dict[str, Any]] = None,
) -> str:
"""
Probe and register the model at model_path.
@@ -72,7 +72,7 @@ class ModelInstallServiceBase(ABC):
This keeps the model in its current location.
:param model_path: Filesystem Path to the model.
:param config: ModelRecordChanges object that will override autoassigned model record values.
:param config: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model.
"""
@@ -92,7 +92,7 @@ class ModelInstallServiceBase(ABC):
def install_path(
self,
model_path: Union[Path, str],
config: Optional[ModelRecordChanges] = None,
config: Optional[Dict[str, Any]] = None,
) -> str:
"""
Probe, register and install the model in the models directory.
@@ -101,7 +101,7 @@ class ModelInstallServiceBase(ABC):
the models directory handled by InvokeAI.
:param model_path: Filesystem Path to the model.
:param config: ModelRecordChanges object that will override autoassigned model record values.
:param config: Dict of attributes that will override autoassigned values.
:returns id: The string ID of the registered model.
"""
@@ -109,14 +109,14 @@ class ModelInstallServiceBase(ABC):
def heuristic_import(
self,
source: str,
config: Optional[ModelRecordChanges] = None,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
r"""Install the indicated model using heuristics to interpret user intentions.
:param source: String source
:param config: Optional ModelRecordChanges object. Any fields in this object
:param config: Optional dict. Any fields in this dict
will override corresponding autoassigned probe fields in the
model's config record as described in `import_model()`.
:param access_token: Optional access token for remote sources.
@@ -147,7 +147,7 @@ class ModelInstallServiceBase(ABC):
def import_model(
self,
source: ModelSource,
config: Optional[ModelRecordChanges] = None,
config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
"""Install the indicated model.

View File

@@ -2,14 +2,13 @@ import re
import traceback
from enum import Enum
from pathlib import Path
from typing import Literal, Optional, Set, Union
from typing import Any, Dict, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob, MultiFileDownloadJob
from invokeai.app.services.model_records import ModelRecordChanges
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
@@ -134,9 +133,8 @@ class ModelInstallJob(BaseModel):
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: ModelRecordChanges = Field(
default_factory=ModelRecordChanges,
description="Configuration information (e.g. 'description') to apply to model.",
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."

View File

@@ -163,27 +163,26 @@ class ModelInstallService(ModelInstallServiceBase):
def register_path(
self,
model_path: Union[Path, str],
config: Optional[ModelRecordChanges] = None,
config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or ModelRecordChanges()
if not config.source:
config.source = model_path.resolve().as_posix()
config.source_type = ModelSourceType.Path
config = config or {}
if not config.get("source"):
config["source"] = model_path.resolve().as_posix()
config["source_type"] = ModelSourceType.Path
return self._register(model_path, config)
def install_path(
self,
model_path: Union[Path, str],
config: Optional[ModelRecordChanges] = None,
config: Optional[Dict[str, Any]] = None,
) -> str: # noqa D102
model_path = Path(model_path)
config = config or ModelRecordChanges()
info: AnyModelConfig = ModelProbe.probe(
Path(model_path), config.model_dump(), hash_algo=self._app_config.hashing_algorithm
) # type: ignore
config = config or {}
if preferred_name := config.name:
info: AnyModelConfig = ModelProbe.probe(Path(model_path), config, hash_algo=self._app_config.hashing_algorithm)
if preferred_name := config.get("name"):
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
dest_path = (
@@ -205,7 +204,7 @@ class ModelInstallService(ModelInstallServiceBase):
def heuristic_import(
self,
source: str,
config: Optional[ModelRecordChanges] = None,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
) -> ModelInstallJob:
@@ -217,7 +216,7 @@ class ModelInstallService(ModelInstallServiceBase):
source_obj.access_token = access_token
return self.import_model(source_obj, config)
def import_model(self, source: ModelSource, config: Optional[ModelRecordChanges] = None) -> ModelInstallJob: # noqa D102
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
if similar_jobs:
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
@@ -319,17 +318,16 @@ class ModelInstallService(ModelInstallServiceBase):
model_path = self._app_config.models_path / model_path
model_path = model_path.resolve()
config = ModelRecordChanges(
name=model_name,
description=stanza.get("description"),
)
config: dict[str, Any] = {}
config["name"] = model_name
config["description"] = stanza.get("description")
legacy_config_path = stanza.get("config")
if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
legacy_config_path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config.config_path = str(legacy_config_path)
config["config_path"] = str(legacy_config_path)
try:
id = self.register_path(model_path=model_path, config=config)
self._logger.info(f"Migrated {model_name} with id {id}")
@@ -502,11 +500,11 @@ class ModelInstallService(ModelInstallServiceBase):
job.total_bytes = self._stat_size(job.local_path)
job.bytes = job.total_bytes
self._signal_job_running(job)
job.config_in.source = str(job.source)
job.config_in.source_type = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
job.config_in["source"] = str(job.source)
job.config_in["source_type"] = MODEL_SOURCE_TO_TYPE_MAP[job.source.__class__]
# enter the metadata, if there is any
if isinstance(job.source_metadata, (HuggingFaceMetadata)):
job.config_in.source_api_response = job.source_metadata.api_response
job.config_in["source_api_response"] = job.source_metadata.api_response
if job.inplace:
key = self.register_path(job.local_path, job.config_in)
@@ -641,11 +639,11 @@ class ModelInstallService(ModelInstallServiceBase):
return new_path
def _register(
self, model_path: Path, config: Optional[ModelRecordChanges] = None, info: Optional[AnyModelConfig] = None
self, model_path: Path, config: Optional[Dict[str, Any]] = None, info: Optional[AnyModelConfig] = None
) -> str:
config = config or ModelRecordChanges()
config = config or {}
info = info or ModelProbe.probe(model_path, config.model_dump(), hash_algo=self._app_config.hashing_algorithm) # type: ignore
info = info or ModelProbe.probe(model_path, config, hash_algo=self._app_config.hashing_algorithm)
model_path = model_path.resolve()
@@ -676,13 +674,11 @@ class ModelInstallService(ModelInstallServiceBase):
precision = TorchDevice.choose_torch_dtype()
return ModelRepoVariant.FP16 if precision == torch.float16 else None
def _import_local_model(
self, source: LocalModelSource, config: Optional[ModelRecordChanges] = None
) -> ModelInstallJob:
def _import_local_model(self, source: LocalModelSource, config: Optional[Dict[str, Any]]) -> ModelInstallJob:
return ModelInstallJob(
id=self._next_id(),
source=source,
config_in=config or ModelRecordChanges(),
config_in=config or {},
local_path=Path(source.path),
inplace=source.inplace or False,
)
@@ -690,7 +686,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_hf(
self,
source: HFModelSource,
config: Optional[ModelRecordChanges] = None,
config: Optional[Dict[str, Any]] = None,
) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests
if source.access_token is None:
@@ -706,7 +702,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _import_from_url(
self,
source: URLModelSource,
config: Optional[ModelRecordChanges] = None,
config: Optional[Dict[str, Any]],
) -> ModelInstallJob:
remote_files, metadata = self._remote_files_from_source(source)
return self._import_remote_model(
@@ -721,7 +717,7 @@ class ModelInstallService(ModelInstallServiceBase):
source: HFModelSource | URLModelSource,
remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata],
config: Optional[ModelRecordChanges],
config: Optional[Dict[str, Any]],
) -> ModelInstallJob:
if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found")
@@ -734,7 +730,7 @@ class ModelInstallService(ModelInstallServiceBase):
install_job = ModelInstallJob(
id=self._next_id(),
source=source,
config_in=config or ModelRecordChanges(),
config_in=config or {},
source_metadata=metadata,
local_path=destdir, # local path may change once the download has started due to content-disposition handling
bytes=0,

View File

@@ -18,7 +18,6 @@ from invokeai.backend.model_manager.config import (
ControlAdapterDefaultSettings,
MainModelDefaultSettings,
ModelFormat,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
@@ -67,16 +66,10 @@ class ModelRecordChanges(BaseModelExcludeNull):
"""A set of changes to apply to a model."""
# Changes applicable to all models
source: Optional[str] = Field(description="original source of the model", default=None)
source_type: Optional[ModelSourceType] = Field(description="type of model source", default=None)
source_api_response: Optional[str] = Field(description="metadata from remote source", default=None)
name: Optional[str] = Field(description="Name of the model.", default=None)
path: Optional[str] = Field(description="Path to the model.", default=None)
description: Optional[str] = Field(description="Model description", default=None)
base: Optional[BaseModelType] = Field(description="The base model.", default=None)
type: Optional[ModelType] = Field(description="Type of model", default=None)
key: Optional[str] = Field(description="Database ID for this model", default=None)
hash: Optional[str] = Field(description="hash of model file", default=None)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[MainModelDefaultSettings | ControlAdapterDefaultSettings] = Field(
description="Default settings for this model", default=None

View File

@@ -124,14 +124,16 @@ class IPAdapter(RawModel):
self.device, dtype=self.dtype
)
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
def to(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
):
if device is not None:
self.device = device
if dtype is not None:
self.dtype = dtype
self._image_proj_model.to(device=self.device, dtype=self.dtype)
self.attn_weights.to(device=self.device, dtype=self.dtype)
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
def calc_size(self) -> int:
# HACK(ryand): Fix this issue with circular imports.

View File

@@ -11,6 +11,7 @@ from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.raw_model import RawModel
from invokeai.backend.util.devices import TorchDevice
class LoRALayerBase:
@@ -56,9 +57,14 @@ class LoRALayerBase:
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
# TODO: find and debug lora/locon with bias
@@ -100,14 +106,19 @@ class LoRALayer(LoRALayerBase):
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
super().to(device=device, dtype=dtype)
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype, non_blocking=non_blocking)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoHALayer(LoRALayerBase):
@@ -156,18 +167,23 @@ class LoHALayer(LoRALayerBase):
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoKRLayer(LoRALayerBase):
@@ -248,7 +264,12 @@ class LoKRLayer(LoRALayerBase):
model_size += val.nelement() * val.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
if self.w1 is not None:
@@ -256,19 +277,19 @@ class LoKRLayer(LoRALayerBase):
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class FullLayer(LoRALayerBase):
@@ -298,10 +319,15 @@ class FullLayer(LoRALayerBase):
model_size += self.weight.nelement() * self.weight.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
class IA3Layer(LoRALayerBase):
@@ -333,11 +359,16 @@ class IA3Layer(LoRALayerBase):
model_size += self.on_input.nelement() * self.on_input.element_size()
return model_size
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
@@ -359,10 +390,15 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
def name(self) -> str:
return self._name
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
def calc_size(self) -> int:
model_size = 0
@@ -485,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
model.layers[layer_key] = layer
return model

View File

@@ -67,7 +67,6 @@ class ModelType(str, Enum):
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
SpandrelImageToImage = "spandrel_image_to_image"
class SubModelType(str, Enum):
@@ -354,7 +353,7 @@ class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
@@ -365,24 +364,13 @@ class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
class SpandrelImageToImageConfig(ModelConfigBase):
"""Model config for Spandrel Image to Image models."""
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
@@ -419,7 +407,6 @@ AnyModelConfig = Annotated[
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),

View File

@@ -167,8 +167,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
size = calc_model_size_by_data(self.logger, model)
self.make_room(size)
running_on_cpu = self.execution_device == torch.device("cpu")
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
@@ -290,9 +289,11 @@ class ModelCache(ModelCacheBase[AnyModel]):
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(target_device, copy=True)
new_dict[k] = v.to(
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device)
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)

View File

@@ -1,45 +0,0 @@
from pathlib import Path
from typing import Optional
import torch
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
@ModelLoaderRegistry.register(
base=BaseModelType.Any, type=ModelType.SpandrelImageToImage, format=ModelFormat.Checkpoint
)
class SpandrelImageToImageModelLoader(ModelLoader):
"""Class for loading Spandrel Image-to-Image models (i.e. models wrapped by spandrel.ImageModelDescriptor)."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("Unexpected submodel requested for Spandrel model.")
model_path = Path(config.path)
model = SpandrelImageToImageModel.load_from_file(model_path)
torch_dtype = self._torch_dtype
if not model.supports_dtype(torch_dtype):
self._logger.warning(
f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} "
"model. Falling back to 'float32'."
)
torch_dtype = torch.float32
model.to(dtype=torch_dtype)
return model

View File

@@ -98,9 +98,6 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
ModelVariantType.Normal: StableDiffusionXLPipeline,
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: StableDiffusionXLPipeline,
},
}
assert isinstance(config, MainCheckpointConfig)
try:

View File

@@ -15,7 +15,6 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
@@ -34,7 +33,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
elif isinstance(model, CLIPTokenizer):
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
return 0
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)):
return model.calc_size()
else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the

View File

@@ -4,7 +4,6 @@ from pathlib import Path
from typing import Any, Dict, Literal, Optional, Union
import safetensors.torch
import spandrel
import torch
from picklescan.scanner import scan_file_path
@@ -26,7 +25,6 @@ from invokeai.backend.model_manager.config import (
SchedulerPredictionType,
)
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length, read_checkpoint_meta
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.util.silence_warnings import SilenceWarnings
CkptType = Dict[str | int, Any]
@@ -222,46 +220,24 @@ class ModelProbe(object):
ckpt = ckpt.get("state_dict", ckpt)
for key in [str(k) for k in ckpt.keys()]:
if key.startswith(("cond_stage_model.", "first_stage_model.", "model.diffusion_model.")):
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
return ModelType.Main
elif key.startswith(("encoder.conv_in", "decoder.conv_in")):
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
return ModelType.VAE
elif key.startswith(("lora_te_", "lora_unet_")):
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
return ModelType.LoRA
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight")):
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
return ModelType.LoRA
elif key.startswith(("controlnet", "control_model", "input_blocks")):
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
return ModelType.ControlNet
elif key.startswith(("image_proj.", "ip_adapter.")):
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
return ModelType.IPAdapter
elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
# Check if the model can be loaded as a SpandrelImageToImageModel.
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
try:
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
# explored to avoid this:
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
# supported on meta tensors.
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
# maintain it, and the risk of false positive detections is higher.
SpandrelImageToImageModel.load_from_file(model_path)
return ModelType.SpandrelImageToImage
except spandrel.UnsupportedModelError:
pass
except RuntimeError as e:
if "No such file or directory" in str(e):
# This error is expected if the model_path does not exist (which is the case in some unit tests).
pass
else:
raise e
else:
# diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
@@ -593,11 +569,6 @@ class T2IAdapterCheckpointProbe(CheckpointProbeBase):
raise NotImplementedError()
class SpandrelImageToImageCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
########################################################
# classes for probing folders
#######################################################
@@ -805,11 +776,6 @@ class CLIPVisionFolderProbe(FolderProbeBase):
return BaseModelType.Any
class SpandrelImageToImageFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.model_path / "config.json"
@@ -839,7 +805,6 @@ ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderPro
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.SpandrelImageToImage, SpandrelImageToImageFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
@@ -849,6 +814,5 @@ ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpoi
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.SpandrelImageToImage, SpandrelImageToImageCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@@ -187,171 +187,157 @@ STARTER_MODELS: list[StarterModel] = [
# endregion
# region ControlNet
StarterModel(
name="QRCode Monster v2 (SD1.5)",
name="QRCode Monster",
base=BaseModelType.StableDiffusion1,
source="monster-labs/control_v1p_sd15_qrcode_monster::v2",
description="ControlNet model that generates scannable creative QR codes",
type=ModelType.ControlNet,
),
StarterModel(
name="QRCode Monster (SDXL)",
base=BaseModelType.StableDiffusionXL,
source="monster-labs/control_v1p_sdxl_qrcode_monster",
description="ControlNet model that generates scannable creative QR codes",
source="monster-labs/control_v1p_sd15_qrcode_monster",
description="Controlnet model that generates scannable creative QR codes",
type=ModelType.ControlNet,
),
StarterModel(
name="canny",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_canny",
description="ControlNet weights trained on sd-1.5 with canny conditioning.",
description="Controlnet weights trained on sd-1.5 with canny conditioning.",
type=ModelType.ControlNet,
),
StarterModel(
name="inpaint",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_inpaint",
description="ControlNet weights trained on sd-1.5 with canny conditioning, inpaint version",
description="Controlnet weights trained on sd-1.5 with canny conditioning, inpaint version",
type=ModelType.ControlNet,
),
StarterModel(
name="mlsd",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_mlsd",
description="ControlNet weights trained on sd-1.5 with canny conditioning, MLSD version",
description="Controlnet weights trained on sd-1.5 with canny conditioning, MLSD version",
type=ModelType.ControlNet,
),
StarterModel(
name="depth",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11f1p_sd15_depth",
description="ControlNet weights trained on sd-1.5 with depth conditioning",
description="Controlnet weights trained on sd-1.5 with depth conditioning",
type=ModelType.ControlNet,
),
StarterModel(
name="normal_bae",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_normalbae",
description="ControlNet weights trained on sd-1.5 with normalbae image conditioning",
description="Controlnet weights trained on sd-1.5 with normalbae image conditioning",
type=ModelType.ControlNet,
),
StarterModel(
name="seg",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_seg",
description="ControlNet weights trained on sd-1.5 with seg image conditioning",
description="Controlnet weights trained on sd-1.5 with seg image conditioning",
type=ModelType.ControlNet,
),
StarterModel(
name="lineart",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_lineart",
description="ControlNet weights trained on sd-1.5 with lineart image conditioning",
description="Controlnet weights trained on sd-1.5 with lineart image conditioning",
type=ModelType.ControlNet,
),
StarterModel(
name="lineart_anime",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15s2_lineart_anime",
description="ControlNet weights trained on sd-1.5 with anime image conditioning",
description="Controlnet weights trained on sd-1.5 with anime image conditioning",
type=ModelType.ControlNet,
),
StarterModel(
name="openpose",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_openpose",
description="ControlNet weights trained on sd-1.5 with openpose image conditioning",
description="Controlnet weights trained on sd-1.5 with openpose image conditioning",
type=ModelType.ControlNet,
),
StarterModel(
name="scribble",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_scribble",
description="ControlNet weights trained on sd-1.5 with scribble image conditioning",
description="Controlnet weights trained on sd-1.5 with scribble image conditioning",
type=ModelType.ControlNet,
),
StarterModel(
name="softedge",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11p_sd15_softedge",
description="ControlNet weights trained on sd-1.5 with soft edge conditioning",
description="Controlnet weights trained on sd-1.5 with soft edge conditioning",
type=ModelType.ControlNet,
),
StarterModel(
name="shuffle",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11e_sd15_shuffle",
description="ControlNet weights trained on sd-1.5 with shuffle image conditioning",
description="Controlnet weights trained on sd-1.5 with shuffle image conditioning",
type=ModelType.ControlNet,
),
StarterModel(
name="tile",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11f1e_sd15_tile",
description="ControlNet weights trained on sd-1.5 with tiled image conditioning",
description="Controlnet weights trained on sd-1.5 with tiled image conditioning",
type=ModelType.ControlNet,
),
StarterModel(
name="ip2p",
base=BaseModelType.StableDiffusion1,
source="lllyasviel/control_v11e_sd15_ip2p",
description="ControlNet weights trained on sd-1.5 with ip2p conditioning.",
description="Controlnet weights trained on sd-1.5 with ip2p conditioning.",
type=ModelType.ControlNet,
),
StarterModel(
name="canny-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlNet-canny-sdxl-1.0",
description="ControlNet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
source="xinsir/controlnet-canny-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
type=ModelType.ControlNet,
),
StarterModel(
name="depth-sdxl",
base=BaseModelType.StableDiffusionXL,
source="diffusers/controlNet-depth-sdxl-1.0",
description="ControlNet weights trained on sdxl-1.0 with depth conditioning.",
source="diffusers/controlnet-depth-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 with depth conditioning.",
type=ModelType.ControlNet,
),
StarterModel(
name="softedge-dexined-sdxl",
base=BaseModelType.StableDiffusionXL,
source="SargeZT/controlNet-sd-xl-1.0-softedge-dexined",
description="ControlNet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
source="SargeZT/controlnet-sd-xl-1.0-softedge-dexined",
description="Controlnet weights trained on sdxl-1.0 with dexined soft edge preprocessing.",
type=ModelType.ControlNet,
),
StarterModel(
name="depth-16bit-zoe-sdxl",
base=BaseModelType.StableDiffusionXL,
source="SargeZT/controlNet-sd-xl-1.0-depth-16bit-zoe",
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
source="SargeZT/controlnet-sd-xl-1.0-depth-16bit-zoe",
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (16 bits).",
type=ModelType.ControlNet,
),
StarterModel(
name="depth-zoe-sdxl",
base=BaseModelType.StableDiffusionXL,
source="diffusers/controlNet-zoe-depth-sdxl-1.0",
description="ControlNet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
source="diffusers/controlnet-zoe-depth-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
type=ModelType.ControlNet,
),
StarterModel(
name="openpose-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlNet-openpose-sdxl-1.0",
description="ControlNet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
source="xinsir/controlnet-openpose-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
type=ModelType.ControlNet,
),
StarterModel(
name="scribble-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlNet-scribble-sdxl-1.0",
description="ControlNet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
type=ModelType.ControlNet,
),
StarterModel(
name="tile-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlNet-tile-sdxl-1.0",
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
source="xinsir/controlnet-scribble-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
type=ModelType.ControlNet,
),
# endregion
@@ -413,43 +399,6 @@ STARTER_MODELS: list[StarterModel] = [
type=ModelType.T2IAdapter,
),
# endregion
# region SpandrelImageToImage
StarterModel(
name="RealESRGAN_x4plus_anime_6B",
base=BaseModelType.Any,
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
description="A Real-ESRGAN 4x upscaling model (optimized for anime images).",
type=ModelType.SpandrelImageToImage,
),
StarterModel(
name="RealESRGAN_x4plus",
base=BaseModelType.Any,
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
description="A Real-ESRGAN 4x upscaling model (general-purpose).",
type=ModelType.SpandrelImageToImage,
),
StarterModel(
name="ESRGAN_SRx4_DF2KOST_official",
base=BaseModelType.Any,
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
description="The official ESRGAN 4x upscaling model.",
type=ModelType.SpandrelImageToImage,
),
StarterModel(
name="RealESRGAN_x2plus",
base=BaseModelType.Any,
source="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
description="A Real-ESRGAN 2x upscaling model (general-purpose).",
type=ModelType.SpandrelImageToImage,
),
StarterModel(
name="SwinIR - realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN",
base=BaseModelType.Any,
source="https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN-with-dict-keys-params-and-params_ema.pth",
description="A SwinIR 4x upscaling model.",
type=ModelType.SpandrelImageToImage,
),
# endregion
]
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -32,27 +32,8 @@ with LoRAHelper.apply_lora_unet(unet, loras):
"""
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
@staticmethod
@contextmanager
def patch_unet_attention_processor(unet: UNet2DConditionModel, processor_cls: Type[Any]):
"""A context manager that patches `unet` with the provided attention processor.
Args:
unet (UNet2DConditionModel): The UNet model to patch.
processor (Type[Any]): Class which will be initialized for each key and passed to set_attn_processor(...).
"""
unet_orig_processors = unet.attn_processors
# create separate instance for each attention, to be able modify each attention separately
unet_new_processors = {key: processor_cls() for key in unet_orig_processors.keys()}
try:
unet.set_attn_processor(unet_new_processors)
yield None
finally:
unet.set_attn_processor(unet_orig_processors)
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
@@ -158,12 +139,15 @@ class ModelPatcher:
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device))
layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device))
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=TorchDevice.CPU_DEVICE)
layer.to(
device=TorchDevice.CPU_DEVICE,
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
@@ -172,7 +156,7 @@ class ModelPatcher:
layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype)
module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
yield # wait for context manager exit
@@ -180,7 +164,9 @@ class ModelPatcher:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)
model.get_submodule(module_key).weight.copy_(
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
)
@classmethod
@contextmanager

View File

@@ -190,7 +190,12 @@ class IAIOnnxRuntimeModel(RawModel):
return self.session.run(None, inputs)
# compatability with RawModel ABC
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass
# compatability with diffusers load code

View File

@@ -1,3 +1,15 @@
"""Base class for 'Raw' models.
The RawModel class is the base class of LoRAModelRaw and TextualInversionModelRaw,
and is used for type checking of calls to the model patcher. Its main purpose
is to avoid a circular import issues when lora.py tries to import BaseModelType
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
from lora.py.
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes.
"""
from abc import ABC, abstractmethod
from typing import Optional
@@ -5,18 +17,13 @@ import torch
class RawModel(ABC):
"""Base class for 'Raw' models.
The RawModel class is the base class of LoRAModelRaw, TextualInversionModelRaw, etc.
and is used for type checking of calls to the model patcher. Its main purpose
is to avoid a circular import issues when lora.py tries to import BaseModelType
from invokeai.backend.model_manager.config, and the latter tries to import LoRAModelRaw
from lora.py.
The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes.
"""
"""Abstract base class for 'Raw' model wrappers."""
@abstractmethod
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass

View File

@@ -1,139 +0,0 @@
from pathlib import Path
from typing import Any, Optional
import numpy as np
import torch
from PIL import Image
from spandrel import ImageModelDescriptor, ModelLoader
from invokeai.backend.raw_model import RawModel
class SpandrelImageToImageModel(RawModel):
"""A wrapper for a Spandrel Image-to-Image model.
The main reason for having a wrapper class is to integrate with the type handling of RawModel.
"""
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
self._spandrel_model = spandrel_model
@staticmethod
def pil_to_tensor(image: Image.Image) -> torch.Tensor:
"""Convert PIL Image to the torch.Tensor format expected by SpandrelImageToImageModel.run().
Args:
image (Image.Image): A PIL Image with shape (H, W, C) and values in the range [0, 255].
Returns:
torch.Tensor: A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
"""
image_np = np.array(image)
# (H, W, C) -> (C, H, W)
image_np = np.transpose(image_np, (2, 0, 1))
image_np = image_np / 255
image_tensor = torch.from_numpy(image_np).float()
# (C, H, W) -> (N, C, H, W)
image_tensor = image_tensor.unsqueeze(0)
return image_tensor
@staticmethod
def tensor_to_pil(tensor: torch.Tensor) -> Image.Image:
"""Convert a torch.Tensor produced by SpandrelImageToImageModel.run() to a PIL Image.
Args:
tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
Returns:
Image.Image: A PIL Image with shape (H, W, C) and values in the range [0, 255].
"""
# (N, C, H, W) -> (C, H, W)
tensor = tensor.squeeze(0)
# (C, H, W) -> (H, W, C)
tensor = tensor.permute(1, 2, 0)
tensor = tensor.clamp(0, 1)
tensor = (tensor * 255).cpu().detach().numpy().astype(np.uint8)
image = Image.fromarray(tensor)
return image
def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""Run the image-to-image model.
Args:
image_tensor (torch.Tensor): A torch.Tensor with shape (N, C, H, W) and values in the range [0, 1].
"""
return self._spandrel_model(image_tensor)
@classmethod
def load_from_file(cls, file_path: str | Path):
model = ModelLoader().load_from_file(file_path)
if not isinstance(model, ImageModelDescriptor):
raise ValueError(
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
"('ImageModelDescriptor')."
)
return cls(spandrel_model=model)
@classmethod
def load_from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
model = ModelLoader().load_from_state_dict(state_dict)
if not isinstance(model, ImageModelDescriptor):
raise ValueError(
f"Loaded a spandrel model of type '{type(model)}'. Only image-to-image models are supported "
"('ImageModelDescriptor')."
)
return cls(spandrel_model=model)
def supports_dtype(self, dtype: torch.dtype) -> bool:
"""Check if the model supports the given dtype."""
if dtype == torch.float16:
return self._spandrel_model.supports_half
elif dtype == torch.bfloat16:
return self._spandrel_model.supports_bfloat16
elif dtype == torch.float32:
# All models support float32.
return True
else:
raise ValueError(f"Unexpected dtype '{dtype}'.")
def get_model_type_name(self) -> str:
"""The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining
consistent over time.
"""
return str(type(self._spandrel_model.model))
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
"""Note: Some models have limited dtype support. Call supports_dtype(...) to check if the dtype is supported.
Note: The non_blocking parameter is currently ignored."""
# TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will have to access the
# model directly if we want to apply this optimization.
self._spandrel_model.to(device=device, dtype=dtype)
@property
def device(self) -> torch.device:
"""The device of the underlying model."""
return self._spandrel_model.device
@property
def dtype(self) -> torch.dtype:
"""The dtype of the underlying model."""
return self._spandrel_model.dtype
@property
def scale(self) -> int:
"""The scale of the model (e.g. 1x, 2x, 4x, etc.)."""
return self._spandrel_model.scale
def calc_size(self) -> int:
"""Get size of the model in memory in bytes."""
# HACK(ryand): Fix this issue with circular imports.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_module_size(self._spandrel_model.model)

View File

@@ -1,131 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union
import torch
from diffusers import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData
@dataclass
class UNetKwargs:
sample: torch.Tensor
timestep: Union[torch.Tensor, float, int]
encoder_hidden_states: torch.Tensor
class_labels: Optional[torch.Tensor] = None
timestep_cond: Optional[torch.Tensor] = None
attention_mask: Optional[torch.Tensor] = None
cross_attention_kwargs: Optional[Dict[str, Any]] = None
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
mid_block_additional_residual: Optional[torch.Tensor] = None
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
encoder_attention_mask: Optional[torch.Tensor] = None
# return_dict: bool = True
@dataclass
class DenoiseInputs:
"""Initial variables passed to denoise. Supposed to be unchanged."""
# The latent-space image to denoise.
# Shape: [batch, channels, latent_height, latent_width]
# - If we are inpainting, this is the initial latent image before noise has been added.
# - If we are generating a new image, this should be initialized to zeros.
# - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
orig_latents: torch.Tensor
# kwargs forwarded to the scheduler.step() method.
scheduler_step_kwargs: dict[str, Any]
# Text conditionging data.
conditioning_data: TextConditioningData
# Noise used for two purposes:
# 1. Used by the scheduler to noise the initial `latents` before denoising.
# 2. Used to noise the `masked_latents` when inpainting.
# `noise` should be None if the `latents` tensor has already been noised.
# Shape: [1 or batch, channels, latent_height, latent_width]
noise: Optional[torch.Tensor]
# The seed used to generate the noise for the denoising process.
# HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
# same noise used earlier in the pipeline. This should really be handled in a clearer way.
seed: int
# The timestep schedule for the denoising process.
timesteps: torch.Tensor
# The first timestep in the schedule. This is used to determine the initial noise level, so
# should be populated if you want noise applied *even* if timesteps is empty.
init_timestep: torch.Tensor
# Class of attention processor that is used.
attention_processor_cls: Type[Any]
@dataclass
class DenoiseContext:
"""Context with all variables in denoise"""
# Initial variables passed to denoise. Supposed to be unchanged.
inputs: DenoiseInputs
# Scheduler which used to apply noise predictions.
scheduler: SchedulerMixin
# UNet model.
unet: Optional[UNet2DConditionModel] = None
# Current state of latent-space image in denoising process.
# None until `PRE_DENOISE_LOOP` callback.
# Shape: [batch, channels, latent_height, latent_width]
latents: Optional[torch.Tensor] = None
# Current denoising step index.
# None until `PRE_STEP` callback.
step_index: Optional[int] = None
# Current denoising step timestep.
# None until `PRE_STEP` callback.
timestep: Optional[torch.Tensor] = None
# Arguments which will be passed to UNet model.
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
unet_kwargs: Optional[UNetKwargs] = None
# SchedulerOutput class returned from step function(normally, generated by scheduler).
# Supposed to be used only in `POST_STEP` callback, otherwise can be None.
step_output: Optional[SchedulerOutput] = None
# Scaled version of `latents`, which will be passed to unet_kwargs initialization.
# Available in events inside step(between `PRE_STEP` and `POST_STEP`).
# Shape: [batch, channels, latent_height, latent_width]
latent_model_input: Optional[torch.Tensor] = None
# [TMP] Defines on which conditionings current unet call will be runned.
# Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
conditioning_mode: Optional[ConditioningMode] = None
# [TMP] Noise predictions from negative conditioning.
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
negative_noise_pred: Optional[torch.Tensor] = None
# [TMP] Noise predictions from positive conditioning.
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
positive_noise_pred: Optional[torch.Tensor] = None
# Combined noise prediction from passed conditionings.
# Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
# Shape: [batch, channels, latent_height, latent_width]
noise_pred: Optional[torch.Tensor] = None
# Dictionary for extensions to pass extra info about denoise process to other extensions.
extra: dict = field(default_factory=dict)

View File

@@ -23,12 +23,21 @@ from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState
from invokeai.backend.util.attention import auto_detect_slice_size
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
@dataclass
class AddsMaskGuidance:
mask: torch.Tensor

View File

@@ -1,17 +1,10 @@
from __future__ import annotations
import math
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import List, Optional, Union
import torch
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
if TYPE_CHECKING:
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
@dataclass
@@ -102,12 +95,6 @@ class TextConditioningRegions:
assert self.masks.shape[1] == len(self.ranges)
class ConditioningMode(Enum):
Both = "both"
Negative = "negative"
Positive = "positive"
class TextConditioningData:
def __init__(
self,
@@ -116,7 +103,7 @@ class TextConditioningData:
uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]],
guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
guidance_rescale_multiplier: float = 0,
):
self.uncond_text = uncond_text
self.cond_text = cond_text
@@ -127,7 +114,6 @@ class TextConditioningData:
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
self.guidance_scale = guidance_scale
# TODO: old backend, remove
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
self.guidance_rescale_multiplier = guidance_rescale_multiplier
@@ -135,114 +121,3 @@ class TextConditioningData:
def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo)
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
"""Fills unet arguments with data from provided conditionings.
Args:
unet_kwargs (UNetKwargs): Object which stores UNet model arguments.
conditioning_mode (ConditioningMode): Describes which conditionings should be used.
"""
_, _, h, w = unet_kwargs.sample.shape
device = unet_kwargs.sample.device
dtype = unet_kwargs.sample.dtype
# TODO: combine regions with conditionings
if conditioning_mode == ConditioningMode.Both:
conditionings = [self.uncond_text, self.cond_text]
c_regions = [self.uncond_regions, self.cond_regions]
elif conditioning_mode == ConditioningMode.Positive:
conditionings = [self.cond_text]
c_regions = [self.cond_regions]
elif conditioning_mode == ConditioningMode.Negative:
conditionings = [self.uncond_text]
c_regions = [self.uncond_regions]
else:
raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}")
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
[c.embeds for c in conditionings]
)
unet_kwargs.encoder_hidden_states = encoder_hidden_states
unet_kwargs.encoder_attention_mask = encoder_attention_mask
if self.is_sdxl():
added_cond_kwargs = dict( # noqa: C408
text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
time_ids=torch.cat([c.add_time_ids for c in conditionings]),
)
unet_kwargs.added_cond_kwargs = added_cond_kwargs
if any(r is not None for r in c_regions):
tmp_regions = []
for c, r in zip(conditionings, c_regions, strict=True):
if r is None:
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=dtype),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
tmp_regions.append(r)
if unet_kwargs.cross_attention_kwargs is None:
unet_kwargs.cross_attention_kwargs = {}
unet_kwargs.cross_attention_kwargs.update(
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
)
@staticmethod
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor:
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
@classmethod
def _pad_conditioning(
cls,
cond: torch.Tensor,
target_len: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes.
Args:
cond (torch.Tensor): Conditioning tensor which to pads by zeros.
target_len (int): To which length(tokens count) pad tensor.
"""
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
if cond.shape[1] < target_len:
conditioning_attention_mask = cls._pad_zeros(
conditioning_attention_mask,
pad_shape=(cond.shape[0], target_len - cond.shape[1]),
dim=1,
)
cond = cls._pad_zeros(
cond,
pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]),
dim=1,
)
return cond, conditioning_attention_mask
@classmethod
def _concat_conditionings_for_batch(
cls,
conditionings: List[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Concatenate provided conditioning tensors to one batched tensor.
If tensors have different sizes then pad them by zeros and creates
encoder_attention_mask to exclude padding from attention.
Args:
conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate.
"""
encoder_attention_mask = None
max_len = max([c.shape[1] for c in conditionings])
if any(c.shape[1] != max_len for c in conditionings):
encoder_attention_masks = [None] * len(conditionings)
for i in range(len(conditionings)):
conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(conditionings[i], max_len)
encoder_attention_mask = torch.cat(encoder_attention_masks)
return torch.cat(conditionings), encoder_attention_mask

View File

@@ -1,14 +1,9 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningRegions,
)
class RegionalPromptData:

View File

@@ -1,142 +0,0 @@
from __future__ import annotations
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from tqdm.auto import tqdm
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
class StableDiffusionBackend:
def __init__(
self,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
):
self.unet = unet
self.scheduler = scheduler
config = get_config()
self._sequential_guidance = config.sequential_guidance
def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
if ctx.inputs.init_timestep.shape[0] == 0:
return ctx.inputs.orig_latents
ctx.latents = ctx.inputs.orig_latents.clone()
if ctx.inputs.noise is not None:
batch_size = ctx.latents.shape[0]
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
ctx.latents = ctx.scheduler.add_noise(
ctx.latents, ctx.inputs.noise, ctx.inputs.init_timestep.expand(batch_size)
)
# if no work to do, return latents
if ctx.inputs.timesteps.shape[0] == 0:
return ctx.latents
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
# ext: preview[pre_denoise_loop, priority=low]
ext_manager.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, ctx)
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020
# ext: inpaint (apply mask to latents on non-inpaint models)
ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx)
# ext: tiles? [override: step]
ctx.step_output = self.step(ctx, ext_manager)
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
# ext: preview[post_step, priority=low]
ext_manager.run_callback(ExtensionCallbackType.POST_STEP, ctx)
ctx.latents = ctx.step_output.prev_sample
# ext: inpaint[post_denoise_loop] (restore unmasked part)
ext_manager.run_callback(ExtensionCallbackType.POST_DENOISE_LOOP, ctx)
return ctx.latents
@torch.inference_mode()
def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput:
ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep)
# TODO: conditionings as list(conditioning_data.to_unet_kwargs - ready)
# Note: The current handling of conditioning doesn't feel very future-proof.
# This might change in the future as new requirements come up, but for now,
# this is the rough plan.
if self._sequential_guidance:
ctx.negative_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Negative)
ctx.positive_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Positive)
else:
both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both)
ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2)
# ext: override combine_noise_preds
ctx.noise_pred = self.combine_noise_preds(ctx)
# ext: cfg_rescale [modify_noise_prediction]
# TODO: rename
ext_manager.run_callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS, ctx)
# compute the previous noisy sample x_t -> x_t-1
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
# clean up locals
ctx.latent_model_input = None
ctx.negative_noise_pred = None
ctx.positive_noise_pred = None
ctx.noise_pred = None
return step_output
@staticmethod
def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor:
guidance_scale = ctx.inputs.conditioning_data.guidance_scale
if isinstance(guidance_scale, list):
guidance_scale = guidance_scale[ctx.step_index]
# Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result
# in slightly different outputs. It is suspected that this is caused by small precision differences.
# return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale)
return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred)
def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode):
sample = ctx.latent_model_input
if conditioning_mode == ConditioningMode.Both:
sample = torch.cat([sample] * 2)
ctx.unet_kwargs = UNetKwargs(
sample=sample,
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning
cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / len(ctx.inputs.timesteps),
),
)
ctx.conditioning_mode = conditioning_mode
ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
# ext: controlnet/ip/t2i [pre_unet]
ext_manager.run_callback(ExtensionCallbackType.PRE_UNET, ctx)
# ext: inpaint [pre_unet, priority=low]
# or
# ext: inpaint [override: unet_forward]
noise_pred = self._unet_forward(**vars(ctx.unet_kwargs))
ext_manager.run_callback(ExtensionCallbackType.POST_UNET, ctx)
# clean up locals
ctx.unet_kwargs = None
ctx.conditioning_mode = None
return noise_pred
def _unet_forward(self, **kwargs) -> torch.Tensor:
return self.unet(**kwargs).sample

View File

@@ -1,12 +0,0 @@
from enum import Enum
class ExtensionCallbackType(Enum):
SETUP = "setup"
PRE_DENOISE_LOOP = "pre_denoise_loop"
POST_DENOISE_LOOP = "post_denoise_loop"
PRE_STEP = "pre_step"
POST_STEP = "post_step"
PRE_UNET = "pre_unet"
POST_UNET = "post_unet"
POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds"

View File

@@ -1,60 +0,0 @@
from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch
from diffusers import UNet2DConditionModel
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
@dataclass
class CallbackMetadata:
callback_type: ExtensionCallbackType
order: int
@dataclass
class CallbackFunctionWithMetadata:
metadata: CallbackMetadata
function: Callable[[DenoiseContext], None]
def callback(callback_type: ExtensionCallbackType, order: int = 0):
def _decorator(function):
function._ext_metadata = CallbackMetadata(
callback_type=callback_type,
order=order,
)
return function
return _decorator
class ExtensionBase:
def __init__(self):
self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
# Register all of the callback methods for this instance.
for func_name in dir(self):
func = getattr(self, func_name)
metadata = getattr(func, "_ext_metadata", None)
if metadata is not None and isinstance(metadata, CallbackMetadata):
if metadata.callback_type not in self._callbacks:
self._callbacks[metadata.callback_type] = []
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
def get_callbacks(self):
return self._callbacks
@contextmanager
def patch_extension(self, ctx: DenoiseContext):
yield None
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
yield None

View File

@@ -1,158 +0,0 @@
from __future__ import annotations
import math
from contextlib import contextmanager
from typing import TYPE_CHECKING, List, Optional, Union
import torch
from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.util.hotfixes import ControlNetModel
class ControlNetExt(ExtensionBase):
def __init__(
self,
model: ControlNetModel,
image: Image,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
control_mode: CONTROLNET_MODE_VALUES,
resize_mode: CONTROLNET_RESIZE_VALUES,
):
super().__init__()
self._model = model
self._image = image
self._weight = weight
self._begin_step_percent = begin_step_percent
self._end_step_percent = end_step_percent
self._control_mode = control_mode
self._resize_mode = resize_mode
self._image_tensor: Optional[torch.Tensor] = None
@contextmanager
def patch_extension(self, ctx: DenoiseContext):
original_processors = self._model.attn_processors
try:
self._model.set_attn_processor(ctx.inputs.attention_processor_cls())
yield None
finally:
self._model.set_attn_processor(original_processors)
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def resize_image(self, ctx: DenoiseContext):
_, _, latent_height, latent_width = ctx.latents.shape
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
self._image_tensor = prepare_control_image(
image=self._image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=ctx.latents.device,
dtype=ctx.latents.dtype,
control_mode=self._control_mode,
resize_mode=self._resize_mode,
)
@callback(ExtensionCallbackType.PRE_UNET)
def pre_unet_step(self, ctx: DenoiseContext):
# skip if model not active in current step
total_steps = len(ctx.inputs.timesteps)
first_step = math.floor(self._begin_step_percent * total_steps)
last_step = math.ceil(self._end_step_percent * total_steps)
if ctx.step_index < first_step or ctx.step_index > last_step:
return
# convert mode to internal flags
soft_injection = self._control_mode in ["more_prompt", "more_control"]
cfg_injection = self._control_mode in ["more_control", "unbalanced"]
# no negative conditioning in cfg_injection mode
if cfg_injection:
if ctx.conditioning_mode == ConditioningMode.Negative:
return
down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive)
if ctx.conditioning_mode == ConditioningMode.Both:
# add zeros as samples for negative conditioning
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
else:
down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode)
if (
ctx.unet_kwargs.down_block_additional_residuals is None
and ctx.unet_kwargs.mid_block_additional_residual is None
):
ctx.unet_kwargs.down_block_additional_residuals = down_samples
ctx.unet_kwargs.mid_block_additional_residual = mid_sample
else:
# add controlnet outputs together if have multiple controlnets
ctx.unet_kwargs.down_block_additional_residuals = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(
ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True
)
]
ctx.unet_kwargs.mid_block_additional_residual += mid_sample
def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode):
total_steps = len(ctx.inputs.timesteps)
model_input = ctx.latent_model_input
image_tensor = self._image_tensor
if conditioning_mode == ConditioningMode.Both:
model_input = torch.cat([model_input] * 2)
image_tensor = torch.cat([image_tensor] * 2)
cn_unet_kwargs = UNetKwargs(
sample=model_input,
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditioning
cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / total_steps,
),
)
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
# get static weight, or weight corresponding to current step
weight = self._weight
if isinstance(weight, list):
weight = weight[ctx.step_index]
tmp_kwargs = vars(cn_unet_kwargs)
# Remove kwargs not related to ControlNet unet
# ControlNet guidance fields
del tmp_kwargs["down_block_additional_residuals"]
del tmp_kwargs["mid_block_additional_residual"]
# T2i Adapter guidance fields
del tmp_kwargs["down_intrablock_additional_residuals"]
# controlnet(s) inference
down_samples, mid_sample = self._model(
controlnet_cond=image_tensor,
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
**vars(cn_unet_kwargs),
)
return down_samples, mid_sample

View File

@@ -1,35 +0,0 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import TYPE_CHECKING, Dict, Optional
import torch
from diffusers import UNet2DConditionModel
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
if TYPE_CHECKING:
from invokeai.app.shared.models import FreeUConfig
class FreeUExt(ExtensionBase):
def __init__(
self,
freeu_config: FreeUConfig,
):
super().__init__()
self._freeu_config = freeu_config
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
unet.enable_freeu(
b1=self._freeu_config.b1,
b2=self._freeu_config.b2,
s1=self._freeu_config.s1,
s2=self._freeu_config.s2,
)
try:
yield
finally:
unet.disable_freeu()

View File

@@ -1,63 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional
import torch
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
# TODO: change event to accept image instead of latents
@dataclass
class PipelineIntermediateState:
step: int
order: int
total_steps: int
timestep: int
latents: torch.Tensor
predicted_original: Optional[torch.Tensor] = None
class PreviewExt(ExtensionBase):
def __init__(self, callback: Callable[[PipelineIntermediateState], None]):
super().__init__()
self.callback = callback
# do last so that all other changes shown
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
def initial_preview(self, ctx: DenoiseContext):
self.callback(
PipelineIntermediateState(
step=-1,
order=ctx.scheduler.order,
total_steps=len(ctx.inputs.timesteps),
timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it?
latents=ctx.latents,
)
)
# do last so that all other changes shown
@callback(ExtensionCallbackType.POST_STEP, order=1000)
def step_preview(self, ctx: DenoiseContext):
if hasattr(ctx.step_output, "denoised"):
predicted_original = ctx.step_output.denoised
elif hasattr(ctx.step_output, "pred_original_sample"):
predicted_original = ctx.step_output.pred_original_sample
else:
predicted_original = ctx.step_output.prev_sample
self.callback(
PipelineIntermediateState(
step=ctx.step_index,
order=ctx.scheduler.order,
total_steps=len(ctx.inputs.timesteps),
timestep=int(ctx.timestep), # TODO: is there any code which uses it?
latents=ctx.step_output.prev_sample,
predicted_original=predicted_original, # TODO: is there any reason for additional field?
)
)

View File

@@ -1,36 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
class RescaleCFGExt(ExtensionBase):
def __init__(self, rescale_multiplier: float):
super().__init__()
self._rescale_multiplier = rescale_multiplier
@staticmethod
def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7):
"""Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf."""
ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True)
ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True)
x_rescaled = total_noise_pred * (ro_pos / ro_cfg)
x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred
return x_final
@callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS)
def rescale_noise_pred(self, ctx: DenoiseContext):
if self._rescale_multiplier > 0:
ctx.noise_pred = self._rescale_cfg(
ctx.noise_pred,
ctx.positive_noise_pred,
self._rescale_multiplier,
)

View File

@@ -1,75 +0,0 @@
from __future__ import annotations
from contextlib import ExitStack, contextmanager
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
import torch
from diffusers import UNet2DConditionModel
from invokeai.app.services.session_processor.session_processor_common import CanceledException
if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase
class ExtensionsManager:
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
self._is_canceled = is_canceled
# A list of extensions in the order that they were added to the ExtensionsManager.
self._extensions: List[ExtensionBase] = []
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
def add_extension(self, extension: ExtensionBase):
self._extensions.append(extension)
self._regenerate_ordered_callbacks()
def _regenerate_ordered_callbacks(self):
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
self._ordered_callbacks = {}
# Fill the ordered callbacks dictionary.
for extension in self._extensions:
for callback_type, callbacks in extension.get_callbacks().items():
if callback_type not in self._ordered_callbacks:
self._ordered_callbacks[callback_type] = []
self._ordered_callbacks[callback_type].extend(callbacks)
# Sort each callback list.
for callback_type, callbacks in self._ordered_callbacks.items():
# Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were
# added will be preserved.
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
if self._is_canceled and self._is_canceled():
raise CanceledException
callbacks = self._ordered_callbacks.get(callback_type, [])
for cb in callbacks:
cb.function(ctx)
@contextmanager
def patch_extensions(self, ctx: DenoiseContext):
if self._is_canceled and self._is_canceled():
raise CanceledException
with ExitStack() as exit_stack:
for ext in self._extensions:
exit_stack.enter_context(ext.patch_extension(ctx))
yield None
@contextmanager
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
if self._is_canceled and self._is_canceled():
raise CanceledException
# TODO: create weight patch logic in PR with extension which uses it
with ExitStack() as exit_stack:
for ext in self._extensions:
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
yield None

View File

@@ -61,7 +61,6 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
# full noise. Investigate the history of why this got commented out.
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
assert isinstance(latents, torch.Tensor) # For static type checking.
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
# cropping into regions.
@@ -123,42 +122,19 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
control_data=region_conditioning.control_data,
)
# Build a region_weight matrix that applies gradient blending to the edges of the region.
# Store the results from the region.
# If two tiles overlap by more than the target overlap amount, crop the left and top edges of the
# affected tiles to achieve the target overlap.
region = region_conditioning.region
_, _, region_height, region_width = step_output.prev_sample.shape
region_weight = torch.ones(
(1, 1, region_height, region_width),
dtype=latents.dtype,
device=latents.device,
)
if region.overlap.left > 0:
left_grad = torch.linspace(
0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype
).view((1, 1, 1, -1))
region_weight[:, :, :, : region.overlap.left] *= left_grad
if region.overlap.top > 0:
top_grad = torch.linspace(
0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype
).view((1, 1, -1, 1))
region_weight[:, :, : region.overlap.top, :] *= top_grad
if region.overlap.right > 0:
right_grad = torch.linspace(
1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype
).view((1, 1, 1, -1))
region_weight[:, :, :, -region.overlap.right :] *= right_grad
if region.overlap.bottom > 0:
bottom_grad = torch.linspace(
1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype
).view((1, 1, -1, 1))
region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad
# Update the merged results with the region results.
merged_latents[
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
] += step_output.prev_sample * region_weight
merged_latents_weights[
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
] += region_weight
top_adjustment = max(0, region.overlap.top - target_overlap)
left_adjustment = max(0, region.overlap.left - target_overlap)
region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom)
region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right)
merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[
:, :, top_adjustment:, left_adjustment:
]
# For now, we treat every region as having the same weight.
merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
if pred_orig_sample is not None:
@@ -166,9 +142,9 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
# they all use the same scheduler.
if merged_pred_original is None:
merged_pred_original = torch.zeros_like(latents)
merged_pred_original[
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
] += pred_orig_sample
merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[
:, :, top_adjustment:, left_adjustment:
]
# Normalize the merged results.
latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)

View File

@@ -65,12 +65,17 @@ class TextualInversionModelRaw(RawModel):
return result
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
if not torch.cuda.is_available():
return
for emb in [self.embedding, self.embedding_2]:
if emb is not None:
emb.to(device=device, dtype=dtype)
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
def calc_size(self) -> int:
"""Get the size of this model in bytes."""

View File

@@ -112,3 +112,15 @@ class TorchDevice:
@classmethod
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name]
@staticmethod
def get_non_blocking(to_device: torch.device) -> bool:
"""Return the non_blocking flag to be used when moving a tensor to a given device.
MPS may have unexpected errors with non-blocking operations - we should not use non-blocking when moving _to_ MPS.
When moving _from_ MPS, we can use non-blocking operations.
See:
- https://github.com/pytorch/pytorch/issues/107455
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28
"""
return False if to_device.type == "mps" else True

View File

@@ -155,8 +155,5 @@
"vite-plugin-eslint": "^1.8.1",
"vite-tsconfig-paths": "^4.3.2",
"vitest": "^1.6.0"
},
"engines": {
"pnpm": "8"
}
}

View File

@@ -77,6 +77,10 @@
"title": "استعادة الوجوه",
"desc": "استعادة الصورة الحالية"
},
"upscale": {
"title": "تحسين الحجم",
"desc": "تحسين حجم الصورة الحالية"
},
"showInfo": {
"title": "عرض المعلومات",
"desc": "عرض معلومات البيانات الخاصة بالصورة الحالية"
@@ -251,6 +255,8 @@
"type": "نوع",
"strength": "قوة",
"upscaling": "تصغير",
"upscale": "تصغير",
"upscaleImage": "تصغير الصورة",
"scale": "مقياس",
"imageFit": "ملائمة الصورة الأولية لحجم الخرج",
"scaleBeforeProcessing": "تحجيم قبل المعالجة",

View File

@@ -187,6 +187,10 @@
"title": "Gesicht restaurieren",
"desc": "Das aktuelle Bild restaurieren"
},
"upscale": {
"title": "Hochskalieren",
"desc": "Das aktuelle Bild hochskalieren"
},
"showInfo": {
"title": "Info anzeigen",
"desc": "Metadaten des aktuellen Bildes anzeigen"
@@ -429,6 +433,8 @@
"type": "Art",
"strength": "Stärke",
"upscaling": "Hochskalierung",
"upscale": "Hochskalieren (Shift + U)",
"upscaleImage": "Bild hochskalieren",
"scale": "Maßstab",
"imageFit": "Ausgangsbild an Ausgabegröße anpassen",
"scaleBeforeProcessing": "Skalieren vor der Verarbeitung",

View File

@@ -32,14 +32,12 @@
"deleteBoardAndImages": "Delete Board and Images",
"deleteBoardOnly": "Delete Board Only",
"deletedBoardsCannotbeRestored": "Deleted boards cannot be restored",
"hideBoards": "Hide Boards",
"loading": "Loading...",
"menuItemAutoAdd": "Auto-add to this Board",
"move": "Move",
"movingImagesToBoard_one": "Moving {{count}} image to board:",
"movingImagesToBoard_other": "Moving {{count}} images to board:",
"myBoard": "My Board",
"noBoards": "No {{boardType}} Boards",
"noMatching": "No matching Boards",
"private": "Private Boards",
"searchBoard": "Search Boards...",
@@ -48,7 +46,6 @@
"topMessage": "This board contains images used in the following features:",
"unarchiveBoard": "Unarchive Board",
"uncategorized": "Uncategorized",
"viewBoards": "View Boards",
"downloadBoard": "Download Board",
"imagesWithCount_one": "{{count}} image",
"imagesWithCount_other": "{{count}} images",
@@ -105,7 +102,6 @@
"negativePrompt": "Negative Prompt",
"discordLabel": "Discord",
"dontAskMeAgain": "Don't ask me again",
"dontShowMeThese": "Don't show me these",
"editor": "Editor",
"error": "Error",
"file": "File",
@@ -377,14 +373,10 @@
"displayBoardSearch": "Display Board Search",
"displaySearch": "Display Search",
"download": "Download",
"exitBoardSearch": "Exit Board Search",
"exitSearch": "Exit Search",
"featuresWillReset": "If you delete this image, those features will immediately be reset.",
"galleryImageSize": "Image Size",
"gallerySettings": "Gallery Settings",
"go": "Go",
"image": "image",
"jump": "Jump",
"loading": "Loading",
"loadMore": "Load More",
"newestFirst": "Newest First",
@@ -644,9 +636,9 @@
"title": "Undo Stroke"
},
"unifiedCanvasHotkeys": "Unified Canvas",
"postProcess": {
"desc": "Process the current image using the selected post-processing model",
"title": "Process Image"
"upscale": {
"desc": "Upscale the current image",
"title": "Upscale"
},
"toggleViewer": {
"desc": "Switches between the Image Viewer and workspace for the current tab.",
@@ -1035,7 +1027,6 @@
"imageActions": "Image Actions",
"sendToImg2Img": "Send to Image to Image",
"sendToUnifiedCanvas": "Send To Unified Canvas",
"sendToUpscale": "Send To Upscale",
"showOptionsPanel": "Show Side Panel (O or T)",
"shuffle": "Shuffle Seed",
"steps": "Steps",
@@ -1043,8 +1034,8 @@
"symmetry": "Symmetry",
"tileSize": "Tile Size",
"type": "Type",
"postProcessing": "Post-Processing (Shift + U)",
"processImage": "Process Image",
"upscale": "Upscale (Shift + U)",
"upscaleImage": "Upscale Image",
"upscaling": "Upscaling",
"useAll": "Use All",
"useSize": "Use Size",
@@ -1100,8 +1091,6 @@
"displayInProgress": "Display Progress Images",
"enableImageDebugging": "Enable Image Debugging",
"enableInformationalPopovers": "Enable Informational Popovers",
"informationalPopoversDisabled": "Informational Popovers Disabled",
"informationalPopoversDisabledDesc": "Informational popovers have been disabled. Enable them in Settings.",
"enableInvisibleWatermark": "Enable Invisible Watermark",
"enableNSFWChecker": "Enable NSFW Checker",
"general": "General",
@@ -1509,30 +1498,6 @@
"seamlessTilingYAxis": {
"heading": "Seamless Tiling Y Axis",
"paragraphs": ["Seamlessly tile an image along the vertical axis."]
},
"upscaleModel": {
"heading": "Upscale Model",
"paragraphs": [
"The upscale model scales the image to the output size before details are added. Any supported upscale model may be used, but some are specialized for different kinds of images, like photos or line drawings."
]
},
"scale": {
"heading": "Scale",
"paragraphs": [
"Scale controls the output image size, and is based on a multiple of the input image resolution. For example a 2x upscale on a 1024x1024 image would produce a 2048 x 2048 output."
]
},
"creativity": {
"heading": "Creativity",
"paragraphs": [
"Creativity controls the amount of freedom granted to the model when adding details. Low creativity stays close to the original image, while high creativity allows for more change. When using a prompt, high creativity increases the influence of the prompt."
]
},
"structure": {
"heading": "Structure",
"paragraphs": [
"Structure controls how closely the output image will keep to the layout of the original. Low structure allows major changes, while high structure strictly maintains the original composition and layout."
]
}
},
"unifiedCanvas": {
@@ -1675,27 +1640,6 @@
"layers_one": "Layer",
"layers_other": "Layers"
},
"upscaling": {
"creativity": "Creativity",
"structure": "Structure",
"upscaleModel": "Upscale Model",
"postProcessingModel": "Post-Processing Model",
"scale": "Scale",
"postProcessingMissingModelWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install a post-processing (image to image) model.",
"missingModelsWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install the required models:",
"mainModelDesc": "Main model (SD1.5 or SDXL architecture)",
"tileControlNetModelDesc": "Tile ControlNet model for the chosen main model architecture",
"upscaleModelDesc": "Upscale (image to image) model",
"missingUpscaleInitialImage": "Missing initial image for upscaling",
"missingUpscaleModel": "Missing upscale model",
"missingTileControlNetModel": "No valid tile ControlNet models installed"
},
"upsell": {
"inviteTeammates": "Invite Teammates",
"professional": "Professional",
"professionalUpsell": "Available in Invokes Professional Edition. Click here or visit invoke.com/pricing for more details.",
"shareAccess": "Share Access"
},
"ui": {
"tabs": {
"generation": "Generation",
@@ -1707,9 +1651,7 @@
"models": "Models",
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
"queue": "Queue",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"upscaling": "Upscaling",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
"queueTab": "$t(ui.tabs.queue) $t(common.tab)"
}
}
}

View File

@@ -151,6 +151,10 @@
"title": "Restaurar rostros",
"desc": "Restaurar rostros en la imagen actual"
},
"upscale": {
"title": "Aumentar resolución",
"desc": "Aumentar la resolución de la imagen actual"
},
"showInfo": {
"title": "Mostrar información",
"desc": "Mostar metadatos de la imagen actual"
@@ -356,6 +360,8 @@
"type": "Tipo",
"strength": "Fuerza",
"upscaling": "Aumento de resolución",
"upscale": "Aumentar resolución",
"upscaleImage": "Aumentar la resolución de la imagen",
"scale": "Escala",
"imageFit": "Ajuste tamaño de imagen inicial al tamaño objetivo",
"scaleBeforeProcessing": "Redimensionar antes de procesar",
@@ -402,12 +408,7 @@
"showProgressInViewer": "Mostrar las imágenes del progreso en el visor",
"ui": "Interfaz del usuario",
"generation": "Generación",
"beta": "Beta",
"reloadingIn": "Recargando en",
"intermediatesClearedFailed": "Error limpiando los intermediarios",
"intermediatesCleared_one": "Borrado {{count}} intermediario",
"intermediatesCleared_many": "Borrados {{count}} intermediarios",
"intermediatesCleared_other": "Borrados {{count}} intermediarios"
"beta": "Beta"
},
"toast": {
"uploadFailed": "Error al subir archivo",
@@ -425,12 +426,7 @@
"parameterSet": "Conjunto de parámetros",
"parameterNotSet": "Parámetro no configurado",
"problemCopyingImage": "No se puede copiar la imagen",
"errorCopied": "Error al copiar",
"baseModelChanged": "Modelo base cambiado",
"addedToBoard": "Añadido al tablero",
"baseModelChangedCleared_one": "Borrado o desactivado {{count}} submodelo incompatible",
"baseModelChangedCleared_many": "Borrados o desactivados {{count}} submodelos incompatibles",
"baseModelChangedCleared_other": "Borrados o desactivados {{count}} submodelos incompatibles"
"errorCopied": "Error al copiar"
},
"tooltip": {
"feature": {
@@ -544,13 +540,7 @@
"downloadBoard": "Descargar panel",
"deleteBoardOnly": "Borrar solo el panel",
"myBoard": "Mi panel",
"noMatching": "No hay paneles que coincidan",
"imagesWithCount_one": "{{count}} imagen",
"imagesWithCount_many": "{{count}} imágenes",
"imagesWithCount_other": "{{count}} imágenes",
"assetsWithCount_one": "{{count}} activo",
"assetsWithCount_many": "{{count}} activos",
"assetsWithCount_other": "{{count}} activos"
"noMatching": "No hay paneles que coincidan"
},
"accordions": {
"compositing": {
@@ -600,27 +590,6 @@
"balanced": "Equilibrado",
"beginEndStepPercent": "Inicio / Final Porcentaje de pasos",
"detectResolution": "Detectar resolución",
"beginEndStepPercentShort": "Inicio / Final %",
"t2i_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.t2iAdapter))",
"controlnet": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.controlNet))",
"ip_adapter": "$t(controlnet.controlAdapter_one) #{{number}} ($t(common.ipAdapter))",
"addControlNet": "Añadir $t(common.controlNet)",
"addIPAdapter": "Añadir $t(common.ipAdapter)",
"controlAdapter_one": "Adaptador de control",
"controlAdapter_many": "Adaptadores de control",
"controlAdapter_other": "Adaptadores de control",
"addT2IAdapter": "Añadir $t(common.t2iAdapter)"
},
"queue": {
"back": "Atrás",
"front": "Delante",
"batchQueuedDesc_one": "Se agregó {{count}} sesión a {{direction}} la cola",
"batchQueuedDesc_many": "Se agregaron {{count}} sesiones a {{direction}} la cola",
"batchQueuedDesc_other": "Se agregaron {{count}} sesiones a {{direction}} la cola"
},
"upsell": {
"inviteTeammates": "Invitar compañeros de equipo",
"shareAccess": "Compartir acceso",
"professionalUpsell": "Disponible en la edición profesional de Invoke. Haz clic aquí o visita invoke.com/pricing para obtener más detalles."
"beginEndStepPercentShort": "Inicio / Final %"
}
}

View File

@@ -130,6 +130,10 @@
"title": "Restaurer les visages",
"desc": "Restaurer l'image actuelle"
},
"upscale": {
"title": "Agrandir",
"desc": "Agrandir l'image actuelle"
},
"showInfo": {
"title": "Afficher les informations",
"desc": "Afficher les informations de métadonnées de l'image actuelle"
@@ -304,6 +308,8 @@
"type": "Type",
"strength": "Force",
"upscaling": "Agrandissement",
"upscale": "Agrandir",
"upscaleImage": "Image en Agrandissement",
"scale": "Echelle",
"imageFit": "Ajuster Image Initiale à la Taille de Sortie",
"scaleBeforeProcessing": "Echelle Avant Traitement",

View File

@@ -90,6 +90,10 @@
"desc": "שחזור התמונה הנוכחית",
"title": "שחזור פרצופים"
},
"upscale": {
"title": "הגדלת קנה מידה",
"desc": "הגדל את התמונה הנוכחית"
},
"showInfo": {
"title": "הצג מידע",
"desc": "הצגת פרטי מטא-נתונים של התמונה הנוכחית"
@@ -259,6 +263,8 @@
"seed": "זרע",
"type": "סוג",
"strength": "חוזק",
"upscale": "הגדלת קנה מידה",
"upscaleImage": "הגדלת קנה מידת התמונה",
"denoisingStrength": "חוזק מנטרל הרעש",
"scaleBeforeProcessing": "שנה קנה מידה לפני עיבוד",
"scaledWidth": "קנה מידה לאחר שינוי W",

View File

@@ -150,11 +150,7 @@
"showArchivedBoards": "Mostra le bacheche archiviate",
"searchImages": "Ricerca per metadati",
"displayBoardSearch": "Mostra la ricerca nelle Bacheche",
"displaySearch": "Mostra la ricerca",
"selectAllOnPage": "Seleziona tutto nella pagina",
"selectAllOnBoard": "Seleziona tutto nella bacheca",
"exitBoardSearch": "Esci da Ricerca bacheca",
"exitSearch": "Esci dalla ricerca"
"displaySearch": "Mostra la ricerca"
},
"hotkeys": {
"keyboardShortcuts": "Tasti di scelta rapida",
@@ -214,6 +210,10 @@
"title": "Restaura volti",
"desc": "Restaura l'immagine corrente"
},
"upscale": {
"title": "Amplia",
"desc": "Amplia l'immagine corrente"
},
"showInfo": {
"title": "Mostra informazioni",
"desc": "Mostra le informazioni sui metadati dell'immagine corrente"
@@ -377,10 +377,6 @@
"toggleViewer": {
"title": "Attiva/disattiva il visualizzatore di immagini",
"desc": "Passa dal visualizzatore immagini all'area di lavoro per la scheda corrente."
},
"postProcess": {
"desc": "Elabora l'immagine corrente utilizzando il modello di post-elaborazione selezionato",
"title": "Elabora immagine"
}
},
"modelManager": {
@@ -509,6 +505,8 @@
"type": "Tipo",
"strength": "Forza",
"upscaling": "Ampliamento",
"upscale": "Amplia (Shift + U)",
"upscaleImage": "Amplia Immagine",
"scale": "Scala",
"imageFit": "Adatta l'immagine iniziale alle dimensioni di output",
"scaleBeforeProcessing": "Scala prima dell'elaborazione",
@@ -593,10 +591,7 @@
"infillColorValue": "Colore di riempimento",
"globalSettings": "Impostazioni globali",
"globalPositivePromptPlaceholder": "Prompt positivo globale",
"globalNegativePromptPlaceholder": "Prompt negativo globale",
"processImage": "Elabora Immagine",
"sendToUpscale": "Invia a Ampliare",
"postProcessing": "Post-elaborazione (Shift + U)"
"globalNegativePromptPlaceholder": "Prompt negativo globale"
},
"settings": {
"models": "Modelli",
@@ -969,10 +964,7 @@
"boards": "Bacheche",
"private": "Bacheche private",
"shared": "Bacheche condivise",
"addPrivateBoard": "Aggiungi una Bacheca Privata",
"noBoards": "Nessuna bacheca {{boardType}}",
"hideBoards": "Nascondi bacheche",
"viewBoards": "Visualizza bacheche"
"addPrivateBoard": "Aggiungi una Bacheca Privata"
},
"controlnet": {
"contentShuffleDescription": "Rimescola il contenuto di un'immagine",
@@ -1692,30 +1684,7 @@
"models": "Modelli",
"modelsTab": "$t(ui.tabs.models) $t(common.tab)",
"queue": "Coda",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"upscaling": "Ampliamento",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)"
"queueTab": "$t(ui.tabs.queue) $t(common.tab)"
}
},
"upscaling": {
"creativity": "Creatività",
"structure": "Struttura",
"upscaleModel": "Modello di Ampliamento",
"scale": "Scala",
"missingModelsWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare i modelli richiesti:",
"mainModelDesc": "Modello principale (architettura SD1.5 o SDXL)",
"tileControlNetModelDesc": "Modello Tile ControlNet per l'architettura del modello principale scelto",
"upscaleModelDesc": "Modello per l'ampliamento (da immagine a immagine)",
"missingUpscaleInitialImage": "Immagine iniziale mancante per l'ampliamento",
"missingUpscaleModel": "Modello per lampliamento mancante",
"missingTileControlNetModel": "Nessun modello ControlNet Tile valido installato",
"postProcessingModel": "Modello di post-elaborazione",
"postProcessingMissingModelWarning": "Visita <LinkComponent>Gestione modelli</LinkComponent> per installare un modello di post-elaborazione (da immagine a immagine)."
},
"upsell": {
"inviteTeammates": "Invita collaboratori",
"shareAccess": "Condividi l'accesso",
"professional": "Professionale",
"professionalUpsell": "Disponibile nell'edizione Professional di Invoke. Fai clic qui o visita invoke.com/pricing per ulteriori dettagli."
}
}

View File

@@ -199,6 +199,10 @@
"title": "顔の修復",
"desc": "現在の画像を修復"
},
"upscale": {
"title": "アップスケール",
"desc": "現在の画像をアップスケール"
},
"showInfo": {
"title": "情報を見る",
"desc": "現在の画像のメタデータ情報を表示"
@@ -423,6 +427,8 @@
"shuffle": "シャッフル",
"strength": "強度",
"upscaling": "アップスケーリング",
"upscale": "アップスケール",
"upscaleImage": "画像をアップスケール",
"scale": "Scale",
"scaleBeforeProcessing": "処理前のスケール",
"scaledWidth": "幅のスケール",

View File

@@ -258,6 +258,10 @@
"desc": "캔버스 브러시를 선택",
"title": "브러시 선택"
},
"upscale": {
"desc": "현재 이미지를 업스케일",
"title": "업스케일"
},
"previousImage": {
"title": "이전 이미지",
"desc": "갤러리에 이전 이미지 표시"

View File

@@ -168,6 +168,10 @@
"title": "Herstel gezichten",
"desc": "Herstelt de huidige afbeelding"
},
"upscale": {
"title": "Schaal op",
"desc": "Schaalt de huidige afbeelding op"
},
"showInfo": {
"title": "Toon info",
"desc": "Toont de metagegevens van de huidige afbeelding"
@@ -408,6 +412,8 @@
"type": "Soort",
"strength": "Sterkte",
"upscaling": "Opschalen",
"upscale": "Vergroot (Shift + U)",
"upscaleImage": "Schaal afbeelding op",
"scale": "Schaal",
"imageFit": "Pas initiële afbeelding in uitvoergrootte",
"scaleBeforeProcessing": "Schalen voor verwerking",

View File

@@ -78,6 +78,10 @@
"title": "Popraw twarze",
"desc": "Uruchamia proces poprawiania twarzy dla aktywnego obrazu"
},
"upscale": {
"title": "Powiększ",
"desc": "Uruchamia proces powiększania aktywnego obrazu"
},
"showInfo": {
"title": "Pokaż informacje",
"desc": "Pokazuje metadane zapisane w aktywnym obrazie"
@@ -228,6 +232,8 @@
"type": "Metoda",
"strength": "Siła",
"upscaling": "Powiększanie",
"upscale": "Powiększ",
"upscaleImage": "Powiększ obraz",
"scale": "Skala",
"imageFit": "Przeskaluj oryginalny obraz",
"scaleBeforeProcessing": "Tryb skalowania",

View File

@@ -160,6 +160,10 @@
"title": "Restaurar Rostos",
"desc": "Restaurar a imagem atual"
},
"upscale": {
"title": "Redimensionar",
"desc": "Redimensionar a imagem atual"
},
"showInfo": {
"title": "Mostrar Informações",
"desc": "Mostrar metadados de informações da imagem atual"
@@ -271,6 +275,8 @@
"showOptionsPanel": "Mostrar Painel de Opções",
"strength": "Força",
"upscaling": "Redimensionando",
"upscale": "Redimensionar",
"upscaleImage": "Redimensionar Imagem",
"scaleBeforeProcessing": "Escala Antes do Processamento",
"images": "Imagems",
"steps": "Passos",

View File

@@ -80,6 +80,10 @@
"title": "Restaurar Rostos",
"desc": "Restaurar a imagem atual"
},
"upscale": {
"title": "Redimensionar",
"desc": "Redimensionar a imagem atual"
},
"showInfo": {
"title": "Mostrar Informações",
"desc": "Mostrar metadados de informações da imagem atual"
@@ -264,6 +268,8 @@
"type": "Tipo",
"strength": "Força",
"upscaling": "Redimensionando",
"upscale": "Redimensionar",
"upscaleImage": "Redimensionar Imagem",
"scale": "Escala",
"imageFit": "Caber Imagem Inicial No Tamanho de Saída",
"scaleBeforeProcessing": "Escala Antes do Processamento",

View File

@@ -214,6 +214,10 @@
"title": "Восстановить лица",
"desc": "Восстановить лица на текущем изображении"
},
"upscale": {
"title": "Увеличение",
"desc": "Увеличить текущеее изображение"
},
"showInfo": {
"title": "Показать метаданные",
"desc": "Показать метаданные из текущего изображения"
@@ -508,6 +512,8 @@
"type": "Тип",
"strength": "Сила",
"upscaling": "Увеличение",
"upscale": "Увеличить",
"upscaleImage": "Увеличить изображение",
"scale": "Масштаб",
"imageFit": "Уместить изображение",
"scaleBeforeProcessing": "Масштабировать",

View File

@@ -90,6 +90,10 @@
"title": "Återskapa ansikten",
"desc": "Återskapa nuvarande bild"
},
"upscale": {
"title": "Skala upp",
"desc": "Skala upp nuvarande bild"
},
"showInfo": {
"title": "Visa info",
"desc": "Visa metadata för nuvarande bild"

View File

@@ -416,6 +416,10 @@
"desc": "Maske/Taban katmanları arasında geçiş yapar",
"title": "Katmanı Gizle-Göster"
},
"upscale": {
"title": "Büyüt",
"desc": "Seçili görseli büyüt"
},
"setSeed": {
"title": "Tohumu Kullan",
"desc": "Seçili görselin tohumunu kullan"
@@ -637,6 +641,7 @@
"copyImage": "Görseli Kopyala",
"height": "Boy",
"width": "En",
"upscale": "Büyüt (Shift + U)",
"useSize": "Boyutu Kullan",
"symmetry": "Bakışım",
"tileSize": "Döşeme Boyutu",
@@ -652,6 +657,7 @@
"showOptionsPanel": "Yan Paneli Göster (O ya da T)",
"shuffle": "Kar",
"usePrompt": "İstemi Kullan",
"upscaleImage": "Görseli Büyüt",
"setToOptimalSizeTooSmall": "$t(parameters.setToOptimalSize) (çok küçük olabilir)",
"setToOptimalSizeTooLarge": "$t(parameters.setToOptimalSize) (çok büyük olabilir)",
"cfgRescaleMultiplier": "CFG Rescale Çarpanı",

View File

@@ -85,6 +85,10 @@
"title": "Відновити обличчя",
"desc": "Відновити обличчя на поточному зображенні"
},
"upscale": {
"title": "Збільшення",
"desc": "Збільшити поточне зображення"
},
"showInfo": {
"title": "Показати метадані",
"desc": "Показати метадані з поточного зображення"
@@ -272,6 +276,8 @@
"type": "Тип",
"strength": "Сила",
"upscaling": "Збільшення",
"upscale": "Збільшити",
"upscaleImage": "Збільшити зображення",
"scale": "Масштаб",
"imageFit": "Вмістити зображення",
"scaleBeforeProcessing": "Масштабувати",

View File

@@ -193,6 +193,10 @@
"title": "面部修复",
"desc": "对当前图像进行面部修复"
},
"upscale": {
"title": "放大",
"desc": "对当前图像进行放大"
},
"showInfo": {
"title": "显示信息",
"desc": "显示当前图像的元数据"
@@ -418,6 +422,8 @@
"type": "种类",
"strength": "强度",
"upscaling": "放大",
"upscale": "放大 (Shift + U)",
"upscaleImage": "放大图像",
"scale": "等级",
"imageFit": "使生成图像长宽适配初始图像",
"scaleBeforeProcessing": "处理前缩放",

View File

@@ -1,6 +1,5 @@
import type { TypedStartListening } from '@reduxjs/toolkit';
import { createListenerMiddleware } from '@reduxjs/toolkit';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addCommitStagingAreaImageListener } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
@@ -48,11 +47,11 @@ import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddlewa
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
import { addEnqueueRequestedUpscale } from './listeners/enqueueRequestedUpscale';
export const listenerMiddleware = createListenerMiddleware();
@@ -86,7 +85,6 @@ addGalleryOffsetChangedListener(startAppListening);
addEnqueueRequestedCanvasListener(startAppListening);
addEnqueueRequestedNodes(startAppListening);
addEnqueueRequestedLinear(startAppListening);
addEnqueueRequestedUpscale(startAppListening);
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
@@ -142,7 +140,7 @@ addModelsLoadedListener(startAppListening);
addAppConfigReceivedListener(startAppListening);
// Ad-hoc upscale workflwo
addAdHocPostProcessingRequestedListener(startAppListening);
addUpscaleRequestedListener(startAppListening);
// Prompts
addDynamicPromptsListener(startAppListening);

View File

@@ -1,36 +0,0 @@
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
import { queueApi } from 'services/api/endpoints/queue';
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
startAppListening({
predicate: (action): action is ReturnType<typeof enqueueRequested> =>
enqueueRequested.match(action) && action.payload.tabName === 'upscaling',
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { shouldShowProgressInViewer } = state.ui;
const { prepend } = action.payload;
const graph = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch(state, graph, prepend);
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
fixedCacheKey: 'enqueueBatch',
})
);
try {
await req.unwrap();
if (shouldShowProgressInViewer) {
dispatch(isImageViewerOpenChanged(true));
}
} finally {
req.reset();
}
},
});
};

View File

@@ -23,7 +23,6 @@ import {
} from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { imagesApi } from 'services/api/endpoints/images';
export const dndDropped = createAction<{
@@ -244,20 +243,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
return;
}
/**
* Image dropped on upscale initial image
*/
if (
overData.actionType === 'SET_UPSCALE_INITIAL_IMAGE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(upscaleInitialImageChanged(imageDTO));
return;
}
/**
* Multiple images dropped on user board
*/

View File

@@ -14,7 +14,6 @@ import {
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { omit } from 'lodash-es';
@@ -90,15 +89,6 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
return;
}
if (postUploadAction?.type === 'SET_UPSCALE_INITIAL_IMAGE') {
dispatch(upscaleInitialImageChanged(imageDTO));
toast({
...DEFAULT_UPLOADED_TOAST,
description: 'set as upscale initial image',
});
return;
}
if (postUploadAction?.type === 'SET_CONTROL_ADAPTER_IMAGE') {
const { id } = postUploadAction;
dispatch(

View File

@@ -10,7 +10,6 @@ import { heightChanged, widthChanged } from 'features/controlLayers/store/contro
import { loraRemoved } from 'features/lora/store/loraSlice';
import { calculateNewSize } from 'features/parameters/components/ImageSize/calculateNewSize';
import { modelChanged, vaeSelected } from 'features/parameters/store/generationSlice';
import { postProcessingModelChanged, upscaleModelChanged } from 'features/parameters/store/upscaleSlice';
import { zParameterModel, zParameterVAEModel } from 'features/parameters/types/parameterSchemas';
import { getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { refinerModelChanged } from 'features/sdxl/store/sdxlSlice';
@@ -18,12 +17,7 @@ import { forEach } from 'lodash-es';
import type { Logger } from 'roarr';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig } from 'services/api/types';
import {
isNonRefinerMainModelConfig,
isRefinerMainModelModelConfig,
isSpandrelImageToImageModelConfig,
isVAEModelConfig,
} from 'services/api/types';
import { isNonRefinerMainModelConfig, isRefinerMainModelModelConfig, isVAEModelConfig } from 'services/api/types';
export const addModelsLoadedListener = (startAppListening: AppStartListening) => {
startAppListening({
@@ -42,7 +36,6 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
handleVAEModels(models, state, dispatch, log);
handleLoRAModels(models, state, dispatch, log);
handleControlAdapterModels(models, state, dispatch, log);
handleSpandrelImageToImageModels(models, state, dispatch, log);
},
});
};
@@ -184,25 +177,3 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, _log)
dispatch(controlAdapterModelCleared({ id: ca.id }));
});
};
const handleSpandrelImageToImageModels: ModelHandler = (models, state, dispatch, _log) => {
const { upscaleModel: currentUpscaleModel, postProcessingModel: currentPostProcessingModel } = state.upscale;
const upscaleModels = models.filter(isSpandrelImageToImageModelConfig);
const firstModel = upscaleModels[0] || null;
const isCurrentUpscaleModelAvailable = currentUpscaleModel
? upscaleModels.some((m) => m.key === currentUpscaleModel.key)
: false;
if (!isCurrentUpscaleModelAvailable) {
dispatch(upscaleModelChanged(firstModel));
}
const isCurrentPostProcessingModelAvailable = currentPostProcessingModel
? upscaleModels.some((m) => m.key === currentPostProcessingModel.key)
: false;
if (!isCurrentPostProcessingModelAvailable) {
dispatch(postProcessingModelChanged(firstModel));
}
};

View File

@@ -13,7 +13,6 @@ import {
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { socketInvocationComplete } from 'services/events/actions';
@@ -52,14 +51,6 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
}
if (!imageDTO.is_intermediate) {
// update the total images for the board
dispatch(
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
draft.total += 1;
})
);
dispatch(
imagesApi.util.invalidateTags([
{ type: 'Board', id: imageDTO.board_id ?? 'none' },

View File

@@ -2,28 +2,46 @@ import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { buildAdHocPostProcessingGraph } from 'features/nodes/util/graph/buildAdHocPostProcessingGraph';
import { buildAdHocUpscaleGraph } from 'features/nodes/util/graph/buildAdHocUpscaleGraph';
import { createIsAllowedToUpscaleSelector } from 'features/parameters/hooks/useIsAllowedToUpscale';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig, ImageDTO } from 'services/api/types';
export const adHocPostProcessingRequested = createAction<{ imageDTO: ImageDTO }>(`upscaling/postProcessingRequested`);
export const upscaleRequested = createAction<{ imageDTO: ImageDTO }>(`upscale/upscaleRequested`);
export const addAdHocPostProcessingRequestedListener = (startAppListening: AppStartListening) => {
export const addUpscaleRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: adHocPostProcessingRequested,
actionCreator: upscaleRequested,
effect: async (action, { dispatch, getState }) => {
const log = logger('session');
const { imageDTO } = action.payload;
const { image_name } = imageDTO;
const state = getState();
const { isAllowedToUpscale, detailTKey } = createIsAllowedToUpscaleSelector(imageDTO)(state);
// if we can't upscale, show a toast and return
if (!isAllowedToUpscale) {
log.error(
{ imageDTO },
t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge') // should never coalesce
);
toast({
id: 'NOT_ALLOWED_TO_UPSCALE',
title: t(detailTKey ?? 'parameters.isAllowedToUpscale.tooLarge'), // should never coalesce
status: 'error',
});
return;
}
const enqueueBatchArg: BatchConfig = {
prepend: true,
batch: {
graph: await buildAdHocPostProcessingGraph({
image: imageDTO,
graph: buildAdHocUpscaleGraph({
image_name,
state,
}),
runs: 1,

View File

@@ -25,7 +25,7 @@ import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/no
import { workflowSettingsPersistConfig, workflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice';
import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice';
import { upscalePersistConfig, upscaleSlice } from 'features/parameters/store/upscaleSlice';
import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice';
import { queueSlice } from 'features/queue/store/queueSlice';
import { sdxlPersistConfig, sdxlSlice } from 'features/sdxl/store/sdxlSlice';
import { configSlice } from 'features/system/store/configSlice';
@@ -52,6 +52,7 @@ const allReducers = {
[gallerySlice.name]: gallerySlice.reducer,
[generationSlice.name]: generationSlice.reducer,
[nodesSlice.name]: undoable(nodesSlice.reducer, nodesUndoableConfig),
[postprocessingSlice.name]: postprocessingSlice.reducer,
[systemSlice.name]: systemSlice.reducer,
[configSlice.name]: configSlice.reducer,
[uiSlice.name]: uiSlice.reducer,
@@ -68,7 +69,6 @@ const allReducers = {
[controlLayersSlice.name]: undoable(controlLayersSlice.reducer, controlLayersUndoableConfig),
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
[api.reducerPath]: api.reducer,
[upscaleSlice.name]: upscaleSlice.reducer,
};
const rootReducer = combineReducers(allReducers);
@@ -102,6 +102,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[galleryPersistConfig.name]: galleryPersistConfig,
[generationPersistConfig.name]: generationPersistConfig,
[nodesPersistConfig.name]: nodesPersistConfig,
[postprocessingPersistConfig.name]: postprocessingPersistConfig,
[systemPersistConfig.name]: systemPersistConfig,
[workflowPersistConfig.name]: workflowPersistConfig,
[uiPersistConfig.name]: uiPersistConfig,
@@ -113,7 +114,6 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[hrfPersistConfig.name]: hrfPersistConfig,
[controlLayersPersistConfig.name]: controlLayersPersistConfig,
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
[upscalePersistConfig.name]: upscalePersistConfig,
};
const unserialize: UnserializeFunction = (data, key) => {

View File

@@ -72,6 +72,7 @@ export type AppConfig = {
canRestoreDeletedImagesFromBin: boolean;
nodesAllowlist: string[] | undefined;
nodesDenylist: string[] | undefined;
maxUpscalePixels?: number;
metadataFetchDebounce?: number;
workflowFetchDebounce?: number;
isLocal?: boolean;

View File

@@ -10,12 +10,9 @@ import {
PopoverContent,
PopoverTrigger,
Portal,
Spacer,
Text,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setShouldEnableInformationalPopovers } from 'features/system/store/systemSlice';
import { toast } from 'features/toast/toast';
import { useAppSelector } from 'app/store/storeHooks';
import { merge, omit } from 'lodash-es';
import type { ReactElement } from 'react';
import { memo, useCallback, useMemo } from 'react';
@@ -74,7 +71,7 @@ type ContentProps = {
const Content = ({ data, feature }: ContentProps) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const heading = useMemo<string | undefined>(() => t(`popovers.${feature}.heading`), [feature, t]);
const paragraphs = useMemo<string[]>(
@@ -85,25 +82,16 @@ const Content = ({ data, feature }: ContentProps) => {
[feature, t]
);
const onClickLearnMore = useCallback(() => {
const handleClick = useCallback(() => {
if (!data?.href) {
return;
}
window.open(data.href);
}, [data?.href]);
const onClickDontShowMeThese = useCallback(() => {
dispatch(setShouldEnableInformationalPopovers(false));
toast({
title: t('settings.informationalPopoversDisabled'),
description: t('settings.informationalPopoversDisabledDesc'),
status: 'info',
});
}, [dispatch, t]);
return (
<PopoverContent maxW={300}>
<PopoverCloseButton top={2} />
<PopoverContent w={96}>
<PopoverCloseButton />
<PopoverBody>
<Flex gap={2} flexDirection="column" alignItems="flex-start">
{heading && (
@@ -128,19 +116,20 @@ const Content = ({ data, feature }: ContentProps) => {
{paragraphs.map((p) => (
<Text key={p}>{p}</Text>
))}
<Divider />
<Flex alignItems="center" justifyContent="space-between" w="full">
<Button onClick={onClickDontShowMeThese} variant="link" size="sm">
{t('common.dontShowMeThese')}
</Button>
<Spacer />
{data?.href && (
<Button onClick={onClickLearnMore} leftIcon={<PiArrowSquareOutBold />} variant="link" size="sm">
{data?.href && (
<>
<Divider />
<Button
pt={1}
onClick={handleClick}
leftIcon={<PiArrowSquareOutBold />}
alignSelf="flex-end"
variant="link"
>
{t('common.learnMore') ?? heading}
</Button>
)}
</Flex>
</>
)}
</Flex>
</PopoverBody>
</PopoverContent>

View File

@@ -53,11 +53,7 @@ export type Feature =
| 'refinerCfgScale'
| 'scaleBeforeProcessing'
| 'seamlessTilingXAxis'
| 'seamlessTilingYAxis'
| 'upscaleModel'
| 'scale'
| 'creativity'
| 'structure';
| 'seamlessTilingYAxis';
export type PopoverData = PopoverProps & {
image?: string;

View File

@@ -1,18 +0,0 @@
import { useEffect } from 'react';
import { assert } from 'tsafe';
const IDS = new Set<string>();
/**
* Asserts that there is only one instance of a singleton entity. It can be a hook or a component.
* @param id The ID of the singleton entity.
*/
export function useAssertSingleton(id: string) {
useEffect(() => {
assert(!IDS.has(id), `There should be only one instance of ${id}`);
IDS.add(id);
return () => {
IDS.delete(id);
};
}, [id]);
}

View File

@@ -21,10 +21,6 @@ const selectPostUploadAction = createMemoizedSelector(activeTabNameSelector, (ac
postUploadAction = { type: 'SET_CANVAS_INITIAL_IMAGE' };
}
if (activeTabName === 'upscaling') {
postUploadAction = { type: 'SET_UPSCALE_INITIAL_IMAGE' };
}
return postUploadAction;
});

View File

@@ -15,7 +15,6 @@ import type { Templates } from 'features/nodes/store/types';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { selectUpscalelice } from 'features/parameters/store/upscaleSlice';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import i18n from 'i18next';
@@ -41,19 +40,8 @@ const createSelector = (templates: Templates) =>
selectDynamicPromptsSlice,
selectControlLayersSlice,
activeTabNameSelector,
selectUpscalelice,
],
(
controlAdapters,
generation,
system,
nodes,
workflowSettings,
dynamicPrompts,
controlLayers,
activeTabName,
upscale
) => {
(controlAdapters, generation, system, nodes, workflowSettings, dynamicPrompts, controlLayers, activeTabName) => {
const { model } = generation;
const { size } = controlLayers.present;
const { positivePrompt } = controlLayers.present;
@@ -206,16 +194,6 @@ const createSelector = (templates: Templates) =>
reasons.push({ prefix, content });
}
});
} else if (activeTabName === 'upscaling') {
if (!upscale.upscaleInitialImage) {
reasons.push({ content: i18n.t('upscaling.missingUpscaleInitialImage') });
}
if (!upscale.upscaleModel) {
reasons.push({ content: i18n.t('upscaling.missingUpscaleModel') });
}
if (!upscale.tileControlnetModel) {
reasons.push({ content: i18n.t('upscaling.missingTileControlNetModel') });
}
} else {
// Handling for all other tabs
selectControlAdapterAll(controlAdapters)

View File

@@ -62,10 +62,6 @@ export type CanvasInitialImageDropData = BaseDropData & {
actionType: 'SET_CANVAS_INITIAL_IMAGE';
};
type UpscaleInitialImageDropData = BaseDropData & {
actionType: 'SET_UPSCALE_INITIAL_IMAGE';
};
type NodesImageDropData = BaseDropData & {
actionType: 'SET_NODES_IMAGE';
context: {
@@ -102,8 +98,7 @@ export type TypesafeDroppableData =
| IPALayerImageDropData
| RGLayerIPAdapterImageDropData
| IILayerImageDropData
| SelectForCompareDropData
| UpscaleInitialImageDropData;
| SelectForCompareDropData;
type BaseDragData = {
id: string;

View File

@@ -27,8 +27,6 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
return payloadType === 'IMAGE_DTO';
case 'SET_CANVAS_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_UPSCALE_INITIAL_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SELECT_FOR_COMPARE':

View File

@@ -1,47 +0,0 @@
import { Flex, Image, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useTranslation } from 'react-i18next';
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { BoardDTO } from 'services/api/types';
type Props = {
board: BoardDTO | null;
};
export const BoardTooltip = ({ board }: Props) => {
const { t } = useTranslation();
const { imagesTotal } = useGetBoardImagesTotalQuery(board?.board_id || 'none', {
selectFromResult: ({ data }) => {
return { imagesTotal: data?.total ?? 0 };
},
});
const { assetsTotal } = useGetBoardAssetsTotalQuery(board?.board_id || 'none', {
selectFromResult: ({ data }) => {
return { assetsTotal: data?.total ?? 0 };
},
});
const { currentData: coverImage } = useGetImageDTOQuery(board?.cover_image_name ?? skipToken);
return (
<Flex flexDir="column" alignItems="center" gap={1}>
{coverImage && (
<Image
src={coverImage.thumbnail_url}
draggable={false}
objectFit="cover"
maxW={150}
aspectRatio="1/1"
borderRadius="base"
borderBottomRadius="lg"
/>
)}
<Flex flexDir="column" alignItems="center">
<Text noOfLines={1}>
{t('boards.imagesWithCount', { count: imagesTotal })}, {t('boards.assetsWithCount', { count: assetsTotal })}
</Text>
{board?.archived && <Text>({t('boards.archived')})</Text>}
</Flex>
</Flex>
);
};

View File

@@ -0,0 +1,12 @@
import { useTranslation } from 'react-i18next';
type Props = {
imageCount: number;
assetCount: number;
isArchived: boolean;
};
export const BoardTotalsTooltip = ({ imageCount, assetCount, isArchived }: Props) => {
const { t } = useTranslation();
return `${t('boards.imagesWithCount', { count: imageCount })}, ${t('boards.assetsWithCount', { count: assetCount })}${isArchived ? ` (${t('boards.archived')})` : ''}`;
};

View File

@@ -1,10 +1,13 @@
import { Button, Collapse, Flex, Icon, Text, useDisclosure } from '@invoke-ai/ui-library';
import { Box, Flex, Text } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppSelector } from 'app/store/storeHooks';
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { useMemo } from 'react';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { CSSProperties } from 'react';
import { memo, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold } from 'react-icons/pi';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import type { BoardDTO } from 'services/api/types';
@@ -12,111 +15,101 @@ import AddBoardButton from './AddBoardButton';
import GalleryBoard from './GalleryBoard';
import NoBoardBoard from './NoBoardBoard';
type Props = {
isPrivate: boolean;
setBoardToDelete: (board?: BoardDTO) => void;
const overlayScrollbarsStyles: CSSProperties = {
height: '100%',
width: '100%',
};
export const BoardsList = ({ isPrivate, setBoardToDelete }: Props) => {
const { t } = useTranslation();
const BoardsList = () => {
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
const boardSearchText = useAppSelector((s) => s.gallery.boardSearchText);
const allowPrivateBoards = useAppSelector((s) => s.config.allowPrivateBoards);
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
const { data: boards } = useListAllBoardsQuery(queryArgs);
const allowPrivateBoards = useAppSelector((s) => s.config.allowPrivateBoards);
const { isOpen, onToggle } = useDisclosure({ defaultIsOpen: true });
const [boardToDelete, setBoardToDelete] = useState<BoardDTO>();
const { t } = useTranslation();
const filteredBoards = useMemo(() => {
if (!boards) {
return EMPTY_ARRAY;
}
return boards.filter((board) => {
if (boardSearchText.length) {
return board.is_private === isPrivate && board.board_name.toLowerCase().includes(boardSearchText.toLowerCase());
} else {
return board.is_private === isPrivate;
}
});
}, [boardSearchText, boards, isPrivate]);
const boardElements = useMemo(() => {
const elements = [];
if (allowPrivateBoards && isPrivate && !boardSearchText.length) {
elements.push(<NoBoardBoard key="none" isSelected={selectedBoardId === 'none'} />);
}
if (!allowPrivateBoards && !boardSearchText.length) {
elements.push(<NoBoardBoard key="none" isSelected={selectedBoardId === 'none'} />);
}
filteredBoards.forEach((board) => {
elements.push(
<GalleryBoard
board={board}
isSelected={selectedBoardId === board.board_id}
setBoardToDelete={setBoardToDelete}
key={board.board_id}
/>
);
});
return elements;
}, [allowPrivateBoards, isPrivate, boardSearchText.length, filteredBoards, selectedBoardId, setBoardToDelete]);
const boardListTitle = useMemo(() => {
if (allowPrivateBoards) {
return isPrivate ? t('boards.private') : t('boards.shared');
} else {
return t('boards.boards');
}
}, [isPrivate, allowPrivateBoards, t]);
const { filteredPrivateBoards, filteredSharedBoards } = useMemo(() => {
const filteredBoards = boardSearchText
? boards?.filter((board) => board.board_name.toLowerCase().includes(boardSearchText.toLowerCase()))
: boards;
const filteredPrivateBoards = filteredBoards?.filter((board) => board.is_private) ?? EMPTY_ARRAY;
const filteredSharedBoards = filteredBoards?.filter((board) => !board.is_private) ?? EMPTY_ARRAY;
return { filteredPrivateBoards, filteredSharedBoards };
}, [boardSearchText, boards]);
return (
<Flex direction="column">
<Flex
position="sticky"
w="full"
justifyContent="space-between"
alignItems="center"
ps={2}
py={1}
zIndex={1}
top={0}
bg="base.900"
>
{allowPrivateBoards ? (
<Button variant="unstyled" onClick={onToggle}>
<Flex gap="2" alignItems="center">
<Icon
boxSize={4}
as={PiCaretDownBold}
transform={isOpen ? undefined : 'rotate(-90deg)'}
fill="base.500"
/>
<Text fontSize="sm" fontWeight="semibold" userSelect="none" color="base.500">
{boardListTitle}
</Text>
<>
<Box position="relative" w="full" h="full">
<Box position="absolute" top={0} right={0} bottom={0} left={0}>
<OverlayScrollbarsComponent defer style={overlayScrollbarsStyles} options={overlayScrollbarsParams.options}>
{allowPrivateBoards && (
<Flex direction="column" gap={1}>
<Flex
position="sticky"
w="full"
justifyContent="space-between"
alignItems="center"
ps={2}
pb={1}
pt={2}
zIndex={1}
top={0}
bg="base.900"
>
<Text fontSize="md" fontWeight="semibold" userSelect="none">
{t('boards.private')}
</Text>
<AddBoardButton isPrivateBoard={true} />
</Flex>
<Flex direction="column" gap={1}>
<NoBoardBoard isSelected={selectedBoardId === 'none'} />
{filteredPrivateBoards.map((board) => (
<GalleryBoard
board={board}
isSelected={selectedBoardId === board.board_id}
setBoardToDelete={setBoardToDelete}
key={board.board_id}
/>
))}
</Flex>
</Flex>
)}
<Flex direction="column" gap={1}>
<Flex
position="sticky"
w="full"
justifyContent="space-between"
alignItems="center"
ps={2}
pb={1}
pt={2}
zIndex={1}
top={0}
bg="base.900"
>
<Text fontSize="md" fontWeight="semibold" userSelect="none">
{allowPrivateBoards ? t('boards.shared') : t('boards.boards')}
</Text>
<AddBoardButton isPrivateBoard={false} />
</Flex>
<Flex direction="column" gap={1}>
{!allowPrivateBoards && <NoBoardBoard isSelected={selectedBoardId === 'none'} />}
{filteredSharedBoards.map((board) => (
<GalleryBoard
board={board}
isSelected={selectedBoardId === board.board_id}
setBoardToDelete={setBoardToDelete}
key={board.board_id}
/>
))}
</Flex>
</Flex>
</Button>
) : (
<Text fontSize="sm" fontWeight="semibold" userSelect="none" color="base.500">
{boardListTitle}
</Text>
)}
<AddBoardButton isPrivateBoard={isPrivate} />
</Flex>
<Collapse in={isOpen}>
<Flex direction="column" gap={1}>
{boardElements.length ? (
boardElements
) : (
<Text variant="subtext" textAlign="center">
{t('boards.noBoards', { boardType: boardSearchText.length ? 'Matching' : '' })}
</Text>
)}
</Flex>
</Collapse>
</Flex>
</OverlayScrollbarsComponent>
</Box>
</Box>
<DeleteBoardModal boardToDelete={boardToDelete} setBoardToDelete={setBoardToDelete} />
</>
);
};
export default memo(BoardsList);

View File

@@ -1,35 +0,0 @@
import { Box } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { CSSProperties } from 'react';
import { memo, useState } from 'react';
import type { BoardDTO } from 'services/api/types';
import { BoardsList } from './BoardsList';
const overlayScrollbarsStyles: CSSProperties = {
height: '100%',
width: '100%',
};
const BoardsListWrapper = () => {
const allowPrivateBoards = useAppSelector((s) => s.config.allowPrivateBoards);
const [boardToDelete, setBoardToDelete] = useState<BoardDTO>();
return (
<>
<Box position="relative" w="full" h="full">
<Box position="absolute" top={0} right={0} bottom={0} left={0}>
<OverlayScrollbarsComponent defer style={overlayScrollbarsStyles} options={overlayScrollbarsParams.options}>
{allowPrivateBoards && <BoardsList isPrivate={true} setBoardToDelete={setBoardToDelete} />}
<BoardsList isPrivate={false} setBoardToDelete={setBoardToDelete} />
</OverlayScrollbarsComponent>
</Box>
</Box>
<DeleteBoardModal boardToDelete={boardToDelete} setBoardToDelete={setBoardToDelete} />
</>
);
};
export default memo(BoardsListWrapper);

View File

@@ -40,7 +40,7 @@ const BoardsSearch = () => {
);
return (
<InputGroup>
<InputGroup pt={2}>
<Input
placeholder={t('boards.searchBoard')}
value={boardSearchText}

View File

@@ -17,7 +17,7 @@ import IAIDroppable from 'common/components/IAIDroppable';
import type { AddToBoardDropData } from 'features/dnd/types';
import { AutoAddBadge } from 'features/gallery/components/Boards/AutoAddBadge';
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
import { BoardTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTooltip';
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
import type { MouseEvent, MouseEventHandler, MutableRefObject } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
@@ -115,7 +115,18 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
return (
<BoardContextMenu board={board} setBoardToDelete={setBoardToDelete}>
{(ref) => (
<Tooltip label={<BoardTooltip board={board} />} openDelay={1000} placement="left" closeOnScroll p={2}>
<Tooltip
label={
<BoardTotalsTooltip
imageCount={board.image_count}
assetCount={board.asset_count}
isArchived={Boolean(board.archived)}
/>
}
openDelay={1000}
placement="left"
closeOnScroll
>
<Flex
position="relative"
ref={ref}
@@ -126,12 +137,10 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
borderRadius="base"
cursor="pointer"
py={1}
ps={1}
pe={4}
gap={4}
px={2}
gap={2}
bg={isSelected ? 'base.850' : undefined}
_hover={_hover}
h={12}
>
<CoverImage board={board} />
<Editable
@@ -146,26 +155,26 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
onChange={onChange}
onSubmit={onSubmit}
isPreviewFocusable={false}
fontSize="sm"
>
<EditablePreview
cursor="pointer"
p={0}
fontSize="sm"
fontSize="md"
textOverflow="ellipsis"
noOfLines={1}
w="fit-content"
wordBreak="break-all"
fontWeight={isSelected ? 'bold' : 'normal'}
color={isSelected ? 'base.100' : 'base.400'}
fontWeight={isSelected ? 'semibold' : 'normal'}
/>
<EditableInput sx={editableInputStyles} />
<JankEditableHijack onStartEditingRef={onStartEditingRef} />
</Editable>
{autoAddBoardId === board.board_id && !editingDisclosure.isOpen && <AutoAddBadge />}
{board.archived && !editingDisclosure.isOpen && <Icon as={PiArchiveBold} fill="base.300" />}
{!editingDisclosure.isOpen && <Text variant="subtext">{board.image_count}</Text>}
{!editingDisclosure.isOpen && <Text variant="subtext">{board.image_count + board.asset_count}</Text>}
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="lg">{t('unifiedCanvas.move')}</Text>} />
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="md">{t('unifiedCanvas.move')}</Text>} />
</Flex>
</Tooltip>
)}
@@ -194,8 +203,8 @@ const CoverImage = ({ board }: { board: BoardDTO }) => {
src={coverImage.thumbnail_url}
draggable={false}
objectFit="cover"
w={10}
h={10}
w={8}
h={8}
borderRadius="base"
borderBottomRadius="lg"
/>
@@ -203,8 +212,8 @@ const CoverImage = ({ board }: { board: BoardDTO }) => {
}
return (
<Flex w={10} h={10} justifyContent="center" alignItems="center">
<Icon boxSize={10} as={PiImageSquare} opacity={0.7} color="base.500" />
<Flex w={8} h={8} justifyContent="center" alignItems="center">
<Icon boxSize={8} as={PiImageSquare} opacity={0.7} color="base.500" />
</Flex>
);
};

View File

@@ -4,12 +4,12 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import IAIDroppable from 'common/components/IAIDroppable';
import type { RemoveFromBoardDropData } from 'features/dnd/types';
import { AutoAddBadge } from 'features/gallery/components/Boards/AutoAddBadge';
import { BoardTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTooltip';
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
import NoBoardBoardContextMenu from 'features/gallery/components/Boards/NoBoardBoardContextMenu';
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
import { useGetUncategorizedImageCountsQuery } from 'services/api/endpoints/boards';
import { useBoardName } from 'services/api/hooks/useBoardName';
interface Props {
@@ -22,11 +22,7 @@ const _hover: SystemStyleObject = {
const NoBoardBoard = memo(({ isSelected }: Props) => {
const dispatch = useAppDispatch();
const { imagesTotal } = useGetBoardImagesTotalQuery('none', {
selectFromResult: ({ data }) => {
return { imagesTotal: data?.total ?? 0 };
},
});
const { data } = useGetUncategorizedImageCountsQuery();
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const boardSearchText = useAppSelector((s) => s.gallery.boardSearchText);
@@ -46,16 +42,31 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
[]
);
const filteredOut = useMemo(() => {
return boardSearchText ? !boardName.toLowerCase().includes(boardSearchText.toLowerCase()) : false;
}, [boardName, boardSearchText]);
const { t } = useTranslation();
if (boardSearchText.length) {
if (filteredOut) {
return null;
}
return (
<NoBoardBoardContextMenu>
{(ref) => (
<Tooltip label={<BoardTooltip board={null} />} openDelay={1000} placement="left" closeOnScroll>
<Tooltip
label={
<BoardTotalsTooltip
imageCount={data?.image_count ?? 0}
assetCount={data?.asset_count ?? 0}
isArchived={false}
/>
}
openDelay={1000}
placement="left"
closeOnScroll
>
<Flex
position="relative"
ref={ref}
@@ -64,17 +75,15 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
alignItems="center"
borderRadius="base"
cursor="pointer"
px={2}
py={1}
ps={1}
pe={4}
gap={4}
gap={2}
bg={isSelected ? 'base.850' : undefined}
_hover={_hover}
h={12}
>
<Flex w="10" justifyContent="space-around">
<Flex w={8} h={8} justifyContent="center" alignItems="center">
{/* iconified from public/assets/images/invoke-symbol-wht-lrg.svg */}
<Icon boxSize={8} opacity={1} stroke="base.500" viewBox="0 0 66 66" fill="none">
<Icon boxSize={6} opacity={1} stroke="base.500" viewBox="0 0 66 66" fill="none">
<path
d="M43.9137 16H63.1211V3H3.12109V16H22.3285L43.9137 50H63.1211V63H3.12109V50H22.3285"
strokeWidth="5"
@@ -82,12 +91,18 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
</Icon>
</Flex>
<Text fontSize="sm" fontWeight={isSelected ? 'bold' : 'normal'} noOfLines={1} flexGrow={1}>
<Text
fontSize="md"
color={isSelected ? 'base.100' : 'base.400'}
fontWeight={isSelected ? 'semibold' : 'normal'}
noOfLines={1}
flexGrow={1}
>
{boardName}
</Text>
{autoAddBoardId === 'none' && <AutoAddBadge />}
<Text variant="subtext">{imagesTotal}</Text>
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="lg">{t('unifiedCanvas.move')}</Text>} />
<Text variant="subtext">{(data?.image_count ?? 0) + (data?.asset_count ?? 0)}</Text>
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="md">{t('unifiedCanvas.move')}</Text>} />
</Flex>
</Tooltip>
)}

View File

@@ -1,105 +0,0 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import {
Box,
Collapse,
Flex,
IconButton,
Spacer,
Tab,
TabList,
Tabs,
Text,
useDisclosure,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGallerySearchTerm } from 'features/gallery/components/ImageGrid/useGallerySearchTerm';
import { galleryViewChanged } from 'features/gallery/store/gallerySlice';
import type { CSSProperties } from 'react';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiMagnifyingGlassBold } from 'react-icons/pi';
import { useBoardName } from 'services/api/hooks/useBoardName';
import GalleryImageGrid from './ImageGrid/GalleryImageGrid';
import { GalleryPagination } from './ImageGrid/GalleryPagination';
import { GallerySearch } from './ImageGrid/GallerySearch';
const BASE_STYLES: ChakraProps['sx'] = {
fontWeight: 'semibold',
fontSize: 'sm',
color: 'base.300',
};
const SELECTED_STYLES: ChakraProps['sx'] = {
borderColor: 'base.800',
borderBottomColor: 'base.900',
color: 'invokeBlue.300',
};
const COLLAPSE_STYLES: CSSProperties = { flexShrink: 0, minHeight: 0 };
export const Gallery = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const galleryView = useAppSelector((s) => s.gallery.galleryView);
const initialSearchTerm = useAppSelector((s) => s.gallery.searchTerm);
const searchDisclosure = useDisclosure({ defaultIsOpen: initialSearchTerm.length > 0 });
const [searchTerm, onChangeSearchTerm, onResetSearchTerm] = useGallerySearchTerm();
const handleClickImages = useCallback(() => {
dispatch(galleryViewChanged('images'));
}, [dispatch]);
const handleClickAssets = useCallback(() => {
dispatch(galleryViewChanged('assets'));
}, [dispatch]);
const handleClickSearch = useCallback(() => {
searchDisclosure.onToggle();
onResetSearchTerm();
}, [onResetSearchTerm, searchDisclosure]);
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
const boardName = useBoardName(selectedBoardId);
return (
<Flex flexDirection="column" alignItems="center" justifyContent="space-between" h="full" w="full" pt={1}>
<Tabs index={galleryView === 'images' ? 0 : 1} variant="enclosed" display="flex" flexDir="column" w="full">
<TabList gap={2} fontSize="sm" borderColor="base.800" alignItems="center" w="full">
<Text fontSize="sm" fontWeight="semibold" noOfLines={1} px="2">
{boardName}
</Text>
<Spacer />
<Tab sx={BASE_STYLES} _selected={SELECTED_STYLES} onClick={handleClickImages} data-testid="images-tab">
{t('parameters.images')}
</Tab>
<Tab sx={BASE_STYLES} _selected={SELECTED_STYLES} onClick={handleClickAssets} data-testid="assets-tab">
{t('gallery.assets')}
</Tab>
<IconButton
onClick={handleClickSearch}
tooltip={searchDisclosure.isOpen ? `${t('gallery.exitSearch')}` : `${t('gallery.displaySearch')}`}
aria-label={t('gallery.displaySearch')}
icon={<PiMagnifyingGlassBold />}
colorScheme={searchDisclosure.isOpen ? 'invokeBlue' : 'base'}
variant="link"
/>
</TabList>
</Tabs>
<Box w="full">
<Collapse in={searchDisclosure.isOpen} style={COLLAPSE_STYLES}>
<Box w="full" pt={2}>
<GallerySearch
searchTerm={searchTerm}
onChangeSearchTerm={onChangeSearchTerm}
onResetSearchTerm={onResetSearchTerm}
/>
</Box>
</Collapse>
</Box>
<GalleryImageGrid />
<GalleryPagination />
</Flex>
);
};

View File

@@ -0,0 +1,33 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { memo } from 'react';
import { useBoardName } from 'services/api/hooks/useBoardName';
type Props = {
onClick: () => void;
};
const GalleryBoardName = (props: Props) => {
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
const boardName = useBoardName(selectedBoardId);
return (
<Flex
onClick={props.onClick}
as="button"
h="full"
w="full"
borderWidth={1}
borderRadius="base"
alignItems="center"
justifyContent="center"
px={2}
>
<Text fontWeight="semibold" fontSize="md" noOfLines={1} wordBreak="break-all" color="base.200">
{boardName}
</Text>
</Flex>
);
};
export default memo(GalleryBoardName);

View File

@@ -3,21 +3,32 @@ import { useStore } from '@nanostores/react';
import { $projectName, $projectUrl } from 'app/store/nanostores/projectId';
import { memo } from 'react';
export const GalleryHeader = memo(() => {
import GalleryBoardName from './GalleryBoardName';
type Props = {
onClickBoardName: () => void;
};
export const GalleryHeader = memo((props: Props) => {
const projectName = useStore($projectName);
const projectUrl = useStore($projectUrl);
if (projectName && projectUrl) {
return (
<Flex gap={2} w="full" alignItems="center" justifyContent="space-evenly" pe={2}>
<Text fontSize="md" fontWeight="semibold" noOfLines={1} wordBreak="break-all" w="full" textAlign="center">
<Text fontSize="md" fontWeight="semibold" noOfLines={1} w="full" textAlign="center">
<Link href={projectUrl}>{projectName}</Link>
</Text>
<GalleryBoardName onClick={props.onClickBoardName} />
</Flex>
);
}
return null;
return (
<Flex w="full" pe={2}>
<GalleryBoardName onClick={props.onClickBoardName} />
</Flex>
);
});
GalleryHeader.displayName = 'GalleryHeader';

View File

@@ -13,7 +13,6 @@ import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/ac
import { imageToCompareChanged } from 'features/gallery/store/gallerySlice';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { toast } from 'features/toast/toast';
import { setActiveTab } from 'features/ui/store/uiSlice';
@@ -125,11 +124,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
dispatch(imageToCompareChanged(imageDTO));
}, [dispatch, imageDTO]);
const handleSendToUpscale = useCallback(() => {
dispatch(upscaleInitialImageChanged(imageDTO));
dispatch(setActiveTab('upscaling'));
}, [dispatch, imageDTO]);
return (
<>
<MenuItem as="a" href={imageDTO.image_url} target="_blank" icon={<PiShareFatBold />}>
@@ -191,9 +185,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
{t('parameters.sendToUnifiedCanvas')}
</MenuItem>
)}
<MenuItem icon={<PiShareFatBold />} onClickCapture={handleSendToUpscale} id="send-to-upscale">
{t('parameters.sendToUpscale')}
</MenuItem>
<MenuDivider />
<MenuItem icon={<PiFoldersBold />} onClickCapture={handleChangeBoard}>
{t('boards.changeBoard')}

View File

@@ -1,28 +1,57 @@
import { Box, Button, Collapse, Divider, Flex, IconButton, useDisclosure } from '@invoke-ai/ui-library';
import type { ChakraProps } from '@invoke-ai/ui-library';
import {
Box,
Collapse,
Divider,
Flex,
IconButton,
Spacer,
Tab,
TabList,
Tabs,
useDisclosure,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { GalleryHeader } from 'features/gallery/components/GalleryHeader';
import { boardSearchTextChanged } from 'features/gallery/store/gallerySlice';
import { galleryViewChanged } from 'features/gallery/store/gallerySlice';
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
import { usePanel, type UsePanelOptions } from 'features/ui/hooks/usePanel';
import type { CSSProperties } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold, PiCaretUpBold, PiMagnifyingGlassBold } from 'react-icons/pi';
import { PiMagnifyingGlassBold } from 'react-icons/pi';
import type { ImperativePanelGroupHandle } from 'react-resizable-panels';
import { Panel, PanelGroup } from 'react-resizable-panels';
import BoardsListWrapper from './Boards/BoardsList/BoardsListWrapper';
import BoardsList from './Boards/BoardsList/BoardsList';
import BoardsSearch from './Boards/BoardsList/BoardsSearch';
import { Gallery } from './Gallery';
import GallerySettingsPopover from './GallerySettingsPopover/GallerySettingsPopover';
import GalleryImageGrid from './ImageGrid/GalleryImageGrid';
import { GalleryPagination } from './ImageGrid/GalleryPagination';
import { GallerySearch } from './ImageGrid/GallerySearch';
const COLLAPSE_STYLES: CSSProperties = { flexShrink: 0, minHeight: 0 };
const BASE_STYLES: ChakraProps['sx'] = {
fontWeight: 'semibold',
fontSize: 'sm',
color: 'base.300',
};
const SELECTED_STYLES: ChakraProps['sx'] = {
borderColor: 'base.800',
borderBottomColor: 'base.900',
color: 'invokeBlue.300',
};
const ImageGalleryContent = () => {
const { t } = useTranslation();
const galleryView = useAppSelector((s) => s.gallery.galleryView);
const searchTerm = useAppSelector((s) => s.gallery.searchTerm);
const boardSearchText = useAppSelector((s) => s.gallery.boardSearchText);
const dispatch = useAppDispatch();
const boardSearchDisclosure = useDisclosure({ defaultIsOpen: !!boardSearchText.length });
const searchDisclosure = useDisclosure({ defaultIsOpen: false });
const boardSearchDisclosure = useDisclosure({ defaultIsOpen: false });
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
const boardsListPanelOptions = useMemo<UsePanelOptions>(
@@ -38,58 +67,42 @@ const ImageGalleryContent = () => {
);
const boardsListPanel = usePanel(boardsListPanelOptions);
const handleClickBoardSearch = useCallback(() => {
if (boardSearchText.length) {
dispatch(boardSearchTextChanged(''));
}
boardSearchDisclosure.onToggle();
boardsListPanel.expand();
}, [boardSearchText.length, boardSearchDisclosure, boardsListPanel, dispatch]);
const handleClickImages = useCallback(() => {
dispatch(galleryViewChanged('images'));
}, [dispatch]);
const handleToggleBoardPanel = useCallback(() => {
if (boardSearchText.length) {
dispatch(boardSearchTextChanged(''));
}
boardSearchDisclosure.onClose();
boardsListPanel.toggle();
}, [boardSearchText.length, boardSearchDisclosure, boardsListPanel, dispatch]);
const handleClickAssets = useCallback(() => {
dispatch(galleryViewChanged('assets'));
}, [dispatch]);
return (
<Flex position="relative" flexDirection="column" h="full" w="full" pt={2}>
<Flex alignItems="center" gap={0}>
<GalleryHeader />
<Flex alignItems="center" justifyContent="space-between" w="full">
<Button
w={112}
size="sm"
variant="ghost"
onClick={handleToggleBoardPanel}
rightIcon={boardsListPanel.isCollapsed ? <PiCaretDownBold /> : <PiCaretUpBold />}
>
{boardsListPanel.isCollapsed ? t('boards.viewBoards') : t('boards.hideBoards')}
</Button>
<Flex alignItems="center" justifyContent="space-between">
<GallerySettingsPopover />
<Flex>
<IconButton
w="full"
h="full"
onClick={handleClickBoardSearch}
tooltip={
boardSearchDisclosure.isOpen
? `${t('gallery.exitBoardSearch')}`
: `${t('gallery.displayBoardSearch')}`
}
aria-label={t('gallery.displayBoardSearch')}
icon={<PiMagnifyingGlassBold />}
colorScheme={boardSearchDisclosure.isOpen ? 'invokeBlue' : 'base'}
variant="link"
/>
</Flex>
</Flex>
</Flex>
<Flex alignItems="center" gap={2}>
<GalleryHeader onClickBoardName={boardsListPanel.toggle} />
<GallerySettingsPopover />
<Box position="relative" h="full">
<IconButton
w="full"
h="full"
onClick={boardSearchDisclosure.onToggle}
tooltip={`${t('gallery.displayBoardSearch')}`}
aria-label={t('gallery.displayBoardSearch')}
icon={<PiMagnifyingGlassBold />}
variant="link"
/>
{boardSearchText && (
<Box
position="absolute"
w={2}
h={2}
bg="invokeBlue.300"
borderRadius="full"
insetBlockStart={2}
insetInlineEnd={2}
/>
)}
</Box>
</Flex>
<PanelGroup ref={panelGroupRef} direction="vertical">
<Panel
id="boards-list-panel"
@@ -102,12 +115,10 @@ const ImageGalleryContent = () => {
>
<Flex flexDir="column" w="full" h="full">
<Collapse in={boardSearchDisclosure.isOpen} style={COLLAPSE_STYLES}>
<Box w="full" pt={2}>
<BoardsSearch />
</Box>
<BoardsSearch />
</Collapse>
<Divider pt={2} />
<BoardsListWrapper />
<BoardsList />
</Flex>
</Panel>
<ResizeHandle
@@ -116,7 +127,50 @@ const ImageGalleryContent = () => {
onDoubleClick={boardsListPanel.onDoubleClickHandle}
/>
<Panel id="gallery-wrapper-panel" minSize={20}>
<Gallery />
<Flex flexDirection="column" alignItems="center" justifyContent="space-between" h="full" w="full">
<Tabs index={galleryView === 'images' ? 0 : 1} variant="enclosed" display="flex" flexDir="column" w="full">
<TabList gap={2} fontSize="sm" borderColor="base.800">
<Tab sx={BASE_STYLES} _selected={SELECTED_STYLES} onClick={handleClickImages} data-testid="images-tab">
{t('parameters.images')}
</Tab>
<Tab sx={BASE_STYLES} _selected={SELECTED_STYLES} onClick={handleClickAssets} data-testid="assets-tab">
{t('gallery.assets')}
</Tab>
<Spacer />
<Box position="relative">
<IconButton
w="full"
h="full"
onClick={searchDisclosure.onToggle}
tooltip={`${t('gallery.displaySearch')}`}
aria-label={t('gallery.displaySearch')}
icon={<PiMagnifyingGlassBold />}
variant="link"
/>
{searchTerm && (
<Box
position="absolute"
w={2}
h={2}
bg="invokeBlue.300"
borderRadius="full"
insetBlockStart={2}
insetInlineEnd={2}
/>
)}
</Box>
</TabList>
</Tabs>
<Box w="full">
<Collapse in={searchDisclosure.isOpen} style={COLLAPSE_STYLES}>
<Box w="full" pt={2}>
<GallerySearch />
</Box>
</Collapse>
</Box>
<GalleryImageGrid />
<GalleryPagination />
</Flex>
</Panel>
</PanelGroup>
</Flex>

View File

@@ -3,8 +3,6 @@ import { ELLIPSIS, useGalleryPagination } from 'features/gallery/hooks/useGaller
import { useCallback } from 'react';
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
import { JumpTo } from './JumpTo';
export const GalleryPagination = () => {
const { goPrev, goNext, isPrevEnabled, isNextEnabled, pageButtons, goToPage, currentPage, total } =
useGalleryPagination();
@@ -22,7 +20,7 @@ export const GalleryPagination = () => {
}
return (
<Flex justifyContent="center" alignItems="center" w="full" gap={1} pt={2}>
<Flex gap={2} alignItems="center" w="full">
<IconButton
size="sm"
aria-label="prev"
@@ -32,9 +30,25 @@ export const GalleryPagination = () => {
variant="ghost"
/>
<Spacer />
{pageButtons.map((page, i) => (
<PageButton key={`${page}_${i}`} page={page} currentPage={currentPage} goToPage={goToPage} />
))}
{pageButtons.map((page, i) => {
if (page === ELLIPSIS) {
return (
<Button size="sm" key={`ellipsis_${i}`} variant="link" isDisabled>
...
</Button>
);
}
return (
<Button
size="sm"
key={page}
onClick={goToPage.bind(null, page - 1)}
variant={currentPage === page - 1 ? 'solid' : 'outline'}
>
{page}
</Button>
);
})}
<Spacer />
<IconButton
size="sm"
@@ -44,28 +58,6 @@ export const GalleryPagination = () => {
isDisabled={!isNextEnabled}
variant="ghost"
/>
<JumpTo />
</Flex>
);
};
type PageButtonProps = {
page: number | typeof ELLIPSIS;
currentPage: number;
goToPage: (page: number) => void;
};
const PageButton = ({ page, currentPage, goToPage }: PageButtonProps) => {
if (page === ELLIPSIS) {
return (
<Button size="sm" variant="link" isDisabled>
...
</Button>
);
}
return (
<Button size="sm" onClick={goToPage.bind(null, page - 1)} variant={currentPage === page - 1 ? 'solid' : 'outline'}>
{page}
</Button>
);
};

View File

@@ -1,60 +1,59 @@
import { IconButton, Input, InputGroup, InputRightElement, Spinner } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import type { ChangeEvent, KeyboardEvent } from 'react';
import { useCallback } from 'react';
import { searchTermChanged } from 'features/gallery/store/gallerySlice';
import { debounce } from 'lodash-es';
import type { ChangeEvent } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import { useListImagesQuery } from 'services/api/endpoints/images';
type Props = {
searchTerm: string;
onChangeSearchTerm: (value: string) => void;
onResetSearchTerm: () => void;
};
export const GallerySearch = ({ searchTerm, onChangeSearchTerm, onResetSearchTerm }: Props) => {
export const GallerySearch = () => {
const dispatch = useAppDispatch();
const searchTerm = useAppSelector((s) => s.gallery.searchTerm);
const { t } = useTranslation();
const [searchTermInput, setSearchTermInput] = useState(searchTerm);
const queryArgs = useAppSelector(selectListImagesQueryArgs);
const { isPending } = useListImagesQuery(queryArgs, {
selectFromResult: ({ isLoading, isFetching }) => ({ isPending: isLoading || isFetching }),
});
const debouncedSetSearchTerm = useMemo(() => {
return debounce((value: string) => {
dispatch(searchTermChanged(value));
}, 1000);
}, [dispatch]);
const handleChangeInput = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChangeSearchTerm(e.target.value);
setSearchTermInput(e.target.value);
debouncedSetSearchTerm(e.target.value);
},
[onChangeSearchTerm]
[debouncedSetSearchTerm]
);
const handleKeydown = useCallback(
(e: KeyboardEvent<HTMLInputElement>) => {
// exit search mode on escape
if (e.key === 'Escape') {
onResetSearchTerm();
}
},
[onResetSearchTerm]
);
const handleClearInput = useCallback(() => {
setSearchTermInput('');
dispatch(searchTermChanged(''));
}, [dispatch]);
return (
<InputGroup>
<Input
placeholder={t('gallery.searchImages')}
value={searchTerm}
value={searchTermInput}
onChange={handleChangeInput}
data-testid="image-search-input"
onKeyDown={handleKeydown}
/>
{isPending && (
<InputRightElement h="full" pe={2}>
<Spinner size="sm" opacity={0.5} />
</InputRightElement>
)}
{!isPending && searchTerm.length && (
{!isPending && searchTermInput.length && (
<InputRightElement h="full" pe={2}>
<IconButton
onClick={onResetSearchTerm}
onClick={handleClearInput}
size="sm"
variant="link"
aria-label={t('boards.clearSearch')}

View File

@@ -1,97 +0,0 @@
import {
Button,
CompositeNumberInput,
Flex,
FormControl,
Popover,
PopoverArrow,
PopoverBody,
PopoverContent,
PopoverTrigger,
useDisclosure,
} from '@invoke-ai/ui-library';
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
import { useCallback, useEffect, useRef, useState } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
export const JumpTo = () => {
const { t } = useTranslation();
const { goToPage, currentPage, pages } = useGalleryPagination();
const [newPage, setNewPage] = useState(currentPage);
const { isOpen, onToggle, onClose } = useDisclosure();
const ref = useRef<HTMLInputElement>(null);
const onOpen = useCallback(() => {
setNewPage(currentPage);
setTimeout(() => {
const input = ref.current?.querySelector('input');
input?.focus();
input?.select();
}, 0);
}, [currentPage]);
const onChangeJumpTo = useCallback((v: number) => {
setNewPage(v - 1);
}, []);
const onClickGo = useCallback(() => {
goToPage(newPage);
onClose();
}, [newPage, goToPage, onClose]);
useHotkeys(
'enter',
() => {
onClickGo();
},
{ enabled: isOpen, enableOnFormTags: ['input'] },
[isOpen, onClickGo]
);
useHotkeys(
'esc',
() => {
setNewPage(currentPage);
onClose();
},
{ enabled: isOpen, enableOnFormTags: ['input'] },
[isOpen, onClose]
);
useEffect(() => {
setNewPage(currentPage);
}, [currentPage]);
return (
<Popover isOpen={isOpen} onClose={onClose} onOpen={onOpen}>
<PopoverTrigger>
<Button aria-label={t('gallery.jump')} size="sm" onClick={onToggle} variant="outline">
{t('gallery.jump')}
</Button>
</PopoverTrigger>
<PopoverContent>
<PopoverArrow />
<PopoverBody>
<Flex gap={2} alignItems="center">
<FormControl>
<CompositeNumberInput
ref={ref}
size="sm"
maxW="60px"
value={newPage + 1}
min={1}
max={pages}
step={1}
onChange={onChangeJumpTo}
/>
</FormControl>
<Button h="full" size="sm" onClick={onClickGo}>
{t('gallery.go')}
</Button>
</Flex>
</PopoverBody>
</PopoverContent>
</Popover>
);
};

View File

@@ -1,37 +0,0 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { searchTermChanged } from 'features/gallery/store/gallerySlice';
import { debounce } from 'lodash-es';
import { useCallback, useMemo, useState } from 'react';
export const useGallerySearchTerm = () => {
// Highlander!
useAssertSingleton('gallery-search-state');
const dispatch = useAppDispatch();
const searchTerm = useAppSelector((s) => s.gallery.searchTerm);
const [localSearchTerm, setLocalSearchTerm] = useState(searchTerm);
const debouncedSetSearchTerm = useMemo(() => {
return debounce((val: string) => {
dispatch(searchTermChanged(val));
}, 1000);
}, [dispatch]);
const onChange = useCallback(
(val: string) => {
setLocalSearchTerm(val);
debouncedSetSearchTerm(val);
},
[debouncedSetSearchTerm]
);
const onReset = useCallback(() => {
debouncedSetSearchTerm.cancel();
setLocalSearchTerm('');
dispatch(searchTermChanged(''));
}, [debouncedSetSearchTerm, dispatch]);
return [localSearchTerm, onChange, onReset] as const;
};

View File

@@ -2,7 +2,7 @@ import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/query';
import { adHocPostProcessingRequested } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice';
import { DeleteImageButton } from 'features/deleteImageModal/components/DeleteImageButton';
@@ -14,7 +14,7 @@ import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers';
import { $templates } from 'features/nodes/store/nodesSlice';
import { PostProcessingPopover } from 'features/parameters/components/PostProcessing/PostProcessingPopover';
import ParamUpscalePopover from 'features/parameters/components/Upscale/ParamUpscaleSettings';
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectSystemSlice } from 'features/system/store/systemSlice';
@@ -97,7 +97,7 @@ const CurrentImageButtons = () => {
if (!imageDTO) {
return;
}
dispatch(adHocPostProcessingRequested({ imageDTO }));
dispatch(upscaleRequested({ imageDTO }));
}, [dispatch, imageDTO]);
const handleDelete = useCallback(() => {
@@ -193,7 +193,7 @@ const CurrentImageButtons = () => {
{isUpscalingEnabled && (
<ButtonGroup isDisabled={isQueueMutationInProgress}>
{isUpscalingEnabled && <PostProcessingPopover imageDTO={imageDTO} />}
{isUpscalingEnabled && <ParamUpscalePopover imageDTO={imageDTO} />}
</ButtonGroup>
)}

View File

@@ -9,13 +9,7 @@ import CurrentImageButtons from './CurrentImageButtons';
import { ViewerToggleMenu } from './ViewerToggleMenu';
export const ViewerToolbar = memo(() => {
const showToggle = useAppSelector((s) => {
const tab = activeTabNameSelector(s);
if (tab === 'upscaling' || tab === 'workflows') {
return false;
}
return true;
});
const tab = useAppSelector(activeTabNameSelector);
return (
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
@@ -29,7 +23,7 @@ export const ViewerToolbar = memo(() => {
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto">
{showToggle && <ViewerToggleMenu />}
{tab !== 'workflows' && <ViewerToggleMenu />}
</Flex>
</Flex>
</Flex>

View File

@@ -1,7 +1,6 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { offsetChanged } from 'features/gallery/store/gallerySlice';
import { throttle } from 'lodash-es';
import { useCallback, useEffect, useMemo } from 'react';
import { useListImagesQuery } from 'services/api/endpoints/images';
@@ -81,41 +80,32 @@ export const useGalleryPagination = () => {
return offset > 0;
}, [count, offset]);
const onOffsetChanged = useCallback(
(arg: Parameters<typeof offsetChanged>[0]) => {
dispatch(offsetChanged(arg));
},
[dispatch]
);
const throttledOnOffsetChanged = useMemo(() => throttle(onOffsetChanged, 500), [onOffsetChanged]);
const goNext = useCallback(
(withHotkey?: 'arrow' | 'alt+arrow') => {
throttledOnOffsetChanged({ offset: offset + (limit || 0), withHotkey });
dispatch(offsetChanged({ offset: offset + (limit || 0), withHotkey }));
},
[throttledOnOffsetChanged, offset, limit]
[dispatch, offset, limit]
);
const goPrev = useCallback(
(withHotkey?: 'arrow' | 'alt+arrow') => {
throttledOnOffsetChanged({ offset: Math.max(offset - (limit || 0), 0), withHotkey });
dispatch(offsetChanged({ offset: Math.max(offset - (limit || 0), 0), withHotkey }));
},
[throttledOnOffsetChanged, offset, limit]
[dispatch, offset, limit]
);
const goToPage = useCallback(
(page: number) => {
throttledOnOffsetChanged({ offset: page * (limit || 0) });
dispatch(offsetChanged({ offset: page * (limit || 0) }));
},
[throttledOnOffsetChanged, limit]
[dispatch, limit]
);
const goToFirst = useCallback(() => {
throttledOnOffsetChanged({ offset: 0 });
}, [throttledOnOffsetChanged]);
dispatch(offsetChanged({ offset: 0 }));
}, [dispatch]);
const goToLast = useCallback(() => {
throttledOnOffsetChanged({ offset: (pages - 1) * (limit || 0) });
}, [throttledOnOffsetChanged, pages, limit]);
dispatch(offsetChanged({ offset: (pages - 1) * (limit || 0) }));
}, [dispatch, pages, limit]);
// handle when total/pages decrease and user is on high page number (ie bulk removing or deleting)
useEffect(() => {

Some files were not shown because too many files have changed in this diff Show More