Compare commits

..

21 Commits

Author SHA1 Message Date
Lincoln Stein
39881d3d7d fix installer logic for tokenizer_3 and text_encoder_3 2024-06-21 23:34:18 -04:00
Lincoln Stein
28f1d25973 unpin dependencies; fix typo in sd3.py 2024-06-21 15:59:47 -04:00
Lincoln Stein
95377ea159 add non-commercial use message to sd3 starter; rebuild frontend 2024-06-20 21:59:28 -04:00
Lincoln Stein
445561e3a4 add sd3 to starter models 2024-06-20 18:13:46 -04:00
blessedcoolant
66260fd345 fix: Update Clip 3 slot title & lint issues 2024-06-20 08:53:35 +05:30
blessedcoolant
c403efa83f fix: Make TE5 Optional 2024-06-20 08:45:36 +05:30
blessedcoolant
cd99ef2f46 Merge branch 'main' into lstein/feat/sd3-model-loading 2024-06-20 08:43:34 +05:30
Lincoln Stein
9dce4f09ae scale default RAM cache by size of system evirtual memory 2024-06-18 13:49:12 -04:00
blessedcoolant
22b5c036aa Revert "fix: height and weight not working on sd3 node"
This reverts commit be14fd59c9.
2024-06-17 06:41:49 +05:30
blessedcoolant
be14fd59c9 fix: height and weight not working on sd3 node 2024-06-17 06:34:01 +05:30
Lincoln Stein
423057a2e8 add config variable to suppress loading of sd3 text_encoder_3 T5 model 2024-06-16 16:28:39 -04:00
blessedcoolant
f65d50a4dd wip: basic wrapper for generating sd3 images 2024-06-16 04:18:20 +05:30
Lincoln Stein
554809c647 return correct base type for sd3 VAEs 2024-06-15 18:17:03 -04:00
Lincoln Stein
ac0396e6f7 Merge branch 'lstein/feat/sd3-model-loading' of github.com:invoke-ai/InvokeAI into lstein/feat/sd3-model-loading 2024-06-14 16:48:20 -04:00
Lincoln Stein
78f704e7d5 tweak installer to select correct components of HF SD3 diffusers models 2024-06-14 16:46:24 -04:00
blessedcoolant
41236031b2 chore: remove unrequired changes to v1 workflow field types 2024-06-15 00:00:44 +05:30
blessedcoolant
ddbd2ebd9d wip: add Transformer Field to Node UI 2024-06-14 22:25:26 +05:30
blessedcoolant
0c970bc880 wip: add SD3 Model Loader Invocation 2024-06-14 22:21:09 +05:30
blessedcoolant
c79d9b9ecf wip: Add Initial support for select SD3 models in UI 2024-06-14 16:04:16 +05:30
Lincoln Stein
03b9d17d0b draft sd3 loading; probable VRAM leak when using quantized submodels 2024-06-13 00:51:00 -04:00
Lincoln Stein
002f8242a1 add draft SD3 probing; there is an issue with FromOriginalControlNetMixin in backend.util.hotfixes due to new diffusers 2024-06-12 22:44:34 -04:00
76 changed files with 2417 additions and 1150 deletions

View File

