Compare commits

..

1 Commits

Author SHA1 Message Date
psychedelicious
4a459e2b17 example of drag indicator for layer/gallery tabs 2024-10-10 21:17:55 +10:00
83 changed files with 505 additions and 3494 deletions

View File

@@ -5,10 +5,9 @@ from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
@@ -116,8 +115,6 @@ async def delete_board(
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
)
async def list_boards(
order_by: BoardRecordOrderBy = Query(default=BoardRecordOrderBy.CreatedAt, description="The attribute to order by"),
direction: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The direction to order by"),
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
offset: Optional[int] = Query(default=None, description="The page offset"),
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
@@ -125,9 +122,9 @@ async def list_boards(
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
"""Gets a list of boards"""
if all:
return ApiDependencies.invoker.services.boards.get_all(order_by, direction, include_archived)
return ApiDependencies.invoker.services.boards.get_all(include_archived)
elif offset is not None and limit is not None:
return ApiDependencies.invoker.services.boards.get_many(order_by, direction, offset, limit, include_archived)
return ApiDependencies.invoker.services.boards.get_many(offset, limit, include_archived)
else:
raise HTTPException(
status_code=400,

View File

@@ -88,7 +88,7 @@ async def list_workflows(
default=WorkflowRecordOrderBy.Name, description="The attribute to order by"
),
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
category: WorkflowCategory = Query(default=WorkflowCategory.User, description="The category of workflow to get"),
category: Optional[WorkflowCategory] = Query(default=None, description="The category of workflow to get"),
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets a page of workflows"""

View File

@@ -192,7 +192,6 @@ class FieldDescriptions:
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
class ImageField(BaseModel):

View File

@@ -1,99 +0,0 @@
from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
class FluxControlNetField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_weight: float | list[float] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
instantx_control_mode: int | None = Field(default=-1, description=FieldDescriptions.instantx_control_mode)
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("flux_controlnet_output")
class FluxControlNetOutput(BaseInvocationOutput):
"""FLUX ControlNet info"""
control: FluxControlNetField = OutputField(description=FieldDescriptions.control)
@invocation(
"flux_controlnet",
title="FLUX ControlNet",
tags=["controlnet", "flux"],
category="controlnet",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxControlNetInvocation(BaseInvocation):
"""Collect FLUX ControlNet info to pass to other nodes."""
image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: float | list[float] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)
begin_step_percent: float = InputField(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
# Note: We default to -1 instead of None, because in the workflow editor UI None is not currently supported.
instantx_control_mode: int | None = InputField(default=-1, description=FieldDescriptions.instantx_control_mode)
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> FluxControlNetOutput:
return FluxControlNetOutput(
control=FluxControlNetField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
resize_mode=self.resize_mode,
instantx_control_mode=self.instantx_control_mode,
),
)

View File

@@ -16,16 +16,11 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule_fractional,
@@ -49,7 +44,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.1.0",
version="3.0.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -92,13 +87,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
control: FluxControlNetField | list[FluxControlNetField] | None = InputField(
default=None, input=Input.Connection, description="ControlNet models."
)
controlnet_vae: VAEField | None = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -179,8 +167,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
inpaint_mask = self._prep_inpaint_mask(context, x)
b, _c, latent_h, latent_w = x.shape
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
b, _c, h, w = x.shape
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
@@ -204,21 +192,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
noise=noise,
)
with ExitStack() as exit_stack:
# Prepare ControlNet extensions.
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
controlnet_extensions = self._prep_controlnet_extensions(
context=context,
exit_stack=exit_stack,
latent_height=latent_h,
latent_width=latent_w,
dtype=inference_dtype,
device=x.device,
)
# Load the transformer model.
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
with (
transformer_info.model_on_device() as (cached_weights, transformer),
ExitStack() as exit_stack,
):
assert isinstance(transformer, Flux)
config = transformer_info.config
assert config is not None
@@ -263,7 +242,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
step_callback=self._build_step_callback(context),
guidance=self.guidance,
inpaint_extension=inpaint_extension,
controlnet_extensions=controlnet_extensions,
)
x = unpack(x.float(), self.height, self.width)
@@ -310,104 +288,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# `latents`.
return mask.expand_as(latents)
def _prep_controlnet_extensions(
self,
context: InvocationContext,
exit_stack: ExitStack,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[XLabsControlNetExtension | InstantXControlNetExtension]:
# Normalize the controlnet input to list[ControlField].
controlnets: list[FluxControlNetField]
if self.control is None:
controlnets = []
elif isinstance(self.control, FluxControlNetField):
controlnets = [self.control]
elif isinstance(self.control, list):
controlnets = self.control
else:
raise ValueError(f"Unsupported controlnet type: {type(self.control)}")
# TODO(ryand): Add a field to the model config so that we can distinguish between XLabs and InstantX ControlNets
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
# minimize peak memory.
# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
# keep peak memory down.
controlnet_conds: list[torch.Tensor] = []
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
image = context.images.get_pil(controlnet.image.image_name)
if isinstance(controlnet_model.model, InstantXControlNetFlux):
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
vae_info = context.models.load(self.controlnet_vae.vae)
controlnet_conds.append(
InstantXControlNetExtension.prepare_controlnet_cond(
controlnet_image=image,
vae_info=vae_info,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
resize_mode=controlnet.resize_mode,
)
)
elif isinstance(controlnet_model.model, XLabsControlNetFlux):
controlnet_conds.append(
XLabsControlNetExtension.prepare_controlnet_cond(
controlnet_image=image,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
resize_mode=controlnet.resize_mode,
)
)
# Finally, load the ControlNet models and initialize the ControlNet extensions.
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
for controlnet, controlnet_cond, controlnet_model in zip(
controlnets, controlnet_conds, controlnet_models, strict=True
):
model = exit_stack.enter_context(controlnet_model)
if isinstance(model, XLabsControlNetFlux):
controlnet_extensions.append(
XLabsControlNetExtension(
model=model,
controlnet_cond=controlnet_cond,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
elif isinstance(model, InstantXControlNetFlux):
instantx_control_mode: torch.Tensor | None = None
if controlnet.instantx_control_mode is not None and controlnet.instantx_control_mode >= 0:
instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long)
instantx_control_mode = instantx_control_mode.reshape([-1, 1])
controlnet_extensions.append(
InstantXControlNetExtension(
model=model,
controlnet_cond=controlnet_cond,
instantx_control_mode=instantx_control_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
else:
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
return controlnet_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)

View File

@@ -1,8 +1,7 @@
from abc import ABC, abstractmethod
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecord, BoardRecordOrderBy
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecord
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
class BoardRecordStorageBase(ABC):
@@ -40,19 +39,12 @@ class BoardRecordStorageBase(ABC):
@abstractmethod
def get_many(
self,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardRecord]:
"""Gets many board records."""
pass
@abstractmethod
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
"""Gets all board records."""
pass

View File

@@ -1,10 +1,8 @@
from datetime import datetime
from enum import Enum
from typing import Optional, Union
from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
@@ -62,13 +60,6 @@ class BoardChanges(BaseModel, extra="forbid"):
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
class BoardRecordOrderBy(str, Enum, metaclass=MetaEnum):
"""The order by options for board records"""
CreatedAt = "created_at"
Name = "board_name"
class BoardRecordNotFoundException(Exception):
"""Raised when an board record is not found."""

View File

@@ -8,12 +8,10 @@ from invokeai.app.services.board_records.board_records_common import (
BoardRecord,
BoardRecordDeleteException,
BoardRecordNotFoundException,
BoardRecordOrderBy,
BoardRecordSaveException,
deserialize_board_record,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string
@@ -146,12 +144,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
return self.get(board_id)
def get_many(
self,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardRecord]:
try:
self._lock.acquire()
@@ -161,16 +154,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
ORDER BY created_at DESC
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
if include_archived:
archived_filter = ""
else:
archived_filter = "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(archived_filter=archived_filter)
# Execute query to fetch boards
self._cursor.execute(final_query, (limit, offset))
@@ -204,32 +198,23 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
finally:
self._lock.release()
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
try:
self._lock.acquire()
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY created_at DESC
"""
if include_archived:
archived_filter = ""
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
archived_filter = "WHERE archived = 0"
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(archived_filter=archived_filter)
self._cursor.execute(final_query)

View File

@@ -1,9 +1,8 @@
from abc import ABC, abstractmethod
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
class BoardServiceABC(ABC):
@@ -44,19 +43,12 @@ class BoardServiceABC(ABC):
@abstractmethod
def get_many(
self,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardDTO]:
"""Gets many boards."""
pass
@abstractmethod
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardDTO]:
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
"""Gets all boards."""
pass

View File

@@ -1,9 +1,8 @@
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
from invokeai.app.services.board_records.board_records_common import BoardChanges
from invokeai.app.services.boards.boards_base import BoardServiceABC
from invokeai.app.services.boards.boards_common import BoardDTO, board_record_to_dto
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
class BoardService(BoardServiceABC):
@@ -48,16 +47,9 @@ class BoardService(BoardServiceABC):
self.__invoker.services.board_records.delete(board_id)
def get_many(
self,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
self, offset: int = 0, limit: int = 10, include_archived: bool = False
) -> OffsetPaginatedResults[BoardDTO]:
board_records = self.__invoker.services.board_records.get_many(
order_by, direction, offset, limit, include_archived
)
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
board_dtos = []
for r in board_records.items:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
@@ -71,10 +63,8 @@ class BoardService(BoardServiceABC):
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all(order_by, direction, include_archived)
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
board_records = self.__invoker.services.board_records.get_all(include_archived)
board_dtos = []
for r in board_records:
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)

View File

@@ -184,8 +184,7 @@ class ModelInstallService(ModelInstallServiceBase):
) # type: ignore
if preferred_name := config.name:
if model_path.suffix:
preferred_name = f"{preferred_name}.{model_path.suffix}"
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
dest_path = (
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)

View File

@@ -41,9 +41,9 @@ class WorkflowRecordsStorageBase(ABC):
self,
order_by: WorkflowRecordOrderBy,
direction: SQLiteDirection,
category: WorkflowCategory,
page: int,
per_page: Optional[int],
category: Optional[WorkflowCategory],
query: Optional[str],
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets many workflows."""

View File

@@ -127,9 +127,9 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
self,
order_by: WorkflowRecordOrderBy,
direction: SQLiteDirection,
category: WorkflowCategory,
page: int = 0,
per_page: Optional[int] = None,
category: Optional[WorkflowCategory] = None,
query: Optional[str] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
try:
@@ -137,7 +137,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
assert category in WorkflowCategory
count_query = "SELECT COUNT(*) FROM workflow_library"
main_query = """
SELECT
@@ -149,16 +148,26 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
updated_at,
opened_at
FROM workflow_library
WHERE category = ?
"""
main_params: list[int | str] = [category.value]
count_params: list[int | str] = [category.value]
main_params: list[int | str] = []
count_params: list[int | str] = []
if category:
assert category in WorkflowCategory
main_query += " WHERE category = ?"
count_query += " WHERE category = ?"
main_params.append(category.value)
count_params.append(category.value)
stripped_query = query.strip() if query else None
if stripped_query:
wildcard_query = "%" + stripped_query + "%"
main_query += " AND name LIKE ? OR description LIKE ? "
count_query += " AND name LIKE ? OR description LIKE ?;"
if "WHERE" in main_query:
main_query += " AND (name LIKE ? OR description LIKE ?)"
count_query += " AND (name LIKE ? OR description LIKE ?)"
else:
main_query += " WHERE name LIKE ? OR description LIKE ?"
count_query += " WHERE name LIKE ? OR description LIKE ?"
main_params.extend([wildcard_query, wildcard_query])
count_params.extend([wildcard_query, wildcard_query])

View File

@@ -1,58 +0,0 @@
from dataclasses import dataclass
import torch
@dataclass
class ControlNetFluxOutput:
single_block_residuals: list[torch.Tensor] | None
double_block_residuals: list[torch.Tensor] | None
def apply_weight(self, weight: float):
if self.single_block_residuals is not None:
for i in range(len(self.single_block_residuals)):
self.single_block_residuals[i] = self.single_block_residuals[i] * weight
if self.double_block_residuals is not None:
for i in range(len(self.double_block_residuals)):
self.double_block_residuals[i] = self.double_block_residuals[i] * weight
def add_tensor_lists_elementwise(
list1: list[torch.Tensor] | None, list2: list[torch.Tensor] | None
) -> list[torch.Tensor] | None:
"""Add two tensor lists elementwise that could be None."""
if list1 is None and list2 is None:
return None
if list1 is None:
return list2
if list2 is None:
return list1
new_list: list[torch.Tensor] = []
for list1_tensor, list2_tensor in zip(list1, list2, strict=True):
new_list.append(list1_tensor + list2_tensor)
return new_list
def add_controlnet_flux_outputs(
controlnet_output_1: ControlNetFluxOutput, controlnet_output_2: ControlNetFluxOutput
) -> ControlNetFluxOutput:
return ControlNetFluxOutput(
single_block_residuals=add_tensor_lists_elementwise(
controlnet_output_1.single_block_residuals, controlnet_output_2.single_block_residuals
),
double_block_residuals=add_tensor_lists_elementwise(
controlnet_output_1.double_block_residuals, controlnet_output_2.double_block_residuals
),
)
def sum_controlnet_flux_outputs(
controlnet_outputs: list[ControlNetFluxOutput],
) -> ControlNetFluxOutput:
controlnet_output_sum = ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
for controlnet_output in controlnet_outputs:
controlnet_output_sum = add_controlnet_flux_outputs(controlnet_output_sum, controlnet_output)
return controlnet_output_sum

View File

@@ -1,180 +0,0 @@
# This file was initially copied from:
# https://github.com/huggingface/diffusers/blob/99f608218caa069a2f16dcf9efab46959b15aec0/src/diffusers/models/controlnet_flux.py
from dataclasses import dataclass
import torch
import torch.nn as nn
from invokeai.backend.flux.controlnet.zero_module import zero_module
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
class InstantXControlNetFluxOutput:
controlnet_block_samples: list[torch.Tensor] | None
controlnet_single_block_samples: list[torch.Tensor] | None
# NOTE(ryand): Mapping between diffusers FLUX transformer params and BFL FLUX transformer params:
# - Diffusers: BFL
# - in_channels: in_channels
# - num_layers: depth
# - num_single_layers: depth_single_blocks
# - attention_head_dim: hidden_size // num_heads
# - num_attention_heads: num_heads
# - joint_attention_dim: context_in_dim
# - pooled_projection_dim: vec_in_dim
# - guidance_embeds: guidance_embed
# - axes_dims_rope: axes_dim
class InstantXControlNetFlux(torch.nn.Module):
def __init__(self, params: FluxParams, num_control_modes: int | None = None):
"""
Args:
params (FluxParams): The parameters for the FLUX model.
num_control_modes (int | None, optional): The number of controlnet modes. If non-None, then the model is a
'union controlnet' model and expects a mode conditioning input at runtime.
"""
super().__init__()
# The following modules mirror the base FLUX transformer model.
# -------------------------------------------------------------
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
# The following modules are specific to the ControlNet model.
# -----------------------------------------------------------
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.double_blocks)):
self.controlnet_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_blocks)):
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
self.is_union = False
if num_control_modes is not None:
self.is_union = True
self.controlnet_mode_embedder = nn.Embedding(num_control_modes, self.hidden_size)
self.controlnet_x_embedder = zero_module(torch.nn.Linear(self.in_channels, self.hidden_size))
def forward(
self,
controlnet_cond: torch.Tensor,
controlnet_mode: torch.Tensor | None,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
timesteps: torch.Tensor,
y: torch.Tensor,
guidance: torch.Tensor | None = None,
) -> InstantXControlNetFluxOutput:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
# Add controlnet_cond embedding.
img = img + self.controlnet_x_embedder(controlnet_cond)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
# If this is a union ControlNet, then concat the control mode embedding to the T5 text embedding.
if self.is_union:
if controlnet_mode is None:
# We allow users to enter 'None' as the controlnet_mode if they don't want to worry about this input.
# We've chosen to use a zero-embedding in this case.
zero_index = torch.zeros([1, 1], dtype=torch.long, device=txt.device)
controlnet_mode_emb = torch.zeros_like(self.controlnet_mode_embedder(zero_index))
else:
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
txt = torch.cat([controlnet_mode_emb, txt], dim=1)
txt_ids = torch.cat([txt_ids[:, :1, :], txt_ids], dim=1)
else:
assert controlnet_mode is None
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
double_block_samples: list[torch.Tensor] = []
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
double_block_samples.append(img)
img = torch.cat((txt, img), 1)
single_block_samples: list[torch.Tensor] = []
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
single_block_samples.append(img[:, txt.shape[1] :])
# ControlNet Block
controlnet_double_block_samples: list[torch.Tensor] = []
for double_block_sample, controlnet_block in zip(double_block_samples, self.controlnet_blocks, strict=True):
double_block_sample = controlnet_block(double_block_sample)
controlnet_double_block_samples.append(double_block_sample)
controlnet_single_block_samples: list[torch.Tensor] = []
for single_block_sample, controlnet_block in zip(
single_block_samples, self.controlnet_single_blocks, strict=True
):
single_block_sample = controlnet_block(single_block_sample)
controlnet_single_block_samples.append(single_block_sample)
return InstantXControlNetFluxOutput(
controlnet_block_samples=controlnet_double_block_samples or None,
controlnet_single_block_samples=controlnet_single_block_samples or None,
)

View File

@@ -1,295 +0,0 @@
from typing import Any, Dict
import torch
from invokeai.backend.flux.model import FluxParams
def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
"""Is the state dict for an XLabs ControlNet model?
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
"""
# If all of the expected keys are present, then this is very likely an XLabs ControlNet model.
expected_keys = {
"controlnet_blocks.0.bias",
"controlnet_blocks.0.weight",
"input_hint_block.0.bias",
"input_hint_block.0.weight",
"pos_embed_input.bias",
"pos_embed_input.weight",
}
if expected_keys.issubset(sd.keys()):
return True
return False
def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
"""Is the state dict for an InstantX ControlNet model?
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
"""
# If all of the expected keys are present, then this is very likely an InstantX ControlNet model.
expected_keys = {
"controlnet_blocks.0.bias",
"controlnet_blocks.0.weight",
"controlnet_x_embedder.bias",
"controlnet_x_embedder.weight",
}
if expected_keys.issubset(sd.keys()):
return True
return False
def _fuse_weights(*t: torch.Tensor) -> torch.Tensor:
"""Fuse weights along dimension 0.
Used to fuse q, k, v attention weights into a single qkv tensor when converting from diffusers to BFL format.
"""
# TODO(ryand): Double check dim=0 is correct.
return torch.cat(t, dim=0)
def _convert_flux_double_block_sd_from_diffusers_to_bfl_format(
sd: Dict[str, torch.Tensor], double_block_index: int
) -> Dict[str, torch.Tensor]:
"""Convert the state dict for a double block from diffusers format to BFL format."""
to_prefix = f"double_blocks.{double_block_index}"
from_prefix = f"transformer_blocks.{double_block_index}"
new_sd: dict[str, torch.Tensor] = {}
# Check one key to determine if this block exists.
if f"{from_prefix}.attn.add_q_proj.bias" not in sd:
return new_sd
# txt_attn.qkv
new_sd[f"{to_prefix}.txt_attn.qkv.bias"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.add_q_proj.bias"),
sd.pop(f"{from_prefix}.attn.add_k_proj.bias"),
sd.pop(f"{from_prefix}.attn.add_v_proj.bias"),
)
new_sd[f"{to_prefix}.txt_attn.qkv.weight"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.add_q_proj.weight"),
sd.pop(f"{from_prefix}.attn.add_k_proj.weight"),
sd.pop(f"{from_prefix}.attn.add_v_proj.weight"),
)
# img_attn.qkv
new_sd[f"{to_prefix}.img_attn.qkv.bias"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.bias"),
sd.pop(f"{from_prefix}.attn.to_k.bias"),
sd.pop(f"{from_prefix}.attn.to_v.bias"),
)
new_sd[f"{to_prefix}.img_attn.qkv.weight"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.weight"),
sd.pop(f"{from_prefix}.attn.to_k.weight"),
sd.pop(f"{from_prefix}.attn.to_v.weight"),
)
# Handle basic 1-to-1 key conversions.
key_map = {
# img_attn
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
# img_mlp
"ff.net.0.proj.weight": "img_mlp.0.weight",
"ff.net.0.proj.bias": "img_mlp.0.bias",
"ff.net.2.weight": "img_mlp.2.weight",
"ff.net.2.bias": "img_mlp.2.bias",
# img_mod
"norm1.linear.weight": "img_mod.lin.weight",
"norm1.linear.bias": "img_mod.lin.bias",
# txt_attn
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
"attn.to_add_out.weight": "txt_attn.proj.weight",
"attn.to_add_out.bias": "txt_attn.proj.bias",
# txt_mlp
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
# txt_mod
"norm1_context.linear.weight": "txt_mod.lin.weight",
"norm1_context.linear.bias": "txt_mod.lin.bias",
}
for from_key, to_key in key_map.items():
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
return new_sd
def _convert_flux_single_block_sd_from_diffusers_to_bfl_format(
sd: Dict[str, torch.Tensor], single_block_index: int
) -> Dict[str, torch.Tensor]:
"""Convert the state dict for a single block from diffusers format to BFL format."""
to_prefix = f"single_blocks.{single_block_index}"
from_prefix = f"single_transformer_blocks.{single_block_index}"
new_sd: dict[str, torch.Tensor] = {}
# Check one key to determine if this block exists.
if f"{from_prefix}.attn.to_q.bias" not in sd:
return new_sd
# linear1 (qkv)
new_sd[f"{to_prefix}.linear1.bias"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.bias"),
sd.pop(f"{from_prefix}.attn.to_k.bias"),
sd.pop(f"{from_prefix}.attn.to_v.bias"),
sd.pop(f"{from_prefix}.proj_mlp.bias"),
)
new_sd[f"{to_prefix}.linear1.weight"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.weight"),
sd.pop(f"{from_prefix}.attn.to_k.weight"),
sd.pop(f"{from_prefix}.attn.to_v.weight"),
sd.pop(f"{from_prefix}.proj_mlp.weight"),
)
# Handle basic 1-to-1 key conversions.
key_map = {
# linear2
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
# modulation
"norm.linear.weight": "modulation.lin.weight",
"norm.linear.bias": "modulation.lin.bias",
# norm
"attn.norm_k.weight": "norm.key_norm.scale",
"attn.norm_q.weight": "norm.query_norm.scale",
}
for from_key, to_key in key_map.items():
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
return new_sd
def convert_diffusers_instantx_state_dict_to_bfl_format(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert an InstantX ControlNet state dict to the format that can be loaded by our internal
InstantXControlNetFlux model.
The original InstantX ControlNet model was developed to be used in diffusers. We have ported the original
implementation to InstantXControlNetFlux to make it compatible with BFL-style models. This function converts the
original state dict to the format expected by InstantXControlNetFlux.
"""
# Shallow copy sd so that we can pop keys from it without modifying the original.
sd = sd.copy()
new_sd: dict[str, torch.Tensor] = {}
# Handle basic 1-to-1 key conversions.
basic_key_map = {
# Base model keys.
# ----------------
# txt_in keys.
"context_embedder.bias": "txt_in.bias",
"context_embedder.weight": "txt_in.weight",
# guidance_in MLPEmbedder keys.
"time_text_embed.guidance_embedder.linear_1.bias": "guidance_in.in_layer.bias",
"time_text_embed.guidance_embedder.linear_1.weight": "guidance_in.in_layer.weight",
"time_text_embed.guidance_embedder.linear_2.bias": "guidance_in.out_layer.bias",
"time_text_embed.guidance_embedder.linear_2.weight": "guidance_in.out_layer.weight",
# vector_in MLPEmbedder keys.
"time_text_embed.text_embedder.linear_1.bias": "vector_in.in_layer.bias",
"time_text_embed.text_embedder.linear_1.weight": "vector_in.in_layer.weight",
"time_text_embed.text_embedder.linear_2.bias": "vector_in.out_layer.bias",
"time_text_embed.text_embedder.linear_2.weight": "vector_in.out_layer.weight",
# time_in MLPEmbedder keys.
"time_text_embed.timestep_embedder.linear_1.bias": "time_in.in_layer.bias",
"time_text_embed.timestep_embedder.linear_1.weight": "time_in.in_layer.weight",
"time_text_embed.timestep_embedder.linear_2.bias": "time_in.out_layer.bias",
"time_text_embed.timestep_embedder.linear_2.weight": "time_in.out_layer.weight",
# img_in keys.
"x_embedder.bias": "img_in.bias",
"x_embedder.weight": "img_in.weight",
}
for old_key, new_key in basic_key_map.items():
v = sd.pop(old_key, None)
if v is not None:
new_sd[new_key] = v
# Handle the double_blocks.
block_index = 0
while True:
converted_double_block_sd = _convert_flux_double_block_sd_from_diffusers_to_bfl_format(sd, block_index)
if len(converted_double_block_sd) == 0:
break
new_sd.update(converted_double_block_sd)
block_index += 1
# Handle the single_blocks.
block_index = 0
while True:
converted_singe_block_sd = _convert_flux_single_block_sd_from_diffusers_to_bfl_format(sd, block_index)
if len(converted_singe_block_sd) == 0:
break
new_sd.update(converted_singe_block_sd)
block_index += 1
# Transfer controlnet keys as-is.
for k in list(sd.keys()):
if k.startswith("controlnet_"):
new_sd[k] = sd.pop(k)
# Assert that all keys have been handled.
assert len(sd) == 0
return new_sd
def infer_flux_params_from_state_dict(sd: Dict[str, torch.Tensor]) -> FluxParams:
"""Infer the FluxParams from the shape of a FLUX state dict. When a model is distributed in diffusers format, this
information is all contained in the config.json file that accompanies the model. However, being apple to infer the
params from the state dict enables us to load models (e.g. an InstantX ControlNet) from a single weight file.
"""
hidden_size = sd["img_in.weight"].shape[0]
mlp_hidden_dim = sd["double_blocks.0.img_mlp.0.weight"].shape[0]
# mlp_ratio is a float, but we treat it as an int here to avoid having to think about possible float precision
# issues. In practice, mlp_ratio is usually 4.
mlp_ratio = mlp_hidden_dim // hidden_size
head_dim = sd["double_blocks.0.img_attn.norm.query_norm.scale"].shape[0]
num_heads = hidden_size // head_dim
# Count the number of double blocks.
double_block_index = 0
while f"double_blocks.{double_block_index}.img_attn.qkv.weight" in sd:
double_block_index += 1
# Count the number of single blocks.
single_block_index = 0
while f"single_blocks.{single_block_index}.linear1.weight" in sd:
single_block_index += 1
return FluxParams(
in_channels=sd["img_in.weight"].shape[1],
vec_in_dim=sd["vector_in.in_layer.weight"].shape[1],
context_in_dim=sd["txt_in.weight"].shape[1],
hidden_size=hidden_size,
mlp_ratio=mlp_ratio,
num_heads=num_heads,
depth=double_block_index,
depth_single_blocks=single_block_index,
# axes_dim cannot be inferred from the state dict. The hard-coded value is correct for dev/schnell models.
axes_dim=[16, 56, 56],
# theta cannot be inferred from the state dict. The hard-coded value is correct for dev/schnell models.
theta=10_000,
qkv_bias="double_blocks.0.img_attn.qkv.bias" in sd,
guidance_embed="guidance_in.in_layer.weight" in sd,
)
def infer_instantx_num_control_modes_from_state_dict(sd: Dict[str, torch.Tensor]) -> int | None:
"""Infer the number of ControlNet Union modes from the shape of a InstantX ControlNet state dict.
Returns None if the model is not a ControlNet Union model. Otherwise returns the number of modes.
"""
mode_embedder_key = "controlnet_mode_embedder.weight"
if mode_embedder_key not in sd:
return None
return sd[mode_embedder_key].shape[0]

