mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 05:58:05 -05:00
Compare commits
21 Commits
maryhipp/s
...
lstein/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39881d3d7d | ||
|
|
28f1d25973 | ||
|
|
95377ea159 | ||
|
|
445561e3a4 | ||
|
|
66260fd345 | ||
|
|
c403efa83f | ||
|
|
cd99ef2f46 | ||
|
|
9dce4f09ae | ||
|
|
22b5c036aa | ||
|
|
be14fd59c9 | ||
|
|
423057a2e8 | ||
|
|
f65d50a4dd | ||
|
|
554809c647 | ||
|
|
ac0396e6f7 | ||
|
|
78f704e7d5 | ||
|
|
41236031b2 | ||
|
|
ddbd2ebd9d | ||
|
|
0c970bc880 | ||
|
|
c79d9b9ecf | ||
|
|
03b9d17d0b | ||
|
|
002f8242a1 |
@@ -316,7 +316,6 @@ async def list_image_dtos(
|
||||
),
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of images per page"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of image DTOs"""
|
||||
|
||||
@@ -327,7 +326,6 @@ async def list_image_dtos(
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
search_term
|
||||
)
|
||||
|
||||
return image_dtos
|
||||
|
||||
@@ -42,6 +42,7 @@ class UIType(str, Enum, metaclass=MetaEnum):
|
||||
MainModel = "MainModelField"
|
||||
SDXLMainModel = "SDXLMainModelField"
|
||||
SDXLRefinerModel = "SDXLRefinerModelField"
|
||||
SD3MainModel = "SD3MainModelField"
|
||||
ONNXModel = "ONNXModelField"
|
||||
VAEModel = "VAEModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
@@ -125,6 +126,7 @@ class FieldDescriptions:
|
||||
noise = "Noise tensor"
|
||||
clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count"
|
||||
unet = "UNet (scheduler, LoRAs)"
|
||||
transformer = "Transformer"
|
||||
vae = "VAE"
|
||||
cond = "Conditioning tensor"
|
||||
controlnet_model = "ControlNet model to load"
|
||||
@@ -133,6 +135,7 @@ class FieldDescriptions:
|
||||
main_model = "Main model (UNet, VAE, CLIP) to load"
|
||||
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
|
||||
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
|
||||
sd3_main_model = "SD3 Main Model (Transformer, CLIP1, CLIP2, CLIP3, VAE) to load"
|
||||
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
|
||||
lora_weight = "The weight at which the LoRA is applied to each model"
|
||||
compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor"
|
||||
|
||||
@@ -12,14 +12,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
||||
from invokeai.app.invocations.constants import DEFAULT_PRECISION
|
||||
from invokeai.app.invocations.fields import (
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
LatentsField,
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, LatentsField, WithBoard, WithMetadata
|
||||
from invokeai.app.invocations.model import VAEField
|
||||
from invokeai.app.invocations.primitives import ImageOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
@@ -8,13 +8,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType, SubModelType
|
||||
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
|
||||
|
||||
|
||||
class ModelIdentifierField(BaseModel):
|
||||
@@ -54,6 +48,11 @@ class UNetField(BaseModel):
|
||||
freeu_config: Optional[FreeUConfig] = Field(default=None, description="FreeU configuration")
|
||||
|
||||
|
||||
class TransformerField(BaseModel):
|
||||
transformer: ModelIdentifierField = Field(description="Info to load unet submodel")
|
||||
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
|
||||
|
||||
|
||||
class CLIPField(BaseModel):
|
||||
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
|
||||
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
|
||||
@@ -61,6 +60,15 @@ class CLIPField(BaseModel):
|
||||
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
|
||||
|
||||
|
||||
class SD3CLIPField(BaseModel):
|
||||
tokenizer_1: ModelIdentifierField = Field(description="Info to load tokenizer 1 submodel")
|
||||
text_encoder_1: ModelIdentifierField = Field(description="Info to load text_encoder 1 submodel")
|
||||
tokenizer_2: ModelIdentifierField = Field(description="Info to load tokenizer 2 submodel")
|
||||
text_encoder_2: ModelIdentifierField = Field(description="Info to load text_encoder 2 submodel")
|
||||
tokenizer_3: Optional[ModelIdentifierField] = Field(description="Info to load tokenizer 3 submodel")
|
||||
text_encoder_3: Optional[ModelIdentifierField] = Field(description="Info to load text_encoder 3 submodel")
|
||||
|
||||
|
||||
class VAEField(BaseModel):
|
||||
vae: ModelIdentifierField = Field(description="Info to load vae submodel")
|
||||
seamless_axes: List[str] = Field(default_factory=list, description='Axes("x" and "y") to which apply seamless')
|
||||
|
||||
200
invokeai/app/invocations/sd3.py
Normal file
200
invokeai/app/invocations/sd3.py
Normal file
@@ -0,0 +1,200 @@
|
||||
from contextlib import ExitStack
|
||||
from typing import Optional, cast
|
||||
|
||||
import torch
|
||||
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
|
||||
from pydantic import field_validator
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Input,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
||||
from invokeai.app.invocations.denoise_latents import get_scheduler
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, InputField, LatentsField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField, SD3CLIPField, TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.misc import SEED_MAX
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
sd3_pipeline: Optional[StableDiffusion3Pipeline] = None
|
||||
|
||||
|
||||
class FakeVae:
|
||||
class FakeVaeConfig:
|
||||
def __init__(self) -> None:
|
||||
self.block_out_channels = [0]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.config = FakeVae.FakeVaeConfig()
|
||||
|
||||
|
||||
@invocation_output("sd3_model_loader_output")
|
||||
class SD3ModelLoaderOutput(BaseInvocationOutput):
|
||||
"""Stable Diffuion 3 base model loader output"""
|
||||
|
||||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
|
||||
clip: SD3CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP")
|
||||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
|
||||
|
||||
|
||||
@invocation("sd3_model_loader", title="SD3 Main Model", tags=["model", "sd3"], category="model", version="1.0.0")
|
||||
class SD3ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loads an SD3 base model, outputting its submodels."""
|
||||
|
||||
model: ModelIdentifierField = InputField(description=FieldDescriptions.sd3_main_model, ui_type=UIType.SD3MainModel)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> SD3ModelLoaderOutput:
|
||||
model_key = self.model.key
|
||||
|
||||
if not context.models.exists(model_key):
|
||||
raise Exception(f"Unknown model: {model_key}")
|
||||
|
||||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
|
||||
scheduler = self.model.model_copy(update={"submodel_type": SubModelType.Scheduler})
|
||||
tokenizer_1 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
|
||||
text_encoder_1 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
|
||||
tokenizer_2 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer2})
|
||||
text_encoder_2 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder2})
|
||||
try:
|
||||
tokenizer_3 = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer3})
|
||||
text_encoder_3 = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder3})
|
||||
except Exception:
|
||||
tokenizer_3 = None
|
||||
text_encoder_3 = None
|
||||
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
|
||||
|
||||
return SD3ModelLoaderOutput(
|
||||
transformer=TransformerField(transformer=transformer, scheduler=scheduler),
|
||||
clip=SD3CLIPField(
|
||||
tokenizer_1=tokenizer_1,
|
||||
text_encoder_1=text_encoder_1,
|
||||
tokenizer_2=tokenizer_2,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_3=tokenizer_3,
|
||||
text_encoder_3=text_encoder_3,
|
||||
),
|
||||
vae=VAEField(vae=vae),
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
"sd3_image_generator", title="Stable Diffusion 3", tags=["latent", "sd3"], category="latents", version="1.0.0"
|
||||
)
|
||||
class StableDiffusion3Invocation(BaseInvocation):
|
||||
"""Generates an image using Stable Diffusion 3."""
|
||||
|
||||
transformer: TransformerField = InputField(
|
||||
description=FieldDescriptions.transformer,
|
||||
input=Input.Connection,
|
||||
title="Transformer",
|
||||
ui_order=0,
|
||||
)
|
||||
clip: SD3CLIPField = InputField(
|
||||
description=FieldDescriptions.clip,
|
||||
input=Input.Connection,
|
||||
title="CLIP",
|
||||
ui_order=1,
|
||||
)
|
||||
noise: Optional[LatentsField] = InputField(
|
||||
default=None,
|
||||
description=FieldDescriptions.noise,
|
||||
input=Input.Connection,
|
||||
ui_order=2,
|
||||
)
|
||||
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
||||
default="euler_f",
|
||||
description=FieldDescriptions.scheduler,
|
||||
ui_type=UIType.Scheduler,
|
||||
)
|
||||
positive_prompt: str = InputField(default="", title="Positive Prompt")
|
||||
negative_prompt: str = InputField(default="", title="Negative Prompt")
|
||||
steps: int = InputField(default=20, gt=0, description=FieldDescriptions.steps)
|
||||
guidance_scale: float = InputField(default=7.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||
use_clip_3: bool = InputField(default=True, description="Use TE5 Encoder of SD3", title="Use TE5 Encoder")
|
||||
|
||||
seed: int = InputField(
|
||||
default=0,
|
||||
ge=0,
|
||||
le=SEED_MAX,
|
||||
description=FieldDescriptions.seed,
|
||||
)
|
||||
width: int = InputField(
|
||||
default=1024,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
gt=0,
|
||||
description=FieldDescriptions.width,
|
||||
)
|
||||
height: int = InputField(
|
||||
default=1024,
|
||||
multiple_of=LATENT_SCALE_FACTOR,
|
||||
gt=0,
|
||||
description=FieldDescriptions.height,
|
||||
)
|
||||
|
||||
@field_validator("seed", mode="before")
|
||||
def modulo_seed(cls, v: int):
|
||||
"""Return the seed modulo (SEED_MAX + 1) to ensure it is within the valid range."""
|
||||
return v % (SEED_MAX + 1)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
with ExitStack() as stack:
|
||||
tokenizer_1 = stack.enter_context(context.models.load(self.clip.tokenizer_1))
|
||||
tokenizer_2 = stack.enter_context(context.models.load(self.clip.tokenizer_2))
|
||||
text_encoder_1 = stack.enter_context(context.models.load(self.clip.text_encoder_1))
|
||||
text_encoder_2 = stack.enter_context(context.models.load(self.clip.text_encoder_2))
|
||||
transformer = stack.enter_context(context.models.load(self.transformer.transformer))
|
||||
|
||||
assert isinstance(transformer, SD3Transformer2DModel)
|
||||
assert isinstance(text_encoder_1, CLIPTextModelWithProjection)
|
||||
assert isinstance(text_encoder_2, CLIPTextModelWithProjection)
|
||||
assert isinstance(tokenizer_1, CLIPTokenizer)
|
||||
assert isinstance(tokenizer_2, CLIPTokenizer)
|
||||
|
||||
if self.use_clip_3 and self.clip.tokenizer_3 and self.clip.text_encoder_3:
|
||||
tokenizer_3 = stack.enter_context(context.models.load(self.clip.tokenizer_3))
|
||||
text_encoder_3 = stack.enter_context(context.models.load(self.clip.text_encoder_3))
|
||||
assert isinstance(text_encoder_3, T5EncoderModel)
|
||||
assert isinstance(tokenizer_3, T5TokenizerFast)
|
||||
else:
|
||||
tokenizer_3 = None
|
||||
text_encoder_3 = None
|
||||
|
||||
scheduler = get_scheduler(
|
||||
context=context,
|
||||
scheduler_info=self.transformer.scheduler,
|
||||
scheduler_name=self.scheduler,
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
sd3_pipeline = StableDiffusion3Pipeline(
|
||||
transformer=transformer,
|
||||
vae=FakeVae(),
|
||||
text_encoder=text_encoder_1,
|
||||
text_encoder_2=text_encoder_2,
|
||||
text_encoder_3=text_encoder_3,
|
||||
tokenizer=tokenizer_1,
|
||||
tokenizer_2=tokenizer_2,
|
||||
tokenizer_3=tokenizer_3,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
results = sd3_pipeline(
|
||||
self.positive_prompt,
|
||||
negative_prompt=self.negative_prompt,
|
||||
num_inference_steps=self.steps,
|
||||
guidance_scale=self.guidance_scale,
|
||||
output_type="latent",
|
||||
)
|
||||
|
||||
latents = cast(torch.Tensor, results.images[0])
|
||||
latents = latents.unsqueeze(0)
|
||||
|
||||
latents_name = context.tensors.save(latents)
|
||||
return LatentsOutput.build(latents_name, latents=latents, seed=self.seed)
|
||||
@@ -32,6 +32,7 @@ ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
|
||||
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
|
||||
LOG_FORMAT = Literal["plain", "color", "syslog", "legacy"]
|
||||
LOG_LEVEL = Literal["debug", "info", "warning", "error", "critical"]
|
||||
SYSTEM_RAM_TO_CACHE_SIZE_FACTOR = 0.25 # after 60 GB, default ram cache will scale by this factor
|
||||
CONFIG_SCHEMA_VERSION = "4.0.1"
|
||||
|
||||
|
||||
@@ -45,7 +46,7 @@ def get_default_ram_cache_size() -> float:
|
||||
max_ram = psutil.virtual_memory().total / GB
|
||||
|
||||
if max_ram >= 60:
|
||||
return 15.0
|
||||
return max_ram * SYSTEM_RAM_TO_CACHE_SIZE_FACTOR
|
||||
if max_ram >= 30:
|
||||
return 7.5
|
||||
if max_ram >= 14:
|
||||
|
||||
@@ -41,7 +41,6 @@ class ImageRecordStorageBase(ABC):
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
@@ -148,7 +148,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@@ -209,13 +208,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
# Search term condition
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND json_extract(images.metadata, '$') LIKE ?
|
||||
"""
|
||||
query_params.append(f'%{search_term}%')
|
||||
|
||||
query_pagination = """--sql
|
||||
ORDER BY images.starred DESC, images.created_at DESC LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
@@ -120,7 +120,6 @@ class ImageServiceABC(ABC):
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@@ -206,7 +206,6 @@ class ImageService(ImageServiceABC):
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self.__invoker.services.image_records.get_many(
|
||||
@@ -216,7 +215,6 @@ class ImageService(ImageServiceABC):
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
search_term
|
||||
)
|
||||
|
||||
image_dtos = [
|
||||
|
||||
@@ -52,6 +52,7 @@ class BaseModelType(str, Enum):
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
StableDiffusionXLRefiner = "sdxl-refiner"
|
||||
StableDiffusion3 = "sd-3"
|
||||
# Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
|
||||
@@ -75,8 +76,11 @@ class SubModelType(str, Enum):
|
||||
UNet = "unet"
|
||||
TextEncoder = "text_encoder"
|
||||
TextEncoder2 = "text_encoder_2"
|
||||
TextEncoder3 = "text_encoder_3"
|
||||
Tokenizer = "tokenizer"
|
||||
Tokenizer2 = "tokenizer_2"
|
||||
Tokenizer3 = "tokenizer_3"
|
||||
Transformer = "transformer"
|
||||
VAE = "vae"
|
||||
VAEDecoder = "vae_decoder"
|
||||
VAEEncoder = "vae_encoder"
|
||||
|
||||
@@ -84,6 +84,8 @@ class ModelLoader(ModelLoaderBase):
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
self._logger.info(f"Loading {config.key}:{submodel_type}")
|
||||
|
||||
cache_path: Path = self._convert_cache.cache_path(str(model_path))
|
||||
if self._needs_conversion(config, model_path, cache_path):
|
||||
loaded_model = self._do_convert(config, model_path, cache_path, submodel_type)
|
||||
|
||||
@@ -73,6 +73,7 @@ class CacheRecord(Generic[T]):
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
is_quantized: bool = False
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
|
||||
@@ -60,9 +60,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
execution_device: torch.device = torch.device("cuda"),
|
||||
storage_device: torch.device = torch.device("cpu"),
|
||||
precision: torch.dtype = torch.float16,
|
||||
sequential_offload: bool = False,
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
log_memory_usage: bool = False,
|
||||
logger: Optional[Logger] = None,
|
||||
):
|
||||
@@ -74,7 +72,6 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
:param storage_device: Torch device to save inactive model in [torch.device('cpu')]
|
||||
:param precision: Precision for loaded models [torch.float16]
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
@@ -163,8 +160,18 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
size = calc_model_size_by_data(model)
|
||||
self.make_room(size)
|
||||
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
is_quantized = hasattr(model, "is_quantized") and model.is_quantized
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not is_quantized else None
|
||||
cache_record = CacheRecord(
|
||||
key=key,
|
||||
model=model,
|
||||
device=self._execution_device
|
||||
if is_quantized
|
||||
else self._storage_device, # quantized models are loaded directly into CUDA
|
||||
is_quantized=is_quantized,
|
||||
state_dict=state_dict,
|
||||
size=size,
|
||||
)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
@@ -233,8 +240,23 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if vram_in_use <= reserved:
|
||||
break
|
||||
|
||||
# Special handling of the stable-diffusion-3:text_encoder_3
|
||||
# submodel, when the user has loaded a quantized model.
|
||||
# The only way to remove the quantized version of this model from VRAM is to
|
||||
# delete it completely - it can't be moved from device to device
|
||||
# This also contains a workaround for quantized models that
|
||||
# persist indefinitely in VRAM
|
||||
if cache_entry.is_quantized:
|
||||
self._empty_quantized_state_dict(cache_entry.model)
|
||||
cache_entry.model = None
|
||||
self._delete_cache_entry(cache_entry)
|
||||
vram_in_use = torch.cuda.memory_allocated() + size_required
|
||||
continue
|
||||
|
||||
if not cache_entry.loaded:
|
||||
continue
|
||||
|
||||
if not cache_entry.locked:
|
||||
self.move_model_to_device(cache_entry, self.storage_device)
|
||||
cache_entry.loaded = False
|
||||
@@ -242,7 +264,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self.logger.debug(
|
||||
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
|
||||
)
|
||||
|
||||
gc.collect()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
|
||||
@@ -256,7 +278,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
|
||||
source_device = cache_entry.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# Note: We compare device types so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
@@ -407,3 +429,20 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
del cache_entry
|
||||
gc.collect()
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
def _empty_quantized_state_dict(self, model: AnyModel) -> None:
|
||||
"""Set all keys of a model's state dict to None.
|
||||
|
||||
This is a partial workaround for a poorly-understood bug in
|
||||
transformers' support for quantized T5EncoderModels (text_encoder_3
|
||||
of SD3). This allows most of the model to be unloaded from VRAM, but
|
||||
still leaks 8K of VRAM each time the model is unloaded. Using the quantized
|
||||
version of stable-diffusion-3-medium is NOT recommended.
|
||||
"""
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
sd = model.state_dict()
|
||||
for k in sd.keys():
|
||||
sd[k] = None
|
||||
|
||||
@@ -36,9 +36,11 @@ VARIANT_TO_IN_CHANNEL_MAP = {
|
||||
class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
"""Class to load main models."""
|
||||
|
||||
# note - will be removed for load_single_file()
|
||||
model_base_to_model_type = {
|
||||
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
||||
BaseModelType.StableDiffusion3: "SD3",
|
||||
BaseModelType.StableDiffusionXL: "SDXL",
|
||||
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
||||
}
|
||||
@@ -65,7 +67,10 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader):
|
||||
if variant and "no file named" in str(
|
||||
e
|
||||
): # try without the variant, just in case user's preferences changed
|
||||
result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype)
|
||||
result = load_class.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=self._torch_dtype,
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
@@ -100,6 +100,7 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||
"StableDiffusion3Pipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.VAE,
|
||||
"AutoencoderTiny": ModelType.VAE,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
@@ -298,10 +299,13 @@ class ModelProbe(object):
|
||||
return possible_conf.absolute()
|
||||
|
||||
if model_type is ModelType.Main:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
config_file = f"stable-diffusion/{config_file}"
|
||||
if base_type is BaseModelType.StableDiffusion3:
|
||||
config_file = "stable-diffusion/v3-inference.yaml"
|
||||
else:
|
||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
config_file = config_file[prediction_type]
|
||||
config_file = f"stable-diffusion/{config_file}"
|
||||
elif model_type is ModelType.ControlNet:
|
||||
config_file = (
|
||||
"controlnet/cldm_v15.yaml"
|
||||
@@ -374,7 +378,7 @@ def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[Con
|
||||
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
|
||||
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
|
||||
return MainModelDefaultSettings(width=512, height=512)
|
||||
elif model_base is BaseModelType.StableDiffusionXL:
|
||||
elif model_base in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusion3]:
|
||||
return MainModelDefaultSettings(width=1024, height=1024)
|
||||
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
|
||||
return None
|
||||
@@ -398,7 +402,10 @@ class CheckpointProbeBase(ProbeBase):
|
||||
if model_type != ModelType.Main:
|
||||
return ModelVariantType.Normal
|
||||
state_dict = self.checkpoint.get("state_dict") or self.checkpoint
|
||||
in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
key = "model.diffusion_model.input_blocks.0.0.weight"
|
||||
if key not in state_dict:
|
||||
return ModelVariantType.Normal
|
||||
in_channels = state_dict[key].shape[1]
|
||||
if in_channels == 9:
|
||||
return ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
@@ -425,6 +432,9 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||
return BaseModelType.StableDiffusionXLRefiner
|
||||
key_name = "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight"
|
||||
if key_name in state_dict:
|
||||
return BaseModelType.StableDiffusion3
|
||||
else:
|
||||
raise InvalidModelConfigException("Cannot determine base type")
|
||||
|
||||
@@ -596,6 +606,10 @@ class FolderProbeBase(ProbeBase):
|
||||
|
||||
class PipelineFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
with open(self.model_path / "model_index.json", "r") as file:
|
||||
index_conf = json.load(file)
|
||||
if index_conf.get("_class_name") == "StableDiffusion3Pipeline":
|
||||
return BaseModelType.StableDiffusion3
|
||||
with open(self.model_path / "unet" / "config.json", "r") as file:
|
||||
unet_conf = json.load(file)
|
||||
if unet_conf["cross_attention_dim"] == 768:
|
||||
@@ -644,6 +658,8 @@ class VaeFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
if self._config_looks_like_sdxl():
|
||||
return BaseModelType.StableDiffusionXL
|
||||
elif self._config_looks_like_sd3():
|
||||
return BaseModelType.StableDiffusion3
|
||||
elif self._name_looks_like_sdxl():
|
||||
# but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
|
||||
# by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
|
||||
@@ -663,6 +679,15 @@ class VaeFolderProbe(FolderProbeBase):
|
||||
def _name_looks_like_sdxl(self) -> bool:
|
||||
return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))
|
||||
|
||||
def _config_looks_like_sd3(self) -> bool:
|
||||
# config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
|
||||
config_file = self.model_path / "config.json"
|
||||
if not config_file.exists():
|
||||
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
return config.get("scaling_factor", 0) == 1.5305 and config.get("sample_size") in [512, 1024]
|
||||
|
||||
def _guess_name(self) -> str:
|
||||
name = self.model_path.name
|
||||
if name == "vae":
|
||||
|
||||
@@ -122,6 +122,13 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
type=ModelType.Main,
|
||||
dependencies=[sdxl_fp16_vae_fix],
|
||||
),
|
||||
StarterModel(
|
||||
name="Stable Diffusion 3",
|
||||
base=BaseModelType.StableDiffusion3,
|
||||
source="stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
description="The OG Stable Diffusion 3 base model **NOT FOR COMMERCIAL USE**.",
|
||||
type=ModelType.Main,
|
||||
),
|
||||
# endregion
|
||||
# region VAE
|
||||
sdxl_fp16_vae_fix,
|
||||
|
||||
@@ -35,6 +35,18 @@ def filter_files(
|
||||
The file list can be obtained from the `files` field of HuggingFaceMetadata,
|
||||
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
|
||||
"""
|
||||
|
||||
# BRITTLENESS WARNING!!
|
||||
# The following pattern is designed to match model files that are components of diffusers submodels,
|
||||
# but not to match other random stuff found in huggingface repos.
|
||||
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
|
||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||
# will adhere to this naming convention, so this is an area to be careful of.
|
||||
DIFFUSERS_COMPONENT_PATTERN = (
|
||||
r"model(-fp16)?(-\d+-of-\d+)?(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$"
|
||||
)
|
||||
|
||||
variant = variant or ModelRepoVariant.Default
|
||||
paths: List[Path] = []
|
||||
root = files[0].parts[0]
|
||||
@@ -45,31 +57,26 @@ def filter_files(
|
||||
|
||||
# Start by filtering on model file extensions, discarding images, docs, etc
|
||||
for file in files:
|
||||
if file.name.endswith((".json", ".txt")):
|
||||
paths.append(file)
|
||||
elif file.name.endswith(
|
||||
if file.name.endswith(
|
||||
(
|
||||
".json",
|
||||
".txt",
|
||||
"learned_embeds.bin",
|
||||
"ip_adapter.bin",
|
||||
"lora_weights.safetensors",
|
||||
"weights.pb",
|
||||
"onnx_data",
|
||||
"spiece.model",
|
||||
)
|
||||
):
|
||||
paths.append(file)
|
||||
# BRITTLENESS WARNING!!
|
||||
# Diffusers models always seem to have "model" in their name, and the regex filter below is applied to avoid
|
||||
# downloading random checkpoints that might also be in the repo. However there is no guarantee
|
||||
# that a checkpoint doesn't contain "model" in its name, and no guarantee that future diffusers models
|
||||
# will adhere to this naming convention, so this is an area to be careful of.
|
||||
elif re.search(r"model(\.[^.]+)?\.(safetensors|bin|onnx|xml|pth|pt|ckpt|msgpack)$", file.name):
|
||||
elif re.search(DIFFUSERS_COMPONENT_PATTERN, file.name):
|
||||
paths.append(file)
|
||||
|
||||
# limit search to subfolder if requested
|
||||
if subfolder:
|
||||
subfolder = root / subfolder
|
||||
paths = [x for x in paths if x.parent == Path(subfolder)]
|
||||
|
||||
# _filter_by_variant uniquifies the paths and returns a set
|
||||
return sorted(_filter_by_variant(paths, variant))
|
||||
|
||||
@@ -97,9 +104,22 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
if variant == ModelRepoVariant.Flax:
|
||||
result.add(path)
|
||||
|
||||
elif path.suffix in [".json", ".txt"]:
|
||||
elif path.suffix in [".json", ".txt", ".model"]:
|
||||
result.add(path)
|
||||
|
||||
# handle shard patterns
|
||||
elif re.match(r"model\.fp16-\d+-of-\d+\.safetensors", path.name):
|
||||
if variant is ModelRepoVariant.FP16:
|
||||
result.add(path)
|
||||
else:
|
||||
continue
|
||||
|
||||
elif re.match(r"model-\d+-of-\d+\.safetensors", path.name):
|
||||
if variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]:
|
||||
result.add(path)
|
||||
else:
|
||||
continue
|
||||
|
||||
elif variant in [
|
||||
ModelRepoVariant.FP16,
|
||||
ModelRepoVariant.FP32,
|
||||
@@ -123,6 +143,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
score += 1
|
||||
|
||||
candidate_variant_label = path.suffixes[0] if len(path.suffixes) == 2 else None
|
||||
candidate_variant_label, *_ = str(candidate_variant_label).split("-") # handle shard pattern
|
||||
|
||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||
@@ -139,6 +160,8 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
||||
else:
|
||||
continue
|
||||
|
||||
print(subfolder_weights)
|
||||
|
||||
for candidate_list in subfolder_weights.values():
|
||||
highest_score_candidate = max(candidate_list, key=lambda candidate: candidate.score)
|
||||
if highest_score_candidate:
|
||||
|
||||
@@ -7,6 +7,7 @@ from diffusers import (
|
||||
DPMSolverSinglestepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
@@ -29,6 +30,7 @@ SCHEDULER_MAP = {
|
||||
"euler": (EulerDiscreteScheduler, {"use_karras_sigmas": False}),
|
||||
"euler_k": (EulerDiscreteScheduler, {"use_karras_sigmas": True}),
|
||||
"euler_a": (EulerAncestralDiscreteScheduler, {}),
|
||||
"euler_f": (FlowMatchEulerDiscreteScheduler, {}),
|
||||
"kdpm_2": (KDPM2DiscreteScheduler, {}),
|
||||
"kdpm_2_a": (KDPM2AncestralDiscreteScheduler, {}),
|
||||
"dpmpp_2s": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": False}),
|
||||
|
||||
@@ -3,7 +3,12 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import diffusers
|
||||
import torch
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.loaders import FromOriginalControlNetMixin
|
||||
|
||||
# The following import is
|
||||
# generating import errors with diffusers 028.2
|
||||
# tried diffusers.loaders.controlnet import FromOriginalControlNetMixin, but this
|
||||
# fails as well
|
||||
# from diffusers.loaders import FromOriginalControlNetMixin
|
||||
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
|
||||
from diffusers.models.controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||
from diffusers.models.embeddings import (
|
||||
@@ -32,7 +37,7 @@ from invokeai.backend.util.logging import InvokeAILogger
|
||||
logger = InvokeAILogger.get_logger(__name__)
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlNetMixin):
|
||||
class ControlNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
|
||||
@@ -37,11 +37,7 @@
|
||||
"selectBoard": "Select a Board",
|
||||
"topMessage": "This board contains images used in the following features:",
|
||||
"uncategorized": "Uncategorized",
|
||||
"downloadBoard": "Download Board",
|
||||
"imagesWithCount_one": "{{count}} image",
|
||||
"imagesWithCount_other": "{{count}} images",
|
||||
"assetsWithCount_one": "{{count}} asset",
|
||||
"assetsWithCount_other": "{{count}} assets"
|
||||
"downloadBoard": "Download Board"
|
||||
},
|
||||
"accordions": {
|
||||
"generation": {
|
||||
@@ -384,11 +380,7 @@
|
||||
"problemDeletingImagesDesc": "One or more images could not be deleted",
|
||||
"viewerImage": "Viewer Image",
|
||||
"compareImage": "Compare Image",
|
||||
"noActiveSearch": "No active search",
|
||||
"openInViewer": "Open in Viewer",
|
||||
"searchingBy": "Searching by",
|
||||
"selectAllOnPage": "Select All On Page",
|
||||
"selectAllOnBoard": "Select All On Board",
|
||||
"selectForCompare": "Select for Compare",
|
||||
"selectAnImageToCompare": "Select an Image to Compare",
|
||||
"slider": "Slider",
|
||||
|
||||
@@ -2,7 +2,8 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { getListImagesUrl } from 'services/api/util';
|
||||
import type { ImageCache } from 'services/api/types';
|
||||
import { getListImagesUrl, imagesSelectors } from 'services/api/util';
|
||||
|
||||
export const addFirstListImagesListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@@ -17,10 +18,13 @@ export const addFirstListImagesListener = (startAppListening: AppStartListening)
|
||||
cancelActiveListeners();
|
||||
unsubscribe();
|
||||
|
||||
const data = action.payload;
|
||||
// TODO: figure out how to type the predicate
|
||||
const data = action.payload as ImageCache;
|
||||
|
||||
if (data.items.length > 0) {
|
||||
dispatch(imageSelected(data.items[0] ?? null));
|
||||
if (data.ids.length > 0) {
|
||||
// Select the first image
|
||||
const firstImage = imagesSelectors.selectAll(data)[0];
|
||||
dispatch(imageSelected(firstImage ?? null));
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
boardIdSelected,
|
||||
galleryViewChanged,
|
||||
imageSelected,
|
||||
selectionChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
export const addBoardIdSelectedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@@ -18,9 +14,14 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
|
||||
|
||||
const state = getState();
|
||||
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const board_id = boardIdSelected.match(action) ? action.payload.boardId : state.gallery.selectedBoardId;
|
||||
|
||||
dispatch(selectionChanged([]));
|
||||
const galleryView = galleryViewChanged.match(action) ? action.payload : state.gallery.galleryView;
|
||||
|
||||
// when a board is selected, we need to wait until the board has loaded *some* images, then select the first one
|
||||
const categories = galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES;
|
||||
|
||||
const queryArgs = { board_id: board_id ?? 'none', categories };
|
||||
|
||||
// wait until the board has some images - maybe it already has some from a previous fetch
|
||||
// must use getState() to ensure we do not have stale state
|
||||
@@ -34,12 +35,11 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
|
||||
const { data: boardImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
|
||||
|
||||
if (boardImagesData && boardIdSelected.match(action) && action.payload.selectedImageName) {
|
||||
const selectedImage = boardImagesData.items.find(
|
||||
(item) => item.image_name === action.payload.selectedImageName
|
||||
);
|
||||
const selectedImage = imagesSelectors.selectById(boardImagesData, action.payload.selectedImageName);
|
||||
dispatch(imageSelected(selectedImage || null));
|
||||
} else if (boardImagesData) {
|
||||
dispatch(imageSelected(boardImagesData.items[0] || null));
|
||||
const firstImage = imagesSelectors.selectAll(boardImagesData)[0];
|
||||
dispatch(imageSelected(firstImage || null));
|
||||
} else {
|
||||
// board has no images - deselect
|
||||
dispatch(imageSelected(null));
|
||||
|
||||
@@ -4,6 +4,7 @@ import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelecto
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
export const galleryImageClicked = createAction<{
|
||||
imageDTO: ImageDTO;
|
||||
@@ -31,14 +32,14 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
||||
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
||||
const state = getState();
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const queryResult = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
|
||||
if (!queryResult.data) {
|
||||
if (!listImagesData) {
|
||||
// Should never happen if we have clicked a gallery image
|
||||
return;
|
||||
}
|
||||
|
||||
const imageDTOs = queryResult.data.items;
|
||||
const imageDTOs = imagesSelectors.selectAll(listImagesData);
|
||||
const selection = state.gallery.selection;
|
||||
|
||||
if (altKey) {
|
||||
|
||||
@@ -22,10 +22,11 @@ import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isImageFieldInputInstance } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { clamp, forEach } from 'lodash-es';
|
||||
import { api } from 'services/api';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
|
||||
state.nodes.present.nodes.forEach((node) => {
|
||||
@@ -117,7 +118,32 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
|
||||
}
|
||||
|
||||
dispatch(isModalOpenChanged(false));
|
||||
|
||||
const state = getState();
|
||||
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
|
||||
|
||||
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
|
||||
const { image_name } = imageDTO;
|
||||
|
||||
const baseQueryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||
|
||||
const cachedImageDTOs = data ? imagesSelectors.selectAll(data) : [];
|
||||
|
||||
const deletedImageIndex = cachedImageDTOs.findIndex((i) => i.image_name === image_name);
|
||||
|
||||
const filteredImageDTOs = cachedImageDTOs.filter((i) => i.image_name !== image_name);
|
||||
|
||||
const newSelectedImageIndex = clamp(deletedImageIndex, 0, filteredImageDTOs.length - 1);
|
||||
|
||||
const newSelectedImageDTO = filteredImageDTOs[newSelectedImageIndex];
|
||||
|
||||
if (newSelectedImageDTO) {
|
||||
dispatch(imageSelected(newSelectedImageDTO));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
}
|
||||
|
||||
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||
if (imageUsage.isCanvasImage) {
|
||||
@@ -142,20 +168,6 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
|
||||
if (wasImageDeleted) {
|
||||
dispatch(api.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
|
||||
}
|
||||
|
||||
const lastSelectedImage = state.gallery.selection[state.gallery.selection.length - 1]?.image_name;
|
||||
|
||||
if (imageDTO && imageDTO?.image_name === lastSelectedImage) {
|
||||
const baseQueryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(baseQueryArgs)(state);
|
||||
|
||||
if (data && data.items) {
|
||||
const newlySelectedImage = data?.items.find((img) => img.image_name !== imageDTO?.image_name);
|
||||
dispatch(imageSelected(newlySelectedImage || null));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
@@ -176,8 +188,10 @@ export const addRequestedSingleImageDeletionListener = (startAppListening: AppSt
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
|
||||
if (data && data.items[0]) {
|
||||
dispatch(imageSelected(data.items[0]));
|
||||
const newSelectedImageDTO = data ? imagesSelectors.selectAll(data)[0] : undefined;
|
||||
|
||||
if (newSelectedImageDTO) {
|
||||
dispatch(imageSelected(newSelectedImageDTO));
|
||||
} else {
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
|
||||
@@ -15,12 +15,7 @@ import {
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import {
|
||||
imageSelected,
|
||||
imageToCompareChanged,
|
||||
isImageViewerOpenChanged,
|
||||
selectionChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@@ -221,7 +216,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
board_id: boardId,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -239,7 +233,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
imageDTO,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -255,7 +248,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
board_id: boardId,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -269,7 +261,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
imageDTOs,
|
||||
})
|
||||
);
|
||||
dispatch(selectionChanged([]));
|
||||
return;
|
||||
}
|
||||
},
|
||||
|
||||
@@ -8,14 +8,14 @@ import {
|
||||
galleryViewChanged,
|
||||
imageSelected,
|
||||
isImageViewerOpenChanged,
|
||||
offsetChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { getCategories, getListImagesUrl } from 'services/api/util';
|
||||
import { imagesAdapter } from 'services/api/util';
|
||||
import { socketInvocationComplete } from 'services/events/actions';
|
||||
|
||||
// These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them
|
||||
@@ -52,6 +52,24 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
}
|
||||
|
||||
if (!imageDTO.is_intermediate) {
|
||||
/**
|
||||
* Cache updates for when an image result is received
|
||||
* - add it to the no_board/images
|
||||
*/
|
||||
|
||||
dispatch(
|
||||
imagesApi.util.updateQueryData(
|
||||
'listImages',
|
||||
{
|
||||
board_id: imageDTO.board_id ?? 'none',
|
||||
categories: IMAGE_CATEGORIES,
|
||||
},
|
||||
(draft) => {
|
||||
imagesAdapter.addOne(draft, imageDTO);
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
// update the total images for the board
|
||||
dispatch(
|
||||
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
|
||||
@@ -60,18 +78,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(
|
||||
imagesApi.util.invalidateTags([
|
||||
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
|
||||
{
|
||||
type: 'ImageList',
|
||||
id: getListImagesUrl({
|
||||
board_id: imageDTO.board_id ?? 'none',
|
||||
categories: getCategories(imageDTO),
|
||||
}),
|
||||
},
|
||||
])
|
||||
);
|
||||
dispatch(imagesApi.util.invalidateTags([{ type: 'Board', id: imageDTO.board_id ?? 'none' }]));
|
||||
|
||||
const { shouldAutoSwitch } = gallery;
|
||||
|
||||
@@ -91,8 +98,6 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(offsetChanged(0));
|
||||
|
||||
if (!imageDTO.board_id && gallery.selectedBoardId !== 'none') {
|
||||
dispatch(
|
||||
boardIdSelected({
|
||||
|
||||
@@ -1,37 +1,47 @@
|
||||
import type { IconButtonProps, SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import type { MouseEvent } from 'react';
|
||||
import { memo } from 'react';
|
||||
import type { MouseEvent, ReactElement } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
const sx: SystemStyleObject = {
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
|
||||
},
|
||||
};
|
||||
|
||||
type Props = Omit<IconButtonProps, 'aria-label' | 'onClick' | 'tooltip'> & {
|
||||
type Props = {
|
||||
onClick: (event: MouseEvent<HTMLButtonElement>) => void;
|
||||
tooltip: string;
|
||||
icon?: ReactElement;
|
||||
styleOverrides?: SystemStyleObject;
|
||||
};
|
||||
|
||||
const IAIDndImageIcon = (props: Props) => {
|
||||
const { onClick, tooltip, icon, ...rest } = props;
|
||||
const { onClick, tooltip, icon, styleOverrides } = props;
|
||||
|
||||
const sx = useMemo(
|
||||
() => ({
|
||||
position: 'absolute',
|
||||
top: 1,
|
||||
insetInlineEnd: 1,
|
||||
p: 0,
|
||||
minW: 0,
|
||||
svg: {
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: 'normal',
|
||||
fill: 'base.100',
|
||||
_hover: { fill: 'base.50' },
|
||||
filter: 'drop-shadow(0px 0px 0.1rem var(--invoke-colors-base-800))',
|
||||
},
|
||||
...styleOverrides,
|
||||
}),
|
||||
[styleOverrides]
|
||||
);
|
||||
|
||||
return (
|
||||
<IconButton
|
||||
onClick={onClick}
|
||||
aria-label={tooltip}
|
||||
tooltip={tooltip}
|
||||
icon={icon}
|
||||
size="sm"
|
||||
variant="link"
|
||||
sx={sx}
|
||||
data-testid={tooltip}
|
||||
{...rest}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
16
invokeai/frontend/web/src/common/util/dateComparator.ts
Normal file
16
invokeai/frontend/web/src/common/util/dateComparator.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
/**
|
||||
* Comparator function for sorting dates in ascending order
|
||||
*/
|
||||
export const dateComparator = (a: string, b: string) => {
|
||||
const dateA = new Date(a);
|
||||
const dateB = new Date(b);
|
||||
|
||||
// sort in ascending order
|
||||
if (dateA > dateB) {
|
||||
return 1;
|
||||
}
|
||||
if (dateA < dateB) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Spinner } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
@@ -184,7 +185,7 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
|
||||
/>
|
||||
</Box>
|
||||
|
||||
<Flex flexDir="column" top={1} insetInlineEnd={1}>
|
||||
<>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleResetControlImage}
|
||||
icon={controlImage ? <PiArrowCounterClockwiseBold size={16} /> : undefined}
|
||||
@@ -194,13 +195,15 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
|
||||
onClick={handleSaveControlImage}
|
||||
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.saveControlImage')}
|
||||
styleOverrides={saveControlImageStyleOverrides}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.setControlImageDimensions')}
|
||||
styleOverrides={setControlImageDimensionsStyleOverrides}
|
||||
/>
|
||||
</Flex>
|
||||
</>
|
||||
|
||||
{pendingControlImages.includes(id) && (
|
||||
<Flex
|
||||
@@ -223,3 +226,6 @@ const ControlAdapterImagePreview = ({ isSmall, id }: Props) => {
|
||||
};
|
||||
|
||||
export default memo(ControlAdapterImagePreview);
|
||||
|
||||
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Spinner, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -202,13 +203,13 @@ export const ControlAdapterImagePreview = memo(
|
||||
onClick={handleSaveControlImage}
|
||||
icon={controlImage ? <PiFloppyDiskBold size={16} /> : undefined}
|
||||
tooltip={t('controlnet.saveControlImage')}
|
||||
mt={6}
|
||||
styleOverrides={saveControlImageStyleOverrides}
|
||||
/>
|
||||
<IAIDndImageIcon
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
mt={12}
|
||||
styleOverrides={setControlImageDimensionsStyleOverrides}
|
||||
/>
|
||||
</>
|
||||
|
||||
@@ -234,3 +235,6 @@ export const ControlAdapterImagePreview = memo(
|
||||
);
|
||||
|
||||
ControlAdapterImagePreview.displayName = 'ControlAdapterImagePreview';
|
||||
|
||||
const saveControlImageStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 12 };
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -99,7 +100,7 @@ export const IPAdapterImagePreview = memo(
|
||||
onClick={handleSetControlImageToDimensions}
|
||||
icon={controlImage ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
mt={6}
|
||||
styleOverrides={setControlImageDimensionsStyleOverrides}
|
||||
/>
|
||||
</>
|
||||
</Flex>
|
||||
@@ -108,3 +109,5 @@ export const IPAdapterImagePreview = memo(
|
||||
);
|
||||
|
||||
IPAdapterImagePreview.displayName = 'IPAdapterImagePreview';
|
||||
|
||||
const setControlImageDimensionsStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, useShiftModifier } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
@@ -96,7 +97,7 @@ export const InitialImagePreview = memo(({ image, onChangeImage, droppableData,
|
||||
onClick={onUseSize}
|
||||
icon={imageDTO ? <PiRulerBold size={16} /> : undefined}
|
||||
tooltip={shift ? t('controlnet.setControlImageDimensionsForce') : t('controlnet.setControlImageDimensions')}
|
||||
mt={6}
|
||||
styleOverrides={useSizeStyleOverrides}
|
||||
/>
|
||||
</>
|
||||
</Flex>
|
||||
@@ -104,3 +105,5 @@ export const InitialImagePreview = memo(({ image, onChangeImage, droppableData,
|
||||
});
|
||||
|
||||
InitialImagePreview.displayName = 'InitialImagePreview';
|
||||
|
||||
const useSizeStyleOverrides: SystemStyleObject = { mt: 6 };
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
|
||||
type Props = {
|
||||
board_id: string;
|
||||
};
|
||||
|
||||
export const BoardTotalsTooltip = ({ board_id }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const { imagesTotal } = useGetBoardImagesTotalQuery(board_id, {
|
||||
selectFromResult: ({ data }) => {
|
||||
return { imagesTotal: data?.total ?? 0 };
|
||||
},
|
||||
});
|
||||
const { assetsTotal } = useGetBoardAssetsTotalQuery(board_id, {
|
||||
selectFromResult: ({ data }) => {
|
||||
return { assetsTotal: data?.total ?? 0 };
|
||||
},
|
||||
});
|
||||
return `${t('boards.imagesWithCount', { count: imagesTotal })}, ${t('boards.assetsWithCount', { count: assetsTotal })}`;
|
||||
};
|
||||
@@ -8,12 +8,15 @@ import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||
import type { AddToBoardDropData } from 'features/dnd/types';
|
||||
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
|
||||
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
|
||||
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
|
||||
import { autoAddBoardIdChanged, boardIdSelected, selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImagesSquare } from 'react-icons/pi';
|
||||
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
|
||||
import {
|
||||
useGetBoardAssetsTotalQuery,
|
||||
useGetBoardImagesTotalQuery,
|
||||
useUpdateBoardMutation,
|
||||
} from 'services/api/endpoints/boards';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
import type { BoardDTO } from 'services/api/types';
|
||||
|
||||
@@ -48,6 +51,17 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
setIsHovered(false);
|
||||
}, []);
|
||||
|
||||
const { data: imagesTotal } = useGetBoardImagesTotalQuery(board.board_id);
|
||||
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(board.board_id);
|
||||
const tooltip = useMemo(() => {
|
||||
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
|
||||
assetsTotal.total
|
||||
} asset${assetsTotal.total === 1 ? '' : 's'}`;
|
||||
}, [assetsTotal, imagesTotal]);
|
||||
|
||||
const { currentData: coverImage } = useGetImageDTOQuery(board.cover_image_name ?? skipToken);
|
||||
|
||||
const { board_name, board_id } = board;
|
||||
@@ -118,7 +132,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
>
|
||||
<BoardContextMenu board={board} board_id={board_id} setBoardToDelete={setBoardToDelete}>
|
||||
{(ref) => (
|
||||
<Tooltip label={<BoardTotalsTooltip board_id={board.board_id} />} openDelay={1000}>
|
||||
<Tooltip label={tooltip} openDelay={1000}>
|
||||
<Flex
|
||||
ref={ref}
|
||||
onClick={handleSelectBoard}
|
||||
|
||||
@@ -5,11 +5,11 @@ import SelectionOverlay from 'common/components/SelectionOverlay';
|
||||
import type { RemoveFromBoardDropData } from 'features/dnd/types';
|
||||
import AutoAddIcon from 'features/gallery/components/Boards/AutoAddIcon';
|
||||
import BoardContextMenu from 'features/gallery/components/Boards/BoardContextMenu';
|
||||
import { BoardTotalsTooltip } from 'features/gallery/components/Boards/BoardsList/BoardTotalsTooltip';
|
||||
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
|
||||
import InvokeLogoSVG from 'public/assets/images/invoke-symbol-wht-lrg.svg';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
|
||||
interface Props {
|
||||
@@ -29,6 +29,17 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
}, [dispatch, autoAssignBoardOnClick]);
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const { data: imagesTotal } = useGetBoardImagesTotalQuery('none');
|
||||
const { data: assetsTotal } = useGetBoardAssetsTotalQuery('none');
|
||||
const tooltip = useMemo(() => {
|
||||
if (imagesTotal?.total === undefined || assetsTotal?.total === undefined) {
|
||||
return undefined;
|
||||
}
|
||||
return `${imagesTotal.total} image${imagesTotal.total === 1 ? '' : 's'}, ${
|
||||
assetsTotal.total
|
||||
} asset${assetsTotal.total === 1 ? '' : 's'}`;
|
||||
}, [assetsTotal, imagesTotal]);
|
||||
|
||||
const handleMouseOver = useCallback(() => {
|
||||
setIsHovered(true);
|
||||
}, []);
|
||||
@@ -60,7 +71,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
>
|
||||
<BoardContextMenu board_id="none">
|
||||
{(ref) => (
|
||||
<Tooltip label={<BoardTotalsTooltip board_id="none" />} openDelay={1000}>
|
||||
<Tooltip label={tooltip} openDelay={1000}>
|
||||
<Flex
|
||||
ref={ref}
|
||||
onClick={handleSelectBoard}
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
import { Flex, IconButton, Spacer, Tag, TagCloseButton, TagLabel, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { BiSelectMultiple } from 'react-icons/bi';
|
||||
|
||||
import { GallerySearch } from './GallerySearch';
|
||||
|
||||
export const GalleryBulkSelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { selection } = useAppSelector((s) => s.gallery);
|
||||
const { t } = useTranslation();
|
||||
const { imageDTOs } = useGalleryImages();
|
||||
|
||||
const onClickClearSelection = useCallback(() => {
|
||||
dispatch(selectionChanged([]));
|
||||
}, [dispatch]);
|
||||
|
||||
const onClickSelectAllPage = useCallback(() => {
|
||||
dispatch(selectionChanged(selection.concat(imageDTOs)));
|
||||
}, [dispatch, imageDTOs, selection]);
|
||||
|
||||
return (
|
||||
<Flex alignItems="center" justifyContent="space-between">
|
||||
<Flex>
|
||||
{selection.length > 0 ? (
|
||||
<Tag>
|
||||
<TagLabel>
|
||||
{selection.length} {t('common.selected')}
|
||||
</TagLabel>
|
||||
<Tooltip label="Clear selection">
|
||||
<TagCloseButton onClick={onClickClearSelection} />
|
||||
</Tooltip>
|
||||
</Tag>
|
||||
) : (
|
||||
<Spacer />
|
||||
)}
|
||||
|
||||
<Tooltip label={t('gallery.selectAllOnPage')}>
|
||||
<IconButton
|
||||
variant="outline"
|
||||
size="sm"
|
||||
icon={<BiSelectMultiple />}
|
||||
aria-label="Bulk select"
|
||||
onClick={onClickSelectAllPage}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
|
||||
<GallerySearch />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -1,97 +0,0 @@
|
||||
import { Flex, IconButton, Input, InputGroup, InputRightElement, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { searchTermChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { motion } from 'framer-motion';
|
||||
import { debounce } from 'lodash-es';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiMagnifyingGlassBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
export const GallerySearch = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { searchTerm } = useAppSelector((s) => s.gallery);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [searchTermInput, setSearchTermInput] = useState('');
|
||||
|
||||
const debouncedSetSearchTerm = useMemo(() => {
|
||||
return debounce((value: string) => {
|
||||
dispatch(searchTermChanged(value));
|
||||
}, 1000);
|
||||
}, [dispatch]);
|
||||
|
||||
const onChangeInput = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
setSearchTermInput(e.target.value);
|
||||
debouncedSetSearchTerm(e.target.value);
|
||||
},
|
||||
[debouncedSetSearchTerm]
|
||||
);
|
||||
|
||||
const onClearInput = useCallback(() => {
|
||||
setSearchTermInput('');
|
||||
debouncedSetSearchTerm('');
|
||||
}, [debouncedSetSearchTerm]);
|
||||
|
||||
const toggleExpanded = useCallback((newState: boolean) => {
|
||||
setExpanded(newState);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex>
|
||||
{!expanded && (
|
||||
<Tooltip
|
||||
label={
|
||||
searchTerm && searchTerm.length ? `${t('gallery.searchingBy')} ${searchTerm}` : t('gallery.noActiveSearch')
|
||||
}
|
||||
>
|
||||
<IconButton
|
||||
aria-label="Close"
|
||||
icon={<PiMagnifyingGlassBold />}
|
||||
onClick={toggleExpanded.bind(null, true)}
|
||||
variant="outline"
|
||||
size="sm"
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
<motion.div
|
||||
initial={false}
|
||||
animate={{ width: expanded ? '200px' : '0px' }}
|
||||
transition={{ duration: 0.3 }}
|
||||
style={{ overflow: 'hidden' }}
|
||||
>
|
||||
<InputGroup size="sm">
|
||||
<IconButton
|
||||
aria-label="Close"
|
||||
icon={<PiMagnifyingGlassBold />}
|
||||
onClick={toggleExpanded.bind(null, false)}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
/>
|
||||
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Search..."
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onChange={onChangeInput}
|
||||
value={searchTermInput}
|
||||
/>
|
||||
{searchTermInput && searchTermInput.length && (
|
||||
<InputRightElement h="full" pe={2}>
|
||||
<IconButton
|
||||
onClick={onClearInput}
|
||||
size="sm"
|
||||
variant="link"
|
||||
aria-label={t('boards.clearSearch')}
|
||||
icon={<PiXBold />}
|
||||
/>
|
||||
</InputRightElement>
|
||||
)}
|
||||
</InputGroup>
|
||||
</motion.div>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -1,22 +1,22 @@
|
||||
import { Box, Button, ButtonGroup, Flex, Tab, TabList, Tabs, useDisclosure } from '@invoke-ai/ui-library';
|
||||
import { Box, Button, ButtonGroup, Flex, Tab, TabList, Tabs, useDisclosure, VStack } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { $galleryHeader } from 'app/store/nanostores/galleryHeader';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { galleryViewChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImagesBold } from 'react-icons/pi';
|
||||
import { RiServerLine } from 'react-icons/ri';
|
||||
|
||||
import BoardsList from './Boards/BoardsList/BoardsList';
|
||||
import GalleryBoardName from './GalleryBoardName';
|
||||
import { GalleryBulkSelect } from './GalleryBulkSelect';
|
||||
import GallerySettingsPopover from './GallerySettingsPopover';
|
||||
import GalleryImageGrid from './ImageGrid/GalleryImageGrid';
|
||||
import { GalleryPagination } from './ImageGrid/GalleryPagination';
|
||||
|
||||
const ImageGalleryContent = () => {
|
||||
const { t } = useTranslation();
|
||||
const resizeObserverRef = useRef<HTMLDivElement>(null);
|
||||
const galleryGridRef = useRef<HTMLDivElement>(null);
|
||||
const galleryView = useAppSelector((s) => s.gallery.galleryView);
|
||||
const dispatch = useAppDispatch();
|
||||
const galleryHeader = useStore($galleryHeader);
|
||||
@@ -31,10 +31,10 @@ const ImageGalleryContent = () => {
|
||||
}, [dispatch]);
|
||||
|
||||
return (
|
||||
<Flex layerStyle="first" flexDirection="column" h="full" w="full" borderRadius="base" p={2} gap={2}>
|
||||
<VStack layerStyle="first" flexDirection="column" h="full" w="full" borderRadius="base" p={2}>
|
||||
{galleryHeader}
|
||||
<Box>
|
||||
<Flex alignItems="center" justifyContent="space-between" gap={2}>
|
||||
<Box w="full">
|
||||
<Flex ref={resizeObserverRef} alignItems="center" justifyContent="space-between" gap={2}>
|
||||
<GalleryBoardName isOpen={isBoardListOpen} onToggle={onToggleBoardList} />
|
||||
<GallerySettingsPopover />
|
||||
</Flex>
|
||||
@@ -42,41 +42,40 @@ const ImageGalleryContent = () => {
|
||||
<BoardsList isOpen={isBoardListOpen} />
|
||||
</Box>
|
||||
</Box>
|
||||
<Flex alignItems="center" justifyContent="space-between" gap={2}>
|
||||
<Tabs index={galleryView === 'images' ? 0 : 1} variant="unstyled" size="sm" w="full">
|
||||
<TabList>
|
||||
<ButtonGroup w="full">
|
||||
<Tab
|
||||
as={Button}
|
||||
size="sm"
|
||||
isChecked={galleryView === 'images'}
|
||||
onClick={handleClickImages}
|
||||
w="full"
|
||||
leftIcon={<PiImagesBold size="16px" />}
|
||||
data-testid="images-tab"
|
||||
>
|
||||
{t('parameters.images')}
|
||||
</Tab>
|
||||
<Tab
|
||||
as={Button}
|
||||
size="sm"
|
||||
isChecked={galleryView === 'assets'}
|
||||
onClick={handleClickAssets}
|
||||
w="full"
|
||||
leftIcon={<RiServerLine size="16px" />}
|
||||
data-testid="assets-tab"
|
||||
>
|
||||
{t('gallery.assets')}
|
||||
</Tab>
|
||||
</ButtonGroup>
|
||||
</TabList>
|
||||
</Tabs>
|
||||
<Flex ref={galleryGridRef} direction="column" gap={2} h="full" w="full">
|
||||
<Flex alignItems="center" justifyContent="space-between" gap={2}>
|
||||
<Tabs index={galleryView === 'images' ? 0 : 1} variant="unstyled" size="sm" w="full">
|
||||
<TabList>
|
||||
<ButtonGroup w="full">
|
||||
<Tab
|
||||
as={Button}
|
||||
size="sm"
|
||||
isChecked={galleryView === 'images'}
|
||||
onClick={handleClickImages}
|
||||
w="full"
|
||||
leftIcon={<PiImagesBold size="16px" />}
|
||||
data-testid="images-tab"
|
||||
>
|
||||
{t('parameters.images')}
|
||||
</Tab>
|
||||
<Tab
|
||||
as={Button}
|
||||
size="sm"
|
||||
isChecked={galleryView === 'assets'}
|
||||
onClick={handleClickAssets}
|
||||
w="full"
|
||||
leftIcon={<RiServerLine size="16px" />}
|
||||
data-testid="assets-tab"
|
||||
>
|
||||
{t('gallery.assets')}
|
||||
</Tab>
|
||||
</ButtonGroup>
|
||||
</TabList>
|
||||
</Tabs>
|
||||
</Flex>
|
||||
<GalleryImageGrid />
|
||||
</Flex>
|
||||
<GalleryBulkSelect />
|
||||
|
||||
<GalleryImageGrid />
|
||||
<GalleryPagination />
|
||||
</Flex>
|
||||
</VStack>
|
||||
);
|
||||
};
|
||||
|
||||
|
||||
@@ -16,13 +16,13 @@ import type { MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiStarBold, PiStarFill, PiTrashSimpleFill } from 'react-icons/pi';
|
||||
import { useStarImagesMutation, useUnstarImagesMutation } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
// This class name is used to calculate the number of images that fit in the gallery
|
||||
export const GALLERY_IMAGE_CLASS_NAME = 'gallery-image';
|
||||
import { useGetImageDTOQuery, useStarImagesMutation, useUnstarImagesMutation } from 'services/api/endpoints/images';
|
||||
|
||||
const imageSx: SystemStyleObject = { w: 'full', h: 'full' };
|
||||
const imageIconStyleOverrides: SystemStyleObject = {
|
||||
bottom: 2,
|
||||
top: 'auto',
|
||||
};
|
||||
const boxSx: SystemStyleObject = {
|
||||
containerType: 'inline-size',
|
||||
};
|
||||
@@ -34,22 +34,24 @@ const badgeSx: SystemStyleObject = {
|
||||
};
|
||||
|
||||
interface HoverableImageProps {
|
||||
imageDTO: ImageDTO;
|
||||
imageName: string;
|
||||
index: number;
|
||||
}
|
||||
|
||||
const GalleryImage = ({ index, imageDTO }: HoverableImageProps) => {
|
||||
const GalleryImage = (props: HoverableImageProps) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { imageName } = props;
|
||||
const { currentData: imageDTO } = useGetImageDTOQuery(imageName);
|
||||
const shift = useShiftModifier();
|
||||
const { t } = useTranslation();
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
|
||||
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageDTO.image_name);
|
||||
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageName);
|
||||
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
|
||||
|
||||
const customStarUi = useStore($customStarUI);
|
||||
|
||||
const imageContainerRef = useScrollIntoView(isSelected, index, areMultiplesSelected);
|
||||
const imageContainerRef = useScrollIntoView(isSelected, props.index, areMultiplesSelected);
|
||||
|
||||
const handleDelete = useCallback(
|
||||
(e: MouseEvent<HTMLButtonElement>) => {
|
||||
@@ -112,32 +114,32 @@ const GalleryImage = ({ index, imageDTO }: HoverableImageProps) => {
|
||||
}, []);
|
||||
|
||||
const starIcon = useMemo(() => {
|
||||
if (imageDTO.starred) {
|
||||
if (imageDTO?.starred) {
|
||||
return customStarUi ? customStarUi.on.icon : <PiStarFill size="20" />;
|
||||
}
|
||||
if (!imageDTO.starred && isHovered) {
|
||||
if (!imageDTO?.starred && isHovered) {
|
||||
return customStarUi ? customStarUi.off.icon : <PiStarBold size="20" />;
|
||||
}
|
||||
}, [imageDTO.starred, isHovered, customStarUi]);
|
||||
}, [imageDTO?.starred, isHovered, customStarUi]);
|
||||
|
||||
const starTooltip = useMemo(() => {
|
||||
if (imageDTO.starred) {
|
||||
if (imageDTO?.starred) {
|
||||
return customStarUi ? customStarUi.off.text : 'Unstar';
|
||||
}
|
||||
if (!imageDTO.starred) {
|
||||
if (!imageDTO?.starred) {
|
||||
return customStarUi ? customStarUi.on.text : 'Star';
|
||||
}
|
||||
return '';
|
||||
}, [imageDTO.starred, customStarUi]);
|
||||
}, [imageDTO?.starred, customStarUi]);
|
||||
|
||||
const dataTestId = useMemo(() => getGalleryImageDataTestId(imageDTO.image_name), [imageDTO.image_name]);
|
||||
const dataTestId = useMemo(() => getGalleryImageDataTestId(imageDTO?.image_name), [imageDTO?.image_name]);
|
||||
|
||||
if (!imageDTO) {
|
||||
return <IAIFillSkeleton />;
|
||||
}
|
||||
|
||||
return (
|
||||
<Box w="full" h="full" p={1.5} className={GALLERY_IMAGE_CLASS_NAME} data-testid={dataTestId} sx={boxSx}>
|
||||
<Box w="full" h="full" className="gallerygrid-image" data-testid={dataTestId} sx={boxSx}>
|
||||
<Flex
|
||||
ref={imageContainerRef}
|
||||
userSelect="none"
|
||||
@@ -181,23 +183,14 @@ const GalleryImage = ({ index, imageDTO }: HoverableImageProps) => {
|
||||
pointerEvents="none"
|
||||
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
|
||||
)}
|
||||
<IAIDndImageIcon
|
||||
onClick={toggleStarredState}
|
||||
icon={starIcon}
|
||||
tooltip={starTooltip}
|
||||
position="absolute"
|
||||
top={1}
|
||||
insetInlineEnd={1}
|
||||
/>
|
||||
<IAIDndImageIcon onClick={toggleStarredState} icon={starIcon} tooltip={starTooltip} />
|
||||
|
||||
{isHovered && shift && (
|
||||
<IAIDndImageIcon
|
||||
onClick={handleDelete}
|
||||
icon={<PiTrashSimpleFill size="16px" />}
|
||||
tooltip={t('gallery.deleteImage_one')}
|
||||
position="absolute"
|
||||
bottom={1}
|
||||
insetInlineEnd={1}
|
||||
tooltip={t('gallery.deleteImage', { count: 1 })}
|
||||
styleOverrides={imageIconStyleOverrides}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
|
||||
@@ -1,32 +1,120 @@
|
||||
import { Box, Flex, Grid } from '@invoke-ai/ui-library';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { Box, Button, Flex } from '@invoke-ai/ui-library';
|
||||
import type { EntityId } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import { overlayScrollbarsParams } from 'common/components/OverlayScrollbars/constants';
|
||||
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
|
||||
import { useGalleryHotkeys } from 'features/gallery/hooks/useGalleryHotkeys';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { limitChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { debounce } from 'lodash-es';
|
||||
import { memo, useEffect, useMemo, useState } from 'react';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImageBold, PiWarningCircleBold } from 'react-icons/pi';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
import type { GridComponents, ItemContent, ListRange, VirtuosoGridHandle } from 'react-virtuoso';
|
||||
import { VirtuosoGrid } from 'react-virtuoso';
|
||||
import { useBoardTotal } from 'services/api/hooks/useBoardTotal';
|
||||
|
||||
import { GALLERY_GRID_CLASS_NAME } from './constants';
|
||||
import GalleryImage, { GALLERY_IMAGE_CLASS_NAME } from './GalleryImage';
|
||||
import GalleryImage from './GalleryImage';
|
||||
import ImageGridItemContainer from './ImageGridItemContainer';
|
||||
import ImageGridListContainer from './ImageGridListContainer';
|
||||
|
||||
const components: GridComponents = {
|
||||
Item: ImageGridItemContainer,
|
||||
List: ImageGridListContainer,
|
||||
};
|
||||
|
||||
const virtuosoStyles: CSSProperties = { height: '100%' };
|
||||
|
||||
const GalleryImageGrid = () => {
|
||||
useGalleryHotkeys();
|
||||
const { t } = useTranslation();
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const { imageDTOs, isLoading, isError, isFetching } = useListImagesQuery(queryArgs, {
|
||||
selectFromResult: ({ data, isLoading, isSuccess, isError, isFetching }) => ({
|
||||
imageDTOs: data?.items ?? EMPTY_ARRAY,
|
||||
isLoading,
|
||||
isSuccess,
|
||||
isError,
|
||||
isFetching,
|
||||
}),
|
||||
});
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
const [scroller, setScroller] = useState<HTMLElement | null>(null);
|
||||
const [initialize, osInstance] = useOverlayScrollbars(overlayScrollbarsParams);
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const { currentViewTotal } = useBoardTotal(selectedBoardId);
|
||||
const virtuosoRangeRef = useRef<ListRange | null>(null);
|
||||
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
|
||||
const {
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
queryResult: { currentData, isFetching, isSuccess, isError },
|
||||
} = useGalleryImages();
|
||||
useGalleryHotkeys();
|
||||
const itemContentFunc: ItemContent<EntityId, void> = useCallback(
|
||||
(index, imageName) => <GalleryImage key={imageName} index={index} imageName={imageName as string} />,
|
||||
[]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
// Initialize the gallery's custom scrollbar
|
||||
const { current: root } = rootRef;
|
||||
if (scroller && root) {
|
||||
initialize({
|
||||
target: root,
|
||||
elements: {
|
||||
viewport: scroller,
|
||||
},
|
||||
});
|
||||
}
|
||||
return () => osInstance()?.destroy();
|
||||
}, [scroller, initialize, osInstance]);
|
||||
|
||||
const onRangeChanged = useCallback((range: ListRange) => {
|
||||
virtuosoRangeRef.current = range;
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
virtuosoGridRefs.set({ rootRef, virtuosoRangeRef, virtuosoRef });
|
||||
return () => {
|
||||
virtuosoGridRefs.set({});
|
||||
};
|
||||
}, []);
|
||||
|
||||
if (!currentData) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<IAINoContentFallback label={t('gallery.loading')} icon={PiImageBold} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSuccess && currentData?.ids.length === 0) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<IAINoContentFallback label={t('gallery.noImagesInGallery')} icon={PiImageBold} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSuccess && currentData) {
|
||||
return (
|
||||
<>
|
||||
<Box ref={rootRef} data-overlayscrollbars="" h="100%" id="gallery-grid">
|
||||
<VirtuosoGrid
|
||||
style={virtuosoStyles}
|
||||
data={currentData.ids}
|
||||
endReached={handleLoadMoreImages}
|
||||
components={components}
|
||||
scrollerRef={setScroller}
|
||||
itemContent={itemContentFunc}
|
||||
ref={virtuosoRef}
|
||||
rangeChanged={onRangeChanged}
|
||||
overscan={10}
|
||||
/>
|
||||
</Box>
|
||||
<Button
|
||||
onClick={handleLoadMoreImages}
|
||||
isDisabled={!areMoreImagesAvailable}
|
||||
isLoading={isFetching}
|
||||
loadingText={t('gallery.loading')}
|
||||
flexShrink={0}
|
||||
>
|
||||
{`${t('accessibility.loadMore')} (${currentData.ids.length} / ${currentViewTotal})`}
|
||||
</Button>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
if (isError) {
|
||||
return (
|
||||
@@ -36,115 +124,7 @@ const GalleryImageGrid = () => {
|
||||
);
|
||||
}
|
||||
|
||||
if (isLoading || isFetching) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<IAINoContentFallback label={t('gallery.loading')} icon={PiImageBold} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
if (imageDTOs.length === 0) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<IAINoContentFallback label={t('gallery.noImagesInGallery')} icon={PiImageBold} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return <Content />;
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(GalleryImageGrid);
|
||||
|
||||
const Content = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
|
||||
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const { imageDTOs } = useListImagesQuery(queryArgs, {
|
||||
selectFromResult: ({ data }) => ({ imageDTOs: data?.items ?? EMPTY_ARRAY }),
|
||||
});
|
||||
// Use a callback ref to get reactivity on the container element because it is conditionally rendered
|
||||
const [container, containerRef] = useState<HTMLDivElement | null>(null);
|
||||
|
||||
const calculateNewLimit = useMemo(() => {
|
||||
// Debounce this to not thrash the API
|
||||
return debounce(() => {
|
||||
if (!container) {
|
||||
// Container not rendered yet
|
||||
return;
|
||||
}
|
||||
// Managing refs for dynamically rendered components is a bit tedious:
|
||||
// - https://react.dev/learn/manipulating-the-dom-with-refs#how-to-manage-a-list-of-refs-using-a-ref-callback
|
||||
// As a easy workaround, we can just grab the first gallery image element directly.
|
||||
const galleryImageEl = document.querySelector(`.${GALLERY_IMAGE_CLASS_NAME}`);
|
||||
if (!galleryImageEl) {
|
||||
// No images in gallery?
|
||||
return;
|
||||
}
|
||||
|
||||
const galleryImageRect = galleryImageEl.getBoundingClientRect();
|
||||
const containerRect = container.getBoundingClientRect();
|
||||
|
||||
if (!galleryImageRect.width || !galleryImageRect.height || !containerRect.width || !containerRect.height) {
|
||||
// Gallery is too small to fit images or not rendered yet
|
||||
return;
|
||||
}
|
||||
|
||||
// Floating-point precision requires we round to get the correct number of images per row
|
||||
const imagesPerRow = Math.round(containerRect.width / galleryImageRect.width);
|
||||
// However, when calculating the number of images per column, we want to floor the value to not overflow the container
|
||||
const imagesPerColumn = Math.floor(containerRect.height / galleryImageRect.height);
|
||||
// Always load at least 1 row of images
|
||||
const limit = Math.max(imagesPerRow, imagesPerRow * imagesPerColumn);
|
||||
dispatch(limitChanged(limit));
|
||||
}, 300);
|
||||
}, [container, dispatch]);
|
||||
|
||||
useEffect(() => {
|
||||
// We want to recalculate the limit when image size changes
|
||||
calculateNewLimit();
|
||||
}, [calculateNewLimit, galleryImageMinimumWidth]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!container) {
|
||||
return;
|
||||
}
|
||||
|
||||
const resizeObserver = new ResizeObserver(calculateNewLimit);
|
||||
resizeObserver.observe(container);
|
||||
|
||||
// First render
|
||||
calculateNewLimit();
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
};
|
||||
}, [calculateNewLimit, container, dispatch]);
|
||||
|
||||
return (
|
||||
<Box position="relative" w="full" h="full">
|
||||
<Box
|
||||
ref={containerRef}
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
w="full"
|
||||
h="full"
|
||||
overflow="hidden"
|
||||
>
|
||||
<Grid
|
||||
className={GALLERY_GRID_CLASS_NAME}
|
||||
gridTemplateColumns={`repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`}
|
||||
>
|
||||
{imageDTOs.map((imageDTO, index) => (
|
||||
<GalleryImage key={imageDTO.image_name} imageDTO={imageDTO} index={index} />
|
||||
))}
|
||||
</Grid>
|
||||
</Box>
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
import { Button, Flex, IconButton, Spacer, Text } from '@invoke-ai/ui-library';
|
||||
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
|
||||
import { PiCaretDoubleLeftBold, PiCaretDoubleRightBold, PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||
|
||||
export const GalleryPagination = () => {
|
||||
const {
|
||||
goPrev,
|
||||
goNext,
|
||||
goToFirst,
|
||||
goToLast,
|
||||
isFirstEnabled,
|
||||
isLastEnabled,
|
||||
isPrevEnabled,
|
||||
isNextEnabled,
|
||||
pageButtons,
|
||||
goToPage,
|
||||
currentPage,
|
||||
rangeDisplay,
|
||||
total,
|
||||
} = useGalleryPagination();
|
||||
|
||||
if (!total) {
|
||||
return <Flex flexDir="column" alignItems="center" gap="2" height="48px"></Flex>;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" alignItems="center" gap="2" height="48px">
|
||||
<Flex gap={2} alignItems="center" w="full">
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="prev"
|
||||
icon={<PiCaretDoubleLeftBold />}
|
||||
onClick={goToFirst}
|
||||
isDisabled={!isFirstEnabled}
|
||||
/>
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="prev"
|
||||
icon={<PiCaretLeftBold />}
|
||||
onClick={goPrev}
|
||||
isDisabled={!isPrevEnabled}
|
||||
/>
|
||||
<Spacer />
|
||||
{pageButtons.map((page) => (
|
||||
<Button
|
||||
size="sm"
|
||||
key={page}
|
||||
onClick={goToPage.bind(null, page)}
|
||||
variant={currentPage === page ? 'solid' : 'outline'}
|
||||
>
|
||||
{page + 1}
|
||||
</Button>
|
||||
))}
|
||||
<Spacer />
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="next"
|
||||
icon={<PiCaretRightBold />}
|
||||
onClick={goNext}
|
||||
isDisabled={!isNextEnabled}
|
||||
/>
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="next"
|
||||
icon={<PiCaretDoubleRightBold />}
|
||||
onClick={goToLast}
|
||||
isDisabled={!isLastEnabled}
|
||||
/>
|
||||
</Flex>
|
||||
<Text>{rangeDisplay}</Text>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,15 @@
|
||||
import type { FlexProps } from '@invoke-ai/ui-library';
|
||||
import { Box, forwardRef } from '@invoke-ai/ui-library';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const imageItemContainerTestId = 'image-item-container';
|
||||
|
||||
type ItemContainerProps = PropsWithChildren & FlexProps;
|
||||
const ItemContainer = forwardRef((props: ItemContainerProps, ref) => (
|
||||
<Box className="item-container" ref={ref} p={1.5} data-testid={imageItemContainerTestId}>
|
||||
{props.children}
|
||||
</Box>
|
||||
));
|
||||
|
||||
export default memo(ItemContainer);
|
||||
@@ -0,0 +1,26 @@
|
||||
import type { FlexProps } from '@invoke-ai/ui-library';
|
||||
import { forwardRef, Grid } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { PropsWithChildren } from 'react';
|
||||
import { memo } from 'react';
|
||||
|
||||
export const imageListContainerTestId = 'image-list-container';
|
||||
|
||||
type ListContainerProps = PropsWithChildren & FlexProps;
|
||||
const ListContainer = forwardRef((props: ListContainerProps, ref) => {
|
||||
const galleryImageMinimumWidth = useAppSelector((s) => s.gallery.galleryImageMinimumWidth);
|
||||
|
||||
return (
|
||||
<Grid
|
||||
{...props}
|
||||
className="list-container"
|
||||
ref={ref}
|
||||
gridTemplateColumns={`repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`}
|
||||
data-testid={imageListContainerTestId}
|
||||
>
|
||||
{props.children}
|
||||
</Grid>
|
||||
);
|
||||
});
|
||||
|
||||
export default memo(ListContainer);
|
||||
@@ -1 +0,0 @@
|
||||
export const GALLERY_GRID_CLASS_NAME = 'gallery-grid';
|
||||
@@ -2,7 +2,6 @@ import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, IconButton, Spinner } from '@invoke-ai/ui-library';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { useGalleryNavigation } from 'features/gallery/hooks/useGalleryNavigation';
|
||||
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDoubleRightBold, PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||
@@ -17,8 +16,11 @@ const NextPrevImageButtons = () => {
|
||||
|
||||
const { prevImage, nextImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
|
||||
|
||||
const { isFetching } = useGalleryImages().queryResult;
|
||||
const { isNextEnabled, goNext } = useGalleryPagination();
|
||||
const {
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
queryResult: { isFetching },
|
||||
} = useGalleryImages();
|
||||
|
||||
return (
|
||||
<Box pos="relative" h="full" w="full">
|
||||
@@ -45,17 +47,17 @@ const NextPrevImageButtons = () => {
|
||||
sx={nextPrevButtonStyles}
|
||||
/>
|
||||
)}
|
||||
{isOnLastImage && isNextEnabled && !isFetching && (
|
||||
{isOnLastImage && areMoreImagesAvailable && !isFetching && (
|
||||
<IconButton
|
||||
aria-label={t('accessibility.loadMore')}
|
||||
icon={<PiCaretDoubleRightBold size={64} />}
|
||||
variant="unstyled"
|
||||
onClick={goNext}
|
||||
onClick={handleLoadMoreImages}
|
||||
boxSize={16}
|
||||
sx={nextPrevButtonStyles}
|
||||
/>
|
||||
)}
|
||||
{isOnLastImage && isNextEnabled && isFetching && (
|
||||
{isOnLastImage && areMoreImagesAvailable && isFetching && (
|
||||
<Flex w={16} h={16} alignItems="center" justifyContent="center">
|
||||
<Spinner opacity={0.5} size="xl" />
|
||||
</Flex>
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { isStagingSelector } from 'features/canvas/store/canvasSelectors';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { useGalleryNavigation } from 'features/gallery/hooks/useGalleryNavigation';
|
||||
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
|
||||
/**
|
||||
* Registers gallery hotkeys. This hook is a singleton.
|
||||
@@ -19,30 +17,21 @@ export const useGalleryHotkeys = () => {
|
||||
return activeTabName !== 'canvas' || !isStaging;
|
||||
}, [activeTabName, isStaging]);
|
||||
|
||||
const { goNext, goPrev, isNextEnabled, isPrevEnabled } = useGalleryPagination();
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const queryResult = useListImagesQuery(queryArgs);
|
||||
|
||||
const {
|
||||
handleLeftImage,
|
||||
handleRightImage,
|
||||
handleUpImage,
|
||||
handleDownImage,
|
||||
areImagesBelowCurrent,
|
||||
isOnFirstImageOfView,
|
||||
isOnLastImageOfView,
|
||||
} = useGalleryNavigation();
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
queryResult: { isFetching },
|
||||
} = useGalleryImages();
|
||||
|
||||
const { handleLeftImage, handleRightImage, handleUpImage, handleDownImage, isOnLastImage, areImagesBelowCurrent } =
|
||||
useGalleryNavigation();
|
||||
|
||||
useHotkeys(
|
||||
['left', 'alt+left'],
|
||||
(e) => {
|
||||
if (isOnFirstImageOfView && isPrevEnabled && !queryResult.isFetching) {
|
||||
goPrev();
|
||||
return;
|
||||
}
|
||||
canNavigateGallery && handleLeftImage(e.altKey);
|
||||
},
|
||||
[handleLeftImage, canNavigateGallery, isOnFirstImageOfView, goPrev, isPrevEnabled, queryResult.isFetching]
|
||||
[handleLeftImage, canNavigateGallery]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
@@ -51,15 +40,15 @@ export const useGalleryHotkeys = () => {
|
||||
if (!canNavigateGallery) {
|
||||
return;
|
||||
}
|
||||
if (isOnLastImageOfView && isNextEnabled && !queryResult.isFetching) {
|
||||
goNext();
|
||||
if (isOnLastImage && areMoreImagesAvailable && !isFetching) {
|
||||
handleLoadMoreImages();
|
||||
return;
|
||||
}
|
||||
if (!isOnLastImageOfView) {
|
||||
if (!isOnLastImage) {
|
||||
handleRightImage(e.altKey);
|
||||
}
|
||||
},
|
||||
[isOnLastImageOfView, goNext, isNextEnabled, queryResult.isFetching, handleRightImage, canNavigateGallery]
|
||||
[isOnLastImage, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleRightImage, canNavigateGallery]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
@@ -74,13 +63,13 @@ export const useGalleryHotkeys = () => {
|
||||
useHotkeys(
|
||||
['down', 'alt+down'],
|
||||
(e) => {
|
||||
if (!areImagesBelowCurrent && isNextEnabled && !queryResult.isFetching) {
|
||||
goNext();
|
||||
if (!areImagesBelowCurrent && areMoreImagesAvailable && !isFetching) {
|
||||
handleLoadMoreImages();
|
||||
return;
|
||||
}
|
||||
handleDownImage(e.altKey);
|
||||
},
|
||||
{ preventDefault: true },
|
||||
[areImagesBelowCurrent, goNext, isNextEnabled, queryResult.isFetching, handleDownImage]
|
||||
[areImagesBelowCurrent, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleDownImage]
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,15 +1,38 @@
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { useMemo } from 'react';
|
||||
import { moreImagesLoaded } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
|
||||
/**
|
||||
* Provides access to the gallery images and a way to imperatively fetch more.
|
||||
*/
|
||||
export const useGalleryImages = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const galleryView = useAppSelector((s) => s.gallery.galleryView);
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const queryResult = useListImagesQuery(queryArgs);
|
||||
const imageDTOs = useMemo(() => queryResult.data?.items ?? EMPTY_ARRAY, [queryResult.data]);
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const { data: assetsTotal } = useGetBoardAssetsTotalQuery(selectedBoardId);
|
||||
const { data: imagesTotal } = useGetBoardImagesTotalQuery(selectedBoardId);
|
||||
const currentViewTotal = useMemo(
|
||||
() => (galleryView === 'images' ? imagesTotal?.total : assetsTotal?.total),
|
||||
[assetsTotal?.total, galleryView, imagesTotal?.total]
|
||||
);
|
||||
const areMoreImagesAvailable = useMemo(() => {
|
||||
if (!currentViewTotal || !queryResult.data) {
|
||||
return false;
|
||||
}
|
||||
return queryResult.data.ids.length < currentViewTotal;
|
||||
}, [queryResult.data, currentViewTotal]);
|
||||
const handleLoadMoreImages = useCallback(() => {
|
||||
dispatch(moreImagesLoaded());
|
||||
}, [dispatch]);
|
||||
|
||||
return {
|
||||
imageDTOs,
|
||||
areMoreImagesAvailable,
|
||||
handleLoadMoreImages,
|
||||
queryResult,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { useAltModifier } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { GALLERY_GRID_CLASS_NAME } from 'features/gallery/components/ImageGrid/constants';
|
||||
import { GALLERY_IMAGE_CLASS_NAME } from 'features/gallery/components/ImageGrid/GalleryImage';
|
||||
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
|
||||
import { imageItemContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridItemContainer';
|
||||
import { imageListContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridListContainer';
|
||||
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { imageSelected, imageToCompareChanged } from 'features/gallery/store/gallerySlice';
|
||||
@@ -11,6 +11,7 @@ import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAli
|
||||
import { clamp } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
|
||||
/**
|
||||
* This hook is used to navigate the gallery using the arrow keys.
|
||||
@@ -28,9 +29,10 @@ import type { ImageDTO } from 'services/api/types';
|
||||
*/
|
||||
const getImagesPerRow = (): number => {
|
||||
const widthOfGalleryImage =
|
||||
document.querySelector(`.${GALLERY_IMAGE_CLASS_NAME}`)?.getBoundingClientRect().width ?? 1;
|
||||
document.querySelector(`[data-testid="${imageItemContainerTestId}"]`)?.getBoundingClientRect().width ?? 1;
|
||||
|
||||
const widthOfGalleryGrid = document.querySelector(`.${GALLERY_GRID_CLASS_NAME}`)?.getBoundingClientRect().width ?? 0;
|
||||
const widthOfGalleryGrid =
|
||||
document.querySelector(`[data-testid="${imageListContainerTestId}"]`)?.getBoundingClientRect().width ?? 0;
|
||||
|
||||
const imagesPerRow = Math.round(widthOfGalleryGrid / widthOfGalleryImage);
|
||||
|
||||
@@ -113,8 +115,6 @@ type UseGalleryNavigationReturn = {
|
||||
isOnFirstImage: boolean;
|
||||
isOnLastImage: boolean;
|
||||
areImagesBelowCurrent: boolean;
|
||||
isOnFirstImageOfView: boolean;
|
||||
isOnLastImageOfView: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -134,19 +134,23 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
return lastSelected;
|
||||
}
|
||||
});
|
||||
const { imageDTOs } = useGalleryImages();
|
||||
const loadedImagesCount = useMemo(() => imageDTOs.length, [imageDTOs.length]);
|
||||
|
||||
const {
|
||||
queryResult: { data },
|
||||
} = useGalleryImages();
|
||||
const loadedImagesCount = useMemo(() => data?.ids.length ?? 0, [data?.ids.length]);
|
||||
const lastSelectedImageIndex = useMemo(() => {
|
||||
if (imageDTOs.length === 0 || !lastSelectedImage) {
|
||||
if (!data || !lastSelectedImage) {
|
||||
return 0;
|
||||
}
|
||||
return imageDTOs.findIndex((i) => i.image_name === lastSelectedImage.image_name);
|
||||
}, [imageDTOs, lastSelectedImage]);
|
||||
return imagesSelectors.selectAll(data).findIndex((i) => i.image_name === lastSelectedImage.image_name);
|
||||
}, [lastSelectedImage, data]);
|
||||
|
||||
const handleNavigation = useCallback(
|
||||
(direction: 'left' | 'right' | 'up' | 'down', alt?: boolean) => {
|
||||
const { index, image } = getImageFuncs[direction](imageDTOs, lastSelectedImageIndex);
|
||||
if (!data) {
|
||||
return;
|
||||
}
|
||||
const { index, image } = getImageFuncs[direction](imagesSelectors.selectAll(data), lastSelectedImageIndex);
|
||||
if (!image || index === lastSelectedImageIndex) {
|
||||
return;
|
||||
}
|
||||
@@ -157,7 +161,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
}
|
||||
scrollToImage(image.image_name, index);
|
||||
},
|
||||
[imageDTOs, lastSelectedImageIndex, dispatch]
|
||||
[data, lastSelectedImageIndex, dispatch]
|
||||
);
|
||||
|
||||
const isOnFirstImage = useMemo(() => lastSelectedImageIndex === 0, [lastSelectedImageIndex]);
|
||||
@@ -172,14 +176,6 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
return lastSelectedImageIndex + imagesPerRow < loadedImagesCount;
|
||||
}, [lastSelectedImageIndex, loadedImagesCount]);
|
||||
|
||||
const isOnFirstImageOfView = useMemo(() => {
|
||||
return lastSelectedImageIndex === 0;
|
||||
}, [lastSelectedImageIndex]);
|
||||
|
||||
const isOnLastImageOfView = useMemo(() => {
|
||||
return lastSelectedImageIndex === loadedImagesCount - 1;
|
||||
}, [lastSelectedImageIndex, loadedImagesCount]);
|
||||
|
||||
const handleLeftImage = useCallback(
|
||||
(alt?: boolean) => {
|
||||
handleNavigation('left', alt);
|
||||
@@ -226,7 +222,5 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
areImagesBelowCurrent,
|
||||
nextImage,
|
||||
prevImage,
|
||||
isOnFirstImageOfView,
|
||||
isOnLastImageOfView,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { offsetChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback, useEffect, useMemo } from 'react';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
|
||||
export const useGalleryPagination = (pageButtonsPerSide: number = 2) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { offset, limit } = useAppSelector((s) => s.gallery);
|
||||
const queryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
|
||||
const { count, total } = useListImagesQuery(queryArgs, {
|
||||
selectFromResult: ({ data }) => ({ count: data?.items.length ?? 0, total: data?.total ?? 0 }),
|
||||
});
|
||||
|
||||
const currentPage = useMemo(() => Math.ceil(offset / (limit || 0)), [offset, limit]);
|
||||
const pages = useMemo(() => Math.ceil(total / (limit || 0)), [total, limit]);
|
||||
|
||||
const isNextEnabled = useMemo(() => {
|
||||
if (!count) {
|
||||
return false;
|
||||
}
|
||||
return currentPage + 1 < pages;
|
||||
}, [count, currentPage, pages]);
|
||||
const isPrevEnabled = useMemo(() => {
|
||||
if (!count) {
|
||||
return false;
|
||||
}
|
||||
return offset > 0;
|
||||
}, [count, offset]);
|
||||
|
||||
const goNext = useCallback(() => {
|
||||
dispatch(offsetChanged(offset + (limit || 0)));
|
||||
}, [dispatch, offset, limit]);
|
||||
|
||||
const goPrev = useCallback(() => {
|
||||
dispatch(offsetChanged(Math.max(offset - (limit || 0), 0)));
|
||||
}, [dispatch, offset, limit]);
|
||||
|
||||
const goToPage = useCallback(
|
||||
(page: number) => {
|
||||
dispatch(offsetChanged(page * (limit || 0)));
|
||||
},
|
||||
[dispatch, limit]
|
||||
);
|
||||
const goToFirst = useCallback(() => {
|
||||
dispatch(offsetChanged(0));
|
||||
}, [dispatch]);
|
||||
const goToLast = useCallback(() => {
|
||||
dispatch(offsetChanged((pages - 1) * (limit || 0)));
|
||||
}, [dispatch, pages, limit]);
|
||||
|
||||
// handle when total/pages decrease and user is on high page number (ie bulk removing or deleting)
|
||||
useEffect(() => {
|
||||
if (pages && currentPage + 1 > pages) {
|
||||
goToLast();
|
||||
}
|
||||
}, [currentPage, pages, goToLast]);
|
||||
|
||||
// calculate the page buttons to display - current page with 3 around it
|
||||
const pageButtons = useMemo(() => {
|
||||
const buttons = [];
|
||||
const maxPageButtons = pageButtonsPerSide * 2 + 1;
|
||||
let startPage = Math.max(currentPage - Math.floor(maxPageButtons / 2), 0);
|
||||
const endPage = Math.min(startPage + maxPageButtons - 1, pages - 1);
|
||||
|
||||
if (endPage - startPage < maxPageButtons - 1) {
|
||||
startPage = Math.max(endPage - maxPageButtons + 1, 0);
|
||||
}
|
||||
|
||||
for (let i = startPage; i <= endPage; i++) {
|
||||
buttons.push(i);
|
||||
}
|
||||
|
||||
return buttons;
|
||||
}, [currentPage, pageButtonsPerSide, pages]);
|
||||
|
||||
const isFirstEnabled = useMemo(() => currentPage > 0, [currentPage]);
|
||||
const isLastEnabled = useMemo(() => currentPage < pages - 1, [currentPage, pages]);
|
||||
|
||||
const rangeDisplay = useMemo(() => {
|
||||
const startItem = currentPage * (limit || 0) + 1;
|
||||
const endItem = Math.min((currentPage + 1) * (limit || 0), total);
|
||||
return `${startItem}-${endItem} of ${total}`;
|
||||
}, [total, currentPage, limit]);
|
||||
|
||||
const numberOnPage = useMemo(() => {
|
||||
return Math.min((currentPage + 1) * (limit || 0), total);
|
||||
}, [currentPage, limit, total]);
|
||||
|
||||
const api = useMemo(
|
||||
() => ({
|
||||
count,
|
||||
total,
|
||||
currentPage,
|
||||
pages,
|
||||
isNextEnabled,
|
||||
isPrevEnabled,
|
||||
goNext,
|
||||
goPrev,
|
||||
goToPage,
|
||||
goToFirst,
|
||||
goToLast,
|
||||
pageButtons,
|
||||
isFirstEnabled,
|
||||
isLastEnabled,
|
||||
rangeDisplay,
|
||||
numberOnPage,
|
||||
}),
|
||||
[
|
||||
count,
|
||||
total,
|
||||
currentPage,
|
||||
pages,
|
||||
isNextEnabled,
|
||||
isPrevEnabled,
|
||||
goNext,
|
||||
goPrev,
|
||||
goToPage,
|
||||
goToFirst,
|
||||
goToLast,
|
||||
pageButtons,
|
||||
isFirstEnabled,
|
||||
isLastEnabled,
|
||||
rangeDisplay,
|
||||
numberOnPage,
|
||||
]
|
||||
);
|
||||
|
||||
return api;
|
||||
};
|
||||
@@ -1,5 +1,3 @@
|
||||
import type { SkipToken } from '@reduxjs/toolkit/query';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
@@ -12,15 +10,11 @@ export const selectLastSelectedImage = createMemoizedSelector(
|
||||
|
||||
export const selectListImagesQueryArgs = createMemoizedSelector(
|
||||
selectGallerySlice,
|
||||
(gallery): ListImagesArgs | SkipToken =>
|
||||
gallery.limit
|
||||
? {
|
||||
board_id: gallery.selectedBoardId,
|
||||
categories: gallery.galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES,
|
||||
offset: gallery.offset,
|
||||
limit: gallery.limit,
|
||||
is_intermediate: false,
|
||||
search_term: gallery.searchTerm,
|
||||
}
|
||||
: skipToken
|
||||
(gallery): ListImagesArgs => ({
|
||||
board_id: gallery.selectedBoardId,
|
||||
categories: gallery.galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES,
|
||||
offset: gallery.offset,
|
||||
limit: gallery.limit,
|
||||
is_intermediate: false,
|
||||
})
|
||||
);
|
||||
|
||||
@@ -7,7 +7,7 @@ import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import type { BoardId, ComparisonMode, GalleryState, GalleryView } from './types';
|
||||
import { IMAGE_LIMIT } from './types';
|
||||
import { IMAGE_LIMIT, INITIAL_IMAGE_LIMIT } from './types';
|
||||
|
||||
const initialGalleryState: GalleryState = {
|
||||
selection: [],
|
||||
@@ -19,7 +19,7 @@ const initialGalleryState: GalleryState = {
|
||||
selectedBoardId: 'none',
|
||||
galleryView: 'images',
|
||||
boardSearchText: '',
|
||||
limit: 20,
|
||||
limit: INITIAL_IMAGE_LIMIT,
|
||||
offset: 0,
|
||||
isImageViewerOpen: true,
|
||||
imageToCompare: null,
|
||||
@@ -72,6 +72,7 @@ export const gallerySlice = createSlice({
|
||||
state.selectedBoardId = action.payload.boardId;
|
||||
state.galleryView = 'images';
|
||||
state.offset = 0;
|
||||
state.limit = INITIAL_IMAGE_LIMIT;
|
||||
},
|
||||
autoAddBoardIdChanged: (state, action: PayloadAction<BoardId>) => {
|
||||
if (!action.payload) {
|
||||
@@ -83,11 +84,20 @@ export const gallerySlice = createSlice({
|
||||
galleryViewChanged: (state, action: PayloadAction<GalleryView>) => {
|
||||
state.galleryView = action.payload;
|
||||
state.offset = 0;
|
||||
state.limit = IMAGE_LIMIT;
|
||||
state.limit = INITIAL_IMAGE_LIMIT;
|
||||
},
|
||||
boardSearchTextChanged: (state, action: PayloadAction<string>) => {
|
||||
state.boardSearchText = action.payload;
|
||||
},
|
||||
moreImagesLoaded: (state) => {
|
||||
if (state.offset === 0 && state.limit === INITIAL_IMAGE_LIMIT) {
|
||||
state.offset = INITIAL_IMAGE_LIMIT;
|
||||
state.limit = IMAGE_LIMIT;
|
||||
} else {
|
||||
state.offset += IMAGE_LIMIT;
|
||||
state.limit += IMAGE_LIMIT;
|
||||
}
|
||||
},
|
||||
alwaysShowImageSizeBadgeChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.alwaysShowImageSizeBadge = action.payload;
|
||||
},
|
||||
@@ -104,15 +114,6 @@ export const gallerySlice = createSlice({
|
||||
comparisonFitChanged: (state, action: PayloadAction<'contain' | 'fill'>) => {
|
||||
state.comparisonFit = action.payload;
|
||||
},
|
||||
offsetChanged: (state, action: PayloadAction<number>) => {
|
||||
state.offset = action.payload;
|
||||
},
|
||||
limitChanged: (state, action: PayloadAction<number>) => {
|
||||
state.limit = action.payload;
|
||||
},
|
||||
searchTermChanged: (state, action: PayloadAction<string | undefined>) => {
|
||||
state.searchTerm = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
|
||||
@@ -148,6 +149,7 @@ export const {
|
||||
galleryViewChanged,
|
||||
selectionChanged,
|
||||
boardSearchTextChanged,
|
||||
moreImagesLoaded,
|
||||
alwaysShowImageSizeBadgeChanged,
|
||||
isImageViewerOpenChanged,
|
||||
imageToCompareChanged,
|
||||
@@ -155,9 +157,6 @@ export const {
|
||||
comparedImagesSwapped,
|
||||
comparisonFitChanged,
|
||||
comparisonModeCycled,
|
||||
offsetChanged,
|
||||
limitChanged,
|
||||
searchTermChanged,
|
||||
} = gallerySlice.actions;
|
||||
|
||||
const isAnyBoardDeleted = isAnyOf(
|
||||
|
||||
@@ -2,7 +2,8 @@ import type { ImageCategory, ImageDTO } from 'services/api/types';
|
||||
|
||||
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
|
||||
export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other'];
|
||||
export const IMAGE_LIMIT = 15;
|
||||
export const INITIAL_IMAGE_LIMIT = 100;
|
||||
export const IMAGE_LIMIT = 20;
|
||||
|
||||
export type GalleryView = 'images' | 'assets';
|
||||
export type BoardId = 'none' | (string & Record<never, never>);
|
||||
@@ -20,7 +21,6 @@ export type GalleryState = {
|
||||
boardSearchText: string;
|
||||
offset: number;
|
||||
limit: number;
|
||||
searchTerm?: string;
|
||||
alwaysShowImageSizeBadge: boolean;
|
||||
imageToCompare: ImageDTO | null;
|
||||
comparisonMode: ComparisonMode;
|
||||
|
||||
@@ -11,6 +11,7 @@ const BASE_COLOR_MAP: Record<BaseModelType, string> = {
|
||||
any: 'base',
|
||||
'sd-1': 'green',
|
||||
'sd-2': 'teal',
|
||||
'sd-3': 'purple',
|
||||
sdxl: 'invokeBlue',
|
||||
'sdxl-refiner': 'invokeBlue',
|
||||
};
|
||||
|
||||
@@ -10,6 +10,7 @@ import type { UpdateModelArg } from 'services/api/endpoints/models';
|
||||
const options: ComboboxOption[] = [
|
||||
{ value: 'sd-1', label: MODEL_TYPE_MAP['sd-1'] },
|
||||
{ value: 'sd-2', label: MODEL_TYPE_MAP['sd-2'] },
|
||||
{ value: 'sd-3', label: MODEL_TYPE_MAP['sd-3'] },
|
||||
{ value: 'sdxl', label: MODEL_TYPE_MAP['sdxl'] },
|
||||
{ value: 'sdxl-refiner', label: MODEL_TYPE_MAP['sdxl-refiner'] },
|
||||
];
|
||||
|
||||
@@ -28,6 +28,8 @@ import {
|
||||
isModelIdentifierFieldInputTemplate,
|
||||
isSchedulerFieldInputInstance,
|
||||
isSchedulerFieldInputTemplate,
|
||||
isSD3MainModelFieldInputInstance,
|
||||
isSD3MainModelFieldInputTemplate,
|
||||
isSDXLMainModelFieldInputInstance,
|
||||
isSDXLMainModelFieldInputTemplate,
|
||||
isSDXLRefinerModelFieldInputInstance,
|
||||
@@ -53,6 +55,7 @@ import MainModelFieldInputComponent from './inputs/MainModelFieldInputComponent'
|
||||
import NumberFieldInputComponent from './inputs/NumberFieldInputComponent';
|
||||
import RefinerModelFieldInputComponent from './inputs/RefinerModelFieldInputComponent';
|
||||
import SchedulerFieldInputComponent from './inputs/SchedulerFieldInputComponent';
|
||||
import SD3MainModelFieldInputComponent from './inputs/SD3MainModelFieldInputComponent';
|
||||
import SDXLMainModelFieldInputComponent from './inputs/SDXLMainModelFieldInputComponent';
|
||||
import StringFieldInputComponent from './inputs/StringFieldInputComponent';
|
||||
import T2IAdapterModelFieldInputComponent from './inputs/T2IAdapterModelFieldInputComponent';
|
||||
@@ -133,6 +136,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
|
||||
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
|
||||
import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useSD3Models } from 'services/api/hooks/modelsByType';
|
||||
import type { MainModelConfig } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
type Props = FieldComponentProps<SD3MainModelFieldInputInstance, SD3MainModelFieldInputTemplate>;
|
||||
|
||||
const SD3MainModelFieldInputComponent = (props: Props) => {
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const [modelConfigs, { isLoading }] = useSD3Models();
|
||||
const _onChange = useCallback(
|
||||
(value: MainModelConfig | null) => {
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldMainModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
|
||||
modelConfigs,
|
||||
onChange: _onChange,
|
||||
isLoading,
|
||||
selectedModel: field.value,
|
||||
});
|
||||
|
||||
return (
|
||||
<Flex w="full" alignItems="center" gap={2}>
|
||||
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
|
||||
<Combobox
|
||||
value={value}
|
||||
placeholder={placeholder}
|
||||
options={options}
|
||||
onChange={onChange}
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(SD3MainModelFieldInputComponent);
|
||||
@@ -631,6 +631,7 @@ export const schema = {
|
||||
'euler',
|
||||
'euler_k',
|
||||
'euler_a',
|
||||
'euler_f',
|
||||
'kdpm_2',
|
||||
'kdpm_2_a',
|
||||
'dpmpp_2s',
|
||||
@@ -694,6 +695,7 @@ export const schema = {
|
||||
'euler',
|
||||
'euler_k',
|
||||
'euler_a',
|
||||
'euler_f',
|
||||
'kdpm_2',
|
||||
'kdpm_2_a',
|
||||
'dpmpp_2s',
|
||||
@@ -839,7 +841,7 @@ export const schema = {
|
||||
},
|
||||
BaseModelType: {
|
||||
description: 'Base model type.',
|
||||
enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'],
|
||||
enum: ['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner'],
|
||||
title: 'BaseModelType',
|
||||
type: 'string',
|
||||
},
|
||||
@@ -855,8 +857,11 @@ export const schema = {
|
||||
'unet',
|
||||
'text_encoder',
|
||||
'text_encoder_2',
|
||||
'text_encoder_3',
|
||||
'tokenizer',
|
||||
'tokenizer_2',
|
||||
'tokenizer_3',
|
||||
'transformer',
|
||||
'vae',
|
||||
'vae_decoder',
|
||||
'vae_encoder',
|
||||
|
||||
@@ -47,6 +47,7 @@ export const zSchedulerField = z.enum([
|
||||
'heun_k',
|
||||
'lms_k',
|
||||
'euler_a',
|
||||
'euler_f',
|
||||
'kdpm_2_a',
|
||||
'lcm',
|
||||
'tcd',
|
||||
@@ -55,7 +56,7 @@ export type SchedulerField = z.infer<typeof zSchedulerField>;
|
||||
// #endregion
|
||||
|
||||
// #region Model-related schemas
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner']);
|
||||
const zModelType = z.enum([
|
||||
'main',
|
||||
'vae',
|
||||
@@ -71,8 +72,11 @@ const zSubModelType = z.enum([
|
||||
'unet',
|
||||
'text_encoder',
|
||||
'text_encoder_2',
|
||||
'text_encoder_3',
|
||||
'tokenizer',
|
||||
'tokenizer_2',
|
||||
'tokenizer_3',
|
||||
'transformer',
|
||||
'vae',
|
||||
'vae_decoder',
|
||||
'vae_encoder',
|
||||
|
||||
@@ -32,11 +32,14 @@ export const MODEL_TYPES = [
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'SD3MainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'VaeModelField',
|
||||
'UNetField',
|
||||
'TransformerField',
|
||||
'VAEField',
|
||||
'CLIPField',
|
||||
'SD3CLIPField',
|
||||
'T2IAdapterModelField',
|
||||
];
|
||||
|
||||
@@ -47,6 +50,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
BoardField: 'purple.500',
|
||||
BooleanField: 'green.500',
|
||||
CLIPField: 'green.500',
|
||||
SD3CLIPField: 'green.500',
|
||||
ColorField: 'pink.300',
|
||||
ConditioningField: 'cyan.500',
|
||||
ControlField: 'teal.500',
|
||||
@@ -62,10 +66,12 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
MainModelField: 'teal.500',
|
||||
SDXLMainModelField: 'teal.500',
|
||||
SDXLRefinerModelField: 'teal.500',
|
||||
SD3MainModelField: 'teal.500',
|
||||
StringField: 'yellow.500',
|
||||
T2IAdapterField: 'teal.500',
|
||||
T2IAdapterModelField: 'teal.500',
|
||||
UNetField: 'red.500',
|
||||
TransformerField: 'red.500',
|
||||
VAEField: 'blue.500',
|
||||
VAEModelField: 'teal.500',
|
||||
};
|
||||
|
||||
@@ -119,6 +119,10 @@ const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SDXLRefinerModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zSD3MainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SD3MainModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('VAEModelField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -155,6 +159,7 @@ const zStatefulFieldType = z.union([
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zSD3MainModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
zControlNetModelFieldType,
|
||||
@@ -466,6 +471,28 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR
|
||||
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region SD3MainModelField
|
||||
|
||||
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
|
||||
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zSD3MainModelFieldValue,
|
||||
});
|
||||
const zSD3MainModelFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zSD3MainModelFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zSD3MainModelFieldValue,
|
||||
});
|
||||
const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zSD3MainModelFieldType,
|
||||
});
|
||||
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
|
||||
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
|
||||
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
|
||||
zSD3MainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
|
||||
zSD3MainModelFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region VAEModelField
|
||||
|
||||
export const zVAEModelFieldValue = zModelIdentifierField.optional();
|
||||
@@ -662,6 +689,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zMainModelFieldValue,
|
||||
zSDXLMainModelFieldValue,
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zSD3MainModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
@@ -689,6 +717,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zMainModelFieldInputInstance,
|
||||
zSDXLMainModelFieldInputInstance,
|
||||
zSDXLRefinerModelFieldInputInstance,
|
||||
zSD3MainModelFieldInputInstance,
|
||||
zVAEModelFieldInputInstance,
|
||||
zLoRAModelFieldInputInstance,
|
||||
zControlNetModelFieldInputInstance,
|
||||
@@ -717,6 +746,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zMainModelFieldInputTemplate,
|
||||
zSDXLMainModelFieldInputTemplate,
|
||||
zSDXLRefinerModelFieldInputTemplate,
|
||||
zSD3MainModelFieldInputTemplate,
|
||||
zVAEModelFieldInputTemplate,
|
||||
zLoRAModelFieldInputTemplate,
|
||||
zControlNetModelFieldInputTemplate,
|
||||
@@ -746,6 +776,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zMainModelFieldOutputTemplate,
|
||||
zSDXLMainModelFieldOutputTemplate,
|
||||
zSDXLRefinerModelFieldOutputTemplate,
|
||||
zSD3MainModelFieldOutputTemplate,
|
||||
zVAEModelFieldOutputTemplate,
|
||||
zLoRAModelFieldOutputTemplate,
|
||||
zControlNetModelFieldOutputTemplate,
|
||||
|
||||
@@ -44,7 +44,7 @@ export const zSchedulerField = z.enum([
|
||||
// #endregion
|
||||
|
||||
// #region Model-related schemas
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sd-3', 'sdxl', 'sdxl-refiner']);
|
||||
const zModelName = z.string().min(3);
|
||||
export const zModelIdentifier = z.object({
|
||||
model_name: zModelName,
|
||||
|
||||
@@ -217,6 +217,20 @@ const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
});
|
||||
// #endregion
|
||||
|
||||
// #region SDXLMainModelField
|
||||
const zSD3MainModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SD3MainModelField'),
|
||||
});
|
||||
const zSD3MainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SD3 models only.
|
||||
const zSD3MainModelFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
type: zSD3MainModelFieldType,
|
||||
value: zSD3MainModelFieldValue,
|
||||
});
|
||||
const zSD3MainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({
|
||||
type: zSD3MainModelFieldType,
|
||||
});
|
||||
// #endregion
|
||||
|
||||
// #region VAEModelField
|
||||
const zVAEModelFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('VAEModelField'),
|
||||
@@ -339,6 +353,7 @@ const zStatefulFieldType = z.union([
|
||||
zMainModelFieldType,
|
||||
zSDXLMainModelFieldType,
|
||||
zSDXLRefinerModelFieldType,
|
||||
zSD3MainModelFieldType,
|
||||
zVAEModelFieldType,
|
||||
zLoRAModelFieldType,
|
||||
zControlNetModelFieldType,
|
||||
@@ -378,6 +393,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zMainModelFieldInputInstance,
|
||||
zSDXLMainModelFieldInputInstance,
|
||||
zSDXLRefinerModelFieldInputInstance,
|
||||
zSD3MainModelFieldInputInstance,
|
||||
zVAEModelFieldInputInstance,
|
||||
zLoRAModelFieldInputInstance,
|
||||
zControlNetModelFieldInputInstance,
|
||||
@@ -402,6 +418,7 @@ const zStatefulFieldOutputInstance = z.union([
|
||||
zMainModelFieldOutputInstance,
|
||||
zSDXLMainModelFieldOutputInstance,
|
||||
zSDXLRefinerModelFieldOutputInstance,
|
||||
zSD3MainModelFieldOutputInstance,
|
||||
zVAEModelFieldOutputInstance,
|
||||
zLoRAModelFieldOutputInstance,
|
||||
zControlNetModelFieldOutputInstance,
|
||||
|
||||
@@ -15,6 +15,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
|
||||
MainModelField: undefined,
|
||||
SchedulerField: 'euler',
|
||||
SDXLMainModelField: undefined,
|
||||
SD3MainModelField: undefined,
|
||||
SDXLRefinerModelField: undefined,
|
||||
StringField: '',
|
||||
T2IAdapterModelField: undefined,
|
||||
|
||||
@@ -15,6 +15,7 @@ import type {
|
||||
MainModelFieldInputTemplate,
|
||||
ModelIdentifierFieldInputTemplate,
|
||||
SchedulerFieldInputTemplate,
|
||||
SD3MainModelFieldInputTemplate,
|
||||
SDXLMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelFieldInputTemplate,
|
||||
StatefulFieldType,
|
||||
@@ -193,6 +194,20 @@ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder<SDXLRefiner
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildSD3MainModelFieldInputTemplate: FieldInputTemplateBuilder<SD3MainModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: SD3MainModelFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder<VAEModelFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -375,6 +390,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
SchedulerField: buildSchedulerFieldInputTemplate,
|
||||
SDXLMainModelField: buildSDXLMainModelFieldInputTemplate,
|
||||
SDXLRefinerModelField: buildRefinerModelFieldInputTemplate,
|
||||
SD3MainModelField: buildSD3MainModelFieldInputTemplate,
|
||||
StringField: buildStringFieldInputTemplate,
|
||||
T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate,
|
||||
VAEModelField: buildVAEModelFieldInputTemplate,
|
||||
|
||||
@@ -30,6 +30,7 @@ const MODEL_FIELD_TYPES = [
|
||||
'MainModelField',
|
||||
'SDXLMainModelField',
|
||||
'SDXLRefinerModelField',
|
||||
'SD3MainModelField',
|
||||
'VAEModelField',
|
||||
'LoRAModelField',
|
||||
'ControlNetModelField',
|
||||
|
||||
@@ -39,7 +39,7 @@ const ParamClipSkip = () => {
|
||||
return CLIP_SKIP_MAP[model.base].markers;
|
||||
}, [model]);
|
||||
|
||||
if (model?.base === 'sdxl') {
|
||||
if (model?.base === 'sdxl' || model?.base === 'sd-3') {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ export const MODEL_TYPE_MAP = {
|
||||
any: 'Any',
|
||||
'sd-1': 'Stable Diffusion 1.x',
|
||||
'sd-2': 'Stable Diffusion 2.x',
|
||||
'sd-3': 'Stable Diffusion 3.x',
|
||||
sdxl: 'Stable Diffusion XL',
|
||||
'sdxl-refiner': 'Stable Diffusion XL Refiner',
|
||||
};
|
||||
@@ -18,6 +19,7 @@ export const MODEL_TYPE_SHORT_MAP = {
|
||||
any: 'Any',
|
||||
'sd-1': 'SD1.X',
|
||||
'sd-2': 'SD2.X',
|
||||
'sd-3': 'SD3.X',
|
||||
sdxl: 'SDXL',
|
||||
'sdxl-refiner': 'SDXLR',
|
||||
};
|
||||
@@ -38,6 +40,11 @@ export const CLIP_SKIP_MAP = {
|
||||
maxClip: 24,
|
||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||
},
|
||||
// TODO: Update this when we have more details on how CLIP SKIP works with SD3
|
||||
'sd-3': {
|
||||
maxClip: 24,
|
||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||
},
|
||||
sdxl: {
|
||||
maxClip: 24,
|
||||
markers: [0, 1, 2, 3, 5, 10, 15, 20, 24],
|
||||
@@ -73,6 +80,7 @@ export const SCHEDULER_OPTIONS: ComboboxOption[] = [
|
||||
{ value: 'heun_k', label: 'Heun Karras' },
|
||||
{ value: 'lms_k', label: 'LMS Karras' },
|
||||
{ value: 'euler_a', label: 'Euler Ancestral' },
|
||||
{ value: 'euler_f', label: 'Euler Flow Match' },
|
||||
{ value: 'kdpm_2_a', label: 'KDPM 2 Ancestral' },
|
||||
{ value: 'lcm', label: 'LCM' },
|
||||
{ value: 'tcd', label: 'TCD' },
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -10,6 +10,7 @@ import {
|
||||
isNonRefinerMainModelConfig,
|
||||
isNonSDXLMainModelConfig,
|
||||
isRefinerMainModelModelConfig,
|
||||
isSD3MainModelModelConfig,
|
||||
isSDXLMainModelModelConfig,
|
||||
isT2IAdapterModelConfig,
|
||||
isTIModelConfig,
|
||||
@@ -35,6 +36,7 @@ export const useMainModels = buildModelsHook(isNonRefinerMainModelConfig);
|
||||
export const useNonSDXLMainModels = buildModelsHook(isNonSDXLMainModelConfig);
|
||||
export const useRefinerModels = buildModelsHook(isRefinerMainModelModelConfig);
|
||||
export const useSDXLModels = buildModelsHook(isSDXLMainModelModelConfig);
|
||||
export const useSD3Models = buildModelsHook(isSD3MainModelModelConfig);
|
||||
export const useLoRAModels = buildModelsHook(isLoRAModelConfig);
|
||||
export const useControlNetAndT2IAdapterModels = buildModelsHook(isControlNetOrT2IAdapterModelConfig);
|
||||
export const useControlNetModels = buildModelsHook(isControlNetModelConfig);
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetBoardAssetsTotalQuery, useGetBoardImagesTotalQuery } from 'services/api/endpoints/boards';
|
||||
|
||||
export const useBoardTotal = (board_id: BoardId) => {
|
||||
const galleryView = useAppSelector((s) => s.gallery.galleryView);
|
||||
|
||||
const { data: totalImages } = useGetBoardImagesTotalQuery(board_id);
|
||||
const { data: totalAssets } = useGetBoardAssetsTotalQuery(board_id);
|
||||
|
||||
const currentViewTotal = useMemo(
|
||||
() => (galleryView === 'images' ? totalImages?.total : totalAssets?.total),
|
||||
[galleryView, totalAssets, totalImages]
|
||||
);
|
||||
|
||||
return { totalImages, totalAssets, currentViewTotal };
|
||||
};
|
||||
File diff suppressed because one or more lines are too long
@@ -1,10 +1,12 @@
|
||||
import type { EntityState } from '@reduxjs/toolkit';
|
||||
import type { components, paths } from 'services/api/schema';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
|
||||
export type S = components['schemas'];
|
||||
|
||||
export type ImageCache = EntityState<ImageDTO, string>;
|
||||
|
||||
export type ListImagesArgs = NonNullable<paths['/api/v1/images/']['get']['parameters']['query']>;
|
||||
export type ListImagesResponse = paths['/api/v1/images/']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
export type DeleteBoardResult =
|
||||
paths['/api/v1/boards/{board_id}']['delete']['responses']['200']['content']['application/json'];
|
||||
@@ -107,7 +109,11 @@ export const isSDXLMainModelModelConfig = (config: AnyModelConfig): config is Ma
|
||||
};
|
||||
|
||||
export const isNonSDXLMainModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2');
|
||||
return config.type === 'main' && (config.base === 'sd-1' || config.base === 'sd-2' || config.base === 'sd-3');
|
||||
};
|
||||
|
||||
export const isSD3MainModelModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
return config.type === 'main' && config.base === 'sd-3';
|
||||
};
|
||||
|
||||
export const isTIModelConfig = (config: AnyModelConfig): config is MainModelConfig => {
|
||||
|
||||
@@ -1,8 +1,56 @@
|
||||
import { createEntityAdapter } from '@reduxjs/toolkit';
|
||||
import { getSelectorsOptions } from 'app/store/createMemoizedSelector';
|
||||
import { dateComparator } from 'common/util/dateComparator';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import queryString from 'query-string';
|
||||
import { buildV1Url } from 'services/api';
|
||||
|
||||
import type { ImageDTO, ListImagesArgs } from './types';
|
||||
import type { ImageCache, ImageDTO, ListImagesArgs } from './types';
|
||||
|
||||
export const getIsImageInDateRange = (data: ImageCache | undefined, imageDTO: ImageDTO) => {
|
||||
if (!data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const totalCachedImageDtos = imagesSelectors.selectAll(data);
|
||||
|
||||
if (totalCachedImageDtos.length <= 1) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const cachedStarredImages = [];
|
||||
const cachedUnstarredImages = [];
|
||||
|
||||
for (let index = 0; index < totalCachedImageDtos.length; index++) {
|
||||
const image = totalCachedImageDtos[index];
|
||||
if (image?.starred) {
|
||||
cachedStarredImages.push(image);
|
||||
}
|
||||
if (!image?.starred) {
|
||||
cachedUnstarredImages.push(image);
|
||||
}
|
||||
}
|
||||
|
||||
if (imageDTO.starred) {
|
||||
const lastStarredImage = cachedStarredImages[cachedStarredImages.length - 1];
|
||||
// if starring or already starred, want to look in list of starred images
|
||||
if (!lastStarredImage) {
|
||||
return true;
|
||||
} // no starred images showing, so always show this one
|
||||
const createdDate = new Date(imageDTO.created_at);
|
||||
const oldestDate = new Date(lastStarredImage.created_at);
|
||||
return createdDate >= oldestDate;
|
||||
} else {
|
||||
const lastUnstarredImage = cachedUnstarredImages[cachedUnstarredImages.length - 1];
|
||||
// if unstarring or already unstarred, want to look in list of unstarred images
|
||||
if (!lastUnstarredImage) {
|
||||
return false;
|
||||
} // no unstarred images showing, so don't show this one
|
||||
const createdDate = new Date(imageDTO.created_at);
|
||||
const oldestDate = new Date(lastUnstarredImage.created_at);
|
||||
return createdDate >= oldestDate;
|
||||
}
|
||||
};
|
||||
|
||||
export const getCategories = (imageDTO: ImageDTO) => {
|
||||
if (IMAGE_CATEGORIES.includes(imageDTO.image_category)) {
|
||||
@@ -11,6 +59,25 @@ export const getCategories = (imageDTO: ImageDTO) => {
|
||||
return ASSETS_CATEGORIES;
|
||||
};
|
||||
|
||||
// The adapter is not actually the data store - it just provides helper functions to interact
|
||||
// with some other store of data. We will use the RTK Query cache as that store.
|
||||
export const imagesAdapter = createEntityAdapter<ImageDTO, string>({
|
||||
selectId: (image) => image.image_name,
|
||||
sortComparer: (a, b) => {
|
||||
// Compare starred images first
|
||||
if (a.starred && !b.starred) {
|
||||
return -1;
|
||||
}
|
||||
if (!a.starred && b.starred) {
|
||||
return 1;
|
||||
}
|
||||
return dateComparator(b.created_at, a.created_at);
|
||||
},
|
||||
});
|
||||
|
||||
// Create selectors for the adapter.
|
||||
export const imagesSelectors = imagesAdapter.getSelectors(undefined, getSelectorsOptions);
|
||||
|
||||
// Helper to create the url for the listImages endpoint. Also we use it to create the cache key.
|
||||
export const getListImagesUrl = (queryArgs: ListImagesArgs) =>
|
||||
buildV1Url(`images/?${queryString.stringify(queryArgs, { arrayFormat: 'none' })}`);
|
||||
|
||||
@@ -39,6 +39,7 @@ from invokeai.app.invocations.model import (
|
||||
ModelIdentifierField,
|
||||
ModelLoaderOutput,
|
||||
SDXLLoRALoaderOutput,
|
||||
TransformerField,
|
||||
UNetField,
|
||||
UNetOutput,
|
||||
VAEField,
|
||||
@@ -117,6 +118,7 @@ __all__ = [
|
||||
# invokeai.app.invocations.model
|
||||
"ModelIdentifierField",
|
||||
"UNetField",
|
||||
"TransformerField",
|
||||
"CLIPField",
|
||||
"VAEField",
|
||||
"UNetOutput",
|
||||
|
||||
@@ -33,30 +33,32 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
# Core generation dependencies, pinned for reproducible builds.
|
||||
"accelerate==0.30.1",
|
||||
"accelerate",
|
||||
"bitsandbytes",
|
||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==2.0.2",
|
||||
"controlnet-aux==0.0.7",
|
||||
"diffusers[torch]==0.27.2",
|
||||
"diffusers[torch]",
|
||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"mediapipe==0.10.7", # needed for "mediapipeface" controlnet model
|
||||
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
||||
"numpy", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
||||
"onnx==1.15.0",
|
||||
"onnxruntime==1.16.3",
|
||||
"opencv-python==4.9.0.80",
|
||||
"pytorch-lightning==2.1.3",
|
||||
"pytorch-lightning",
|
||||
"safetensors==0.4.3",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"torch==2.2.2",
|
||||
"torchmetrics==0.11.4",
|
||||
"torch",
|
||||
"torchmetrics",
|
||||
"torchsde==0.2.6",
|
||||
"torchvision==0.17.2",
|
||||
"transformers==4.41.1",
|
||||
"torchvision",
|
||||
"transformers",
|
||||
"sentencepiece==0.1.99",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
"fastapi-events==0.11.0",
|
||||
"fastapi==0.111.0",
|
||||
"huggingface-hub==0.23.1",
|
||||
"huggingface-hub",
|
||||
"pydantic-settings==2.2.1",
|
||||
"pydantic==2.7.2",
|
||||
"python-socketio==5.11.1",
|
||||
@@ -73,7 +75,7 @@ dependencies = [
|
||||
"easing-functions",
|
||||
"einops",
|
||||
"facexlib",
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"matplotlib", # needed for plotting of Penner easing functions
|
||||
"npyscreen",
|
||||
"omegaconf",
|
||||
"picklescan",
|
||||
|
||||
Reference in New Issue
Block a user