@@ -316,7 +316,6 @@ async def list_image_dtos(
),
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of image DTOs"""
@@ -327,7 +326,6 @@ async def list_image_dtos(
categories,
is_intermediate,
board_id,
search_term
)
return image_dtos

View File

@@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
MainModel = "MainModelField"
SDXLMainModel = "SDXLMainModelField"
SDXLRefinerModel = "SDXLRefinerModelField"
SD3MainModel = "SD3MainModelField"
ONNXModel = "ONNXModelField"
VAEModel = "VAEModelField"
LoRAModel = "LoRAModelField"
@@ -125,6 +126,7 @@ class FieldDescriptions:
noise = "Noise tensor"
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
unet = "UNet (scheduler, LoRAs)"
transformer = "Transformer"
vae = "VAE"
cond = "Conditioning tensor"
controlnet_model = "ControlNet model to load"
@@ -133,6 +135,7 @@ class FieldDescriptions:
main_model = "Main model (UNet, VAE, CLIP) to load"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
sd3_main_model = "SD3 Main Model (Transformer, CLIP1, CLIP2, CLIP3, VAE) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
lora_weight = "The weight at which the LoRA is applied to each model"
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"

View File

@@ -12,14 +12,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField, WithBoard, WithMetadata
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext

View File

@@ -8,13 +8,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
class ModelIdentifierField(BaseModel):
@@ -54,6 +48,11 @@ class UNetField(BaseModel):
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load unet submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
class CLIPField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
@@ -61,6 +60,15 @@ class CLIPField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class SD3CLIPField(BaseModel):
tokenizer_1: ModelIdentifierField = Field(description="Info to load tokenizer 1 submodel")
text_encoder_1: ModelIdentifierField = Field(description="Info to load text_encoder 1 submodel")
tokenizer_2: ModelIdentifierField = Field(description="Info to load tokenizer 2 submodel")
text_encoder_2: ModelIdentifierField = Field(description="Info to load text_encoder 2 submodel")
tokenizer_3: Optional[ModelIdentifierField] = Field(description="Info to load tokenizer 3 submodel")
text_encoder_3: Optional[ModelIdentifierField] = Field(description="Info to load text_encoder 3 submodel")
class VAEField(BaseModel):
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')

View File

@@ -0,0 +1,200 @@
from contextlib import ExitStack
from typing import Optional, cast
import torch
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
from pydantic import field_validator
from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Input,
invocation,
invocation_output,
)
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.denoise_latents import get_scheduler
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField, SD3CLIPField, TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.model_manager.config import SubModelType
sd3_pipeline: Optional[StableDiffusion3Pipeline] = None
class FakeVae:
class FakeVaeConfig:
def __init__(self) -> None:
self.block_out_channels = [0]
def __init__(self) -> None:
self.config = FakeVae.FakeVaeConfig()
@invocation_output("sd3_model_loader_output")
class SD3ModelLoaderOutput(BaseInvocationOutput):
"""Stable Diffuion 3 base model loader output"""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
clip: SD3CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation("sd3_model_loader", title="SD3 Main Model", tags=["model", "sd3"], category="model", version="1.0.0")
class SD3ModelLoaderInvocation(BaseInvocation):
"""Loads an SD3 base model, outputting its submodels."""
model: ModelIdentifierField = InputField(description=FieldDescriptions.sd3_main_model, ui_type=UIType.SD3MainModel)
def invoke(self, context: InvocationContext) -> SD3ModelLoaderOutput:
model_key = self.model.key
if not context.models.exists(model_key):
raise Exception(f"Unknown model: {model_key}")
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
tokenizer_1 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder_1 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
tokenizer_2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
text_encoder_2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
try:
tokenizer_3 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
text_encoder_3 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
except Exception:
tokenizer_3 = None
text_encoder_3 = None
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
return SD3ModelLoaderOutput(
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
clip=SD3CLIPField(
tokenizer_1=tokenizer_1,
text_encoder_1=text_encoder_1,
tokenizer_2=tokenizer_2,
text_encoder_2=text_encoder_2,
tokenizer_3=tokenizer_3,
text_encoder_3=text_encoder_3,
),
vae=VAEField(vae=vae),
)
@invocation(
"sd3_image_generator", title="Stable Diffusion 3", tags=["latent", "sd3"], category="latents", version="1.0.0"
)
class StableDiffusion3Invocation(BaseInvocation):
"""Generates an image using Stable Diffusion 3."""
transformer: TransformerField = InputField(
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
ui_order=0,
)
clip: SD3CLIPField = InputField(
description=FieldDescriptions.clip,
input=Input.Connection,
title="CLIP",
ui_order=1,
)
noise: Optional[LatentsField] = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
ui_order=2,
)
scheduler: SCHEDULER_NAME_VALUES = InputField(
default="euler_f",
description=FieldDescriptions.scheduler,
ui_type=UIType.Scheduler,
)
positive_prompt: str = InputField(default="", title="Positive Prompt")
negative_prompt: str = InputField(default="", title="Negative Prompt")
steps: int = InputField(default=20, gt=0, description=FieldDescriptions.steps)
guidance_scale: float = InputField(default=7.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
use_clip_3: bool = InputField(default=True, description="Use TE5 Encoder of SD3", title="Use TE5 Encoder")
seed: int = InputField(
default=0,
ge=0,
le=SEED_MAX,
description=FieldDescriptions.seed,
)
width: int = InputField(
default=1024,
multiple_of=LATENT_SCALE_FACTOR,
gt=0,
description=FieldDescriptions.width,
)
height: int = InputField(
default=1024,
multiple_of=LATENT_SCALE_FACTOR,
gt=0,
description=FieldDescriptions.height,
)
@field_validator("seed", mode="before")
def modulo_seed(cls, v: int):
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
return v % (SEED_MAX + 1)
def invoke(self, context: InvocationContext) -> LatentsOutput:
with ExitStack() as stack:
tokenizer_1 = stack.enter_context(context.models.load(self.clip.tokenizer_1))
tokenizer_2 = stack.enter_context(context.models.load(self.clip.tokenizer_2))
text_encoder_1 = stack.enter_context(context.models.load(self.clip.text_encoder_1))
text_encoder_2 = stack.enter_context(context.models.load(self.clip.text_encoder_2))
transformer = stack.enter_context(context.models.load(self.transformer.transformer))
assert isinstance(transformer, SD3Transformer2DModel)
assert isinstance(text_encoder_1, CLIPTextModelWithProjection)
assert isinstance(text_encoder_2, CLIPTextModelWithProjection)
assert isinstance(tokenizer_1, CLIPTokenizer)
assert isinstance(tokenizer_2, CLIPTokenizer)
if self.use_clip_3 and self.clip.tokenizer_3 and self.clip.text_encoder_3:
tokenizer_3 = stack.enter_context(context.models.load(self.clip.tokenizer_3))
text_encoder_3 = stack.enter_context(context.models.load(self.clip.text_encoder_3))
assert isinstance(text_encoder_3, T5EncoderModel)
assert isinstance(tokenizer_3, T5TokenizerFast)
else:
tokenizer_3 = None
text_encoder_3 = None
scheduler = get_scheduler(
context=context,
scheduler_info=self.transformer.scheduler,
scheduler_name=self.scheduler,
seed=self.seed,
)
sd3_pipeline = StableDiffusion3Pipeline(
transformer=transformer,
vae=FakeVae(),
text_encoder=text_encoder_1,
text_encoder_2=text_encoder_2,
text_encoder_3=text_encoder_3,
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
tokenizer_3=tokenizer_3,
scheduler=scheduler,
)
results = sd3_pipeline(
self.positive_prompt,
negative_prompt=self.negative_prompt,
num_inference_steps=self.steps,
guidance_scale=self.guidance_scale,
output_type="latent",
)
latents = cast(torch.Tensor, results.images[0])
latents = latents.unsqueeze(0)
latents_name = context.tensors.save(latents)
return LatentsOutput.build(latents_name, latents=latents, seed=self.seed)

View File

@@ -32,6 +32,7 @@ ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
SYSTEM_RAM_TO_CACHE_SIZE_FACTOR = 0.25 # after 60 GB, default ram cache will scale by this factor
CONFIG_SCHEMA_VERSION = "4.0.1"
@@ -45,7 +46,7 @@ def get_default_ram_cache_size() -> float:
max_ram = psutil.virtual_memory().total / GB
if max_ram >= 60:
return 15.0
return max_ram * SYSTEM_RAM_TO_CACHE_SIZE_FACTOR
if max_ram >= 30:
return 7.5
if max_ram >= 14:

View File

@@ -41,7 +41,6 @@ class ImageRecordStorageBase(ABC):
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass

View File

@@ -148,7 +148,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
@@ -209,13 +208,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""
query_params.append(board_id)
# Search term condition
if search_term:
query_conditions += """--sql
AND json_extract(images.metadata, '$') LIKE ?
"""
query_params.append(f'%{search_term}%')
query_pagination = """--sql
ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
"""

View File

@@ -120,7 +120,6 @@ class ImageServiceABC(ABC):
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass

View File

@@ -206,7 +206,6 @@ class ImageService(ImageServiceABC):
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self.__invoker.services.image_records.get_many(
@@ -216,7 +215,6 @@ class ImageService(ImageServiceABC):
categories,
is_intermediate,
board_id,
search_term
)
image_dtos = [

View File

@@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
StableDiffusion3 = "sd-3"
# Kandinsky2_1 = "kandinsky-2.1"
@@ -75,8 +76,11 @@ class SubModelType(str, Enum):
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
TextEncoder3 = "text_encoder_3"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Tokenizer3 = "tokenizer_3"
Transformer = "transformer"
VAE = "vae"
VAEDecoder = "vae_decoder"
VAEEncoder = "vae_encoder"

View File

@@ -84,6 +84,8 @@ class ModelLoader(ModelLoaderBase):
except IndexError:
pass
self._logger.info(f"Loading {config.key}:{submodel_type}")
cache_path: Path = self._convert_cache.cache_path(str(model_path))
if self._needs_conversion(config, model_path, cache_path):
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)

View File

@@ -73,6 +73,7 @@ class CacheRecord(Generic[T]):
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
is_quantized: bool = False
loaded: bool = False
_locks: int = 0

View File

@@ -60,9 +60,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
execution_device: torch.device = torch.device("cuda"),
storage_device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float16,
sequential_offload: bool = False,
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
log_memory_usage: bool = False,
logger: Optional[Logger] = None,
):
@@ -74,7 +72,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
:param precision: Precision for loaded models [torch.float16]
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
@@ -163,8 +160,18 @@ class ModelCache(ModelCacheBase[AnyModel]):
size = calc_model_size_by_data(model)
self.make_room(size)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
is_quantized = hasattr(model, "is_quantized") and model.is_quantized
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not is_quantized else None
cache_record = CacheRecord(
key=key,
model=model,
device=self._execution_device
if is_quantized
else self._storage_device, # quantized models are loaded directly into CUDA
is_quantized=is_quantized,
state_dict=state_dict,
size=size,
)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
@@ -233,8 +240,23 @@ class ModelCache(ModelCacheBase[AnyModel]):
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
# Special handling of the stable-diffusion-3:text_encoder_3
# submodel, when the user has loaded a quantized model.
# The only way to remove the quantized version of this model from VRAM is to
# delete it completely - it can't be moved from device to device
# This also contains a workaround for quantized models that
# persist indefinitely in VRAM
if cache_entry.is_quantized:
self._empty_quantized_state_dict(cache_entry.model)
cache_entry.model = None
self._delete_cache_entry(cache_entry)
vram_in_use = torch.cuda.memory_allocated() + size_required
continue
if not cache_entry.loaded:
continue
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False
@@ -242,7 +264,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
gc.collect()
TorchDevice.empty_cache()
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
@@ -256,7 +278,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
source_device = cache_entry.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# Note: We compare device types so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
@@ -407,3 +429,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
self._cache_stack.remove(cache_entry.key)
del self._cached_models[cache_entry.key]
del cache_entry
gc.collect()
TorchDevice.empty_cache()
def _empty_quantized_state_dict(self, model: AnyModel) -> None:
"""Set all keys of a model's state dict to None.
This is a partial workaround for a poorly-understood bug in
transformers' support for quantized T5EncoderModels (text_encoder_3
of SD3). This allows most of the model to be unloaded from VRAM, but
still leaks 8K of VRAM each time the model is unloaded. Using the quantized
version of stable-diffusion-3-medium is NOT recommended.
"""
assert isinstance(model, torch.nn.Module)
sd = model.state_dict()
for k in sd.keys():
sd[k] = None

View File

@@ -36,9 +36,11 @@ VARIANT_TO_IN_CHANNEL_MAP = {
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
"""Class to load main models."""
# note - will be removed for load_single_file()
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusion3: "SD3",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
@@ -65,7 +67,10 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
if variant and "no file named" in str(
e
): # try without the variant, just in case user's preferences changed
result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
result = load_class.from_pretrained(
model_path,
torch_dtype=self._torch_dtype,
)
else:
raise e

View File

@@ -100,6 +100,7 @@ class ModelProbe(object):
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
"StableDiffusionXLInpaintPipeline": ModelType.Main,
"LatentConsistencyModelPipeline": ModelType.Main,
"StableDiffusion3Pipeline": ModelType.Main,
"AutoencoderKL": ModelType.VAE,
"AutoencoderTiny": ModelType.VAE,
"ControlNetModel": ModelType.ControlNet,
@@ -298,10 +299,13 @@ class ModelProbe(object):
return possible_conf.absolute()
if model_type is ModelType.Main:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
config_file = f"stable-diffusion/{config_file}"
if base_type is BaseModelType.StableDiffusion3:
config_file = "stable-diffusion/v3-inference.yaml"
else:
config_file = LEGACY_CONFIGS[base_type][variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
config_file = config_file[prediction_type]
config_file = f"stable-diffusion/{config_file}"
elif model_type is ModelType.ControlNet:
config_file = (
"controlnet/cldm_v15.yaml"
@@ -374,7 +378,7 @@ def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[Con
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
return MainModelDefaultSettings(width=512, height=512)
elif model_base is BaseModelType.StableDiffusionXL:
elif model_base in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusion3]:
return MainModelDefaultSettings(width=1024, height=1024)
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
return None
@@ -398,7 +402,10 @@ class CheckpointProbeBase(ProbeBase):
if model_type != ModelType.Main:
return ModelVariantType.Normal
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
key = "model.diffusion_model.input_blocks.0.0.weight"
if key not in state_dict:
return ModelVariantType.Normal
in_channels = state_dict[key].shape[1]
if in_channels == 9:
return ModelVariantType.Inpaint
elif in_channels == 5:
@@ -425,6 +432,9 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
key_name = "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight"
if key_name in state_dict:
return BaseModelType.StableDiffusion3
else:
raise InvalidModelConfigException("Cannot determine base type")
@@ -596,6 +606,10 @@ class FolderProbeBase(ProbeBase):
class PipelineFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
with open(self.model_path / "model_index.json", "r") as file:
index_conf = json.load(file)
if index_conf.get("_class_name") == "StableDiffusion3Pipeline":
return BaseModelType.StableDiffusion3
with open(self.model_path / "unet" / "config.json", "r") as file:
unet_conf = json.load(file)
if unet_conf["cross_attention_dim"] == 768:
@@ -644,6 +658,8 @@ class VaeFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
if self._config_looks_like_sdxl():
return BaseModelType.StableDiffusionXL
elif self._config_looks_like_sd3():
return BaseModelType.StableDiffusion3
elif self._name_looks_like_sdxl():
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
@@ -663,6 +679,15 @@ class VaeFolderProbe(FolderProbeBase):
def _name_looks_like_sdxl(self) -> bool:
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
def _config_looks_like_sd3(self) -> bool:
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
config_file = self.model_path / "config.json"
if not config_file.exists():
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
return config.get("scaling_factor", 0) == 1.5305 and config.get("sample_size") in [512, 1024]
def _guess_name(self) -> str:
name = self.model_path.name
if name == "vae":

View File

@@ -122,6 +122,13 @@ STARTER_MODELS: list[StarterModel] = [
type=ModelType.Main,
dependencies=[sdxl_fp16_vae_fix],
),
StarterModel(
name="Stable Diffusion 3",
base=BaseModelType.StableDiffusion3,
source="stabilityai/stable-diffusion-3-medium-diffusers",
description="The OG Stable Diffusion 3 base model **NOT FOR COMMERCIAL USE**.",
type=ModelType.Main,
),
# endregion
# region VAE
sdxl_fp16_vae_fix,

View File

@@ -35,6 +35,18 @@ def filter_files(
The file list can be obtained from the `files` field of HuggingFaceMetadata,
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
"""
# BRITTLENESS WARNING!!
# The following pattern is designed to match model files that are components of diffusers submodels,
# but not to match other random stuff found in huggingface repos.
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area to be careful of.
DIFFUSERS_COMPONENT_PATTERN = (
r"model(-fp16)?(-\d+-of-\d+)?(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
)
variant = variant or ModelRepoVariant.Default
paths: List[Path] = []
root = files[0].parts[0]
@@ -45,31 +57,26 @@ def filter_files(
# Start by filtering on model file extensions, discarding images, docs, etc
for file in files:
if file.name.endswith((".json", ".txt")):
paths.append(file)
elif file.name.endswith(
if file.name.endswith(
(
".json",
".txt",
"learned_embeds.bin",
"ip_adapter.bin",
"lora_weights.safetensors",
"weights.pb",
"onnx_data",
"spiece.model",
)
):
paths.append(file)
# BRITTLENESS WARNING!!
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
# downloading random checkpoints that might also be in the repo. However there is no guarantee
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
# will adhere to this naming convention, so this is an area to be careful of.
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
elif re.search(DIFFUSERS_COMPONENT_PATTERN, file.name):
paths.append(file)
# limit search to subfolder if requested
if subfolder:
subfolder = root / subfolder
paths = [x for x in paths if x.parent == Path(subfolder)]
# _filter_by_variant uniquifies the paths and returns a set
return sorted(_filter_by_variant(paths, variant))
@@ -97,9 +104,22 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
if variant == ModelRepoVariant.Flax:
result.add(path)
elif path.suffix in [".json", ".txt"]:
elif path.suffix in [".json", ".txt", ".model"]:
result.add(path)
# handle shard patterns
elif re.match(r"model\.fp16-\d+-of-\d+\.safetensors", path.name):
if variant is ModelRepoVariant.FP16:
result.add(path)
else:
continue
elif re.match(r"model-\d+-of-\d+\.safetensors", path.name):
if variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]:
result.add(path)
else:
continue
elif variant in [
ModelRepoVariant.FP16,
ModelRepoVariant.FP32,
@@ -123,6 +143,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
score += 1
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
candidate_variant_label, *_ = str(candidate_variant_label).split("-") # handle shard pattern
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
@@ -139,6 +160,8 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
else:
continue
print(subfolder_weights)
for candidate_list in subfolder_weights.values():
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
if highest_score_candidate:

View File

@@ -7,6 +7,7 @@ from diffusers import (
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
FlowMatchEulerDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
@@ -29,6 +30,7 @@ SCHEDULER_MAP = {
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
"euler_a": (EulerAncestralDiscreteScheduler, {}),
"euler_f": (FlowMatchEulerDiscreteScheduler, {}),
"kdpm_2": (KDPM2DiscreteScheduler, {}),
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),

View File

@@ -3,7 +3,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import diffusers
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalControlNetMixin
# The following import is
# generating import errors with diffusers 028.2
# tried diffusers.loaders.controlnet import FromOriginalControlNetMixin, but this
# fails as well
# from diffusers.loaders import FromOriginalControlNetMixin
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
from diffusers.models.embeddings import (
@@ -32,7 +37,7 @@ from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger(__name__)
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
class ControlNetModel(ModelMixin, ConfigMixin):
"""
A ControlNet model.

View File

@@ -37,11 +37,7 @@
"selectBoard": "Select a Board",
"topMessage": "This board contains images used in the following features:",
"uncategorized": "Uncategorized",
"downloadBoard": "Download Board",
"imagesWithCount_one": "{{count}} image",
"imagesWithCount_other": "{{count}} images",
"assetsWithCount_one": "{{count}} asset",
"assetsWithCount_other": "{{count}} assets"
"downloadBoard": "Download Board"
},
"accordions": {
"generation": {
@@ -384,11 +380,7 @@
"problemDeletingImagesDesc": "One or more images could not be deleted",
"viewerImage": "Viewer Image",
"compareImage": "Compare Image",
"noActiveSearch": "No active search",
"openInViewer": "Open in Viewer",
"searchingBy": "Searching by",
"selectAllOnPage": "Select All On Page",
"selectAllOnBoard": "Select All On Board",
"selectForCompare": "Select for Compare",
"selectAnImageToCompare": "Select an Image to Compare",
"slider": "Slider",

View File

@@ -2,7 +2,8 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { imagesApi } from 'services/api/endpoints/images';
import { getListImagesUrl } from 'services/api/util';
import type { ImageCache } from 'services/api/types';
import { getListImagesUrl, imagesSelectors } from 'services/api/util';
export const addFirstListImagesListener = (startAppListening: AppStartListening) => {
startAppListening({
@@ -17,10 +18,13 @@ export const addFirstListImagesListener = (startAppListening: AppStartListening)
cancelActiveListeners();
unsubscribe();
const data = action.payload;
// TODO: figure out how to type the predicate
const data = action.payload as ImageCache;
if (data.items.length > 0) {
dispatch(imageSelected(data.items[0] ?? null));
if (data.ids.length > 0) {
// Select the first image
const firstImage = imagesSelectors.selectAll(data)[0];
dispatch(imageSelected(firstImage ?? null));
}
},
});

View File

@@ -1,13 +1,9 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
boardIdSelected,
galleryViewChanged,
imageSelected,
selectionChanged,
} from 'features/gallery/store/gallerySlice';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { imagesApi } from 'services/api/endpoints/images';
import { imagesSelectors } from 'services/api/util';
export const addBoardIdSelectedListener = (startAppListening: AppStartListening) => {
startAppListening({
@@ -18,9 +14,14 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
const board_id = boardIdSelected.match(action) ? action.payload.boardId : state.gallery.selectedBoardId;
dispatch(selectionChanged([]));
const galleryView = galleryViewChanged.match(action) ? action.payload : state.gallery.galleryView;
// when a board is selected, we need to wait until the board has loaded *some* images, then select the first one
const categories = galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES;
const queryArgs = { board_id: board_id ?? 'none', categories };
// wait until the board has some images - maybe it already has some from a previous fetch
// must use getState() to ensure we do not have stale state
@@ -34,12 +35,11 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
const { data: boardImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
if (boardImagesData && boardIdSelected.match(action) && action.payload.selectedImageName) {
const selectedImage = boardImagesData.items.find(
(item) => item.image_name === action.payload.selectedImageName
);
const selectedImage = imagesSelectors.selectById(boardImagesData, action.payload.selectedImageName);
dispatch(imageSelected(selectedImage || null));
} else if (boardImagesData) {
dispatch(imageSelected(boardImagesData.items[0] || null));
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
dispatch(imageSelected(firstImage || null));
} else {
// board has no images - deselect
dispatch(imageSelected(null));

View File

@@ -4,6 +4,7 @@ import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelecto
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
export const galleryImageClicked = createAction<{
imageDTO: ImageDTO;
@@ -31,14 +32,14 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
const queryResult = imagesApi.endpoints.listImages.select(queryArgs)(state);
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
if (!queryResult.data) {
if (!listImagesData) {
// Should never happen if we have clicked a gallery image
return;
}
const imageDTOs = queryResult.data.items;
const imageDTOs = imagesSelectors.selectAll(listImagesData);
const selection = state.gallery.selection;
if (altKey) {

View File

@@ -22,10 +22,11 @@ import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { isImageFieldInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { forEach } from 'lodash-es';
import { clamp, forEach } from 'lodash-es';
import { api } from 'services/api';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
state.nodes.present.nodes.forEach((node) => {
@@ -117,7 +118,32 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
}
dispatch(isModalOpenChanged(false));
const state = getState();
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
const { image_name } = imageDTO;
const baseQueryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
const cachedImageDTOs = data ? imagesSelectors.selectAll(data) : [];
const deletedImageIndex = cachedImageDTOs.findIndex((i) => i.image_name === image_name);
const filteredImageDTOs = cachedImageDTOs.filter((i) => i.image_name !== image_name);
const newSelectedImageIndex = clamp(deletedImageIndex, 0, filteredImageDTOs.length - 1);
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex];
if (newSelectedImageDTO) {
dispatch(imageSelected(newSelectedImageDTO));
} else {
dispatch(imageSelected(null));
}
}
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
if (imageUsage.isCanvasImage) {
@@ -142,20 +168,6 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
if (wasImageDeleted) {
dispatch(api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
}
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
const baseQueryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
if (data && data.items) {
const newlySelectedImage = data?.items.find((img) => img.image_name !== imageDTO?.image_name);
dispatch(imageSelected(newlySelectedImage || null));
} else {
dispatch(imageSelected(null));
}
}
},
});
@@ -176,8 +188,10 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
const queryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state);
if (data && data.items[0]) {
dispatch(imageSelected(data.items[0]));
const newSelectedImageDTO = data ? imagesSelectors.selectAll(data)[0] : undefined;
if (newSelectedImageDTO) {
dispatch(imageSelected(newSelectedImageDTO));
} else {
dispatch(imageSelected(null));
}

View File

@@ -15,12 +15,7 @@ import {
} from 'features/controlLayers/store/controlLayersSlice';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import {
imageSelected,
imageToCompareChanged,
isImageViewerOpenChanged,
selectionChanged,
} from 'features/gallery/store/gallerySlice';
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
@@ -221,7 +216,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
@@ -239,7 +233,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
imageDTO,
})
);
dispatch(selectionChanged([]));
return;
}
@@ -255,7 +248,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
board_id: boardId,
})
);
dispatch(selectionChanged([]));
return;
}
@@ -269,7 +261,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
imageDTOs,
})
);
dispatch(selectionChanged([]));
return;
}
},

View File

@@ -8,14 +8,14 @@ import {
galleryViewChanged,
imageSelected,
isImageViewerOpenChanged,
offsetChanged,
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { imagesAdapter } from 'services/api/util';
import { socketInvocationComplete } from 'services/events/actions';
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
@@ -52,6 +52,24 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
}
if (!imageDTO.is_intermediate) {
/**
* Cache updates for when an image result is received
* - add it to the no_board/images
*/
dispatch(
imagesApi.util.updateQueryData(
'listImages',
{
board_id: imageDTO.board_id ?? 'none',
categories: IMAGE_CATEGORIES,
},
(draft) => {
imagesAdapter.addOne(draft, imageDTO);
}
)
);
// update the total images for the board
dispatch(
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
@@ -60,18 +78,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
})
);
dispatch(
imagesApi.util.invalidateTags([
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
{
type: 'ImageList',
id: getListImagesUrl({
board_id: imageDTO.board_id ?? 'none',
categories: getCategories(imageDTO),
}),
},
])
);
dispatch(imagesApi.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
const { shouldAutoSwitch } = gallery;
@@ -91,8 +98,6 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
);
}
dispatch(offsetChanged(0));
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
dispatch(
boardIdSelected({

View File

@@ -1,37 +1,47 @@
import type { IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { IconButton } from '@invoke-ai/ui-library';
import type { MouseEvent } from 'react';
import { memo } from 'react';
import type { MouseEvent, ReactElement } from 'react';
import { memo, useMemo } from 'react';
const sx: SystemStyleObject = {
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
},
};
type Props = Omit<IconButtonProps, 'aria-label' | 'onClick' | 'tooltip'> & {
type Props = {
onClick: (event: MouseEvent<HTMLButtonElement>) => void;
tooltip: string;
icon?: ReactElement;
styleOverrides?: SystemStyleObject;
};
const IAIDndImageIcon = (props: Props) => {
const { onClick, tooltip, icon, ...rest } = props;
const { onClick, tooltip, icon, styleOverrides } = props;
const sx = useMemo(
() => ({
position: 'absolute',
top: 1,
insetInlineEnd: 1,
p: 0,
minW: 0,
svg: {
transitionProperty: 'common',
transitionDuration: 'normal',
fill: 'base.100',
_hover: { fill: 'base.50' },
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
},
...styleOverrides,
}),
[styleOverrides]
);
return (
<IconButton
onClick={onClick}
aria-label={tooltip}
tooltip={tooltip}
icon={icon}
size="sm"
variant="link"
sx={sx}
data-testid={tooltip}
{...rest}
/>
);
};

View File

@@ -0,0 +1,16 @@
/**
* Comparator function for sorting dates in ascending order
*/
export const dateComparator = (a: string, b: string) => {
const dateA = new Date(a);
const dateB = new Date(b);
// sort in ascending order
if (dateA > dateB) {
return 1;
}
if (dateA < dateB) {
return -1;
}
return 0;
};

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Spinner } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
@@ -184,7 +185,7 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
/>
</Box>
<Flex flexDir="column" top={1} insetInlineEnd={1}>
<>
<IAIDndImageIcon
onClick={handleResetControlImage}
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
@@ -194,13 +195,15 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
onClick={handleSaveControlImage}
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
tooltip={t('controlnet.saveControlImage')}
styleOverrides={saveControlImageStyleOverrides}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={t('controlnet.setControlImageDimensions')}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</Flex>
</>
{pendingControlImages.includes(id) && (
<Flex
@@ -223,3 +226,6 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
};
export default memo(ControlAdapterImagePreview);
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Spinner, useShiftModifier } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -202,13 +203,13 @@ export const ControlAdapterImagePreview = memo(
onClick={handleSaveControlImage}
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
tooltip={t('controlnet.saveControlImage')}
mt={6}
styleOverrides={saveControlImageStyleOverrides}
/>
<IAIDndImageIcon
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
mt={12}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</>
@@ -234,3 +235,6 @@ export const ControlAdapterImagePreview = memo(
);
ControlAdapterImagePreview.displayName = 'ControlAdapterImagePreview';
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -99,7 +100,7 @@ export const IPAdapterImagePreview = memo(
onClick={handleSetControlImageToDimensions}
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
mt={6}
styleOverrides={setControlImageDimensionsStyleOverrides}
/>
</>
</Flex>
@@ -108,3 +109,5 @@ export const IPAdapterImagePreview = memo(
);
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 6 };

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
@@ -96,7 +97,7 @@ export const InitialImagePreview = memo(({ image, onChangeImage, droppableData,
onClick={onUseSize}
icon={imageDTO ? <PiRulerBold size={16} /> : undefined}
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
mt={6}
styleOverrides={useSizeStyleOverrides}
/>
</>
</Flex>
@@ -104,3 +105,5 @@ export const InitialImagePreview = memo(({ image, onChangeImage, droppableData,
});
InitialImagePreview.displayName = 'InitialImagePreview';
const useSizeStyleOverrides: SystemStyleObject = { mt: 6 };

View File

@@ -1,21 +0,0 @@
import { useTranslation } from 'react-i18next';
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
type Props = {
board_id: string;
};
export const BoardTotalsTooltip = ({ board_id }: Props) => {
const { t } = useTranslation();
const { imagesTotal } = useGetBoardImagesTotalQuery(board_id, {
selectFromResult: ({ data }) => {
return { imagesTotal: data?.total ?? 0 };
},
});
const { assetsTotal } = useGetBoardAssetsTotalQuery(board_id, {
selectFromResult: ({ data }) => {
return { assetsTotal: data?.total ?? 0 };
},
});
return `${t('boards.imagesWithCount', { count: imagesTotal })}, ${t('boards.assetsWithCount', { count: assetsTotal })}`;
};

View File

@@ -8,12 +8,15 @@ import SelectionOverlay from 'common/components/SelectionOverlay';
import type { AddToBoardDropData } from 'features/dnd/types';
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
import { autoAddBoardIdChanged, boardIdSelected, selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImagesSquare } from 'react-icons/pi';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import {
useGetBoardAssetsTotalQuery,
useGetBoardImagesTotalQuery,
useUpdateBoardMutation,
} from 'services/api/endpoints/boards';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import type { BoardDTO } from 'services/api/types';
@@ -48,6 +51,17 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
setIsHovered(false);
}, []);
const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id);
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id);
const tooltip = useMemo(() => {
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
return undefined;
}
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
assetsTotal.total
} asset${assetsTotal.total === 1 ? '' : 's'}`;
}, [assetsTotal, imagesTotal]);
const { currentData: coverImage } = useGetImageDTOQuery(board.cover_image_name ?? skipToken);
const { board_name, board_id } = board;
@@ -118,7 +132,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
>
<BoardContextMenu board={board} board_id={board_id} setBoardToDelete={setBoardToDelete}>
{(ref) => (
<Tooltip label={<BoardTotalsTooltip board_id={board.board_id} />} openDelay={1000}>
<Tooltip label={tooltip} openDelay={1000}>
<Flex
ref={ref}
onClick={handleSelectBoard}

View File

@@ -5,11 +5,11 @@ import SelectionOverlay from 'common/components/SelectionOverlay';
import type { RemoveFromBoardDropData } from 'features/dnd/types';
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
import { useBoardName } from 'services/api/hooks/useBoardName';
interface Props {
@@ -29,6 +29,17 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
}, [dispatch, autoAssignBoardOnClick]);
const [isHovered, setIsHovered] = useState(false);
const { data: imagesTotal } = useGetBoardImagesTotalQuery('none');
const { data: assetsTotal } = useGetBoardAssetsTotalQuery('none');
const tooltip = useMemo(() => {
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
return undefined;
}
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
assetsTotal.total
} asset${assetsTotal.total === 1 ? '' : 's'}`;
}, [assetsTotal, imagesTotal]);
const handleMouseOver = useCallback(() => {
setIsHovered(true);
}, []);
@@ -60,7 +71,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
>
<BoardContextMenu board_id="none">
{(ref) => (
<Tooltip label={<BoardTotalsTooltip board_id="none" />} openDelay={1000}>
<Tooltip label={tooltip} openDelay={1000}>
<Flex
ref={ref}
onClick={handleSelectBoard}

View File

@@ -1,55 +0,0 @@
import { Flex, IconButton, Spacer, Tag, TagCloseButton, TagLabel, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
import { selectionChanged } from 'features/gallery/store/gallerySlice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { BiSelectMultiple } from 'react-icons/bi';
import { GallerySearch } from './GallerySearch';
export const GalleryBulkSelect = () => {
const dispatch = useAppDispatch();
const { selection } = useAppSelector((s) => s.gallery);
const { t } = useTranslation();
const { imageDTOs } = useGalleryImages();
const onClickClearSelection = useCallback(() => {
dispatch(selectionChanged([]));
}, [dispatch]);
const onClickSelectAllPage = useCallback(() => {
dispatch(selectionChanged(selection.concat(imageDTOs)));
}, [dispatch, imageDTOs, selection]);
return (
<Flex alignItems="center" justifyContent="space-between">
<Flex>
{selection.length > 0 ? (
<Tag>
<TagLabel>
{selection.length} {t('common.selected')}
</TagLabel>
<Tooltip label="Clear selection">
<TagCloseButton onClick={onClickClearSelection} />
</Tooltip>
</Tag>
) : (
<Spacer />
)}
<Tooltip label={t('gallery.selectAllOnPage')}>
<IconButton
variant="outline"
size="sm"
icon={<BiSelectMultiple />}
aria-label="Bulk select"
onClick={onClickSelectAllPage}
/>
</Tooltip>
</Flex>
<GallerySearch />
</Flex>
);
};

View File

@@ -1,97 +0,0 @@
import { Flex, IconButton, Input, InputGroup, InputRightElement, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { searchTermChanged } from 'features/gallery/store/gallerySlice';
import { motion } from 'framer-motion';
import { debounce } from 'lodash-es';
import type { ChangeEvent } from 'react';
import { useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiMagnifyingGlassBold, PiXBold } from 'react-icons/pi';
export const GallerySearch = () => {
const dispatch = useAppDispatch();
const { searchTerm } = useAppSelector((s) => s.gallery);
const { t } = useTranslation();
const [expanded, setExpanded] = useState(false);
const [searchTermInput, setSearchTermInput] = useState('');
const debouncedSetSearchTerm = useMemo(() => {
return debounce((value: string) => {
dispatch(searchTermChanged(value));
}, 1000);
}, [dispatch]);
const onChangeInput = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
setSearchTermInput(e.target.value);
debouncedSetSearchTerm(e.target.value);
},
[debouncedSetSearchTerm]
);
const onClearInput = useCallback(() => {
setSearchTermInput('');
debouncedSetSearchTerm('');
}, [debouncedSetSearchTerm]);
const toggleExpanded = useCallback((newState: boolean) => {
setExpanded(newState);
}, []);
return (
<Flex>
{!expanded && (
<Tooltip
label={
searchTerm && searchTerm.length ? `${t('gallery.searchingBy')} ${searchTerm}` : t('gallery.noActiveSearch')
}
>
<IconButton
aria-label="Close"
icon={<PiMagnifyingGlassBold />}
onClick={toggleExpanded.bind(null, true)}
variant="outline"
size="sm"
/>
</Tooltip>
)}
<motion.div
initial={false}
animate={{ width: expanded ? '200px' : '0px' }}
transition={{ duration: 0.3 }}
style={{ overflow: 'hidden' }}
>
<InputGroup size="sm">
<IconButton
aria-label="Close"
icon={<PiMagnifyingGlassBold />}
onClick={toggleExpanded.bind(null, false)}
variant="ghost"
size="sm"
/>
<Input
type="text"
placeholder="Search..."
size="sm"
variant="outline"
onChange={onChangeInput}
value={searchTermInput}
/>
{searchTermInput && searchTermInput.length && (
<InputRightElement h="full" pe={2}>
<IconButton
onClick={onClearInput}
size="sm"
variant="link"
aria-label={t('boards.clearSearch')}
icon={<PiXBold />}
/>
</InputRightElement>
)}
</InputGroup>
</motion.div>
</Flex>
);
};

View File

@@ -1,22 +1,22 @@
import { Box, Button, ButtonGroup, Flex, Tab, TabList, Tabs, useDisclosure } from '@invoke-ai/ui-library';
import { Box, Button, ButtonGroup, Flex, Tab, TabList, Tabs, useDisclosure, VStack } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { $galleryHeader } from 'app/store/nanostores/galleryHeader';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { galleryViewChanged } from 'features/gallery/store/gallerySlice';
import { memo, useCallback } from 'react';
import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImagesBold } from 'react-icons/pi';
import { RiServerLine } from 'react-icons/ri';
import BoardsList from './Boards/BoardsList/BoardsList';
import GalleryBoardName from './GalleryBoardName';
import { GalleryBulkSelect } from './GalleryBulkSelect';
import GallerySettingsPopover from './GallerySettingsPopover';
import GalleryImageGrid from './ImageGrid/GalleryImageGrid';
import { GalleryPagination } from './ImageGrid/GalleryPagination';
const ImageGalleryContent = () => {
const { t } = useTranslation();
const resizeObserverRef = useRef<HTMLDivElement>(null);
const galleryGridRef = useRef<HTMLDivElement>(null);
const galleryView = useAppSelector((s) => s.gallery.galleryView);
const dispatch = useAppDispatch();
const galleryHeader = useStore($galleryHeader);
@@ -31,10 +31,10 @@ const ImageGalleryContent = () => {
}, [dispatch]);
return (
<Flex layerStyle="first" flexDirection="column" h="full" w="full" borderRadius="base" p={2} gap={2}>
<VStack layerStyle="first" flexDirection="column" h="full" w="full" borderRadius="base" p={2}>
{galleryHeader}
<Box>
<Flex alignItems="center" justifyContent="space-between" gap={2}>
<Box w="full">
<Flex ref={resizeObserverRef} alignItems="center" justifyContent="space-between" gap={2}>
<GalleryBoardName isOpen={isBoardListOpen} onToggle={onToggleBoardList} />
<GallerySettingsPopover />
</Flex>
@@ -42,41 +42,40 @@ const ImageGalleryContent = () => {
<BoardsList isOpen={isBoardListOpen} />
</Box>
</Box>
<Flex alignItems="center" justifyContent="space-between" gap={2}>
<Tabs index={galleryView === 'images' ? 0 : 1} variant="unstyled" size="sm" w="full">
<TabList>
<ButtonGroup w="full">
<Tab
as={Button}
size="sm"
isChecked={galleryView === 'images'}
onClick={handleClickImages}
w="full"
leftIcon={<PiImagesBold size="16px" />}
data-testid="images-tab"
>
{t('parameters.images')}
</Tab>
<Tab
as={Button}
size="sm"
isChecked={galleryView === 'assets'}
onClick={handleClickAssets}
w="full"
leftIcon={<RiServerLine size="16px" />}
data-testid="assets-tab"
>
{t('gallery.assets')}
</Tab>
</ButtonGroup>
</TabList>
</Tabs>
<Flex ref={galleryGridRef} direction="column" gap={2} h="full" w="full">
<Flex alignItems="center" justifyContent="space-between" gap={2}>
<Tabs index={galleryView === 'images' ? 0 : 1} variant="unstyled" size="sm" w="full">
<TabList>
<ButtonGroup w="full">
<Tab
as={Button}
size="sm"
isChecked={galleryView === 'images'}
onClick={handleClickImages}
w="full"
leftIcon={<PiImagesBold size="16px" />}
data-testid="images-tab"
>
{t('parameters.images')}
</Tab>
<Tab
as={Button}
size="sm"
isChecked={galleryView === 'assets'}
onClick={handleClickAssets}
w="full"
leftIcon={<RiServerLine size="16px" />}
data-testid="assets-tab"
>
{t('gallery.assets')}
</Tab>
</ButtonGroup>
</TabList>
</Tabs>
</Flex>
<GalleryImageGrid />
</Flex>
<GalleryBulkSelect />
<GalleryImageGrid />
<GalleryPagination />
</Flex>
</VStack>
);
};

View File

@@ -16,13 +16,13 @@ import type { MouseEvent } from 'react';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiStarBold, PiStarFill, PiTrashSimpleFill } from 'react-icons/pi';
import { useStarImagesMutation, useUnstarImagesMutation } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
// This class name is used to calculate the number of images that fit in the gallery
export const GALLERY_IMAGE_CLASS_NAME = 'gallery-image';
import { useGetImageDTOQuery, useStarImagesMutation, useUnstarImagesMutation } from 'services/api/endpoints/images';
const imageSx: SystemStyleObject = { w: 'full', h: 'full' };
const imageIconStyleOverrides: SystemStyleObject = {
bottom: 2,
top: 'auto',
};
const boxSx: SystemStyleObject = {
containerType: 'inline-size',
};
@@ -34,22 +34,24 @@ const badgeSx: SystemStyleObject = {
};
interface HoverableImageProps {
imageDTO: ImageDTO;
imageName: string;
index: number;
}
const GalleryImage = ({ index, imageDTO }: HoverableImageProps) => {
const GalleryImage = (props: HoverableImageProps) => {
const dispatch = useAppDispatch();
const { imageName } = props;
const { currentData: imageDTO } = useGetImageDTOQuery(imageName);
const shift = useShiftModifier();
const { t } = useTranslation();
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageDTO.image_name);
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageName);
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
const customStarUi = useStore($customStarUI);
const imageContainerRef = useScrollIntoView(isSelected, index, areMultiplesSelected);
const imageContainerRef = useScrollIntoView(isSelected, props.index, areMultiplesSelected);
const handleDelete = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
@@ -112,32 +114,32 @@ const GalleryImage = ({ index, imageDTO }: HoverableImageProps) => {
}, []);
const starIcon = useMemo(() => {
if (imageDTO.starred) {
if (imageDTO?.starred) {
return customStarUi ? customStarUi.on.icon : <PiStarFill size="20" />;
}
if (!imageDTO.starred && isHovered) {
if (!imageDTO?.starred && isHovered) {
return customStarUi ? customStarUi.off.icon : <PiStarBold size="20" />;
}
}, [imageDTO.starred, isHovered, customStarUi]);
}, [imageDTO?.starred, isHovered, customStarUi]);
const starTooltip = useMemo(() => {
if (imageDTO.starred) {
if (imageDTO?.starred) {
return customStarUi ? customStarUi.off.text : 'Unstar';
}
if (!imageDTO.starred) {
if (!imageDTO?.starred) {
return customStarUi ? customStarUi.on.text : 'Star';
}
return '';
}, [imageDTO.starred, customStarUi]);
}, [imageDTO?.starred, customStarUi]);
const dataTestId = useMemo(() => getGalleryImageDataTestId(imageDTO.image_name), [imageDTO.image_name]);
const dataTestId = useMemo(() => getGalleryImageDataTestId(imageDTO?.image_name), [imageDTO?.image_name]);
if (!imageDTO) {
return <IAIFillSkeleton />;
}
return (
<Box w="full" h="full" p={1.5} className={GALLERY_IMAGE_CLASS_NAME} data-testid={dataTestId} sx={boxSx}>
<Box w="full" h="full" className="gallerygrid-image" data-testid={dataTestId} sx={boxSx}>
<Flex
ref={imageContainerRef}
userSelect="none"
@@ -181,23 +183,14 @@ const GalleryImage = ({ index, imageDTO }: HoverableImageProps) => {
pointerEvents="none"
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
)}
<IAIDndImageIcon
onClick={toggleStarredState}
icon={starIcon}
tooltip={starTooltip}
position="absolute"
top={1}
insetInlineEnd={1}
/>
<IAIDndImageIcon onClick={toggleStarredState} icon={starIcon} tooltip={starTooltip} />
{isHovered && shift && (
<IAIDndImageIcon
onClick={handleDelete}
icon={<PiTrashSimpleFill size="16px" />}
tooltip={t('gallery.deleteImage_one')}
position="absolute"
bottom={1}
insetInlineEnd={1}
tooltip={t('gallery.deleteImage', { count: 1 })}
styleOverrides={imageIconStyleOverrides}
/>
)}
</>

View File

@@ -1,32 +1,120 @@
import { Box, Flex, Grid } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { Box, Button, Flex } from '@invoke-ai/ui-library';
import type { EntityId } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
import { useGalleryHotkeys } from 'features/gallery/hooks/useGalleryHotkeys';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { limitChanged } from 'features/gallery/store/gallerySlice';
import { debounce } from 'lodash-es';
import { memo, useEffect, useMemo, useState } from 'react';
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
import type { CSSProperties } from 'react';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImageBold, PiWarningCircleBold } from 'react-icons/pi';
import { useListImagesQuery } from 'services/api/endpoints/images';
import type { GridComponents, ItemContent, ListRange, VirtuosoGridHandle } from 'react-virtuoso';
import { VirtuosoGrid } from 'react-virtuoso';
import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
import { GALLERY_GRID_CLASS_NAME } from './constants';
import GalleryImage, { GALLERY_IMAGE_CLASS_NAME } from './GalleryImage';
import GalleryImage from './GalleryImage';
import ImageGridItemContainer from './ImageGridItemContainer';
import ImageGridListContainer from './ImageGridListContainer';
const components: GridComponents = {
Item: ImageGridItemContainer,
List: ImageGridListContainer,
};
const virtuosoStyles: CSSProperties = { height: '100%' };
const GalleryImageGrid = () => {
useGalleryHotkeys();
const { t } = useTranslation();
const queryArgs = useAppSelector(selectListImagesQueryArgs);
const { imageDTOs, isLoading, isError, isFetching } = useListImagesQuery(queryArgs, {
selectFromResult: ({ data, isLoading, isSuccess, isError, isFetching }) => ({
imageDTOs: data?.items ?? EMPTY_ARRAY,
isLoading,
isSuccess,
isError,
isFetching,
}),
});
const rootRef = useRef<HTMLDivElement>(null);
const [scroller, setScroller] = useState<HTMLElement | null>(null);
const [initialize, osInstance] = useOverlayScrollbars(overlayScrollbarsParams);
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
const { currentViewTotal } = useBoardTotal(selectedBoardId);
const virtuosoRangeRef = useRef<ListRange | null>(null);
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
const {
areMoreImagesAvailable,
handleLoadMoreImages,
queryResult: { currentData, isFetching, isSuccess, isError },
} = useGalleryImages();
useGalleryHotkeys();
const itemContentFunc: ItemContent<EntityId, void> = useCallback(
(index, imageName) => <GalleryImage key={imageName} index={index} imageName={imageName as string} />,
[]
);
useEffect(() => {
// Initialize the gallery's custom scrollbar
const { current: root } = rootRef;
if (scroller && root) {
initialize({
target: root,
elements: {
viewport: scroller,
},
});
}
return () => osInstance()?.destroy();
}, [scroller, initialize, osInstance]);
const onRangeChanged = useCallback((range: ListRange) => {
virtuosoRangeRef.current = range;
}, []);
useEffect(() => {
virtuosoGridRefs.set({ rootRef, virtuosoRangeRef, virtuosoRef });
return () => {
virtuosoGridRefs.set({});
};
}, []);
if (!currentData) {
return (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<IAINoContentFallback label={t('gallery.loading')} icon={PiImageBold} />
</Flex>
);
}
if (isSuccess && currentData?.ids.length === 0) {
return (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<IAINoContentFallback label={t('gallery.noImagesInGallery')} icon={PiImageBold} />
</Flex>
);
}
if (isSuccess && currentData) {
return (
<>
<Box ref={rootRef} data-overlayscrollbars="" h="100%" id="gallery-grid">
<VirtuosoGrid
style={virtuosoStyles}
data={currentData.ids}
endReached={handleLoadMoreImages}
components={components}
scrollerRef={setScroller}
itemContent={itemContentFunc}
ref={virtuosoRef}
rangeChanged={onRangeChanged}
overscan={10}
/>
</Box>
<Button
onClick={handleLoadMoreImages}
isDisabled={!areMoreImagesAvailable}
isLoading={isFetching}
loadingText={t('gallery.loading')}
flexShrink={0}
>
{`${t('accessibility.loadMore')} (${currentData.ids.length} / ${currentViewTotal})`}
</Button>
</>
);
}
if (isError) {
return (
@@ -36,115 +124,7 @@ const GalleryImageGrid = () => {
);
}
if (isLoading || isFetching) {
return (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<IAINoContentFallback label={t('gallery.loading')} icon={PiImageBold} />
</Flex>
);
}
if (imageDTOs.length === 0) {
return (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<IAINoContentFallback label={t('gallery.noImagesInGallery')} icon={PiImageBold} />
</Flex>
);
}
return <Content />;
return null;
};
export default memo(GalleryImageGrid);
const Content = () => {
const dispatch = useAppDispatch();
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
const queryArgs = useAppSelector(selectListImagesQueryArgs);
const { imageDTOs } = useListImagesQuery(queryArgs, {
selectFromResult: ({ data }) => ({ imageDTOs: data?.items ?? EMPTY_ARRAY }),
});
// Use a callback ref to get reactivity on the container element because it is conditionally rendered
const [container, containerRef] = useState<HTMLDivElement | null>(null);
const calculateNewLimit = useMemo(() => {
// Debounce this to not thrash the API
return debounce(() => {
if (!container) {
// Container not rendered yet
return;
}
// Managing refs for dynamically rendered components is a bit tedious:
// - https://react.dev/learn/manipulating-the-dom-with-refs#how-to-manage-a-list-of-refs-using-a-ref-callback
// As a easy workaround, we can just grab the first gallery image element directly.
const galleryImageEl = document.querySelector(`.${GALLERY_IMAGE_CLASS_NAME}`);
if (!galleryImageEl) {
// No images in gallery?
return;
}
const galleryImageRect = galleryImageEl.getBoundingClientRect();
const containerRect = container.getBoundingClientRect();
if (!galleryImageRect.width || !galleryImageRect.height || !containerRect.width || !containerRect.height) {
// Gallery is too small to fit images or not rendered yet
return;
}
// Floating-point precision requires we round to get the correct number of images per row
const imagesPerRow = Math.round(containerRect.width / galleryImageRect.width);
// However, when calculating the number of images per column, we want to floor the value to not overflow the container
const imagesPerColumn = Math.floor(containerRect.height / galleryImageRect.height);
// Always load at least 1 row of images
const limit = Math.max(imagesPerRow, imagesPerRow * imagesPerColumn);
dispatch(limitChanged(limit));
}, 300);
}, [container, dispatch]);
useEffect(() => {
// We want to recalculate the limit when image size changes
calculateNewLimit();
}, [calculateNewLimit, galleryImageMinimumWidth]);
useEffect(() => {
if (!container) {
return;
}
const resizeObserver = new ResizeObserver(calculateNewLimit);
resizeObserver.observe(container);
// First render
calculateNewLimit();
return () => {
resizeObserver.disconnect();
};
}, [calculateNewLimit, container, dispatch]);
return (
<Box position="relative" w="full" h="full">
<Box
ref={containerRef}
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
w="full"
h="full"
overflow="hidden"
>
<Grid
className={GALLERY_GRID_CLASS_NAME}
gridTemplateColumns={`repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`}
>
{imageDTOs.map((imageDTO, index) => (
<GalleryImage key={imageDTO.image_name} imageDTO={imageDTO} index={index} />
))}
</Grid>
</Box>
</Box>
);
};

View File

@@ -1,73 +0,0 @@
import { Button, Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
import { PiCaretDoubleLeftBold, PiCaretDoubleRightBold, PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
export const GalleryPagination = () => {
const {
goPrev,
goNext,
goToFirst,
goToLast,
isFirstEnabled,
isLastEnabled,
isPrevEnabled,
isNextEnabled,
pageButtons,
goToPage,
currentPage,
rangeDisplay,
total,
} = useGalleryPagination();
if (!total) {
return <Flex flexDir="column" alignItems="center" gap="2" height="48px"></Flex>;
}
return (
<Flex flexDir="column" alignItems="center" gap="2" height="48px">
<Flex gap={2} alignItems="center" w="full">
<IconButton
size="sm"
aria-label="prev"
icon={<PiCaretDoubleLeftBold />}
onClick={goToFirst}
isDisabled={!isFirstEnabled}
/>
<IconButton
size="sm"
aria-label="prev"
icon={<PiCaretLeftBold />}
onClick={goPrev}
isDisabled={!isPrevEnabled}
/>
<Spacer />
{pageButtons.map((page) => (
<Button
size="sm"
key={page}
onClick={goToPage.bind(null, page)}
variant={currentPage === page ? 'solid' : 'outline'}
>
{page + 1}
</Button>
))}
<Spacer />
<IconButton
size="sm"
aria-label="next"
icon={<PiCaretRightBold />}
onClick={goNext}
isDisabled={!isNextEnabled}
/>
<IconButton
size="sm"
aria-label="next"
icon={<PiCaretDoubleRightBold />}
onClick={goToLast}
isDisabled={!isLastEnabled}
/>
</Flex>
<Text>{rangeDisplay}</Text>
</Flex>
);
};

View File

@@ -0,0 +1,15 @@
import type { FlexProps } from '@invoke-ai/ui-library';
import { Box, forwardRef } from '@invoke-ai/ui-library';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
export const imageItemContainerTestId = 'image-item-container';
type ItemContainerProps = PropsWithChildren & FlexProps;
const ItemContainer = forwardRef((props: ItemContainerProps, ref) => (
<Box className="item-container" ref={ref} p={1.5} data-testid={imageItemContainerTestId}>
{props.children}
</Box>
));
export default memo(ItemContainer);

View File

@@ -0,0 +1,26 @@
import type { FlexProps } from '@invoke-ai/ui-library';
import { forwardRef, Grid } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
export const imageListContainerTestId = 'image-list-container';
type ListContainerProps = PropsWithChildren & FlexProps;
const ListContainer = forwardRef((props: ListContainerProps, ref) => {
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
return (
<Grid
{...props}
className="list-container"
ref={ref}
gridTemplateColumns={`repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`}
data-testid={imageListContainerTestId}
>
{props.children}
</Grid>
);
});
export default memo(ListContainer);

View File

@@ -1 +0,0 @@
export const GALLERY_GRID_CLASS_NAME = 'gallery-grid';

View File

@@ -2,7 +2,6 @@ import type { ChakraProps } from '@invoke-ai/ui-library';
import { Box, Flex, IconButton, Spinner } from '@invoke-ai/ui-library';
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
import { useGalleryNavigation } from 'features/gallery/hooks/useGalleryNavigation';
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCaretDoubleRightBold, PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
@@ -17,8 +16,11 @@ const NextPrevImageButtons = () => {
const { prevImage, nextImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
const { isFetching } = useGalleryImages().queryResult;
const { isNextEnabled, goNext } = useGalleryPagination();
const {
areMoreImagesAvailable,
handleLoadMoreImages,
queryResult: { isFetching },
} = useGalleryImages();
return (
<Box pos="relative" h="full" w="full">
@@ -45,17 +47,17 @@ const NextPrevImageButtons = () => {
sx={nextPrevButtonStyles}
/>
)}
{isOnLastImage && isNextEnabled && !isFetching && (
{isOnLastImage && areMoreImagesAvailable && !isFetching && (
<IconButton
aria-label={t('accessibility.loadMore')}
icon={<PiCaretDoubleRightBold size={64} />}
variant="unstyled"
onClick={goNext}
onClick={handleLoadMoreImages}
boxSize={16}
sx={nextPrevButtonStyles}
/>
)}
{isOnLastImage && isNextEnabled && isFetching && (
{isOnLastImage && areMoreImagesAvailable && isFetching && (
<Flex w={16} h={16} alignItems="center" justifyContent="center">
<Spinner opacity={0.5} size="xl" />
</Flex>

View File

@@ -1,12 +1,10 @@
import { useAppSelector } from 'app/store/storeHooks';
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
import { useGalleryNavigation } from 'features/gallery/hooks/useGalleryNavigation';
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useListImagesQuery } from 'services/api/endpoints/images';
/**
* Registers gallery hotkeys. This hook is a singleton.
@@ -19,30 +17,21 @@ export const useGalleryHotkeys = () => {
return activeTabName !== 'canvas' || !isStaging;
}, [activeTabName, isStaging]);
const { goNext, goPrev, isNextEnabled, isPrevEnabled } = useGalleryPagination();
const queryArgs = useAppSelector(selectListImagesQueryArgs);
const queryResult = useListImagesQuery(queryArgs);
const {
handleLeftImage,
handleRightImage,
handleUpImage,
handleDownImage,
areImagesBelowCurrent,
isOnFirstImageOfView,
isOnLastImageOfView,
} = useGalleryNavigation();
areMoreImagesAvailable,
handleLoadMoreImages,
queryResult: { isFetching },
} = useGalleryImages();
const { handleLeftImage, handleRightImage, handleUpImage, handleDownImage, isOnLastImage, areImagesBelowCurrent } =
useGalleryNavigation();
useHotkeys(
['left', 'alt+left'],
(e) => {
if (isOnFirstImageOfView && isPrevEnabled && !queryResult.isFetching) {
goPrev();
return;
}
canNavigateGallery && handleLeftImage(e.altKey);
},
[handleLeftImage, canNavigateGallery, isOnFirstImageOfView, goPrev, isPrevEnabled, queryResult.isFetching]
[handleLeftImage, canNavigateGallery]
);
useHotkeys(
@@ -51,15 +40,15 @@ export const useGalleryHotkeys = () => {
if (!canNavigateGallery) {
return;
}
if (isOnLastImageOfView && isNextEnabled && !queryResult.isFetching) {
goNext();
if (isOnLastImage && areMoreImagesAvailable && !isFetching) {
handleLoadMoreImages();
return;
}
if (!isOnLastImageOfView) {
if (!isOnLastImage) {
handleRightImage(e.altKey);
}
},
[isOnLastImageOfView, goNext, isNextEnabled, queryResult.isFetching, handleRightImage, canNavigateGallery]
[isOnLastImage, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleRightImage, canNavigateGallery]
);
useHotkeys(
@@ -74,13 +63,13 @@ export const useGalleryHotkeys = () => {
useHotkeys(
['down', 'alt+down'],
(e) => {
if (!areImagesBelowCurrent && isNextEnabled && !queryResult.isFetching) {
goNext();
if (!areImagesBelowCurrent && areMoreImagesAvailable && !isFetching) {
handleLoadMoreImages();
return;
}
handleDownImage(e.altKey);
},
{ preventDefault: true },
[areImagesBelowCurrent, goNext, isNextEnabled, queryResult.isFetching, handleDownImage]
[areImagesBelowCurrent, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleDownImage]
);
};

View File

@@ -1,15 +1,38 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { useMemo } from 'react';
import { moreImagesLoaded } from 'features/gallery/store/gallerySlice';
import { useCallback, useMemo } from 'react';
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
import { useListImagesQuery } from 'services/api/endpoints/images';
/**
* Provides access to the gallery images and a way to imperatively fetch more.
*/
export const useGalleryImages = () => {
const dispatch = useAppDispatch();
const galleryView = useAppSelector((s) => s.gallery.galleryView);
const queryArgs = useAppSelector(selectListImagesQueryArgs);
const queryResult = useListImagesQuery(queryArgs);
const imageDTOs = useMemo(() => queryResult.data?.items ?? EMPTY_ARRAY, [queryResult.data]);
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(selectedBoardId);
const { data: imagesTotal } = useGetBoardImagesTotalQuery(selectedBoardId);
const currentViewTotal = useMemo(
() => (galleryView === 'images' ? imagesTotal?.total : assetsTotal?.total),
[assetsTotal?.total, galleryView, imagesTotal?.total]
);
const areMoreImagesAvailable = useMemo(() => {
if (!currentViewTotal || !queryResult.data) {
return false;
}
return queryResult.data.ids.length < currentViewTotal;
}, [queryResult.data, currentViewTotal]);
const handleLoadMoreImages = useCallback(() => {
dispatch(moreImagesLoaded());
}, [dispatch]);
return {
imageDTOs,
areMoreImagesAvailable,
handleLoadMoreImages,
queryResult,
};
};

View File

@@ -1,8 +1,8 @@
import { useAltModifier } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { GALLERY_GRID_CLASS_NAME } from 'features/gallery/components/ImageGrid/constants';
import { GALLERY_IMAGE_CLASS_NAME } from 'features/gallery/components/ImageGrid/GalleryImage';
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
import { imageItemContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridItemContainer';
import { imageListContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridListContainer';
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
import { imageSelected, imageToCompareChanged } from 'features/gallery/store/gallerySlice';
@@ -11,6 +11,7 @@ import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAli
import { clamp } from 'lodash-es';
import { useCallback, useMemo } from 'react';
import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
/**
* This hook is used to navigate the gallery using the arrow keys.
@@ -28,9 +29,10 @@ import type { ImageDTO } from 'services/api/types';
*/
const getImagesPerRow = (): number => {
const widthOfGalleryImage =
document.querySelector(`.${GALLERY_IMAGE_CLASS_NAME}`)?.getBoundingClientRect().width ?? 1;
document.querySelector(`[data-testid="${imageItemContainerTestId}"]`)?.getBoundingClientRect().width ?? 1;
const widthOfGalleryGrid = document.querySelector(`.${GALLERY_GRID_CLASS_NAME}`)?.getBoundingClientRect().width ?? 0;
const widthOfGalleryGrid =
document.querySelector(`[data-testid="${imageListContainerTestId}"]`)?.getBoundingClientRect().width ?? 0;
const imagesPerRow = Math.round(widthOfGalleryGrid / widthOfGalleryImage);
@@ -113,8 +115,6 @@ type UseGalleryNavigationReturn = {
isOnFirstImage: boolean;
isOnLastImage: boolean;
areImagesBelowCurrent: boolean;
isOnFirstImageOfView: boolean;
isOnLastImageOfView: boolean;
};
/**
@@ -134,19 +134,23 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
return lastSelected;
}
});
const { imageDTOs } = useGalleryImages();
const loadedImagesCount = useMemo(() => imageDTOs.length, [imageDTOs.length]);
const {
queryResult: { data },
} = useGalleryImages();
const loadedImagesCount = useMemo(() => data?.ids.length ?? 0, [data?.ids.length]);
const lastSelectedImageIndex = useMemo(() => {
if (imageDTOs.length === 0 || !lastSelectedImage) {
if (!data || !lastSelectedImage) {
return 0;
}
return imageDTOs.findIndex((i) => i.image_name === lastSelectedImage.image_name);
}, [imageDTOs, lastSelectedImage]);
return imagesSelectors.selectAll(data).findIndex((i) => i.image_name === lastSelectedImage.image_name);
}, [lastSelectedImage, data]);
const handleNavigation = useCallback(
(direction: 'left' | 'right' | 'up' | 'down', alt?: boolean) => {
const { index, image } = getImageFuncs[direction](imageDTOs, lastSelectedImageIndex);
if (!data) {
return;
}
const { index, image } = getImageFuncs[direction](imagesSelectors.selectAll(data), lastSelectedImageIndex);
if (!image || index === lastSelectedImageIndex) {
return;
}
@@ -157,7 +161,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
}
scrollToImage(image.image_name, index);
},
[imageDTOs, lastSelectedImageIndex, dispatch]
[data, lastSelectedImageIndex, dispatch]
);
const isOnFirstImage = useMemo(() => lastSelectedImageIndex === 0, [lastSelectedImageIndex]);
@@ -172,14 +176,6 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
return lastSelectedImageIndex + imagesPerRow < loadedImagesCount;
}, [lastSelectedImageIndex, loadedImagesCount]);
const isOnFirstImageOfView = useMemo(() => {
return lastSelectedImageIndex === 0;
}, [lastSelectedImageIndex]);
const isOnLastImageOfView = useMemo(() => {
return lastSelectedImageIndex === loadedImagesCount - 1;
}, [lastSelectedImageIndex, loadedImagesCount]);
const handleLeftImage = useCallback(
(alt?: boolean) => {
handleNavigation('left', alt);
@@ -226,7 +222,5 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
areImagesBelowCurrent,
nextImage,
prevImage,
isOnFirstImageOfView,
isOnLastImageOfView,
};
};

View File

@@ -1,131 +0,0 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { offsetChanged } from 'features/gallery/store/gallerySlice';
import { useCallback, useEffect, useMemo } from 'react';
import { useListImagesQuery } from 'services/api/endpoints/images';
export const useGalleryPagination = (pageButtonsPerSide: number = 2) => {
const dispatch = useAppDispatch();
const { offset, limit } = useAppSelector((s) => s.gallery);
const queryArgs = useAppSelector(selectListImagesQueryArgs);
const { count, total } = useListImagesQuery(queryArgs, {
selectFromResult: ({ data }) => ({ count: data?.items.length ?? 0, total: data?.total ?? 0 }),
});
const currentPage = useMemo(() => Math.ceil(offset / (limit || 0)), [offset, limit]);
const pages = useMemo(() => Math.ceil(total / (limit || 0)), [total, limit]);
const isNextEnabled = useMemo(() => {
if (!count) {
return false;
}
return currentPage + 1 < pages;
}, [count, currentPage, pages]);
const isPrevEnabled = useMemo(() => {
if (!count) {
return false;
}
return offset > 0;
}, [count, offset]);
const goNext = useCallback(() => {
dispatch(offsetChanged(offset + (limit || 0)));
}, [dispatch, offset, limit]);
const goPrev = useCallback(() => {
dispatch(offsetChanged(Math.max(offset - (limit || 0), 0)));
}, [dispatch, offset, limit]);
const goToPage = useCallback(
(page: number) => {
dispatch(offsetChanged(page * (limit || 0)));
},
[dispatch, limit]
);
const goToFirst = useCallback(() => {
dispatch(offsetChanged(0));
}, [dispatch]);
const goToLast = useCallback(() => {
dispatch(offsetChanged((pages - 1) * (limit || 0)));
}, [dispatch, pages, limit]);
// handle when total/pages decrease and user is on high page number (ie bulk removing or deleting)
useEffect(() => {
if (pages && currentPage + 1 > pages) {
goToLast();
}
}, [currentPage, pages, goToLast]);
// calculate the page buttons to display - current page with 3 around it
const pageButtons = useMemo(() => {
const buttons = [];
const maxPageButtons = pageButtonsPerSide * 2 + 1;
let startPage = Math.max(currentPage - Math.floor(maxPageButtons / 2), 0);
const endPage = Math.min(startPage + maxPageButtons - 1, pages - 1);
if (endPage - startPage < maxPageButtons - 1) {
startPage = Math.max(endPage - maxPageButtons + 1, 0);
}
for (let i = startPage; i <= endPage; i++) {
buttons.push(i);
}
return buttons;
}, [currentPage, pageButtonsPerSide, pages]);
const isFirstEnabled = useMemo(() => currentPage > 0, [currentPage]);
const isLastEnabled = useMemo(() => currentPage < pages - 1, [currentPage, pages]);
const rangeDisplay = useMemo(() => {
const startItem = currentPage * (limit || 0) + 1;
const endItem = Math.min((currentPage + 1) * (limit || 0), total);
return `${startItem}-${endItem} of ${total}`;
}, [total, currentPage, limit]);
const numberOnPage = useMemo(() => {
return Math.min((currentPage + 1) * (limit || 0), total);
}, [currentPage, limit, total]);
const api = useMemo(
() => ({
count,
total,
currentPage,
pages,
isNextEnabled,
isPrevEnabled,
goNext,
goPrev,
goToPage,
goToFirst,
goToLast,
pageButtons,
isFirstEnabled,
isLastEnabled,
rangeDisplay,
numberOnPage,
}),
[
count,
total,
currentPage,
pages,
isNextEnabled,
isPrevEnabled,
goNext,
goPrev,
goToPage,
goToFirst,
goToLast,
pageButtons,
isFirstEnabled,
isLastEnabled,
rangeDisplay,
numberOnPage,
]
);
return api;
};

View File

@@ -1,5 +1,3 @@
import type { SkipToken } from '@reduxjs/toolkit/query';
import { skipToken } from '@reduxjs/toolkit/query';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
@@ -12,15 +10,11 @@ export const selectLastSelectedImage = createMemoizedSelector(
export const selectListImagesQueryArgs = createMemoizedSelector(
selectGallerySlice,
(gallery): ListImagesArgs | SkipToken =>
gallery.limit
? {
board_id: gallery.selectedBoardId,
categories: gallery.galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES,
offset: gallery.offset,
limit: gallery.limit,
is_intermediate: false,
search_term: gallery.searchTerm,
}
: skipToken
(gallery): ListImagesArgs => ({
board_id: gallery.selectedBoardId,
categories: gallery.galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES,
offset: gallery.offset,
limit: gallery.limit,
is_intermediate: false,
})
);

View File

@@ -7,7 +7,7 @@ import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import type { BoardId, ComparisonMode, GalleryState, GalleryView } from './types';
import { IMAGE_LIMIT } from './types';
import { IMAGE_LIMIT, INITIAL_IMAGE_LIMIT } from './types';
const initialGalleryState: GalleryState = {
selection: [],
@@ -19,7 +19,7 @@ const initialGalleryState: GalleryState = {
selectedBoardId: 'none',
galleryView: 'images',
boardSearchText: '',
limit: 20,
limit: INITIAL_IMAGE_LIMIT,
offset: 0,
isImageViewerOpen: true,
imageToCompare: null,
@@ -72,6 +72,7 @@ export const gallerySlice = createSlice({
state.selectedBoardId = action.payload.boardId;
state.galleryView = 'images';
state.offset = 0;
state.limit = INITIAL_IMAGE_LIMIT;
},
autoAddBoardIdChanged: (state, action: PayloadAction<BoardId>) => {
if (!action.payload) {
@@ -83,11 +84,20 @@ export const gallerySlice = createSlice({
galleryViewChanged: (state, action: PayloadAction<GalleryView>) => {
state.galleryView = action.payload;
state.offset = 0;
state.limit = IMAGE_LIMIT;
state.limit = INITIAL_IMAGE_LIMIT;
},
boardSearchTextChanged: (state, action: PayloadAction<string>) => {
state.boardSearchText = action.payload;
},
moreImagesLoaded: (state) => {
if (state.offset === 0 && state.limit === INITIAL_IMAGE_LIMIT) {
state.offset = INITIAL_IMAGE_LIMIT;
state.limit = IMAGE_LIMIT;
} else {
state.offset += IMAGE_LIMIT;
state.limit += IMAGE_LIMIT;
}
},
alwaysShowImageSizeBadgeChanged: (state, action: PayloadAction<boolean>) => {
state.alwaysShowImageSizeBadge = action.payload;
},
@@ -104,15 +114,6 @@ export const gallerySlice = createSlice({
comparisonFitChanged: (state, action: PayloadAction<'contain' | 'fill'>) => {
state.comparisonFit = action.payload;
},
offsetChanged: (state, action: PayloadAction<number>) => {
state.offset = action.payload;
},
limitChanged: (state, action: PayloadAction<number>) => {
state.limit = action.payload;
},
searchTermChanged: (state, action: PayloadAction<string | undefined>) => {
state.searchTerm = action.payload;
},
},
extraReducers: (builder) => {
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
@@ -148,6 +149,7 @@ export const {
galleryViewChanged,
selectionChanged,
boardSearchTextChanged,
moreImagesLoaded,
alwaysShowImageSizeBadgeChanged,
isImageViewerOpenChanged,
imageToCompareChanged,
@@ -155,9 +157,6 @@ export const {
comparedImagesSwapped,
comparisonFitChanged,
comparisonModeCycled,
offsetChanged,
limitChanged,
searchTermChanged,
} = gallerySlice.actions;
const isAnyBoardDeleted = isAnyOf(

View File

@@ -2,7 +2,8 @@ import type { ImageCategory, ImageDTO } from 'services/api/types';
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other'];
export const IMAGE_LIMIT = 15;
export const INITIAL_IMAGE_LIMIT = 100;
export const IMAGE_LIMIT = 20;
export type GalleryView = 'images' | 'assets';
export type BoardId = 'none' | (string & Record<never, never>);
@@ -20,7 +21,6 @@ export type GalleryState = {
boardSearchText: string;
offset: number;
limit: number;
searchTerm?: string;
alwaysShowImageSizeBadge: boolean;
imageToCompare: ImageDTO | null;
comparisonMode: ComparisonMode;

View File

@@ -11,6 +11,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
any: 'base',
'sd-1': 'green',
'sd-2': 'teal',
'sd-3': 'purple',
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
};

View File

@@ -10,6 +10,7 @@ import type { UpdateModelArg } from 'services/api/endpoints/models';
const options: ComboboxOption[] = [
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
{ value: 'sd-3', label: MODEL_TYPE_MAP['sd-3'] },
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
];

View File

@@ -28,6 +28,8 @@ import {
isModelIdentifierFieldInputTemplate,
isSchedulerFieldInputInstance,
isSchedulerFieldInputTemplate,
isSD3MainModelFieldInputInstance,
isSD3MainModelFieldInputTemplate,
isSDXLMainModelFieldInputInstance,
isSDXLMainModelFieldInputTemplate,
isSDXLRefinerModelFieldInputInstance,
@@ -53,6 +55,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent'
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
@@ -133,6 +136,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@@ -0,0 +1,55 @@
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useSD3Models } from 'services/api/hooks/modelsByType';
import type { MainModelConfig } from 'services/api/types';
import type { FieldComponentProps } from './types';
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
const SD3MainModelFieldInputComponent = (props: Props) => {
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const [modelConfigs, { isLoading }] = useSD3Models();
const _onChange = useCallback(
(value: MainModelConfig | null) => {
if (!value) {
return;
}
dispatch(
fieldMainModelValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
modelConfigs,
onChange: _onChange,
isLoading,
selectedModel: field.value,
});
return (
<Flex w="full" alignItems="center" gap={2}>
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
<Combobox
value={value}
placeholder={placeholder}
options={options}
onChange={onChange}
noOptionsMessage={noOptionsMessage}
/>
</FormControl>
</Flex>
);
};
export default memo(SD3MainModelFieldInputComponent);

View File

@@ -631,6 +631,7 @@ export const schema = {
'euler',
'euler_k',
'euler_a',
'euler_f',
'kdpm_2',
'kdpm_2_a',
'dpmpp_2s',
@@ -694,6 +695,7 @@ export const schema = {
'euler',
'euler_k',
'euler_a',
'euler_f',
'kdpm_2',
'kdpm_2_a',
'dpmpp_2s',
@@ -839,7 +841,7 @@ export const schema = {
},
BaseModelType: {
description: 'Base model type.',
enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
enum: ['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner'],
title: 'BaseModelType',
type: 'string',
},
@@ -855,8 +857,11 @@ export const schema = {
'unet',
'text_encoder',
'text_encoder_2',
'text_encoder_3',
'tokenizer',
'tokenizer_2',
'tokenizer_3',
'transformer',
'vae',
'vae_decoder',
'vae_encoder',

View File

@@ -47,6 +47,7 @@ export const zSchedulerField = z.enum([
'heun_k',
'lms_k',
'euler_a',
'euler_f',
'kdpm_2_a',
'lcm',
'tcd',
@@ -55,7 +56,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner']);
const zModelType = z.enum([
'main',
'vae',
@@ -71,8 +72,11 @@ const zSubModelType = z.enum([
'unet',
'text_encoder',
'text_encoder_2',
'text_encoder_3',
'tokenizer',
'tokenizer_2',
'tokenizer_3',
'transformer',
'vae',
'vae_decoder',
'vae_encoder',

View File

@@ -32,11 +32,14 @@ export const MODEL_TYPES = [
'LoRAModelField',
'MainModelField',
'SDXLMainModelField',
'SD3MainModelField',
'SDXLRefinerModelField',
'VaeModelField',
'UNetField',
'TransformerField',
'VAEField',
'CLIPField',
'SD3CLIPField',
'T2IAdapterModelField',
];
@@ -47,6 +50,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
BoardField: 'purple.500',
BooleanField: 'green.500',
CLIPField: 'green.500',
SD3CLIPField: 'green.500',
ColorField: 'pink.300',
ConditioningField: 'cyan.500',
ControlField: 'teal.500',
@@ -62,10 +66,12 @@ export const FIELD_COLORS: { [key: string]: string } = {
MainModelField: 'teal.500',
SDXLMainModelField: 'teal.500',
SDXLRefinerModelField: 'teal.500',
SD3MainModelField: 'teal.500',
StringField: 'yellow.500',
T2IAdapterField: 'teal.500',
T2IAdapterModelField: 'teal.500',
UNetField: 'red.500',
TransformerField: 'red.500',
VAEField: 'blue.500',
VAEModelField: 'teal.500',
};

View File

@@ -119,6 +119,10 @@ const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
name: z.literal('SDXLRefinerModelField'),
originalType: zStatelessFieldType.optional(),
});
const zSD3MainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SD3MainModelField'),
originalType: zStatelessFieldType.optional(),
});
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
originalType: zStatelessFieldType.optional(),
@@ -155,6 +159,7 @@ const zStatefulFieldType = z.union([
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
zSD3MainModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zControlNetModelFieldType,
@@ -466,6 +471,28 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region SD3MainModelField
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
value: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zSD3MainModelFieldType,
originalType: zFieldType.optional(),
default: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zSD3MainModelFieldType,
});
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
zSD3MainModelFieldInputInstance.safeParse(val).success;
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
zSD3MainModelFieldInputTemplate.safeParse(val).success;
// #endregion
// #region VAEModelField
export const zVAEModelFieldValue = zModelIdentifierField.optional();
@@ -662,6 +689,7 @@ export const zStatefulFieldValue = z.union([
zMainModelFieldValue,
zSDXLMainModelFieldValue,
zSDXLRefinerModelFieldValue,
zSD3MainModelFieldValue,
zVAEModelFieldValue,
zLoRAModelFieldValue,
zControlNetModelFieldValue,
@@ -689,6 +717,7 @@ const zStatefulFieldInputInstance = z.union([
zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance,
zControlNetModelFieldInputInstance,
@@ -717,6 +746,7 @@ const zStatefulFieldInputTemplate = z.union([
zMainModelFieldInputTemplate,
zSDXLMainModelFieldInputTemplate,
zSDXLRefinerModelFieldInputTemplate,
zSD3MainModelFieldInputTemplate,
zVAEModelFieldInputTemplate,
zLoRAModelFieldInputTemplate,
zControlNetModelFieldInputTemplate,
@@ -746,6 +776,7 @@ const zStatefulFieldOutputTemplate = z.union([
zMainModelFieldOutputTemplate,
zSDXLMainModelFieldOutputTemplate,
zSDXLRefinerModelFieldOutputTemplate,
zSD3MainModelFieldOutputTemplate,
zVAEModelFieldOutputTemplate,
zLoRAModelFieldOutputTemplate,
zControlNetModelFieldOutputTemplate,

View File

@@ -44,7 +44,7 @@ export const zSchedulerField = z.enum([
// #endregion
// #region Model-related schemas
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner']);
const zModelName = z.string().min(3);
export const zModelIdentifier = z.object({
model_name: zModelName,

View File

@@ -217,6 +217,20 @@ const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
});
// #endregion
// #region SDXLMainModelField
const zSD3MainModelFieldType = zFieldTypeBase.extend({
name: z.literal('SD3MainModelField'),
});
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
type: zSD3MainModelFieldType,
value: zSD3MainModelFieldValue,
});
const zSD3MainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
type: zSD3MainModelFieldType,
});
// #endregion
// #region VAEModelField
const zVAEModelFieldType = zFieldTypeBase.extend({
name: z.literal('VAEModelField'),
@@ -339,6 +353,7 @@ const zStatefulFieldType = z.union([
zMainModelFieldType,
zSDXLMainModelFieldType,
zSDXLRefinerModelFieldType,
zSD3MainModelFieldType,
zVAEModelFieldType,
zLoRAModelFieldType,
zControlNetModelFieldType,
@@ -378,6 +393,7 @@ const zStatefulFieldInputInstance = z.union([
zMainModelFieldInputInstance,
zSDXLMainModelFieldInputInstance,
zSDXLRefinerModelFieldInputInstance,
zSD3MainModelFieldInputInstance,
zVAEModelFieldInputInstance,
zLoRAModelFieldInputInstance,
zControlNetModelFieldInputInstance,
@@ -402,6 +418,7 @@ const zStatefulFieldOutputInstance = z.union([
zMainModelFieldOutputInstance,
zSDXLMainModelFieldOutputInstance,
zSDXLRefinerModelFieldOutputInstance,
zSD3MainModelFieldOutputInstance,
zVAEModelFieldOutputInstance,
zLoRAModelFieldOutputInstance,
zControlNetModelFieldOutputInstance,

View File

@@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
MainModelField: undefined,
SchedulerField: 'euler',
SDXLMainModelField: undefined,
SD3MainModelField: undefined,
SDXLRefinerModelField: undefined,
StringField: '',
T2IAdapterModelField: undefined,

View File

@@ -15,6 +15,7 @@ import type {
MainModelFieldInputTemplate,
ModelIdentifierFieldInputTemplate,
SchedulerFieldInputTemplate,
SD3MainModelFieldInputTemplate,
SDXLMainModelFieldInputTemplate,
SDXLRefinerModelFieldInputTemplate,
StatefulFieldType,
@@ -193,6 +194,20 @@ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefiner
return template;
};
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: SD3MainModelFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? undefined,
};
return template;
};
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -375,6 +390,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
SchedulerField: buildSchedulerFieldInputTemplate,
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
StringField: buildStringFieldInputTemplate,
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
VAEModelField: buildVAEModelFieldInputTemplate,

View File

@@ -30,6 +30,7 @@ const MODEL_FIELD_TYPES = [
'MainModelField',
'SDXLMainModelField',
'SDXLRefinerModelField',
'SD3MainModelField',
'VAEModelField',
'LoRAModelField',
'ControlNetModelField',

View File

@@ -39,7 +39,7 @@ const ParamClipSkip = () => {
return CLIP_SKIP_MAP[model.base].markers;
}, [model]);
if (model?.base === 'sdxl') {
if (model?.base === 'sdxl' || model?.base === 'sd-3') {
return null;
}

View File

@@ -7,6 +7,7 @@ export const MODEL_TYPE_MAP = {
any: 'Any',
'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x',
'sd-3': 'Stable Diffusion 3.x',
sdxl: 'Stable Diffusion XL',
'sdxl-refiner': 'Stable Diffusion XL Refiner',
};
@@ -18,6 +19,7 @@ export const MODEL_TYPE_SHORT_MAP = {
any: 'Any',
'sd-1': 'SD1.X',
'sd-2': 'SD2.X',
'sd-3': 'SD3.X',
sdxl: 'SDXL',
'sdxl-refiner': 'SDXLR',
};
@@ -38,6 +40,11 @@ export const CLIP_SKIP_MAP = {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
// TODO: Update this when we have more details on how CLIP SKIP works with SD3
'sd-3': {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
},
sdxl: {
maxClip: 24,
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
@@ -73,6 +80,7 @@ export const SCHEDULER_OPTIONS: ComboboxOption[] = [
{ value: 'heun_k', label: 'Heun Karras' },
{ value: 'lms_k', label: 'LMS Karras' },
{ value: 'euler_a', label: 'Euler Ancestral' },
{ value: 'euler_f', label: 'Euler Flow Match' },
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
{ value: 'lcm', label: 'LCM' },
{ value: 'tcd', label: 'TCD' },

File diff suppressed because it is too large Load Diff

View File

@@ -10,6 +10,7 @@ import {
isNonRefinerMainModelConfig,
isNonSDXLMainModelConfig,
isRefinerMainModelModelConfig,
isSD3MainModelModelConfig,
isSDXLMainModelModelConfig,
isT2IAdapterModelConfig,
isTIModelConfig,
@@ -35,6 +36,7 @@ export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);

View File

@@ -0,0 +1,18 @@
import { useAppSelector } from 'app/store/storeHooks';
import type { BoardId } from 'features/gallery/store/types';
import { useMemo } from 'react';
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
export const useBoardTotal = (board_id: BoardId) => {
const galleryView = useAppSelector((s) => s.gallery.galleryView);
const { data: totalImages } = useGetBoardImagesTotalQuery(board_id);
const { data: totalAssets } = useGetBoardAssetsTotalQuery(board_id);
const currentViewTotal = useMemo(
() => (galleryView === 'images' ? totalImages?.total : totalAssets?.total),
[galleryView, totalAssets, totalImages]
);
return { totalImages, totalAssets, currentViewTotal };
};

File diff suppressed because one or more lines are too long

View File

@@ -1,10 +1,12 @@
import type { EntityState } from '@reduxjs/toolkit';
import type { components, paths } from 'services/api/schema';
import type { O } from 'ts-toolbelt';
export type S = components['schemas'];
export type ImageCache = EntityState<ImageDTO, string>;
export type ListImagesArgs = NonNullable<paths['/api/v1/images/']['get']['parameters']['query']>;
export type ListImagesResponse = paths['/api/v1/images/']['get']['responses']['200']['content']['application/json'];
export type DeleteBoardResult =
paths['/api/v1/boards/{board_id}']['delete']['responses']['200']['content']['application/json'];
@@ -107,7 +109,11 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
};
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2' || config.base === 'sd-3');
};
export const isSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
return config.type === 'main' && config.base === 'sd-3';
};
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {

View File

@@ -1,8 +1,56 @@
import { createEntityAdapter } from '@reduxjs/toolkit';
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
import { dateComparator } from 'common/util/dateComparator';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
import queryString from 'query-string';
import { buildV1Url } from 'services/api';
import type { ImageDTO, ListImagesArgs } from './types';
import type { ImageCache, ImageDTO, ListImagesArgs } from './types';
export const getIsImageInDateRange = (data: ImageCache | undefined, imageDTO: ImageDTO) => {
if (!data) {
return false;
}
const totalCachedImageDtos = imagesSelectors.selectAll(data);
if (totalCachedImageDtos.length <= 1) {
return true;
}
const cachedStarredImages = [];
const cachedUnstarredImages = [];
for (let index = 0; index < totalCachedImageDtos.length; index++) {
const image = totalCachedImageDtos[index];
if (image?.starred) {
cachedStarredImages.push(image);
}
if (!image?.starred) {
cachedUnstarredImages.push(image);
}
}
if (imageDTO.starred) {
const lastStarredImage = cachedStarredImages[cachedStarredImages.length - 1];
// if starring or already starred, want to look in list of starred images
if (!lastStarredImage) {
return true;
} // no starred images showing, so always show this one
const createdDate = new Date(imageDTO.created_at);
const oldestDate = new Date(lastStarredImage.created_at);
return createdDate >= oldestDate;
} else {
const lastUnstarredImage = cachedUnstarredImages[cachedUnstarredImages.length - 1];
// if unstarring or already unstarred, want to look in list of unstarred images
if (!lastUnstarredImage) {
return false;
} // no unstarred images showing, so don't show this one
const createdDate = new Date(imageDTO.created_at);
const oldestDate = new Date(lastUnstarredImage.created_at);
return createdDate >= oldestDate;
}
};
export const getCategories = (imageDTO: ImageDTO) => {
if (IMAGE_CATEGORIES.includes(imageDTO.image_category)) {
@@ -11,6 +59,25 @@ export const getCategories = (imageDTO: ImageDTO) => {
return ASSETS_CATEGORIES;
};
// The adapter is not actually the data store - it just provides helper functions to interact
// with some other store of data. We will use the RTK Query cache as that store.
export const imagesAdapter = createEntityAdapter<ImageDTO, string>({
selectId: (image) => image.image_name,
sortComparer: (a, b) => {
// Compare starred images first
if (a.starred && !b.starred) {
return -1;
}
if (!a.starred && b.starred) {
return 1;
}
return dateComparator(b.created_at, a.created_at);
},
});
// Create selectors for the adapter.
export const imagesSelectors = imagesAdapter.getSelectors(undefined, getSelectorsOptions);
// Helper to create the url for the listImages endpoint. Also we use it to create the cache key.
export const getListImagesUrl = (queryArgs: ListImagesArgs) =>
buildV1Url(`images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`);

View File

@@ -39,6 +39,7 @@ from invokeai.app.invocations.model import (
ModelIdentifierField,
ModelLoaderOutput,
SDXLLoRALoaderOutput,
TransformerField,
UNetField,
UNetOutput,
VAEField,
@@ -117,6 +118,7 @@ __all__ = [
# invokeai.app.invocations.model
"ModelIdentifierField",
"UNetField",
"TransformerField",
"CLIPField",
"VAEField",
"UNetOutput",

View File

@@ -33,30 +33,32 @@ classifiers = [
]
dependencies = [
# Core generation dependencies, pinned for reproducible builds.
"accelerate==0.30.1",
"accelerate",
"bitsandbytes",
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==2.0.2",
"controlnet-aux==0.0.7",
"diffusers[torch]==0.27.2",
"diffusers[torch]",
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
"numpy", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
"onnx==1.15.0",
"onnxruntime==1.16.3",
"opencv-python==4.9.0.80",
"pytorch-lightning==2.1.3",
"pytorch-lightning",
"safetensors==0.4.3",
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
"torch==2.2.2",
"torchmetrics==0.11.4",
"torch",
"torchmetrics",
"torchsde==0.2.6",
"torchvision==0.17.2",
"transformers==4.41.1",
"torchvision",
"transformers",
"sentencepiece==0.1.99",
# Core application dependencies, pinned for reproducible builds.
"fastapi-events==0.11.0",
"fastapi==0.111.0",
"huggingface-hub==0.23.1",
"huggingface-hub",
"pydantic-settings==2.2.1",
"pydantic==2.7.2",
"python-socketio==5.11.1",
@@ -73,7 +75,7 @@ dependencies = [
"easing-functions",
"einops",
"facexlib",
"matplotlib", # needed for plotting of Penner easing functions
"matplotlib", # needed for plotting of Penner easing functions
"npyscreen",
"omegaconf",
"picklescan",