View File

@@ -1,130 +0,0 @@
# This file was initially based on:
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/controlnet.py
from dataclasses import dataclass
import torch
from einops import rearrange
from invokeai.backend.flux.controlnet.zero_module import zero_module
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding
@dataclass
class XLabsControlNetFluxOutput:
controlnet_double_block_residuals: list[torch.Tensor] | None
class XLabsControlNetFlux(torch.nn.Module):
"""A ControlNet model for FLUX.
The architecture is very similar to the base FLUX model, with the following differences:
- A `controlnet_depth` parameter is passed to control the number of double_blocks that the ControlNet is applied to.
In order to keep the ControlNet small, this is typically much less than the depth of the base FLUX model.
- There is a set of `controlnet_blocks` that are applied to the output of each double_block.
"""
def __init__(self, params: FluxParams, controlnet_depth: int = 2):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else torch.nn.Identity()
)
self.txt_in = torch.nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = torch.nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(controlnet_depth)
]
)
# Add ControlNet blocks.
self.controlnet_blocks = torch.nn.ModuleList([])
for _ in range(controlnet_depth):
controlnet_block = torch.nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.input_hint_block = torch.nn.Sequential(
torch.nn.Conv2d(3, 16, 3, padding=1),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
torch.nn.SiLU(),
zero_module(torch.nn.Conv2d(16, 16, 3, padding=1)),
)
def forward(
self,
img: torch.Tensor,
img_ids: torch.Tensor,
controlnet_cond: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
timesteps: torch.Tensor,
y: torch.Tensor,
guidance: torch.Tensor | None = None,
) -> XLabsControlNetFluxOutput:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_res_samples: list[torch.Tensor] = []
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
block_res_samples.append(img)
controlnet_block_res_samples: list[torch.Tensor] = []
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks, strict=True):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples.append(block_res_sample)
return XLabsControlNetFluxOutput(controlnet_double_block_residuals=controlnet_block_res_samples)

View File

@@ -1,12 +0,0 @@
from typing import TypeVar
import torch
T = TypeVar("T", bound=torch.nn.Module)
def zero_module(module: T) -> T:
"""Initialize the parameters of a module to zero."""
for p in module.parameters():
torch.nn.init.zeros_(p)
return module

View File

