mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 19:07:59 -05:00
Compare commits
41 Commits
maryhipp/s
...
lstein/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9dcace7d82 | ||
|
|
02957be333 | ||
|
|
5d6a77d336 | ||
|
|
9b7b182cf7 | ||
|
|
2219e3643a | ||
|
|
6932f27b43 | ||
|
|
0df018bd4e | ||
|
|
7088d5610b | ||
|
|
589a7959c0 | ||
|
|
e26360f85b | ||
|
|
debef2476e | ||
|
|
e57809e1c6 | ||
|
|
1c0067f931 | ||
|
|
c3d1252892 | ||
|
|
84f5cbdd97 | ||
|
|
edac01d4fb | ||
|
|
d04c880cce | ||
|
|
763a2e2632 | ||
|
|
eaadc55c7d | ||
|
|
89f8326c0b | ||
|
|
99558de178 | ||
|
|
77130f108d | ||
|
|
371f5bc782 | ||
|
|
fb9b7fb63a | ||
|
|
bd833900a3 | ||
|
|
a84f3058e2 | ||
|
|
f7436f3bae | ||
|
|
7dd93cb810 | ||
|
|
9adb15f86c | ||
|
|
3d69372785 | ||
|
|
eca29c41d0 | ||
|
|
9df0980c46 | ||
|
|
cef51ad80d | ||
|
|
83356ec74c | ||
|
|
9336a076de | ||
|
|
32d3e4dc5c | ||
|
|
a1dcab9c38 | ||
|
|
bd9b00a6bf | ||
|
|
eaa2c68693 | ||
|
|
24d73280ee | ||
|
|
6b991a5269 |
@@ -1328,7 +1328,7 @@ from invokeai.app.services.model_load import ModelLoadService, ModelLoaderRegist
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=config.ram_cache_size, max_vram_cache_size=config.vram_cache_size, logger=logger
|
||||
max_cache_size=config.ram_cache_size, logger=logger
|
||||
)
|
||||
convert_cache = ModelConvertCache(
|
||||
cache_path=config.models_convert_cache_path, max_size=config.convert_cache_size
|
||||
|
||||
@@ -316,7 +316,6 @@ async def list_image_dtos(
|
||||
),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of images per page"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of image DTOs"""
|
||||
|
||||
@@ -327,7 +326,6 @@ async def list_image_dtos(
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
search_term
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
||||
@@ -103,6 +103,7 @@ class CompelInvocation(BaseInvocation):
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
truncate_long_prompts=False,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
@@ -117,6 +118,7 @@ class CompelInvocation(BaseInvocation):
|
||||
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
conditioning=ConditioningField(
|
||||
conditioning_name=conditioning_name,
|
||||
@@ -203,6 +205,7 @@ class SDXLPromptInvocationBase:
|
||||
truncate_long_prompts=False, # TODO:
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
|
||||
requires_pooled=get_pooled,
|
||||
device=TorchDevice.choose_torch_device(),
|
||||
)
|
||||
|
||||
conjunction = Compel.parse_prompt_string(prompt)
|
||||
@@ -313,7 +316,6 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
conditioning_name = context.conditioning.save(conditioning_data)
|
||||
|
||||
return ConditioningOutput(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||
import copy
|
||||
import inspect
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
@@ -193,9 +194,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||
for cond in cond_list:
|
||||
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||
cond_data = copy.deepcopy(context.conditioning.load(cond.conditioning_name))
|
||||
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||
|
||||
mask = cond.mask
|
||||
if mask is not None:
|
||||
mask = context.tensors.load(mask.tensor_name)
|
||||
@@ -226,6 +226,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
|
||||
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||
resized_mask = tf(mask)
|
||||
assert isinstance(resized_mask, torch.Tensor)
|
||||
return resized_mask
|
||||
|
||||
def _concat_regional_text_embeddings(
|
||||
|
||||
@@ -26,13 +26,13 @@ LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
DEFAULT_RAM_CACHE = 10.0
|
||||
DEFAULT_VRAM_CACHE = 0.25
|
||||
DEFAULT_CONVERT_CACHE = 20.0
|
||||
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
|
||||
DEVICE = Literal["auto", "cpu", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5", "cuda:6", "cuda:7", "mps"]
|
||||
PRECISION = Literal["auto", "float16", "bfloat16", "float32", "autocast"]
|
||||
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||
CONFIG_SCHEMA_VERSION = "4.0.1"
|
||||
CONFIG_SCHEMA_VERSION = "4.0.2"
|
||||
|
||||
|
||||
def get_default_ram_cache_size() -> float:
|
||||
@@ -105,14 +105,16 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
convert_cache: Maximum size of on-disk converted models cache (GB).
|
||||
lazy_offload: Keep models in VRAM until their space is needed.
|
||||
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
|
||||
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda:0`, `cuda:1`, `cuda:2`, `cuda:3`, `cuda:4`, `cuda:5`, `cuda:6`, `cuda:7`, `mps`
|
||||
devices: List of execution devices; will override default device selected.
|
||||
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`, `autocast`
|
||||
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
|
||||
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
|
||||
attention_slice_size: Slice size, valid when attention_type=="sliced".<br>Valid values: `auto`, `balanced`, `max`, `1`, `2`, `3`, `4`, `5`, `6`, `7`, `8`
|
||||
force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).
|
||||
pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.
|
||||
max_queue_size: Maximum number of items in the session queue.
|
||||
max_threads: Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.
|
||||
clear_queue_on_startup: Empties session queue on startup.
|
||||
allow_nodes: List of nodes to allow. Omit to allow all.
|
||||
deny_nodes: List of nodes to deny. Omit to deny none.
|
||||
@@ -178,6 +180,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
|
||||
# DEVICE
|
||||
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
|
||||
devices: Optional[list[DEVICE]] = Field(default=None, description="List of execution devices; will override default device selected.")
|
||||
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
|
||||
|
||||
# GENERATION
|
||||
@@ -187,6 +190,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
||||
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
||||
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
||||
max_threads: Optional[int] = Field(default=None, description="Maximum number of session queue execution threads. Autocalculated from number of GPUs if not set.")
|
||||
clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.")
|
||||
|
||||
# NODES
|
||||
@@ -376,9 +380,6 @@ def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
parsed_config_dict["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
@@ -426,6 +427,27 @@ def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig
|
||||
return config
|
||||
|
||||
|
||||
def migrate_v4_0_1_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate v4.0.1 config dictionary to a current config object.
|
||||
|
||||
A few new multi-GPU options were added in 4.0.2, and this simply
|
||||
updates the schema label.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.1 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, _ in config_dict.items():
|
||||
if k == "schema_version":
|
||||
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
return config
|
||||
|
||||
|
||||
# TO DO: replace this with a formal registration and migration system
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""Load and migrate a config file to the latest version.
|
||||
|
||||
@@ -457,6 +479,10 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
|
||||
elif loaded_config_dict["schema_version"] == "4.0.1":
|
||||
loaded_config_dict = migrate_v4_0_1_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
|
||||
@@ -41,7 +41,6 @@ class ImageRecordStorageBase(ABC):
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
@@ -148,7 +148,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@@ -209,13 +208,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND json_extract(images.metadata, '$') LIKE ?
|
||||
"""
|
||||
query_params.append(f'%{search_term}%')
|
||||
|
||||
query_pagination = """--sql
|
||||
ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
@@ -120,7 +120,6 @@ class ImageServiceABC(ABC):
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@@ -206,7 +206,6 @@ class ImageService(ImageServiceABC):
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self.__invoker.services.image_records.get_many(
|
||||
@@ -216,7 +215,6 @@ class ImageService(ImageServiceABC):
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
search_term
|
||||
)
|
||||
|
||||
image_dtos = [
|
||||
|
||||
@@ -53,11 +53,11 @@ class InvocationServices:
|
||||
model_images: "ModelImageFileStorageBase",
|
||||
model_manager: "ModelManagerServiceBase",
|
||||
download_queue: "DownloadQueueServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
session_queue: "SessionQueueBase",
|
||||
session_processor: "SessionProcessorBase",
|
||||
invocation_cache: "InvocationCacheBase",
|
||||
names: "NameServiceBase",
|
||||
performance_statistics: "InvocationStatsServiceBase",
|
||||
urls: "UrlServiceBase",
|
||||
workflow_records: "WorkflowRecordsStorageBase",
|
||||
tensors: "ObjectSerializerBase[torch.Tensor]",
|
||||
@@ -77,11 +77,11 @@ class InvocationServices:
|
||||
self.model_images = model_images
|
||||
self.model_manager = model_manager
|
||||
self.download_queue = download_queue
|
||||
self.performance_statistics = performance_statistics
|
||||
self.session_queue = session_queue
|
||||
self.session_processor = session_processor
|
||||
self.invocation_cache = invocation_cache
|
||||
self.names = names
|
||||
self.performance_statistics = performance_statistics
|
||||
self.urls = urls
|
||||
self.workflow_records = workflow_records
|
||||
self.tensors = tensors
|
||||
|
||||
@@ -74,9 +74,9 @@ class InvocationStatsService(InvocationStatsServiceBase):
|
||||
)
|
||||
self._stats[graph_execution_state_id].add_node_execution_stats(node_stats)
|
||||
|
||||
def reset_stats(self):
|
||||
self._stats = {}
|
||||
self._cache_stats = {}
|
||||
def reset_stats(self, graph_execution_state_id: str):
|
||||
self._stats.pop(graph_execution_state_id)
|
||||
self._cache_stats.pop(graph_execution_state_id)
|
||||
|
||||
def get_stats(self, graph_execution_state_id: str) -> InvocationStatsSummary:
|
||||
graph_stats_summary = self._get_graph_summary(graph_execution_state_id)
|
||||
|
||||
@@ -284,9 +284,14 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
unfinished_jobs = [x for x in self._install_jobs if not x.in_terminal_state]
|
||||
self._install_jobs = unfinished_jobs
|
||||
|
||||
def _migrate_yaml(self) -> None:
|
||||
def _migrate_yaml(self, rename_yaml: Optional[bool] = True, overwrite_db: Optional[bool] = False) -> None:
|
||||
db_models = self.record_store.all_models()
|
||||
|
||||
if overwrite_db:
|
||||
for model in db_models:
|
||||
self.record_store.del_model(model.key)
|
||||
db_models = self.record_store.all_models()
|
||||
|
||||
legacy_models_yaml_path = (
|
||||
self._app_config.legacy_models_yaml_path or self._app_config.root_path / "configs" / "models.yaml"
|
||||
)
|
||||
@@ -336,7 +341,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.warning(f"Model at {model_path} could not be migrated: {e}")
|
||||
|
||||
# Rename `models.yaml` to `models.yaml.bak` to prevent re-migration
|
||||
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
|
||||
if rename_yaml:
|
||||
legacy_models_yaml_path.rename(legacy_models_yaml_path.with_suffix(".yaml.bak"))
|
||||
|
||||
# Unset the path - we are done with it either way
|
||||
self._app_config.legacy_models_yaml_path = None
|
||||
|
||||
@@ -33,6 +33,11 @@ class ModelLoadServiceBase(ABC):
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs we are configured to use."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], AnyModel]] = None
|
||||
|
||||
@@ -46,6 +46,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self._registry = registry
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
"""Start the service."""
|
||||
self._invoker = invoker
|
||||
|
||||
@property
|
||||
@@ -53,6 +54,11 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the RAM cache used by this loader."""
|
||||
return self._ram_cache
|
||||
|
||||
@property
|
||||
def gpu_count(self) -> int:
|
||||
"""Return the number of GPUs available for our uses."""
|
||||
return len(self._ram_cache.execution_devices)
|
||||
|
||||
@property
|
||||
def convert_cache(self) -> ModelConvertCacheBase:
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Set
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
@@ -31,7 +32,7 @@ class ModelManagerServiceBase(ABC):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: torch.device,
|
||||
execution_devices: Optional[Set[torch.device]] = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Team
|
||||
"""Implementation of ModelManagerServiceBase."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache, ModelLoaderRegistry
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from ..config import InvokeAIAppConfig
|
||||
@@ -69,7 +65,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
model_record_service: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
events: EventServiceBase,
|
||||
execution_device: Optional[torch.device] = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Construct the model manager service instance.
|
||||
@@ -82,9 +77,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
ram_cache = ModelCache(
|
||||
max_cache_size=app_config.ram,
|
||||
max_vram_cache_size=app_config.vram,
|
||||
lazy_offloading=app_config.lazy_offload,
|
||||
logger=logger,
|
||||
execution_device=execution_device or TorchDevice.choose_torch_device(),
|
||||
)
|
||||
convert_cache = ModelConvertCache(cache_path=app_config.convert_cache_path, max_size=app_config.convert_cache)
|
||||
loader = ModelLoadService(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
@@ -9,6 +10,7 @@ import torch
|
||||
from invokeai.app.services.object_serializer.object_serializer_base import ObjectSerializerBase
|
||||
from invokeai.app.services.object_serializer.object_serializer_common import ObjectNotFoundError
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
@@ -70,7 +72,10 @@ class ObjectSerializerDisk(ObjectSerializerBase[T]):
|
||||
return self._output_dir / name
|
||||
|
||||
def _new_name(self) -> str:
|
||||
return f"{self._obj_class_name}_{uuid_string()}"
|
||||
tid = threading.current_thread().ident
|
||||
# Add tid to the object name because uuid4 not thread-safe on windows
|
||||
# See https://stackoverflow.com/questions/2759644/python-multiprocessing-doesnt-play-nicely-with-uuid-uuid4
|
||||
return f"{self._obj_class_name}_{tid}-{uuid_string()}"
|
||||
|
||||
def _tempdir_cleanup(self) -> None:
|
||||
"""Calls `cleanup` on the temporary directory, if it exists."""
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import traceback
|
||||
from contextlib import suppress
|
||||
from threading import BoundedSemaphore, Thread
|
||||
from queue import Queue
|
||||
from threading import BoundedSemaphore, Lock, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
from typing import Optional, Set
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.events.events_common import (
|
||||
@@ -26,6 +27,7 @@ from invokeai.app.services.session_queue.session_queue_common import SessionQueu
|
||||
from invokeai.app.services.shared.graph import NodeInputError
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||
from invokeai.app.util.profiler import Profiler
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
|
||||
@@ -57,8 +59,11 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
|
||||
self._on_node_error_callbacks = on_node_error_callbacks or []
|
||||
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
|
||||
self._process_lock = Lock()
|
||||
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
||||
def start(
|
||||
self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None
|
||||
) -> None:
|
||||
self._services = services
|
||||
self._cancel_event = cancel_event
|
||||
self._profiler = profiler
|
||||
@@ -76,7 +81,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# Loop over invocations until the session is complete or canceled
|
||||
while True:
|
||||
try:
|
||||
invocation = queue_item.session.next()
|
||||
with self._process_lock:
|
||||
invocation = queue_item.session.next()
|
||||
# Anything other than a `NodeInputError` is handled as a processor error
|
||||
except NodeInputError as e:
|
||||
error_type = e.__class__.__name__
|
||||
@@ -108,7 +114,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
|
||||
self._on_after_run_session(queue_item=queue_item)
|
||||
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
|
||||
try:
|
||||
# Any unhandled exception in this scope is an invocation error & will fail the graph
|
||||
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
@@ -210,7 +216,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self._services.performance_statistics.log_stats(queue_item.session.id)
|
||||
self._services.performance_statistics.reset_stats()
|
||||
self._services.performance_statistics.reset_stats(queue_item.session.id)
|
||||
|
||||
for callback in self._on_after_run_session_callbacks:
|
||||
callback(queue_item=queue_item)
|
||||
@@ -324,7 +330,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
self._active_queue_items: Set[SessionQueueItem] = set()
|
||||
self._invocation: Optional[BaseInvocation] = None
|
||||
|
||||
self._resume_event = ThreadEvent()
|
||||
@@ -350,7 +356,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
else None
|
||||
)
|
||||
|
||||
self._worker_thread_count = self._invoker.services.configuration.max_threads or len(
|
||||
TorchDevice.execution_devices()
|
||||
)
|
||||
|
||||
self._session_worker_queue: Queue[SessionQueueItem] = Queue()
|
||||
|
||||
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
|
||||
# Session processor - singlethreaded
|
||||
self._thread = Thread(
|
||||
name="session_processor",
|
||||
target=self._process,
|
||||
@@ -363,6 +376,16 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
# Session processor workers - multithreaded
|
||||
self._invoker.services.logger.debug(f"Starting {self._worker_thread_count} session processing threads.")
|
||||
for _i in range(0, self._worker_thread_count):
|
||||
worker = Thread(
|
||||
name="session_worker",
|
||||
target=self._process_next_session,
|
||||
daemon=True,
|
||||
)
|
||||
worker.start()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self._stop_event.set()
|
||||
|
||||
@@ -370,7 +393,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now_event.set()
|
||||
|
||||
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
|
||||
if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
|
||||
if any(item.queue_id == event[1].queue_id for item in self._active_queue_items):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
|
||||
@@ -378,7 +401,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now()
|
||||
|
||||
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
|
||||
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
|
||||
if self._active_queue_items and event[1].status in ["completed", "failed", "canceled"]:
|
||||
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
|
||||
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
|
||||
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
|
||||
@@ -403,7 +426,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
return SessionProcessorStatus(
|
||||
is_started=self._resume_event.is_set(),
|
||||
is_processing=self._queue_item is not None,
|
||||
is_processing=len(self._active_queue_items) > 0,
|
||||
)
|
||||
|
||||
def _process(
|
||||
@@ -428,30 +451,22 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
resume_event.wait()
|
||||
|
||||
# Get the next session to process
|
||||
self._queue_item = self._invoker.services.session_queue.dequeue()
|
||||
queue_item = self._invoker.services.session_queue.dequeue()
|
||||
|
||||
if self._queue_item is None:
|
||||
if queue_item is None:
|
||||
# The queue was empty, wait for next polling interval or event to try again
|
||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
|
||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||
self._session_worker_queue.put(queue_item)
|
||||
self._invoker.services.logger.debug(f"Scheduling queue item {queue_item.item_id} to run")
|
||||
cancel_event.clear()
|
||||
|
||||
# Run the graph
|
||||
self.session_runner.run(queue_item=self._queue_item)
|
||||
# self.session_runner.run(queue_item=self._queue_item)
|
||||
|
||||
except Exception as e:
|
||||
error_type = e.__class__.__name__
|
||||
error_message = str(e)
|
||||
error_traceback = traceback.format_exc()
|
||||
self._on_non_fatal_processor_error(
|
||||
queue_item=self._queue_item,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
except Exception:
|
||||
# Wait for next polling interval or event to try again
|
||||
poll_now_event.wait(self._polling_interval)
|
||||
continue
|
||||
@@ -466,9 +481,25 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
finally:
|
||||
stop_event.clear()
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
|
||||
def _process_next_session(self) -> None:
|
||||
while True:
|
||||
self._resume_event.wait()
|
||||
queue_item = self._session_worker_queue.get()
|
||||
if queue_item.status == "canceled":
|
||||
continue
|
||||
try:
|
||||
self._active_queue_items.add(queue_item)
|
||||
# reserve a GPU for this session - may block
|
||||
with self._invoker.services.model_manager.load.ram_cache.reserve_execution_device():
|
||||
# Run the session on the reserved GPU
|
||||
self.session_runner.run(queue_item=queue_item)
|
||||
except Exception:
|
||||
continue
|
||||
finally:
|
||||
self._active_queue_items.remove(queue_item)
|
||||
|
||||
def _on_non_fatal_processor_error(
|
||||
self,
|
||||
queue_item: Optional[SessionQueueItem],
|
||||
|
||||
@@ -236,6 +236,9 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
}
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.item_id
|
||||
|
||||
|
||||
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
|
||||
pass
|
||||
|
||||
@@ -652,7 +652,7 @@ class Graph(BaseModel):
|
||||
output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
|
||||
|
||||
# Input type must be a list
|
||||
if get_origin(input_field) != list:
|
||||
if get_origin(input_field) is not list:
|
||||
return False
|
||||
|
||||
# Validate that all outputs match the input type
|
||||
|
||||
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from torch import Tensor
|
||||
@@ -26,11 +27,13 @@ from invokeai.backend.model_manager.config import (
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
"""
|
||||
The InvocationContext provides access to various services and data about the current invocation.
|
||||
@@ -323,7 +326,6 @@ class ConditioningInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
The loaded conditioning data.
|
||||
"""
|
||||
|
||||
return self._services.conditioning.load(name)
|
||||
|
||||
|
||||
@@ -557,6 +559,28 @@ class UtilInterface(InvocationContextInterface):
|
||||
is_canceled=self.is_canceled,
|
||||
)
|
||||
|
||||
def torch_device(self) -> torch.device:
|
||||
"""
|
||||
Return a torch device to use in the current invocation.
|
||||
|
||||
Returns:
|
||||
A torch.device not currently in use by the system.
|
||||
"""
|
||||
ram_cache: "ModelCacheBase[AnyModel]" = self._services.model_manager.load.ram_cache
|
||||
return ram_cache.get_execution_device()
|
||||
|
||||
def torch_dtype(self, device: Optional[torch.device] = None) -> torch.dtype:
|
||||
"""
|
||||
Return a precision type to use with the current invocation and torch device.
|
||||
|
||||
Args:
|
||||
device: Optional device.
|
||||
|
||||
Returns:
|
||||
A torch.dtype suited for the current device.
|
||||
"""
|
||||
return TorchDevice.choose_torch_dtype(device)
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Provides access to various services and data for the current invocation.
|
||||
|
||||
@@ -25,6 +25,7 @@ from enum import Enum
|
||||
from typing import Literal, Optional, Type, TypeAlias, Union
|
||||
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
@@ -37,7 +38,7 @@ from ..raw_model import RawModel
|
||||
|
||||
# ModelMixin is the base class for all diffusers and transformers models
|
||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
|
||||
AnyModel = Union[ConfigMixin, ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor]]
|
||||
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
@@ -177,6 +178,7 @@ class ModelConfigBase(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
"""Extend the pydantic schema from a json."""
|
||||
schema["required"].extend(["key", "type", "format"])
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
||||
@@ -443,7 +445,7 @@ class ModelConfigFactory(object):
|
||||
model = dest_class.model_validate(model_data)
|
||||
else:
|
||||
# mypy doesn't typecheck TypeAdapters well?
|
||||
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
||||
model = AnyModelConfigValidator.validate_python(model_data)
|
||||
assert model is not None
|
||||
if key:
|
||||
model.key = key
|
||||
|
||||
@@ -65,8 +65,7 @@ class LoadedModelWithoutConfig:
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
self._locker.lock()
|
||||
return self.model
|
||||
return self._locker.lock()
|
||||
|
||||
def __exit__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Context exit."""
|
||||
|
||||
@@ -8,9 +8,10 @@ model will be cleared and (re)loaded from disk when next needed.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from logging import Logger
|
||||
from typing import Dict, Generic, Optional, TypeVar
|
||||
from typing import Dict, Generator, Generic, Optional, Set, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
@@ -51,44 +52,13 @@ class CacheRecord(Generic[T]):
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
model: Read-only copy of the model *without weights* residing in the "meta device"
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
|
||||
The state_dict should be treated as a read-only attribute. Do not attempt
|
||||
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
||||
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
||||
context manager call `model_on_device()`.
|
||||
"""
|
||||
|
||||
key: str
|
||||
model: T
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock this record."""
|
||||
self._locks += 1
|
||||
|
||||
def unlock(self) -> None:
|
||||
"""Unlock this record."""
|
||||
self._locks -= 1
|
||||
assert self._locks >= 0
|
||||
|
||||
@property
|
||||
def locked(self) -> bool:
|
||||
"""Return true if record is locked."""
|
||||
return self._locks > 0
|
||||
model: T
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -115,14 +85,27 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
def execution_devices(self) -> Set[torch.device]:
|
||||
"""Return the set of available execution devices."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@contextmanager
|
||||
@abstractmethod
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
def reserve_execution_device(self, timeout: int = 0) -> Generator[torch.device, None, None]:
|
||||
"""Reserve an execution device (GPU) under the current thread id."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_execution_device(self) -> torch.device:
|
||||
"""
|
||||
Return an execution device that has been reserved for current thread.
|
||||
|
||||
Note that reservations are done using the current thread's TID.
|
||||
It might be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
|
||||
May generate a ValueError if no GPU has been reserved.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@@ -131,16 +114,6 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Offload from VRAM any models not actively in use."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def stats(self) -> Optional[CacheStats]:
|
||||
@@ -202,6 +175,11 @@ class ModelCacheBase(ABC, Generic[T]):
|
||||
"""Return true if the model identified by key and submodel_type is in the cache."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
|
||||
"""Move a copy of the model into the indicated device and return it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cache_size(self) -> int:
|
||||
"""Get the total size of the models currently cached."""
|
||||
|
||||
@@ -18,17 +18,19 @@ context. Use like this:
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import math
|
||||
import time
|
||||
from contextlib import suppress
|
||||
import sys
|
||||
import threading
|
||||
from contextlib import contextmanager, suppress
|
||||
from logging import Logger
|
||||
from typing import Dict, List, Optional
|
||||
from threading import BoundedSemaphore
|
||||
from typing import Dict, Generator, List, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, SubModelType
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@@ -39,9 +41,7 @@ from .model_locker import ModelLocker
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
|
||||
# amount of GPU memory to hold in reserve for use by generations (GB)
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 2.75
|
||||
DEFAULT_MAX_VRAM_CACHE_SIZE = 0.25
|
||||
|
||||
# actual size of a gig
|
||||
GIG = 1073741824
|
||||
@@ -57,12 +57,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self,
|
||||
max_cache_size: float = DEFAULT_MAX_CACHE_SIZE,
|
||||
max_vram_cache_size: float = DEFAULT_MAX_VRAM_CACHE_SIZE,
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
@@ -70,23 +66,19 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
Initialize the model RAM cache.
|
||||
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
:param execution_device: Torch device to load active model into [torch.device('cuda')]
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
"""
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0
|
||||
self._precision: torch.dtype = precision
|
||||
self._max_cache_size: float = max_cache_size
|
||||
self._max_vram_cache_size: float = max_vram_cache_size
|
||||
self._execution_device: torch.device = execution_device
|
||||
self._storage_device: torch.device = storage_device
|
||||
self._ram_lock = threading.Lock()
|
||||
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
self._log_memory_usage = log_memory_usage
|
||||
self._stats: Optional[CacheStats] = None
|
||||
@@ -94,25 +86,87 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
|
||||
self._cache_stack: List[str] = []
|
||||
|
||||
# device to thread id
|
||||
self._device_lock = threading.Lock()
|
||||
self._execution_devices: Dict[torch.device, int] = {x: 0 for x in TorchDevice.execution_devices()}
|
||||
self._free_execution_device = BoundedSemaphore(len(self._execution_devices))
|
||||
|
||||
self.logger.info(
|
||||
f"Using rendering device(s): {', '.join(sorted([str(x) for x in self._execution_devices.keys()]))}"
|
||||
)
|
||||
|
||||
@property
|
||||
def logger(self) -> Logger:
|
||||
"""Return the logger used by the cache."""
|
||||
return self._logger
|
||||
|
||||
@property
|
||||
def lazy_offloading(self) -> bool:
|
||||
"""Return true if the cache is configured to lazily offload models in VRAM."""
|
||||
return self._lazy_offloading
|
||||
|
||||
@property
|
||||
def storage_device(self) -> torch.device:
|
||||
"""Return the storage device (e.g. "CPU" for RAM)."""
|
||||
return self._storage_device
|
||||
|
||||
@property
|
||||
def execution_device(self) -> torch.device:
|
||||
"""Return the exection device (e.g. "cuda" for VRAM)."""
|
||||
return self._execution_device
|
||||
def execution_devices(self) -> Set[torch.device]:
|
||||
"""Return the set of available execution devices."""
|
||||
devices = self._execution_devices.keys()
|
||||
return set(devices)
|
||||
|
||||
def get_execution_device(self) -> torch.device:
|
||||
"""
|
||||
Return an execution device that has been reserved for current thread.
|
||||
|
||||
Note that reservations are done using the current thread's TID.
|
||||
It would be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
|
||||
May generate a ValueError if no GPU has been reserved.
|
||||
"""
|
||||
current_thread = threading.current_thread().ident
|
||||
assert current_thread is not None
|
||||
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
||||
if not assigned:
|
||||
raise ValueError(f"No GPU has been reserved for the use of thread {current_thread}")
|
||||
return assigned[0]
|
||||
|
||||
@contextmanager
|
||||
def reserve_execution_device(self, timeout: Optional[int] = None) -> Generator[torch.device, None, None]:
|
||||
"""Reserve an execution device (e.g. GPU) for exclusive use by a generation thread.
|
||||
|
||||
Note that the reservation is done using the current thread's TID.
|
||||
It would be better to do this using the session ID, but that involves
|
||||
too many detailed changes to model manager calls.
|
||||
"""
|
||||
device = None
|
||||
with self._device_lock:
|
||||
current_thread = threading.current_thread().ident
|
||||
assert current_thread is not None
|
||||
|
||||
# look for a device that has already been assigned to this thread
|
||||
assigned = [x for x, tid in self._execution_devices.items() if current_thread == tid]
|
||||
if assigned:
|
||||
device = assigned[0]
|
||||
|
||||
# no device already assigned. Get one.
|
||||
if device is None:
|
||||
self._free_execution_device.acquire(timeout=timeout)
|
||||
with self._device_lock:
|
||||
free_device = [x for x, tid in self._execution_devices.items() if tid == 0]
|
||||
self._execution_devices[free_device[0]] = current_thread
|
||||
device = free_device[0]
|
||||
|
||||
# we are outside the lock region now
|
||||
self.logger.info(f"{current_thread} Reserved torch device {device}")
|
||||
|
||||
# Tell TorchDevice to use this object to get the torch device.
|
||||
TorchDevice.set_model_cache(self)
|
||||
try:
|
||||
yield device
|
||||
finally:
|
||||
with self._device_lock:
|
||||
self.logger.info(f"{current_thread} Released torch device {device}")
|
||||
self._execution_devices[device] = 0
|
||||
self._free_execution_device.release()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def max_cache_size(self) -> float:
|
||||
@@ -157,16 +211,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Store model under key and optional submodel_type."""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(model)
|
||||
self.make_room(size)
|
||||
with self._ram_lock:
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(model)
|
||||
self.make_room(size)
|
||||
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
cache_record = CacheRecord(key=key, model=model, size=size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
def get(
|
||||
self,
|
||||
@@ -184,36 +238,37 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
This may raise an IndexError if the model is not in the cache.
|
||||
"""
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
with self._ram_lock:
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
if self.stats:
|
||||
self.stats.hits += 1
|
||||
else:
|
||||
if self.stats:
|
||||
self.stats.misses += 1
|
||||
raise IndexError(f"The model with key {key} is not in the cache.")
|
||||
|
||||
cache_entry = self._cached_models[key]
|
||||
cache_entry = self._cached_models[key]
|
||||
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
# more stats
|
||||
if self.stats:
|
||||
stats_name = stats_name or key
|
||||
self.stats.cache_size = int(self._max_cache_size * GIG)
|
||||
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
|
||||
self.stats.in_cache = len(self._cached_models)
|
||||
self.stats.loaded_model_sizes[stats_name] = max(
|
||||
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
)
|
||||
|
||||
# this moves the entry to the top (right end) of the stack
|
||||
with suppress(Exception):
|
||||
self._cache_stack.remove(key)
|
||||
self._cache_stack.append(key)
|
||||
return ModelLocker(
|
||||
cache=self,
|
||||
cache_entry=cache_entry,
|
||||
)
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
@@ -225,127 +280,34 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
else:
|
||||
return model_key
|
||||
|
||||
def offload_unlocked_models(self, size_required: int) -> None:
|
||||
"""Move any unused models from VRAM."""
|
||||
reserved = self._max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
cache_entry.loaded = False
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
)
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
"""Move model into the indicated device.
|
||||
def model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> AnyModel:
|
||||
"""Move a copy of the model into the indicated device and return it.
|
||||
|
||||
:param cache_entry: The CacheRecord for the model
|
||||
:param target_device: The torch.device to move the model into
|
||||
|
||||
May raise a torch.cuda.OutOfMemoryError
|
||||
"""
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
source_device = cache_entry.device
|
||||
with self._ram_lock:
|
||||
self.logger.debug(f"Called to move {cache_entry.key} ({type(cache_entry.model)=}) to {target_device}")
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
# Some models don't have a `to` method, in which case they run in RAM/CPU.
|
||||
if not hasattr(cache_entry.model, "to"):
|
||||
return
|
||||
|
||||
# This roundabout method for moving the model around is done to avoid
|
||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# This operation is slightly faster than running `to()` on the whole model.
|
||||
#
|
||||
# When the model needs to be removed from VRAM we simply delete the copy
|
||||
# of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# in RAM into the model. So this operation is very fast.
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
try:
|
||||
if cache_entry.state_dict is not None:
|
||||
assert hasattr(cache_entry.model, "load_state_dict")
|
||||
if target_device == self.storage_device:
|
||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device, non_blocking=True)
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Moved model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
|
||||
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if (
|
||||
snapshot_before is not None
|
||||
and snapshot_after is not None
|
||||
and snapshot_before.vram is not None
|
||||
and snapshot_after.vram is not None
|
||||
):
|
||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||
if not math.isclose(
|
||||
vram_change,
|
||||
cache_entry.size,
|
||||
rel_tol=0.1,
|
||||
abs_tol=10 * MB,
|
||||
):
|
||||
self.logger.debug(
|
||||
f"Moving model '{cache_entry.key}' from {source_device} to"
|
||||
f" {target_device} caused an unexpected change in VRAM usage. The model's"
|
||||
" estimated size may be incorrect. Estimated model size:"
|
||||
f" {(cache_entry.size/GIG):.3f} GB.\n"
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
# Some models don't have a state dictionary, in which case the
|
||||
# stored model will still reside in CPU
|
||||
if hasattr(cache_entry.model, "to"):
|
||||
model_in_gpu = copy.deepcopy(cache_entry.model)
|
||||
assert hasattr(model_in_gpu, "to")
|
||||
model_in_gpu.to(target_device)
|
||||
return model_in_gpu
|
||||
else:
|
||||
return cache_entry.model # what happens in CPU stays in CPU
|
||||
|
||||
def print_cuda_stats(self) -> None:
|
||||
"""Log CUDA diagnostics."""
|
||||
vram = "%4.2fG" % (torch.cuda.memory_allocated() / GIG)
|
||||
ram = "%4.2fG" % (self.cache_size() / GIG)
|
||||
|
||||
in_ram_models = 0
|
||||
in_vram_models = 0
|
||||
locked_in_vram_models = 0
|
||||
for cache_record in self._cached_models.values():
|
||||
if hasattr(cache_record.model, "device"):
|
||||
if cache_record.model.device == self.storage_device:
|
||||
in_ram_models += 1
|
||||
else:
|
||||
in_vram_models += 1
|
||||
if cache_record.locked:
|
||||
locked_in_vram_models += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) ="
|
||||
f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})"
|
||||
)
|
||||
in_ram_models = len(self._cached_models)
|
||||
self.logger.debug(f"Current VRAM/RAM usage for {in_ram_models} models: {vram}/{ram}")
|
||||
|
||||
def make_room(self, size: int) -> None:
|
||||
"""Make enough room in the cache to accommodate a new model of indicated size."""
|
||||
@@ -368,12 +330,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
)
|
||||
|
||||
if not cache_entry.locked:
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# Expected refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
# 1 from onnx runtime object
|
||||
if refs <= (3 if "onnx" in model_key else 2):
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
@@ -400,10 +364,26 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if self.stats:
|
||||
self.stats.cleared = models_cleared
|
||||
gc.collect()
|
||||
|
||||
TorchDevice.empty_cache()
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _check_free_vram(self, target_device: torch.device, needed_size: int) -> None:
|
||||
if target_device.type != "cuda":
|
||||
return
|
||||
vram_device = ( # mem_get_info() needs an indexed device
|
||||
target_device if target_device.index is not None else torch.device(str(target_device), index=0)
|
||||
)
|
||||
free_mem, _ = torch.cuda.mem_get_info(torch.device(vram_device))
|
||||
if needed_size > free_mem:
|
||||
raise torch.cuda.OutOfMemoryError
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
try:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _device_name(device: torch.device) -> str:
|
||||
return f"{device.type}:{device.index}"
|
||||
|
||||
@@ -10,6 +10,8 @@ from invokeai.backend.model_manager import AnyModel
|
||||
|
||||
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
|
||||
|
||||
MAX_GPU_WAIT = 600 # wait up to 10 minutes for a GPU to become free
|
||||
|
||||
|
||||
class ModelLocker(ModelLockerBase):
|
||||
"""Internal class that mediates movement in and out of GPU."""
|
||||
@@ -29,33 +31,29 @@ class ModelLocker(ModelLockerBase):
|
||||
"""Return the model without moving it around."""
|
||||
return self._cache_entry.model
|
||||
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return self._cache_entry.state_dict
|
||||
|
||||
def lock(self) -> AnyModel:
|
||||
"""Move the model into the execution device (GPU) and lock it."""
|
||||
self._cache_entry.lock()
|
||||
try:
|
||||
if self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||
self._cache_entry.loaded = True
|
||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||
device = self._cache.get_execution_device()
|
||||
model_on_device = self._cache.model_to_device(self._cache_entry, device)
|
||||
self._cache.logger.debug(f"Moved {self._cache_entry.key} to {device}")
|
||||
self._cache.print_cuda_stats()
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self._cache.logger.warning("Insufficient GPU memory to load model. Aborting")
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
except Exception:
|
||||
self._cache_entry.unlock()
|
||||
raise
|
||||
|
||||
return self.model
|
||||
return model_on_device
|
||||
|
||||
# It is no longer necessary to move the model out of VRAM
|
||||
# because it will be removed when it goes out of scope
|
||||
# in the caller's context
|
||||
def unlock(self) -> None:
|
||||
"""Call upon exit from context."""
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(0)
|
||||
self._cache.print_cuda_stats()
|
||||
self._cache.print_cuda_stats()
|
||||
|
||||
# This is no longer in use in MGPU.
|
||||
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||
"""Return the state dict (if any) for the cached model."""
|
||||
return None
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
@@ -34,6 +35,8 @@ with LoRAHelper.apply_lora_unet(unet, loras):
|
||||
|
||||
# TODO: rename smth like ModelPatcher and add TI method?
|
||||
class ModelPatcher:
|
||||
_thread_lock = threading.Lock()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
|
||||
assert "." not in lora_key
|
||||
@@ -106,7 +109,7 @@ class ModelPatcher:
|
||||
"""
|
||||
original_weights = {}
|
||||
try:
|
||||
with torch.no_grad():
|
||||
with torch.no_grad(), cls._thread_lock:
|
||||
for lora, lora_weight in loras:
|
||||
# assert lora.device.type == "cpu"
|
||||
for layer_key, layer in lora.layers.items():
|
||||
@@ -129,9 +132,7 @@ class ModelPatcher:
|
||||
dtype = module.weight.dtype
|
||||
|
||||
if module_key not in original_weights:
|
||||
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
|
||||
original_weights[module_key] = model_state_dict[module_key + ".weight"]
|
||||
else:
|
||||
if model_state_dict is None: # no CPU copy of the state dict was provided
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
@@ -32,8 +32,11 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||
assert self.pooled_embeds.device == device
|
||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||
return super().to(device=device, dtype=dtype)
|
||||
result = super().to(device=device, dtype=dtype)
|
||||
assert self.embeds.device == device
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import threading
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -293,24 +294,31 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_kwargs["regional_ip_data"] = regional_ip_data
|
||||
|
||||
added_cond_kwargs = None
|
||||
if conditioning_data.is_sdxl():
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.uncond_text.add_time_ids,
|
||||
conditioning_data.cond_text.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
try:
|
||||
if conditioning_data.is_sdxl():
|
||||
# tid = threading.current_thread().ident
|
||||
# print(f'DEBUG {tid} {conditioning_data.uncond_text.pooled_embeds.device=} {conditioning_data.cond_text.pooled_embeds.device=}', flush=True),
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.uncond_text.pooled_embeds,
|
||||
conditioning_data.cond_text.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.uncond_text.add_time_ids,
|
||||
conditioning_data.cond_text.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
except Exception as e:
|
||||
tid = threading.current_thread().ident
|
||||
print(f"DEBUG: {tid} {str(e)}")
|
||||
raise e
|
||||
|
||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
||||
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
from typing import Dict, Literal, Optional, Union
|
||||
"""Torch Device class provides torch device selection services."""
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from deprecated import deprecated
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
# legacy APIs
|
||||
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
@@ -42,9 +48,23 @@ PRECISION_TO_NAME: Dict[torch.dtype, TorchPrecisionNames] = {v: k for k, v in NA
|
||||
class TorchDevice:
|
||||
"""Abstraction layer for torch devices."""
|
||||
|
||||
_model_cache: Optional["ModelCacheBase[AnyModel]"] = None
|
||||
|
||||
@classmethod
|
||||
def set_model_cache(cls, cache: "ModelCacheBase[AnyModel]"):
|
||||
"""Set the current model cache."""
|
||||
cls._model_cache = cache
|
||||
|
||||
@classmethod
|
||||
def choose_torch_device(cls) -> torch.device:
|
||||
"""Return the torch.device to use for accelerated inference."""
|
||||
if cls._model_cache:
|
||||
return cls._model_cache.get_execution_device()
|
||||
else:
|
||||
return cls._choose_device()
|
||||
|
||||
@classmethod
|
||||
def _choose_device(cls) -> torch.device:
|
||||
app_config = get_config()
|
||||
if app_config.device != "auto":
|
||||
device = torch.device(app_config.device)
|
||||
@@ -56,11 +76,19 @@ class TorchDevice:
|
||||
device = CPU_DEVICE
|
||||
return cls.normalize(device)
|
||||
|
||||
@classmethod
|
||||
def execution_devices(cls) -> Set[torch.device]:
|
||||
"""Return a list of torch.devices that can be used for accelerated inference."""
|
||||
app_config = get_config()
|
||||
if app_config.devices is None:
|
||||
return cls._lookup_execution_devices()
|
||||
return {torch.device(x) for x in app_config.devices}
|
||||
|
||||
@classmethod
|
||||
def choose_torch_dtype(cls, device: Optional[torch.device] = None) -> torch.dtype:
|
||||
"""Return the precision to use for accelerated inference."""
|
||||
device = device or cls.choose_torch_device()
|
||||
config = get_config()
|
||||
device = device or cls._choose_device()
|
||||
if device.type == "cuda" and torch.cuda.is_available():
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
if "GeForce GTX 1660" in device_name or "GeForce GTX 1650" in device_name:
|
||||
@@ -108,3 +136,13 @@ class TorchDevice:
|
||||
@classmethod
|
||||
def _to_dtype(cls, precision_name: TorchPrecisionNames) -> torch.dtype:
|
||||
return NAME_TO_PRECISION[precision_name]
|
||||
|
||||
@classmethod
|
||||
def _lookup_execution_devices(cls) -> Set[torch.device]:
|
||||
if torch.cuda.is_available():
|
||||
devices = {torch.device(f"cuda:{x}") for x in range(0, torch.cuda.device_count())}
|
||||
elif torch.backends.mps.is_available():
|
||||
devices = {torch.device("mps")}
|
||||
else:
|
||||
devices = {torch.device("cpu")}
|
||||
return devices
|
||||
|
||||
@@ -37,11 +37,7 @@
|
||||
"selectBoard": "Select a Board",
|
||||
"topMessage": "This board contains images used in the following features:",
|
||||
"uncategorized": "Uncategorized",
|
||||
"downloadBoard": "Download Board",
|
||||
"imagesWithCount_one": "{{count}} image",
|
||||
"imagesWithCount_other": "{{count}} images",
|
||||
"assetsWithCount_one": "{{count}} asset",
|
||||
"assetsWithCount_other": "{{count}} assets"
|
||||
"downloadBoard": "Download Board"
|
||||
},
|
||||
"accordions": {
|
||||
"generation": {
|
||||
@@ -384,11 +380,7 @@
|
||||
"problemDeletingImagesDesc": "One or more images could not be deleted",
|
||||
"viewerImage": "Viewer Image",
|
||||
"compareImage": "Compare Image",
|
||||
"noActiveSearch": "No active search",
|
||||
"openInViewer": "Open in Viewer",
|
||||
"searchingBy": "Searching by",
|
||||
"selectAllOnPage": "Select All On Page",
|
||||
"selectAllOnBoard": "Select All On Board",
|
||||
"selectForCompare": "Select for Compare",
|
||||
"selectAnImageToCompare": "Select an Image to Compare",
|
||||
"slider": "Slider",
|
||||
|
||||
@@ -2,7 +2,8 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { getListImagesUrl } from 'services/api/util';
|
||||
import type { ImageCache } from 'services/api/types';
|
||||
import { getListImagesUrl, imagesSelectors } from 'services/api/util';
|
||||
|
||||
export const addFirstListImagesListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@@ -17,10 +18,13 @@ export const addFirstListImagesListener = (startAppListening: AppStartListening)
|
||||
cancelActiveListeners();
|
||||
unsubscribe();
|
||||
|
||||
const data = action.payload;
|
||||
// TODO: figure out how to type the predicate
|
||||
const data = action.payload as ImageCache;
|
||||
|
||||
if (data.items.length > 0) {
|
||||
dispatch(imageSelected(data.items[0] ?? null));
|
||||
if (data.ids.length > 0) {
|
||||
// Select the first image
|
||||
const firstImage = imagesSelectors.selectAll(data)[0];
|
||||
dispatch(imageSelected(firstImage ?? null));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
boardIdSelected,
|
||||
galleryViewChanged,
|
||||
imageSelected,
|
||||
selectionChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
export const addBoardIdSelectedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@@ -18,9 +14,14 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
|
||||
|
||||
const state = getState();
|
||||
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const board_id = boardIdSelected.match(action) ? action.payload.boardId : state.gallery.selectedBoardId;
|
||||
|
||||
dispatch(selectionChanged([]));
|
||||
const galleryView = galleryViewChanged.match(action) ? action.payload : state.gallery.galleryView;
|
||||
|
||||
// when a board is selected, we need to wait until the board has loaded *some* images, then select the first one
|
||||
const categories = galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES;
|
||||
|
||||
const queryArgs = { board_id: board_id ?? 'none', categories };
|
||||
|
||||
// wait until the board has some images - maybe it already has some from a previous fetch
|
||||
// must use getState() to ensure we do not have stale state
|
||||
@@ -34,12 +35,11 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
|
||||
const { data: boardImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
|
||||
|
||||
if (boardImagesData && boardIdSelected.match(action) && action.payload.selectedImageName) {
|
||||
const selectedImage = boardImagesData.items.find(
|
||||
(item) => item.image_name === action.payload.selectedImageName
|
||||
);
|
||||
const selectedImage = imagesSelectors.selectById(boardImagesData, action.payload.selectedImageName);
|
||||
dispatch(imageSelected(selectedImage || null));
|
||||
} else if (boardImagesData) {
|
||||
dispatch(imageSelected(boardImagesData.items[0] || null));
|
||||
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
|
||||
dispatch(imageSelected(firstImage || null));
|
||||
} else {
|
||||
// board has no images - deselect
|
||||
dispatch(imageSelected(null));
|
||||
|
||||
@@ -4,6 +4,7 @@ import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelecto
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
export const galleryImageClicked = createAction<{
|
||||
imageDTO: ImageDTO;
|
||||
@@ -31,14 +32,14 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
||||
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
||||
const state = getState();
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const queryResult = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
|
||||
if (!queryResult.data) {
|
||||
if (!listImagesData) {
|
||||
// Should never happen if we have clicked a gallery image
|
||||
return;
|
||||
}
|
||||
|
||||
const imageDTOs = queryResult.data.items;
|
||||
const imageDTOs = imagesSelectors.selectAll(listImagesData);
|
||||
const selection = state.gallery.selection;
|
||||
|
||||
if (altKey) {
|
||||
|
||||
@@ -22,10 +22,11 @@ import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isImageFieldInputInstance } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { clamp, forEach } from 'lodash-es';
|
||||
import { api } from 'services/api';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
state.nodes.present.nodes.forEach((node) => {
|
||||
@@ -117,7 +118,32 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
|
||||
}
|
||||
|
||||
dispatch(isModalOpenChanged(false));
|
||||
|
||||
const state = getState();
|
||||
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
|
||||
|
||||
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
|
||||
const { image_name } = imageDTO;
|
||||
|
||||
const baseQueryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||
|
||||
const cachedImageDTOs = data ? imagesSelectors.selectAll(data) : [];
|
||||
|
||||
const deletedImageIndex = cachedImageDTOs.findIndex((i) => i.image_name === image_name);
|
||||
|
||||
const filteredImageDTOs = cachedImageDTOs.filter((i) => i.image_name !== image_name);
|
||||
|
||||
const newSelectedImageIndex = clamp(deletedImageIndex, 0, filteredImageDTOs.length - 1);
|
||||
|
||||
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex];
|
||||
|
||||
if (newSelectedImageDTO) {
|
||||
dispatch(imageSelected(newSelectedImageDTO));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
}
|
||||
|
||||
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||
if (imageUsage.isCanvasImage) {
|
||||
@@ -142,20 +168,6 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
|
||||
if (wasImageDeleted) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
|
||||
}
|
||||
|
||||
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
|
||||
|
||||
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
|
||||
const baseQueryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||
|
||||
if (data && data.items) {
|
||||
const newlySelectedImage = data?.items.find((img) => img.image_name !== imageDTO?.image_name);
|
||||
dispatch(imageSelected(newlySelectedImage || null));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -176,8 +188,10 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
|
||||
if (data && data.items[0]) {
|
||||
dispatch(imageSelected(data.items[0]));
|
||||
const newSelectedImageDTO = data ? imagesSelectors.selectAll(data)[0] : undefined;
|
||||
|
||||
if (newSelectedImageDTO) {
|
||||
dispatch(imageSelected(newSelectedImageDTO));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
|
||||
@@ -15,12 +15,7 @@ import {
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import {
|
||||
imageSelected,
|
||||
imageToCompareChanged,
|
||||
isImageViewerOpenChanged,
|
||||
selectionChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@@ -221,7 +216,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
board_id: boardId,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -239,7 +233,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
imageDTO,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -255,7 +248,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
board_id: boardId,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -269,7 +261,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
imageDTOs,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
},
|
||||
|
||||
@@ -8,14 +8,14 @@ import {
|
||||
galleryViewChanged,
|
||||
imageSelected,
|
||||
isImageViewerOpenChanged,
|
||||
offsetChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { getCategories, getListImagesUrl } from 'services/api/util';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
import { socketInvocationComplete } from 'services/events/actions';
|
||||
|
||||
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
|
||||
@@ -52,6 +52,24 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
}
|
||||
|
||||
if (!imageDTO.is_intermediate) {
|
||||
/**
|
||||
* Cache updates for when an image result is received
|
||||
* - add it to the no_board/images
|
||||
*/
|
||||
|
||||
dispatch(
|
||||
imagesApi.util.updateQueryData(
|
||||
'listImages',
|
||||
{
|
||||
board_id: imageDTO.board_id ?? 'none',
|
||||
categories: IMAGE_CATEGORIES,
|
||||
},
|
||||
(draft) => {
|
||||
imagesAdapter.addOne(draft, imageDTO);
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
// update the total images for the board
|
||||
dispatch(
|
||||
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
|
||||
@@ -60,18 +78,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(
|
||||
imagesApi.util.invalidateTags([
|
||||
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
|
||||
{
|
||||
type: 'ImageList',
|
||||
id: getListImagesUrl({
|
||||
board_id: imageDTO.board_id ?? 'none',
|
||||
categories: getCategories(imageDTO),
|
||||
}),
|
||||
},
|
||||
])
|
||||
);
|
||||
dispatch(imagesApi.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
|
||||
|
||||
const { shouldAutoSwitch } = gallery;
|
||||
|
||||
@@ -91,8 +98,6 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(offsetChanged(0));
|
||||
|
||||
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
|
||||
dispatch(
|
||||
boardIdSelected({
|
||||
|
||||
@@ -1,37 +1,47 @@
|
||||
import type { IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import type { MouseEvent } from 'react';
|
||||
import { memo } from 'react';
|
||||
import type { MouseEvent, ReactElement } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
const sx: SystemStyleObject = {
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
|
||||
},
|
||||
};
|
||||
|
||||
type Props = Omit<IconButtonProps, 'aria-label' | 'onClick' | 'tooltip'> & {
|
||||
type Props = {
|
||||
onClick: (event: MouseEvent<HTMLButtonElement>) => void;
|
||||
tooltip: string;
|
||||
icon?: ReactElement;
|
||||
styleOverrides?: SystemStyleObject;
|
||||
};
|
||||
|
||||
const IAIDndImageIcon = (props: Props) => {
|
||||
const { onClick, tooltip, icon, ...rest } = props;
|
||||
const { onClick, tooltip, icon, styleOverrides } = props;
|
||||
|
||||
const sx = useMemo(
|
||||
() => ({
|
||||
position: 'absolute',
|
||||
top: 1,
|
||||
insetInlineEnd: 1,
|
||||
p: 0,
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
|
||||
},
|
||||
...styleOverrides,
|
||||
}),
|
||||
[styleOverrides]
|
||||
);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
onClick={onClick}
|
||||
aria-label={tooltip}
|
||||
tooltip={tooltip}
|
||||
icon={icon}
|
||||
size="sm"
|
||||
variant="link"
|
||||
sx={sx}
|
||||
data-testid={tooltip}
|
||||
{...rest}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
16
invokeai/frontend/web/src/common/util/dateComparator.ts
Normal file
16
invokeai/frontend/web/src/common/util/dateComparator.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
/**
|
||||
* Comparator function for sorting dates in ascending order
|
||||
*/
|
||||
export const dateComparator = (a: string, b: string) => {
|
||||
const dateA = new Date(a);
|
||||
const dateB = new Date(b);
|
||||
|
||||
// sort in ascending order
|
||||
if (dateA > dateB) {
|
||||
return 1;
|
||||
}
|
||||
if (dateA < dateB) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Spinner } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
@@ -184,7 +185,7 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
|
||||
/>
|
||||
</Box>
|
||||
|
||||
<Flex flexDir="column" top={1} insetInlineEnd={1}>
|
||||
<>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
|
||||
@@ -194,13 +195,15 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
|
||||
onClick={handleSaveControlImage}
|
||||
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.saveControlImage')}
|
||||
styleOverrides={saveControlImageStyleOverrides}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.setControlImageDimensions')}
|
||||
styleOverrides={setControlImageDimensionsStyleOverrides}
|
||||
/>
|
||||
</Flex>
|
||||
</>
|
||||
|
||||
{pendingControlImages.includes(id) && (
|
||||
<Flex
|
||||
@@ -223,3 +226,6 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
|
||||
};
|
||||
|
||||
export default memo(ControlAdapterImagePreview);
|
||||
|
||||
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Spinner, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -202,13 +203,13 @@ export const ControlAdapterImagePreview = memo(
|
||||
onClick={handleSaveControlImage}
|
||||
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.saveControlImage')}
|
||||
mt={6}
|
||||
styleOverrides={saveControlImageStyleOverrides}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
mt={12}
|
||||
styleOverrides={setControlImageDimensionsStyleOverrides}
|
||||
/>
|
||||
</>
|
||||
|
||||
@@ -234,3 +235,6 @@ export const ControlAdapterImagePreview = memo(
|
||||
);
|
||||
|
||||
ControlAdapterImagePreview.displayName = 'ControlAdapterImagePreview';
|
||||
|
||||
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -99,7 +100,7 @@ export const IPAdapterImagePreview = memo(
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
mt={6}
|
||||
styleOverrides={setControlImageDimensionsStyleOverrides}
|
||||
/>
|
||||
</>
|
||||
</Flex>
|
||||
@@ -108,3 +109,5 @@ export const IPAdapterImagePreview = memo(
|
||||
);
|
||||
|
||||
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';
|
||||
|
||||
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -96,7 +97,7 @@ export const InitialImagePreview = memo(({ image, onChangeImage, droppableData,
|
||||
onClick={onUseSize}
|
||||
icon={imageDTO ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
mt={6}
|
||||
styleOverrides={useSizeStyleOverrides}
|
||||
/>
|
||||
</>
|
||||
</Flex>
|
||||
@@ -104,3 +105,5 @@ export const InitialImagePreview = memo(({ image, onChangeImage, droppableData,
|
||||
});
|
||||
|
||||
InitialImagePreview.displayName = 'InitialImagePreview';
|
||||
|
||||
const useSizeStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
|
||||
type Props = {
|
||||
board_id: string;
|
||||
};
|
||||
|
||||
export const BoardTotalsTooltip = ({ board_id }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { imagesTotal } = useGetBoardImagesTotalQuery(board_id, {
|
||||
selectFromResult: ({ data }) => {
|
||||
return { imagesTotal: data?.total ?? 0 };
|
||||
},
|
||||
});
|
||||
const { assetsTotal } = useGetBoardAssetsTotalQuery(board_id, {
|
||||
selectFromResult: ({ data }) => {
|
||||
return { assetsTotal: data?.total ?? 0 };
|
||||
},
|
||||
});
|
||||
return `${t('boards.imagesWithCount', { count: imagesTotal })}, ${t('boards.assetsWithCount', { count: assetsTotal })}`;
|
||||
};
|
||||
@@ -8,12 +8,15 @@ import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||
import type { AddToBoardDropData } from 'features/dnd/types';
|
||||
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
|
||||
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
|
||||
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
|
||||
import { autoAddBoardIdChanged, boardIdSelected, selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImagesSquare } from 'react-icons/pi';
|
||||
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
|
||||
import {
|
||||
useGetBoardAssetsTotalQuery,
|
||||
useGetBoardImagesTotalQuery,
|
||||
useUpdateBoardMutation,
|
||||
} from 'services/api/endpoints/boards';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import type { BoardDTO } from 'services/api/types';
|
||||
|
||||
@@ -48,6 +51,17 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
setIsHovered(false);
|
||||
}, []);
|
||||
|
||||
const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id);
|
||||
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id);
|
||||
const tooltip = useMemo(() => {
|
||||
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
|
||||
assetsTotal.total
|
||||
} asset${assetsTotal.total === 1 ? '' : 's'}`;
|
||||
}, [assetsTotal, imagesTotal]);
|
||||
|
||||
const { currentData: coverImage } = useGetImageDTOQuery(board.cover_image_name ?? skipToken);
|
||||
|
||||
const { board_name, board_id } = board;
|
||||
@@ -118,7 +132,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
>
|
||||
<BoardContextMenu board={board} board_id={board_id} setBoardToDelete={setBoardToDelete}>
|
||||
{(ref) => (
|
||||
<Tooltip label={<BoardTotalsTooltip board_id={board.board_id} />} openDelay={1000}>
|
||||
<Tooltip label={tooltip} openDelay={1000}>
|
||||
<Flex
|
||||
ref={ref}
|
||||
onClick={handleSelectBoard}
|
||||
|
||||
@@ -5,11 +5,11 @@ import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||
import type { RemoveFromBoardDropData } from 'features/dnd/types';
|
||||
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
|
||||
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
|
||||
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
|
||||
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
|
||||
interface Props {
|
||||
@@ -29,6 +29,17 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
}, [dispatch, autoAssignBoardOnClick]);
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const { data: imagesTotal } = useGetBoardImagesTotalQuery('none');
|
||||
const { data: assetsTotal } = useGetBoardAssetsTotalQuery('none');
|
||||
const tooltip = useMemo(() => {
|
||||
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
|
||||
assetsTotal.total
|
||||
} asset${assetsTotal.total === 1 ? '' : 's'}`;
|
||||
}, [assetsTotal, imagesTotal]);
|
||||
|
||||
const handleMouseOver = useCallback(() => {
|
||||
setIsHovered(true);
|
||||
}, []);
|
||||
@@ -60,7 +71,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
>
|
||||
<BoardContextMenu board_id="none">
|
||||
{(ref) => (
|
||||
<Tooltip label={<BoardTotalsTooltip board_id="none" />} openDelay={1000}>
|
||||
<Tooltip label={tooltip} openDelay={1000}>
|
||||
<Flex
|
||||
ref={ref}
|
||||
onClick={handleSelectBoard}
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
import { Flex, IconButton, Spacer, Tag, TagCloseButton, TagLabel, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BiSelectMultiple } from 'react-icons/bi';
|
||||
|
||||
import { GallerySearch } from './GallerySearch';
|
||||
|
||||
export const GalleryBulkSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { selection } = useAppSelector((s) => s.gallery);
|
||||
const { t } = useTranslation();
|
||||
const { imageDTOs } = useGalleryImages();
|
||||
|
||||
const onClickClearSelection = useCallback(() => {
|
||||
dispatch(selectionChanged([]));
|
||||
}, [dispatch]);
|
||||
|
||||
const onClickSelectAllPage = useCallback(() => {
|
||||
dispatch(selectionChanged(selection.concat(imageDTOs)));
|
||||
}, [dispatch, imageDTOs, selection]);
|
||||
|
||||
return (
|
||||
<Flex alignItems="center" justifyContent="space-between">
|
||||
<Flex>
|
||||
{selection.length > 0 ? (
|
||||
<Tag>
|
||||
<TagLabel>
|
||||
{selection.length} {t('common.selected')}
|
||||
</TagLabel>
|
||||
<Tooltip label="Clear selection">
|
||||
<TagCloseButton onClick={onClickClearSelection} />
|
||||
</Tooltip>
|
||||
</Tag>
|
||||
) : (
|
||||
<Spacer />
|
||||
)}
|
||||
|
||||
<Tooltip label={t('gallery.selectAllOnPage')}>
|
||||
<IconButton
|
||||
variant="outline"
|
||||
size="sm"
|
||||
icon={<BiSelectMultiple />}
|
||||
aria-label="Bulk select"
|
||||
onClick={onClickSelectAllPage}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
|
||||
<GallerySearch />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -1,97 +0,0 @@
|
||||
import { Flex, IconButton, Input, InputGroup, InputRightElement, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { searchTermChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { motion } from 'framer-motion';
|
||||
import { debounce } from 'lodash-es';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiMagnifyingGlassBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
export const GallerySearch = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { searchTerm } = useAppSelector((s) => s.gallery);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [searchTermInput, setSearchTermInput] = useState('');
|
||||
|
||||
const debouncedSetSearchTerm = useMemo(() => {
|
||||
return debounce((value: string) => {
|
||||
dispatch(searchTermChanged(value));
|
||||
}, 1000);
|
||||
}, [dispatch]);
|
||||
|
||||
const onChangeInput = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
setSearchTermInput(e.target.value);
|
||||
debouncedSetSearchTerm(e.target.value);
|
||||
},
|
||||
[debouncedSetSearchTerm]
|
||||
);
|
||||
|
||||
const onClearInput = useCallback(() => {
|
||||
setSearchTermInput('');
|
||||
debouncedSetSearchTerm('');
|
||||
}, [debouncedSetSearchTerm]);
|
||||
|
||||
const toggleExpanded = useCallback((newState: boolean) => {
|
||||
setExpanded(newState);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex>
|
||||
{!expanded && (
|
||||
<Tooltip
|
||||
label={
|
||||
searchTerm && searchTerm.length ? `${t('gallery.searchingBy')} ${searchTerm}` : t('gallery.noActiveSearch')
|
||||
}
|
||||
>
|
||||
<IconButton
|
||||
aria-label="Close"
|
||||
icon={<PiMagnifyingGlassBold />}
|
||||
onClick={toggleExpanded.bind(null, true)}
|
||||
variant="outline"
|
||||
size="sm"
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
<motion.div
|
||||
initial={false}
|
||||
animate={{ width: expanded ? '200px' : '0px' }}
|
||||
transition={{ duration: 0.3 }}
|
||||
style={{ overflow: 'hidden' }}
|
||||
>
|
||||
<InputGroup size="sm">
|
||||
<IconButton
|
||||
aria-label="Close"
|
||||
icon={<PiMagnifyingGlassBold />}
|
||||
onClick={toggleExpanded.bind(null, false)}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
/>
|
||||
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Search..."
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onChange={onChangeInput}
|
||||
value={searchTermInput}
|
||||
/>
|
||||
{searchTermInput && searchTermInput.length && (
|
||||
<InputRightElement h="full" pe={2}>
|
||||
<IconButton
|
||||
onClick={onClearInput}
|
||||
size="sm"
|
||||
variant="link"
|
||||
aria-label={t('boards.clearSearch')}
|
||||
icon={<PiXBold />}
|
||||
/>
|
||||
</InputRightElement>
|
||||
)}
|
||||
</InputGroup>
|
||||
</motion.div>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -1,22 +1,22 @@
|
||||
import { Box, Button, ButtonGroup, Flex, Tab, TabList, Tabs, useDisclosure } from '@invoke-ai/ui-library';
|
||||
import { Box, Button, ButtonGroup, Flex, Tab, TabList, Tabs, useDisclosure, VStack } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $galleryHeader } from 'app/store/nanostores/galleryHeader';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { galleryViewChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImagesBold } from 'react-icons/pi';
|
||||
import { RiServerLine } from 'react-icons/ri';
|
||||
|
||||
import BoardsList from './Boards/BoardsList/BoardsList';
|
||||
import GalleryBoardName from './GalleryBoardName';
|
||||
import { GalleryBulkSelect } from './GalleryBulkSelect';
|
||||
import GallerySettingsPopover from './GallerySettingsPopover';
|
||||
import GalleryImageGrid from './ImageGrid/GalleryImageGrid';
|
||||
import { GalleryPagination } from './ImageGrid/GalleryPagination';
|
||||
|
||||
const ImageGalleryContent = () => {
|
||||
const { t } = useTranslation();
|
||||
const resizeObserverRef = useRef<HTMLDivElement>(null);
|
||||
const galleryGridRef = useRef<HTMLDivElement>(null);
|
||||
const galleryView = useAppSelector((s) => s.gallery.galleryView);
|
||||
const dispatch = useAppDispatch();
|
||||
const galleryHeader = useStore($galleryHeader);
|
||||
@@ -31,10 +31,10 @@ const ImageGalleryContent = () => {
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex layerStyle="first" flexDirection="column" h="full" w="full" borderRadius="base" p={2} gap={2}>
|
||||
<VStack layerStyle="first" flexDirection="column" h="full" w="full" borderRadius="base" p={2}>
|
||||
{galleryHeader}
|
||||
<Box>
|
||||
<Flex alignItems="center" justifyContent="space-between" gap={2}>
|
||||
<Box w="full">
|
||||
<Flex ref={resizeObserverRef} alignItems="center" justifyContent="space-between" gap={2}>
|
||||
<GalleryBoardName isOpen={isBoardListOpen} onToggle={onToggleBoardList} />
|
||||
<GallerySettingsPopover />
|
||||
</Flex>
|
||||
@@ -42,41 +42,40 @@ const ImageGalleryContent = () => {
|
||||
<BoardsList isOpen={isBoardListOpen} />
|
||||
</Box>
|
||||
</Box>
|
||||
<Flex alignItems="center" justifyContent="space-between" gap={2}>
|
||||
<Tabs index={galleryView === 'images' ? 0 : 1} variant="unstyled" size="sm" w="full">
|
||||
<TabList>
|
||||
<ButtonGroup w="full">
|
||||
<Tab
|
||||
as={Button}
|
||||
size="sm"
|
||||
isChecked={galleryView === 'images'}
|
||||
onClick={handleClickImages}
|
||||
w="full"
|
||||
leftIcon={<PiImagesBold size="16px" />}
|
||||
data-testid="images-tab"
|
||||
>
|
||||
{t('parameters.images')}
|
||||
</Tab>
|
||||
<Tab
|
||||
as={Button}
|
||||
size="sm"
|
||||
isChecked={galleryView === 'assets'}
|
||||
onClick={handleClickAssets}
|
||||
w="full"
|
||||
leftIcon={<RiServerLine size="16px" />}
|
||||
data-testid="assets-tab"
|
||||
>
|
||||
{t('gallery.assets')}
|
||||
</Tab>
|
||||
</ButtonGroup>
|
||||
</TabList>
|
||||
</Tabs>
|
||||
<Flex ref={galleryGridRef} direction="column" gap={2} h="full" w="full">
|
||||
<Flex alignItems="center" justifyContent="space-between" gap={2}>
|
||||
<Tabs index={galleryView === 'images' ? 0 : 1} variant="unstyled" size="sm" w="full">
|
||||
<TabList>
|
||||
<ButtonGroup w="full">
|
||||
<Tab
|
||||
as={Button}
|
||||
size="sm"
|
||||
isChecked={galleryView === 'images'}
|
||||
onClick={handleClickImages}
|
||||
w="full"
|
||||
leftIcon={<PiImagesBold size="16px" />}
|
||||
data-testid="images-tab"
|
||||
>
|
||||
{t('parameters.images')}
|
||||
</Tab>
|
||||
<Tab
|
||||
as={Button}
|
||||
size="sm"
|
||||
isChecked={galleryView === 'assets'}
|
||||
onClick={handleClickAssets}
|
||||
w="full"
|
||||
leftIcon={<RiServerLine size="16px" />}
|
||||
data-testid="assets-tab"
|
||||
>
|
||||
{t('gallery.assets')}
|
||||
</Tab>
|
||||
</ButtonGroup>
|
||||
</TabList>
|
||||
</Tabs>
|
||||
</Flex>
|
||||
<GalleryImageGrid />
|
||||
</Flex>
|
||||
<GalleryBulkSelect />
|
||||
|
||||
<GalleryImageGrid />
|
||||
<GalleryPagination />
|
||||
</Flex>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -16,13 +16,13 @@ import type { MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiStarBold, PiStarFill, PiTrashSimpleFill } from 'react-icons/pi';
|
||||
import { useStarImagesMutation, useUnstarImagesMutation } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
// This class name is used to calculate the number of images that fit in the gallery
|
||||
export const GALLERY_IMAGE_CLASS_NAME = 'gallery-image';
|
||||
import { useGetImageDTOQuery, useStarImagesMutation, useUnstarImagesMutation } from 'services/api/endpoints/images';
|
||||
|
||||
const imageSx: SystemStyleObject = { w: 'full', h: 'full' };
|
||||
const imageIconStyleOverrides: SystemStyleObject = {
|
||||
bottom: 2,
|
||||
top: 'auto',
|
||||
};
|
||||
const boxSx: SystemStyleObject = {
|
||||
containerType: 'inline-size',
|
||||
};
|
||||
@@ -34,22 +34,24 @@ const badgeSx: SystemStyleObject = {
|
||||
};
|
||||
|
||||
interface HoverableImageProps {
|
||||
imageDTO: ImageDTO;
|
||||
imageName: string;
|
||||
index: number;
|
||||
}
|
||||
|
||||
const GalleryImage = ({ index, imageDTO }: HoverableImageProps) => {
|
||||
const GalleryImage = (props: HoverableImageProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { imageName } = props;
|
||||
const { currentData: imageDTO } = useGetImageDTOQuery(imageName);
|
||||
const shift = useShiftModifier();
|
||||
const { t } = useTranslation();
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
|
||||
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageDTO.image_name);
|
||||
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageName);
|
||||
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
|
||||
|
||||
const customStarUi = useStore($customStarUI);
|
||||
|
||||
const imageContainerRef = useScrollIntoView(isSelected, index, areMultiplesSelected);
|
||||
const imageContainerRef = useScrollIntoView(isSelected, props.index, areMultiplesSelected);
|
||||
|
||||
const handleDelete = useCallback(
|
||||
(e: MouseEvent<HTMLButtonElement>) => {
|
||||
@@ -112,32 +114,32 @@ const GalleryImage = ({ index, imageDTO }: HoverableImageProps) => {
|
||||
}, []);
|
||||
|
||||
const starIcon = useMemo(() => {
|
||||
if (imageDTO.starred) {
|
||||
if (imageDTO?.starred) {
|
||||
return customStarUi ? customStarUi.on.icon : <PiStarFill size="20" />;
|
||||
}
|
||||
if (!imageDTO.starred && isHovered) {
|
||||
if (!imageDTO?.starred && isHovered) {
|
||||
return customStarUi ? customStarUi.off.icon : <PiStarBold size="20" />;
|
||||
}
|
||||
}, [imageDTO.starred, isHovered, customStarUi]);
|
||||
}, [imageDTO?.starred, isHovered, customStarUi]);
|
||||
|
||||
const starTooltip = useMemo(() => {
|
||||
if (imageDTO.starred) {
|
||||
if (imageDTO?.starred) {
|
||||
return customStarUi ? customStarUi.off.text : 'Unstar';
|
||||
}
|
||||
if (!imageDTO.starred) {
|
||||
if (!imageDTO?.starred) {
|
||||
return customStarUi ? customStarUi.on.text : 'Star';
|
||||
}
|
||||
return '';
|
||||
}, [imageDTO.starred, customStarUi]);
|
||||
}, [imageDTO?.starred, customStarUi]);
|
||||
|
||||
const dataTestId = useMemo(() => getGalleryImageDataTestId(imageDTO.image_name), [imageDTO.image_name]);
|
||||
const dataTestId = useMemo(() => getGalleryImageDataTestId(imageDTO?.image_name), [imageDTO?.image_name]);
|
||||
|
||||
if (!imageDTO) {
|
||||
return <IAIFillSkeleton />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box w="full" h="full" p={1.5} className={GALLERY_IMAGE_CLASS_NAME} data-testid={dataTestId} sx={boxSx}>
|
||||
<Box w="full" h="full" className="gallerygrid-image" data-testid={dataTestId} sx={boxSx}>
|
||||
<Flex
|
||||
ref={imageContainerRef}
|
||||
userSelect="none"
|
||||
@@ -181,23 +183,14 @@ const GalleryImage = ({ index, imageDTO }: HoverableImageProps) => {
|
||||
pointerEvents="none"
|
||||
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
|
||||
)}
|
||||
<IAIDndImageIcon
|
||||
onClick={toggleStarredState}
|
||||
icon={starIcon}
|
||||
tooltip={starTooltip}
|
||||
position="absolute"
|
||||
top={1}
|
||||
insetInlineEnd={1}
|
||||
/>
|
||||
<IAIDndImageIcon onClick={toggleStarredState} icon={starIcon} tooltip={starTooltip} />
|
||||
|
||||
{isHovered && shift && (
|
||||
<IAIDndImageIcon
|
||||
onClick={handleDelete}
|
||||
icon={<PiTrashSimpleFill size="16px" />}
|
||||
tooltip={t('gallery.deleteImage_one')}
|
||||
position="absolute"
|
||||
bottom={1}
|
||||
insetInlineEnd={1}
|
||||
tooltip={t('gallery.deleteImage', { count: 1 })}
|
||||
styleOverrides={imageIconStyleOverrides}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
|
||||
@@ -1,32 +1,120 @@
|
||||
import { Box, Flex, Grid } from '@invoke-ai/ui-library';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { Box, Button, Flex } from '@invoke-ai/ui-library';
|
||||
import type { EntityId } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
||||
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
|
||||
import { useGalleryHotkeys } from 'features/gallery/hooks/useGalleryHotkeys';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { limitChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { debounce } from 'lodash-es';
|
||||
import { memo, useEffect, useMemo, useState } from 'react';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImageBold, PiWarningCircleBold } from 'react-icons/pi';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
import type { GridComponents, ItemContent, ListRange, VirtuosoGridHandle } from 'react-virtuoso';
|
||||
import { VirtuosoGrid } from 'react-virtuoso';
|
||||
import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
|
||||
|
||||
import { GALLERY_GRID_CLASS_NAME } from './constants';
|
||||
import GalleryImage, { GALLERY_IMAGE_CLASS_NAME } from './GalleryImage';
|
||||
import GalleryImage from './GalleryImage';
|
||||
import ImageGridItemContainer from './ImageGridItemContainer';
|
||||
import ImageGridListContainer from './ImageGridListContainer';
|
||||
|
||||
const components: GridComponents = {
|
||||
Item: ImageGridItemContainer,
|
||||
List: ImageGridListContainer,
|
||||
};
|
||||
|
||||
const virtuosoStyles: CSSProperties = { height: '100%' };
|
||||
|
||||
const GalleryImageGrid = () => {
|
||||
useGalleryHotkeys();
|
||||
const { t } = useTranslation();
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const { imageDTOs, isLoading, isError, isFetching } = useListImagesQuery(queryArgs, {
|
||||
selectFromResult: ({ data, isLoading, isSuccess, isError, isFetching }) => ({
|
||||
imageDTOs: data?.items ?? EMPTY_ARRAY,
|
||||
isLoading,
|
||||
isSuccess,
|
||||
isError,
|
||||
isFetching,
|
||||
}),
|
||||
});
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
const [scroller, setScroller] = useState<HTMLElement | null>(null);
|
||||
const [initialize, osInstance] = useOverlayScrollbars(overlayScrollbarsParams);
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const { currentViewTotal } = useBoardTotal(selectedBoardId);
|
||||
const virtuosoRangeRef = useRef<ListRange | null>(null);
|
||||
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
|
||||
const {
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
queryResult: { currentData, isFetching, isSuccess, isError },
|
||||
} = useGalleryImages();
|
||||
useGalleryHotkeys();
|
||||
const itemContentFunc: ItemContent<EntityId, void> = useCallback(
|
||||
(index, imageName) => <GalleryImage key={imageName} index={index} imageName={imageName as string} />,
|
||||
[]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
// Initialize the gallery's custom scrollbar
|
||||
const { current: root } = rootRef;
|
||||
if (scroller && root) {
|
||||
initialize({
|
||||
target: root,
|
||||
elements: {
|
||||
viewport: scroller,
|
||||
},
|
||||
});
|
||||
}
|
||||
return () => osInstance()?.destroy();
|
||||
}, [scroller, initialize, osInstance]);
|
||||
|
||||
const onRangeChanged = useCallback((range: ListRange) => {
|
||||
virtuosoRangeRef.current = range;
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
virtuosoGridRefs.set({ rootRef, virtuosoRangeRef, virtuosoRef });
|
||||
return () => {
|
||||
virtuosoGridRefs.set({});
|
||||
};
|
||||
}, []);
|
||||
|
||||
if (!currentData) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<IAINoContentFallback label={t('gallery.loading')} icon={PiImageBold} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSuccess && currentData?.ids.length === 0) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<IAINoContentFallback label={t('gallery.noImagesInGallery')} icon={PiImageBold} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSuccess && currentData) {
|
||||
return (
|
||||
<>
|
||||
<Box ref={rootRef} data-overlayscrollbars="" h="100%" id="gallery-grid">
|
||||
<VirtuosoGrid
|
||||
style={virtuosoStyles}
|
||||
data={currentData.ids}
|
||||
endReached={handleLoadMoreImages}
|
||||
components={components}
|
||||
scrollerRef={setScroller}
|
||||
itemContent={itemContentFunc}
|
||||
ref={virtuosoRef}
|
||||
rangeChanged={onRangeChanged}
|
||||
overscan={10}
|
||||
/>
|
||||
</Box>
|
||||
<Button
|
||||
onClick={handleLoadMoreImages}
|
||||
isDisabled={!areMoreImagesAvailable}
|
||||
isLoading={isFetching}
|
||||
loadingText={t('gallery.loading')}
|
||||
flexShrink={0}
|
||||
>
|
||||
{`${t('accessibility.loadMore')} (${currentData.ids.length} / ${currentViewTotal})`}
|
||||
</Button>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
if (isError) {
|
||||
return (
|
||||
@@ -36,115 +124,7 @@ const GalleryImageGrid = () => {
|
||||
);
|
||||
}
|
||||
|
||||
if (isLoading || isFetching) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<IAINoContentFallback label={t('gallery.loading')} icon={PiImageBold} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (imageDTOs.length === 0) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<IAINoContentFallback label={t('gallery.noImagesInGallery')} icon={PiImageBold} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return <Content />;
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(GalleryImageGrid);
|
||||
|
||||
const Content = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
|
||||
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const { imageDTOs } = useListImagesQuery(queryArgs, {
|
||||
selectFromResult: ({ data }) => ({ imageDTOs: data?.items ?? EMPTY_ARRAY }),
|
||||
});
|
||||
// Use a callback ref to get reactivity on the container element because it is conditionally rendered
|
||||
const [container, containerRef] = useState<HTMLDivElement | null>(null);
|
||||
|
||||
const calculateNewLimit = useMemo(() => {
|
||||
// Debounce this to not thrash the API
|
||||
return debounce(() => {
|
||||
if (!container) {
|
||||
// Container not rendered yet
|
||||
return;
|
||||
}
|
||||
// Managing refs for dynamically rendered components is a bit tedious:
|
||||
// - https://react.dev/learn/manipulating-the-dom-with-refs#how-to-manage-a-list-of-refs-using-a-ref-callback
|
||||
// As a easy workaround, we can just grab the first gallery image element directly.
|
||||
const galleryImageEl = document.querySelector(`.${GALLERY_IMAGE_CLASS_NAME}`);
|
||||
if (!galleryImageEl) {
|
||||
// No images in gallery?
|
||||
return;
|
||||
}
|
||||
|
||||
const galleryImageRect = galleryImageEl.getBoundingClientRect();
|
||||
const containerRect = container.getBoundingClientRect();
|
||||
|
||||
if (!galleryImageRect.width || !galleryImageRect.height || !containerRect.width || !containerRect.height) {
|
||||
// Gallery is too small to fit images or not rendered yet
|
||||
return;
|
||||
}
|
||||
|
||||
// Floating-point precision requires we round to get the correct number of images per row
|
||||
const imagesPerRow = Math.round(containerRect.width / galleryImageRect.width);
|
||||
// However, when calculating the number of images per column, we want to floor the value to not overflow the container
|
||||
const imagesPerColumn = Math.floor(containerRect.height / galleryImageRect.height);
|
||||
// Always load at least 1 row of images
|
||||
const limit = Math.max(imagesPerRow, imagesPerRow * imagesPerColumn);
|
||||
dispatch(limitChanged(limit));
|
||||
}, 300);
|
||||
}, [container, dispatch]);
|
||||
|
||||
useEffect(() => {
|
||||
// We want to recalculate the limit when image size changes
|
||||
calculateNewLimit();
|
||||
}, [calculateNewLimit, galleryImageMinimumWidth]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
|
||||
const resizeObserver = new ResizeObserver(calculateNewLimit);
|
||||
resizeObserver.observe(container);
|
||||
|
||||
// First render
|
||||
calculateNewLimit();
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
}, [calculateNewLimit, container, dispatch]);
|
||||
|
||||
return (
|
||||
<Box position="relative" w="full" h="full">
|
||||
<Box
|
||||
ref={containerRef}
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
w="full"
|
||||
h="full"
|
||||
overflow="hidden"
|
||||
>
|
||||
<Grid
|
||||
className={GALLERY_GRID_CLASS_NAME}
|
||||
gridTemplateColumns={`repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`}
|
||||
>
|
||||
{imageDTOs.map((imageDTO, index) => (
|
||||
<GalleryImage key={imageDTO.image_name} imageDTO={imageDTO} index={index} />
|
||||
))}
|
||||
</Grid>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
import { Button, Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
|
||||
import { PiCaretDoubleLeftBold, PiCaretDoubleRightBold, PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||
|
||||
export const GalleryPagination = () => {
|
||||
const {
|
||||
goPrev,
|
||||
goNext,
|
||||
goToFirst,
|
||||
goToLast,
|
||||
isFirstEnabled,
|
||||
isLastEnabled,
|
||||
isPrevEnabled,
|
||||
isNextEnabled,
|
||||
pageButtons,
|
||||
goToPage,
|
||||
currentPage,
|
||||
rangeDisplay,
|
||||
total,
|
||||
} = useGalleryPagination();
|
||||
|
||||
if (!total) {
|
||||
return <Flex flexDir="column" alignItems="center" gap="2" height="48px"></Flex>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" alignItems="center" gap="2" height="48px">
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="prev"
|
||||
icon={<PiCaretDoubleLeftBold />}
|
||||
onClick={goToFirst}
|
||||
isDisabled={!isFirstEnabled}
|
||||
/>
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="prev"
|
||||
icon={<PiCaretLeftBold />}
|
||||
onClick={goPrev}
|
||||
isDisabled={!isPrevEnabled}
|
||||
/>
|
||||
<Spacer />
|
||||
{pageButtons.map((page) => (
|
||||
<Button
|
||||
size="sm"
|
||||
key={page}
|
||||
onClick={goToPage.bind(null, page)}
|
||||
variant={currentPage === page ? 'solid' : 'outline'}
|
||||
>
|
||||
{page + 1}
|
||||
</Button>
|
||||
))}
|
||||
<Spacer />
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="next"
|
||||
icon={<PiCaretRightBold />}
|
||||
onClick={goNext}
|
||||
isDisabled={!isNextEnabled}
|
||||
/>
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="next"
|
||||
icon={<PiCaretDoubleRightBold />}
|
||||
onClick={goToLast}
|
||||
isDisabled={!isLastEnabled}
|
||||
/>
|
||||
</Flex>
|
||||
<Text>{rangeDisplay}</Text>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,15 @@
|
||||
import type { FlexProps } from '@invoke-ai/ui-library';
|
||||
import { Box, forwardRef } from '@invoke-ai/ui-library';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const imageItemContainerTestId = 'image-item-container';
|
||||
|
||||
type ItemContainerProps = PropsWithChildren & FlexProps;
|
||||
const ItemContainer = forwardRef((props: ItemContainerProps, ref) => (
|
||||
<Box className="item-container" ref={ref} p={1.5} data-testid={imageItemContainerTestId}>
|
||||
{props.children}
|
||||
</Box>
|
||||
));
|
||||
|
||||
export default memo(ItemContainer);
|
||||
@@ -0,0 +1,26 @@
|
||||
import type { FlexProps } from '@invoke-ai/ui-library';
|
||||
import { forwardRef, Grid } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const imageListContainerTestId = 'image-list-container';
|
||||
|
||||
type ListContainerProps = PropsWithChildren & FlexProps;
|
||||
const ListContainer = forwardRef((props: ListContainerProps, ref) => {
|
||||
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
|
||||
|
||||
return (
|
||||
<Grid
|
||||
{...props}
|
||||
className="list-container"
|
||||
ref={ref}
|
||||
gridTemplateColumns={`repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`}
|
||||
data-testid={imageListContainerTestId}
|
||||
>
|
||||
{props.children}
|
||||
</Grid>
|
||||
);
|
||||
});
|
||||
|
||||
export default memo(ListContainer);
|
||||
@@ -1 +0,0 @@
|
||||
export const GALLERY_GRID_CLASS_NAME = 'gallery-grid';
|
||||
@@ -2,7 +2,6 @@ import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, IconButton, Spinner } from '@invoke-ai/ui-library';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { useGalleryNavigation } from 'features/gallery/hooks/useGalleryNavigation';
|
||||
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDoubleRightBold, PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||
@@ -17,8 +16,11 @@ const NextPrevImageButtons = () => {
|
||||
|
||||
const { prevImage, nextImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
|
||||
|
||||
const { isFetching } = useGalleryImages().queryResult;
|
||||
const { isNextEnabled, goNext } = useGalleryPagination();
|
||||
const {
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
queryResult: { isFetching },
|
||||
} = useGalleryImages();
|
||||
|
||||
return (
|
||||
<Box pos="relative" h="full" w="full">
|
||||
@@ -45,17 +47,17 @@ const NextPrevImageButtons = () => {
|
||||
sx={nextPrevButtonStyles}
|
||||
/>
|
||||
)}
|
||||
{isOnLastImage && isNextEnabled && !isFetching && (
|
||||
{isOnLastImage && areMoreImagesAvailable && !isFetching && (
|
||||
<IconButton
|
||||
aria-label={t('accessibility.loadMore')}
|
||||
icon={<PiCaretDoubleRightBold size={64} />}
|
||||
variant="unstyled"
|
||||
onClick={goNext}
|
||||
onClick={handleLoadMoreImages}
|
||||
boxSize={16}
|
||||
sx={nextPrevButtonStyles}
|
||||
/>
|
||||
)}
|
||||
{isOnLastImage && isNextEnabled && isFetching && (
|
||||
{isOnLastImage && areMoreImagesAvailable && isFetching && (
|
||||
<Flex w={16} h={16} alignItems="center" justifyContent="center">
|
||||
<Spinner opacity={0.5} size="xl" />
|
||||
</Flex>
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { useGalleryNavigation } from 'features/gallery/hooks/useGalleryNavigation';
|
||||
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
|
||||
/**
|
||||
* Registers gallery hotkeys. This hook is a singleton.
|
||||
@@ -19,30 +17,21 @@ export const useGalleryHotkeys = () => {
|
||||
return activeTabName !== 'canvas' || !isStaging;
|
||||
}, [activeTabName, isStaging]);
|
||||
|
||||
const { goNext, goPrev, isNextEnabled, isPrevEnabled } = useGalleryPagination();
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const queryResult = useListImagesQuery(queryArgs);
|
||||
|
||||
const {
|
||||
handleLeftImage,
|
||||
handleRightImage,
|
||||
handleUpImage,
|
||||
handleDownImage,
|
||||
areImagesBelowCurrent,
|
||||
isOnFirstImageOfView,
|
||||
isOnLastImageOfView,
|
||||
} = useGalleryNavigation();
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
queryResult: { isFetching },
|
||||
} = useGalleryImages();
|
||||
|
||||
const { handleLeftImage, handleRightImage, handleUpImage, handleDownImage, isOnLastImage, areImagesBelowCurrent } =
|
||||
useGalleryNavigation();
|
||||
|
||||
useHotkeys(
|
||||
['left', 'alt+left'],
|
||||
(e) => {
|
||||
if (isOnFirstImageOfView && isPrevEnabled && !queryResult.isFetching) {
|
||||
goPrev();
|
||||
return;
|
||||
}
|
||||
canNavigateGallery && handleLeftImage(e.altKey);
|
||||
},
|
||||
[handleLeftImage, canNavigateGallery, isOnFirstImageOfView, goPrev, isPrevEnabled, queryResult.isFetching]
|
||||
[handleLeftImage, canNavigateGallery]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
@@ -51,15 +40,15 @@ export const useGalleryHotkeys = () => {
|
||||
if (!canNavigateGallery) {
|
||||
return;
|
||||
}
|
||||
if (isOnLastImageOfView && isNextEnabled && !queryResult.isFetching) {
|
||||
goNext();
|
||||
if (isOnLastImage && areMoreImagesAvailable && !isFetching) {
|
||||
handleLoadMoreImages();
|
||||
return;
|
||||
}
|
||||
if (!isOnLastImageOfView) {
|
||||
if (!isOnLastImage) {
|
||||
handleRightImage(e.altKey);
|
||||
}
|
||||
},
|
||||
[isOnLastImageOfView, goNext, isNextEnabled, queryResult.isFetching, handleRightImage, canNavigateGallery]
|
||||
[isOnLastImage, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleRightImage, canNavigateGallery]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
@@ -74,13 +63,13 @@ export const useGalleryHotkeys = () => {
|
||||
useHotkeys(
|
||||
['down', 'alt+down'],
|
||||
(e) => {
|
||||
if (!areImagesBelowCurrent && isNextEnabled && !queryResult.isFetching) {
|
||||
goNext();
|
||||
if (!areImagesBelowCurrent && areMoreImagesAvailable && !isFetching) {
|
||||
handleLoadMoreImages();
|
||||
return;
|
||||
}
|
||||
handleDownImage(e.altKey);
|
||||
},
|
||||
{ preventDefault: true },
|
||||
[areImagesBelowCurrent, goNext, isNextEnabled, queryResult.isFetching, handleDownImage]
|
||||
[areImagesBelowCurrent, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleDownImage]
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,15 +1,38 @@
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { useMemo } from 'react';
|
||||
import { moreImagesLoaded } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
|
||||
/**
|
||||
* Provides access to the gallery images and a way to imperatively fetch more.
|
||||
*/
|
||||
export const useGalleryImages = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const galleryView = useAppSelector((s) => s.gallery.galleryView);
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const queryResult = useListImagesQuery(queryArgs);
|
||||
const imageDTOs = useMemo(() => queryResult.data?.items ?? EMPTY_ARRAY, [queryResult.data]);
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(selectedBoardId);
|
||||
const { data: imagesTotal } = useGetBoardImagesTotalQuery(selectedBoardId);
|
||||
const currentViewTotal = useMemo(
|
||||
() => (galleryView === 'images' ? imagesTotal?.total : assetsTotal?.total),
|
||||
[assetsTotal?.total, galleryView, imagesTotal?.total]
|
||||
);
|
||||
const areMoreImagesAvailable = useMemo(() => {
|
||||
if (!currentViewTotal || !queryResult.data) {
|
||||
return false;
|
||||
}
|
||||
return queryResult.data.ids.length < currentViewTotal;
|
||||
}, [queryResult.data, currentViewTotal]);
|
||||
const handleLoadMoreImages = useCallback(() => {
|
||||
dispatch(moreImagesLoaded());
|
||||
}, [dispatch]);
|
||||
|
||||
return {
|
||||
imageDTOs,
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
queryResult,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { useAltModifier } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { GALLERY_GRID_CLASS_NAME } from 'features/gallery/components/ImageGrid/constants';
|
||||
import { GALLERY_IMAGE_CLASS_NAME } from 'features/gallery/components/ImageGrid/GalleryImage';
|
||||
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
|
||||
import { imageItemContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridItemContainer';
|
||||
import { imageListContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridListContainer';
|
||||
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { imageSelected, imageToCompareChanged } from 'features/gallery/store/gallerySlice';
|
||||
@@ -11,6 +11,7 @@ import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAli
|
||||
import { clamp } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
/**
|
||||
* This hook is used to navigate the gallery using the arrow keys.
|
||||
@@ -28,9 +29,10 @@ import type { ImageDTO } from 'services/api/types';
|
||||
*/
|
||||
const getImagesPerRow = (): number => {
|
||||
const widthOfGalleryImage =
|
||||
document.querySelector(`.${GALLERY_IMAGE_CLASS_NAME}`)?.getBoundingClientRect().width ?? 1;
|
||||
document.querySelector(`[data-testid="${imageItemContainerTestId}"]`)?.getBoundingClientRect().width ?? 1;
|
||||
|
||||
const widthOfGalleryGrid = document.querySelector(`.${GALLERY_GRID_CLASS_NAME}`)?.getBoundingClientRect().width ?? 0;
|
||||
const widthOfGalleryGrid =
|
||||
document.querySelector(`[data-testid="${imageListContainerTestId}"]`)?.getBoundingClientRect().width ?? 0;
|
||||
|
||||
const imagesPerRow = Math.round(widthOfGalleryGrid / widthOfGalleryImage);
|
||||
|
||||
@@ -113,8 +115,6 @@ type UseGalleryNavigationReturn = {
|
||||
isOnFirstImage: boolean;
|
||||
isOnLastImage: boolean;
|
||||
areImagesBelowCurrent: boolean;
|
||||
isOnFirstImageOfView: boolean;
|
||||
isOnLastImageOfView: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -134,19 +134,23 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
return lastSelected;
|
||||
}
|
||||
});
|
||||
const { imageDTOs } = useGalleryImages();
|
||||
const loadedImagesCount = useMemo(() => imageDTOs.length, [imageDTOs.length]);
|
||||
|
||||
const {
|
||||
queryResult: { data },
|
||||
} = useGalleryImages();
|
||||
const loadedImagesCount = useMemo(() => data?.ids.length ?? 0, [data?.ids.length]);
|
||||
const lastSelectedImageIndex = useMemo(() => {
|
||||
if (imageDTOs.length === 0 || !lastSelectedImage) {
|
||||
if (!data || !lastSelectedImage) {
|
||||
return 0;
|
||||
}
|
||||
return imageDTOs.findIndex((i) => i.image_name === lastSelectedImage.image_name);
|
||||
}, [imageDTOs, lastSelectedImage]);
|
||||
return imagesSelectors.selectAll(data).findIndex((i) => i.image_name === lastSelectedImage.image_name);
|
||||
}, [lastSelectedImage, data]);
|
||||
|
||||
const handleNavigation = useCallback(
|
||||
(direction: 'left' | 'right' | 'up' | 'down', alt?: boolean) => {
|
||||
const { index, image } = getImageFuncs[direction](imageDTOs, lastSelectedImageIndex);
|
||||
if (!data) {
|
||||
return;
|
||||
}
|
||||
const { index, image } = getImageFuncs[direction](imagesSelectors.selectAll(data), lastSelectedImageIndex);
|
||||
if (!image || index === lastSelectedImageIndex) {
|
||||
return;
|
||||
}
|
||||
@@ -157,7 +161,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
}
|
||||
scrollToImage(image.image_name, index);
|
||||
},
|
||||
[imageDTOs, lastSelectedImageIndex, dispatch]
|
||||
[data, lastSelectedImageIndex, dispatch]
|
||||
);
|
||||
|
||||
const isOnFirstImage = useMemo(() => lastSelectedImageIndex === 0, [lastSelectedImageIndex]);
|
||||
@@ -172,14 +176,6 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
return lastSelectedImageIndex + imagesPerRow < loadedImagesCount;
|
||||
}, [lastSelectedImageIndex, loadedImagesCount]);
|
||||
|
||||
const isOnFirstImageOfView = useMemo(() => {
|
||||
return lastSelectedImageIndex === 0;
|
||||
}, [lastSelectedImageIndex]);
|
||||
|
||||
const isOnLastImageOfView = useMemo(() => {
|
||||
return lastSelectedImageIndex === loadedImagesCount - 1;
|
||||
}, [lastSelectedImageIndex, loadedImagesCount]);
|
||||
|
||||
const handleLeftImage = useCallback(
|
||||
(alt?: boolean) => {
|
||||
handleNavigation('left', alt);
|
||||
@@ -226,7 +222,5 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
areImagesBelowCurrent,
|
||||
nextImage,
|
||||
prevImage,
|
||||
isOnFirstImageOfView,
|
||||
isOnLastImageOfView,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { offsetChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback, useEffect, useMemo } from 'react';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
|
||||
export const useGalleryPagination = (pageButtonsPerSide: number = 2) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { offset, limit } = useAppSelector((s) => s.gallery);
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
|
||||
const { count, total } = useListImagesQuery(queryArgs, {
|
||||
selectFromResult: ({ data }) => ({ count: data?.items.length ?? 0, total: data?.total ?? 0 }),
|
||||
});
|
||||
|
||||
const currentPage = useMemo(() => Math.ceil(offset / (limit || 0)), [offset, limit]);
|
||||
const pages = useMemo(() => Math.ceil(total / (limit || 0)), [total, limit]);
|
||||
|
||||
const isNextEnabled = useMemo(() => {
|
||||
if (!count) {
|
||||
return false;
|
||||
}
|
||||
return currentPage + 1 < pages;
|
||||
}, [count, currentPage, pages]);
|
||||
const isPrevEnabled = useMemo(() => {
|
||||
if (!count) {
|
||||
return false;
|
||||
}
|
||||
return offset > 0;
|
||||
}, [count, offset]);
|
||||
|
||||
const goNext = useCallback(() => {
|
||||
dispatch(offsetChanged(offset + (limit || 0)));
|
||||
}, [dispatch, offset, limit]);
|
||||
|
||||
const goPrev = useCallback(() => {
|
||||
dispatch(offsetChanged(Math.max(offset - (limit || 0), 0)));
|
||||
}, [dispatch, offset, limit]);
|
||||
|
||||
const goToPage = useCallback(
|
||||
(page: number) => {
|
||||
dispatch(offsetChanged(page * (limit || 0)));
|
||||
},
|
||||
[dispatch, limit]
|
||||
);
|
||||
const goToFirst = useCallback(() => {
|
||||
dispatch(offsetChanged(0));
|
||||
}, [dispatch]);
|
||||
const goToLast = useCallback(() => {
|
||||
dispatch(offsetChanged((pages - 1) * (limit || 0)));
|
||||
}, [dispatch, pages, limit]);
|
||||
|
||||
// handle when total/pages decrease and user is on high page number (ie bulk removing or deleting)
|
||||
useEffect(() => {
|
||||
if (pages && currentPage + 1 > pages) {
|
||||
goToLast();
|
||||
}
|
||||
}, [currentPage, pages, goToLast]);
|
||||
|
||||
// calculate the page buttons to display - current page with 3 around it
|
||||
const pageButtons = useMemo(() => {
|
||||
const buttons = [];
|
||||
const maxPageButtons = pageButtonsPerSide * 2 + 1;
|
||||
let startPage = Math.max(currentPage - Math.floor(maxPageButtons / 2), 0);
|
||||
const endPage = Math.min(startPage + maxPageButtons - 1, pages - 1);
|
||||
|
||||
if (endPage - startPage < maxPageButtons - 1) {
|
||||
startPage = Math.max(endPage - maxPageButtons + 1, 0);
|
||||
}
|
||||
|
||||
for (let i = startPage; i <= endPage; i++) {
|
||||
buttons.push(i);
|
||||
}
|
||||
|
||||
return buttons;
|
||||
}, [currentPage, pageButtonsPerSide, pages]);
|
||||
|
||||
const isFirstEnabled = useMemo(() => currentPage > 0, [currentPage]);
|
||||
const isLastEnabled = useMemo(() => currentPage < pages - 1, [currentPage, pages]);
|
||||
|
||||
const rangeDisplay = useMemo(() => {
|
||||
const startItem = currentPage * (limit || 0) + 1;
|
||||
const endItem = Math.min((currentPage + 1) * (limit || 0), total);
|
||||
return `${startItem}-${endItem} of ${total}`;
|
||||
}, [total, currentPage, limit]);
|
||||
|
||||
const numberOnPage = useMemo(() => {
|
||||
return Math.min((currentPage + 1) * (limit || 0), total);
|
||||
}, [currentPage, limit, total]);
|
||||
|
||||
const api = useMemo(
|
||||
() => ({
|
||||
count,
|
||||
total,
|
||||
currentPage,
|
||||
pages,
|
||||
isNextEnabled,
|
||||
isPrevEnabled,
|
||||
goNext,
|
||||
goPrev,
|
||||
goToPage,
|
||||
goToFirst,
|
||||
goToLast,
|
||||
pageButtons,
|
||||
isFirstEnabled,
|
||||
isLastEnabled,
|
||||
rangeDisplay,
|
||||
numberOnPage,
|
||||
}),
|
||||
[
|
||||
count,
|
||||
total,
|
||||
currentPage,
|
||||
pages,
|
||||
isNextEnabled,
|
||||
isPrevEnabled,
|
||||
goNext,
|
||||
goPrev,
|
||||
goToPage,
|
||||
goToFirst,
|
||||
goToLast,
|
||||
pageButtons,
|
||||
isFirstEnabled,
|
||||
isLastEnabled,
|
||||
rangeDisplay,
|
||||
numberOnPage,
|
||||
]
|
||||
);
|
||||
|
||||
return api;
|
||||
};
|
||||
@@ -1,5 +1,3 @@
|
||||
import type { SkipToken } from '@reduxjs/toolkit/query';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
@@ -12,15 +10,11 @@ export const selectLastSelectedImage = createMemoizedSelector(
|
||||
|
||||
export const selectListImagesQueryArgs = createMemoizedSelector(
|
||||
selectGallerySlice,
|
||||
(gallery): ListImagesArgs | SkipToken =>
|
||||
gallery.limit
|
||||
? {
|
||||
board_id: gallery.selectedBoardId,
|
||||
categories: gallery.galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES,
|
||||
offset: gallery.offset,
|
||||
limit: gallery.limit,
|
||||
is_intermediate: false,
|
||||
search_term: gallery.searchTerm,
|
||||
}
|
||||
: skipToken
|
||||
(gallery): ListImagesArgs => ({
|
||||
board_id: gallery.selectedBoardId,
|
||||
categories: gallery.galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES,
|
||||
offset: gallery.offset,
|
||||
limit: gallery.limit,
|
||||
is_intermediate: false,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -7,7 +7,7 @@ import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import type { BoardId, ComparisonMode, GalleryState, GalleryView } from './types';
|
||||
import { IMAGE_LIMIT } from './types';
|
||||
import { IMAGE_LIMIT, INITIAL_IMAGE_LIMIT } from './types';
|
||||
|
||||
const initialGalleryState: GalleryState = {
|
||||
selection: [],
|
||||
@@ -19,7 +19,7 @@ const initialGalleryState: GalleryState = {
|
||||
selectedBoardId: 'none',
|
||||
galleryView: 'images',
|
||||
boardSearchText: '',
|
||||
limit: 20,
|
||||
limit: INITIAL_IMAGE_LIMIT,
|
||||
offset: 0,
|
||||
isImageViewerOpen: true,
|
||||
imageToCompare: null,
|
||||
@@ -72,6 +72,7 @@ export const gallerySlice = createSlice({
|
||||
state.selectedBoardId = action.payload.boardId;
|
||||
state.galleryView = 'images';
|
||||
state.offset = 0;
|
||||
state.limit = INITIAL_IMAGE_LIMIT;
|
||||
},
|
||||
autoAddBoardIdChanged: (state, action: PayloadAction<BoardId>) => {
|
||||
if (!action.payload) {
|
||||
@@ -83,11 +84,20 @@ export const gallerySlice = createSlice({
|
||||
galleryViewChanged: (state, action: PayloadAction<GalleryView>) => {
|
||||
state.galleryView = action.payload;
|
||||
state.offset = 0;
|
||||
state.limit = IMAGE_LIMIT;
|
||||
state.limit = INITIAL_IMAGE_LIMIT;
|
||||
},
|
||||
boardSearchTextChanged: (state, action: PayloadAction<string>) => {
|
||||
state.boardSearchText = action.payload;
|
||||
},
|
||||
moreImagesLoaded: (state) => {
|
||||
if (state.offset === 0 && state.limit === INITIAL_IMAGE_LIMIT) {
|
||||
state.offset = INITIAL_IMAGE_LIMIT;
|
||||
state.limit = IMAGE_LIMIT;
|
||||
} else {
|
||||
state.offset += IMAGE_LIMIT;
|
||||
state.limit += IMAGE_LIMIT;
|
||||
}
|
||||
},
|
||||
alwaysShowImageSizeBadgeChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.alwaysShowImageSizeBadge = action.payload;
|
||||
},
|
||||
@@ -104,15 +114,6 @@ export const gallerySlice = createSlice({
|
||||
comparisonFitChanged: (state, action: PayloadAction<'contain' | 'fill'>) => {
|
||||
state.comparisonFit = action.payload;
|
||||
},
|
||||
offsetChanged: (state, action: PayloadAction<number>) => {
|
||||
state.offset = action.payload;
|
||||
},
|
||||
limitChanged: (state, action: PayloadAction<number>) => {
|
||||
state.limit = action.payload;
|
||||
},
|
||||
searchTermChanged: (state, action: PayloadAction<string | undefined>) => {
|
||||
state.searchTerm = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
|
||||
@@ -148,6 +149,7 @@ export const {
|
||||
galleryViewChanged,
|
||||
selectionChanged,
|
||||
boardSearchTextChanged,
|
||||
moreImagesLoaded,
|
||||
alwaysShowImageSizeBadgeChanged,
|
||||
isImageViewerOpenChanged,
|
||||
imageToCompareChanged,
|
||||
@@ -155,9 +157,6 @@ export const {
|
||||
comparedImagesSwapped,
|
||||
comparisonFitChanged,
|
||||
comparisonModeCycled,
|
||||
offsetChanged,
|
||||
limitChanged,
|
||||
searchTermChanged,
|
||||
} = gallerySlice.actions;
|
||||
|
||||
const isAnyBoardDeleted = isAnyOf(
|
||||
|
||||
@@ -2,7 +2,8 @@ import type { ImageCategory, ImageDTO } from 'services/api/types';
|
||||
|
||||
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
|
||||
export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other'];
|
||||
export const IMAGE_LIMIT = 15;
|
||||
export const INITIAL_IMAGE_LIMIT = 100;
|
||||
export const IMAGE_LIMIT = 20;
|
||||
|
||||
export type GalleryView = 'images' | 'assets';
|
||||
export type BoardId = 'none' | (string & Record<never, never>);
|
||||
@@ -20,7 +21,6 @@ export type GalleryState = {
|
||||
boardSearchText: string;
|
||||
offset: number;
|
||||
limit: number;
|
||||
searchTerm?: string;
|
||||
alwaysShowImageSizeBadge: boolean;
|
||||
imageToCompare: ImageDTO | null;
|
||||
comparisonMode: ComparisonMode;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,18 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
|
||||
export const useBoardTotal = (board_id: BoardId) => {
|
||||
const galleryView = useAppSelector((s) => s.gallery.galleryView);
|
||||
|
||||
const { data: totalImages } = useGetBoardImagesTotalQuery(board_id);
|
||||
const { data: totalAssets } = useGetBoardAssetsTotalQuery(board_id);
|
||||
|
||||
const currentViewTotal = useMemo(
|
||||
() => (galleryView === 'images' ? totalImages?.total : totalAssets?.total),
|
||||
[galleryView, totalAssets, totalImages]
|
||||
);
|
||||
|
||||
return { totalImages, totalAssets, currentViewTotal };
|
||||
};
|
||||
@@ -7283,144 +7283,144 @@ export type components = {
|
||||
project_id: string | null;
|
||||
};
|
||||
InvocationOutputMap: {
|
||||
image_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||
sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
||||
latents_collection: components["schemas"]["LatentsCollectionOutput"];
|
||||
metadata: components["schemas"]["MetadataOutput"];
|
||||
invert_tensor_mask: components["schemas"]["MaskOutput"];
|
||||
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
|
||||
midas_depth_image_processor: components["schemas"]["ImageOutput"];
|
||||
lscale: components["schemas"]["LatentsOutput"];
|
||||
string_split: components["schemas"]["String2Output"];
|
||||
integer_collection: components["schemas"]["IntegerCollectionOutput"];
|
||||
boolean_collection: components["schemas"]["BooleanCollectionOutput"];
|
||||
noise: components["schemas"]["NoiseOutput"];
|
||||
float_math: components["schemas"]["FloatOutput"];
|
||||
seamless: components["schemas"]["SeamlessModeOutput"];
|
||||
img_lerp: components["schemas"]["ImageOutput"];
|
||||
img_blur: components["schemas"]["ImageOutput"];
|
||||
string_join: components["schemas"]["StringOutput"];
|
||||
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
|
||||
mul: components["schemas"]["IntegerOutput"];
|
||||
l2i: components["schemas"]["ImageOutput"];
|
||||
img_chan: components["schemas"]["ImageOutput"];
|
||||
conditioning_collection: components["schemas"]["ConditioningCollectionOutput"];
|
||||
blank_image: components["schemas"]["ImageOutput"];
|
||||
ip_adapter: components["schemas"]["IPAdapterOutput"];
|
||||
tile_image_processor: components["schemas"]["ImageOutput"];
|
||||
integer_math: components["schemas"]["IntegerOutput"];
|
||||
infill_tile: components["schemas"]["ImageOutput"];
|
||||
mask_edge: components["schemas"]["ImageOutput"];
|
||||
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
|
||||
color_correct: components["schemas"]["ImageOutput"];
|
||||
save_image: components["schemas"]["ImageOutput"];
|
||||
show_image: components["schemas"]["ImageOutput"];
|
||||
float: components["schemas"]["FloatOutput"];
|
||||
prompt_from_file: components["schemas"]["StringCollectionOutput"];
|
||||
merge_metadata: components["schemas"]["MetadataOutput"];
|
||||
img_scale: components["schemas"]["ImageOutput"];
|
||||
string_join_three: components["schemas"]["StringOutput"];
|
||||
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
|
||||
freeu: components["schemas"]["UNetOutput"];
|
||||
img_channel_multiply: components["schemas"]["ImageOutput"];
|
||||
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||
img_conv: components["schemas"]["ImageOutput"];
|
||||
segment_anything_processor: components["schemas"]["ImageOutput"];
|
||||
latents: components["schemas"]["LatentsOutput"];
|
||||
lineart_image_processor: components["schemas"]["ImageOutput"];
|
||||
hed_image_processor: components["schemas"]["ImageOutput"];
|
||||
infill_lama: components["schemas"]["ImageOutput"];
|
||||
infill_patchmatch: components["schemas"]["ImageOutput"];
|
||||
float_collection: components["schemas"]["FloatCollectionOutput"];
|
||||
denoise_latents: components["schemas"]["LatentsOutput"];
|
||||
metadata: components["schemas"]["MetadataOutput"];
|
||||
compel: components["schemas"]["ConditioningOutput"];
|
||||
img_blur: components["schemas"]["ImageOutput"];
|
||||
img_crop: components["schemas"]["ImageOutput"];
|
||||
sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
||||
img_ilerp: components["schemas"]["ImageOutput"];
|
||||
img_paste: components["schemas"]["ImageOutput"];
|
||||
core_metadata: components["schemas"]["MetadataOutput"];
|
||||
lora_collection_loader: components["schemas"]["LoRALoaderOutput"];
|
||||
lora_selector: components["schemas"]["LoRASelectorOutput"];
|
||||
create_denoise_mask: components["schemas"]["DenoiseMaskOutput"];
|
||||
rectangle_mask: components["schemas"]["MaskOutput"];
|
||||
noise: components["schemas"]["NoiseOutput"];
|
||||
float_to_int: components["schemas"]["IntegerOutput"];
|
||||
esrgan: components["schemas"]["ImageOutput"];
|
||||
merge_tiles_to_image: components["schemas"]["ImageOutput"];
|
||||
prompt_from_file: components["schemas"]["StringCollectionOutput"];
|
||||
infill_rgba: components["schemas"]["ImageOutput"];
|
||||
sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
||||
lora_loader: components["schemas"]["LoRALoaderOutput"];
|
||||
iterate: components["schemas"]["IterateInvocationOutput"];
|
||||
t2i_adapter: components["schemas"]["T2IAdapterOutput"];
|
||||
color_map_image_processor: components["schemas"]["ImageOutput"];
|
||||
blank_image: components["schemas"]["ImageOutput"];
|
||||
normalbae_image_processor: components["schemas"]["ImageOutput"];
|
||||
canvas_paste_back: components["schemas"]["ImageOutput"];
|
||||
string_split_neg: components["schemas"]["StringPosNegOutput"];
|
||||
img_channel_offset: components["schemas"]["ImageOutput"];
|
||||
face_mask_detection: components["schemas"]["FaceMaskOutput"];
|
||||
cv_inpaint: components["schemas"]["ImageOutput"];
|
||||
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
|
||||
invert_tensor_mask: components["schemas"]["MaskOutput"];
|
||||
tomask: components["schemas"]["ImageOutput"];
|
||||
main_model_loader: components["schemas"]["ModelLoaderOutput"];
|
||||
img_watermark: components["schemas"]["ImageOutput"];
|
||||
img_pad_crop: components["schemas"]["ImageOutput"];
|
||||
random_range: components["schemas"]["IntegerCollectionOutput"];
|
||||
mlsd_image_processor: components["schemas"]["ImageOutput"];
|
||||
merge_metadata: components["schemas"]["MetadataOutput"];
|
||||
string_join: components["schemas"]["StringOutput"];
|
||||
vae_loader: components["schemas"]["VAEOutput"];
|
||||
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
|
||||
calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"];
|
||||
mask_from_id: components["schemas"]["ImageOutput"];
|
||||
zoe_depth_image_processor: components["schemas"]["ImageOutput"];
|
||||
img_resize: components["schemas"]["ImageOutput"];
|
||||
string_replace: components["schemas"]["StringOutput"];
|
||||
face_identifier: components["schemas"]["ImageOutput"];
|
||||
canny_image_processor: components["schemas"]["ImageOutput"];
|
||||
collect: components["schemas"]["CollectInvocationOutput"];
|
||||
calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"];
|
||||
color: components["schemas"]["ColorOutput"];
|
||||
sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"];
|
||||
mediapipe_face_processor: components["schemas"]["ImageOutput"];
|
||||
pidi_image_processor: components["schemas"]["ImageOutput"];
|
||||
div: components["schemas"]["IntegerOutput"];
|
||||
range_of_size: components["schemas"]["IntegerCollectionOutput"];
|
||||
img_resize: components["schemas"]["ImageOutput"];
|
||||
img_watermark: components["schemas"]["ImageOutput"];
|
||||
esrgan: components["schemas"]["ImageOutput"];
|
||||
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
|
||||
img_paste: components["schemas"]["ImageOutput"];
|
||||
face_identifier: components["schemas"]["ImageOutput"];
|
||||
create_denoise_mask: components["schemas"]["DenoiseMaskOutput"];
|
||||
content_shuffle_image_processor: components["schemas"]["ImageOutput"];
|
||||
round_float: components["schemas"]["FloatOutput"];
|
||||
calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"];
|
||||
lscale: components["schemas"]["LatentsOutput"];
|
||||
rand_int: components["schemas"]["IntegerOutput"];
|
||||
infill_cv2: components["schemas"]["ImageOutput"];
|
||||
sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"];
|
||||
img_nsfw: components["schemas"]["ImageOutput"];
|
||||
main_model_loader: components["schemas"]["ModelLoaderOutput"];
|
||||
tomask: components["schemas"]["ImageOutput"];
|
||||
string_replace: components["schemas"]["StringOutput"];
|
||||
face_off: components["schemas"]["FaceOffOutput"];
|
||||
string: components["schemas"]["StringOutput"];
|
||||
heuristic_resize: components["schemas"]["ImageOutput"];
|
||||
midas_depth_image_processor: components["schemas"]["ImageOutput"];
|
||||
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||
mask_combine: components["schemas"]["ImageOutput"];
|
||||
clip_skip: components["schemas"]["CLIPSkipInvocationOutput"];
|
||||
image: components["schemas"]["ImageOutput"];
|
||||
infill_rgba: components["schemas"]["ImageOutput"];
|
||||
img_hue_adjust: components["schemas"]["ImageOutput"];
|
||||
vae_loader: components["schemas"]["VAEOutput"];
|
||||
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||
segment_anything_processor: components["schemas"]["ImageOutput"];
|
||||
sub: components["schemas"]["IntegerOutput"];
|
||||
iterate: components["schemas"]["IterateInvocationOutput"];
|
||||
img_mul: components["schemas"]["ImageOutput"];
|
||||
denoise_latents: components["schemas"]["LatentsOutput"];
|
||||
lineart_image_processor: components["schemas"]["ImageOutput"];
|
||||
rand_float: components["schemas"]["FloatOutput"];
|
||||
rectangle_mask: components["schemas"]["MaskOutput"];
|
||||
lora_selector: components["schemas"]["LoRASelectorOutput"];
|
||||
pair_tile_image: components["schemas"]["PairTileImageOutput"];
|
||||
cv_inpaint: components["schemas"]["ImageOutput"];
|
||||
hed_image_processor: components["schemas"]["ImageOutput"];
|
||||
range: components["schemas"]["IntegerCollectionOutput"];
|
||||
img_pad_crop: components["schemas"]["ImageOutput"];
|
||||
string_split_neg: components["schemas"]["StringPosNegOutput"];
|
||||
string_collection: components["schemas"]["StringCollectionOutput"];
|
||||
zoe_depth_image_processor: components["schemas"]["ImageOutput"];
|
||||
save_image: components["schemas"]["ImageOutput"];
|
||||
img_ilerp: components["schemas"]["ImageOutput"];
|
||||
compel: components["schemas"]["ConditioningOutput"];
|
||||
unsharp_mask: components["schemas"]["ImageOutput"];
|
||||
image_collection: components["schemas"]["ImageCollectionOutput"];
|
||||
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
|
||||
float_to_int: components["schemas"]["IntegerOutput"];
|
||||
random_range: components["schemas"]["IntegerCollectionOutput"];
|
||||
ideal_size: components["schemas"]["IdealSizeOutput"];
|
||||
i2l: components["schemas"]["LatentsOutput"];
|
||||
infill_patchmatch: components["schemas"]["ImageOutput"];
|
||||
depth_anything_image_processor: components["schemas"]["ImageOutput"];
|
||||
infill_lama: components["schemas"]["ImageOutput"];
|
||||
mask_from_id: components["schemas"]["ImageOutput"];
|
||||
conditioning: components["schemas"]["ConditioningOutput"];
|
||||
lresize: components["schemas"]["LatentsOutput"];
|
||||
infill_tile: components["schemas"]["ImageOutput"];
|
||||
integer_collection: components["schemas"]["IntegerCollectionOutput"];
|
||||
img_lerp: components["schemas"]["ImageOutput"];
|
||||
step_param_easing: components["schemas"]["FloatCollectionOutput"];
|
||||
metadata_item: components["schemas"]["MetadataItemOutput"];
|
||||
controlnet: components["schemas"]["ControlOutput"];
|
||||
merge_tiles_to_image: components["schemas"]["ImageOutput"];
|
||||
boolean: components["schemas"]["BooleanOutput"];
|
||||
core_metadata: components["schemas"]["MetadataOutput"];
|
||||
img_channel_offset: components["schemas"]["ImageOutput"];
|
||||
model_identifier: components["schemas"]["ModelIdentifierOutput"];
|
||||
scheduler: components["schemas"]["SchedulerOutput"];
|
||||
lresize: components["schemas"]["LatentsOutput"];
|
||||
img_mul: components["schemas"]["ImageOutput"];
|
||||
create_gradient_mask: components["schemas"]["GradientMaskOutput"];
|
||||
color_map_image_processor: components["schemas"]["ImageOutput"];
|
||||
canvas_paste_back: components["schemas"]["ImageOutput"];
|
||||
mask_edge: components["schemas"]["ImageOutput"];
|
||||
lora_loader: components["schemas"]["LoRALoaderOutput"];
|
||||
float_collection: components["schemas"]["FloatCollectionOutput"];
|
||||
float_range: components["schemas"]["FloatCollectionOutput"];
|
||||
normalbae_image_processor: components["schemas"]["ImageOutput"];
|
||||
lblend: components["schemas"]["LatentsOutput"];
|
||||
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
|
||||
dynamic_prompt: components["schemas"]["StringCollectionOutput"];
|
||||
leres_image_processor: components["schemas"]["ImageOutput"];
|
||||
add: components["schemas"]["IntegerOutput"];
|
||||
img_scale: components["schemas"]["ImageOutput"];
|
||||
rand_float: components["schemas"]["FloatOutput"];
|
||||
tile_to_properties: components["schemas"]["TileToPropertiesOutput"];
|
||||
img_crop: components["schemas"]["ImageOutput"];
|
||||
integer: components["schemas"]["IntegerOutput"];
|
||||
calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"];
|
||||
range_of_size: components["schemas"]["IntegerCollectionOutput"];
|
||||
sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"];
|
||||
heuristic_resize: components["schemas"]["ImageOutput"];
|
||||
controlnet: components["schemas"]["ControlOutput"];
|
||||
string: components["schemas"]["StringOutput"];
|
||||
tile_image_processor: components["schemas"]["ImageOutput"];
|
||||
metadata_item: components["schemas"]["MetadataItemOutput"];
|
||||
freeu: components["schemas"]["UNetOutput"];
|
||||
round_float: components["schemas"]["FloatOutput"];
|
||||
conditioning: components["schemas"]["ConditioningOutput"];
|
||||
ideal_size: components["schemas"]["IdealSizeOutput"];
|
||||
float: components["schemas"]["FloatOutput"];
|
||||
conditioning_collection: components["schemas"]["ConditioningCollectionOutput"];
|
||||
alpha_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||
integer_math: components["schemas"]["IntegerOutput"];
|
||||
string_collection: components["schemas"]["StringCollectionOutput"];
|
||||
img_conv: components["schemas"]["ImageOutput"];
|
||||
img_channel_multiply: components["schemas"]["ImageOutput"];
|
||||
lblend: components["schemas"]["LatentsOutput"];
|
||||
color: components["schemas"]["ColorOutput"];
|
||||
image: components["schemas"]["ImageOutput"];
|
||||
sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"];
|
||||
image_collection: components["schemas"]["ImageCollectionOutput"];
|
||||
model_identifier: components["schemas"]["ModelIdentifierOutput"];
|
||||
l2i: components["schemas"]["ImageOutput"];
|
||||
seamless: components["schemas"]["SeamlessModeOutput"];
|
||||
boolean_collection: components["schemas"]["BooleanCollectionOutput"];
|
||||
string_join_three: components["schemas"]["StringOutput"];
|
||||
ip_adapter: components["schemas"]["IPAdapterOutput"];
|
||||
add: components["schemas"]["IntegerOutput"];
|
||||
crop_latents: components["schemas"]["LatentsOutput"];
|
||||
mlsd_image_processor: components["schemas"]["ImageOutput"];
|
||||
float_range: components["schemas"]["FloatCollectionOutput"];
|
||||
mul: components["schemas"]["IntegerOutput"];
|
||||
dw_openpose_image_processor: components["schemas"]["ImageOutput"];
|
||||
boolean: components["schemas"]["BooleanOutput"];
|
||||
dynamic_prompt: components["schemas"]["StringCollectionOutput"];
|
||||
mediapipe_face_processor: components["schemas"]["ImageOutput"];
|
||||
i2l: components["schemas"]["LatentsOutput"];
|
||||
latents_collection: components["schemas"]["LatentsCollectionOutput"];
|
||||
integer: components["schemas"]["IntegerOutput"];
|
||||
img_chan: components["schemas"]["ImageOutput"];
|
||||
pair_tile_image: components["schemas"]["PairTileImageOutput"];
|
||||
unsharp_mask: components["schemas"]["ImageOutput"];
|
||||
img_hue_adjust: components["schemas"]["ImageOutput"];
|
||||
lineart_anime_image_processor: components["schemas"]["ImageOutput"];
|
||||
face_off: components["schemas"]["FaceOffOutput"];
|
||||
mask_combine: components["schemas"]["ImageOutput"];
|
||||
leres_image_processor: components["schemas"]["ImageOutput"];
|
||||
image_mask_to_tensor: components["schemas"]["MaskOutput"];
|
||||
sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||
scheduler: components["schemas"]["SchedulerOutput"];
|
||||
sub: components["schemas"]["IntegerOutput"];
|
||||
pidi_image_processor: components["schemas"]["ImageOutput"];
|
||||
infill_cv2: components["schemas"]["ImageOutput"];
|
||||
div: components["schemas"]["IntegerOutput"];
|
||||
img_nsfw: components["schemas"]["ImageOutput"];
|
||||
depth_anything_image_processor: components["schemas"]["ImageOutput"];
|
||||
sdxl_compel_prompt: components["schemas"]["ConditioningOutput"];
|
||||
range: components["schemas"]["IntegerCollectionOutput"];
|
||||
rand_int: components["schemas"]["IntegerOutput"];
|
||||
float_math: components["schemas"]["FloatOutput"];
|
||||
};
|
||||
/**
|
||||
* InvocationStartedEvent
|
||||
@@ -14108,7 +14108,7 @@ export type operations = {
|
||||
install_hugging_face_model: {
|
||||
parameters: {
|
||||
query: {
|
||||
/** @description HuggingFace repo_id to install */
|
||||
/** @description Hugging Face repo_id to install */
|
||||
source: string;
|
||||
};
|
||||
};
|
||||
@@ -14698,8 +14698,6 @@ export type operations = {
|
||||
offset?: number;
|
||||
/** @description The number of images per page */
|
||||
limit?: number;
|
||||
/** @description The term to search for */
|
||||
search_term?: string | null;
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import type { components, paths } from 'services/api/schema';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
|
||||
export type S = components['schemas'];
|
||||
|
||||
export type ImageCache = EntityState<ImageDTO, string>;
|
||||
|
||||
export type ListImagesArgs = NonNullable<paths['/api/v1/images/']['get']['parameters']['query']>;
|
||||
export type ListImagesResponse = paths['/api/v1/images/']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
export type DeleteBoardResult =
|
||||
paths['/api/v1/boards/{board_id}']['delete']['responses']['200']['content']['application/json'];
|
||||
|
||||
@@ -1,8 +1,56 @@
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import { dateComparator } from 'common/util/dateComparator';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import queryString from 'query-string';
|
||||
import { buildV1Url } from 'services/api';
|
||||
|
||||
import type { ImageDTO, ListImagesArgs } from './types';
|
||||
import type { ImageCache, ImageDTO, ListImagesArgs } from './types';
|
||||
|
||||
export const getIsImageInDateRange = (data: ImageCache | undefined, imageDTO: ImageDTO) => {
|
||||
if (!data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const totalCachedImageDtos = imagesSelectors.selectAll(data);
|
||||
|
||||
if (totalCachedImageDtos.length <= 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const cachedStarredImages = [];
|
||||
const cachedUnstarredImages = [];
|
||||
|
||||
for (let index = 0; index < totalCachedImageDtos.length; index++) {
|
||||
const image = totalCachedImageDtos[index];
|
||||
if (image?.starred) {
|
||||
cachedStarredImages.push(image);
|
||||
}
|
||||
if (!image?.starred) {
|
||||
cachedUnstarredImages.push(image);
|
||||
}
|
||||
}
|
||||
|
||||
if (imageDTO.starred) {
|
||||
const lastStarredImage = cachedStarredImages[cachedStarredImages.length - 1];
|
||||
// if starring or already starred, want to look in list of starred images
|
||||
if (!lastStarredImage) {
|
||||
return true;
|
||||
} // no starred images showing, so always show this one
|
||||
const createdDate = new Date(imageDTO.created_at);
|
||||
const oldestDate = new Date(lastStarredImage.created_at);
|
||||
return createdDate >= oldestDate;
|
||||
} else {
|
||||
const lastUnstarredImage = cachedUnstarredImages[cachedUnstarredImages.length - 1];
|
||||
// if unstarring or already unstarred, want to look in list of unstarred images
|
||||
if (!lastUnstarredImage) {
|
||||
return false;
|
||||
} // no unstarred images showing, so don't show this one
|
||||
const createdDate = new Date(imageDTO.created_at);
|
||||
const oldestDate = new Date(lastUnstarredImage.created_at);
|
||||
return createdDate >= oldestDate;
|
||||
}
|
||||
};
|
||||
|
||||
export const getCategories = (imageDTO: ImageDTO) => {
|
||||
if (IMAGE_CATEGORIES.includes(imageDTO.image_category)) {
|
||||
@@ -11,6 +59,25 @@ export const getCategories = (imageDTO: ImageDTO) => {
|
||||
return ASSETS_CATEGORIES;
|
||||
};
|
||||
|
||||
// The adapter is not actually the data store - it just provides helper functions to interact
|
||||
// with some other store of data. We will use the RTK Query cache as that store.
|
||||
export const imagesAdapter = createEntityAdapter<ImageDTO, string>({
|
||||
selectId: (image) => image.image_name,
|
||||
sortComparer: (a, b) => {
|
||||
// Compare starred images first
|
||||
if (a.starred && !b.starred) {
|
||||
return -1;
|
||||
}
|
||||
if (!a.starred && b.starred) {
|
||||
return 1;
|
||||
}
|
||||
return dateComparator(b.created_at, a.created_at);
|
||||
},
|
||||
});
|
||||
|
||||
// Create selectors for the adapter.
|
||||
export const imagesSelectors = imagesAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
|
||||
// Helper to create the url for the listImages endpoint. Also we use it to create the cache key.
|
||||
export const getListImagesUrl = (queryArgs: ListImagesArgs) =>
|
||||
buildV1Url(`images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`);
|
||||
|
||||
54
scripts/populate_model_db_from_yaml.py
Executable file
54
scripts/populate_model_db_from_yaml.py
Executable file
@@ -0,0 +1,54 @@
|
||||
#!/bin/env python
|
||||
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.model_install import ModelInstallService
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
def get_args() -> Namespace:
|
||||
parser = ArgumentParser(description="Update models database from yaml file")
|
||||
parser.add_argument("--root", type=Path, required=False, default=None)
|
||||
parser.add_argument("--yaml_file", type=Path, required=False, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def populate_config() -> InvokeAIAppConfig:
|
||||
args = get_args()
|
||||
config = get_config()
|
||||
if args.root:
|
||||
config._root = args.root
|
||||
if args.yaml_file:
|
||||
config.legacy_models_yaml_path = args.yaml_file
|
||||
else:
|
||||
config.legacy_models_yaml_path = config.root_path / "configs/models.yaml"
|
||||
return config
|
||||
|
||||
|
||||
def initialize_installer(config: InvokeAIAppConfig) -> ModelInstallService:
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
db = SqliteDatabase(config.db_path, logger)
|
||||
record_store = ModelRecordServiceSQL(db)
|
||||
queue = DownloadQueueService()
|
||||
queue.start()
|
||||
installer = ModelInstallService(app_config=config, record_store=record_store, download_queue=queue)
|
||||
return installer
|
||||
|
||||
|
||||
def main() -> None:
|
||||
config = populate_config()
|
||||
installer = initialize_installer(config)
|
||||
installer._migrate_yaml(rename_yaml=False, overwrite_db=True)
|
||||
print("\n<INSTALLED MODELS>")
|
||||
print("\t".join(["key", "name", "type", "path"]))
|
||||
for model in installer.record_store.all_models():
|
||||
print("\t".join([model.key, model.name, model.type, (config.models_path / model.path).as_posix()]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,13 +14,14 @@ def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Pat
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = mm2_model_manager.install.register_path(embedding_file)
|
||||
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
|
||||
assert loaded_model is not None
|
||||
assert loaded_model.config.key == key
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
with mm2_model_manager.load.ram_cache.reserve_execution_device():
|
||||
loaded_model = mm2_model_manager.load.load_model(store.get_model(key))
|
||||
assert loaded_model is not None
|
||||
assert loaded_model.config.key == key
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
|
||||
config = mm2_model_manager.store.get_model(key)
|
||||
loaded_model_2 = mm2_model_manager.load.load_model(config)
|
||||
config = mm2_model_manager.store.get_model(key)
|
||||
loaded_model_2 = mm2_model_manager.load.load_model(config)
|
||||
|
||||
assert loaded_model.config.key == loaded_model_2.config.key
|
||||
assert loaded_model.config.key == loaded_model_2.config.key
|
||||
|
||||
@@ -89,11 +89,10 @@ def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase:
|
||||
ram_cache = ModelCache(
|
||||
logger=InvokeAILogger.get_logger(),
|
||||
max_cache_size=mm2_app_config.ram,
|
||||
max_vram_cache_size=mm2_app_config.vram,
|
||||
)
|
||||
convert_cache = ModelConvertCache(mm2_app_config.convert_cache_path)
|
||||
return ModelLoadService(
|
||||
|
||||
@@ -8,7 +8,9 @@ import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.backend.model_manager.load import ModelCache
|
||||
from invokeai.backend.util.devices import TorchDevice, choose_precision, choose_torch_device, torch_dtype
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
devices = ["cpu", "cuda:0", "cuda:1", "mps"]
|
||||
device_types_cpu = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", torch.float32)]
|
||||
@@ -20,6 +22,7 @@ device_types_mps = [("cpu", torch.float32), ("cuda:0", torch.float32), ("mps", t
|
||||
def test_device_choice(device_name):
|
||||
config = get_config()
|
||||
config.device = device_name
|
||||
TorchDevice.set_model_cache(None) # disable dynamic selection of GPU device
|
||||
torch_device = TorchDevice.choose_torch_device()
|
||||
assert torch_device == torch.device(device_name)
|
||||
|
||||
@@ -130,3 +133,32 @@ def test_legacy_precision_name():
|
||||
assert "float16" == choose_precision(torch.device("cuda"))
|
||||
assert "float16" == choose_precision(torch.device("mps"))
|
||||
assert "float32" == choose_precision(torch.device("cpu"))
|
||||
|
||||
|
||||
def test_multi_device_support_1():
|
||||
config = get_config()
|
||||
config.devices = ["cuda:0", "cuda:1"]
|
||||
assert TorchDevice.execution_devices() == {torch.device("cuda:0"), torch.device("cuda:1")}
|
||||
|
||||
|
||||
def test_multi_device_support_2():
|
||||
config = get_config()
|
||||
config.devices = None
|
||||
with (
|
||||
patch("torch.cuda.device_count", return_value=3),
|
||||
patch("torch.cuda.is_available", return_value=True),
|
||||
):
|
||||
assert TorchDevice.execution_devices() == {
|
||||
torch.device("cuda:0"),
|
||||
torch.device("cuda:1"),
|
||||
torch.device("cuda:2"),
|
||||
}
|
||||
|
||||
|
||||
def test_multi_device_support_3():
|
||||
config = get_config()
|
||||
config.devices = ["cuda:0", "cuda:1"]
|
||||
cache = ModelCache()
|
||||
with cache.reserve_execution_device() as gpu:
|
||||
assert gpu in [torch.device(x) for x in config.devices]
|
||||
assert TorchDevice.choose_torch_device() == gpu
|
||||
|
||||
@@ -17,7 +17,6 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.images.images_default import ImageService
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403
|
||||
@@ -49,13 +48,13 @@ def mock_services() -> InvocationServices:
|
||||
model_manager=None, # type: ignore
|
||||
download_queue=None, # type: ignore
|
||||
names=None, # type: ignore
|
||||
performance_statistics=InvocationStatsService(),
|
||||
session_processor=None, # type: ignore
|
||||
session_queue=None, # type: ignore
|
||||
urls=None, # type: ignore
|
||||
workflow_records=None, # type: ignore
|
||||
tensors=None, # type: ignore
|
||||
conditioning=None, # type: ignore
|
||||
performance_statistics=None, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -92,7 +92,6 @@ def test_migrate_v3_config_from_file(tmp_path: Path, patch_rootdir: None):
|
||||
assert config.host == "192.168.1.1"
|
||||
assert config.port == 8080
|
||||
assert config.ram == 100
|
||||
assert config.vram == 50
|
||||
assert config.legacy_models_yaml_path == Path("/custom/models.yaml")
|
||||
# This should be stripped out
|
||||
assert not hasattr(config, "esrgan")
|
||||
|
||||
Reference in New Issue
Block a user