Compare commits

..

51 Commits

Author SHA1 Message Date
Ryan Dick
6bcf48aa37 WIP - Started working towards MultiDiffusion batching. 2024-06-18 15:44:39 -04:00
Ryan Dick
b1bb1511fe Delete rough notes. 2024-06-18 15:36:36 -04:00
Ryan Dick
99046a8145 Fix advanced scheduler behaviour in MultiDiffusionPipeline. 2024-06-18 15:36:36 -04:00
Ryan Dick
72be7e71e3 Fix handling of stateful schedulers in MultiDiffusionPipeline. 2024-06-18 15:36:36 -04:00
Ryan Dick
35adaf1c17 Connect TiledMultiDiffusionDenoiseLatents to the MultiDiffusionPipeline backend. 2024-06-18 15:36:34 -04:00
Ryan Dick
865c2335de Remove regional conditioning logic from MultiDiffusionPipeline - it is not yet supported. 2024-06-18 15:35:52 -04:00
Ryan Dick
49ca42f84a Initial (untested) implementation of MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
Ryan Dick
493fcd8660 Remove inpainting support from MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
Ryan Dick
20322d781e Remove IP-Adapter and T2I-Adapter support from MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
Ryan Dick
889d13e02a Document plan for the rest of the MultiDiffusion implementation. 2024-06-18 15:35:52 -04:00
Ryan Dick
6ccd2a867b Add detailed docstring to latents_from_embeddings(). 2024-06-18 15:35:52 -04:00
Ryan Dick
5861fa1719 Copy StableDiffusionGeneratorPipeline as a starting point for a new MultiDiffusionPipeline. 2024-06-18 15:35:52 -04:00
Ryan Dick
dfd4beb62b Simplify handling of inpainting models. Improve the in-code documentation around inpainting. 2024-06-18 15:35:52 -04:00
Ryan Dick
83df0c0df5 Minor tidying of latents_from_embeddings(...). 2024-06-18 15:35:52 -04:00
Ryan Dick
c58c4069a7 Consolidate latents_from_embeddings(...) and generate_latents_from_embeddings(...) into a single function. 2024-06-18 15:35:52 -04:00
Ryan Dick
3937fffa94 Fix invocation name of tiled_multi_diffusion_denoise_latents. 2024-06-18 15:35:52 -04:00
Ryan Dick
bbf5f67691 Improve clarity of comments regarded when 'noise' and 'latents' are expected to be set. 2024-06-18 15:35:52 -04:00
Ryan Dick
2f5c147b84 Fix static check errors on imports in diffusers_pipeline.py. 2024-06-18 15:35:52 -04:00
Ryan Dick
bd2839b748 Remove a condition for handling inpainting models that never resolves to True. The same logic is already applied earlier by AddsMaskLatents. 2024-06-18 15:35:52 -04:00
Ryan Dick
4f70dd7ce1 Add clarifying comment to explain why noise might be None in latents_from_embedding(). 2024-06-18 15:35:52 -04:00
Ryan Dick
066672fbfd Remove unused are_like_tensors() function. 2024-06-18 15:35:52 -04:00
Ryan Dick
abefaee4d1 Remove unused StableDiffusionGeneratorPipeline.use_ip_adapter member. 2024-06-18 15:35:52 -04:00
Ryan Dick
3254ba5904 Remove unused StableDiffusionGeneratorPipeline.control_model. 2024-06-18 15:35:52 -04:00
Ryan Dick
73a8c55852 Stricter typing for the is_gradient_mask: bool. 2024-06-18 15:35:52 -04:00
Ryan Dick
f82af7c22d Fix typing of control_data to reflect that it can be None. 2024-06-18 15:35:52 -04:00
Ryan Dick
3aef717ef4 Fix typing of timesteps and init_timestep. 2024-06-18 15:35:52 -04:00
Ryan Dick
c2cf1137e9 Fix typing to reflect that the callback arg to latents_from_embeddings is never None. 2024-06-18 15:35:52 -04:00
Ryan Dick
803a24bc0a Move seed above optional params. 2024-06-18 15:35:52 -04:00
Ryan Dick
7d24ad8ccd Simplify handling of AddsMaskGuidance, and fix some related type errors. 2024-06-18 15:35:52 -04:00
Ryan Dick
cb389063b2 Remove unused num_inference_steps. 2024-06-18 15:35:52 -04:00
Ryan Dick
81b8a69e1a WIP TiledMultiDiffusionDenoiseLatents. Updated parameter list and first half of the logic. 2024-06-18 15:35:50 -04:00
Ryan Dick
7ee5db87ad Tidy DenoiseLatentsInvocation.prep_control_data(...) and fix some type errors. 2024-06-18 15:34:30 -04:00
Ryan Dick
66cf2c59bd Make DenoiseLatentsInvocation.prep_control_data(...) a staticmethod so that it can be called externally. 2024-06-18 15:34:30 -04:00
Ryan Dick
3bad1367e9 Copy TiledStableDiffusionRefineInvocation as a starting point for TiledMultiDiffusionDenoiseLatents.py 2024-06-18 15:34:22 -04:00
Ryan Dick
867a7642a6 Change tiling strategy to make TiledStableDiffusionRefineInvocation work with more tile shapes and overlaps. 2024-06-18 15:31:58 -04:00
Ryan Dick
d9d1c8f9cb Expose a few more params from TiledStableDiffusionRefineInvocation. 2024-06-18 15:31:58 -04:00
Ryan Dick
e03eb7fb45 Add support for LoRA models in TiledStableDiffusionRefineInvocation. 2024-06-18 15:31:58 -04:00
Ryan Dick
85db33bc7e Add naive ControlNet support to TiledStableDiffusionRefineInvocation 2024-06-18 15:31:58 -04:00
Ryan Dick
93e3a2b504 Fix ControlNetModel type hint import source. 2024-06-18 15:31:58 -04:00
Ryan Dick
6a7a26f1bf Rough prototype of TiledStableDiffusionRefineInvocation is working. 2024-06-18 15:31:58 -04:00
Ryan Dick
08ca03ef9f WIP - TiledStableDiffusionRefine 2024-06-18 15:31:54 -04:00
Ryan Dick
ccf90b6bd6 Minor improvements to LatentsToImageInvocation type hints. 2024-06-18 15:31:21 -04:00
Ryan Dick
753239b48d Expose vae_decode(...) as a staticmethod on LatentsToImageInvocation. 2024-06-18 15:31:21 -04:00
Ryan Dick
65fa4664c9 Fix return type of prepare_noise_and_latents(...). 2024-06-18 15:31:21 -04:00
Ryan Dick
297570ded3 Make init_scheduler() a staticmethod on DenoiseLatentsInvocation so that it can be called externally. 2024-06-18 15:31:21 -04:00
Ryan Dick
680fdcf293 Only allow a single positive/negative prompt conditioning input for tiled refine. 2024-06-18 15:31:21 -04:00
Ryan Dick
5ff91f2c44 WIP on TiledStableDiffusionRefine 2024-06-18 15:31:14 -04:00
Ryan Dick
69aa7057e7 Convert several methods in DenoiseLatentsInvocation to staticmethods so that they can be called externally. 2024-06-18 15:25:08 -04:00
Ryan Dick
d3932f40de Simplify the logic in prepare_noise_and_latents(...). 2024-06-18 15:25:08 -04:00
Ryan Dick
ee74cd7fab Split out the prepare_noise_and_latents(...) logic in DenoiseLatentsInvocation so that it can be called from other invocations. 2024-06-18 15:25:08 -04:00
Ryan Dick
bda25b40c9 (minor) Add a TODO note to get_scheduler(...). 2024-06-18 15:25:08 -04:00
134 changed files with 3368 additions and 2832 deletions

View File