@@ -3,10 +3,7 @@ from typing import Callable
import torch
from tqdm import tqdm
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -24,7 +21,6 @@ def denoise(
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
inpaint_extension: InpaintExtension | None,
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -42,30 +38,6 @@ def denoise(
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# Run ControlNet models.
controlnet_residuals: list[ControlNetFluxOutput] = []
for controlnet_extension in controlnet_extensions:
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step - 1,
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
)
# Merge the ControlNet residuals from multiple ControlNets.
# TODO(ryand): We may want to alculate the sum just-in-time to keep peak memory low. Keep in mind, that the
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
pred = model(
img=img,
img_ids=img_ids,
@@ -74,8 +46,6 @@ def denoise(
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
)
preview_img = img - t_curr * pred

View File

@@ -1,45 +0,0 @@
import math
from abc import ABC, abstractmethod
from typing import List, Union
import torch
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
class BaseControlNetExtension(ABC):
def __init__(
self,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
self._weight = weight
self._begin_step_percent = begin_step_percent
self._end_step_percent = end_step_percent
def _get_weight(self, timestep_index: int, total_num_timesteps: int) -> float:
first_step = math.floor(self._begin_step_percent * total_num_timesteps)
last_step = math.ceil(self._end_step_percent * total_num_timesteps)
if timestep_index < first_step or timestep_index > last_step:
return 0.0
if isinstance(self._weight, list):
return self._weight[timestep_index]
return self._weight
@abstractmethod
def run_controlnet(
self,
timestep_index: int,
total_num_timesteps: int,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> ControlNetFluxOutput: ...

View File

@@ -1,194 +0,0 @@
import math
from typing import List, Union
import torch
from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import (
InstantXControlNetFlux,
InstantXControlNetFluxOutput,
)
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
from invokeai.backend.flux.sampling_utils import pack
from invokeai.backend.model_manager.load.load_base import LoadedModel
class InstantXControlNetExtension(BaseControlNetExtension):
def __init__(
self,
model: InstantXControlNetFlux,
controlnet_cond: torch.Tensor,
instantx_control_mode: torch.Tensor | None,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
super().__init__(
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
self._model = model
# The VAE-encoded and 'packed' control image to pass to the ControlNet model.
self._controlnet_cond = controlnet_cond
# TODO(ryand): Should we define an enum for the instantx_control_mode? Is it likely to change for future models?
# The control mode for InstantX ControlNet union models.
# See the values defined here: https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union#control-mode
# Expected shape: (batch_size, 1), Expected dtype: torch.long
# If None, a zero-embedding will be used.
self._instantx_control_mode = instantx_control_mode
# TODO(ryand): Pass in these params if a new base transformer / InstantX ControlNet pair get released.
self._flux_transformer_num_double_blocks = 19
self._flux_transformer_num_single_blocks = 38
@classmethod
def prepare_controlnet_cond(
cls,
controlnet_image: Image,
vae_info: LoadedModel,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
resize_mode: CONTROLNET_RESIZE_VALUES,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
resized_controlnet_image = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode="balanced",
resize_mode=resize_mode,
)
# Shift the image from [0, 1] to [-1, 1].
resized_controlnet_image = resized_controlnet_image * 2 - 1
# Run VAE encoder.
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
controlnet_cond = pack(controlnet_cond)
return controlnet_cond
@classmethod
def from_controlnet_image(
cls,
model: InstantXControlNetFlux,
controlnet_image: Image,
instantx_control_mode: torch.Tensor | None,
vae_info: LoadedModel,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
resize_mode: CONTROLNET_RESIZE_VALUES,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
resized_controlnet_image = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode="balanced",
resize_mode=resize_mode,
)
# Shift the image from [0, 1] to [-1, 1].
resized_controlnet_image = resized_controlnet_image * 2 - 1
# Run VAE encoder.
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
controlnet_cond = pack(controlnet_cond)
return cls(
model=model,
controlnet_cond=controlnet_cond,
instantx_control_mode=instantx_control_mode,
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
def _instantx_output_to_controlnet_output(
self, instantx_output: InstantXControlNetFluxOutput
) -> ControlNetFluxOutput:
# The `interval_control` logic here is based on
# https://github.com/huggingface/diffusers/blob/31058cdaef63ca660a1a045281d156239fba8192/src/diffusers/models/transformers/transformer_flux.py#L507-L511
# Handle double block residuals.
double_block_residuals: list[torch.Tensor] = []
double_block_samples = instantx_output.controlnet_block_samples
if double_block_samples:
interval_control = self._flux_transformer_num_double_blocks / len(double_block_samples)
interval_control = int(math.ceil(interval_control))
for i in range(self._flux_transformer_num_double_blocks):
double_block_residuals.append(double_block_samples[i // interval_control])
# Handle single block residuals.
single_block_residuals: list[torch.Tensor] = []
single_block_samples = instantx_output.controlnet_single_block_samples
if single_block_samples:
interval_control = self._flux_transformer_num_single_blocks / len(single_block_samples)
interval_control = int(math.ceil(interval_control))
for i in range(self._flux_transformer_num_single_blocks):
single_block_residuals.append(single_block_samples[i // interval_control])
return ControlNetFluxOutput(
double_block_residuals=double_block_residuals or None,
single_block_residuals=single_block_residuals or None,
)
def run_controlnet(
self,
timestep_index: int,
total_num_timesteps: int,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> ControlNetFluxOutput:
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
if weight < 1e-6:
return ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
# Make sure inputs have correct device and dtype.
self._controlnet_cond = self._controlnet_cond.to(device=img.device, dtype=img.dtype)
self._instantx_control_mode = (
self._instantx_control_mode.to(device=img.device) if self._instantx_control_mode is not None else None
)
instantx_output: InstantXControlNetFluxOutput = self._model(
controlnet_cond=self._controlnet_cond,
controlnet_mode=self._instantx_control_mode,
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
timesteps=timesteps,
y=y,
guidance=guidance,
)
controlnet_output = self._instantx_output_to_controlnet_output(instantx_output)
controlnet_output.apply_weight(weight)
return controlnet_output

View File

@@ -1,150 +0,0 @@
from typing import List, Union
import torch
from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux, XLabsControlNetFluxOutput
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
class XLabsControlNetExtension(BaseControlNetExtension):
def __init__(
self,
model: XLabsControlNetFlux,
controlnet_cond: torch.Tensor,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
super().__init__(
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
self._model = model
# _controlnet_cond is the control image passed to the ControlNet model.
# Pixel values are in the range [-1, 1]. Shape: (batch_size, 3, height, width).
self._controlnet_cond = controlnet_cond
# TODO(ryand): Pass in these params if a new base transformer / XLabs ControlNet pair get released.
self._flux_transformer_num_double_blocks = 19
self._flux_transformer_num_single_blocks = 38
@classmethod
def prepare_controlnet_cond(
cls,
controlnet_image: Image,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
resize_mode: CONTROLNET_RESIZE_VALUES,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
controlnet_cond = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode="balanced",
resize_mode=resize_mode,
)
# Map pixel values from [0, 1] to [-1, 1].
controlnet_cond = controlnet_cond * 2 - 1
return controlnet_cond
@classmethod
def from_controlnet_image(
cls,
model: XLabsControlNetFlux,
controlnet_image: Image,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
resize_mode: CONTROLNET_RESIZE_VALUES,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
controlnet_cond = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode="balanced",
resize_mode=resize_mode,
)
# Map pixel values from [0, 1] to [-1, 1].
controlnet_cond = controlnet_cond * 2 - 1
return cls(
model=model,
controlnet_cond=controlnet_cond,
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
def _xlabs_output_to_controlnet_output(self, xlabs_output: XLabsControlNetFluxOutput) -> ControlNetFluxOutput:
# The modulo index logic used here is based on:
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/model.py#L198-L200
# Handle double block residuals.
double_block_residuals: list[torch.Tensor] = []
xlabs_double_block_residuals = xlabs_output.controlnet_double_block_residuals
if xlabs_double_block_residuals is not None:
for i in range(self._flux_transformer_num_double_blocks):
double_block_residuals.append(xlabs_double_block_residuals[i % len(xlabs_double_block_residuals)])
return ControlNetFluxOutput(
double_block_residuals=double_block_residuals,
single_block_residuals=None,
)
def run_controlnet(
self,
timestep_index: int,
total_num_timesteps: int,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> ControlNetFluxOutput:
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
if weight < 1e-6:
return ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
xlabs_output: XLabsControlNetFluxOutput = self._model(
img=img,
img_ids=img_ids,
controlnet_cond=self._controlnet_cond,
txt=txt,
txt_ids=txt_ids,
timesteps=timesteps,
y=y,
guidance=guidance,
)
controlnet_output = self._xlabs_output_to_controlnet_output(xlabs_output)
controlnet_output.apply_weight(weight)
return controlnet_output

View File

@@ -87,9 +87,7 @@ class Flux(nn.Module):
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None,
controlnet_double_block_residuals: list[Tensor] | None,
controlnet_single_block_residuals: list[Tensor] | None,
guidance: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -107,27 +105,12 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
# Validate double_block_residuals shape.
if controlnet_double_block_residuals is not None:
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
for block_index, block in enumerate(self.double_blocks):
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
if controlnet_double_block_residuals is not None:
img += controlnet_double_block_residuals[block_index]
img = torch.cat((txt, img), 1)
# Validate single_block_residuals shape.
if controlnet_single_block_residuals is not None:
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
for block_index, block in enumerate(self.single_blocks):
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
if controlnet_single_block_residuals is not None:
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)

View File

@@ -8,36 +8,17 @@ from diffusers import ControlNetModel
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
)
from invokeai.backend.model_manager.config import (
BaseModelType,
ControlNetCheckpointConfig,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import ControlNetCheckpointConfig, SubModelType
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""

View File

@@ -10,15 +10,6 @@ from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.state_dict_utils import (
convert_diffusers_instantx_state_dict_to_bfl_format,
infer_flux_params_from_state_dict,
infer_instantx_num_control_modes_from_state_dict,
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.util import ae_params, params
@@ -33,8 +24,6 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
ControlNetCheckpointConfig,
ControlNetDiffusersConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
MainGGUFCheckpointConfig,
@@ -304,51 +293,3 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
model.load_state_dict(sd, assign=True)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
class FluxControlnetModel(ModelLoader):
"""Class to load FLUX ControlNet models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
model_path = Path(config.path)
elif isinstance(config, ControlNetDiffusersConfig):
# If this is a diffusers directory, we simply ignore the config file and load from the weight file.
model_path = Path(config.path) / "diffusion_pytorch_model.safetensors"
else:
raise ValueError(f"Unexpected ControlNet model config type: {type(config)}")
sd = load_file(model_path)
# Detect the FLUX ControlNet model type from the state dict.
if is_state_dict_xlabs_controlnet(sd):
return self._load_xlabs_controlnet(sd)
elif is_state_dict_instantx_controlnet(sd):
return self._load_instantx_controlnet(sd)
else:
raise ValueError("Do not recognize the state dict as an XLabs or InstantX ControlNet model.")
def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
with accelerate.init_empty_weights():
# HACK(ryand): Is it safe to assume dev here?
model = XLabsControlNetFlux(params["flux-dev"])
model.load_state_dict(sd, assign=True)
return model
def _load_instantx_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
flux_params = infer_flux_params_from_state_dict(sd)
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
with accelerate.init_empty_weights():
model = InstantXControlNetFlux(flux_params, num_control_modes)
model.load_state_dict(sd, assign=True)
return model

View File

@@ -10,10 +10,6 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
@@ -120,7 +116,6 @@ class ModelProbe(object):
"CLIPModel": ModelType.CLIPEmbed,
"CLIPTextModel": ModelType.CLIPEmbed,
"T5EncoderModel": ModelType.T5Encoder,
"FluxControlNetModel": ModelType.ControlNet,
}
@classmethod
@@ -260,19 +255,7 @@ class ModelProbe(object):
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return ModelType.LoRA
elif key.startswith(
(
"controlnet",
"control_model",
"input_blocks",
# XLabs FLUX ControlNet models have keys starting with "controlnet_blocks."
# For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
# TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
# delicate.
"controlnet_blocks",
)
):
elif key.startswith(("controlnet", "control_model", "input_blocks")):
return ModelType.ControlNet
elif key.startswith(("image_proj.", "ip_adapter.")):
return ModelType.IPAdapter
@@ -455,7 +438,6 @@ MODEL_NAME_TO_PREPROCESSOR = {
"lineart": "lineart_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"softedge": "hed_image_processor",
"hed": "hed_image_processor",
"shuffle": "content_shuffle_image_processor",
"pose": "dw_openpose_image_processor",
"mediapipe": "mediapipe_face_processor",
@@ -467,8 +449,7 @@ MODEL_NAME_TO_PREPROCESSOR = {
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
model_name_lower = model_name.lower()
if k in model_name_lower:
if k in model_name:
return ControlAdapterDefaultSettings(preprocessor=v)
return None
@@ -642,11 +623,6 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
if is_state_dict_xlabs_controlnet(checkpoint) or is_state_dict_instantx_controlnet(checkpoint):
# TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing
# get_format()?
return BaseModelType.Flux
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"controlnet_mid_block.bias",
@@ -868,19 +844,22 @@ class ControlNetFolderProbe(FolderProbeBase):
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
if config.get("_class_name", None) == "FluxControlNetModel":
return BaseModelType.Flux
# no obvious way to distinguish between sd2-base and sd2-768
dimension = config["cross_attention_dim"]
if dimension == 768:
return BaseModelType.StableDiffusion1
if dimension == 1024:
return BaseModelType.StableDiffusion2
if dimension == 2048:
return BaseModelType.StableDiffusionXL
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else (
BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
)
if not base_model:
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
return base_model
class LoRAFolderProbe(FolderProbeBase):

View File

@@ -422,13 +422,6 @@ STARTER_MODELS: list[StarterModel] = [
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
type=ModelType.ControlNet,
),
StarterModel(
name="FLUX.1-dev-Controlnet-Union-Pro",
base=BaseModelType.Flux,
source="Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
description="A unified ControlNet for FLUX.1-dev model that supports 7 control modes, including canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6)",
type=ModelType.ControlNet,
),
# endregion
# region T2I Adapter
StarterModel(

View File

@@ -198,24 +198,20 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self.disable_attention_slicing()
return
elif config.attention_type == "torch-sdp":
# torch-sdp is the default in diffusers.
return
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enables sdp automatically
return
else:
raise Exception("torch-sdp attention slicing not available")
# See https://github.com/invoke-ai/InvokeAI/issues/7049 for context.
# Bumping torch from 2.2.2 to 2.4.1 caused the sliced attention implementation to produce incorrect results.
# For now, if a user is on an MPS device and has not explicitly set the attention_type, then we select the
# non-sliced torch-sdp implementation. This keeps things working on MPS at the cost of increased peak memory
# utilization.
if torch.backends.mps.is_available():
return
# The remainder if this code is called when attention_type=='auto'.
# the remainder if this code is called when attention_type=='auto'
if self.unet.device.type == "cuda":
if is_xformers_available() and prefer_xformers:
self.enable_xformers_memory_efficient_attention()
return
# torch-sdp is the default in diffusers.
return
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
# diffusers enables sdp automatically
return
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
mem_free = psutil.virtual_memory().free

View File

@@ -936,8 +936,7 @@
},
"paramScheduler": {
"paragraphs": [
"Verwendeter Planer währende des Generierungsprozesses.",
"Jeder Planer definiert, wie einem Bild iterativ Rauschen hinzugefügt wird, oder wie ein Sample basierend auf der Ausgabe eines Modells aktualisiert wird."
"\"Planer\" definiert, wie iterativ Rauschen zu einem Bild hinzugefügt wird, oder wie ein Sample bei der Ausgabe eines Modells aktualisiert wird."
],
"heading": "Planer"
},
@@ -963,61 +962,6 @@
},
"ipAdapterMethod": {
"heading": "Methode"
},
"refinerScheduler": {
"heading": "Planer",
"paragraphs": [
"Planer, der während der Veredelungsphase des Generierungsprozesses verwendet wird.",
"Ähnlich wie der Generierungsplaner."
]
},
"compositingCoherenceMode": {
"paragraphs": [
"Verwendete Methode zur Erstellung eines kohärenten Bildes mit dem neu generierten maskierten Bereich."
],
"heading": "Modus"
},
"compositingCoherencePass": {
"heading": "Kohärenzdurchlauf"
},
"controlNet": {
"heading": "ControlNet"
},
"compositingMaskAdjustments": {
"paragraphs": [
"Die Maske anpassen."
],
"heading": "Maskenanpassungen"
},
"compositingMaskBlur": {
"paragraphs": [
"Der Unschärferadius der Maske."
],
"heading": "Maskenunschärfe"
},
"compositingBlurMethod": {
"paragraphs": [
"Die auf den maskierten Bereich angewendete Unschärfemethode."
],
"heading": "Unschärfemethode"
},
"controlNetResizeMode": {
"heading": "Größenänderungsmodus"
},
"paramWidth": {
"heading": "Breite",
"paragraphs": [
"Breite des generierten Bildes. Muss ein Vielfaches von 8 sein."
]
},
"controlNetControlMode": {
"heading": "Kontrollmodus"
},
"controlNetProcessor": {
"heading": "Prozessor"
},
"patchmatchDownScaleSize": {
"heading": "Herunterskalieren"
}
},
"invocationCache": {
@@ -1136,8 +1080,7 @@
"workflowContact": "Kontaktdaten",
"workflowNotes": "Notizen",
"workflowTags": "Tags",
"workflowVersion": "Version",
"saveToGallery": "In Galerie speichern"
"workflowVersion": "Version"
},
"hrf": {
"enableHrf": "Korrektur für hohe Auflösungen",
@@ -1307,16 +1250,7 @@
"searchByName": "Nach Name suchen",
"promptTemplateCleared": "Promptvorlage gelöscht",
"preview": "Vorschau",
"positivePrompt": "Positiv-Prompt",
"active": "Aktiv",
"deleteTemplate2": "Sind Sie sicher, dass Sie diese Vorlage löschen möchten? Dies kann nicht rückgängig gemacht werden.",
"deleteTemplate": "Vorlage löschen",
"copyTemplate": "Vorlage kopieren",
"editTemplate": "Vorlage bearbeiten",
"deleteImage": "Bild löschen",
"defaultTemplates": "Standardvorlagen",
"nameColumn": "'name'",
"exportDownloaded": "Export heruntergeladen"
"positivePrompt": "Positiv-Prompt"
},
"newUserExperience": {
"gettingStartedSeries": "Wünschen Sie weitere Anleitungen? In unserer <LinkComponent>Einführungsserie</LinkComponent> finden Sie Tipps, wie Sie das Potenzial von Invoke Studio voll ausschöpfen können.",
@@ -1329,22 +1263,13 @@
"bbox": "Bbox"
},
"transform": {
"fitToBbox": "An Bbox anpassen",
"reset": "Zurücksetzen",
"apply": "Anwenden",
"cancel": "Abbrechen"
"fitToBbox": "An Bbox anpassen"
},
"pullBboxIntoLayerError": "Problem, Bbox in die Ebene zu ziehen",
"pullBboxIntoLayer": "Bbox in Ebene ziehen",
"HUD": {
"bbox": "Bbox",
"scaledBbox": "Skalierte Bbox",
"entityStatus": {
"isHidden": "{{title}} ist ausgeblendet",
"isDisabled": "{{title}} ist deaktiviert",
"isLocked": "{{title}} ist gesperrt",
"isEmpty": "{{title}} ist leer"
}
"scaledBbox": "Skalierte Bbox"
},
"fitBboxToLayers": "Bbox an Ebenen anpassen",
"pullBboxIntoReferenceImage": "Bbox ins Referenzbild ziehen",
@@ -1354,12 +1279,7 @@
"clipToBbox": "Pinselstriche auf Bbox beschränken",
"canvasContextMenu": {
"saveBboxToGallery": "Bbox in Galerie speichern",
"bboxGroup": "Aus Bbox erstellen",
"canvasGroup": "Leinwand",
"newGlobalReferenceImage": "Neues globales Referenzbild",
"newRegionalReferenceImage": "Neues regionales Referenzbild",
"newControlLayer": "Neue Kontroll-Ebene",
"newRasterLayer": "Neue Raster-Ebene"
"bboxGroup": "Aus Bbox erstellen"
},
"rectangle": "Rechteck",
"saveCanvasToGallery": "Leinwand in Galerie speichern",
@@ -1390,7 +1310,7 @@
"regional": "Regional",
"newGlobalReferenceImageOk": "Globales Referenzbild erstellt",
"savedToGalleryError": "Fehler beim Speichern in der Galerie",
"savedToGalleryOk": "In Galerie gespeichert",
"savedToGalleryOk": "In Galerie speichern",
"newGlobalReferenceImageError": "Problem beim Erstellen eines globalen Referenzbilds",
"newRegionalReferenceImageOk": "Regionales Referenzbild erstellt",
"duplicate": "Duplizieren",
@@ -1423,39 +1343,12 @@
"showProgressOnCanvas": "Fortschritt auf Leinwand anzeigen",
"controlMode": {
"balanced": "Ausgewogen"
},
"globalReferenceImages_withCount_hidden": "Globale Referenzbilder ({{count}} ausgeblendet)",
"sendToGallery": "An Galerie senden",
"stagingArea": {
"accept": "Annehmen",
"next": "Nächste",
"discardAll": "Alle verwerfen",
"discard": "Verwerfen",
"previous": "Vorherige"
},
"regionalGuidance_withCount_visible": "Regionale Führung ({{count}})",
"regionalGuidance_withCount_hidden": "Regionale Führung ({{count}} ausgeblendet)",
"settings": {
"snapToGrid": {
"on": "Ein",
"off": "Aus",
"label": "Am Raster ausrichten"
}
},
"layer_one": "Ebene",
"layer_other": "Ebenen",
"layer_withCount_one": "Ebene ({{count}})",
"layer_withCount_other": "Ebenen ({{count}})"
}
},
"upsell": {
"shareAccess": "Zugang teilen",
"professional": "Professionell",
"inviteTeammates": "Teamkollegen einladen",
"professionalUpsell": "Verfügbar in der Professional Edition von Invoke. Klicken Sie hier oder besuchen Sie invoke.com/pricing für weitere Details."
},
"upscaling": {
"creativity": "Kreativität",
"structure": "Struktur",
"scale": "Maßstab"
}
}

View File

@@ -285,7 +285,6 @@
"assetsTab": "Files youve uploaded for use in your projects.",
"autoAssignBoardOnClick": "Auto-Assign Board on Click",
"autoSwitchNewImages": "Auto-Switch to New Images",
"boardsSettings": "Boards Settings",
"copy": "Copy",
"currentlyInUse": "This image is currently in use in the following features:",
"drop": "Drop",
@@ -305,7 +304,6 @@
"go": "Go",
"image": "image",
"imagesTab": "Images youve created and saved within Invoke.",
"imagesSettings": "Gallery Images Settings",
"jump": "Jump",
"loading": "Loading",
"newestFirst": "Newest First",
@@ -1643,7 +1641,6 @@
"sendToCanvas": "Send To Canvas",
"newLayerFromImage": "New Layer from Image",
"newCanvasFromImage": "New Canvas from Image",
"newImg2ImgCanvasFromImage": "New Img2Img from Image",
"copyToClipboard": "Copy to Clipboard",
"sendToCanvasDesc": "Pressing Invoke stages your work in progress on the canvas.",
"viewProgressInViewer": "View progress and outputs in the <Btn>Image Viewer</Btn>.",

View File

@@ -1730,8 +1730,7 @@
"mlsd_detection": {
"score_threshold": "Soglia di punteggio",
"distance_threshold": "Soglia di distanza",
"description": "Genera una mappa dei segmenti di linea dal livello selezionato utilizzando il modello di rilevamento dei segmenti di linea MLSD.",
"label": "Rilevamento segmenti di linea"
"description": "Genera una mappa dei segmenti di linea dal livello selezionato utilizzando il modello di rilevamento dei segmenti di linea MLSD."
},
"content_shuffle": {
"label": "Mescola contenuto",

View File

@@ -158,9 +158,7 @@
"move": "Двигать",
"gallery": "Галерея",
"openViewer": "Открыть просмотрщик",
"closeViewer": "Закрыть просмотрщик",
"imagesTab": "Изображения, созданные и сохраненные в Invoke.",
"assetsTab": "Файлы, которые вы загрузили для использования в своих проектах."
"closeViewer": "Закрыть просмотрщик"
},
"hotkeys": {
"searchHotkeys": "Поиск горячих клавиш",
@@ -930,10 +928,7 @@
"imageAccessError": "Невозможно найти изображение {{image_name}}, сбрасываем на значение по умолчанию",
"boardAccessError": "Невозможно найти доску {{board_id}}, сбрасываем на значение по умолчанию",
"modelAccessError": "Невозможно найти модель {{key}}, сброс на модель по умолчанию",
"saveToGallery": "Сохранить в галерею",
"noWorkflows": "Нет рабочих процессов",
"noMatchingWorkflows": "Нет совпадающих рабочих процессов",
"workflowHelpText": "Нужна помощь? Ознакомьтесь с нашим руководством <LinkComponent>Getting Started with Workflows</LinkComponent>"
"saveToGallery": "Сохранить в галерею"
},
"boards": {
"autoAddBoard": "Авто добавление Доски",
@@ -1558,10 +1553,7 @@
"autoLayout": "Автоматическое расположение",
"userWorkflows": "Пользовательские рабочие процессы",
"projectWorkflows": "Рабочие процессы проекта",
"defaultWorkflows": "Стандартные рабочие процессы",
"deleteWorkflow2": "Вы уверены, что хотите удалить этот рабочий процесс? Это нельзя отменить.",
"chooseWorkflowFromLibrary": "Выбрать рабочий процесс из библиотеки",
"uploadAndSaveWorkflow": "Загрузить в библиотеку"
"defaultWorkflows": "Стандартные рабочие процессы"
},
"hrf": {
"enableHrf": "Включить исправление высокого разрешения",
@@ -1880,8 +1872,8 @@
"duplicate": "Дублировать",
"inpaintMasks_withCount_visible": "Маски перерисовки ({{count}})",
"layer_one": "Слой",
"layer_few": "Слоя",
"layer_many": "Слоев",
"layer_few": "",
"layer_many": "",
"prompt": "Запрос",
"negativePrompt": "Исключающий запрос",
"beginEndStepPercentShort": "Начало/конец %",
@@ -2043,7 +2035,7 @@
"whatsNewInInvoke": "Что нового в Invoke"
},
"newUserExperience": {
"toGetStarted": "Чтобы начать работу, введите в поле запрос и нажмите <StrongComponent>Invoke</StrongComponent>, чтобы сгенерировать первое изображение. Выберите шаблон запроса, чтобы улучшить результаты. Вы можете сохранить изображения непосредственно в <StrongComponent>Галерею</StrongComponent> или отредактировать их на <StrongComponent>Холсте</StrongComponent>.",
"toGetStarted": "Чтобы начать работу, введите в поле запрос и нажмите <StrongComponent>Invoke</StrongComponent>, чтобы сгенерировать первое изображение. Вы можете сохранить изображения непосредственно в <StrongComponent>Галерею</StrongComponent> или отредактировать их на <StrongComponent>Холсте</StrongComponent>.",
"gettingStartedSeries": "Хотите получить больше рекомендаций? Ознакомьтесь с нашей серией <LinkComponent>Getting Started Series</LinkComponent> для получения советов по раскрытию всего потенциала Invoke Studio."
}
}

View File

@@ -20,7 +20,6 @@ import {
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { ShareWorkflowModal } from 'features/nodes/components/sidePanel/WorkflowListMenu/ShareWorkflowModal';
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
@@ -121,7 +120,6 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
<GlobalImageHotkeys />
<NewGallerySessionDialog />
<NewCanvasSessionDialog />
<ImageContextMenu />
</ErrorBoundary>
);
};

View File

@@ -4,9 +4,9 @@ import { IAILoadingImageFallback, IAINoContentFallback } from 'common/components
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import type { MouseEvent, ReactElement, ReactNode, SyntheticEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { PiImageBold, PiUploadSimpleBold } from 'react-icons/pi';
import type { ImageDTO, PostUploadAction } from 'services/api/types';
@@ -17,14 +17,7 @@ const defaultUploadElement = <Icon as={PiUploadSimpleBold} boxSize={16} />;
const defaultNoContentFallback = <IAINoContentFallback icon={PiImageBold} />;
const baseStyles: SystemStyleObject = {
touchAction: 'none',
userSelect: 'none',
webkitUserSelect: 'none',
};
const sx: SystemStyleObject = {
...baseStyles,
'.gallery-image-container::before': {
content: '""',
display: 'inline-block',
@@ -109,10 +102,59 @@ const IAIDndImage = (props: IAIDndImageProps) => {
useThumbailFallback,
withHoverOverlay = false,
children,
onMouseOver,
onMouseOut,
dataTestId,
...rest
} = props;
const handleMouseOver = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (onMouseOver) {
onMouseOver(e);
}
},
[onMouseOver]
);
const handleMouseOut = useCallback(
(e: MouseEvent<HTMLDivElement>) => {
if (onMouseOut) {
onMouseOut(e);
}
},
[onMouseOut]
);
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
postUploadAction,
isDisabled: isUploadDisabled,
});
const uploadButtonStyles = useMemo<SystemStyleObject>(() => {
const styles: SystemStyleObject = {
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: 'base.500',
};
if (!isUploadDisabled) {
Object.assign(styles, {
cursor: 'pointer',
bg: 'base.700',
_hover: {
bg: 'base.650',
color: 'base.300',
},
});
}
return styles;
}, [isUploadDisabled, minSize]);
const openInNewTab = useCallback(
(e: MouseEvent) => {
if (!imageDTO) {
@@ -126,126 +168,76 @@ const IAIDndImage = (props: IAIDndImageProps) => {
[imageDTO]
);
const ref = useRef<HTMLDivElement>(null);
useImageContextMenu(imageDTO, ref);
return (
<Flex
ref={ref}
width="full"
height="full"
alignItems="center"
justifyContent="center"
position="relative"
minW={minSize ? minSize : undefined}
minH={minSize ? minSize : undefined}
userSelect="none"
cursor={isDragDisabled || !imageDTO ? 'default' : 'pointer'}
sx={withHoverOverlay ? sx : baseStyles}
data-selected={isSelectedForCompare ? 'selectedForCompare' : isSelected ? 'selected' : undefined}
{...rest}
>
{imageDTO && (
<ImageContextMenu imageDTO={imageDTO}>
{(ref) => (
<Flex
className="gallery-image-container"
w="full"
h="full"
position={fitContainer ? 'absolute' : 'relative'}
ref={ref}
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
width="full"
height="full"
alignItems="center"
justifyContent="center"
position="relative"
minW={minSize ? minSize : undefined}
minH={minSize ? minSize : undefined}
userSelect="none"
cursor={isDragDisabled || !imageDTO ? 'default' : 'pointer'}
sx={withHoverOverlay ? sx : undefined}
data-selected={isSelectedForCompare ? 'selectedForCompare' : isSelected ? 'selected' : undefined}
{...rest}
>
<Image
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
fallbackSrc={useThumbailFallback ? imageDTO.thumbnail_url : undefined}
fallback={useThumbailFallback ? undefined : <IAILoadingImageFallback image={imageDTO} />}
onError={onError}
draggable={false}
w={imageDTO.width}
objectFit="contain"
maxW="full"
maxH="full"
borderRadius="base"
sx={imageSx}
data-testid={dataTestId}
/>
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
{imageDTO && (
<Flex
className="gallery-image-container"
w="full"
h="full"
position={fitContainer ? 'absolute' : 'relative'}
alignItems="center"
justifyContent="center"
>
<Image
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
fallbackStrategy="beforeLoadOrError"
fallbackSrc={useThumbailFallback ? imageDTO.thumbnail_url : undefined}
fallback={useThumbailFallback ? undefined : <IAILoadingImageFallback image={imageDTO} />}
onError={onError}
draggable={false}
w={imageDTO.width}
objectFit="contain"
maxW="full"
maxH="full"
borderRadius="base"
sx={imageSx}
data-testid={dataTestId}
/>
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<>
<Flex sx={uploadButtonStyles} {...getUploadButtonProps()}>
<input {...getUploadInputProps()} />
{uploadElement}
</Flex>
</>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
onAuxClick={openInNewTab}
/>
)}
{children}
{!isDropDisabled && <IAIDroppable data={droppableData} disabled={isDropDisabled} dropLabel={dropLabel} />}
</Flex>
)}
{!imageDTO && !isUploadDisabled && (
<UploadButton
isUploadDisabled={isUploadDisabled}
postUploadAction={postUploadAction}
uploadElement={uploadElement}
minSize={minSize}
/>
)}
{!imageDTO && isUploadDisabled && noContentFallback}
{imageDTO && !isDragDisabled && (
<IAIDraggable
data={draggableData}
disabled={isDragDisabled || !imageDTO}
onClick={onClick}
onAuxClick={openInNewTab}
/>
)}
{children}
{!isDropDisabled && <IAIDroppable data={droppableData} disabled={isDropDisabled} dropLabel={dropLabel} />}
</Flex>
</ImageContextMenu>
);
};
export default memo(IAIDndImage);
const UploadButton = memo(
({
isUploadDisabled,
postUploadAction,
uploadElement,
minSize,
}: {
isUploadDisabled: boolean;
postUploadAction?: PostUploadAction;
uploadElement: ReactNode;
minSize: number;
}) => {
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
postUploadAction,
isDisabled: isUploadDisabled,
});
const uploadButtonStyles = useMemo<SystemStyleObject>(() => {
const styles: SystemStyleObject = {
minH: minSize,
w: 'full',
h: 'full',
alignItems: 'center',
justifyContent: 'center',
borderRadius: 'base',
transitionProperty: 'common',
transitionDuration: '0.1s',
color: 'base.500',
};
if (!isUploadDisabled) {
Object.assign(styles, {
cursor: 'pointer',
bg: 'base.700',
_hover: {
bg: 'base.650',
color: 'base.300',
},
});
}
return styles;
}, [isUploadDisabled, minSize]);
return (
<Flex sx={uploadButtonStyles} {...getUploadButtonProps()}>
<input {...getUploadInputProps()} />
{uploadElement}
</Flex>
);
}
);
UploadButton.displayName = 'UploadButton';

View File

@@ -9,6 +9,7 @@ import {
isModalOpenChanged,
selectChangeBoardModalSlice,
} from 'features/changeBoardModal/store/slice';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
@@ -28,7 +29,8 @@ const ChangeBoardModal = () => {
useAssertSingleton('ChangeBoardModal');
const dispatch = useAppDispatch();
const [selectedBoard, setSelectedBoard] = useState<string | null>();
const { data: boards, isFetching } = useListAllBoardsQuery({ include_archived: true });
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
const { data: boards, isFetching } = useListAllBoardsQuery(queryArgs);
const isModalOpen = useAppSelector(selectIsModalOpen);
const imagesToChange = useAppSelector(selectImagesToChange);
const [addImagesToBoard] = useAddImagesToBoardMutation();

View File

@@ -80,6 +80,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addControlLayer}
isDisabled={isFLUX}
>
{t('controlLayers.controlLayer')}
</Button>

View File

@@ -56,7 +56,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.layer_other')}>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer} isDisabled={isFLUX}>
{t('controlLayers.controlLayer')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>

View File

@@ -99,6 +99,7 @@ const PanelTabs = memo(() => {
<Box as="span" w="full">
{layersTabLabel}
</Box>
{dndCtx.active && <Box position="absolute" top={0} left={0} right={0} bottom={0} border="2px solid red" />}
</Tab>
<Tab position="relative" onMouseOver={onOnMouseOverGalleryTab} onMouseOut={onMouseOut}>
{t('gallery.gallery')}

View File

@@ -16,7 +16,6 @@ import {
controlLayerModelChanged,
controlLayerWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
@@ -43,7 +42,6 @@ export const ControlLayerControlAdapter = memo(() => {
const entityIdentifier = useEntityIdentifierContext('control_layer');
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
const filter = useEntityFilter(entityIdentifier);
const isFLUX = useAppSelector(selectIsFLUX);
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
@@ -119,7 +117,7 @@ export const ControlLayerControlAdapter = memo(() => {
</Flex>
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
{controlAdapter.type === 'controlnet' && !isFLUX && (
{controlAdapter.type === 'controlnet' && (
<ControlLayerControlAdapterControlMode
controlMode={controlAdapter.controlMode}
onChange={onChangeControlMode}

View File

@@ -18,7 +18,7 @@ export const ControlLayerMenuItems = memo(() => {
<IconMenuItemGroup>
<CanvasEntityMenuItemsArrange />
<CanvasEntityMenuItemsDuplicate />
<CanvasEntityMenuItemsDelete asIcon />
<CanvasEntityMenuItemsDelete />
</IconMenuItemGroup>
<MenuDivider />
<CanvasEntityMenuItemsTransform />

View File

@@ -9,7 +9,7 @@ export const IPAdapterMenuItems = memo(() => {
<IconMenuItemGroup>
<CanvasEntityMenuItemsArrange />
<CanvasEntityMenuItemsDuplicate />
<CanvasEntityMenuItemsDelete asIcon />
<CanvasEntityMenuItemsDelete />
</IconMenuItemGroup>
);
});

View File

@@ -13,7 +13,7 @@ export const InpaintMaskMenuItems = memo(() => {
<IconMenuItemGroup>
<CanvasEntityMenuItemsArrange />
<CanvasEntityMenuItemsDuplicate />
<CanvasEntityMenuItemsDelete asIcon />
<CanvasEntityMenuItemsDelete />
</IconMenuItemGroup>
<MenuDivider />
<CanvasEntityMenuItemsTransform />

View File

@@ -17,7 +17,7 @@ export const RasterLayerMenuItems = memo(() => {
<IconMenuItemGroup>
<CanvasEntityMenuItemsArrange />
<CanvasEntityMenuItemsDuplicate />
<CanvasEntityMenuItemsDelete asIcon />
<CanvasEntityMenuItemsDelete />
</IconMenuItemGroup>
<MenuDivider />
<CanvasEntityMenuItemsTransform />

View File

@@ -14,7 +14,7 @@ export const RegionalGuidanceMenuItems = memo(() => {
<Flex gap={2}>
<CanvasEntityMenuItemsArrange />
<CanvasEntityMenuItemsDuplicate />
<CanvasEntityMenuItemsDelete asIcon />
<CanvasEntityMenuItemsDelete />
</Flex>
<MenuDivider />
<RegionalGuidanceMenuItemsAddPromptsAndIPAdapter />

View File

@@ -1,4 +1,3 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { IconMenuItem } from 'common/components/IconMenuItem';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
@@ -8,11 +7,7 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
type Props = {
asIcon?: boolean;
};
export const CanvasEntityMenuItemsDelete = memo(({ asIcon = false }: Props) => {
export const CanvasEntityMenuItemsDelete = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const entityIdentifier = useEntityIdentifierContext();
@@ -22,23 +17,15 @@ export const CanvasEntityMenuItemsDelete = memo(({ asIcon = false }: Props) => {
dispatch(entityDeleted({ entityIdentifier }));
}, [dispatch, entityIdentifier]);
if (asIcon) {
return (
<IconMenuItem
aria-label={t('common.delete')}
tooltip={t('common.delete')}
onClick={deleteEntity}
icon={<PiTrashSimpleBold />}
isDestructive
isDisabled={!isInteractable}
/>
);
}
return (
<MenuItem onClick={deleteEntity} icon={<PiTrashSimpleBold />} isDestructive isDisabled={!isInteractable}>
{t('common.delete')}
</MenuItem>
<IconMenuItem
aria-label={t('common.delete')}
tooltip={t('common.delete')}
onClick={deleteEntity}
icon={<PiTrashSimpleBold />}
isDestructive
isDisabled={!isInteractable}
/>
);
});

View File

@@ -2,11 +2,8 @@ import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasReset } from 'features/controlLayers/store/actions';
import {
bboxChangedFromCanvas,
controlLayerAdded,
inpaintMaskAdded,
rasterLayerAdded,
@@ -17,32 +14,19 @@ import {
rgPositivePromptChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import {
selectBboxModelBase,
selectBboxRect,
selectCanvasSlice,
selectEntityOrThrow,
} from 'features/controlLayers/store/selectors';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type {
CanvasEntityIdentifier,
CanvasRasterLayerState,
CanvasRegionalGuidanceState,
ControlNetConfig,
IPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
import {
imageDTOToImageObject,
initialControlNet,
initialIPAdapter,
initialT2IAdapter,
} from 'features/controlLayers/store/util';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { useCallback } from 'react';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
export const selectDefaultControlAdapter = createSelector(
@@ -106,74 +90,6 @@ export const useAddRasterLayer = () => {
return func;
};
export const useNewRasterLayerFromImage = () => {
const dispatch = useAppDispatch();
const bboxRect = useAppSelector(selectBboxRect);
const func = useCallback(
(imageDTO: ImageDTO) => {
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasRasterLayerState> = {
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageObject],
};
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
},
[bboxRect.x, bboxRect.y, dispatch]
);
return func;
};
/**
* Returns a function that adds a new canvas with the given image as the initial image, replicating the img2img flow:
* - Reset the canvas
* - Resize the bbox to the image's aspect ratio at the optimal size for the selected model
* - Add the image as a raster layer
* - Resizes the layer to fit the bbox using the 'fill' strategy
*
* This allows the user to immediately generate a new image from the given image without any additional steps.
*/
export const useNewCanvasFromImage = () => {
const dispatch = useAppDispatch();
const bboxRect = useAppSelector(selectBboxRect);
const base = useAppSelector(selectBboxModelBase);
const func = useCallback(
(imageDTO: ImageDTO) => {
// Calculate the new bbox dimensions to fit the image's aspect ratio at the optimal size
const ratio = imageDTO.width / imageDTO.height;
const optimalDimension = getOptimalDimension(base);
const { width, height } = calculateNewSize(ratio, optimalDimension ** 2, base);
// The overrides need to include the layer's ID so we can transform the layer it is initialized
const overrides = {
id: getPrefixedId('raster_layer'),
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageDTOToImageObject(imageDTO)],
} satisfies Partial<CanvasRasterLayerState>;
CanvasEntityAdapterBase.registerInitCallback(async (adapter) => {
// Skip the callback if the adapter is not the one we are creating
if (adapter.id !== overrides.id) {
return false;
}
// Fit the layer to the bbox w/ fill strategy
await adapter.transformer.startTransform({ silent: true });
adapter.transformer.fitToBboxFill();
await adapter.transformer.applyTransform();
return true;
});
dispatch(canvasReset());
// The `bboxChangedFromCanvas` reducer does no validation! Careful!
dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height }));
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
},
[base, bboxRect.x, bboxRect.y, dispatch]
);
return func;
};
export const useAddInpaintMask = () => {
const dispatch = useAppDispatch();
const func = useCallback(() => {

View File

@@ -7,7 +7,6 @@ import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/ko
import type { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
import type { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
import { getKonvaNodeDebugAttrs, getRectIntersection } from 'features/controlLayers/konva/util';
@@ -16,8 +15,7 @@ import {
selectIsolatedTransformingPreview,
} from 'features/controlLayers/store/canvasSettingsSlice';
import {
buildSelectIsHidden,
buildSelectIsSelected,
buildEntityIsHiddenSelector,
selectBboxRect,
selectCanvasSlice,
selectEntity,
@@ -31,11 +29,6 @@ import type { ImageDTO } from 'services/api/types';
import stableHash from 'stable-hash';
import { assert } from 'tsafe';
// Ideally, we'd type `adapter` as `CanvasEntityAdapterBase`, but the generics make this tricky. `CanvasEntityAdapter`
// is a union of all entity adapters and is functionally identical to `CanvasEntityAdapterBase`. We'll need to do a
// type assertion below in the `onInit` method, which calls these callbacks.
type InitCallback = (adapter: CanvasEntityAdapter) => Promise<boolean>;
export abstract class CanvasEntityAdapterBase<
T extends CanvasRenderableEntityState,
U extends string,
@@ -94,79 +87,7 @@ export abstract class CanvasEntityAdapterBase<
*/
abstract getHashableState: () => SerializableObject;
/**
* Callbacks that are executed when the module is initialized.
*/
private static initCallbacks = new Set<InitCallback>();
/**
* Register a callback to be run when an entity adapter is initialized.
*
* The callback is called for every adapter that is initialized with the adapter as its only argument. Use an early
* return to skip entities that are not of interest, returning `false` to keep the callback registered. Return `true`
* to unregister the callback after it is called.
*
* @param callback The callback to register.
*
* @example
* ```ts
* // A callback that is executed once for a specific entity:
* const myId = 'my_id';
* canvasManager.entityRenderer.registerOnInitCallback(async (adapter) => {
* if (adapter.id !== myId) {
* // These are not the droids you are looking for, move along
* return false;
* }
*
* doSomething();
*
* // Remove the callback
* return true;
* });
* ```
*
* @example
* ```ts
* // A callback that is executed once for the next entity that is initialized:
* canvasManager.entityRenderer.registerOnInitCallback(async (adapter) => {
* doSomething();
*
* // Remove the callback
* return true;
* });
* ```
*
* @example
* ```ts
* // A callback that is executed for every entity and is never removed:
* canvasManager.entityRenderer.registerOnInitCallback(async (adapter) => {
* // Do something with the adapter
* return false;
* });
*/
static registerInitCallback = (callback: InitCallback) => {
const wrapped = async (adapter: CanvasEntityAdapter) => {
const result = await callback(adapter);
if (result) {
this.initCallbacks.delete(wrapped);
}
return result;
};
this.initCallbacks.add(wrapped);
};
/**
* Runs all init callbacks with the given entity adapter.
* @param adapter The adapter of the entity that was initialized.
*/
private static runInitCallbacks = (adapter: CanvasEntityAdapter) => {
for (const callback of this.initCallbacks) {
callback(adapter);
}
};
selectIsHidden: Selector<RootState, boolean>;
selectIsSelected: Selector<RootState, boolean>;
/**
* The Konva nodes that make up the entity adapter:
@@ -250,8 +171,7 @@ export abstract class CanvasEntityAdapterBase<
assert(state !== undefined, 'Missing entity state on creation');
this.state = state;
this.selectIsHidden = buildSelectIsHidden(this.entityIdentifier);
this.selectIsSelected = buildSelectIsSelected(this.entityIdentifier);
this.selectIsHidden = buildEntityIsHiddenSelector(this.entityIdentifier);
/**
* There are a number of reason we may need to show or hide a layer:
@@ -260,7 +180,6 @@ export abstract class CanvasEntityAdapterBase<
* - Staging status changes and `isolatedStagingPreview` is enabled
* - Global filtering status changes and `isolatedFilteringPreview` is enabled
* - Global transforming status changes and `isolatedTransformingPreview` is enabled
* - The entity is selected or deselected (only selected and onscreen entities are rendered)
*/
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectIsHidden, this.syncVisibility));
this.subscriptions.add(
@@ -271,7 +190,6 @@ export abstract class CanvasEntityAdapterBase<
this.manager.stateApi.createStoreSubscription(selectIsolatedTransformingPreview, this.syncVisibility)
);
this.subscriptions.add(this.manager.stateApi.$transformingAdapter.listen(this.syncVisibility));
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectIsSelected, this.syncVisibility));
/**
* The tool preview may need to be updated when the entity is locked or disabled. For example, when we disable the
@@ -310,8 +228,21 @@ export abstract class CanvasEntityAdapterBase<
syncIsOnscreen = () => {
const stageRect = this.manager.stage.getScaledStageRect();
const isOnScreen = this.checkIntersection(stageRect);
const entityRect = this.transformer.$pixelRect.get();
const position = this.manager.stateApi.runSelector(this.selectPosition);
if (!position) {
return;
}
const entityRectRelativeToStage = {
x: entityRect.x + position.x,
y: entityRect.y + position.y,
width: entityRect.width,
height: entityRect.height,
};
const intersection = getRectIntersection(stageRect, entityRectRelativeToStage);
const prevIsOnScreen = this.$isOnScreen.get();
const isOnScreen = intersection.width > 0 && intersection.height > 0;
this.$isOnScreen.set(isOnScreen);
if (prevIsOnScreen !== isOnScreen) {
this.log.trace(`Moved ${isOnScreen ? 'on-screen' : 'off-screen'}`);
@@ -321,19 +252,10 @@ export abstract class CanvasEntityAdapterBase<
syncIntersectsBbox = () => {
const bboxRect = this.manager.stateApi.getBbox().rect;
const intersectsBbox = this.checkIntersection(bboxRect);
const prevIntersectsBbox = this.$intersectsBbox.get();
this.$intersectsBbox.set(intersectsBbox);
if (prevIntersectsBbox !== intersectsBbox) {
this.log.trace(`Moved ${intersectsBbox ? 'into bbox' : 'out of bbox'}`);
}
};
checkIntersection = (rect: Rect): boolean => {
const entityRect = this.transformer.$pixelRect.get();
const position = this.manager.stateApi.runSelector(this.selectPosition);
if (!position) {
return false;
return;
}
const entityRectRelativeToStage = {
x: entityRect.x + position.x,
@@ -341,9 +263,14 @@ export abstract class CanvasEntityAdapterBase<
width: entityRect.width,
height: entityRect.height,
};
const intersection = getRectIntersection(rect, entityRectRelativeToStage);
const doesIntersect = intersection.width > 0 && intersection.height > 0;
return doesIntersect;
const intersection = getRectIntersection(bboxRect, entityRectRelativeToStage);
const prevIntersectsBbox = this.$intersectsBbox.get();
const intersectsBbox = intersection.width > 0 && intersection.height > 0;
this.$intersectsBbox.set(intersectsBbox);
if (prevIntersectsBbox !== intersectsBbox) {
this.log.trace(`Moved ${intersectsBbox ? 'into bbox' : 'out of bbox'}`);
}
};
initialize = async () => {
@@ -372,10 +299,6 @@ export abstract class CanvasEntityAdapterBase<
await this.renderer.initialize();
this.syncZIndices();
this.syncVisibility();
// Call the init callbacks.
// TODO(psyche): Get rid of the cast - see note in type def for `InitCallback`.
CanvasEntityAdapterBase.runInitCallbacks(this as CanvasEntityAdapter);
};
syncZIndices = () => {

View File

@@ -1,4 +1,3 @@
import { Mutex } from 'async-mutex';
import { withResultAsync } from 'common/util/result';
import { roundToMultiple } from 'common/util/roundDownToMultiple';
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
@@ -167,13 +166,6 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
*/
$silentTransform = atom(false);
/**
* A mutex to prevent concurrent operations.
*
* The mutex is locked during transformation and during rect calculations which are handled in a web worker.
*/
transformMutex = new Mutex();
konva: {
transformer: Konva.Transformer;
proxyRect: Konva.Rect;
@@ -432,7 +424,6 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
return;
}
const { rect } = this.manager.stateApi.getBbox();
const gridSize = this.manager.stateApi.getGridSize();
const width = this.konva.proxyRect.width();
const height = this.konva.proxyRect.height();
const scaleX = rect.width / width;
@@ -446,8 +437,8 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
const offsetY = (rect.height - height * scale) / 2;
this.konva.proxyRect.setAttrs({
x: clamp(roundToMultiple(rect.x + offsetX, gridSize), rect.x, rect.x + rect.width),
y: clamp(roundToMultiple(rect.y + offsetY, gridSize), rect.y, rect.y + rect.height),
x: clamp(Math.round(rect.x + offsetX), rect.x, rect.x + rect.width),
y: clamp(Math.round(rect.y + offsetY), rect.y, rect.y + rect.height),
scaleX: scale,
scaleY: scale,
rotation: 0,
@@ -464,7 +455,6 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
return;
}
const { rect } = this.manager.stateApi.getBbox();
const gridSize = this.manager.stateApi.getGridSize();
const width = this.konva.proxyRect.width();
const height = this.konva.proxyRect.height();
const scaleX = rect.width / width;
@@ -478,8 +468,8 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
const offsetY = (rect.height - height * scale) / 2;
this.konva.proxyRect.setAttrs({
x: roundToMultiple(rect.x + offsetX, gridSize),
y: roundToMultiple(rect.y + offsetY, gridSize),
x: Math.round(rect.x + offsetX),
y: Math.round(rect.y + offsetY),
scaleX: scale,
scaleY: scale,
rotation: 0,
@@ -657,13 +647,11 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
* @param arg.silent Whether the transformation should be silent. If silent, the transform controls will not be shown,
* so you _must_ immediately call `applyTransform` or `stopTransform` to complete the transformation.
*/
startTransform = async (arg?: { silent: boolean }) => {
startTransform = (arg?: { silent: boolean }) => {
const transformingAdapter = this.manager.stateApi.$transformingAdapter.get();
if (transformingAdapter) {
assert(false, `Already transforming an entity: ${transformingAdapter.id}`);
}
// This will be released when the transformation is stopped
await this.transformMutex.acquire();
this.log.debug('Starting transform');
const { silent } = { silent: false, ...arg };
this.$silentTransform.set(silent);
@@ -716,7 +704,6 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.syncInteractionState();
this.manager.stateApi.$transformingAdapter.set(null);
this.$isProcessing.set(false);
this.transformMutex.release();
};
/**
@@ -820,6 +807,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
calculateRect = debounce(() => {
this.log.debug('Calculating bbox');
this.$isPendingRectCalculation.set(true);
const canvas = this.parent.getCanvas();
if (!this.parent.renderer.hasObjects()) {
@@ -829,7 +817,6 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.parent.$canvasCache.set(canvas);
this.$isPendingRectCalculation.set(false);
this.updateBbox();
this.transformMutex.release();
return;
}
@@ -842,7 +829,6 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.parent.$canvasCache.set(canvas);
this.$isPendingRectCalculation.set(false);
this.updateBbox();
this.transformMutex.release();
return;
}
@@ -871,14 +857,11 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
this.parent.$canvasCache.set(canvas);
this.$isPendingRectCalculation.set(false);
this.updateBbox();
this.transformMutex.release();
}
);
}, this.config.RECT_CALC_DEBOUNCE_MS);
requestRectCalculation = async () => {
// This will be released when the rect calculation is complete
await this.transformMutex.acquire();
requestRectCalculation = () => {
this.$isPendingRectCalculation.set(true);
this.syncInteractionState();
this.calculateRect();

View File

@@ -25,6 +25,7 @@ import {
getScaledBoundingBoxDimensions,
} from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify';
import type { MainModelBase } from 'features/nodes/types/common';
import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
@@ -771,6 +772,11 @@ export const canvasSlice = createSlice({
syncScaledSize(state);
},
bboxModelBaseChanged: (state, action: PayloadAction<{ modelBase: MainModelBase }>) => {
const { modelBase } = action.payload;
state.bbox.modelBase = modelBase;
syncScaledSize(state);
},
bboxSyncedToOptimalDimension: (state) => {
const optimalDimension = getOptimalDimension(state.bbox.modelBase);

View File

@@ -308,7 +308,7 @@ const getSelectIsTypeHidden = (type: CanvasEntityType) => {
/**
* Builds a selector taht selects if the entity is hidden.
*/
export const buildSelectIsHidden = (entityIdentifier: CanvasEntityIdentifier) => {
export const buildEntityIsHiddenSelector = (entityIdentifier: CanvasEntityIdentifier) => {
const selectIsTypeHidden = getSelectIsTypeHidden(entityIdentifier.type);
return createSelector(
[selectCanvasSlice, selectIsTypeHidden, selectIsStaging, selectIsolatedStagingPreview],
@@ -339,16 +339,6 @@ export const buildSelectIsHidden = (entityIdentifier: CanvasEntityIdentifier) =>
);
};
/**
* Builds a selector taht selects if the entity is selected.
*/
export const buildSelectIsSelected = (entityIdentifier: CanvasEntityIdentifier) => {
return createSelector(
selectSelectedEntityIdentifier,
(selectedEntityIdentifier) => selectedEntityIdentifier?.id === entityIdentifier.id
);
};
export const selectWidth = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.width);
export const selectHeight = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.height);
export const selectAspectRatioID = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.aspectRatio.id);

View File

@@ -1,86 +0,0 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectBoardsListOrderBy, selectBoardsListOrderDir } from 'features/gallery/store/gallerySelectors';
import { boardsListOrderByChanged, boardsListOrderDirChanged } from 'features/gallery/store/gallerySlice';
import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { z } from 'zod';
const zOrderBy = z.enum(['created_at', 'board_name']);
type OrderBy = z.infer<typeof zOrderBy>;
const isOrderBy = (v: unknown): v is OrderBy => zOrderBy.safeParse(v).success;
const zDirection = z.enum(['ASC', 'DESC']);
type Direction = z.infer<typeof zDirection>;
const isDirection = (v: unknown): v is Direction => zDirection.safeParse(v).success;
export const BoardsListSortControls = () => {
const { t } = useTranslation();
const orderBy = useAppSelector(selectBoardsListOrderBy);
const direction = useAppSelector(selectBoardsListOrderDir);
const ORDER_BY_OPTIONS: ComboboxOption[] = useMemo(
() => [
{ value: 'created_at', label: t('workflows.created') },
{ value: 'board_name', label: t('workflows.name') },
],
[t]
);
const DIRECTION_OPTIONS: ComboboxOption[] = useMemo(
() => [
{ value: 'ASC', label: t('workflows.ascending') },
{ value: 'DESC', label: t('workflows.descending') },
],
[t]
);
const dispatch = useAppDispatch();
const onChangeOrderBy = useCallback<ComboboxOnChange>(
(v) => {
if (!isOrderBy(v?.value) || v.value === orderBy) {
return;
}
dispatch(boardsListOrderByChanged(v.value));
},
[orderBy, dispatch]
);
const valueOrderBy = useMemo(() => {
return ORDER_BY_OPTIONS.find((o) => o.value === orderBy) || ORDER_BY_OPTIONS[0];
}, [orderBy, ORDER_BY_OPTIONS]);
const onChangeDirection = useCallback<ComboboxOnChange>(
(v) => {
if (!isDirection(v?.value) || v.value === direction) {
return;
}
dispatch(boardsListOrderDirChanged(v.value));
},
[direction, dispatch]
);
const valueDirection = useMemo(
() => DIRECTION_OPTIONS.find((o) => o.value === direction),
[direction, DIRECTION_OPTIONS]
);
return (
<Flex flexDir="column" gap={4}>
<FormControl orientation="horizontal" gap={1}>
<FormLabel>{t('common.orderBy')}</FormLabel>
<Combobox isSearchable={false} value={valueOrderBy} options={ORDER_BY_OPTIONS} onChange={onChangeOrderBy} />
</FormControl>
<FormControl orientation="horizontal" gap={1}>
<FormLabel>{t('common.direction')}</FormLabel>
<Combobox
isSearchable={false}
value={valueDirection}
options={DIRECTION_OPTIONS}
onChange={onChangeDirection}
/>
</FormControl>
</Flex>
);
};

View File

@@ -1,53 +0,0 @@
import {
Box,
Divider,
Flex,
IconButton,
Popover,
PopoverBody,
PopoverContent,
PopoverTrigger,
} from '@invoke-ai/ui-library';
import BoardAutoAddSelect from 'features/gallery/components/Boards/BoardAutoAddSelect';
import AutoAssignBoardCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoAssignBoardCheckbox';
import ShowArchivedBoardsCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowArchivedBoardsCheckbox';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiGearSixFill } from 'react-icons/pi';
import { BoardsListSortControls } from './BoardsListSortControls';
const BoardsSettingsPopover = () => {
const { t } = useTranslation();
return (
<Popover isLazy>
<PopoverTrigger>
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
aria-label={t('gallery.boardsSettings')}
icon={<PiGearSixFill />}
tooltip={t('gallery.boardsSettings')}
/>
</PopoverTrigger>
<PopoverContent>
<PopoverBody>
<Flex direction="column" gap={2}>
<AutoAssignBoardCheckbox />
<ShowArchivedBoardsCheckbox />
<BoardAutoAddSelect />
<Box py={2}>
<Divider />
</Box>
<BoardsListSortControls />
</Flex>
</PopoverBody>
</PopoverContent>
</Popover>
);
};
export default memo(BoardsSettingsPopover);

View File

@@ -23,7 +23,6 @@ import { useTranslation } from 'react-i18next';
import { PiMagnifyingGlassBold } from 'react-icons/pi';
import { useBoardName } from 'services/api/hooks/useBoardName';
import GallerySettingsPopover from './GallerySettingsPopover/GallerySettingsPopover';
import GalleryImageGrid from './ImageGrid/GalleryImageGrid';
import { GalleryPagination } from './ImageGrid/GalleryPagination';
import { GallerySearch } from './ImageGrid/GallerySearch';
@@ -86,18 +85,15 @@ export const Gallery = () => {
{t('gallery.assets')}
</Tab>
</Tooltip>
<Flex h="full" justifyContent="flex-end">
<GallerySettingsPopover />
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
onClick={handleClickSearch}
tooltip={searchDisclosure.isOpen ? `${t('gallery.exitSearch')}` : `${t('gallery.displaySearch')}`}
aria-label={t('gallery.displaySearch')}
icon={<PiMagnifyingGlassBold />}
/>
</Flex>
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
onClick={handleClickSearch}
tooltip={searchDisclosure.isOpen ? `${t('gallery.exitSearch')}` : `${t('gallery.displaySearch')}`}
aria-label={t('gallery.displaySearch')}
icon={<PiMagnifyingGlassBold />}
/>
</TabList>
</Tabs>

View File

@@ -15,8 +15,8 @@ import { Panel, PanelGroup } from 'react-resizable-panels';
import BoardsListWrapper from './Boards/BoardsList/BoardsListWrapper';
import BoardsSearch from './Boards/BoardsList/BoardsSearch';
import BoardsSettingsPopover from './Boards/BoardsSettingsPopover';
import { Gallery } from './Gallery';
import GallerySettingsPopover from './GallerySettingsPopover/GallerySettingsPopover';
const COLLAPSE_STYLES: CSSProperties = { flexShrink: 0, minHeight: 0 };
@@ -64,7 +64,7 @@ const GalleryPanelContent = () => {
</Flex>
<GalleryHeader />
<Flex h="full" w="25%" justifyContent="flex-end">
<BoardsSettingsPopover />
<GallerySettingsPopover />
<IconButton
size="sm"
variant="link"

View File

@@ -1,7 +1,10 @@
import { Divider, Flex, IconButton, Popover, PopoverBody, PopoverContent, PopoverTrigger } from '@invoke-ai/ui-library';
import BoardAutoAddSelect from 'features/gallery/components/Boards/BoardAutoAddSelect';
import AlwaysShowImageSizeCheckbox from 'features/gallery/components/GallerySettingsPopover/AlwaysShowImageSizeCheckbox';
import AutoAssignBoardCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoAssignBoardCheckbox';
import AutoSwitchCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoSwitchCheckbox';
import ImageMinimumWidthSlider from 'features/gallery/components/GallerySettingsPopover/ImageMinimumWidthSlider';
import ShowArchivedBoardsCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowArchivedBoardsCheckbox';
import ShowStarredFirstCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowStarredFirstCheckbox';
import SortDirectionCombobox from 'features/gallery/components/GallerySettingsPopover/SortDirectionCombobox';
import { memo } from 'react';
@@ -18,9 +21,8 @@ const GallerySettingsPopover = () => {
size="sm"
variant="link"
alignSelf="stretch"
aria-label={t('gallery.imagesSettings')}
aria-label={t('gallery.gallerySettings')}
icon={<PiGearSixFill />}
tooltip={t('gallery.imagesSettings')}
/>
</PopoverTrigger>
<PopoverContent>
@@ -28,7 +30,10 @@ const GallerySettingsPopover = () => {
<Flex direction="column" gap={2}>
<ImageMinimumWidthSlider />
<AutoSwitchCheckbox />
<AutoAssignBoardCheckbox />
<AlwaysShowImageSizeCheckbox />
<ShowArchivedBoardsCheckbox />
<BoardAutoAddSelect />
<Divider pt={2} />
<ShowStarredFirstCheckbox />
<SortDirectionCombobox />

View File

@@ -1,276 +1,42 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Menu, MenuButton, MenuList, Portal, useGlobalMenuClose } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import type { ContextMenuProps } from '@invoke-ai/ui-library';
import { ContextMenu, MenuList } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import MultipleSelectionMenuItems from 'features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems';
import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMenu/SingleSelectionMenuItems';
import { selectSelectionCount } from 'features/gallery/store/gallerySelectors';
import { map } from 'nanostores';
import type { RefObject } from 'react';
import { memo, useCallback, useEffect, useRef } from 'react';
import { memo, useCallback } from 'react';
import type { ImageDTO } from 'services/api/types';
/**
* The delay in milliseconds before the context menu opens on long press.
*/
const LONGPRESS_DELAY_MS = 500;
/**
* The threshold in pixels that the pointer must move before the long press is cancelled.
*/
const LONGPRESS_MOVE_THRESHOLD_PX = 10;
import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
import SingleSelectionMenuItems from './SingleSelectionMenuItems';
/**
* The singleton state of the context menu.
*/
const $imageContextMenuState = map<{
isOpen: boolean;
imageDTO: ImageDTO | null;
position: { x: number; y: number };
}>({
isOpen: false,
imageDTO: null,
position: { x: -1, y: -1 },
});
/**
* Convenience function to close the context menu.
*/
const onClose = () => {
$imageContextMenuState.setKey('isOpen', false);
type Props = {
imageDTO: ImageDTO | undefined;
children: ContextMenuProps<HTMLDivElement>['children'];
};
/**
* Map of elements to image DTOs. This is used to determine which image DTO to show the context menu for, depending on
* the target of the context menu or long press event.
*/
const elToImageMap = new Map<HTMLDivElement, ImageDTO>();
/**
* Given a target node, find the first registered parent element that contains the target node and return the imageDTO
* associated with it.
*/
const getImageDTOFromMap = (target: Node): ImageDTO | undefined => {
const entry = Array.from(elToImageMap.entries()).find((entry) => entry[0].contains(target));
return entry?.[1];
};
/**
* Register a context menu for an image DTO on a target element.
* @param imageDTO The image DTO to register the context menu for.
* @param targetRef The ref of the target element that should trigger the context menu.
*/
export const useImageContextMenu = (imageDTO: ImageDTO | undefined, targetRef: RefObject<HTMLDivElement>) => {
useEffect(() => {
if (!targetRef.current || !imageDTO) {
return;
}
const el = targetRef.current;
elToImageMap.set(el, imageDTO);
return () => {
elToImageMap.delete(el);
};
}, [imageDTO, targetRef]);
};
/**
* Singleton component that renders the context menu for images.
*/
export const ImageContextMenu = memo(() => {
useAssertSingleton('ImageContextMenu');
const state = useStore($imageContextMenuState);
useGlobalMenuClose(onClose);
return (
<Portal>
<Menu isOpen={state.isOpen} gutter={0} placement="auto-end" onClose={onClose}>
<MenuButton
aria-hidden={true}
w={1}
h={1}
position="absolute"
left={state.position.x}
top={state.position.y}
cursor="default"
bg="transparent"
_hover={_hover}
pointerEvents="none"
/>
<MenuContent />
</Menu>
<ImageContextMenuEventLogical />
</Portal>
);
});
ImageContextMenu.displayName = 'ImageContextMenu';
const _hover: ChakraProps['_hover'] = { bg: 'transparent' };
/**
* A logical component that listens for context menu events and opens the context menu. It's separate from
* ImageContextMenu component to avoid re-rendering the whole context menu on every context menu event.
*/
const ImageContextMenuEventLogical = memo(() => {
const lastPositionRef = useRef<{ x: number; y: number }>({ x: -1, y: -1 });
const longPressTimeoutRef = useRef(0);
const animationTimeoutRef = useRef(0);
const onContextMenu = useCallback((e: MouseEvent | PointerEvent) => {
if (e.shiftKey) {
// This is a shift + right click event, which should open the native context menu
onClose();
return;
}
const imageDTO = getImageDTOFromMap(e.target as Node);
if (!imageDTO) {
// Can't find the image DTO, close the context menu
onClose();
return;
}
// clear pending delayed open
window.clearTimeout(animationTimeoutRef.current);
e.preventDefault();
if (lastPositionRef.current.x !== e.pageX || lastPositionRef.current.y !== e.pageY) {
// if the mouse moved, we need to close, wait for animation and reopen the menu at the new position
if ($imageContextMenuState.get().isOpen) {
onClose();
}
animationTimeoutRef.current = window.setTimeout(() => {
// Open the menu after the animation with the new state
$imageContextMenuState.set({
isOpen: true,
position: { x: e.pageX, y: e.pageY },
imageDTO,
});
}, 100);
} else {
// else we can just open the menu at the current position w/ new state
$imageContextMenuState.set({
isOpen: true,
position: { x: e.pageX, y: e.pageY },
imageDTO,
});
}
// Always sync the last position
lastPositionRef.current = { x: e.pageX, y: e.pageY };
}, []);
// Use a long press to open the context menu on touch devices
const onPointerDown = useCallback(
(e: PointerEvent) => {
if (e.pointerType === 'mouse') {
// Bail out if it's a mouse event - this is for touch/pen only
return;
}
longPressTimeoutRef.current = window.setTimeout(() => {
onContextMenu(e);
}, LONGPRESS_DELAY_MS);
lastPositionRef.current = { x: e.pageX, y: e.pageY };
},
[onContextMenu]
);
const onPointerMove = useCallback((e: PointerEvent) => {
if (e.pointerType === 'mouse') {
// Bail out if it's a mouse event - this is for touch/pen only
return;
}
if (longPressTimeoutRef.current === null) {
return;
}
// If the pointer has moved more than the threshold, cancel the long press
const lastPosition = lastPositionRef.current;
const distanceFromLastPosition = Math.hypot(e.pageX - lastPosition.x, e.pageY - lastPosition.y);
if (distanceFromLastPosition > LONGPRESS_MOVE_THRESHOLD_PX) {
clearTimeout(longPressTimeoutRef.current);
}
}, []);
const onPointerUp = useCallback((e: PointerEvent) => {
if (e.pointerType === 'mouse') {
// Bail out if it's a mouse event - this is for touch/pen only
return;
}
if (longPressTimeoutRef.current) {
clearTimeout(longPressTimeoutRef.current);
}
}, []);
const onPointerCancel = useCallback((e: PointerEvent) => {
if (e.pointerType === 'mouse') {
// Bail out if it's a mouse event - this is for touch/pen only
return;
}
if (longPressTimeoutRef.current) {
clearTimeout(longPressTimeoutRef.current);
}
}, []);
useEffect(() => {
const controller = new AbortController();
// Context menu events
window.addEventListener('contextmenu', onContextMenu, { signal: controller.signal });
// Long press events
window.addEventListener('pointerdown', onPointerDown, { signal: controller.signal });
window.addEventListener('pointerup', onPointerUp, { signal: controller.signal });
window.addEventListener('pointercancel', onPointerCancel, { signal: controller.signal });
window.addEventListener('pointermove', onPointerMove, { signal: controller.signal });
return () => {
controller.abort();
};
}, [onContextMenu, onPointerCancel, onPointerDown, onPointerMove, onPointerUp]);
useEffect(
() => () => {
// Clean up any timeouts when we unmount
window.clearTimeout(animationTimeoutRef.current);
window.clearTimeout(longPressTimeoutRef.current);
},
[]
);
return null;
});
ImageContextMenuEventLogical.displayName = 'ImageContextMenuEventLogical';
// The content of the context menu, which changes based on the selection count. Split out and memoized to avoid
// re-rendering the whole context menu too often.
const MenuContent = memo(() => {
const ImageContextMenu = ({ imageDTO, children }: Props) => {
const selectionCount = useAppSelector(selectSelectionCount);
const state = useStore($imageContextMenuState);
if (!state.imageDTO) {
return null;
}
const renderMenuFunc = useCallback(() => {
if (!imageDTO) {
return null;
}
if (selectionCount > 1) {
return (
<MenuList visibility="visible">
<MultipleSelectionMenuItems />
</MenuList>
);
}
if (selectionCount > 1) {
return (
<MenuList visibility="visible">
<MultipleSelectionMenuItems />
<SingleSelectionMenuItems imageDTO={imageDTO} />
</MenuList>
);
}
}, [imageDTO, selectionCount]);
return (
<MenuList visibility="visible">
<SingleSelectionMenuItems imageDTO={state.imageDTO} />
</MenuList>
);
});
return <ContextMenu renderMenu={renderMenuFunc}>{children}</ContextMenu>;
};
MenuContent.displayName = 'MenuContent';
export default memo(ImageContextMenu);

View File

@@ -1,7 +1,11 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useNewCanvasFromImage } from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { canvasReset } from 'features/controlLayers/store/actions';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { toast } from 'features/toast/toast';
@@ -10,16 +14,23 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFileBold } from 'react-icons/pi';
const selectBboxRect = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect);
export const ImageMenuItemNewCanvasFromImage = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const imageDTO = useImageDTOContext();
const bboxRect = useAppSelector(selectBboxRect);
const imageViewer = useImageViewer();
const newCanvasFromImage = useNewCanvasFromImage();
const isBusy = useCanvasIsBusy();
const onClick = useCallback(() => {
newCanvasFromImage(imageDTO);
const handleSendToCanvas = useCallback(() => {
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasRasterLayerState> = {
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageObject],
};
dispatch(canvasReset());
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
dispatch(setActiveTab('canvas'));
imageViewer.close();
toast({
@@ -27,10 +38,10 @@ export const ImageMenuItemNewCanvasFromImage = memo(() => {
title: t('toast.sentToCanvas'),
status: 'success',
});
}, [dispatch, imageDTO, imageViewer, newCanvasFromImage, t]);
}, [bboxRect.x, bboxRect.y, dispatch, imageDTO, imageViewer, t]);
return (
<MenuItem icon={<PiFileBold />} onClickCapture={onClick} isDisabled={isBusy}>
<MenuItem icon={<PiFileBold />} onClickCapture={handleSendToCanvas}>
{t('controlLayers.newCanvasFromImage')}
</MenuItem>
);

View File

@@ -1,8 +1,11 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
import { useNewRasterLayerFromImage } from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { sentImageToCanvas } from 'features/gallery/store/actions';
@@ -11,17 +14,23 @@ import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selectBboxRect = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect);
export const ImageMenuItemNewLayerFromImage = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const imageDTO = useImageDTOContext();
const bboxRect = useAppSelector(selectBboxRect);
const imageViewer = useImageViewer();
const newRasterLayerFromImage = useNewRasterLayerFromImage();
const isBusy = useCanvasIsBusy();
const onClick = useCallback(() => {
const handleSendToCanvas = useCallback(() => {
const imageObject = imageDTOToImageObject(imageDTO);
const overrides: Partial<CanvasRasterLayerState> = {
position: { x: bboxRect.x, y: bboxRect.y },
objects: [imageObject],
};
dispatch(sentImageToCanvas());
newRasterLayerFromImage(imageDTO);
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
dispatch(setActiveTab('canvas'));
imageViewer.close();
toast({
@@ -29,10 +38,10 @@ export const ImageMenuItemNewLayerFromImage = memo(() => {
title: t('toast.sentToCanvas'),
status: 'success',
});
}, [dispatch, imageDTO, imageViewer, newRasterLayerFromImage, t]);
}, [bboxRect.x, bboxRect.y, dispatch, imageDTO, imageViewer, t]);
return (
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClick} isDisabled={isBusy}>
<MenuItem icon={<NewLayerIcon />} onClickCapture={handleSendToCanvas}>
{t('controlLayers.newLayerFromImage')}
</MenuItem>
);

View File

@@ -1,6 +1,5 @@
import { MenuDivider } from '@invoke-ai/ui-library';
import { IconMenuItemGroup } from 'common/components/IconMenuItem';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { ImageMenuItemChangeBoard } from 'features/gallery/components/ImageContextMenu/ImageMenuItemChangeBoard';
import { ImageMenuItemCopy } from 'features/gallery/components/ImageContextMenu/ImageMenuItemCopy';
import { ImageMenuItemDelete } from 'features/gallery/components/ImageContextMenu/ImageMenuItemDelete';
@@ -38,10 +37,8 @@ const SingleSelectionMenuItems = ({ imageDTO }: SingleSelectionMenuItemsProps) =
<ImageMenuItemMetadataRecallActions />
<MenuDivider />
<ImageMenuItemSendToUpscale />
<CanvasManagerProviderGate>
<ImageMenuItemNewLayerFromImage />
<ImageMenuItemNewCanvasFromImage />
</CanvasManagerProviderGate>
<ImageMenuItemNewLayerFromImage />
<ImageMenuItemNewCanvasFromImage />
<MenuDivider />
<ImageMenuItemChangeBoard />
<ImageMenuItemStarUnstar />

View File

@@ -63,9 +63,12 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
() => createSelector(selectGallerySlice, (gallery) => gallery.imageToCompare?.image_name === imageDTO.image_name),
[imageDTO.image_name]
);
const alwaysShowImageSizeBadge = useAppSelector(selectAlwaysShouldImageSizeBadge);
const isSelectedForCompare = useAppSelector(selectIsSelectedForCompare);
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
const customStarUi = useStore($customStarUI);
const imageContainerRef = useScrollIntoView(isSelected, index, areMultiplesSelected);
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
@@ -88,6 +91,20 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
}
}, [imageDTO, selectedBoardId, areMultiplesSelected]);
const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation();
const toggleStarredState = useCallback(() => {
if (imageDTO) {
if (imageDTO.starred) {
unstarImages({ imageDTOs: [imageDTO] });
}
if (!imageDTO.starred) {
starImages({ imageDTOs: [imageDTO] });
}
}
}, [starImages, unstarImages, imageDTO]);
const [isHovered, setIsHovered] = useState(false);
const handleMouseOver = useCallback(() => {
@@ -104,6 +121,25 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
setIsHovered(false);
}, []);
const starIcon = useMemo(() => {
if (imageDTO.starred) {
return customStarUi ? customStarUi.on.icon : <PiStarFill />;
}
if (!imageDTO.starred && isHovered) {
return customStarUi ? customStarUi.off.icon : <PiStarBold />;
}
}, [imageDTO.starred, isHovered, customStarUi]);
const starTooltip = useMemo(() => {
if (imageDTO.starred) {
return customStarUi ? customStarUi.off.text : 'Unstar';
}
if (!imageDTO.starred) {
return customStarUi ? customStarUi.on.text : 'Star';
}
return '';
}, [imageDTO.starred, customStarUi]);
const dataTestId = useMemo(() => getGalleryImageDataTestId(imageDTO.image_name), [imageDTO.image_name]);
if (!imageDTO) {
@@ -119,8 +155,6 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
justifyContent="center"
alignItems="center"
aspectRatio="1/1"
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
<IAIDndImage
onClick={handleClick}
@@ -135,8 +169,38 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
isUploadDisabled={true}
thumbnail={true}
withHoverOverlay
onMouseOver={handleMouseOver}
onMouseOut={handleMouseOut}
>
<HoverIcons imageDTO={imageDTO} isHovered={isHovered} />
<>
{(isHovered || alwaysShowImageSizeBadge) && (
<Text
position="absolute"
background="base.900"
color="base.50"
fontSize="sm"
fontWeight="semibold"
bottom={1}
left={1}
opacity={0.7}
px={2}
lineHeight={1.25}
borderTopEndRadius="base"
sx={badgeSx}
pointerEvents="none"
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
)}
<IAIDndImageIcon
onClick={toggleStarredState}
icon={starIcon}
tooltip={starTooltip}
position="absolute"
top={2}
insetInlineEnd={2}
/>
{isHovered && <DeleteIcon imageDTO={imageDTO} />}
{isHovered && <OpenInViewerIconButton imageDTO={imageDTO} />}
</>
</IAIDndImage>
</Flex>
</Box>
@@ -145,21 +209,7 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
GalleryImageContent.displayName = 'GalleryImageContent';
const HoverIcons = memo(({ imageDTO, isHovered }: { imageDTO: ImageDTO; isHovered: boolean }) => {
const alwaysShowImageSizeBadge = useAppSelector(selectAlwaysShouldImageSizeBadge);
return (
<>
{(isHovered || alwaysShowImageSizeBadge) && <SizeBadge imageDTO={imageDTO} />}
{(isHovered || imageDTO.starred) && <StarIcon imageDTO={imageDTO} />}
{isHovered && <DeleteIcon imageDTO={imageDTO} />}
{isHovered && <OpenInViewerIconButton imageDTO={imageDTO} />}
</>
);
});
HoverIcons.displayName = 'HoverIcons';
const DeleteIcon = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
const DeleteIcon = ({ imageDTO }: { imageDTO: ImageDTO }) => {
const shift = useShiftModifier();
const { t } = useTranslation();
const dispatch = useAppDispatch();
@@ -188,11 +238,9 @@ const DeleteIcon = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
insetInlineEnd={2}
/>
);
});
};
DeleteIcon.displayName = 'DeleteIcon';
const OpenInViewerIconButton = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
const OpenInViewerIconButton = ({ imageDTO }: { imageDTO: ImageDTO }) => {
const imageViewer = useImageViewer();
const { t } = useTranslation();
@@ -210,77 +258,4 @@ const OpenInViewerIconButton = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
insetInlineStart={2}
/>
);
});
OpenInViewerIconButton.displayName = 'OpenInViewerIconButton';
const StarIcon = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
const customStarUi = useStore($customStarUI);
const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation();
const toggleStarredState = useCallback(() => {
if (imageDTO) {
if (imageDTO.starred) {
unstarImages({ imageDTOs: [imageDTO] });
}
if (!imageDTO.starred) {
starImages({ imageDTOs: [imageDTO] });
}
}
}, [starImages, unstarImages, imageDTO]);
const starIcon = useMemo(() => {
if (imageDTO.starred) {
return customStarUi ? customStarUi.on.icon : <PiStarFill />;
}
if (!imageDTO.starred) {
return customStarUi ? customStarUi.off.icon : <PiStarBold />;
}
}, [imageDTO.starred, customStarUi]);
const starTooltip = useMemo(() => {
if (imageDTO.starred) {
return customStarUi ? customStarUi.off.text : 'Unstar';
}
if (!imageDTO.starred) {
return customStarUi ? customStarUi.on.text : 'Star';
}
return '';
}, [imageDTO.starred, customStarUi]);
return (
<IAIDndImageIcon
onClick={toggleStarredState}
icon={starIcon}
tooltip={starTooltip}
position="absolute"
top={2}
insetInlineEnd={2}
/>
);
});
StarIcon.displayName = 'StarIcon';
const SizeBadge = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
return (
<Text
position="absolute"
background="base.900"
color="base.50"
fontSize="sm"
fontWeight="semibold"
bottom={1}
left={1}
opacity={0.7}
px={2}
lineHeight={1.25}
borderTopEndRadius="base"
sx={badgeSx}
pointerEvents="none"
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
);
});
SizeBadge.displayName = 'SizeBadge';
};

View File

@@ -1,11 +1,11 @@
import { Button, Flex, IconButton, Spacer } from '@invoke-ai/ui-library';
import { ELLIPSIS, useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
import { memo, useCallback } from 'react';
import { useCallback } from 'react';
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
import { JumpTo } from './JumpTo';
export const GalleryPagination = memo(() => {
export const GalleryPagination = () => {
const { goPrev, goNext, isPrevEnabled, isNextEnabled, pageButtons, goToPage, currentPage, total } =
useGalleryPagination();
@@ -47,9 +47,7 @@ export const GalleryPagination = memo(() => {
<JumpTo />
</Flex>
);
});
GalleryPagination.displayName = 'GalleryPagination';
};
type PageButtonProps = {
page: number | typeof ELLIPSIS;
@@ -57,7 +55,7 @@ type PageButtonProps = {
goToPage: (page: number) => void;
};
const PageButton = memo(({ page, currentPage, goToPage }: PageButtonProps) => {
const PageButton = ({ page, currentPage, goToPage }: PageButtonProps) => {
if (page === ELLIPSIS) {
return (
<Button size="sm" variant="link" isDisabled>
@@ -70,6 +68,4 @@ const PageButton = memo(({ page, currentPage, goToPage }: PageButtonProps) => {
{page}
</Button>
);
});
PageButton.displayName = 'PageButton';
};

View File

@@ -11,11 +11,11 @@ import {
useDisclosure,
} from '@invoke-ai/ui-library';
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
import { memo, useCallback, useEffect, useRef, useState } from 'react';
import { useCallback, useEffect, useRef, useState } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
export const JumpTo = memo(() => {
export const JumpTo = () => {
const { t } = useTranslation();
const { goToPage, currentPage, pages } = useGalleryPagination();
const [newPage, setNewPage] = useState(currentPage);
@@ -64,7 +64,7 @@ export const JumpTo = memo(() => {
}, [currentPage]);
return (
<Popover isOpen={isOpen} onClose={onClose} onOpen={onOpen} isLazy lazyBehavior="unmount">
<Popover isOpen={isOpen} onClose={onClose} onOpen={onOpen}>
<PopoverTrigger>
<Button aria-label={t('gallery.jump')} size="sm" onClick={onToggle} variant="outline">
{t('gallery.jump')}
@@ -94,6 +94,4 @@ export const JumpTo = memo(() => {
</PopoverContent>
</Popover>
);
});
JumpTo.displayName = 'JumpTo';
};

View File

@@ -32,8 +32,6 @@ export const selectListImagesQueryArgs = createMemoizedSelector(
export const selectListBoardsQueryArgs = createMemoizedSelector(
selectGallerySlice,
(gallery): ListBoardsArgs => ({
order_by: gallery.boardsListOrderBy,
direction: gallery.boardsListOrderDir,
include_archived: gallery.shouldShowArchivedBoards ? true : undefined,
})
);
@@ -46,9 +44,6 @@ export const selectAutoAssignBoardOnClick = createSelector(
);
export const selectBoardSearchText = createSelector(selectGallerySlice, (gallery) => gallery.boardSearchText);
export const selectSearchTerm = createSelector(selectGallerySlice, (gallery) => gallery.searchTerm);
export const selectBoardsListOrderBy = createSelector(selectGallerySlice, (gallery) => gallery.boardsListOrderBy);
export const selectBoardsListOrderDir = createSelector(selectGallerySlice, (gallery) => gallery.boardsListOrderDir);
export const selectSelectionCount = createSelector(selectGallerySlice, (gallery) => gallery.selection.length);
export const selectHasMultipleImagesSelected = createSelector(selectSelectionCount, (count) => count > 1);
export const selectGalleryImageMinimumWidth = createSelector(

View File

@@ -2,7 +2,7 @@ import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { isEqual, uniqBy } from 'lodash-es';
import type { BoardRecordOrderBy, ImageDTO } from 'services/api/types';
import type { ImageDTO } from 'services/api/types';
import type { BoardId, ComparisonMode, GalleryState, GalleryView, OrderDir } from './types';
@@ -25,8 +25,6 @@ const initialGalleryState: GalleryState = {
comparisonMode: 'slider',
comparisonFit: 'fill',
shouldShowArchivedBoards: false,
boardsListOrderBy: 'created_at',
boardsListOrderDir: 'DESC',
};
export const gallerySlice = createSlice({
@@ -163,12 +161,6 @@ export const gallerySlice = createSlice({
state.searchTerm = action.payload;
state.offset = 0;
},
boardsListOrderByChanged: (state, action: PayloadAction<BoardRecordOrderBy>) => {
state.boardsListOrderBy = action.payload;
},
boardsListOrderDirChanged: (state, action: PayloadAction<OrderDir>) => {
state.boardsListOrderDir = action.payload;
},
},
});
@@ -194,8 +186,6 @@ export const {
starredFirstChanged,
shouldShowArchivedBoardsChanged,
searchTermChanged,
boardsListOrderByChanged,
boardsListOrderDirChanged,
} = gallerySlice.actions;
export const selectGallerySlice = (state: RootState) => state.gallery;

View File

@@ -1,4 +1,4 @@
import type { BoardRecordOrderBy, ImageCategory, ImageDTO } from 'services/api/types';
import type { ImageCategory, ImageDTO } from 'services/api/types';
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other'];
@@ -28,6 +28,4 @@ export type GalleryState = {
comparisonMode: ComparisonMode;
comparisonFit: ComparisonFit;
shouldShowArchivedBoards: boolean;
boardsListOrderBy: BoardRecordOrderBy;
boardsListOrderDir: OrderDir;
};

View File

@@ -1,6 +1,7 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { fieldBoardValueChanged } from 'features/nodes/store/nodesSlice';
import type { BoardFieldInputInstance, BoardFieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback, useMemo } from 'react';
@@ -13,28 +14,26 @@ const BoardFieldInputComponent = (props: FieldComponentProps<BoardFieldInputInst
const { nodeId, field } = props;
const dispatch = useAppDispatch();
const { t } = useTranslation();
const { options, hasBoards } = useListAllBoardsQuery(
{ include_archived: true },
{
selectFromResult: ({ data }) => {
const options: ComboboxOption[] = [
{
label: 'None',
value: 'none',
},
].concat(
(data ?? []).map(({ board_id, board_name }) => ({
label: board_name,
value: board_id,
}))
);
return {
options,
hasBoards: options.length > 1,
};
},
}
);
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
const { options, hasBoards } = useListAllBoardsQuery(queryArgs, {
selectFromResult: ({ data }) => {
const options: ComboboxOption[] = [
{
label: 'None',
value: 'none',
},
].concat(
(data ?? []).map(({ board_id, board_name }) => ({
label: board_name,
value: board_id,
}))
);
return {
options,
hasBoards: options.length > 1,
};
},
});
const onChange = useCallback<ComboboxOnChange>(
(v) => {

View File

@@ -43,7 +43,7 @@ export const ShareWorkflowModal = () => {
if (!workflowToShare || !projectUrl) {
return null;
}
return `${window.location.origin}/${projectUrl}/studio?selectedWorkflowId=${workflowToShare.workflow_id}`;
return `${projectUrl}/studio?selectedWorkflowId=${workflowToShare.workflow_id}`;
}, [projectUrl, workflowToShare]);
const handleCopy = useCallback(() => {

View File

@@ -36,6 +36,8 @@ export const addControlNets = async (
};
for (const layer of validControlLayers) {
result.addedControlNets++;
const getImageDTOResult = await withResultAsync(() => {
const adapter = manager.adapters.controlLayers.get(layer.id);
assert(adapter, 'Adapter not found');
@@ -48,7 +50,6 @@ export const addControlNets = async (
const imageDTO = getImageDTOResult.value;
addControlNetToGraph(g, layer, imageDTO, collector);
result.addedControlNets++;
}
return result;
@@ -76,6 +77,8 @@ export const addT2IAdapters = async (
};
for (const layer of validControlLayers) {
result.addedT2IAdapters++;
const getImageDTOResult = await withResultAsync(() => {
const adapter = manager.adapters.controlLayers.get(layer.id);
assert(adapter, 'Adapter not found');
@@ -88,7 +91,6 @@ export const addT2IAdapters = async (
const imageDTO = getImageDTOResult.value;
addT2IAdapterToGraph(g, layer, imageDTO, collector);
result.addedT2IAdapters++;
}
return result;
@@ -108,10 +110,10 @@ const addControlNetToGraph = (
const controlNet = g.addNode({
id: `control_net_${id}`,
type: model.base === 'flux' ? 'flux_controlnet' : 'controlnet',
type: 'controlnet',
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
control_mode: model.base === 'flux' ? undefined : controlMode,
control_mode: controlMode,
resize_mode: 'just_resize',
control_model: model,
control_weight: weight,

View File

@@ -19,8 +19,6 @@ import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { addControlNets } from './addControlAdapters';
const log = logger('system');
export const buildFLUXGraph = async (
@@ -95,7 +93,6 @@ export const buildFLUXGraph = async (
> = l2i;
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
g.addEdge(modelLoader, 'vae', noise, 'controlnet_vae');
g.addEdge(modelLoader, 'vae', l2i, 'vae');
g.addEdge(modelLoader, 'clip', posCond, 'clip');
@@ -180,24 +177,6 @@ export const buildFLUXGraph = async (
);
}
const controlNetCollector = g.addNode({
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets(
manager,
canvas.controlLayers.entities,
g,
canvas.bbox.rect,
controlNetCollector,
modelConfig.base
);
if (controlNetResult.addedControlNets > 0) {
g.addEdge(controlNetCollector, 'collection', noise, 'control');
} else {
g.deleteNode(controlNetCollector.id);
}
if (state.system.shouldUseNSFWChecker) {
canvasOutput = addNSFWChecker(g, canvasOutput);
}

View File

@@ -1,4 +1,5 @@
import { getStore } from 'app/store/nanostores/store';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { modelsApi } from 'services/api/endpoints/models';
@@ -43,9 +44,10 @@ export const checkImageAccess = async (name: string): Promise<boolean> => {
* @returns A promise that resolves to true if the client has access, else false.
*/
export const checkBoardAccess = async (id: string): Promise<boolean> => {
const { dispatch } = getStore();
const { dispatch, getState } = getStore();
try {
const req = dispatch(boardsApi.endpoints.listAllBoards.initiate({ include_archived: true }));
const queryArgs = selectListBoardsQueryArgs(getState());
const req = dispatch(boardsApi.endpoints.listAllBoards.initiate(queryArgs));
req.unsubscribe();
const result = await req.unwrap();
return result.some((b) => b.board_id === id);

View File

@@ -1,19 +1,19 @@
import { useAppSelector } from 'app/store/storeHooks';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import type { BoardId } from 'features/gallery/store/types';
import { t } from 'i18next';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
export const useBoardName = (board_id: BoardId) => {
const { boardName } = useListAllBoardsQuery(
{ include_archived: true },
{
selectFromResult: ({ data }) => {
const selectedBoard = data?.find((b) => b.board_id === board_id);
const boardName = selectedBoard?.board_name || t('boards.uncategorized');
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
const { boardName } = useListAllBoardsQuery(queryArgs, {
selectFromResult: ({ data }) => {
const selectedBoard = data?.find((b) => b.board_id === board_id);
const boardName = selectedBoard?.board_name || t('boards.uncategorized');
return { boardName };
},
}
);
return { boardName };
},
});
return boardName;
};

File diff suppressed because one or more lines are too long

View File

@@ -241,5 +241,3 @@ export type PostUploadAction =
| RGIPAdapterImagePostUploadAction
| UpscaleInitialImageAction
| ReplaceLayerWithImagePostUploadAction;
export type BoardRecordOrderBy = S['BoardRecordOrderBy'];

View File

@@ -24,15 +24,6 @@ export default defineConfig(({ mode }) => {
cssInjectedByJsPlugin(),
],
build: {
/**
* zone.js (via faro) requires max ES2015 to prevent spamming unhandled promise rejections.
*
* See:
* - https://github.com/grafana/faro-web-sdk/issues/566
* - https://github.com/angular/angular/issues/51328
* - https://github.com/open-telemetry/opentelemetry-js/issues/3030
*/
target: 'ES2015',
cssCodeSplit: true,
lib: {
entry: path.resolve(__dirname, './src/index.ts'),

View File

@@ -1 +1 @@
__version__ = "5.2.0rc1"
__version__ = "5.1.1"

View File

@@ -43,8 +43,8 @@ dependencies = [
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
"mediapipe>=0.10.7", # needed for "mediapipeface" controlnet model
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
"onnx==1.16.1",
"onnxruntime==1.19.2",
"onnx>=1.15.0",
"onnxruntime>=1.16.3",
"opencv-python==4.9.0.80",
"pytorch-lightning==2.1.3",
"safetensors==0.4.3",

View File

@@ -1,30 +0,0 @@
import argparse
import json
from safetensors.torch import load_file
def extract_sd_keys_and_shapes(safetensors_file: str):
sd = load_file(safetensors_file)
keys_to_shapes = {k: v.shape for k, v in sd.items()}
out_file = "keys_and_shapes.json"
with open(out_file, "w") as f:
json.dump(keys_to_shapes, f, indent=4)
print(f"Keys and shapes written to '{out_file}'.")
def main():
parser = argparse.ArgumentParser(
description="Extracts the keys and shapes from the state dict in a safetensors file. Intended for creating "
+ "dummy state dicts for use in unit tests."
)
parser.add_argument("safetensors_file", type=str, help="Path to the safetensors file.")
args = parser.parse_args()
extract_sd_keys_and_shapes(args.safetensors_file)
if __name__ == "__main__":
main()

View File

@@ -1,374 +0,0 @@
# State dict keys and shapes for an InstantX FLUX ControlNet Union model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/4f32d6f2b220f8873d49bb8acc073e1df180c994/diffusion_pytorch_model.safetensors
instantx_sd_shapes = {
"context_embedder.bias": [3072],
"context_embedder.weight": [3072, 4096],
"controlnet_blocks.0.bias": [3072],
"controlnet_blocks.0.weight": [3072, 3072],
"controlnet_blocks.1.bias": [3072],
"controlnet_blocks.1.weight": [3072, 3072],
"controlnet_blocks.2.bias": [3072],
"controlnet_blocks.2.weight": [3072, 3072],
"controlnet_blocks.3.bias": [3072],
"controlnet_blocks.3.weight": [3072, 3072],
"controlnet_blocks.4.bias": [3072],
"controlnet_blocks.4.weight": [3072, 3072],
"controlnet_mode_embedder.weight": [10, 3072],
"controlnet_single_blocks.0.bias": [3072],
"controlnet_single_blocks.0.weight": [3072, 3072],
"controlnet_single_blocks.1.bias": [3072],
"controlnet_single_blocks.1.weight": [3072, 3072],
"controlnet_single_blocks.2.bias": [3072],
"controlnet_single_blocks.2.weight": [3072, 3072],
"controlnet_single_blocks.3.bias": [3072],
"controlnet_single_blocks.3.weight": [3072, 3072],
"controlnet_single_blocks.4.bias": [3072],
"controlnet_single_blocks.4.weight": [3072, 3072],
"controlnet_single_blocks.5.bias": [3072],
"controlnet_single_blocks.5.weight": [3072, 3072],
"controlnet_single_blocks.6.bias": [3072],
"controlnet_single_blocks.6.weight": [3072, 3072],
"controlnet_single_blocks.7.bias": [3072],
"controlnet_single_blocks.7.weight": [3072, 3072],
"controlnet_single_blocks.8.bias": [3072],
"controlnet_single_blocks.8.weight": [3072, 3072],
"controlnet_single_blocks.9.bias": [3072],
"controlnet_single_blocks.9.weight": [3072, 3072],
"controlnet_x_embedder.bias": [3072],
"controlnet_x_embedder.weight": [3072, 64],
"single_transformer_blocks.0.attn.norm_k.weight": [128],
"single_transformer_blocks.0.attn.norm_q.weight": [128],
"single_transformer_blocks.0.attn.to_k.bias": [3072],
"single_transformer_blocks.0.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.0.attn.to_q.bias": [3072],
"single_transformer_blocks.0.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.0.attn.to_v.bias": [3072],
"single_transformer_blocks.0.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.0.norm.linear.bias": [9216],
"single_transformer_blocks.0.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.0.proj_mlp.bias": [12288],
"single_transformer_blocks.0.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.0.proj_out.bias": [3072],
"single_transformer_blocks.0.proj_out.weight": [3072, 15360],
"single_transformer_blocks.1.attn.norm_k.weight": [128],
"single_transformer_blocks.1.attn.norm_q.weight": [128],
"single_transformer_blocks.1.attn.to_k.bias": [3072],
"single_transformer_blocks.1.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.1.attn.to_q.bias": [3072],
"single_transformer_blocks.1.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.1.attn.to_v.bias": [3072],
"single_transformer_blocks.1.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.1.norm.linear.bias": [9216],
"single_transformer_blocks.1.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.1.proj_mlp.bias": [12288],
"single_transformer_blocks.1.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.1.proj_out.bias": [3072],
"single_transformer_blocks.1.proj_out.weight": [3072, 15360],
"single_transformer_blocks.2.attn.norm_k.weight": [128],
"single_transformer_blocks.2.attn.norm_q.weight": [128],
"single_transformer_blocks.2.attn.to_k.bias": [3072],
"single_transformer_blocks.2.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.2.attn.to_q.bias": [3072],
"single_transformer_blocks.2.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.2.attn.to_v.bias": [3072],
"single_transformer_blocks.2.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.2.norm.linear.bias": [9216],
"single_transformer_blocks.2.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.2.proj_mlp.bias": [12288],
"single_transformer_blocks.2.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.2.proj_out.bias": [3072],
"single_transformer_blocks.2.proj_out.weight": [3072, 15360],
"single_transformer_blocks.3.attn.norm_k.weight": [128],
"single_transformer_blocks.3.attn.norm_q.weight": [128],
"single_transformer_blocks.3.attn.to_k.bias": [3072],
"single_transformer_blocks.3.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.3.attn.to_q.bias": [3072],
"single_transformer_blocks.3.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.3.attn.to_v.bias": [3072],
"single_transformer_blocks.3.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.3.norm.linear.bias": [9216],
"single_transformer_blocks.3.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.3.proj_mlp.bias": [12288],
"single_transformer_blocks.3.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.3.proj_out.bias": [3072],
"single_transformer_blocks.3.proj_out.weight": [3072, 15360],
"single_transformer_blocks.4.attn.norm_k.weight": [128],
"single_transformer_blocks.4.attn.norm_q.weight": [128],
"single_transformer_blocks.4.attn.to_k.bias": [3072],
"single_transformer_blocks.4.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.4.attn.to_q.bias": [3072],
"single_transformer_blocks.4.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.4.attn.to_v.bias": [3072],
"single_transformer_blocks.4.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.4.norm.linear.bias": [9216],
"single_transformer_blocks.4.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.4.proj_mlp.bias": [12288],
"single_transformer_blocks.4.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.4.proj_out.bias": [3072],
"single_transformer_blocks.4.proj_out.weight": [3072, 15360],
"single_transformer_blocks.5.attn.norm_k.weight": [128],
"single_transformer_blocks.5.attn.norm_q.weight": [128],
"single_transformer_blocks.5.attn.to_k.bias": [3072],
"single_transformer_blocks.5.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.5.attn.to_q.bias": [3072],
"single_transformer_blocks.5.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.5.attn.to_v.bias": [3072],
"single_transformer_blocks.5.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.5.norm.linear.bias": [9216],
"single_transformer_blocks.5.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.5.proj_mlp.bias": [12288],
"single_transformer_blocks.5.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.5.proj_out.bias": [3072],
"single_transformer_blocks.5.proj_out.weight": [3072, 15360],
"single_transformer_blocks.6.attn.norm_k.weight": [128],
"single_transformer_blocks.6.attn.norm_q.weight": [128],
"single_transformer_blocks.6.attn.to_k.bias": [3072],
"single_transformer_blocks.6.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.6.attn.to_q.bias": [3072],
"single_transformer_blocks.6.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.6.attn.to_v.bias": [3072],
"single_transformer_blocks.6.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.6.norm.linear.bias": [9216],
"single_transformer_blocks.6.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.6.proj_mlp.bias": [12288],
"single_transformer_blocks.6.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.6.proj_out.bias": [3072],
"single_transformer_blocks.6.proj_out.weight": [3072, 15360],
"single_transformer_blocks.7.attn.norm_k.weight": [128],
"single_transformer_blocks.7.attn.norm_q.weight": [128],
"single_transformer_blocks.7.attn.to_k.bias": [3072],
"single_transformer_blocks.7.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.7.attn.to_q.bias": [3072],
"single_transformer_blocks.7.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.7.attn.to_v.bias": [3072],
"single_transformer_blocks.7.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.7.norm.linear.bias": [9216],
"single_transformer_blocks.7.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.7.proj_mlp.bias": [12288],
"single_transformer_blocks.7.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.7.proj_out.bias": [3072],
"single_transformer_blocks.7.proj_out.weight": [3072, 15360],
"single_transformer_blocks.8.attn.norm_k.weight": [128],
"single_transformer_blocks.8.attn.norm_q.weight": [128],
"single_transformer_blocks.8.attn.to_k.bias": [3072],
"single_transformer_blocks.8.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.8.attn.to_q.bias": [3072],
"single_transformer_blocks.8.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.8.attn.to_v.bias": [3072],
"single_transformer_blocks.8.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.8.norm.linear.bias": [9216],
"single_transformer_blocks.8.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.8.proj_mlp.bias": [12288],
"single_transformer_blocks.8.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.8.proj_out.bias": [3072],
"single_transformer_blocks.8.proj_out.weight": [3072, 15360],
"single_transformer_blocks.9.attn.norm_k.weight": [128],
"single_transformer_blocks.9.attn.norm_q.weight": [128],
"single_transformer_blocks.9.attn.to_k.bias": [3072],
"single_transformer_blocks.9.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.9.attn.to_q.bias": [3072],
"single_transformer_blocks.9.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.9.attn.to_v.bias": [3072],
"single_transformer_blocks.9.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.9.norm.linear.bias": [9216],
"single_transformer_blocks.9.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.9.proj_mlp.bias": [12288],
"single_transformer_blocks.9.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.9.proj_out.bias": [3072],
"single_transformer_blocks.9.proj_out.weight": [3072, 15360],
"time_text_embed.guidance_embedder.linear_1.bias": [3072],
"time_text_embed.guidance_embedder.linear_1.weight": [3072, 256],
"time_text_embed.guidance_embedder.linear_2.bias": [3072],
"time_text_embed.guidance_embedder.linear_2.weight": [3072, 3072],
"time_text_embed.text_embedder.linear_1.bias": [3072],
"time_text_embed.text_embedder.linear_1.weight": [3072, 768],
"time_text_embed.text_embedder.linear_2.bias": [3072],
"time_text_embed.text_embedder.linear_2.weight": [3072, 3072],
"time_text_embed.timestep_embedder.linear_1.bias": [3072],
"time_text_embed.timestep_embedder.linear_1.weight": [3072, 256],
"time_text_embed.timestep_embedder.linear_2.bias": [3072],
"time_text_embed.timestep_embedder.linear_2.weight": [3072, 3072],
"transformer_blocks.0.attn.add_k_proj.bias": [3072],
"transformer_blocks.0.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.0.attn.add_q_proj.bias": [3072],
"transformer_blocks.0.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.0.attn.add_v_proj.bias": [3072],
"transformer_blocks.0.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.0.attn.norm_added_k.weight": [128],
"transformer_blocks.0.attn.norm_added_q.weight": [128],
"transformer_blocks.0.attn.norm_k.weight": [128],
"transformer_blocks.0.attn.norm_q.weight": [128],
"transformer_blocks.0.attn.to_add_out.bias": [3072],
"transformer_blocks.0.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.0.attn.to_k.bias": [3072],
"transformer_blocks.0.attn.to_k.weight": [3072, 3072],
"transformer_blocks.0.attn.to_out.0.bias": [3072],
"transformer_blocks.0.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.0.attn.to_q.bias": [3072],
"transformer_blocks.0.attn.to_q.weight": [3072, 3072],
"transformer_blocks.0.attn.to_v.bias": [3072],
"transformer_blocks.0.attn.to_v.weight": [3072, 3072],
"transformer_blocks.0.ff.net.0.proj.bias": [12288],
"transformer_blocks.0.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.0.ff.net.2.bias": [3072],
"transformer_blocks.0.ff.net.2.weight": [3072, 12288],
"transformer_blocks.0.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.0.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.0.ff_context.net.2.bias": [3072],
"transformer_blocks.0.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.0.norm1.linear.bias": [18432],
"transformer_blocks.0.norm1.linear.weight": [18432, 3072],
"transformer_blocks.0.norm1_context.linear.bias": [18432],
"transformer_blocks.0.norm1_context.linear.weight": [18432, 3072],
"transformer_blocks.1.attn.add_k_proj.bias": [3072],
"transformer_blocks.1.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.1.attn.add_q_proj.bias": [3072],
"transformer_blocks.1.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.1.attn.add_v_proj.bias": [3072],
"transformer_blocks.1.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.1.attn.norm_added_k.weight": [128],
"transformer_blocks.1.attn.norm_added_q.weight": [128],
"transformer_blocks.1.attn.norm_k.weight": [128],
"transformer_blocks.1.attn.norm_q.weight": [128],
"transformer_blocks.1.attn.to_add_out.bias": [3072],
"transformer_blocks.1.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.1.attn.to_k.bias": [3072],
"transformer_blocks.1.attn.to_k.weight": [3072, 3072],
"transformer_blocks.1.attn.to_out.0.bias": [3072],
"transformer_blocks.1.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.1.attn.to_q.bias": [3072],
"transformer_blocks.1.attn.to_q.weight": [3072, 3072],
"transformer_blocks.1.attn.to_v.bias": [3072],
"transformer_blocks.1.attn.to_v.weight": [3072, 3072],
"transformer_blocks.1.ff.net.0.proj.bias": [12288],
"transformer_blocks.1.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.1.ff.net.2.bias": [3072],
"transformer_blocks.1.ff.net.2.weight": [3072, 12288],
"transformer_blocks.1.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.1.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.1.ff_context.net.2.bias": [3072],
"transformer_blocks.1.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.1.norm1.linear.bias": [18432],
"transformer_blocks.1.norm1.linear.weight": [18432, 3072],
"transformer_blocks.1.norm1_context.linear.bias": [18432],
"transformer_blocks.1.norm1_context.linear.weight": [18432, 3072],
"transformer_blocks.2.attn.add_k_proj.bias": [3072],
"transformer_blocks.2.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.2.attn.add_q_proj.bias": [3072],
"transformer_blocks.2.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.2.attn.add_v_proj.bias": [3072],
"transformer_blocks.2.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.2.attn.norm_added_k.weight": [128],
"transformer_blocks.2.attn.norm_added_q.weight": [128],
"transformer_blocks.2.attn.norm_k.weight": [128],
"transformer_blocks.2.attn.norm_q.weight": [128],
"transformer_blocks.2.attn.to_add_out.bias": [3072],
"transformer_blocks.2.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.2.attn.to_k.bias": [3072],
"transformer_blocks.2.attn.to_k.weight": [3072, 3072],
"transformer_blocks.2.attn.to_out.0.bias": [3072],
"transformer_blocks.2.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.2.attn.to_q.bias": [3072],
"transformer_blocks.2.attn.to_q.weight": [3072, 3072],
"transformer_blocks.2.attn.to_v.bias": [3072],
"transformer_blocks.2.attn.to_v.weight": [3072, 3072],
"transformer_blocks.2.ff.net.0.proj.bias": [12288],
"transformer_blocks.2.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.2.ff.net.2.bias": [3072],
"transformer_blocks.2.ff.net.2.weight": [3072, 12288],
"transformer_blocks.2.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.2.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.2.ff_context.net.2.bias": [3072],
"transformer_blocks.2.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.2.norm1.linear.bias": [18432],
"transformer_blocks.2.norm1.linear.weight": [18432, 3072],
"transformer_blocks.2.norm1_context.linear.bias": [18432],
"transformer_blocks.2.norm1_context.linear.weight": [18432, 3072],
"transformer_blocks.3.attn.add_k_proj.bias": [3072],
"transformer_blocks.3.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.3.attn.add_q_proj.bias": [3072],
"transformer_blocks.3.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.3.attn.add_v_proj.bias": [3072],
"transformer_blocks.3.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.3.attn.norm_added_k.weight": [128],
"transformer_blocks.3.attn.norm_added_q.weight": [128],
"transformer_blocks.3.attn.norm_k.weight": [128],
"transformer_blocks.3.attn.norm_q.weight": [128],
"transformer_blocks.3.attn.to_add_out.bias": [3072],
"transformer_blocks.3.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.3.attn.to_k.bias": [3072],
"transformer_blocks.3.attn.to_k.weight": [3072, 3072],
"transformer_blocks.3.attn.to_out.0.bias": [3072],
"transformer_blocks.3.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.3.attn.to_q.bias": [3072],
"transformer_blocks.3.attn.to_q.weight": [3072, 3072],
"transformer_blocks.3.attn.to_v.bias": [3072],
"transformer_blocks.3.attn.to_v.weight": [3072, 3072],
"transformer_blocks.3.ff.net.0.proj.bias": [12288],
"transformer_blocks.3.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.3.ff.net.2.bias": [3072],
"transformer_blocks.3.ff.net.2.weight": [3072, 12288],
"transformer_blocks.3.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.3.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.3.ff_context.net.2.bias": [3072],
"transformer_blocks.3.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.3.norm1.linear.bias": [18432],
"transformer_blocks.3.norm1.linear.weight": [18432, 3072],
"transformer_blocks.3.norm1_context.linear.bias": [18432],
"transformer_blocks.3.norm1_context.linear.weight": [18432, 3072],
"transformer_blocks.4.attn.add_k_proj.bias": [3072],
"transformer_blocks.4.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.4.attn.add_q_proj.bias": [3072],
"transformer_blocks.4.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.4.attn.add_v_proj.bias": [3072],
"transformer_blocks.4.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.4.attn.norm_added_k.weight": [128],
"transformer_blocks.4.attn.norm_added_q.weight": [128],
"transformer_blocks.4.attn.norm_k.weight": [128],
"transformer_blocks.4.attn.norm_q.weight": [128],
"transformer_blocks.4.attn.to_add_out.bias": [3072],
"transformer_blocks.4.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.4.attn.to_k.bias": [3072],
"transformer_blocks.4.attn.to_k.weight": [3072, 3072],
"transformer_blocks.4.attn.to_out.0.bias": [3072],
"transformer_blocks.4.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.4.attn.to_q.bias": [3072],
"transformer_blocks.4.attn.to_q.weight": [3072, 3072],
"transformer_blocks.4.attn.to_v.bias": [3072],
"transformer_blocks.4.attn.to_v.weight": [3072, 3072],
"transformer_blocks.4.ff.net.0.proj.bias": [12288],
"transformer_blocks.4.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.4.ff.net.2.bias": [3072],
"transformer_blocks.4.ff.net.2.weight": [3072, 12288],
"transformer_blocks.4.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.4.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.4.ff_context.net.2.bias": [3072],
"transformer_blocks.4.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.4.norm1.linear.bias": [18432],
"transformer_blocks.4.norm1.linear.weight": [18432, 3072],
"transformer_blocks.4.norm1_context.linear.bias": [18432],
"transformer_blocks.4.norm1_context.linear.weight": [18432, 3072],
"x_embedder.bias": [3072],
"x_embedder.weight": [3072, 64],
}
# InstantX FLUX ControlNet config for unit tests.
# Copied from https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/main/config.json
instantx_config = {
"_class_name": "FluxControlNetModel",
"_diffusers_version": "0.30.0.dev0",
"_name_or_path": "/mnt/wangqixun/",
"attention_head_dim": 128,
"axes_dims_rope": [16, 56, 56],
"guidance_embeds": True,
"in_channels": 64,
"joint_attention_dim": 4096,
"num_attention_heads": 24,
"num_layers": 5,
"num_mode": 10,
"num_single_layers": 10,
"patch_size": 1,
"pooled_projection_dim": 768,
}

View File

@@ -1,108 +0,0 @@
import sys
import pytest
import torch
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.state_dict_utils import (
convert_diffusers_instantx_state_dict_to_bfl_format,
infer_flux_params_from_state_dict,
infer_instantx_num_control_modes_from_state_dict,
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from tests.backend.flux.controlnet.instantx_flux_controlnet_state_dict import instantx_config, instantx_sd_shapes
from tests.backend.flux.controlnet.xlabs_flux_controlnet_state_dict import xlabs_sd_shapes
@pytest.mark.parametrize(
["sd_shapes", "expected"],
[
(xlabs_sd_shapes, True),
(instantx_sd_shapes, False),
(["foo"], False),
],
)
def test_is_state_dict_xlabs_controlnet(sd_shapes: dict[str, list[int]], expected: bool):
sd = {k: None for k in sd_shapes}
assert is_state_dict_xlabs_controlnet(sd) == expected
@pytest.mark.parametrize(
["sd_keys", "expected"],
[
(instantx_sd_shapes, True),
(xlabs_sd_shapes, False),
(["foo"], False),
],
)
def test_is_state_dict_instantx_controlnet(sd_keys: list[str], expected: bool):
sd = {k: None for k in sd_keys}
assert is_state_dict_instantx_controlnet(sd) == expected
def test_convert_diffusers_instantx_state_dict_to_bfl_format():
"""Smoke test convert_diffusers_instantx_state_dict_to_bfl_format() to ensure that it handles all of the keys."""
sd = {k: torch.zeros(1) for k in instantx_sd_shapes}
bfl_sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
assert bfl_sd is not None
# TODO(ryand): Figure out why some tests in this file are failing on the MacOS CI runners. It seems to be related to
# using the meta device. I can't reproduce the issue on my local MacOS system.
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
def test_infer_flux_params_from_state_dict():
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
flux_params = infer_flux_params_from_state_dict(sd)
assert flux_params.in_channels == instantx_config["in_channels"]
assert flux_params.vec_in_dim == instantx_config["pooled_projection_dim"]
assert flux_params.context_in_dim == instantx_config["joint_attention_dim"]
assert flux_params.hidden_size // flux_params.num_heads == instantx_config["attention_head_dim"]
assert flux_params.num_heads == instantx_config["num_attention_heads"]
assert flux_params.mlp_ratio == 4
assert flux_params.depth == instantx_config["num_layers"]
assert flux_params.depth_single_blocks == instantx_config["num_single_layers"]
assert flux_params.axes_dim == instantx_config["axes_dims_rope"]
assert flux_params.theta == 10000
assert flux_params.qkv_bias
assert flux_params.guidance_embed == instantx_config["guidance_embeds"]
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
def test_infer_instantx_num_control_modes_from_state_dict():
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
assert num_control_modes == instantx_config["num_mode"]
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
def test_load_instantx_from_state_dict():
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
flux_params = infer_flux_params_from_state_dict(sd)
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
with torch.device("meta"):
model = InstantXControlNetFlux(flux_params, num_control_modes)
model_sd = model.state_dict()
assert set(model_sd.keys()) == set(sd.keys())
for key, tensor in model_sd.items():
assert isinstance(tensor, torch.Tensor)
assert tensor.shape == sd[key].shape

View File

@@ -1,91 +0,0 @@
# State dict keys and shapes for an XLabs FLUX ControlNet model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
xlabs_sd_shapes = {
"controlnet_blocks.0.bias": [3072],
"controlnet_blocks.0.weight": [3072, 3072],
"controlnet_blocks.1.bias": [3072],
"controlnet_blocks.1.weight": [3072, 3072],
"double_blocks.0.img_attn.norm.key_norm.scale": [128],
"double_blocks.0.img_attn.norm.query_norm.scale": [128],
"double_blocks.0.img_attn.proj.bias": [3072],
"double_blocks.0.img_attn.proj.weight": [3072, 3072],
"double_blocks.0.img_attn.qkv.bias": [9216],
"double_blocks.0.img_attn.qkv.weight": [9216, 3072],
"double_blocks.0.img_mlp.0.bias": [12288],
"double_blocks.0.img_mlp.0.weight": [12288, 3072],
"double_blocks.0.img_mlp.2.bias": [3072],
"double_blocks.0.img_mlp.2.weight": [3072, 12288],
"double_blocks.0.img_mod.lin.bias": [18432],
"double_blocks.0.img_mod.lin.weight": [18432, 3072],
"double_blocks.0.txt_attn.norm.key_norm.scale": [128],
"double_blocks.0.txt_attn.norm.query_norm.scale": [128],
"double_blocks.0.txt_attn.proj.bias": [3072],
"double_blocks.0.txt_attn.proj.weight": [3072, 3072],
"double_blocks.0.txt_attn.qkv.bias": [9216],
"double_blocks.0.txt_attn.qkv.weight": [9216, 3072],
"double_blocks.0.txt_mlp.0.bias": [12288],
"double_blocks.0.txt_mlp.0.weight": [12288, 3072],
"double_blocks.0.txt_mlp.2.bias": [3072],
"double_blocks.0.txt_mlp.2.weight": [3072, 12288],
"double_blocks.0.txt_mod.lin.bias": [18432],
"double_blocks.0.txt_mod.lin.weight": [18432, 3072],
"double_blocks.1.img_attn.norm.key_norm.scale": [128],
"double_blocks.1.img_attn.norm.query_norm.scale": [128],
"double_blocks.1.img_attn.proj.bias": [3072],
"double_blocks.1.img_attn.proj.weight": [3072, 3072],
"double_blocks.1.img_attn.qkv.bias": [9216],
"double_blocks.1.img_attn.qkv.weight": [9216, 3072],
"double_blocks.1.img_mlp.0.bias": [12288],
"double_blocks.1.img_mlp.0.weight": [12288, 3072],
"double_blocks.1.img_mlp.2.bias": [3072],
"double_blocks.1.img_mlp.2.weight": [3072, 12288],
"double_blocks.1.img_mod.lin.bias": [18432],
"double_blocks.1.img_mod.lin.weight": [18432, 3072],
"double_blocks.1.txt_attn.norm.key_norm.scale": [128],
"double_blocks.1.txt_attn.norm.query_norm.scale": [128],
"double_blocks.1.txt_attn.proj.bias": [3072],
"double_blocks.1.txt_attn.proj.weight": [3072, 3072],
"double_blocks.1.txt_attn.qkv.bias": [9216],
"double_blocks.1.txt_attn.qkv.weight": [9216, 3072],
"double_blocks.1.txt_mlp.0.bias": [12288],
"double_blocks.1.txt_mlp.0.weight": [12288, 3072],
"double_blocks.1.txt_mlp.2.bias": [3072],
"double_blocks.1.txt_mlp.2.weight": [3072, 12288],
"double_blocks.1.txt_mod.lin.bias": [18432],
"double_blocks.1.txt_mod.lin.weight": [18432, 3072],
"guidance_in.in_layer.bias": [3072],
"guidance_in.in_layer.weight": [3072, 256],
"guidance_in.out_layer.bias": [3072],
"guidance_in.out_layer.weight": [3072, 3072],
"img_in.bias": [3072],
"img_in.weight": [3072, 64],
"input_hint_block.0.bias": [16],
"input_hint_block.0.weight": [16, 3, 3, 3],
"input_hint_block.10.bias": [16],
"input_hint_block.10.weight": [16, 16, 3, 3],
"input_hint_block.12.bias": [16],
"input_hint_block.12.weight": [16, 16, 3, 3],
"input_hint_block.14.bias": [16],
"input_hint_block.14.weight": [16, 16, 3, 3],
"input_hint_block.2.bias": [16],
"input_hint_block.2.weight": [16, 16, 3, 3],
"input_hint_block.4.bias": [16],
"input_hint_block.4.weight": [16, 16, 3, 3],
"input_hint_block.6.bias": [16],
"input_hint_block.6.weight": [16, 16, 3, 3],
"input_hint_block.8.bias": [16],
"input_hint_block.8.weight": [16, 16, 3, 3],
"pos_embed_input.bias": [3072],
"pos_embed_input.weight": [3072, 64],
"time_in.in_layer.bias": [3072],
"time_in.in_layer.weight": [3072, 256],
"time_in.out_layer.bias": [3072],
"time_in.out_layer.weight": [3072, 3072],
"txt_in.bias": [3072],
"txt_in.weight": [3072, 4096],
"vector_in.in_layer.bias": [3072],
"vector_in.in_layer.weight": [3072, 768],
"vector_in.out_layer.bias": [3072],
"vector_in.out_layer.weight": [3072, 3072],
}