@@ -12,24 +12,12 @@
Invoke is a leading creative engine built to empower professionals and enthusiasts alike. Generate and create stunning visual media using the latest AI-driven technologies. Invoke offers an industry leading web-based UI, and serves as the foundation for multiple commercial products.
Invoke is available in two editions:
| **Community Edition** | **Professional Edition** |
|----------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------|
| **For users looking for a locally installed, self-hosted and self-managed service** | **For users or teams looking for a cloud-hosted, fully managed service** |
| - Free to use under a commercially-friendly license | - Monthly subscription fee with three different plan levels |
| - Download and install on compatible hardware | - Offers additional benefits, including multi-user support, improved model training, and more |
| - Includes all core studio features: generate, refine, iterate on images, and build workflows | - Hosted in the cloud for easy, secure model access and scalability |
| Quick Start -> [Installation and Updates][installation docs] | More Information -> [www.invoke.com/pricing](https://www.invoke.com/pricing) |
[Installation and Updates][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs]
<div align="center">
![Highlighted Features - Canvas and Workflows](https://github.com/invoke-ai/InvokeAI/assets/31807370/708f7a82-084f-4860-bfbe-e2588c53548d)
# Documentation
| **Quick Links** |
|----------------------------------------------------------------------------------------------------------------------------|
| [Installation and Updates][installation docs] - [Documentation and Tutorials][docs home] - [Bug Reports][github issues] - [Contributing][contributing docs] |
</div>
## Quick Start

View File

@@ -73,6 +73,15 @@ model's lifetime it may be transformed in various ways, such as
changing its precision or converting it from a .safetensors to a
diffusers model.
`ModelType`, `ModelFormat` and `BaseModelType` are string enums that
are defined in `invokeai.backend.model_manager.config`. They are also
imported by, and can be reexported from,
`invokeai.app.services.model_manager.model_records`:
```
from invokeai.app.services.model_records import ModelType, ModelFormat, BaseModelType
```
The `path` field can be absolute or relative. If relative, it is taken
to be relative to the `models_dir` setting in the user's
`invokeai.yaml` file.

View File

@@ -118,13 +118,15 @@ async def list_boards(
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
offset: Optional[int] = Query(default=None, description="The page offset"),
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
include_archived: bool = Query(default=False, description="Whether or not to include archived boards in list"),
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards"""
if all:
return ApiDependencies.invoker.services.boards.get_all(include_archived)
return ApiDependencies.invoker.services.boards.get_all()
elif offset is not None and limit is not None:
return ApiDependencies.invoker.services.boards.get_many(offset, limit, include_archived)
return ApiDependencies.invoker.services.boards.get_many(
offset,
limit,
)
else:
raise HTTPException(
status_code=400,

View File

@@ -9,14 +9,9 @@ from PIL import Image
from pydantic import BaseModel, Field, JsonValue
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from ..dependencies import ApiDependencies
@@ -321,14 +316,16 @@ async def list_image_dtos(
),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of image DTOs"""
image_dtos = ApiDependencies.invoker.services.images.get_many(
offset, limit, starred_first, order_dir, image_origin, categories, is_intermediate, board_id, search_term
offset,
limit,
image_origin,
categories,
is_intermediate,
board_id,
)
return image_dtos

View File

@@ -3,9 +3,9 @@
import io
import pathlib
import shutil
import traceback
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional, Type
from fastapi import Body, Path, Query, Response, UploadFile
@@ -19,6 +19,7 @@ from typing_extensions import Annotated
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,
ModelRecordChanges,
UnknownModelException,
@@ -29,6 +30,7 @@ from invokeai.backend.model_manager.config import (
MainCheckpointConfig,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
@@ -172,6 +174,18 @@ async def get_model_record(
raise HTTPException(status_code=404, detail=str(e))
# @model_manager_router.get("/summary", operation_id="list_model_summary")
# async def list_model_summary(
# page: int = Query(default=0, description="The page to get"),
# per_page: int = Query(default=10, description="The number of models per page"),
# order_by: ModelRecordOrderBy = Query(default=ModelRecordOrderBy.Default, description="The attribute to order by"),
# ) -> PaginatedResults[ModelSummary]:
# """Gets a page of model summary data."""
# record_store = ApiDependencies.invoker.services.model_manager.store
# results: PaginatedResults[ModelSummary] = record_store.list_models(page=page, per_page=per_page, order_by=order_by)
# return results
class FoundModel(BaseModel):
path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed")
@@ -732,36 +746,39 @@ async def convert_model(
logger.error(f"The model with key {key} is not a main checkpoint model.")
raise HTTPException(400, f"The model with key {key} is not a main checkpoint model.")
with TemporaryDirectory(dir=ApiDependencies.invoker.services.configuration.models_path) as tmpdir:
convert_path = pathlib.Path(tmpdir) / pathlib.Path(model_config.path).stem
converted_model = loader.load_model(model_config)
# write the converted file to the convert path
raw_model = converted_model.model
assert hasattr(raw_model, "save_pretrained")
raw_model.save_pretrained(convert_path)
assert convert_path.exists()
# loading the model will convert it into a cached diffusers file
try:
cc_size = loader.convert_cache.max_size
if cc_size == 0: # temporary set the convert cache to a positive number so that cached model is written
loader._convert_cache.max_size = 1.0
loader.load_model(model_config, submodel_type=SubModelType.Scheduler)
finally:
loader._convert_cache.max_size = cc_size
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# Get the path of the converted model from the loader
cache_path = loader.convert_cache.cache_path(key)
assert cache_path.exists()
# install the diffusers
try:
new_key = installer.install_path(
convert_path,
config={
"name": original_name,
"description": model_config.description,
"hash": model_config.hash,
"source": model_config.source,
},
)
except Exception as e:
logger.error(str(e))
store.update_model(key, changes=ModelRecordChanges(name=original_name))
raise HTTPException(status_code=409, detail=str(e))
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config.name
model_config.name = f"{original_name}.DELETE"
changes = ModelRecordChanges(name=model_config.name)
store.update_model(key, changes=changes)
# install the diffusers
try:
new_key = installer.install_path(
cache_path,
config={
"name": original_name,
"description": model_config.description,
"hash": model_config.hash,
"source": model_config.source,
},
)
except DuplicateModelException as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
# Update the model image if the model had one
try:
@@ -774,8 +791,8 @@ async def convert_model(
# delete the original safetensors file
installer.delete(key)
# delete the temporary directory
# shutil.rmtree(cache_path)
# delete the cached version
shutil.rmtree(cache_path)
# return the config record for the new diffusers directory
new_config = store.get_model(new_key)

View File

@@ -1,5 +1,6 @@
from typing import Literal
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
LATENT_SCALE_FACTOR = 8
@@ -10,6 +11,9 @@ factor is hard-coded to a literal '8' rather than using this constant.
The ratio of image:latent dimensions is LATENT_SCALE_FACTOR:1, or 8:1.
"""
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())]
"""A literal type representing the valid scheduler names."""
IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"]
"""A literal type for PIL image modes supported by Invoke"""

View File

@@ -19,8 +19,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor

View File

@@ -17,7 +17,7 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPVisionModelWithProjection
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.fields import (
ConditioningField,
@@ -53,7 +53,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningData,
TextConditioningRegions,
)
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_MAP, SCHEDULER_NAME_VALUES
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
from invokeai.backend.util.mask import to_standard_float_mask
@@ -693,7 +693,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("'latents' or 'noise' must be provided!")
if noise is not None and noise.shape[1:] != latents.shape[1:]:
raise ValueError(f"Incompatible 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
raise ValueError(f"Incompatable 'noise' and 'latents' shapes: {latents.shape=} {noise.shape=}")
# The seed comes from (in order of priority): the noise field, the latents field, or 0.
seed = 0
@@ -736,7 +736,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# The image prompts are then passed to prep_ip_adapter_data().
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
# get the unet's config so that we can pass the base to sd_step_callback()
# get the unet's config so that we can pass the base to dispatch_progress()
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:

View File

@@ -160,8 +160,6 @@ class FieldDescriptions:
fp32 = "Whether or not to use full float32 precision"
precision = "Precision to use"
tiled = "Processing using overlapping tiles (reduce memory consumption)"
vae_tile_size = "The tile size for VAE tiling in pixels (image space). If set to 0, the default tile size for the "
"model will be used. Larger tile sizes generally produce better results at the cost of higher memory usage."
detect_res = "Pixel resolution for detection"
image_res = "Pixel resolution for output image"
safe_mode = "Whether or not to use safe mode"

View File

@@ -1,4 +1,3 @@
from contextlib import nullcontext
from functools import singledispatchmethod
import einops
@@ -13,7 +12,7 @@ from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
@@ -23,9 +22,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
@invocation(
@@ -33,7 +31,7 @@ from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
title="Image to Latents",
tags=["latents", "image", "vae", "i2l"],
category="latents",
version="1.1.0",
version="1.0.2",
)
class ImageToLatentsInvocation(BaseInvocation):
"""Encodes an image into latents."""
@@ -46,17 +44,12 @@ class ImageToLatentsInvocation(BaseInvocation):
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@staticmethod
def vae_encode(
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
) -> torch.Tensor:
def vae_encode(vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
assert isinstance(vae, torch.nn.Module)
orig_dtype = vae.dtype
if upcast:
vae.to(dtype=torch.float32)
@@ -88,18 +81,9 @@ class ImageToLatentsInvocation(BaseInvocation):
else:
vae.disable_tiling()
tiling_context = nullcontext()
if tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=tile_size,
tile_latent_min_size=tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
with torch.inference_mode(), tiling_context:
with torch.inference_mode():
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)
latents = vae.config.scaling_factor * latents
@@ -117,9 +101,7 @@ class ImageToLatentsInvocation(BaseInvocation):
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
latents = self.vae_encode(
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
)
latents = self.vae_encode(vae_info, self.fp32, self.tiled, image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)

View File

@@ -1,5 +1,3 @@
from contextlib import nullcontext
import torch
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import (
@@ -10,9 +8,10 @@ from diffusers.models.attention_processor import (
)
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
@@ -24,8 +23,8 @@ from invokeai.app.invocations.fields import (
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion import set_seamless
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice
@@ -34,7 +33,7 @@ from invokeai.backend.util.devices import TorchDevice
title="Latents to Image",
tags=["latents", "image", "vae", "l2i"],
category="latents",
version="1.3.0",
version="1.2.2",
)
class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents."""
@@ -48,21 +47,22 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
)
tiled: bool = InputField(default=False, description=FieldDescriptions.tiled)
# NOTE: tile_size = 0 is a special value. We use this rather than `int | None`, because the workflow UI does not
# offer a way to directly set None values.
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
fp32: bool = InputField(default=DEFAULT_PRECISION == torch.float32, description=FieldDescriptions.fp32)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
@staticmethod
def vae_decode(
context: InvocationContext,
vae_info: LoadedModel,
seamless_axes: list[str],
latents: torch.Tensor,
use_fp32: bool,
use_tiling: bool,
) -> Image.Image:
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
with set_seamless(vae_info.model, seamless_axes), vae_info as vae:
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:
if use_fp32:
vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = hasattr(vae.decoder, "mid_block") and isinstance(
@@ -87,24 +87,15 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
vae.to(dtype=torch.float16)
latents = latents.half()
if self.tiled or context.config.get().force_tiled_decode:
if use_tiling or context.config.get().force_tiled_decode:
vae.enable_tiling()
else:
vae.disable_tiling()
tiling_context = nullcontext()
if self.tile_size > 0:
tiling_context = patch_vae_tiling_params(
vae,
tile_sample_min_size=self.tile_size,
tile_latent_min_size=self.tile_size // LATENT_SCALE_FACTOR,
tile_overlap_factor=0.25,
)
# clear memory as vae decode can request a lot
TorchDevice.empty_cache()
with torch.inference_mode(), tiling_context:
with torch.inference_mode():
# copied from diffusers pipeline
latents = latents / vae.config.scaling_factor
image = vae.decode(latents, return_dict=False)[0]
@@ -116,6 +107,21 @@ class LatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
TorchDevice.empty_cache()
return image
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
image = self.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=latents,
use_fp32=self.fp32,
use_tiling=self.tiled,
)
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)

View File

@@ -1,4 +1,5 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.invocations.fields import (
FieldDescriptions,
InputField,
@@ -6,7 +7,6 @@ from invokeai.app.invocations.fields import (
UIType,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@invocation_output("scheduler_output")

View File

@@ -7,8 +7,8 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
@@ -24,12 +24,11 @@ from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData
from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import (
MultiDiffusionPipeline,
MultiDiffusionRegionConditioning,
)
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from invokeai.backend.tiles.tiles import (
calc_tiles_min_overlap,
)
@@ -56,15 +55,15 @@ def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> C
title="Tiled Multi-Diffusion Denoise Latents",
tags=["upscale", "denoise"],
category="latents",
classification=Classification.Beta,
# TODO(ryand): Reset to 1.0.0 right before release.
version="1.0.0",
)
class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
"""Tiled Multi-Diffusion denoising.
This node handles automatically tiling the input image, and is primarily intended for global refinement of images
in tiled upscaling workflows. Future Multi-Diffusion nodes should allow the user to specify custom regions with
different parameters for each region to harness the full power of Multi-Diffusion.
This node handles automatically tiling the input image. Future iterations of
this node should allow the user to specify custom regions with different parameters for each region to harness the
full power of Multi-Diffusion.
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
T2I-Adapter, masking, etc.).
@@ -86,24 +85,21 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
description=FieldDescriptions.latents,
input=Input.Connection,
)
tile_height: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Height of the tiles in image space."
)
tile_width: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Width of the tiles in image space."
)
tile_overlap: int = InputField(
default=32,
multiple_of=LATENT_SCALE_FACTOR,
# TODO(ryand): Add multiple-of validation.
# TODO(ryand): Smaller defaults might make more sense.
tile_height: int = InputField(default=112, gt=0, description="Height of the tiles in latent space.")
tile_width: int = InputField(default=112, gt=0, description="Width of the tiles in latent space.")
tile_min_overlap: int = InputField(
default=16,
gt=0,
description="The overlap between adjacent tiles in pixel space. (Of course, tile merging is applied in latent "
"space.) Tiles will be cropped during merging (if necessary) to ensure that they overlap by exactly this "
"amount.",
description="The minimum overlap between adjacent tiles in latent space. The actual overlap may be larger than "
"this to evenly cover the entire image.",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
# TODO(ryand): The default here should probably be 0.0.
denoising_start: float = InputField(
default=0.0,
default=0.65,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
@@ -154,7 +150,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
self.config = FakeVae.FakeVaeConfig()
return MultiDiffusionPipeline(
vae=FakeVae(),
vae=FakeVae(), # TODO: oh...
text_encoder=None,
tokenizer=None,
unet=unet,
@@ -166,29 +162,19 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# Convert tile image-space dimensions to latent-space dimensions.
latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR
latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR
latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR
seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
_, _, latent_height, latent_width = latents.shape
# Calculate the tile locations to cover the latent-space image.
# TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint?
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
tile_height=latent_tile_height,
tile_width=latent_tile_width,
min_overlap=latent_tile_overlap,
tile_height=self.tile_height,
tile_width=self.tile_width,
min_overlap=self.tile_min_overlap,
)
# Get the unet's config so that we can pass the base to sd_step_callback().
unet_config = context.models.get_config(self.unet.unet.key)
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
@@ -219,8 +205,8 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_tile_height,
latent_width=latent_tile_width,
latent_height=self.tile_height,
latent_width=self.tile_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
@@ -247,7 +233,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
multi_diffusion_conditioning.append(
MultiDiffusionRegionConditioning(
region=tile,
region=tile.coords,
text_conditioning_data=conditioning_data,
control_data=tile_controlnet_data,
)
@@ -265,17 +251,17 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
# Run Multi-Diffusion denoising.
result_latents = pipeline.multi_diffusion_denoise(
multi_diffusion_conditioning=multi_diffusion_conditioning,
target_overlap=latent_tile_overlap,
latents=latents,
scheduler_step_kwargs=scheduler_step_kwargs,
noise=noise,
timesteps=timesteps,
init_timestep=init_timestep,
callback=step_callback,
# TODO(ryand): Add proper callback.
callback=lambda x: None,
)
result_latents = result_latents.to("cpu")
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
result_latents = result_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)

View File

@@ -0,0 +1,380 @@
from contextlib import ExitStack
from typing import Iterator, Tuple
import numpy as np
import numpy.typing as npt
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import (
ConditioningField,
FieldDescriptions,
ImageField,
Input,
InputField,
UIType,
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField
from invokeai.app.invocations.noise import get_noise
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending
from invokeai.backend.tiles.utils import Tile
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@invocation(
"tiled_stable_diffusion_refine",
title="Tiled Stable Diffusion Refine",
tags=["upscale", "denoise"],
category="latents",
version="1.0.0",
)
class TiledStableDiffusionRefineInvocation(BaseInvocation):
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
"""
image: ImageField = InputField(description="Image to be refined.")
positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection
)
# TODO(ryand): Add multiple-of validation.
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.")
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.")
tile_overlap: int = InputField(
default=16,
gt=0,
description="Target overlap between adjacent tiles (the last row/column may overlap more than this).",
)
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
denoising_start: float = InputField(
default=0.65,
ge=0,
le=1,
description=FieldDescriptions.denoising_start,
)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
unet: UNetField = InputField(
description=FieldDescriptions.unet,
input=Input.Connection,
title="UNet",
)
cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
)
vae: VAEField = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
vae_fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
)
# HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly
# don't want to use the image field. Figure out how best to handle this.
# TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty
# much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support
# CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models,
# etc.
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: float = InputField(default=0.6)
@field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float:
"""Validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError("cfg_scale must be greater than 1")
else:
if v < 1:
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor:
"""Crop the latent-space tensor to the area corresponding to the image-space tile.
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR.
"""
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]:
if coord % LATENT_SCALE_FACTOR != 0:
raise ValueError(
f"The tile coordinates must all be divisible by the latent scale factor"
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
)
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
top = image_tile.coords.top // LATENT_SCALE_FACTOR
left = image_tile.coords.left // LATENT_SCALE_FACTOR
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
right = image_tile.coords.right // LATENT_SCALE_FACTOR
return latents[..., top:bottom, left:right]
def run_controlnet(
self,
image: Image.Image,
controlnet_model: ControlNetModel,
weight: float,
do_classifier_free_guidance: bool,
width: int,
height: int,
device: torch.device,
dtype: torch.dtype,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
) -> ControlNetData:
control_image = prepare_control_image(
image=image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
device=device,
dtype=dtype,
control_mode=control_mode,
resize_mode=resize_mode,
)
return ControlNetData(
model=controlnet_model,
image_tensor=control_image,
weight=weight,
begin_step_percent=0.0,
end_step_percent=1.0,
control_mode=control_mode,
# Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to
# ControlNetData in case needed in the future.
resize_mode=resize_mode,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
# TODO(ryand): Expose the seed parameter.
seed = 0
# Load the input image.
input_image = context.images.get_pil(self.image.image_name)
# Calculate the tile locations to cover the image.
# We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This
# facilitates conversions between image space and latent space.
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
tiles = calc_tiles_with_overlap(
image_height=input_image.height,
image_width=input_image.width,
tile_height=self.tile_height,
tile_width=self.tile_width,
overlap=self.tile_overlap,
)
# Convert the input image to a torch.Tensor.
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
# Validate our assumptions about the shape of input_image_torch.
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
assert input_image_torch.shape[:2] == (1, 3)
# Split the input image into tiles in torch.Tensor format.
image_tiles_torch: list[torch.Tensor] = []
for tile in tiles:
image_tile = input_image_torch[
:,
:,
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
]
image_tiles_torch.append(image_tile)
# Split the input image into tiles in numpy format.
# TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work
# with torch.Tensor tiles.
input_image_np = np.array(input_image)
image_tiles_np: list[npt.NDArray[np.uint8]] = []
for tile in tiles:
image_tile_np = input_image_np[
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
:,
]
image_tiles_np.append(image_tile_np)
# VAE-encode each image tile independently.
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
# about for decoding?
vae_info = context.models.load(self.vae.vae)
latent_tiles: list[torch.Tensor] = []
for image_tile_torch in image_tiles_torch:
latent_tiles.append(
ImageToLatentsInvocation.vae_encode(
vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile_torch
)
)
# Generate noise with dimensions corresponding to the full image in latent space.
# It is important that the noise tensor is generated at the full image dimension and then tiled, rather than
# generating for each tile independently. This ensures that overlapping regions between tiles use the same
# noise.
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
global_noise = get_noise(
width=input_image_torch.shape[3],
height=input_image_torch.shape[2],
device=TorchDevice.choose_torch_device(),
seed=seed,
downsampling_factor=LATENT_SCALE_FACTOR,
use_cpu=True,
)
# Crop the global noise into tiles.
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles]
# Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.unet.loras:
lora_info = context.models.load(lora.lora)
assert isinstance(lora_info.model, LoRAModelRaw)
yield (lora_info.model, lora.weight)
del lora_info
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
refined_latent_tiles: list[torch.Tensor] = []
with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()):
assert isinstance(unet, UNet2DConditionModel)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
seed=seed,
)
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
# Assume that all tiles have the same shape.
_, _, latent_height, latent_width = latent_tiles[0].shape
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
# Load the ControlNet model.
# TODO(ryand): Support multiple ControlNet models.
controlnet_model = exit_stack.enter_context(context.models.load(self.control_model))
assert isinstance(controlnet_model, ControlNetModel)
# Denoise (i.e. "refine") each tile independently.
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
assert latent_tile.shape == noise_tile.shape
# Prepare a PIL Image for ControlNet processing.
# TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of
# the tiles. Ideally, the ControlNet code should be able to work with Tensors.
image_tile_pil = Image.fromarray(image_tile_np)
# Run the ControlNet on the image tile.
height, width, _ = image_tile_np.shape
# The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we
# validate this assumption here.
assert height % LATENT_SCALE_FACTOR == 0
assert width % LATENT_SCALE_FACTOR == 0
controlnet_data = self.run_controlnet(
image=image_tile_pil,
controlnet_model=controlnet_model,
weight=self.control_weight,
do_classifier_free_guidance=True,
width=width,
height=height,
device=controlnet_model.device,
dtype=controlnet_model.dtype,
control_mode="balanced",
resize_mode="just_resize_simple",
)
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
# TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM.
latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype)
noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype)
refined_latent_tile = pipeline.latents_from_embeddings(
latents=latent_tile,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise_tile,
seed=seed,
mask=None,
masked_latents=None,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=[controlnet_data],
ip_adapter_data=None,
t2i_adapter_data=None,
callback=lambda x: None,
)
refined_latent_tiles.append(refined_latent_tile)
# VAE-decode each refined latent tile independently.
refined_image_tiles: list[Image.Image] = []
for refined_latent_tile in refined_latent_tiles:
refined_image_tile = LatentsToImageInvocation.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=refined_latent_tile,
use_fp32=self.vae_fp32,
use_tiling=False,
)
refined_image_tiles.append(refined_image_tile)
# TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important.
TorchDevice.empty_cache()
# Merge the refined image tiles back into a single image.
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
# TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter?
merge_tiles_with_linear_blending(
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap
)
# Save the refined image and return its reference.
merged_image_pil = Image.fromarray(merged_image_np)
image_dto = context.images.save(image=merged_image_pil)
return ImageOutput.build(image_dto)

View File

@@ -40,12 +40,16 @@ class BoardRecordStorageBase(ABC):
@abstractmethod
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records."""
pass
@abstractmethod
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
def get_all(
self,
) -> list[BoardRecord]:
"""Gets all board records."""
pass

View File

@@ -22,8 +22,6 @@ class BoardRecord(BaseModelExcludeNull):
"""The updated timestamp of the image."""
cover_image_name: Optional[str] = Field(default=None, description="The name of the cover image of the board.")
"""The name of the cover image of the board."""
archived: bool = Field(description="Whether or not the board is archived.")
"""Whether or not the board is archived."""
def deserialize_board_record(board_dict: dict) -> BoardRecord:
@@ -37,7 +35,6 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
created_at = board_dict.get("created_at", get_iso_timestamp())
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
archived = board_dict.get("archived", False)
return BoardRecord(
board_id=board_id,
@@ -46,14 +43,12 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
archived=archived,
)
class BoardChanges(BaseModel, extra="forbid"):
board_name: Optional[str] = Field(default=None, description="The board's new name.")
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
class BoardRecordNotFoundException(Exception):

View File

@@ -125,17 +125,6 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
(changes.cover_image_name, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
self._cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
@@ -145,49 +134,35 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
return self.get(board_id)
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
# Build base query
base_query = """
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
""",
(limit, offset),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
# Get the total number of boards
self._cursor.execute(
"""--sql
SELECT COUNT(*)
FROM boards
WHERE 1=1;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
self._cursor.execute(count_query)
)
count = cast(int, self._cursor.fetchone()[0])
@@ -199,25 +174,20 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
finally:
self._lock.release()
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
def get_all(
self,
) -> list[BoardRecord]:
try:
self._lock.acquire()
base_query = """
# Get all the boards
self._cursor.execute(
"""--sql
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
"""
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(archived_filter=archived_filter)
self._cursor.execute(final_query)
"""
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]

View File

@@ -44,12 +44,16 @@ class BoardServiceABC(ABC):
@abstractmethod
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
self,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards."""
pass
@abstractmethod
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
def get_all(
self,
) -> list[BoardDTO]:
"""Gets all boards."""
pass

View File

@@ -48,10 +48,8 @@ class BoardService(BoardServiceABC):
def delete(self, board_id: str) -> None:
self.__invoker.services.board_records.delete(board_id)
def get_many(
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
def get_many(self, offset: int = 0, limit: int = 10) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(offset, limit)
board_dtos = []
for r in board_records.items:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
@@ -65,8 +63,8 @@ class BoardService(BoardServiceABC):
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all(include_archived)
def get_all(self) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all()
board_dtos = []
for r in board_records:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import copy
import locale
import os
import re
@@ -26,13 +25,14 @@ DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEFAULT_RAM_CACHE = 10.0
DEFAULT_VRAM_CACHE = 0.25
DEFAULT_CONVERT_CACHE = 20.0
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
CONFIG_SCHEMA_VERSION = "4.0.2"
CONFIG_SCHEMA_VERSION = "4.0.1"
def get_default_ram_cache_size() -> float:
@@ -85,7 +85,7 @@ class InvokeAIAppConfig(BaseSettings):
log_tokenization: Enable logging of parsed prompt tokens.
patchmatch: Enable patchmatch inpaint code.
models_dir: Path to the models directory.
convert_cache_dir: Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions).
convert_cache_dir: Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.
download_cache_dir: Path to the directory that contains dynamically downloaded models.
legacy_conf_dir: Path to directory of legacy checkpoint config files.
db_dir: Path to InvokeAI databases directory.
@@ -102,6 +102,7 @@ class InvokeAIAppConfig(BaseSettings):
profiles_dir: Path to profiles output directory.
ram: Maximum memory amount used by memory model cache for rapid switching (GB).
vram: Amount of VRAM reserved for model storage (GB).
convert_cache: Maximum size of on-disk converted models cache (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
@@ -112,7 +113,6 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
max_queue_size: Maximum number of items in the session queue.
clear_queue_on_startup: Empties session queue on startup.
allow_nodes: List of nodes to allow. Omit to allow all.
deny_nodes: List of nodes to deny. Omit to deny none.
node_cache_size: How many cached nodes to keep in memory.
@@ -147,7 +147,7 @@ class InvokeAIAppConfig(BaseSettings):
# PATHS
models_dir: Path = Field(default=Path("models"), description="Path to the models directory.")
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory (DEPRECATED, but do not delete because it is needed for migration from previous versions).")
convert_cache_dir: Path = Field(default=Path("models/.convert_cache"), description="Path to the converted models cache directory. When loading a non-diffusers model, it will be converted and store on disk at this location.")
download_cache_dir: Path = Field(default=Path("models/.download_cache"), description="Path to the directory that contains dynamically downloaded models.")
legacy_conf_dir: Path = Field(default=Path("configs"), description="Path to directory of legacy checkpoint config files.")
db_dir: Path = Field(default=Path("databases"), description="Path to InvokeAI databases directory.")
@@ -169,8 +169,9 @@ class InvokeAIAppConfig(BaseSettings):
profiles_dir: Path = Field(default=Path("profiles"), description="Path to profiles output directory.")
# CACHE
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
ram: float = Field(default_factory=get_default_ram_cache_size, gt=0, description="Maximum memory amount used by memory model cache for rapid switching (GB).")
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
convert_cache: float = Field(default=DEFAULT_CONVERT_CACHE, ge=0, description="Maximum size of on-disk converted models cache (GB).")
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
@@ -185,7 +186,6 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
# NODES
allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.")
@@ -355,14 +355,14 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
return (init_settings,)
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate a v3 config dictionary to a v4.0.0.
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate a v3 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
An `InvokeAIAppConfig` config dict.
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
@@ -396,41 +396,32 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
parsed_config_dict["schema_version"] = "4.0.0"
return parsed_config_dict
# When migrating the config file, we should not include currently-set environment variables.
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def migrate_v4_0_0_to_4_0_1_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.0 config dictionary to a v4.0.1 config dictionary
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.0 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
A config dict with the settings migrated to v4.0.1.
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
# precision "autocast" was replaced by "auto" in v4.0.1
if parsed_config_dict.get("precision") == "autocast":
parsed_config_dict["precision"] = "auto"
parsed_config_dict["schema_version"] = "4.0.1"
return parsed_config_dict
def migrate_v4_0_1_to_4_0_2_config_dict(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.1 config dictionary to a v4.0.2 config dictionary.
Args:
config_dict: A dictionary of settings from a v4.0.1 config file.
Returns:
An config dict with the settings migrated to v4.0.2.
"""
parsed_config_dict: dict[str, Any] = copy.deepcopy(config_dict)
# convert_cache was removed in 4.0.2
parsed_config_dict.pop("convert_cache", None)
parsed_config_dict["schema_version"] = "4.0.2"
return parsed_config_dict
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
@@ -444,31 +435,27 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""
assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict: dict[str, Any] = yaml.safe_load(file)
loaded_config_dict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
migrated = False
if "InvokeAI" in loaded_config_dict:
migrated = True
loaded_config_dict = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
if loaded_config_dict["schema_version"] == "4.0.0":
migrated = True
loaded_config_dict = migrate_v4_0_0_to_4_0_1_config_dict(loaded_config_dict)
if loaded_config_dict["schema_version"] == "4.0.1":
migrated = True
loaded_config_dict = migrate_v4_0_1_to_4_0_2_config_dict(loaded_config_dict)
if migrated:
# This is a v3 config file, attempt to migrate it
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# load and write without environment variables
migrated_config = DefaultInvokeAIAppConfig.model_validate(loaded_config_dict)
migrated_config.write_file(config_path)
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
if loaded_config_dict["schema_version"] == "4.0.0":
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)

View File

@@ -4,7 +4,6 @@ from typing import Optional
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from .image_records_common import ImageCategory, ImageRecord, ImageRecordChanges, ResourceOrigin
@@ -38,13 +37,10 @@ class ImageRecordStorageBase(ABC):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass

View File

@@ -5,7 +5,6 @@ from typing import Optional, Union, cast
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from .image_records_base import ImageRecordStorageBase
@@ -145,13 +144,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
@@ -212,21 +208,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
query_params.append(board_id)
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
query_pagination = """--sql
ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"

View File

@@ -12,7 +12,6 @@ from invokeai.app.services.image_records.image_records_common import (
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
class ImageServiceABC(ABC):
@@ -117,13 +116,10 @@ class ImageServiceABC(ABC):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass

View File

@@ -5,7 +5,6 @@ from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from ..image_files.image_files_common import (
ImageFileDeleteException,
@@ -74,12 +73,7 @@ class ImageService(ImageServiceABC):
session_id=session_id,
)
if board_id is not None:
try:
self.__invoker.services.board_image_records.add_image_to_board(
board_id=board_id, image_name=image_name
)
except Exception as e:
self.__invoker.services.logger.warn(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.board_image_records.add_image_to_board(board_id=board_id, image_name=image_name)
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
)
@@ -208,25 +202,19 @@ class ImageService(ImageServiceABC):
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self.__invoker.services.image_records.get_many(
offset,
limit,
starred_first,
order_dir,
image_origin,
categories,
is_intermediate,
board_id,
search_term,
)
image_dtos = [

View File

@@ -6,7 +6,8 @@ from pathlib import Path
from typing import Callable, Optional
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
@@ -27,6 +28,11 @@ class ModelLoadServiceBase(ABC):
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache used by this loader."""
@property
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None

View File

@@ -2,7 +2,7 @@
"""Implementation of model loader service."""
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, Optional, Type
from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file
@@ -11,9 +11,14 @@ from torch import load as torch_load
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load import (
LoadedModel,
LoadedModelWithoutConfig,
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -28,7 +33,8 @@ class ModelLoadService(ModelLoadServiceBase):
self,
app_config: InvokeAIAppConfig,
ram_cache: ModelCacheBase[AnyModel],
registry: ModelLoaderRegistry,
convert_cache: ModelConvertCacheBase,
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
):
"""Initialize the model load service."""
logger = InvokeAILogger.get_logger(self.__class__.__name__)
@@ -36,6 +42,7 @@ class ModelLoadService(ModelLoadServiceBase):
self._logger = logger
self._app_config = app_config
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._registry = registry
def start(self, invoker: Invoker) -> None:
@@ -46,6 +53,11 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the RAM cache used by this loader."""
return self._ram_cache
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the checkpoint convert cache used by this loader."""
return self._convert_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
@@ -64,6 +76,7 @@ class ModelLoadService(ModelLoadServiceBase):
app_config=self._app_config,
logger=self._logger,
ram_cache=self._ram_cache,
convert_cache=self._convert_cache,
).load_model(model_config, submodel_type)
if hasattr(self, "_invoker"):

View File

@@ -0,0 +1,17 @@
"""Initialization file for model manager service."""
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from .model_manager_default import ModelManagerService, ModelManagerServiceBase
__all__ = [
"ModelManagerServiceBase",
"ModelManagerService",
"AnyModel",
"AnyModelConfig",
"BaseModelType",
"ModelType",
"SubModelType",
"LoadedModel",
]

View File

@@ -7,8 +7,7 @@ import torch
from typing_extensions import Self
from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager.load.model_cache.model_cache_default import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@@ -87,9 +86,11 @@ class ModelManagerService(ModelManagerServiceBase):
logger=logger,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
loader = ModelLoadService(
app_config=app_config,
ram_cache=ram_cache,
convert_cache=convert_cache,
registry=ModelLoaderRegistry,
)
installer = ModelInstallService(

View File

@@ -37,14 +37,10 @@ class SqliteSessionQueue(SessionQueueBase):
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
self._set_in_progress_to_canceled()
if self.__invoker.services.configuration.clear_queue_on_startup:
clear_result = self.clear(DEFAULT_QUEUE_ID)
if clear_result.deleted > 0:
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
else:
prune_result = self.prune(DEFAULT_QUEUE_ID)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
prune_result = self.prune(DEFAULT_QUEUE_ID)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()

View File

@@ -652,7 +652,7 @@ class Graph(BaseModel):
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
# Input type must be a list
if get_origin(input_field) is not list:
if get_origin(input_field) != list:
return False
# Validate that all outputs match the input type

View File

@@ -14,8 +14,6 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_11 import build_migration_11
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_12 import build_migration_12
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_13 import build_migration_13
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -47,8 +45,6 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_9())
migrator.register_migration(build_migration_10())
migrator.register_migration(build_migration_11(app_config=config, logger=logger))
migrator.register_migration(build_migration_12(app_config=config))
migrator.register_migration(build_migration_13())
migrator.run_migrations()
return db

View File

@@ -1,35 +0,0 @@
import shutil
import sqlite3
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration12Callback:
def __init__(self, app_config: InvokeAIAppConfig) -> None:
self._app_config = app_config
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._remove_model_convert_cache_dir()
def _remove_model_convert_cache_dir(self) -> None:
"""
Removes unused model convert cache directory
"""
convert_cache = self._app_config.convert_cache_path
shutil.rmtree(convert_cache, ignore_errors=True)
def build_migration_12(app_config: InvokeAIAppConfig) -> Migration:
"""
Build the migration from database version 11 to 12.
This migration removes the now-unused model convert cache directory.
"""
migration_12 = Migration(
from_version=11,
to_version=12,
callback=Migration12Callback(app_config),
)
return migration_12

View File

@@ -1,31 +0,0 @@
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration13Callback:
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._add_archived_col(cursor)
def _add_archived_col(self, cursor: sqlite3.Cursor) -> None:
"""
- Adds `archived` columns to the board table.
"""
cursor.execute("ALTER TABLE boards ADD COLUMN archived BOOLEAN DEFAULT FALSE;")
def build_migration_13() -> Migration:
"""
Build the migration from database version 12 to 13..
This migration does the following:
- Adds `archived` columns to the board table.
"""
migration_13 = Migration(
from_version=12,
to_version=13,
callback=Migration13Callback(),
)
return migration_13

View File

@@ -11,7 +11,6 @@ from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
from invokeai.backend.model_manager.load.model_size_utils import calc_module_size
from ..raw_model import RawModel
from .resampler import Resampler
@@ -138,7 +137,10 @@ class IPAdapter(RawModel):
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
def calc_size(self):
return calc_module_size(self._image_proj_model) + calc_module_size(self.attn_weights)
# workaround for circular import
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
def _init_image_proj_model(
self, state_dict: dict[str, torch.Tensor]

View File

@@ -10,7 +10,6 @@ from safetensors.torch import load_file
from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.util.devices import TorchDevice
from .raw_model import RawModel
@@ -522,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
layer.to(device=device, dtype=dtype, non_blocking=True)
model.layers[layer_key] = layer
return model

View File

@@ -12,9 +12,7 @@ def validate_hash(hash: str):
map = json.loads(b64decode(enc_hash))
if alg in map:
if hash_ == map[alg]:
raise Exception(
"This model can not be loaded. If you're looking for help, consider visiting https://www.redirectionprogram.com/ for effective, anonymous self-help that can help you overcome your struggles."
)
raise Exception("Unrecoverable Model Error")
hashes: list[str] = [

View File

@@ -13,6 +13,7 @@ from .config import (
SchedulerPredictionType,
SubModelType,
)
from .load import LoadedModel
from .probe import ModelProbe
from .search import ModelSearch
@@ -22,6 +23,7 @@ __all__ = [
"BaseModelType",
"ModelRepoVariant",
"InvalidModelConfigException",
"LoadedModel",
"ModelConfigFactory",
"ModelFormat",
"ModelProbe",

View File

@@ -24,21 +24,20 @@ import time
from enum import Enum
from typing import Literal, Optional, Type, TypeAlias, Union
import diffusers
import torch
from diffusers.models.modeling_utils import ModelMixin
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from invokeai.app.invocations.constants import SCHEDULER_NAME_VALUES
from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline]
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
class InvalidModelConfigException(Exception):

View File

@@ -0,0 +1,83 @@
# Adapted for use in InvokeAI by Lincoln Stein, July 2023
#
"""Conversion script for the Stable Diffusion checkpoints."""
from pathlib import Path
from typing import Optional
import torch
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
convert_ldm_vae_checkpoint,
create_vae_diffusers_config,
download_controlnet_from_original_ckpt,
download_from_original_stable_diffusion_ckpt,
)
from omegaconf import DictConfig
from . import AnyModel
def convert_ldm_vae_to_diffusers(
checkpoint: torch.Tensor | dict[str, torch.Tensor],
vae_config: DictConfig,
image_size: int,
dump_path: Optional[Path] = None,
precision: torch.dtype = torch.float16,
) -> AutoencoderKL:
"""Convert a checkpoint-style VAE into a Diffusers VAE"""
vae_config = create_vae_diffusers_config(vae_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
vae.to(precision)
if dump_path:
vae.save_pretrained(dump_path, safe_serialization=True)
return vae
def convert_ckpt_to_diffusers(
checkpoint_path: str | Path,
dump_path: Optional[str | Path] = None,
precision: torch.dtype = torch.float16,
use_safetensors: bool = True,
**kwargs,
) -> AnyModel:
"""
Takes all the arguments of download_from_original_stable_diffusion_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
pipe = pipe.to(precision)
# TO DO: save correct repo variant
if dump_path:
pipe.save_pretrained(
dump_path,
safe_serialization=use_safetensors,
)
return pipe
def convert_controlnet_to_diffusers(
checkpoint_path: Path,
dump_path: Optional[Path] = None,
precision: torch.dtype = torch.float16,
**kwargs,
) -> AnyModel:
"""
Takes all the arguments of download_controlnet_from_original_ckpt(),
and in addition a path-like object indicating the location of the desired diffusers
model to be written.
"""
pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
pipe = pipe.to(precision)
# TO DO: save correct repo variant
if dump_path:
pipe.save_pretrained(dump_path, safe_serialization=True)
return pipe

View File

@@ -1 +1,29 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Init file for the model loader.
"""
from importlib import import_module
from pathlib import Path
from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
from .load_default import ModelLoader
from .model_cache.model_cache_default import ModelCache
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
# This registers the subclasses that implement loaders of specific model types
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
for module in loaders:
import_module(f"{__package__}.model_loaders.{module}")
__all__ = [
"LoadedModel",
"LoadedModelWithoutConfig",
"ModelCache",
"ModelConvertCache",
"ModelLoaderBase",
"ModelLoader",
"ModelLoaderRegistryBase",
"ModelLoaderRegistry",
]

View File

@@ -0,0 +1,4 @@
from .convert_cache_base import ModelConvertCacheBase
from .convert_cache_default import ModelConvertCache
__all__ = ["ModelConvertCacheBase", "ModelConvertCache"]

View File

@@ -0,0 +1,28 @@
"""
Disk-based converted model cache.
"""
from abc import ABC, abstractmethod
from pathlib import Path
class ModelConvertCacheBase(ABC):
@property
@abstractmethod
def max_size(self) -> float:
"""Return the maximum size of this cache directory."""
pass
@abstractmethod
def make_room(self, size: float) -> None:
"""
Make sufficient room in the cache directory for a model of max_size.
:param size: Size required (GB)
"""
pass
@abstractmethod
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
pass

View File

@@ -0,0 +1,83 @@
"""
Placeholder for convert cache implementation.
"""
import shutil
from pathlib import Path
from invokeai.backend.util import GIG, directory_size
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.backend.util.util import safe_filename
from .convert_cache_base import ModelConvertCacheBase
class ModelConvertCache(ModelConvertCacheBase):
def __init__(self, cache_path: Path, max_size: float = 10.0):
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
if not cache_path.exists():
cache_path.mkdir(parents=True)
self._cache_path = cache_path
self._max_size = max_size
# adjust cache size at startup in case it has been changed
if self._cache_path.exists():
self.make_room(0.0)
@property
def max_size(self) -> float:
"""Return the maximum size of this cache directory (GB)."""
return self._max_size
@max_size.setter
def max_size(self, value: float) -> None:
"""Set the maximum size of this cache directory (GB)."""
self._max_size = value
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
key = safe_filename(self._cache_path, key)
return self._cache_path / key
def make_room(self, size: float) -> None:
"""
Make sufficient room in the cache directory for a model of max_size.
:param size: Size required (GB)
"""
size_needed = directory_size(self._cache_path) + size
max_size = int(self.max_size) * GIG
logger = InvokeAILogger.get_logger()
if size_needed <= max_size:
return
logger.debug(
f"Convert cache has gotten too large {(size_needed / GIG):4.2f} > {(max_size / GIG):4.2f}G.. Trimming."
)
# For this to work, we make the assumption that the directory contains
# a 'model_index.json', 'unet/config.json' file, or a 'config.json' file at top level.
# This should be true for any diffusers model.
def by_atime(path: Path) -> float:
for config in ["model_index.json", "unet/config.json", "config.json"]:
sentinel = path / config
if sentinel.exists():
return sentinel.stat().st_atime
# no sentinel file found! - pick the most recent file in the directory
try:
atimes = sorted([x.stat().st_atime for x in path.iterdir() if x.is_file()], reverse=True)
return atimes[0]
except IndexError:
return 0.0
# sort by last access time - least accessed files will be at the end
lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True)
logger.debug(f"cached models in descending atime order: {lru_models}")
while size_needed > max_size and len(lru_models) > 0:
next_victim = lru_models.pop()
victim_size = directory_size(next_victim)
logger.debug(f"Removing cached converted model {next_victim} to free {victim_size / GIG} GB")
shutil.rmtree(next_victim)
size_needed -= victim_size

View File

@@ -1,8 +0,0 @@
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
def _build_model_loader_registry():
return ModelLoaderRegistry()
MODEL_LOADER_REGISTRY = _build_model_loader_registry()

View File

@@ -18,6 +18,7 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
SubModelType,
)
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
@@ -111,6 +112,7 @@ class ModelLoaderBase(ABC):
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
pass
@@ -136,6 +138,12 @@ class ModelLoaderBase(ABC):
"""Return size in bytes of the model, calculated before loading."""
pass
@property
@abstractmethod
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated with this loader."""
pass
@property
@abstractmethod
def ram_cache(self) -> ModelCacheBase[AnyModel]:

View File

@@ -12,10 +12,11 @@ from invokeai.backend.model_manager import (
InvalidModelConfigException,
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_size_utils import calc_model_size_by_fs
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import TorchDevice
@@ -29,11 +30,13 @@ class ModelLoader(ModelLoaderBase):
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
self._app_config = app_config
self._logger = logger
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._torch_dtype = TorchDevice.choose_torch_dtype()
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
@@ -47,15 +50,23 @@ class ModelLoader(ModelLoaderBase):
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
if model_config.type is ModelType.Main and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path = self._get_model_path(model_config)
if not model_path.exists():
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
with skip_torch_weight_init():
locker = self._load_and_cache(model_config, submodel_type)
locker = self._convert_and_load(model_config, model_path, submodel_type)
return LoadedModel(config=model_config, _locker=locker)
@property
def convert_cache(self) -> ModelConvertCacheBase:
"""Return the convert cache associated with this loader."""
return self._convert_cache
@property
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the ram cache associated with this loader."""
@@ -65,14 +76,20 @@ class ModelLoader(ModelLoaderBase):
model_base = self._app_config.models_path
return (model_base / config.path).resolve()
def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> ModelLockerBase:
def _convert_and_load(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> ModelLockerBase:
try:
return self._ram_cache.get(config.key, submodel_type)
except IndexError:
pass
config.path = str(self._get_model_path(config))
loaded_model = self._load_model(config, submodel_type)
cache_path: Path = self._convert_cache.cache_path(str(model_path))
if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
else:
config.path = str(cache_path) if cache_path.exists() else str(self._get_model_path(config))
loaded_model = self._load_model(config, submodel_type)
self._ram_cache.put(
config.key,
@@ -96,6 +113,28 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
)
def _do_convert(
self, config: AnyModelConfig, model_path: Path, cache_path: Path, submodel_type: Optional[SubModelType] = None
) -> AnyModel:
self.convert_cache.make_room(calc_model_size_by_fs(model_path))
pipeline = self._convert_model(config, model_path, cache_path if self.convert_cache.max_size > 0 else None)
if submodel_type:
# Proactively load the various submodels into the RAM cache so that we don't have to re-convert
# the entire pipeline every time a new submodel is needed.
for subtype in SubModelType:
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
return getattr(pipeline, submodel_type.value) if submodel_type else pipeline
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False
# This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
raise NotImplementedError
# This needs to be implemented in the subclass
def _load_model(
self,

View File

@@ -285,11 +285,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
)
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
cache_entry.model.to(target_device, non_blocking=True)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)

View File

@@ -1,34 +1,48 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
from typing import Optional, Tuple, Type
"""
This module implements a system in which model loaders register the
type, base and format of models that they know how to load.
from invokeai.backend.model_manager.config import BaseModelType, ModelConfigBase, ModelFormat, ModelType
from invokeai.backend.model_manager.load.load_base import AnyModelConfig, ModelLoaderBase, SubModelType
Use like this:
cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model = cls(
app_config=app_config,
logger=logger,
ram_cache=ram_cache,
convert_cache=convert_cache
).load_model(model_config, submodel_type)
"""
from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
from ..config import (
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
SubModelType,
)
from . import ModelLoaderBase
class ModelLoaderRegistry:
"""A registry that tracks which model loader class to use for a given model type/format/base combination."""
def __init__(self):
self._registry: dict[str, Type[ModelLoaderBase]] = {}
class ModelLoaderRegistryBase(ABC):
"""This class allows model loaders to register their type, base and format."""
@classmethod
@abstractmethod
def register(
self,
loader_class: Type[ModelLoaderBase],
type: ModelType,
format: ModelFormat,
base: BaseModelType = BaseModelType.Any,
):
"""Register a model loader class."""
key = self._to_registry_key(base, type, format)
if key in self._registry:
raise RuntimeError(
f"{loader_class.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type "
f"of model has already been registered by {self._registry[key].__name__}"
)
self._registry[key] = loader_class
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
"""Define a decorator which registers the subclass of loader."""
@classmethod
@abstractmethod
def get_implementation(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType]
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""
Get subclass of ModelLoaderBase registered to handle base and type.
@@ -42,13 +56,46 @@ class ModelLoaderRegistry:
in, in the event that a submodel type is provided.
"""
key1 = self._to_registry_key(config.base, config.type, config.format) # for a specific base type
key2 = self._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = self._registry.get(key1, None) or self._registry.get(key2, None)
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
class ModelLoaderRegistry(ModelLoaderRegistryBase):
"""
This class allows model loaders to register their type, base and format.
"""
_registry: Dict[str, Type[ModelLoaderBase]] = {}
@classmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
)
cls._registry[key] = subclass
return subclass
return decorator
@classmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation:
raise NotImplementedError(
f"No subclass of ModelLoaderBase is registered for base={config.base}, type={config.type}, "
f"format={config.format}"
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
)
return implementation, config, submodel_type

View File

@@ -1,10 +1,9 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for ControlNet model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
from diffusers import ControlNetModel
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@@ -12,7 +11,8 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import ControlNetCheckpointConfig, SubModelType
from invokeai.backend.model_manager.config import CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@@ -23,15 +23,36 @@ from .generic_diffusers import GenericDiffusersLoader
class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
return ControlNetModel.from_single_file(
config.path,
torch_dtype=self._torch_dtype,
)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
else:
return super()._load_model(config, submodel_type)
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
assert isinstance(config, CheckpointConfigBase)
image_size = (
512
if config.base == BaseModelType.StableDiffusion1
else 768
if config.base == BaseModelType.StableDiffusion2
else 1024
)
self._logger.info(f"Converting {model_path} to diffusers format")
with open(self._app_config.legacy_conf_path / config.config_path, "r") as config_stream:
result = convert_controlnet_to_diffusers(
model_path,
output_path,
original_config_file=config_stream,
image_size=image_size,
precision=self._torch_dtype,
from_safetensors=model_path.suffix == ".safetensors",
)
return result

View File

@@ -18,8 +18,8 @@ from invokeai.backend.model_manager import (
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from .. import ModelLoader, ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)

View File

@@ -8,8 +8,7 @@ import torch
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
from invokeai.backend.raw_model import RawModel

View File

@@ -15,6 +15,7 @@ from invokeai.backend.model_manager import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from .. import ModelLoader, ModelLoaderRegistry
@@ -31,9 +32,10 @@ class LoRALoader(ModelLoader):
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
super().__init__(app_config, logger, ram_cache)
super().__init__(app_config, logger, ram_cache, convert_cache)
self._model_base: Optional[BaseModelType] = None
def _load_model(

View File

@@ -4,28 +4,22 @@
from pathlib import Path
from typing import Optional
from diffusers import (
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
ModelVariantType,
SchedulerPredictionType,
SubModelType,
)
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
DiffusersConfigBase,
MainCheckpointConfig,
ModelVariantType,
)
from invokeai.backend.util.silence_warnings import SilenceWarnings
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@@ -54,12 +48,8 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, CheckpointConfigBase):
return self._load_from_singlefile(config, submodel_type)
if submodel_type is None:
if not submodel_type is not None:
raise Exception("A submodel type must be provided when loading main pipelines.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, DiffusersConfigBase) else None
@@ -81,58 +71,46 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
return result
def _load_from_singlefile(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
load_classes = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: StableDiffusionPipeline,
ModelVariantType.Inpaint: StableDiffusionInpaintPipeline,
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: StableDiffusionPipeline,
ModelVariantType.Inpaint: StableDiffusionInpaintPipeline,
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: StableDiffusionXLPipeline,
ModelVariantType.Inpaint: StableDiffusionXLInpaintPipeline,
},
}
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "model_index.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
else:
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
assert isinstance(config, MainCheckpointConfig)
try:
load_class = load_classes[config.base][config.variant]
except KeyError as e:
raise Exception(f"No diffusers pipeline known for base={config.base}, variant={config.variant}") from e
base = config.base
prediction_type = config.prediction_type.value
upcast_attention = config.upcast_attention
image_size = (
1024
if base == BaseModelType.StableDiffusionXL
else 768
if config.prediction_type == SchedulerPredictionType.VPrediction and base == BaseModelType.StableDiffusion2
else 512
)
# Without SilenceWarnings we get log messages like this:
# site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
# warnings.warn(
# Some weights of the model checkpoint were not used when initializing CLIPTextModel:
# ['text_model.embeddings.position_ids']
# Some weights of the model checkpoint were not used when initializing CLIPTextModelWithProjection:
# ['text_model.embeddings.position_ids']
self._logger.info(f"Converting {model_path} to diffusers format")
with SilenceWarnings():
pipeline = load_class.from_single_file(
config.path,
torch_dtype=self._torch_dtype,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
load_safety_checker=False,
)
if not submodel_type:
return pipeline
# Proactively load the various submodels into the RAM cache so that we don't have to re-load
# the entire pipeline every time a new submodel is needed.
for subtype in SubModelType:
if subtype == submodel_type:
continue
if submodel := getattr(pipeline, subtype.value, None):
self._ram_cache.put(config.key, submodel_type=subtype, model=submodel)
return getattr(pipeline, submodel_type.value)
loaded_model = convert_ckpt_to_diffusers(
model_path,
output_path,
model_type=self.model_base_to_model_type[base],
original_config_file=self._app_config.legacy_conf_path / config.config_path,
extract_ema=True,
from_safetensors=model_path.suffix == ".safetensors",
precision=self._torch_dtype,
prediction_type=prediction_type,
image_size=image_size,
upcast_attention=upcast_attention,
load_safety_checker=False,
num_in_channels=VARIANT_TO_IN_CHANNEL_MAP[config.variant],
)
return loaded_model

View File

@@ -1,9 +1,12 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for VAE model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
from diffusers import AutoencoderKL
import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_file as safetensors_load_file
from invokeai.backend.model_manager import (
AnyModelConfig,
@@ -11,26 +14,56 @@ from invokeai.backend.model_manager import (
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.config import AnyModel, SubModelType, VAECheckpointConfig
from invokeai.backend.model_manager.config import AnyModel, CheckpointConfigBase
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from .. import ModelLoaderRegistry
from .generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
class VAELoader(GenericDiffusersLoader):
"""Class to load VAE models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, VAECheckpointConfig):
return AutoencoderKL.from_single_file(
config.path,
torch_dtype=self._torch_dtype,
)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if not isinstance(config, CheckpointConfigBase):
return False
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.converted_at or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
else:
return super()._load_model(config, submodel_type)
return True
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Optional[Path] = None) -> AnyModel:
# TODO(MM2): check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"VAE conversion not supported for model type: {config.base}")
else:
assert isinstance(config, CheckpointConfigBase)
config_file = self._app_config.legacy_conf_path / config.config_path
if model_path.suffix == ".safetensors":
checkpoint = safetensors_load_file(model_path, device="cpu")
else:
checkpoint = torch.load(model_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
ckpt_config = OmegaConf.load(config_file)
assert isinstance(ckpt_config, DictConfig)
self._logger.info(f"Converting {model_path} to diffusers format")
vae_model = convert_ldm_vae_to_diffusers(
checkpoint=checkpoint,
vae_config=ckpt_config,
image_size=512,
precision=self._torch_dtype,
dump_path=output_path,
)
return vae_model

View File

@@ -1,79 +0,0 @@
import json
from pathlib import Path
from typing import Optional
import torch
def calc_module_size(model: torch.nn.Module) -> int:
"""Estimate the size of a torch.nn.Module in bytes."""
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem: int = mem_params + mem_bufs # in bytes
return mem
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
"""Estimate the size of a model on disk in bytes."""
if model_path.is_file():
return model_path.stat().st_size
if subfolder is not None:
model_path = model_path / subfolder
# this can happen when, for example, the safety checker is not downloaded.
if not model_path.exists():
return 0
all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()]
fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name}
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
other_files = set(all_files) - fp16_files - bit8_files
if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.name.endswith(index_postfix):
continue
try:
with open(model_path / file, "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except Exception:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.suffix in file_format]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = (model_path / model_file).stat()
model_size += file_stats.st_size
return model_size
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu

View File

@@ -1,11 +1,14 @@
# Copyright (c) 2024 The InvokeAI Development Team
"""Various utility functions needed by the loader and caching system."""
import json
from pathlib import Path
from typing import Optional
import torch
from diffusers import DiffusionPipeline
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.load.model_size_utils import calc_module_size
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
@@ -14,7 +17,7 @@ def calc_model_size_by_data(model: AnyModel) -> int:
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return calc_module_size(model)
return _calc_model_by_data(model)
elif isinstance(model, IAIOnnxRuntimeModel):
return _calc_onnx_model_by_data(model)
else:
@@ -27,11 +30,84 @@ def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += calc_module_size(submodel)
res += _calc_model_by_data(submodel)
return res
def _calc_model_by_data(model: torch.nn.Module) -> int:
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem: int = mem_params + mem_bufs # in bytes
return mem
def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:
tensor_size = model.tensors.size() * 2 # The session doubles this
mem = tensor_size # in bytes
return mem
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
"""Estimate the size of a model on disk in bytes."""
if model_path.is_file():
return model_path.stat().st_size
if subfolder is not None:
model_path = model_path / subfolder
# this can happen when, for example, the safety checker is not downloaded.
if not model_path.exists():
return 0
all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()]
fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name}
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
other_files = set(all_files) - fp16_files - bit8_files
if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.name.endswith(index_postfix):
continue
try:
with open(model_path / file, "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except Exception:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.suffix in file_format]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = (model_path / model_file).stat()
model_size += file_stats.st_size
return model_size
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu

View File

@@ -312,8 +312,6 @@ class ModelProbe(object):
config_file = (
"stable-diffusion/v1-inference.yaml"
if base_type is BaseModelType.StableDiffusion1
else "stable-diffusion/sd_xl_base.yaml"
if base_type is BaseModelType.StableDiffusionXL
else "stable-diffusion/v2-inference.yaml"
)
else:
@@ -453,16 +451,8 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
class VaeCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
# VAEs of all base types have the same structure, so we wimp out and
# guess using the name.
for regexp, basetype in [
(r"xl", BaseModelType.StableDiffusionXL),
(r"sd2", BaseModelType.StableDiffusion2),
(r"vae", BaseModelType.StableDiffusion1),
]:
if re.search(regexp, self.model_path.name, re.IGNORECASE):
return basetype
raise InvalidModelConfigException("Cannot determine base type")
# I can't find any standalone 2.X VAEs to test with!
return BaseModelType.StableDiffusion1
class LoRACheckpointProbe(CheckpointProbeBase):

View File

@@ -294,8 +294,8 @@ STARTER_MODELS: list[StarterModel] = [
StarterModel(
name="canny-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlnet-canny-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 with canny conditioning, by Xinsir.",
source="diffusers/controlnet-canny-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 with canny conditioning.",
type=ModelType.ControlNet,
),
StarterModel(
@@ -326,20 +326,6 @@ STARTER_MODELS: list[StarterModel] = [
description="Controlnet weights trained on sdxl-1.0 with Zoe's preprocessor (32 bits).",
type=ModelType.ControlNet,
),
StarterModel(
name="openpose-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlnet-openpose-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 compatible with the DWPose processor by Xinsir.",
type=ModelType.ControlNet,
),
StarterModel(
name="scribble-sdxl",
base=BaseModelType.StableDiffusionXL,
source="xinsir/controlnet-scribble-sdxl-1.0",
description="Controlnet weights trained on sdxl-1.0 compatible with various lineart processors and black/white sketches by Xinsir.",
type=ModelType.ControlNet,
),
# endregion
# region T2I Adapter
StarterModel(

View File

@@ -16,7 +16,6 @@ from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.util.devices import TorchDevice
from .lora import LoRAModelRaw
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
@@ -140,15 +139,12 @@ class ModelPatcher:
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device))
layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device))
layer.to(device=device, non_blocking=True)
layer.to(dtype=torch.float32, non_blocking=True)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(
device=TorchDevice.CPU_DEVICE,
non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE),
)
layer.to(device=torch.device("cpu"), non_blocking=True)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
@@ -157,7 +153,7 @@ class ModelPatcher:
layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device))
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
yield # wait for context manager exit
@@ -165,9 +161,7 @@ class ModelPatcher:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(
weight, non_blocking=TorchDevice.get_non_blocking(weight.device)
)
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
@classmethod
@contextmanager

View File

@@ -255,8 +255,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# Validate assumptions about input tensor shapes.
batch_size, latent_channels, latent_height, latent_width = latents.shape
assert latent_channels == 4
assert list(masked_ref_image_latents.shape) == [1, 4, latent_height, latent_width]
assert list(inpainting_mask.shape) == [1, 1, latent_height, latent_width]
assert masked_ref_image_latents.shape == [1, 4, latent_height, latent_width]
assert inpainting_mask == [1, 1, latent_height, latent_width]
# Repeat original_image_latents and inpainting_mask to match the latents batch size.
original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1)
@@ -299,8 +299,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
same noise used earlier in the pipeline. This should really be handled in a clearer way.
timesteps: The timestep schedule for the denoising process.
init_timestep: The first timestep in the schedule. This is used to determine the initial noise level, so
should be populated if you want noise applied *even* if timesteps is empty.
init_timestep: The first timestep in the schedule.
TODO(ryand): I'm pretty sure this should always be the same as timesteps[0:1]. Confirm that that is the
case, and remove this duplicate param.
callback: A callback function that is called to report progress during the denoising process.
control_data: ControlNet data.
ip_adapter_data: IP-Adapter data.
@@ -315,7 +316,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
SD UNet model.
is_gradient_mask: A flag indicating whether `mask` is a gradient mask or not.
"""
if init_timestep.shape[0] == 0:
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
return latents
orig_latents = latents.clone()

View File

@@ -13,13 +13,17 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import (
StableDiffusionGeneratorPipeline,
)
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
from invokeai.backend.tiles.utils import Tile
from invokeai.backend.tiles.utils import TBLR
# The maximum number of regions with compatible sizes that will be batched together.
# Larger batch sizes improve speed, but require more device memory.
MAX_REGION_BATCH_SIZE = 4
@dataclass
class MultiDiffusionRegionConditioning:
# Region coords in latent space.
region: Tile
region: TBLR
text_conditioning_data: TextConditioningData
control_data: list[ControlNetData]
@@ -27,8 +31,31 @@ class MultiDiffusionRegionConditioning:
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
def _split_into_region_batches(
self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]
) -> list[list[MultiDiffusionRegionConditioning]]:
# Group the regions by shape. Only regions with the same shape can be batched together.
conditioning_by_shape: dict[tuple[int, int], list[MultiDiffusionRegionConditioning]] = {}
for region_conditioning in multi_diffusion_conditioning:
shape_hw = (
region_conditioning.region.bottom - region_conditioning.region.top,
region_conditioning.region.right - region_conditioning.region.left,
)
# In python, a tuple of hashable objects is hashable, so can be used as a key in a dict.
if shape_hw not in conditioning_by_shape:
conditioning_by_shape[shape_hw] = []
conditioning_by_shape[shape_hw].append(region_conditioning)
# Split the regions into batches, respecting the MAX_REGION_BATCH_SIZE constraint.
region_conditioning_batches = []
for region_conditioning_batch in conditioning_by_shape.values():
for i in range(0, len(region_conditioning_batch), MAX_REGION_BATCH_SIZE):
region_conditioning_batches.append(region_conditioning_batch[i : i + MAX_REGION_BATCH_SIZE])
return region_conditioning_batches
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
"""Validate that regional conditioning is not used."""
"""Check the input conditioning and confirm that regional prompting is not used."""
for region_conditioning in multi_diffusion_conditioning:
if (
region_conditioning.text_conditioning_data.cond_regions is not None
@@ -39,7 +66,6 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
def multi_diffusion_denoise(
self,
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
target_overlap: int,
latents: torch.Tensor,
scheduler_step_kwargs: dict[str, Any],
noise: Optional[torch.Tensor],
@@ -49,7 +75,9 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
) -> torch.Tensor:
self._check_regional_prompting(multi_diffusion_conditioning)
if init_timestep.shape[0] == 0:
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
return latents
batch_size, _, latent_height, latent_width = latents.shape
@@ -66,16 +94,24 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
# cropping into regions.
self._adjust_memory_efficient_attention(latents)
# Populate a weighted mask that will be used to combine the results from each region after every step.
# For now, we assume that each region has the same weight (1.0).
region_weight_mask = torch.zeros(
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
)
for region_conditioning in multi_diffusion_conditioning:
region = region_conditioning.region
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
# Group the region conditioning into batches for faster processing.
# region_conditioning_batches[b][r] is the r'th region in the b'th batch.
region_conditioning_batches = self._split_into_region_batches(multi_diffusion_conditioning)
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
# separate scheduler state for each region batch.
# TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler
# statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect
# as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when
# multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each
# scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion.
region_batch_schedulers: list[SchedulerMixin] = [
copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning
copy.deepcopy(self.scheduler) for _ in region_conditioning_batches
]
callback(
@@ -92,68 +128,72 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
batched_t = t.expand(batch_size)
merged_latents = torch.zeros_like(latents)
merged_latents_weights = torch.zeros(
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
)
merged_pred_original: torch.Tensor | None = None
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
for region_batch_idx, region_conditioning_batch in enumerate(region_conditioning_batches):
# Switch to the scheduler for the region batch.
self.scheduler = region_batch_schedulers[region_idx]
self.scheduler = region_batch_schedulers[region_batch_idx]
# Crop the inputs to the region.
region_latents = latents[
:,
:,
region_conditioning.region.coords.top : region_conditioning.region.coords.bottom,
region_conditioning.region.coords.left : region_conditioning.region.coords.right,
]
# TODO(ryand): This logic has not yet been tested with input latents with a batch_size > 1.
# Prepare the latents for the region batch.
batch_latents = torch.cat(
[
latents[
:,
:,
region_conditioning.region.top : region_conditioning.region.bottom,
region_conditioning.region.left : region_conditioning.region.right,
]
for region_conditioning in region_conditioning_batch
],
)
# TODO(ryand): Do we have to repeat the text_conditioning_data to match the batch size? Or does step()
# handle broadcasting properly?
# TODO(ryand): Resume here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# Run the denoising step on the region.
step_output = self.step(
t=batched_t,
latents=region_latents,
latents=batch_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=i,
total_step_count=len(timesteps),
total_step_count=total_step_count,
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)
# Run a denoising step on the region.
# step_output = self._region_step(
# region_conditioning=region_conditioning,
# t=batched_t,
# latents=latents,
# step_index=i,
# total_step_count=len(timesteps),
# scheduler_step_kwargs=scheduler_step_kwargs,
# )
# Store the results from the region.
# If two tiles overlap by more than the target overlap amount, crop the left and top edges of the
# affected tiles to achieve the target overlap.
region = region_conditioning.region
top_adjustment = max(0, region.overlap.top - target_overlap)
left_adjustment = max(0, region.overlap.left - target_overlap)
region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom)
region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right)
merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[
:, :, top_adjustment:, left_adjustment:
]
# For now, we treat every region as having the same weight.
merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0
merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
if pred_orig_sample is not None:
# If one region has pred_original_sample, then we can assume that all regions will have it, because
# they all use the same scheduler.
if merged_pred_original is None:
merged_pred_original = torch.zeros_like(latents)
merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[
:, :, top_adjustment:, left_adjustment:
]
merged_pred_original[:, :, region.top : region.bottom, region.left : region.right] += (
pred_orig_sample
)
# Normalize the merged results.
latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)
# For debugging, uncomment this line to visualize the region seams:
# latents = torch.where(merged_latents_weights > 1, 0.0, latents)
latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
predicted_original = None
if merged_pred_original is not None:
predicted_original = torch.where(
merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
)
callback(
@@ -168,3 +208,35 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
)
return latents
@torch.inference_mode()
def _region_batch_step(
self,
region_conditioning: MultiDiffusionRegionConditioning,
t: torch.Tensor,
latents: torch.Tensor,
step_index: int,
total_step_count: int,
scheduler_step_kwargs: dict[str, Any],
):
# Crop the inputs to the region.
region_latents = latents[
:,
:,
region_conditioning.region.top : region_conditioning.region.bottom,
region_conditioning.region.left : region_conditioning.region.right,
]
# Run the denoising step on the region.
return self.step(
t=t,
latents=region_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=step_index,
total_step_count=total_step_count,
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)

View File

@@ -0,0 +1,3 @@
from .schedulers import SCHEDULER_MAP # noqa: F401
__all__ = ["SCHEDULER_MAP"]

View File

@@ -1,5 +1,3 @@
from typing import Literal
from diffusers import (
DDIMScheduler,
DDPMScheduler,
@@ -45,9 +43,3 @@ SCHEDULER_MAP = {
"lcm": (LCMScheduler, {}),
"tcd": (TCDScheduler, {}),
}
# HACK(ryand): Passing a tuple of keys to Literal works at runtime, but not at type-check time. See the docs here for
# more info: https://typing.readthedocs.io/en/latest/spec/literal.html#parameters-at-runtime. For now, we are ignoring
# this error. In the future, we should fix this type handling.
SCHEDULER_NAME_VALUES = Literal[tuple(SCHEDULER_MAP.keys())] # type: ignore

View File

@@ -1,35 +0,0 @@
from contextlib import contextmanager
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny
@contextmanager
def patch_vae_tiling_params(
vae: AutoencoderKL | AutoencoderTiny,
tile_sample_min_size: int,
tile_latent_min_size: int,
tile_overlap_factor: float,
):
"""Patch the parameters that control the VAE tiling tile size and overlap.
These parameters are not explicitly exposed in the VAE's API, but they have a significant impact on the quality of
the outputs. As a general rule, bigger tiles produce better results, but this comes at the cost of higher memory
usage.
"""
# Record initial config.
orig_tile_sample_min_size = vae.tile_sample_min_size
orig_tile_latent_min_size = vae.tile_latent_min_size
orig_tile_overlap_factor = vae.tile_overlap_factor
try:
# Apply target config.
vae.tile_sample_min_size = tile_sample_min_size
vae.tile_latent_min_size = tile_latent_min_size
vae.tile_overlap_factor = tile_overlap_factor
yield
finally:
# Restore initial config.
vae.tile_sample_min_size = orig_tile_sample_min_size
vae.tile_latent_min_size = orig_tile_latent_min_size
vae.tile_overlap_factor = orig_tile_overlap_factor

View File

@@ -42,10 +42,6 @@ PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NA
class TorchDevice:
"""Abstraction layer for torch devices."""
CPU_DEVICE = torch.device("cpu")
CUDA_DEVICE = torch.device("cuda")
MPS_DEVICE = torch.device("mps")
@classmethod
def choose_torch_device(cls) -> torch.device:
"""Return the torch.device to use for accelerated inference."""
@@ -112,15 +108,3 @@ class TorchDevice:
@classmethod
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
return NAME_TO_PRECISION[precision_name]
@staticmethod
def get_non_blocking(to_device: torch.device) -> bool:
"""Return the non_blocking flag to be used when moving a tensor to a given device.
MPS may have unexpected errors with non-blocking operations - we should not use non-blocking when moving _to_ MPS.
When moving _from_ MPS, we can use non-blocking operations.
See:
- https://github.com/pytorch/pytorch/issues/107455
- https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/28
"""
return False if to_device.type == "mps" else True

View File

@@ -5,10 +5,9 @@ from typing import Optional, Union
import pytest
import torch
from invokeai.app.services.model_manager.model_manager_base import ModelManagerServiceBase
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.model_records import UnknownModelException
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager import BaseModelType, LoadedModel, ModelType, SubModelType
@pytest.fixture(scope="session")

View File

@@ -17,10 +17,7 @@
},
"boards": {
"addBoard": "Add Board",
"archiveBoard": "Archive Board",
"archived": "Archived",
"autoAddBoard": "Auto-Add Board",
"selectedForAutoAdd": "Selected for Auto-Add",
"bottomMessage": "Deleting this board and its images will reset any features currently using them.",
"cancel": "Cancel",
"changeBoard": "Change Board",
@@ -39,13 +36,8 @@
"searchBoard": "Search Boards...",
"selectBoard": "Select a Board",
"topMessage": "This board contains images used in the following features:",
"unarchiveBoard": "Unarchive Board",
"uncategorized": "Uncategorized",
"downloadBoard": "Download Board",
"imagesWithCount_one": "{{count}} image",
"imagesWithCount_other": "{{count}} images",
"assetsWithCount_one": "{{count}} asset",
"assetsWithCount_other": "{{count}} assets"
"downloadBoard": "Download Board"
},
"accordions": {
"generation": {
@@ -372,10 +364,6 @@
"image": "image",
"loading": "Loading",
"loadMore": "Load More",
"newestFirst": "Newest First",
"oldestFirst": "Oldest First",
"sortDirection": "Sort Direction",
"showStarredImagesFirst": "Show Starred Images First",
"noImageSelected": "No Image Selected",
"noImagesInGallery": "No Images to Display",
"setCurrentImage": "Set as Current Image",
@@ -393,10 +381,6 @@
"viewerImage": "Viewer Image",
"compareImage": "Compare Image",
"openInViewer": "Open in Viewer",
"searchImages": "Search by Metadata",
"selectAllOnPage": "Select All On Page",
"selectAllOnBoard": "Select All On Board",
"showArchivedBoards": "Show Archived Boards",
"selectForCompare": "Select for Compare",
"selectAnImageToCompare": "Select an Image to Compare",
"slider": "Slider",

View File

@@ -23,7 +23,6 @@ import { addEnqueueRequestedCanvasListener } from 'app/store/middleware/listener
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addEnqueueRequestedNodes } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryOffsetChanged';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addRequestedSingleImageDeletionListener } from 'app/store/middleware/listenerMiddleware/listeners/imageDeleted';
@@ -52,8 +51,6 @@ import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddle
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
@@ -80,7 +77,6 @@ addImagesUnstarredListener(startAppListening);
// Gallery
addGalleryImageClickedListener(startAppListening);
addGalleryOffsetChangedListener(startAppListening);
// User Invoked
addEnqueueRequestedCanvasListener(startAppListening);
@@ -120,7 +116,6 @@ addControlNetAutoProcessListener(startAppListening);
addImageAddedToBoardFulfilledListener(startAppListening);
addImageRemovedFromBoardFulfilledListener(startAppListening);
addBoardIdSelectedListener(startAppListening);
addArchivedOrDeletedBoardListener(startAppListening);
// Node schemas
addGetOpenAPISchemaListener(startAppListening);

View File

@@ -1,48 +0,0 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
autoAddBoardIdChanged,
boardIdSelected,
galleryViewChanged,
shouldShowArchivedBoardsChanged,
} from 'features/gallery/store/gallerySlice';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: isAnyOf(
// Updating a board may change its archived status
boardsApi.endpoints.updateBoard.matchFulfilled,
// If the selected/auto-add board was deleted from a different session, we'll only know during the list request,
boardsApi.endpoints.listAllBoards.matchFulfilled,
// If a board is deleted, we'll need to reset the auto-add board
imagesApi.endpoints.deleteBoard.matchFulfilled,
imagesApi.endpoints.deleteBoardAndImages.matchFulfilled,
// When we change the visibility of archived boards, we may need to reset the auto-add board
shouldShowArchivedBoardsChanged
),
effect: async (action, { dispatch, getState }) => {
/**
* The auto-add board shouldn't be set to an archived board or deleted board. When we archive a board, delete
* a board, or change a the archived board visibility flag, we may need to reset the auto-add board.
*/
const state = getState();
const queryArgs = selectListBoardsQueryArgs(state);
const queryResult = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
const autoAddBoardId = state.gallery.autoAddBoardId;
if (!queryResult.data) {
return;
}
if (!queryResult.data.find((board) => board.board_id === autoAddBoardId)) {
dispatch(autoAddBoardIdChanged('none'));
dispatch(boardIdSelected({ boardId: 'none' }));
dispatch(galleryViewChanged('images'));
}
},
});
};

View File

@@ -2,7 +2,8 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { imagesApi } from 'services/api/endpoints/images';
import { getListImagesUrl } from 'services/api/util';
import type { ImageCache } from 'services/api/types';
import { getListImagesUrl, imagesSelectors } from 'services/api/util';
export const addFirstListImagesListener = (startAppListening: AppStartListening) => {
startAppListening({
@@ -17,10 +18,13 @@ export const addFirstListImagesListener = (startAppListening: AppStartListening)
cancelActiveListeners();
unsubscribe();
const data = action.payload;
// TODO: figure out how to type the predicate
const data = action.payload as ImageCache;
if (data.items.length > 0) {
dispatch(imageSelected(data.items[0] ?? null));
if (data.ids.length > 0) {
// Select the first image
const firstImage = imagesSelectors.selectAll(data)[0];
dispatch(imageSelected(firstImage ?? null));
}
},
});

View File

@@ -1,13 +1,9 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
boardIdSelected,
galleryViewChanged,
imageSelected,
selectionChanged,
} from 'features/gallery/store/gallerySlice';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesSelectors } from 'services/api/util';
export const addBoardIdSelectedListener = (startAppListening: AppStartListening) => {
startAppListening({
@@ -18,9 +14,14 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
const board_id = boardIdSelected.match(action) ? action.payload.boardId : state.gallery.selectedBoardId;
dispatch(selectionChanged([]));
const galleryView = galleryViewChanged.match(action) ? action.payload : state.gallery.galleryView;
// when a board is selected, we need to wait until the board has loaded *some* images, then select the first one
const categories = galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES;
const queryArgs = { board_id: board_id ?? 'none', categories };
// wait until the board has some images - maybe it already has some from a previous fetch
// must use getState() to ensure we do not have stale state
@@ -34,12 +35,11 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
const { data: boardImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
if (boardImagesData && boardIdSelected.match(action) && action.payload.selectedImageName) {
const selectedImage = boardImagesData.items.find(
(item) => item.image_name === action.payload.selectedImageName
);
const selectedImage = imagesSelectors.selectById(boardImagesData, action.payload.selectedImageName);
dispatch(imageSelected(selectedImage || null));
} else if (boardImagesData) {
dispatch(imageSelected(boardImagesData.items[0] || null));
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
dispatch(imageSelected(firstImage || null));
} else {
// board has no images - deselect
dispatch(imageSelected(null));

View File

@@ -4,6 +4,7 @@ import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelecto
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
export const galleryImageClicked = createAction<{
imageDTO: ImageDTO;
@@ -31,14 +32,14 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
const queryResult = imagesApi.endpoints.listImages.select(queryArgs)(state);
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
if (!queryResult.data) {
if (!listImagesData) {
// Should never happen if we have clicked a gallery image
return;
}
const imageDTOs = queryResult.data.items;
const imageDTOs = imagesSelectors.selectAll(listImagesData);
const selection = state.gallery.selection;
if (altKey) {

View File

@@ -1,119 +0,0 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, offsetChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
export const addGalleryOffsetChangedListener = (startAppListening: AppStartListening) => {
/**
* When the user changes pages in the gallery, we need to wait until the next page of images is loaded, then maybe
* update the selection.
*
* There are a three scenarios:
*
* 1. The page is changed by clicking the pagination buttons. No changes to selection are needed.
*
* 2. The page is changed by using the arrow keys (without alt).
* - When going backwards, select the last image.
* - When going forwards, select the first image.
*
* 3. The page is changed by using the arrows keys with alt. This means the user is changing the comparison image.
* - When going backwards, select the last image _as the comparison image_.
* - When going forwards, select the first image _as the comparison image_.
*/
startAppListening({
actionCreator: offsetChanged,
effect: async (action, { dispatch, getState, getOriginalState, take, cancelActiveListeners }) => {
// Cancel any active listeners to prevent the selection from changing without user input
cancelActiveListeners();
const { withHotkey } = action.payload;
if (!withHotkey) {
// User changed pages by clicking the pagination buttons - no changes to selection
return;
}
const originalState = getOriginalState();
const prevOffset = originalState.gallery.offset;
const offset = getState().gallery.offset;
if (offset === prevOffset) {
// The page didn't change - bail
return;
}
/**
* We need to wait until the next page of images is loaded before updating the selection, so we use the correct
* page of images.
*
* The simplest way to do it would be to use `take` to wait for the next fulfilled action, but RTK-Q doesn't
* dispatch an action on cache hits. This means the `take` will only return if the cache is empty. If the user
* changes to a cached page - a common situation - the `take` will never resolve.
*
* So we need to take a two-step approach. First, check if we have data in the cache for the page of images. If
* we have data cached, use it to update the selection. If we don't have data cached, wait for the next fulfilled
* action, which updates the cache, then use the cache to update the selection.
*/
// Check if we have data in the cache for the page of images
const queryArgs = selectListImagesQueryArgs(getState());
let { data } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
// No data yet - wait for the network request to complete
if (!data) {
const takeResult = await take(imagesApi.endpoints.listImages.matchFulfilled, 5000);
if (!takeResult) {
// The request didn't complete in time - bail
return;
}
data = takeResult[0].payload;
}
// We awaited a network request - state could have changed, get fresh state
const state = getState();
const { selection, imageToCompare } = state.gallery;
const imageDTOs = data?.items;
if (!imageDTOs) {
// The page didn't load - bail
return;
}
if (withHotkey === 'arrow') {
// User changed pages by using the arrow keys - selection changes to first or last image depending
if (offset < prevOffset) {
// We've gone backwards
const lastImage = imageDTOs[imageDTOs.length - 1];
if (!selection.some((selectedImage) => selectedImage.image_name === lastImage?.image_name)) {
dispatch(selectionChanged(lastImage ? [lastImage] : []));
}
} else {
// We've gone forwards
const firstImage = imageDTOs[0];
if (!selection.some((selectedImage) => selectedImage.image_name === firstImage?.image_name)) {
dispatch(selectionChanged(firstImage ? [firstImage] : []));
}
}
return;
}
if (withHotkey === 'alt+arrow') {
// User changed pages by using the arrow keys with alt - comparison image changes to first or last depending
if (offset < prevOffset) {
// We've gone backwards
const lastImage = imageDTOs[imageDTOs.length - 1];
if (lastImage && imageToCompare?.image_name !== lastImage.image_name) {
dispatch(imageToCompareChanged(lastImage));
}
} else {
// We've gone forwards
const firstImage = imageDTOs[0];
if (firstImage && imageToCompare?.image_name !== firstImage.image_name) {
dispatch(imageToCompareChanged(firstImage));
}
}
return;
}
},
});
};

View File

@@ -22,10 +22,11 @@ import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { isImageFieldInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { forEach } from 'lodash-es';
import { clamp, forEach } from 'lodash-es';
import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.nodes.present.nodes.forEach((node) => {
@@ -117,7 +118,32 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
}
dispatch(isModalOpenChanged(false));
const state = getState();
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
const { image_name } = imageDTO;
const baseQueryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const cachedImageDTOs = data ? imagesSelectors.selectAll(data) : [];
const deletedImageIndex = cachedImageDTOs.findIndex((i) => i.image_name === image_name);
const filteredImageDTOs = cachedImageDTOs.filter((i) => i.image_name !== image_name);
const newSelectedImageIndex = clamp(deletedImageIndex, 0, filteredImageDTOs.length - 1);
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex];
if (newSelectedImageDTO) {
dispatch(imageSelected(newSelectedImageDTO));
} else {
dispatch(imageSelected(null));
}
}
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
if (imageUsage.isCanvasImage) {
@@ -142,20 +168,6 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
if (wasImageDeleted) {
dispatch(api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
}
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
const baseQueryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
if (data && data.items) {
const newlySelectedImage = data?.items.find((img) => img.image_name !== imageDTO?.image_name);
dispatch(imageSelected(newlySelectedImage || null));
} else {
dispatch(imageSelected(null));
}
}
},
});
@@ -176,8 +188,10 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
const queryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state);
if (data && data.items[0]) {
dispatch(imageSelected(data.items[0]));
const newSelectedImageDTO = data ? imagesSelectors.selectAll(data)[0] : undefined;
if (newSelectedImageDTO) {
dispatch(imageSelected(newSelectedImageDTO));
} else {
dispatch(imageSelected(null));
}

View File

@@ -15,12 +15,7 @@ import {
} from 'features/controlLayers/store/controlLayersSlice';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import {
imageSelected,
imageToCompareChanged,
isImageViewerOpenChanged,
selectionChanged,
} from 'features/gallery/store/gallerySlice';
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
@@ -221,7 +216,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
@@ -239,7 +233,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
imageDTO,
})
);
dispatch(selectionChanged([]));
return;
}
@@ -255,7 +248,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
@@ -269,7 +261,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
imageDTOs,
})
);
dispatch(selectionChanged([]));
return;
}
},

View File

@@ -11,7 +11,6 @@ import {
ipaLayerImageChanged,
rgLayerIPAdapterImageChanged,
} from 'features/controlLayers/store/controlLayersSlice';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { toast } from 'features/toast/toast';
@@ -63,8 +62,7 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
);
// Attempt to get the board's name for the toast
const queryArgs = selectListBoardsQueryArgs(state);
const { data } = boardsApi.endpoints.listAllBoards.select(queryArgs)(state);
const { data } = boardsApi.endpoints.listAllBoards.select()(state);
// Fall back to just the board id if we can't find the board for some reason
const board = data?.find((b) => b.board_id === autoAddBoardId);

View File

@@ -8,14 +8,14 @@ import {
galleryViewChanged,
imageSelected,
isImageViewerOpenChanged,
offsetChanged,
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { imagesAdapter } from 'services/api/util';
import { socketInvocationComplete } from 'services/events/actions';
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
@@ -52,6 +52,24 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
}
if (!imageDTO.is_intermediate) {
/**
* Cache updates for when an image result is received
* - add it to the no_board/images
*/
dispatch(
imagesApi.util.updateQueryData(
'listImages',
{
board_id: imageDTO.board_id ?? 'none',
categories: IMAGE_CATEGORIES,
},
(draft) => {
imagesAdapter.addOne(draft, imageDTO);
}
)
);
// update the total images for the board
dispatch(
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
@@ -60,18 +78,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
})
);
dispatch(
imagesApi.util.invalidateTags([
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
{
type: 'ImageList',
id: getListImagesUrl({
board_id: imageDTO.board_id ?? 'none',
categories: getCategories(imageDTO),
}),
},
])
);
dispatch(imagesApi.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
const { shouldAutoSwitch } = gallery;
@@ -91,8 +98,6 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
);
}
dispatch(offsetChanged({ offset: 0 }));
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
dispatch(
boardIdSelected({

View File

@@ -1,37 +1,47 @@
import type { IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import type { MouseEvent } from 'react';
import { memo } from 'react';
import type { MouseEvent, ReactElement } from 'react';
import { memo, useMemo } from 'react';
const sx: SystemStyleObject = {
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
},
};
type Props = Omit<IconButtonProps, 'aria-label' | 'onClick' | 'tooltip'> & {
type Props = {
onClick: (event: MouseEvent<HTMLButtonElement>) => void;
tooltip: string;
icon?: ReactElement;
styleOverrides?: SystemStyleObject;
};
const IAIDndImageIcon = (props: Props) => {
const { onClick, tooltip, icon, ...rest } = props;
const { onClick, tooltip, icon, styleOverrides } = props;
const sx = useMemo(
() => ({
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
},
...styleOverrides,
}),
[styleOverrides]
);
return (
<IconButton
onClick={onClick}
aria-label={tooltip}
tooltip={tooltip}
icon={icon}
size="sm"
variant="link"
sx={sx}
data-testid={tooltip}
{...rest}
/>
);
};

View File

@@ -0,0 +1,16 @@
/**
* Comparator function for sorting dates in ascending order
*/
export const dateComparator = (a: string, b: string) => {
const dateA = new Date(a);
const dateB = new Date(b);
// sort in ascending order
if (dateA > dateB) {
return 1;
}
if (dateA < dateB) {
return -1;
}
return 0;
};

View File

@@ -7,7 +7,6 @@ import {
isModalOpenChanged,
selectChangeBoardModalSlice,
} from 'features/changeBoardModal/store/slice';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
@@ -21,8 +20,7 @@ const selectImagesToChange = createMemoizedSelector(
const ChangeBoardModal = () => {
const dispatch = useAppDispatch();
const [selectedBoard, setSelectedBoard] = useState<string | null>();
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
const { data: boards, isFetching } = useListAllBoardsQuery(queryArgs);
const { data: boards, isFetching } = useListAllBoardsQuery();
const isModalOpen = useAppSelector((s) => s.changeBoardModal.isModalOpen);
const imagesToChange = useAppSelector(selectImagesToChange);
const [addImagesToBoard] = useAddImagesToBoardMutation();

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Spinner } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
@@ -184,25 +185,25 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
/>
</Box>
{controlImage && (
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={<PiFloppyDiskBold size={16} />}
tooltip={t('controlnet.saveControlImage')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={<PiRulerBold size={16} />}
tooltip={t('controlnet.setControlImageDimensions')}
/>
</Flex>
)}
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
tooltip={t('controlnet.saveControlImage')}
styleOverrides={saveControlImageStyleOverrides}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={t('controlnet.setControlImageDimensions')}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</>
{pendingControlImages.includes(id) && (
<Flex
@@ -225,3 +226,6 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
};
export default memo(ControlAdapterImagePreview);
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Spinner, useShiftModifier } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -159,7 +160,7 @@ export const ControlAdapterImagePreview = memo(
onMouseEnter={handleMouseEnter}
onMouseLeave={handleMouseLeave}
position="relative"
w={36}
w="full"
h={36}
alignItems="center"
justifyContent="center"
@@ -192,27 +193,25 @@ export const ControlAdapterImagePreview = memo(
/>
</Box>
{controlImage && (
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={<PiFloppyDiskBold size={16} />}
tooltip={t('controlnet.saveControlImage')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={<PiRulerBold size={16} />}
tooltip={
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
}
/>
</Flex>
)}
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSaveControlImage}
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
tooltip={t('controlnet.saveControlImage')}
styleOverrides={saveControlImageStyleOverrides}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</>
{controlAdapter.processorPendingBatchId !== null && (
<Flex
@@ -236,3 +235,6 @@ export const ControlAdapterImagePreview = memo(
);
ControlAdapterImagePreview.displayName = 'ControlAdapterImagePreview';
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -81,7 +82,7 @@ export const IPAdapterImagePreview = memo(
}, [handleResetControlImage, isConnected, isErrorControlImage]);
return (
<Flex position="relative" w={36} h={36} alignItems="center">
<Flex position="relative" w="full" h={36} alignItems="center" justifyContent="center">
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
@@ -89,25 +90,24 @@ export const IPAdapterImagePreview = memo(
postUploadAction={postUploadAction}
/>
{controlImage && (
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={<PiRulerBold size={16} />}
tooltip={
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
}
/>
</Flex>
)}
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</>
</Flex>
);
}
);
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 6 };

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -78,34 +79,31 @@ export const InitialImagePreview = memo(({ image, onChangeImage, droppableData,
}, [onReset, isConnected, isErrorControlImage]);
return (
<Flex w="full" alignItems="center" justifyContent="center">
<Flex position="relative" w={36} h={36} alignItems="center" justifyContent="center">
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={imageDTO}
postUploadAction={postUploadAction}
/>
<Flex position="relative" w="full" h={36} alignItems="center" justifyContent="center">
<IAIDndImage
draggableData={draggableData}
droppableData={droppableData}
imageDTO={imageDTO}
postUploadAction={postUploadAction}
/>
{imageDTO && (
<Flex position="absolute" flexDir="column" top={1} insetInlineEnd={1} gap={1}>
<IAIDndImageIcon
onClick={onReset}
icon={<PiArrowCounterClockwiseBold size={16} />}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={onUseSize}
icon={<PiRulerBold size={16} />}
tooltip={
shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')
}
/>
</Flex>
)}
</Flex>
<>
<IAIDndImageIcon
onClick={onReset}
icon={imageDTO ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
tooltip={t('controlnet.resetControlImage')}
/>
<IAIDndImageIcon
onClick={onUseSize}
icon={imageDTO ? <PiRulerBold size={16} /> : undefined}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
styleOverrides={useSizeStyleOverrides}
/>
</>
</Flex>
);
});
InitialImagePreview.displayName = 'InitialImagePreview';
const useSizeStyleOverrides: SystemStyleObject = { mt: 6 };

View File

@@ -11,28 +11,25 @@ const BoardAutoAddSelect = () => {
const { t } = useTranslation();
const autoAddBoardId = useAppSelector((s) => s.gallery.autoAddBoardId);
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const { options, hasBoards } = useListAllBoardsQuery(
{},
{
selectFromResult: ({ data }) => {
const options: ComboboxOption[] = [
{
label: t('controlnet.none'),
value: 'none',
},
].concat(
(data ?? []).map(({ board_id, board_name }) => ({
label: board_name,
value: board_id,
}))
);
return {
options,
hasBoards: options.length > 1,
};
},
}
);
const { options, hasBoards } = useListAllBoardsQuery(undefined, {
selectFromResult: ({ data }) => {
const options: ComboboxOption[] = [
{
label: t('controlnet.none'),
value: 'none',
},
].concat(
(data ?? []).map(({ board_id, board_name }) => ({
label: board_name,
value: board_id,
}))
);
return {
options,
hasBoards: options.length > 1,
};
},
});
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@@ -3,12 +3,11 @@ import { ContextMenu, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-librar
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { autoAddBoardIdChanged, selectGallerySlice } from 'features/gallery/store/gallerySlice';
import type { BoardId } from 'features/gallery/store/types';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { toast } from 'features/toast/toast';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArchiveBold, PiArchiveFill, PiDownloadBold, PiPlusBold } from 'react-icons/pi';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import { PiDownloadBold, PiPlusBold } from 'react-icons/pi';
import { useBulkDownloadImagesMutation } from 'services/api/endpoints/images';
import { useBoardName } from 'services/api/hooks/useBoardName';
import type { BoardDTO } from 'services/api/types';
@@ -16,85 +15,52 @@ import type { BoardDTO } from 'services/api/types';
import GalleryBoardContextMenuItems from './GalleryBoardContextMenuItems';
type Props = {
board: BoardDTO;
board?: BoardDTO;
board_id: BoardId;
children: ContextMenuProps<HTMLDivElement>['children'];
setBoardToDelete: (board?: BoardDTO) => void;
setBoardToDelete?: (board?: BoardDTO) => void;
};
const BoardContextMenu = ({ board, setBoardToDelete, children }: Props) => {
const BoardContextMenu = ({ board, board_id, setBoardToDelete, children }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const selectIsSelectedForAutoAdd = useMemo(
() => createSelector(selectGallerySlice, (gallery) => board.board_id === gallery.autoAddBoardId),
[board.board_id]
() => createSelector(selectGallerySlice, (gallery) => board && board.board_id === gallery.autoAddBoardId),
[board]
);
const [updateBoard] = useUpdateBoardMutation();
const isSelectedForAutoAdd = useAppSelector(selectIsSelectedForAutoAdd);
const boardName = useBoardName(board.board_id);
const boardName = useBoardName(board_id);
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
const [bulkDownload] = useBulkDownloadImagesMutation();
const handleSetAutoAdd = useCallback(() => {
dispatch(autoAddBoardIdChanged(board.board_id));
}, [board.board_id, dispatch]);
dispatch(autoAddBoardIdChanged(board_id));
}, [board_id, dispatch]);
const handleBulkDownload = useCallback(() => {
bulkDownload({ image_names: [], board_id: board.board_id });
}, [board.board_id, bulkDownload]);
const handleArchive = useCallback(async () => {
try {
await updateBoard({
board_id: board.board_id,
changes: { archived: true },
}).unwrap();
} catch (error) {
toast({
status: 'error',
title: 'Unable to archive board',
});
}
}, [board.board_id, updateBoard]);
const handleUnarchive = useCallback(() => {
updateBoard({
board_id: board.board_id,
changes: { archived: false },
});
}, [board.board_id, updateBoard]);
bulkDownload({ image_names: [], board_id: board_id });
}, [board_id, bulkDownload]);
const renderMenuFunc = useCallback(
() => (
<MenuList visibility="visible">
<MenuGroup title={boardName}>
{!autoAssignBoardOnClick && (
<MenuItem icon={<PiPlusBold />} isDisabled={isSelectedForAutoAdd} onClick={handleSetAutoAdd}>
{isSelectedForAutoAdd ? t('boards.selectedForAutoAdd') : t('boards.menuItemAutoAdd')}
</MenuItem>
)}
<MenuItem
icon={<PiPlusBold />}
isDisabled={isSelectedForAutoAdd || autoAssignBoardOnClick}
onClick={handleSetAutoAdd}
>
{t('boards.menuItemAutoAdd')}
</MenuItem>
{isBulkDownloadEnabled && (
<MenuItem icon={<PiDownloadBold />} onClickCapture={handleBulkDownload}>
{t('boards.downloadBoard')}
</MenuItem>
)}
{board.archived && (
<MenuItem icon={<PiArchiveBold />} onClick={handleUnarchive}>
{t('boards.unarchiveBoard')}
</MenuItem>
)}
{!board.archived && (
<MenuItem icon={<PiArchiveFill />} onClick={handleArchive}>
{t('boards.archiveBoard')}
</MenuItem>
)}
<GalleryBoardContextMenuItems board={board} setBoardToDelete={setBoardToDelete} />
{board && <GalleryBoardContextMenuItems board={board} setBoardToDelete={setBoardToDelete} />}
</MenuGroup>
</MenuList>
),
@@ -108,8 +74,6 @@ const BoardContextMenu = ({ board, setBoardToDelete, children }: Props) => {
isSelectedForAutoAdd,
setBoardToDelete,
t,
handleArchive,
handleUnarchive,
]
);

View File

@@ -1,22 +0,0 @@
import { useTranslation } from 'react-i18next';
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
type Props = {
board_id: string;
isArchived: boolean;
};
export const BoardTotalsTooltip = ({ board_id, isArchived }: Props) => {
const { t } = useTranslation();
const { imagesTotal } = useGetBoardImagesTotalQuery(board_id, {
selectFromResult: ({ data }) => {
return { imagesTotal: data?.total ?? 0 };
},
});
const { assetsTotal } = useGetBoardAssetsTotalQuery(board_id, {
selectFromResult: ({ data }) => {
return { assetsTotal: data?.total ?? 0 };
},
});
return `${t('boards.imagesWithCount', { count: imagesTotal })}, ${t('boards.assetsWithCount', { count: assetsTotal })}${isArchived ? ` (${t('boards.archived')})` : ''}`;
};

View File

@@ -2,7 +2,6 @@ import { Collapse, Flex, Grid, GridItem } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { CSSProperties } from 'react';
import { memo, useState } from 'react';
@@ -27,8 +26,7 @@ const BoardsList = (props: Props) => {
const { isOpen } = props;
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
const boardSearchText = useAppSelector((s) => s.gallery.boardSearchText);
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
const { data: boards } = useListAllBoardsQuery(queryArgs);
const { data: boards } = useListAllBoardsQuery();
const filteredBoards = boardSearchText
? boards?.filter((board) => board.board_name.toLowerCase().includes(boardSearchText.toLowerCase()))
: boards;

View File

@@ -8,12 +8,15 @@ import SelectionOverlay from 'common/components/SelectionOverlay';
import type { AddToBoardDropData } from 'features/dnd/types';
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
import { autoAddBoardIdChanged, boardIdSelected, selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArchiveBold, PiImagesSquare } from 'react-icons/pi';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import { PiImagesSquare } from 'react-icons/pi';
import {
useGetBoardAssetsTotalQuery,
useGetBoardImagesTotalQuery,
useUpdateBoardMutation,
} from 'services/api/endpoints/boards';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { BoardDTO } from 'services/api/types';
@@ -25,14 +28,6 @@ const editableInputStyles: SystemStyleObject = {
},
};
const ArchivedIcon = () => {
return (
<Box position="absolute" top={1} insetInlineEnd={2} p={0} minW={0}>
<Icon as={PiArchiveBold} fill="base.300" filter="drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))" />
</Box>
);
};
interface GalleryBoardProps {
board: BoardDTO;
isSelected: boolean;
@@ -41,7 +36,6 @@ interface GalleryBoardProps {
const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const selectIsSelectedForAutoAdd = useMemo(
() => createSelector(selectGallerySlice, (gallery) => board.board_id === gallery.autoAddBoardId),
@@ -57,6 +51,17 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
setIsHovered(false);
}, []);
const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id);
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id);
const tooltip = useMemo(() => {
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
return undefined;
}
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
assetsTotal.total
} asset${assetsTotal.total === 1 ? '' : 's'}`;
}, [assetsTotal, imagesTotal]);
const { currentData: coverImage } = useGetImageDTOQuery(board.cover_image_name ?? skipToken);
const { board_name, board_id } = board;
@@ -112,7 +117,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
const handleChange = useCallback((newBoardName: string) => {
setLocalBoardName(newBoardName);
}, []);
const { t } = useTranslation();
return (
<Box w="full" h="full" userSelect="none">
<Flex
@@ -125,12 +130,9 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
w="full"
h="full"
>
<BoardContextMenu board={board} setBoardToDelete={setBoardToDelete}>
<BoardContextMenu board={board} board_id={board_id} setBoardToDelete={setBoardToDelete}>
{(ref) => (
<Tooltip
label={<BoardTotalsTooltip board_id={board.board_id} isArchived={Boolean(board.archived)} />}
openDelay={1000}
>
<Tooltip label={tooltip} openDelay={1000}>
<Flex
ref={ref}
onClick={handleSelectBoard}
@@ -143,7 +145,6 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
cursor="pointer"
bg="base.800"
>
{board.archived && <ArchivedIcon />}
{coverImage?.thumbnail_url ? (
<Image
src={coverImage?.thumbnail_url}

View File

@@ -4,12 +4,12 @@ import IAIDroppable from 'common/components/IAIDroppable';
import SelectionOverlay from 'common/components/SelectionOverlay';
import type { RemoveFromBoardDropData } from 'features/dnd/types';
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
import NoBoardBoardContextMenu from 'features/gallery/components/Boards/NoBoardBoardContextMenu';
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
import { useBoardName } from 'services/api/hooks/useBoardName';
interface Props {
@@ -29,6 +29,17 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
}, [dispatch, autoAssignBoardOnClick]);
const [isHovered, setIsHovered] = useState(false);
const { data: imagesTotal } = useGetBoardImagesTotalQuery('none');
const { data: assetsTotal } = useGetBoardAssetsTotalQuery('none');
const tooltip = useMemo(() => {
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
return undefined;
}
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
assetsTotal.total
} asset${assetsTotal.total === 1 ? '' : 's'}`;
}, [assetsTotal, imagesTotal]);
const handleMouseOver = useCallback(() => {
setIsHovered(true);
}, []);
@@ -58,9 +69,9 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
w="full"
h="full"
>
<NoBoardBoardContextMenu>
<BoardContextMenu board_id="none">
{(ref) => (
<Tooltip label={<BoardTotalsTooltip board_id="none" isArchived={false} />} openDelay={1000}>
<Tooltip label={tooltip} openDelay={1000}>
<Flex
ref={ref}
onClick={handleSelectBoard}
@@ -111,7 +122,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
</Flex>
</Tooltip>
)}
</NoBoardBoardContextMenu>
</BoardContextMenu>
</Flex>
</Box>
);

View File

@@ -1,55 +0,0 @@
import type { ContextMenuProps } from '@invoke-ai/ui-library';
import { ContextMenu, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { autoAddBoardIdChanged } from 'features/gallery/store/gallerySlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiDownloadBold, PiPlusBold } from 'react-icons/pi';
import { useBulkDownloadImagesMutation } from 'services/api/endpoints/images';
type Props = {
children: ContextMenuProps<HTMLDivElement>['children'];
};
const NoBoardBoardContextMenu = ({ children }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const isSelectedForAutoAdd = useAppSelector((s) => s.gallery.autoAddBoardId === 'none');
const isBulkDownloadEnabled = useFeatureStatus('bulkDownload');
const [bulkDownload] = useBulkDownloadImagesMutation();
const handleSetAutoAdd = useCallback(() => {
dispatch(autoAddBoardIdChanged('none'));
}, [dispatch]);
const handleBulkDownload = useCallback(() => {
bulkDownload({ image_names: [], board_id: 'none' });
}, [bulkDownload]);
const renderMenuFunc = useCallback(
() => (
<MenuList visibility="visible">
<MenuGroup title={t('boards.uncategorized')}>
{!autoAssignBoardOnClick && (
<MenuItem icon={<PiPlusBold />} isDisabled={isSelectedForAutoAdd} onClick={handleSetAutoAdd}>
{isSelectedForAutoAdd ? t('boards.selectedForAutoAdd') : t('boards.menuItemAutoAdd')}
</MenuItem>
)}
{isBulkDownloadEnabled && (
<MenuItem icon={<PiDownloadBold />} onClickCapture={handleBulkDownload}>
{t('boards.downloadBoard')}
</MenuItem>
)}
</MenuGroup>
</MenuList>
),
[autoAssignBoardOnClick, handleBulkDownload, handleSetAutoAdd, isBulkDownloadEnabled, isSelectedForAutoAdd, t]
);
return <ContextMenu renderMenu={renderMenuFunc}>{children}</ContextMenu>;
};
export default memo(NoBoardBoardContextMenu);

View File

@@ -0,0 +1,111 @@
import type { FormLabelProps } from '@invoke-ai/ui-library';
import {
Checkbox,
CompositeSlider,
Flex,
FormControl,
FormControlGroup,
FormLabel,
IconButton,
Popover,
PopoverBody,
PopoverContent,
PopoverTrigger,
Switch,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
alwaysShowImageSizeBadgeChanged,
autoAssignBoardOnClickChanged,
setGalleryImageMinimumWidth,
shouldAutoSwitchChanged,
} from 'features/gallery/store/gallerySlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { RiSettings4Fill } from 'react-icons/ri';
import BoardAutoAddSelect from './Boards/BoardAutoAddSelect';
const formLabelProps: FormLabelProps = {
flexGrow: 1,
};
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
const shouldAutoSwitch = useAppSelector((s) => s.gallery.shouldAutoSwitch);
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
const handleChangeGalleryImageMinimumWidth = useCallback(
(v: number) => {
dispatch(setGalleryImageMinimumWidth(v));
},
[dispatch]
);
const handleChangeAutoSwitch = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldAutoSwitchChanged(e.target.checked));
},
[dispatch]
);
const handleChangeAutoAssignBoardOnClick = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(autoAssignBoardOnClickChanged(e.target.checked)),
[dispatch]
);
const handleChangeAlwaysShowImageSizeBadgeChanged = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(alwaysShowImageSizeBadgeChanged(e.target.checked)),
[dispatch]
);
return (
<Popover isLazy>
<PopoverTrigger>
<IconButton
tooltip={t('gallery.gallerySettings')}
aria-label={t('gallery.gallerySettings')}
size="sm"
icon={<RiSettings4Fill />}
/>
</PopoverTrigger>
<PopoverContent>
<PopoverBody>
<Flex direction="column" gap={2}>
<FormControl>
<FormLabel>{t('gallery.galleryImageSize')}</FormLabel>
<CompositeSlider
value={galleryImageMinimumWidth}
onChange={handleChangeGalleryImageMinimumWidth}
min={45}
max={256}
defaultValue={90}
/>
</FormControl>
<FormControlGroup formLabelProps={formLabelProps}>
<FormControl>
<FormLabel>{t('gallery.autoSwitchNewImages')}</FormLabel>
<Switch isChecked={shouldAutoSwitch} onChange={handleChangeAutoSwitch} />
</FormControl>
<FormControl>
<FormLabel>{t('gallery.autoAssignBoardOnClick')}</FormLabel>
<Checkbox isChecked={autoAssignBoardOnClick} onChange={handleChangeAutoAssignBoardOnClick} />
</FormControl>
<FormControl>
<FormLabel>{t('gallery.alwaysShowImageSizeBadge')}</FormLabel>
<Checkbox isChecked={alwaysShowImageSizeBadge} onChange={handleChangeAlwaysShowImageSizeBadgeChanged} />
</FormControl>
</FormControlGroup>
<BoardAutoAddSelect />
</Flex>
</PopoverBody>
</PopoverContent>
</Popover>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -1,26 +0,0 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { alwaysShowImageSizeBadgeChanged } from 'features/gallery/store/gallerySlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(alwaysShowImageSizeBadgeChanged(e.target.checked)),
[dispatch]
);
return (
<FormControl>
<FormLabel flexGrow={1}>{t('gallery.alwaysShowImageSizeBadge')}</FormLabel>
<Checkbox isChecked={alwaysShowImageSizeBadge} onChange={onChange} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -1,26 +0,0 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { autoAssignBoardOnClickChanged } from 'features/gallery/store/gallerySlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const autoAssignBoardOnClick = useAppSelector((s) => s.gallery.autoAssignBoardOnClick);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => dispatch(autoAssignBoardOnClickChanged(e.target.checked)),
[dispatch]
);
return (
<FormControl>
<FormLabel flexGrow={1}>{t('gallery.autoAssignBoardOnClick')}</FormLabel>
<Checkbox isChecked={autoAssignBoardOnClick} onChange={onChange} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -1,28 +0,0 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { shouldAutoSwitchChanged } from 'features/gallery/store/gallerySlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const shouldAutoSwitch = useAppSelector((s) => s.gallery.shouldAutoSwitch);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldAutoSwitchChanged(e.target.checked));
},
[dispatch]
);
return (
<FormControl>
<FormLabel flexGrow={1}>{t('gallery.autoSwitchNewImages')}</FormLabel>
<Checkbox isChecked={shouldAutoSwitch} onChange={onChange} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -1,41 +0,0 @@
import { Divider, Flex, IconButton, Popover, PopoverBody, PopoverContent, PopoverTrigger } from '@invoke-ai/ui-library';
import BoardAutoAddSelect from 'features/gallery/components/Boards/BoardAutoAddSelect';
import AlwaysShowImageSizeCheckbox from 'features/gallery/components/GallerySettingsPopover/AlwaysShowImageSizeCheckbox';
import AutoAssignBoardCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoAssignBoardCheckbox';
import AutoSwitchCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoSwitchCheckbox';
import ImageMinimumWidthSlider from 'features/gallery/components/GallerySettingsPopover/ImageMinimumWidthSlider';
import ShowArchivedBoardsCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowArchivedBoardsCheckbox';
import ShowStarredFirstCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowStarredFirstCheckbox';
import SortDirectionCombobox from 'features/gallery/components/GallerySettingsPopover/SortDirectionCombobox';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { RiSettings4Fill } from 'react-icons/ri';
const GallerySettingsPopover = () => {
const { t } = useTranslation();
return (
<Popover isLazy>
<PopoverTrigger>
<IconButton aria-label={t('gallery.gallerySettings')} size="sm" icon={<RiSettings4Fill />} />
</PopoverTrigger>
<PopoverContent>
<PopoverBody>
<Flex direction="column" gap={2}>
<ImageMinimumWidthSlider />
<AutoSwitchCheckbox />
<AutoAssignBoardCheckbox />
<AlwaysShowImageSizeCheckbox />
<ShowArchivedBoardsCheckbox />
<BoardAutoAddSelect />
<Divider pt={2} />
<ShowStarredFirstCheckbox />
<SortDirectionCombobox />
</Flex>
</PopoverBody>
</PopoverContent>
</Popover>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -1,26 +0,0 @@
import { CompositeSlider, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { setGalleryImageMinimumWidth } from 'features/gallery/store/gallerySlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
const onChange = useCallback(
(v: number) => {
dispatch(setGalleryImageMinimumWidth(v));
},
[dispatch]
);
return (
<FormControl>
<FormLabel>{t('gallery.galleryImageSize')}</FormLabel>
<CompositeSlider value={galleryImageMinimumWidth} onChange={onChange} min={45} max={256} defaultValue={90} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -1,28 +0,0 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { shouldShowArchivedBoardsChanged } from 'features/gallery/store/gallerySlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const shouldShowArchivedBoards = useAppSelector((s) => s.gallery.shouldShowArchivedBoards);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(shouldShowArchivedBoardsChanged(e.target.checked));
},
[dispatch]
);
return (
<FormControl>
<FormLabel flexGrow={1}>{t('gallery.showArchivedBoards')}</FormLabel>
<Checkbox isChecked={shouldShowArchivedBoards} onChange={onChange} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -1,30 +0,0 @@
import { FormControl, FormLabel, Switch } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { starredFirstChanged } from 'features/gallery/store/gallerySlice';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const starredFirst = useAppSelector((s) => s.gallery.starredFirst);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(starredFirstChanged(e.target.checked));
},
[dispatch]
);
return (
<FormControl w="full">
<FormLabel flexGrow={1} m={0}>
{t('gallery.showStarredImagesFirst')}
</FormLabel>
<Switch size="sm" isChecked={starredFirst} onChange={onChange} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

View File

@@ -1,45 +0,0 @@
import type { ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { SingleValue } from 'chakra-react-select';
import { orderDirChanged } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { assert } from 'tsafe';
const GallerySettingsPopover = () => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const orderDir = useAppSelector((s) => s.gallery.orderDir);
const options = useMemo<ComboboxOption[]>(
() => [
{ value: 'DESC', label: t('gallery.newestFirst') },
{ value: 'ASC', label: t('gallery.oldestFirst') },
],
[t]
);
const onChange = useCallback(
(v: SingleValue<ComboboxOption>) => {
assert(v?.value === 'ASC' || v?.value === 'DESC');
dispatch(orderDirChanged(v.value));
},
[dispatch]
);
const value = useMemo(() => {
return options.find((opt) => opt.value === orderDir);
}, [orderDir, options]);
return (
<FormControl>
<FormLabel flexGrow={1} m={0}>
{t('gallery.sortDirection')}
</FormLabel>
<Combobox isSearchable={false} value={value} options={options} onChange={onChange} />
</FormControl>
);
};
export default memo(GallerySettingsPopover);

Some files were not shown because too many files have changed in this diff Show More