mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-21 03:28:25 -05:00
Compare commits
168 Commits
psyche/fea
...
psyche/fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b7d439c295 | ||
|
|
3da8076a2b | ||
|
|
80360a8abb | ||
|
|
acfeb4a276 | ||
|
|
b33dbfc95f | ||
|
|
f9bc29203b | ||
|
|
cbe7717409 | ||
|
|
d6add93901 | ||
|
|
ea45dce9dc | ||
|
|
8d44363d49 | ||
|
|
9933cdb6b7 | ||
|
|
e3e9d1f27c | ||
|
|
bb59ad438a | ||
|
|
e38f5b1576 | ||
|
|
1bb49b698f | ||
|
|
fa1fbd89fe | ||
|
|
190ef6732c | ||
|
|
947cd4694b | ||
|
|
ee32d0666d | ||
|
|
bc8ad9ccbf | ||
|
|
e96b290fa9 | ||
|
|
b9f83eae6a | ||
|
|
9868e23235 | ||
|
|
0060cae17c | ||
|
|
56f0845552 | ||
|
|
da3f85dd8b | ||
|
|
7185363f17 | ||
|
|
ac08c31fbc | ||
|
|
ea54a2655a | ||
|
|
cc83dede9f | ||
|
|
8464fd2ced | ||
|
|
c3316368d9 | ||
|
|
8b2d5ab28a | ||
|
|
3f6acdc2d3 | ||
|
|
4aa20a95b2 | ||
|
|
2d82e69a33 | ||
|
|
683f9a70e7 | ||
|
|
bb6d073828 | ||
|
|
7f7d8e5177 | ||
|
|
f37c5011f4 | ||
|
|
bb947c6162 | ||
|
|
a654dad20f | ||
|
|
2bd44662f3 | ||
|
|
e7f9086006 | ||
|
|
5141be8009 | ||
|
|
eacdfc660b | ||
|
|
5fd3c39431 | ||
|
|
7daf3b7d4a | ||
|
|
908f65698d | ||
|
|
63c4ac58e9 | ||
|
|
8c125681ea | ||
|
|
118f0ba3bf | ||
|
|
b3b7d084d0 | ||
|
|
812940eb95 | ||
|
|
0559480dd6 | ||
|
|
d99e7dd4e4 | ||
|
|
e854181417 | ||
|
|
de414c09fd | ||
|
|
ce4624f72b | ||
|
|
47c7df3476 | ||
|
|
4289b5e6c3 | ||
|
|
c8d1d14662 | ||
|
|
44c588d778 | ||
|
|
d75ac56d00 | ||
|
|
714dd5f0be | ||
|
|
2f4d3cb5e6 | ||
|
|
b76555bda9 | ||
|
|
1cdd501a0a | ||
|
|
1125218bc5 | ||
|
|
683504bfb5 | ||
|
|
03cf953398 | ||
|
|
24c115663d | ||
|
|
a9e7ecad49 | ||
|
|
76f4766324 | ||
|
|
3dfc242f77 | ||
|
|
1e43389cb4 | ||
|
|
cb33de34f7 | ||
|
|
7562ea48dc | ||
|
|
83f4700f5a | ||
|
|
704e7479b2 | ||
|
|
5f44559f30 | ||
|
|
7a22819100 | ||
|
|
70495665c5 | ||
|
|
ca30acc5b4 | ||
|
|
8121843d86 | ||
|
|
bc0ded0a23 | ||
|
|
30f6034f88 | ||
|
|
7d56a8ce54 | ||
|
|
e7dc439006 | ||
|
|
bce5a93eb1 | ||
|
|
93e98a1f63 | ||
|
|
0f93deab3b | ||
|
|
3f3aba8b10 | ||
|
|
0b84f567f1 | ||
|
|
69c0d7dcc9 | ||
|
|
5307248fcf | ||
|
|
2efaea8f79 | ||
|
|
c1dfd9b7d9 | ||
|
|
c594ef89d2 | ||
|
|
563db67b80 | ||
|
|
236c065edd | ||
|
|
1f5d744d01 | ||
|
|
b36c6af0ae | ||
|
|
4e431a9d5f | ||
|
|
48a8232285 | ||
|
|
94007fef5b | ||
|
|
9e6fb3bd3f | ||
|
|
4aace24f1f | ||
|
|
b1567fe0e4 | ||
|
|
3953e60a4f | ||
|
|
63a2e17f6b | ||
|
|
8b1ef4b902 | ||
|
|
5f2279c984 | ||
|
|
e82d67849c | ||
|
|
3977ffaa3e | ||
|
|
9a8a858fe4 | ||
|
|
859944f848 | ||
|
|
8d1a45863c | ||
|
|
6798bbab26 | ||
|
|
2c92e8a495 | ||
|
|
216b36c75d | ||
|
|
8bf8742984 | ||
|
|
c78eeb1645 | ||
|
|
cd88723a80 | ||
|
|
dea6cbd599 | ||
|
|
0dd9f1f772 | ||
|
|
5d11c30ce6 | ||
|
|
a783539cd2 | ||
|
|
2f8f30b497 | ||
|
|
f878e5e74e | ||
|
|
bfc460a5c6 | ||
|
|
a24581ede2 | ||
|
|
56731766ca | ||
|
|
80bc4ebee3 | ||
|
|
745b6dbd5d | ||
|
|
c7628945c4 | ||
|
|
728927ecff | ||
|
|
1a7eece695 | ||
|
|
2cd14dd066 | ||
|
|
5872f05342 | ||
|
|
4ad135c6ae | ||
|
|
c72c2770fe | ||
|
|
e733a1f30e | ||
|
|
4be3a33744 | ||
|
|
1751c380db | ||
|
|
16cda33025 | ||
|
|
8308e7d186 | ||
|
|
c0aab56d08 | ||
|
|
1795f4f8a2 | ||
|
|
5bfd2ec6b7 | ||
|
|
a35b229a9d | ||
|
|
e93da5d4b2 | ||
|
|
a17ea9bfad | ||
|
|
3578010ba4 | ||
|
|
459cf52043 | ||
|
|
9bcb93f575 | ||
|
|
d1a0e99701 | ||
|
|
92b1515d9d | ||
|
|
36515e1e2a | ||
|
|
c81bb761ed | ||
|
|
1d4a58e52b | ||
|
|
62d12e6468 | ||
|
|
9541156ce5 | ||
|
|
eb5b6625ea | ||
|
|
9758e5a622 | ||
|
|
58eba8bdbd | ||
|
|
2821ba8967 | ||
|
|
2cc72b19bc |
@@ -5,9 +5,10 @@ from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.api.dependencies import ApiDependencies
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
|
||||
boards_router = APIRouter(prefix="/v1/boards", tags=["boards"])
|
||||
|
||||
@@ -115,6 +116,8 @@ async def delete_board(
|
||||
response_model=Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]],
|
||||
)
|
||||
async def list_boards(
|
||||
order_by: BoardRecordOrderBy = Query(default=BoardRecordOrderBy.CreatedAt, description="The attribute to order by"),
|
||||
direction: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The direction to order by"),
|
||||
all: Optional[bool] = Query(default=None, description="Whether to list all boards"),
|
||||
offset: Optional[int] = Query(default=None, description="The page offset"),
|
||||
limit: Optional[int] = Query(default=None, description="The number of boards per page"),
|
||||
@@ -122,9 +125,9 @@ async def list_boards(
|
||||
) -> Union[OffsetPaginatedResults[BoardDTO], list[BoardDTO]]:
|
||||
"""Gets a list of boards"""
|
||||
if all:
|
||||
return ApiDependencies.invoker.services.boards.get_all(include_archived)
|
||||
return ApiDependencies.invoker.services.boards.get_all(order_by, direction, include_archived)
|
||||
elif offset is not None and limit is not None:
|
||||
return ApiDependencies.invoker.services.boards.get_many(offset, limit, include_archived)
|
||||
return ApiDependencies.invoker.services.boards.get_many(order_by, direction, offset, limit, include_archived)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
||||
@@ -88,7 +88,7 @@ async def list_workflows(
|
||||
default=WorkflowRecordOrderBy.Name, description="The attribute to order by"
|
||||
),
|
||||
direction: SQLiteDirection = Query(default=SQLiteDirection.Ascending, description="The direction to order by"),
|
||||
category: Optional[WorkflowCategory] = Query(default=None, description="The category of workflow to get"),
|
||||
category: WorkflowCategory = Query(default=WorkflowCategory.User, description="The category of workflow to get"),
|
||||
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
"""Gets a page of workflows"""
|
||||
|
||||
@@ -192,6 +192,7 @@ class FieldDescriptions:
|
||||
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
||||
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
|
||||
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
||||
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
|
||||
99
invokeai/app/invocations/flux_controlnet.py
Normal file
99
invokeai/app/invocations/flux_controlnet.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
||||
|
||||
|
||||
class FluxControlNetField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: float | list[float] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
instantx_control_mode: int | None = Field(default=-1, description=FieldDescriptions.instantx_control_mode)
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
|
||||
@invocation_output("flux_controlnet_output")
|
||||
class FluxControlNetOutput(BaseInvocationOutput):
|
||||
"""FLUX ControlNet info"""
|
||||
|
||||
control: FluxControlNetField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_controlnet",
|
||||
title="FLUX ControlNet",
|
||||
tags=["controlnet", "flux"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxControlNetInvocation(BaseInvocation):
|
||||
"""Collect FLUX ControlNet info to pass to other nodes."""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: float | list[float] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
# Note: We default to -1 instead of None, because in the workflow editor UI None is not currently supported.
|
||||
instantx_control_mode: int | None = InputField(default=-1, description=FieldDescriptions.instantx_control_mode)
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxControlNetOutput:
|
||||
return FluxControlNetOutput(
|
||||
control=FluxControlNetField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
resize_mode=self.resize_mode,
|
||||
instantx_control_mode=self.instantx_control_mode,
|
||||
),
|
||||
)
|
||||
@@ -16,11 +16,16 @@ from invokeai.app.invocations.fields import (
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.sampling_utils import (
|
||||
clip_timestep_schedule_fractional,
|
||||
@@ -44,7 +49,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.0.0",
|
||||
version="3.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -87,6 +92,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
control: FluxControlNetField | list[FluxControlNetField] | None = InputField(
|
||||
default=None, input=Input.Connection, description="ControlNet models."
|
||||
)
|
||||
controlnet_vae: VAEField | None = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
@@ -167,8 +179,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
b, _c, h, w = x.shape
|
||||
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
b, _c, latent_h, latent_w = x.shape
|
||||
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
@@ -192,12 +204,21 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
with (
|
||||
transformer_info.model_on_device() as (cached_weights, transformer),
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
assert isinstance(transformer, Flux)
|
||||
with ExitStack() as exit_stack:
|
||||
# Prepare ControlNet extensions.
|
||||
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
|
||||
controlnet_extensions = self._prep_controlnet_extensions(
|
||||
context=context,
|
||||
exit_stack=exit_stack,
|
||||
latent_height=latent_h,
|
||||
latent_width=latent_w,
|
||||
dtype=inference_dtype,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
# Load the transformer model.
|
||||
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
||||
assert isinstance(transformer, Flux)
|
||||
config = transformer_info.config
|
||||
assert config is not None
|
||||
|
||||
@@ -242,6 +263,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
inpaint_extension=inpaint_extension,
|
||||
controlnet_extensions=controlnet_extensions,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
@@ -288,6 +310,104 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# `latents`.
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _prep_controlnet_extensions(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
exit_stack: ExitStack,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> list[XLabsControlNetExtension | InstantXControlNetExtension]:
|
||||
# Normalize the controlnet input to list[ControlField].
|
||||
controlnets: list[FluxControlNetField]
|
||||
if self.control is None:
|
||||
controlnets = []
|
||||
elif isinstance(self.control, FluxControlNetField):
|
||||
controlnets = [self.control]
|
||||
elif isinstance(self.control, list):
|
||||
controlnets = self.control
|
||||
else:
|
||||
raise ValueError(f"Unsupported controlnet type: {type(self.control)}")
|
||||
|
||||
# TODO(ryand): Add a field to the model config so that we can distinguish between XLabs and InstantX ControlNets
|
||||
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
|
||||
# minimize peak memory.
|
||||
|
||||
# First, load the ControlNet models so that we can determine the ControlNet types.
|
||||
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
|
||||
|
||||
# Calculate the controlnet conditioning tensors.
|
||||
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
|
||||
# keep peak memory down.
|
||||
controlnet_conds: list[torch.Tensor] = []
|
||||
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
|
||||
image = context.images.get_pil(controlnet.image.image_name)
|
||||
if isinstance(controlnet_model.model, InstantXControlNetFlux):
|
||||
if self.controlnet_vae is None:
|
||||
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
controlnet_conds.append(
|
||||
InstantXControlNetExtension.prepare_controlnet_cond(
|
||||
controlnet_image=image,
|
||||
vae_info=vae_info,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
resize_mode=controlnet.resize_mode,
|
||||
)
|
||||
)
|
||||
elif isinstance(controlnet_model.model, XLabsControlNetFlux):
|
||||
controlnet_conds.append(
|
||||
XLabsControlNetExtension.prepare_controlnet_cond(
|
||||
controlnet_image=image,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
resize_mode=controlnet.resize_mode,
|
||||
)
|
||||
)
|
||||
|
||||
# Finally, load the ControlNet models and initialize the ControlNet extensions.
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
|
||||
for controlnet, controlnet_cond, controlnet_model in zip(
|
||||
controlnets, controlnet_conds, controlnet_models, strict=True
|
||||
):
|
||||
model = exit_stack.enter_context(controlnet_model)
|
||||
|
||||
if isinstance(model, XLabsControlNetFlux):
|
||||
controlnet_extensions.append(
|
||||
XLabsControlNetExtension(
|
||||
model=model,
|
||||
controlnet_cond=controlnet_cond,
|
||||
weight=controlnet.control_weight,
|
||||
begin_step_percent=controlnet.begin_step_percent,
|
||||
end_step_percent=controlnet.end_step_percent,
|
||||
)
|
||||
)
|
||||
elif isinstance(model, InstantXControlNetFlux):
|
||||
instantx_control_mode: torch.Tensor | None = None
|
||||
if controlnet.instantx_control_mode is not None and controlnet.instantx_control_mode >= 0:
|
||||
instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long)
|
||||
instantx_control_mode = instantx_control_mode.reshape([-1, 1])
|
||||
|
||||
controlnet_extensions.append(
|
||||
InstantXControlNetExtension(
|
||||
model=model,
|
||||
controlnet_cond=controlnet_cond,
|
||||
instantx_control_mode=instantx_control_mode,
|
||||
weight=controlnet.control_weight,
|
||||
begin_step_percent=controlnet.begin_step_percent,
|
||||
end_step_percent=controlnet.end_step_percent,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
|
||||
|
||||
return controlnet_extensions
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecord
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecord, BoardRecordOrderBy
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
|
||||
|
||||
class BoardRecordStorageBase(ABC):
|
||||
@@ -39,12 +40,19 @@ class BoardRecordStorageBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self, offset: int = 0, limit: int = 10, include_archived: bool = False
|
||||
self,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
"""Gets many board records."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
) -> list[BoardRecord]:
|
||||
"""Gets all board records."""
|
||||
pass
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
@@ -60,6 +62,13 @@ class BoardChanges(BaseModel, extra="forbid"):
|
||||
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
|
||||
|
||||
|
||||
class BoardRecordOrderBy(str, Enum, metaclass=MetaEnum):
|
||||
"""The order by options for board records"""
|
||||
|
||||
CreatedAt = "created_at"
|
||||
Name = "board_name"
|
||||
|
||||
|
||||
class BoardRecordNotFoundException(Exception):
|
||||
"""Raised when an board record is not found."""
|
||||
|
||||
|
||||
@@ -8,10 +8,12 @@ from invokeai.app.services.board_records.board_records_common import (
|
||||
BoardRecord,
|
||||
BoardRecordDeleteException,
|
||||
BoardRecordNotFoundException,
|
||||
BoardRecordOrderBy,
|
||||
BoardRecordSaveException,
|
||||
deserialize_board_record,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
@@ -144,7 +146,12 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
return self.get(board_id)
|
||||
|
||||
def get_many(
|
||||
self, offset: int = 0, limit: int = 10, include_archived: bool = False
|
||||
self,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@@ -154,17 +161,16 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY created_at DESC
|
||||
ORDER BY {order_by} {direction}
|
||||
LIMIT ? OFFSET ?;
|
||||
"""
|
||||
|
||||
# Determine archived filter condition
|
||||
if include_archived:
|
||||
archived_filter = ""
|
||||
else:
|
||||
archived_filter = "WHERE archived = 0"
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(archived_filter=archived_filter)
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
# Execute query to fetch boards
|
||||
self._cursor.execute(final_query, (limit, offset))
|
||||
@@ -198,23 +204,32 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
) -> list[BoardRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY created_at DESC
|
||||
"""
|
||||
|
||||
if include_archived:
|
||||
archived_filter = ""
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY LOWER(board_name) {direction}
|
||||
"""
|
||||
else:
|
||||
archived_filter = "WHERE archived = 0"
|
||||
base_query = """
|
||||
SELECT *
|
||||
FROM boards
|
||||
{archived_filter}
|
||||
ORDER BY {order_by} {direction}
|
||||
"""
|
||||
|
||||
final_query = base_query.format(archived_filter=archived_filter)
|
||||
archived_filter = "" if include_archived else "WHERE archived = 0"
|
||||
|
||||
final_query = base_query.format(
|
||||
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
|
||||
)
|
||||
|
||||
self._cursor.execute(final_query)
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
|
||||
|
||||
class BoardServiceABC(ABC):
|
||||
@@ -43,12 +44,19 @@ class BoardServiceABC(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self, offset: int = 0, limit: int = 10, include_archived: bool = False
|
||||
self,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
"""Gets many boards."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
) -> list[BoardDTO]:
|
||||
"""Gets all boards."""
|
||||
pass
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
|
||||
from invokeai.app.services.boards.boards_base import BoardServiceABC
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO, board_record_to_dto
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
|
||||
|
||||
class BoardService(BoardServiceABC):
|
||||
@@ -47,9 +48,16 @@ class BoardService(BoardServiceABC):
|
||||
self.__invoker.services.board_records.delete(board_id)
|
||||
|
||||
def get_many(
|
||||
self, offset: int = 0, limit: int = 10, include_archived: bool = False
|
||||
self,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardDTO]:
|
||||
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
|
||||
board_records = self.__invoker.services.board_records.get_many(
|
||||
order_by, direction, offset, limit, include_archived
|
||||
)
|
||||
board_dtos = []
|
||||
for r in board_records.items:
|
||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
@@ -63,8 +71,10 @@ class BoardService(BoardServiceABC):
|
||||
|
||||
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
|
||||
|
||||
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
|
||||
board_records = self.__invoker.services.board_records.get_all(include_archived)
|
||||
def get_all(
|
||||
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
|
||||
) -> list[BoardDTO]:
|
||||
board_records = self.__invoker.services.board_records.get_all(order_by, direction, include_archived)
|
||||
board_dtos = []
|
||||
for r in board_records:
|
||||
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
|
||||
|
||||
@@ -184,7 +184,8 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
) # type: ignore
|
||||
|
||||
if preferred_name := config.name:
|
||||
preferred_name = Path(preferred_name).with_suffix(model_path.suffix)
|
||||
if model_path.suffix:
|
||||
preferred_name = f"{preferred_name}.{model_path.suffix}"
|
||||
|
||||
dest_path = (
|
||||
self.app_config.models_path / info.base.value / info.type.value / (preferred_name or model_path.name)
|
||||
|
||||
@@ -41,9 +41,9 @@ class WorkflowRecordsStorageBase(ABC):
|
||||
self,
|
||||
order_by: WorkflowRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
category: WorkflowCategory,
|
||||
page: int,
|
||||
per_page: Optional[int],
|
||||
category: Optional[WorkflowCategory],
|
||||
query: Optional[str],
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
"""Gets many workflows."""
|
||||
|
||||
@@ -127,9 +127,9 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
self,
|
||||
order_by: WorkflowRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
category: WorkflowCategory,
|
||||
page: int = 0,
|
||||
per_page: Optional[int] = None,
|
||||
category: Optional[WorkflowCategory] = None,
|
||||
query: Optional[str] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
try:
|
||||
@@ -137,6 +137,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
# sanitize!
|
||||
assert order_by in WorkflowRecordOrderBy
|
||||
assert direction in SQLiteDirection
|
||||
assert category in WorkflowCategory
|
||||
count_query = "SELECT COUNT(*) FROM workflow_library"
|
||||
main_query = """
|
||||
SELECT
|
||||
@@ -148,26 +149,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
updated_at,
|
||||
opened_at
|
||||
FROM workflow_library
|
||||
WHERE category = ?
|
||||
"""
|
||||
main_params: list[int | str] = []
|
||||
count_params: list[int | str] = []
|
||||
|
||||
if category:
|
||||
assert category in WorkflowCategory
|
||||
main_query += " WHERE category = ?"
|
||||
count_query += " WHERE category = ?"
|
||||
main_params.append(category.value)
|
||||
count_params.append(category.value)
|
||||
main_params: list[int | str] = [category.value]
|
||||
count_params: list[int | str] = [category.value]
|
||||
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
wildcard_query = "%" + stripped_query + "%"
|
||||
if "WHERE" in main_query:
|
||||
main_query += " AND (name LIKE ? OR description LIKE ?)"
|
||||
count_query += " AND (name LIKE ? OR description LIKE ?)"
|
||||
else:
|
||||
main_query += " WHERE name LIKE ? OR description LIKE ?"
|
||||
count_query += " WHERE name LIKE ? OR description LIKE ?"
|
||||
main_query += " AND name LIKE ? OR description LIKE ? "
|
||||
count_query += " AND name LIKE ? OR description LIKE ?;"
|
||||
main_params.extend([wildcard_query, wildcard_query])
|
||||
count_params.extend([wildcard_query, wildcard_query])
|
||||
|
||||
|
||||
0
invokeai/backend/flux/controlnet/__init__.py
Normal file
0
invokeai/backend/flux/controlnet/__init__.py
Normal file
58
invokeai/backend/flux/controlnet/controlnet_flux_output.py
Normal file
58
invokeai/backend/flux/controlnet/controlnet_flux_output.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetFluxOutput:
|
||||
single_block_residuals: list[torch.Tensor] | None
|
||||
double_block_residuals: list[torch.Tensor] | None
|
||||
|
||||
def apply_weight(self, weight: float):
|
||||
if self.single_block_residuals is not None:
|
||||
for i in range(len(self.single_block_residuals)):
|
||||
self.single_block_residuals[i] = self.single_block_residuals[i] * weight
|
||||
if self.double_block_residuals is not None:
|
||||
for i in range(len(self.double_block_residuals)):
|
||||
self.double_block_residuals[i] = self.double_block_residuals[i] * weight
|
||||
|
||||
|
||||
def add_tensor_lists_elementwise(
|
||||
list1: list[torch.Tensor] | None, list2: list[torch.Tensor] | None
|
||||
) -> list[torch.Tensor] | None:
|
||||
"""Add two tensor lists elementwise that could be None."""
|
||||
if list1 is None and list2 is None:
|
||||
return None
|
||||
if list1 is None:
|
||||
return list2
|
||||
if list2 is None:
|
||||
return list1
|
||||
|
||||
new_list: list[torch.Tensor] = []
|
||||
for list1_tensor, list2_tensor in zip(list1, list2, strict=True):
|
||||
new_list.append(list1_tensor + list2_tensor)
|
||||
return new_list
|
||||
|
||||
|
||||
def add_controlnet_flux_outputs(
|
||||
controlnet_output_1: ControlNetFluxOutput, controlnet_output_2: ControlNetFluxOutput
|
||||
) -> ControlNetFluxOutput:
|
||||
return ControlNetFluxOutput(
|
||||
single_block_residuals=add_tensor_lists_elementwise(
|
||||
controlnet_output_1.single_block_residuals, controlnet_output_2.single_block_residuals
|
||||
),
|
||||
double_block_residuals=add_tensor_lists_elementwise(
|
||||
controlnet_output_1.double_block_residuals, controlnet_output_2.double_block_residuals
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def sum_controlnet_flux_outputs(
|
||||
controlnet_outputs: list[ControlNetFluxOutput],
|
||||
) -> ControlNetFluxOutput:
|
||||
controlnet_output_sum = ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
|
||||
|
||||
for controlnet_output in controlnet_outputs:
|
||||
controlnet_output_sum = add_controlnet_flux_outputs(controlnet_output_sum, controlnet_output)
|
||||
|
||||
return controlnet_output_sum
|
||||
180
invokeai/backend/flux/controlnet/instantx_controlnet_flux.py
Normal file
180
invokeai/backend/flux/controlnet/instantx_controlnet_flux.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# This file was initially copied from:
|
||||
# https://github.com/huggingface/diffusers/blob/99f608218caa069a2f16dcf9efab46959b15aec0/src/diffusers/models/controlnet_flux.py
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from invokeai.backend.flux.controlnet.zero_module import zero_module
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstantXControlNetFluxOutput:
|
||||
controlnet_block_samples: list[torch.Tensor] | None
|
||||
controlnet_single_block_samples: list[torch.Tensor] | None
|
||||
|
||||
|
||||
# NOTE(ryand): Mapping between diffusers FLUX transformer params and BFL FLUX transformer params:
|
||||
# - Diffusers: BFL
|
||||
# - in_channels: in_channels
|
||||
# - num_layers: depth
|
||||
# - num_single_layers: depth_single_blocks
|
||||
# - attention_head_dim: hidden_size // num_heads
|
||||
# - num_attention_heads: num_heads
|
||||
# - joint_attention_dim: context_in_dim
|
||||
# - pooled_projection_dim: vec_in_dim
|
||||
# - guidance_embeds: guidance_embed
|
||||
# - axes_dims_rope: axes_dim
|
||||
|
||||
|
||||
class InstantXControlNetFlux(torch.nn.Module):
|
||||
def __init__(self, params: FluxParams, num_control_modes: int | None = None):
|
||||
"""
|
||||
Args:
|
||||
params (FluxParams): The parameters for the FLUX model.
|
||||
num_control_modes (int | None, optional): The number of controlnet modes. If non-None, then the model is a
|
||||
'union controlnet' model and expects a mode conditioning input at runtime.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# The following modules mirror the base FLUX transformer model.
|
||||
# -------------------------------------------------------------
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
# The following modules are specific to the ControlNet model.
|
||||
# -----------------------------------------------------------
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.double_blocks)):
|
||||
self.controlnet_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.single_blocks)):
|
||||
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
|
||||
|
||||
self.is_union = False
|
||||
if num_control_modes is not None:
|
||||
self.is_union = True
|
||||
self.controlnet_mode_embedder = nn.Embedding(num_control_modes, self.hidden_size)
|
||||
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(self.in_channels, self.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
controlnet_cond: torch.Tensor,
|
||||
controlnet_mode: torch.Tensor | None,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
guidance: torch.Tensor | None = None,
|
||||
) -> InstantXControlNetFluxOutput:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
img = self.img_in(img)
|
||||
|
||||
# Add controlnet_cond embedding.
|
||||
img = img + self.controlnet_x_embedder(controlnet_cond)
|
||||
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
# If this is a union ControlNet, then concat the control mode embedding to the T5 text embedding.
|
||||
if self.is_union:
|
||||
if controlnet_mode is None:
|
||||
# We allow users to enter 'None' as the controlnet_mode if they don't want to worry about this input.
|
||||
# We've chosen to use a zero-embedding in this case.
|
||||
zero_index = torch.zeros([1, 1], dtype=torch.long, device=txt.device)
|
||||
controlnet_mode_emb = torch.zeros_like(self.controlnet_mode_embedder(zero_index))
|
||||
else:
|
||||
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
||||
txt = torch.cat([controlnet_mode_emb, txt], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:, :1, :], txt_ids], dim=1)
|
||||
else:
|
||||
assert controlnet_mode is None
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
double_block_samples: list[torch.Tensor] = []
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
double_block_samples.append(img)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
single_block_samples: list[torch.Tensor] = []
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
single_block_samples.append(img[:, txt.shape[1] :])
|
||||
|
||||
# ControlNet Block
|
||||
controlnet_double_block_samples: list[torch.Tensor] = []
|
||||
for double_block_sample, controlnet_block in zip(double_block_samples, self.controlnet_blocks, strict=True):
|
||||
double_block_sample = controlnet_block(double_block_sample)
|
||||
controlnet_double_block_samples.append(double_block_sample)
|
||||
|
||||
controlnet_single_block_samples: list[torch.Tensor] = []
|
||||
for single_block_sample, controlnet_block in zip(
|
||||
single_block_samples, self.controlnet_single_blocks, strict=True
|
||||
):
|
||||
single_block_sample = controlnet_block(single_block_sample)
|
||||
controlnet_single_block_samples.append(single_block_sample)
|
||||
|
||||
return InstantXControlNetFluxOutput(
|
||||
controlnet_block_samples=controlnet_double_block_samples or None,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples or None,
|
||||
)
|
||||
295
invokeai/backend/flux/controlnet/state_dict_utils.py
Normal file
295
invokeai/backend/flux/controlnet/state_dict_utils.py
Normal file
@@ -0,0 +1,295 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
|
||||
|
||||
def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
|
||||
"""Is the state dict for an XLabs ControlNet model?
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
|
||||
"""
|
||||
# If all of the expected keys are present, then this is very likely an XLabs ControlNet model.
|
||||
expected_keys = {
|
||||
"controlnet_blocks.0.bias",
|
||||
"controlnet_blocks.0.weight",
|
||||
"input_hint_block.0.bias",
|
||||
"input_hint_block.0.weight",
|
||||
"pos_embed_input.bias",
|
||||
"pos_embed_input.weight",
|
||||
}
|
||||
|
||||
if expected_keys.issubset(sd.keys()):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
|
||||
"""Is the state dict for an InstantX ControlNet model?
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
|
||||
"""
|
||||
# If all of the expected keys are present, then this is very likely an InstantX ControlNet model.
|
||||
expected_keys = {
|
||||
"controlnet_blocks.0.bias",
|
||||
"controlnet_blocks.0.weight",
|
||||
"controlnet_x_embedder.bias",
|
||||
"controlnet_x_embedder.weight",
|
||||
}
|
||||
|
||||
if expected_keys.issubset(sd.keys()):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _fuse_weights(*t: torch.Tensor) -> torch.Tensor:
|
||||
"""Fuse weights along dimension 0.
|
||||
|
||||
Used to fuse q, k, v attention weights into a single qkv tensor when converting from diffusers to BFL format.
|
||||
"""
|
||||
# TODO(ryand): Double check dim=0 is correct.
|
||||
return torch.cat(t, dim=0)
|
||||
|
||||
|
||||
def _convert_flux_double_block_sd_from_diffusers_to_bfl_format(
|
||||
sd: Dict[str, torch.Tensor], double_block_index: int
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the state dict for a double block from diffusers format to BFL format."""
|
||||
to_prefix = f"double_blocks.{double_block_index}"
|
||||
from_prefix = f"transformer_blocks.{double_block_index}"
|
||||
|
||||
new_sd: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Check one key to determine if this block exists.
|
||||
if f"{from_prefix}.attn.add_q_proj.bias" not in sd:
|
||||
return new_sd
|
||||
|
||||
# txt_attn.qkv
|
||||
new_sd[f"{to_prefix}.txt_attn.qkv.bias"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.add_q_proj.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.add_k_proj.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.add_v_proj.bias"),
|
||||
)
|
||||
new_sd[f"{to_prefix}.txt_attn.qkv.weight"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.add_q_proj.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.add_k_proj.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.add_v_proj.weight"),
|
||||
)
|
||||
|
||||
# img_attn.qkv
|
||||
new_sd[f"{to_prefix}.img_attn.qkv.bias"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.to_q.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.to_k.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.to_v.bias"),
|
||||
)
|
||||
new_sd[f"{to_prefix}.img_attn.qkv.weight"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.to_q.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.to_k.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.to_v.weight"),
|
||||
)
|
||||
|
||||
# Handle basic 1-to-1 key conversions.
|
||||
key_map = {
|
||||
# img_attn
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||
"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||
# img_mlp
|
||||
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
||||
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
||||
"ff.net.2.weight": "img_mlp.2.weight",
|
||||
"ff.net.2.bias": "img_mlp.2.bias",
|
||||
# img_mod
|
||||
"norm1.linear.weight": "img_mod.lin.weight",
|
||||
"norm1.linear.bias": "img_mod.lin.bias",
|
||||
# txt_attn
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
||||
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
||||
# txt_mlp
|
||||
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
||||
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||
# txt_mod
|
||||
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
||||
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
||||
}
|
||||
for from_key, to_key in key_map.items():
|
||||
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def _convert_flux_single_block_sd_from_diffusers_to_bfl_format(
|
||||
sd: Dict[str, torch.Tensor], single_block_index: int
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the state dict for a single block from diffusers format to BFL format."""
|
||||
to_prefix = f"single_blocks.{single_block_index}"
|
||||
from_prefix = f"single_transformer_blocks.{single_block_index}"
|
||||
|
||||
new_sd: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Check one key to determine if this block exists.
|
||||
if f"{from_prefix}.attn.to_q.bias" not in sd:
|
||||
return new_sd
|
||||
|
||||
# linear1 (qkv)
|
||||
new_sd[f"{to_prefix}.linear1.bias"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.to_q.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.to_k.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.to_v.bias"),
|
||||
sd.pop(f"{from_prefix}.proj_mlp.bias"),
|
||||
)
|
||||
new_sd[f"{to_prefix}.linear1.weight"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.to_q.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.to_k.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.to_v.weight"),
|
||||
sd.pop(f"{from_prefix}.proj_mlp.weight"),
|
||||
)
|
||||
|
||||
# Handle basic 1-to-1 key conversions.
|
||||
key_map = {
|
||||
# linear2
|
||||
"proj_out.weight": "linear2.weight",
|
||||
"proj_out.bias": "linear2.bias",
|
||||
# modulation
|
||||
"norm.linear.weight": "modulation.lin.weight",
|
||||
"norm.linear.bias": "modulation.lin.bias",
|
||||
# norm
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
}
|
||||
for from_key, to_key in key_map.items():
|
||||
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def convert_diffusers_instantx_state_dict_to_bfl_format(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert an InstantX ControlNet state dict to the format that can be loaded by our internal
|
||||
InstantXControlNetFlux model.
|
||||
|
||||
The original InstantX ControlNet model was developed to be used in diffusers. We have ported the original
|
||||
implementation to InstantXControlNetFlux to make it compatible with BFL-style models. This function converts the
|
||||
original state dict to the format expected by InstantXControlNetFlux.
|
||||
"""
|
||||
# Shallow copy sd so that we can pop keys from it without modifying the original.
|
||||
sd = sd.copy()
|
||||
|
||||
new_sd: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Handle basic 1-to-1 key conversions.
|
||||
basic_key_map = {
|
||||
# Base model keys.
|
||||
# ----------------
|
||||
# txt_in keys.
|
||||
"context_embedder.bias": "txt_in.bias",
|
||||
"context_embedder.weight": "txt_in.weight",
|
||||
# guidance_in MLPEmbedder keys.
|
||||
"time_text_embed.guidance_embedder.linear_1.bias": "guidance_in.in_layer.bias",
|
||||
"time_text_embed.guidance_embedder.linear_1.weight": "guidance_in.in_layer.weight",
|
||||
"time_text_embed.guidance_embedder.linear_2.bias": "guidance_in.out_layer.bias",
|
||||
"time_text_embed.guidance_embedder.linear_2.weight": "guidance_in.out_layer.weight",
|
||||
# vector_in MLPEmbedder keys.
|
||||
"time_text_embed.text_embedder.linear_1.bias": "vector_in.in_layer.bias",
|
||||
"time_text_embed.text_embedder.linear_1.weight": "vector_in.in_layer.weight",
|
||||
"time_text_embed.text_embedder.linear_2.bias": "vector_in.out_layer.bias",
|
||||
"time_text_embed.text_embedder.linear_2.weight": "vector_in.out_layer.weight",
|
||||
# time_in MLPEmbedder keys.
|
||||
"time_text_embed.timestep_embedder.linear_1.bias": "time_in.in_layer.bias",
|
||||
"time_text_embed.timestep_embedder.linear_1.weight": "time_in.in_layer.weight",
|
||||
"time_text_embed.timestep_embedder.linear_2.bias": "time_in.out_layer.bias",
|
||||
"time_text_embed.timestep_embedder.linear_2.weight": "time_in.out_layer.weight",
|
||||
# img_in keys.
|
||||
"x_embedder.bias": "img_in.bias",
|
||||
"x_embedder.weight": "img_in.weight",
|
||||
}
|
||||
for old_key, new_key in basic_key_map.items():
|
||||
v = sd.pop(old_key, None)
|
||||
if v is not None:
|
||||
new_sd[new_key] = v
|
||||
|
||||
# Handle the double_blocks.
|
||||
block_index = 0
|
||||
while True:
|
||||
converted_double_block_sd = _convert_flux_double_block_sd_from_diffusers_to_bfl_format(sd, block_index)
|
||||
if len(converted_double_block_sd) == 0:
|
||||
break
|
||||
new_sd.update(converted_double_block_sd)
|
||||
block_index += 1
|
||||
|
||||
# Handle the single_blocks.
|
||||
block_index = 0
|
||||
while True:
|
||||
converted_singe_block_sd = _convert_flux_single_block_sd_from_diffusers_to_bfl_format(sd, block_index)
|
||||
if len(converted_singe_block_sd) == 0:
|
||||
break
|
||||
new_sd.update(converted_singe_block_sd)
|
||||
block_index += 1
|
||||
|
||||
# Transfer controlnet keys as-is.
|
||||
for k in list(sd.keys()):
|
||||
if k.startswith("controlnet_"):
|
||||
new_sd[k] = sd.pop(k)
|
||||
|
||||
# Assert that all keys have been handled.
|
||||
assert len(sd) == 0
|
||||
return new_sd
|
||||
|
||||
|
||||
def infer_flux_params_from_state_dict(sd: Dict[str, torch.Tensor]) -> FluxParams:
|
||||
"""Infer the FluxParams from the shape of a FLUX state dict. When a model is distributed in diffusers format, this
|
||||
information is all contained in the config.json file that accompanies the model. However, being apple to infer the
|
||||
params from the state dict enables us to load models (e.g. an InstantX ControlNet) from a single weight file.
|
||||
"""
|
||||
hidden_size = sd["img_in.weight"].shape[0]
|
||||
mlp_hidden_dim = sd["double_blocks.0.img_mlp.0.weight"].shape[0]
|
||||
# mlp_ratio is a float, but we treat it as an int here to avoid having to think about possible float precision
|
||||
# issues. In practice, mlp_ratio is usually 4.
|
||||
mlp_ratio = mlp_hidden_dim // hidden_size
|
||||
|
||||
head_dim = sd["double_blocks.0.img_attn.norm.query_norm.scale"].shape[0]
|
||||
num_heads = hidden_size // head_dim
|
||||
|
||||
# Count the number of double blocks.
|
||||
double_block_index = 0
|
||||
while f"double_blocks.{double_block_index}.img_attn.qkv.weight" in sd:
|
||||
double_block_index += 1
|
||||
|
||||
# Count the number of single blocks.
|
||||
single_block_index = 0
|
||||
while f"single_blocks.{single_block_index}.linear1.weight" in sd:
|
||||
single_block_index += 1
|
||||
|
||||
return FluxParams(
|
||||
in_channels=sd["img_in.weight"].shape[1],
|
||||
vec_in_dim=sd["vector_in.in_layer.weight"].shape[1],
|
||||
context_in_dim=sd["txt_in.weight"].shape[1],
|
||||
hidden_size=hidden_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
num_heads=num_heads,
|
||||
depth=double_block_index,
|
||||
depth_single_blocks=single_block_index,
|
||||
# axes_dim cannot be inferred from the state dict. The hard-coded value is correct for dev/schnell models.
|
||||
axes_dim=[16, 56, 56],
|
||||
# theta cannot be inferred from the state dict. The hard-coded value is correct for dev/schnell models.
|
||||
theta=10_000,
|
||||
qkv_bias="double_blocks.0.img_attn.qkv.bias" in sd,
|
||||
guidance_embed="guidance_in.in_layer.weight" in sd,
|
||||
)
|
||||
|
||||
|
||||
def infer_instantx_num_control_modes_from_state_dict(sd: Dict[str, torch.Tensor]) -> int | None:
|
||||
"""Infer the number of ControlNet Union modes from the shape of a InstantX ControlNet state dict.
|
||||
|
||||
Returns None if the model is not a ControlNet Union model. Otherwise returns the number of modes.
|
||||
"""
|
||||
mode_embedder_key = "controlnet_mode_embedder.weight"
|
||||
if mode_embedder_key not in sd:
|
||||
return None
|
||||
|
||||
return sd[mode_embedder_key].shape[0]
|
||||
130
invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py
Normal file
130
invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# This file was initially based on:
|
||||
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/controlnet.py
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from invokeai.backend.flux.controlnet.zero_module import zero_module
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding
|
||||
|
||||
|
||||
@dataclass
|
||||
class XLabsControlNetFluxOutput:
|
||||
controlnet_double_block_residuals: list[torch.Tensor] | None
|
||||
|
||||
|
||||
class XLabsControlNetFlux(torch.nn.Module):
|
||||
"""A ControlNet model for FLUX.
|
||||
|
||||
The architecture is very similar to the base FLUX model, with the following differences:
|
||||
- A `controlnet_depth` parameter is passed to control the number of double_blocks that the ControlNet is applied to.
|
||||
In order to keep the ControlNet small, this is typically much less than the depth of the base FLUX model.
|
||||
- There is a set of `controlnet_blocks` that are applied to the output of each double_block.
|
||||
"""
|
||||
|
||||
def __init__(self, params: FluxParams, controlnet_depth: int = 2):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else torch.nn.Identity()
|
||||
)
|
||||
self.txt_in = torch.nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(controlnet_depth)
|
||||
]
|
||||
)
|
||||
|
||||
# Add ControlNet blocks.
|
||||
self.controlnet_blocks = torch.nn.ModuleList([])
|
||||
for _ in range(controlnet_depth):
|
||||
controlnet_block = torch.nn.Linear(self.hidden_size, self.hidden_size)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
self.pos_embed_input = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.input_hint_block = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(3, 16, 3, padding=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(16, 16, 3, padding=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(16, 16, 3, padding=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(16, 16, 3, padding=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
zero_module(torch.nn.Conv2d(16, 16, 3, padding=1)),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
guidance: torch.Tensor | None = None,
|
||||
) -> XLabsControlNetFluxOutput:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
block_res_samples: list[torch.Tensor] = []
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
block_res_samples.append(img)
|
||||
|
||||
controlnet_block_res_samples: list[torch.Tensor] = []
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks, strict=True):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples.append(block_res_sample)
|
||||
|
||||
return XLabsControlNetFluxOutput(controlnet_double_block_residuals=controlnet_block_res_samples)
|
||||
12
invokeai/backend/flux/controlnet/zero_module.py
Normal file
12
invokeai/backend/flux/controlnet/zero_module.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
T = TypeVar("T", bound=torch.nn.Module)
|
||||
|
||||
|
||||
def zero_module(module: T) -> T:
|
||||
"""Initialize the parameters of a module to zero."""
|
||||
for p in module.parameters():
|
||||
torch.nn.init.zeros_(p)
|
||||
return module
|
||||
@@ -3,7 +3,10 @@ from typing import Callable
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
@@ -21,6 +24,7 @@ def denoise(
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
guidance: float,
|
||||
inpaint_extension: InpaintExtension | None,
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
|
||||
):
|
||||
# step 0 is the initial state
|
||||
total_steps = len(timesteps) - 1
|
||||
@@ -38,6 +42,30 @@ def denoise(
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
# Run ControlNet models.
|
||||
controlnet_residuals: list[ControlNetFluxOutput] = []
|
||||
for controlnet_extension in controlnet_extensions:
|
||||
controlnet_residuals.append(
|
||||
controlnet_extension.run_controlnet(
|
||||
timestep_index=step - 1,
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
)
|
||||
|
||||
# Merge the ControlNet residuals from multiple ControlNets.
|
||||
# TODO(ryand): We may want to alculate the sum just-in-time to keep peak memory low. Keep in mind, that the
|
||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
@@ -46,6 +74,8 @@ def denoise(
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
|
||||
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
||||
)
|
||||
|
||||
preview_img = img - t_curr * pred
|
||||
|
||||
0
invokeai/backend/flux/extensions/__init__.py
Normal file
0
invokeai/backend/flux/extensions/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
|
||||
|
||||
|
||||
class BaseControlNetExtension(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
):
|
||||
self._weight = weight
|
||||
self._begin_step_percent = begin_step_percent
|
||||
self._end_step_percent = end_step_percent
|
||||
|
||||
def _get_weight(self, timestep_index: int, total_num_timesteps: int) -> float:
|
||||
first_step = math.floor(self._begin_step_percent * total_num_timesteps)
|
||||
last_step = math.ceil(self._end_step_percent * total_num_timesteps)
|
||||
|
||||
if timestep_index < first_step or timestep_index > last_step:
|
||||
return 0.0
|
||||
|
||||
if isinstance(self._weight, list):
|
||||
return self._weight[timestep_index]
|
||||
|
||||
return self._weight
|
||||
|
||||
@abstractmethod
|
||||
def run_controlnet(
|
||||
self,
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
guidance: torch.Tensor | None,
|
||||
) -> ControlNetFluxOutput: ...
|
||||
@@ -0,0 +1,194 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import (
|
||||
InstantXControlNetFlux,
|
||||
InstantXControlNetFluxOutput,
|
||||
)
|
||||
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
|
||||
from invokeai.backend.flux.sampling_utils import pack
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
|
||||
|
||||
class InstantXControlNetExtension(BaseControlNetExtension):
|
||||
def __init__(
|
||||
self,
|
||||
model: InstantXControlNetFlux,
|
||||
controlnet_cond: torch.Tensor,
|
||||
instantx_control_mode: torch.Tensor | None,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
):
|
||||
super().__init__(
|
||||
weight=weight,
|
||||
begin_step_percent=begin_step_percent,
|
||||
end_step_percent=end_step_percent,
|
||||
)
|
||||
self._model = model
|
||||
# The VAE-encoded and 'packed' control image to pass to the ControlNet model.
|
||||
self._controlnet_cond = controlnet_cond
|
||||
# TODO(ryand): Should we define an enum for the instantx_control_mode? Is it likely to change for future models?
|
||||
# The control mode for InstantX ControlNet union models.
|
||||
# See the values defined here: https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union#control-mode
|
||||
# Expected shape: (batch_size, 1), Expected dtype: torch.long
|
||||
# If None, a zero-embedding will be used.
|
||||
self._instantx_control_mode = instantx_control_mode
|
||||
|
||||
# TODO(ryand): Pass in these params if a new base transformer / InstantX ControlNet pair get released.
|
||||
self._flux_transformer_num_double_blocks = 19
|
||||
self._flux_transformer_num_single_blocks = 38
|
||||
|
||||
@classmethod
|
||||
def prepare_controlnet_cond(
|
||||
cls,
|
||||
controlnet_image: Image,
|
||||
vae_info: LoadedModel,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
):
|
||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
resized_controlnet_image = prepare_control_image(
|
||||
image=controlnet_image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode="balanced",
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
# Shift the image from [0, 1] to [-1, 1].
|
||||
resized_controlnet_image = resized_controlnet_image * 2 - 1
|
||||
|
||||
# Run VAE encoder.
|
||||
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
|
||||
controlnet_cond = pack(controlnet_cond)
|
||||
|
||||
return controlnet_cond
|
||||
|
||||
@classmethod
|
||||
def from_controlnet_image(
|
||||
cls,
|
||||
model: InstantXControlNetFlux,
|
||||
controlnet_image: Image,
|
||||
instantx_control_mode: torch.Tensor | None,
|
||||
vae_info: LoadedModel,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
):
|
||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
resized_controlnet_image = prepare_control_image(
|
||||
image=controlnet_image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode="balanced",
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
# Shift the image from [0, 1] to [-1, 1].
|
||||
resized_controlnet_image = resized_controlnet_image * 2 - 1
|
||||
|
||||
# Run VAE encoder.
|
||||
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
|
||||
controlnet_cond = pack(controlnet_cond)
|
||||
|
||||
return cls(
|
||||
model=model,
|
||||
controlnet_cond=controlnet_cond,
|
||||
instantx_control_mode=instantx_control_mode,
|
||||
weight=weight,
|
||||
begin_step_percent=begin_step_percent,
|
||||
end_step_percent=end_step_percent,
|
||||
)
|
||||
|
||||
def _instantx_output_to_controlnet_output(
|
||||
self, instantx_output: InstantXControlNetFluxOutput
|
||||
) -> ControlNetFluxOutput:
|
||||
# The `interval_control` logic here is based on
|
||||
# https://github.com/huggingface/diffusers/blob/31058cdaef63ca660a1a045281d156239fba8192/src/diffusers/models/transformers/transformer_flux.py#L507-L511
|
||||
|
||||
# Handle double block residuals.
|
||||
double_block_residuals: list[torch.Tensor] = []
|
||||
double_block_samples = instantx_output.controlnet_block_samples
|
||||
if double_block_samples:
|
||||
interval_control = self._flux_transformer_num_double_blocks / len(double_block_samples)
|
||||
interval_control = int(math.ceil(interval_control))
|
||||
for i in range(self._flux_transformer_num_double_blocks):
|
||||
double_block_residuals.append(double_block_samples[i // interval_control])
|
||||
|
||||
# Handle single block residuals.
|
||||
single_block_residuals: list[torch.Tensor] = []
|
||||
single_block_samples = instantx_output.controlnet_single_block_samples
|
||||
if single_block_samples:
|
||||
interval_control = self._flux_transformer_num_single_blocks / len(single_block_samples)
|
||||
interval_control = int(math.ceil(interval_control))
|
||||
for i in range(self._flux_transformer_num_single_blocks):
|
||||
single_block_residuals.append(single_block_samples[i // interval_control])
|
||||
|
||||
return ControlNetFluxOutput(
|
||||
double_block_residuals=double_block_residuals or None,
|
||||
single_block_residuals=single_block_residuals or None,
|
||||
)
|
||||
|
||||
def run_controlnet(
|
||||
self,
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
guidance: torch.Tensor | None,
|
||||
) -> ControlNetFluxOutput:
|
||||
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
|
||||
if weight < 1e-6:
|
||||
return ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
|
||||
|
||||
# Make sure inputs have correct device and dtype.
|
||||
self._controlnet_cond = self._controlnet_cond.to(device=img.device, dtype=img.dtype)
|
||||
self._instantx_control_mode = (
|
||||
self._instantx_control_mode.to(device=img.device) if self._instantx_control_mode is not None else None
|
||||
)
|
||||
|
||||
instantx_output: InstantXControlNetFluxOutput = self._model(
|
||||
controlnet_cond=self._controlnet_cond,
|
||||
controlnet_mode=self._instantx_control_mode,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
timesteps=timesteps,
|
||||
y=y,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
controlnet_output = self._instantx_output_to_controlnet_output(instantx_output)
|
||||
controlnet_output.apply_weight(weight)
|
||||
return controlnet_output
|
||||
150
invokeai/backend/flux/extensions/xlabs_controlnet_extension.py
Normal file
150
invokeai/backend/flux/extensions/xlabs_controlnet_extension.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux, XLabsControlNetFluxOutput
|
||||
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
|
||||
|
||||
|
||||
class XLabsControlNetExtension(BaseControlNetExtension):
|
||||
def __init__(
|
||||
self,
|
||||
model: XLabsControlNetFlux,
|
||||
controlnet_cond: torch.Tensor,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
):
|
||||
super().__init__(
|
||||
weight=weight,
|
||||
begin_step_percent=begin_step_percent,
|
||||
end_step_percent=end_step_percent,
|
||||
)
|
||||
|
||||
self._model = model
|
||||
# _controlnet_cond is the control image passed to the ControlNet model.
|
||||
# Pixel values are in the range [-1, 1]. Shape: (batch_size, 3, height, width).
|
||||
self._controlnet_cond = controlnet_cond
|
||||
|
||||
# TODO(ryand): Pass in these params if a new base transformer / XLabs ControlNet pair get released.
|
||||
self._flux_transformer_num_double_blocks = 19
|
||||
self._flux_transformer_num_single_blocks = 38
|
||||
|
||||
@classmethod
|
||||
def prepare_controlnet_cond(
|
||||
cls,
|
||||
controlnet_image: Image,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
):
|
||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
controlnet_cond = prepare_control_image(
|
||||
image=controlnet_image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode="balanced",
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
# Map pixel values from [0, 1] to [-1, 1].
|
||||
controlnet_cond = controlnet_cond * 2 - 1
|
||||
|
||||
return controlnet_cond
|
||||
|
||||
@classmethod
|
||||
def from_controlnet_image(
|
||||
cls,
|
||||
model: XLabsControlNetFlux,
|
||||
controlnet_image: Image,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
):
|
||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
controlnet_cond = prepare_control_image(
|
||||
image=controlnet_image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode="balanced",
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
# Map pixel values from [0, 1] to [-1, 1].
|
||||
controlnet_cond = controlnet_cond * 2 - 1
|
||||
|
||||
return cls(
|
||||
model=model,
|
||||
controlnet_cond=controlnet_cond,
|
||||
weight=weight,
|
||||
begin_step_percent=begin_step_percent,
|
||||
end_step_percent=end_step_percent,
|
||||
)
|
||||
|
||||
def _xlabs_output_to_controlnet_output(self, xlabs_output: XLabsControlNetFluxOutput) -> ControlNetFluxOutput:
|
||||
# The modulo index logic used here is based on:
|
||||
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/model.py#L198-L200
|
||||
|
||||
# Handle double block residuals.
|
||||
double_block_residuals: list[torch.Tensor] = []
|
||||
xlabs_double_block_residuals = xlabs_output.controlnet_double_block_residuals
|
||||
if xlabs_double_block_residuals is not None:
|
||||
for i in range(self._flux_transformer_num_double_blocks):
|
||||
double_block_residuals.append(xlabs_double_block_residuals[i % len(xlabs_double_block_residuals)])
|
||||
|
||||
return ControlNetFluxOutput(
|
||||
double_block_residuals=double_block_residuals,
|
||||
single_block_residuals=None,
|
||||
)
|
||||
|
||||
def run_controlnet(
|
||||
self,
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
guidance: torch.Tensor | None,
|
||||
) -> ControlNetFluxOutput:
|
||||
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
|
||||
if weight < 1e-6:
|
||||
return ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
|
||||
|
||||
xlabs_output: XLabsControlNetFluxOutput = self._model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
controlnet_cond=self._controlnet_cond,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
timesteps=timesteps,
|
||||
y=y,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
controlnet_output = self._xlabs_output_to_controlnet_output(xlabs_output)
|
||||
controlnet_output.apply_weight(weight)
|
||||
return controlnet_output
|
||||
@@ -87,7 +87,9 @@ class Flux(nn.Module):
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
guidance: Tensor | None,
|
||||
controlnet_double_block_residuals: list[Tensor] | None,
|
||||
controlnet_single_block_residuals: list[Tensor] | None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -105,12 +107,27 @@ class Flux(nn.Module):
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
# Validate double_block_residuals shape.
|
||||
if controlnet_double_block_residuals is not None:
|
||||
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
|
||||
for block_index, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
if controlnet_double_block_residuals is not None:
|
||||
img += controlnet_double_block_residuals[block_index]
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
|
||||
# Validate single_block_residuals shape.
|
||||
if controlnet_single_block_residuals is not None:
|
||||
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
|
||||
|
||||
for block_index, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
if controlnet_single_block_residuals is not None:
|
||||
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
@@ -8,17 +8,36 @@ from diffusers import ControlNetModel
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ControlNetCheckpointConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import ControlNetCheckpointConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
|
||||
)
|
||||
class ControlNetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
|
||||
@@ -10,6 +10,15 @@ from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
convert_diffusers_instantx_state_dict_to_bfl_format,
|
||||
infer_flux_params_from_state_dict,
|
||||
infer_instantx_num_control_modes_from_state_dict,
|
||||
is_state_dict_instantx_controlnet,
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.util import ae_params, params
|
||||
@@ -24,6 +33,8 @@ from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
ControlNetCheckpointConfig,
|
||||
ControlNetDiffusersConfig,
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
MainCheckpointConfig,
|
||||
MainGGUFCheckpointConfig,
|
||||
@@ -293,3 +304,51 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
class FluxControlnetModel(ModelLoader):
|
||||
"""Class to load FLUX ControlNet models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, ControlNetCheckpointConfig):
|
||||
model_path = Path(config.path)
|
||||
elif isinstance(config, ControlNetDiffusersConfig):
|
||||
# If this is a diffusers directory, we simply ignore the config file and load from the weight file.
|
||||
model_path = Path(config.path) / "diffusion_pytorch_model.safetensors"
|
||||
else:
|
||||
raise ValueError(f"Unexpected ControlNet model config type: {type(config)}")
|
||||
|
||||
sd = load_file(model_path)
|
||||
|
||||
# Detect the FLUX ControlNet model type from the state dict.
|
||||
if is_state_dict_xlabs_controlnet(sd):
|
||||
return self._load_xlabs_controlnet(sd)
|
||||
elif is_state_dict_instantx_controlnet(sd):
|
||||
return self._load_instantx_controlnet(sd)
|
||||
else:
|
||||
raise ValueError("Do not recognize the state dict as an XLabs or InstantX ControlNet model.")
|
||||
|
||||
def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
|
||||
with accelerate.init_empty_weights():
|
||||
# HACK(ryand): Is it safe to assume dev here?
|
||||
model = XLabsControlNetFlux(params["flux-dev"])
|
||||
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
def _load_instantx_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
|
||||
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
|
||||
flux_params = infer_flux_params_from_state_dict(sd)
|
||||
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
model = InstantXControlNetFlux(flux_params, num_control_modes)
|
||||
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
@@ -10,6 +10,10 @@ from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
is_state_dict_instantx_controlnet,
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
@@ -116,6 +120,7 @@ class ModelProbe(object):
|
||||
"CLIPModel": ModelType.CLIPEmbed,
|
||||
"CLIPTextModel": ModelType.CLIPEmbed,
|
||||
"T5EncoderModel": ModelType.T5Encoder,
|
||||
"FluxControlNetModel": ModelType.ControlNet,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -255,7 +260,19 @@ class ModelProbe(object):
|
||||
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
|
||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
|
||||
return ModelType.LoRA
|
||||
elif key.startswith(("controlnet", "control_model", "input_blocks")):
|
||||
elif key.startswith(
|
||||
(
|
||||
"controlnet",
|
||||
"control_model",
|
||||
"input_blocks",
|
||||
# XLabs FLUX ControlNet models have keys starting with "controlnet_blocks."
|
||||
# For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
|
||||
# TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with
|
||||
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
|
||||
# delicate.
|
||||
"controlnet_blocks",
|
||||
)
|
||||
):
|
||||
return ModelType.ControlNet
|
||||
elif key.startswith(("image_proj.", "ip_adapter.")):
|
||||
return ModelType.IPAdapter
|
||||
@@ -438,6 +455,7 @@ MODEL_NAME_TO_PREPROCESSOR = {
|
||||
"lineart": "lineart_image_processor",
|
||||
"lineart_anime": "lineart_anime_image_processor",
|
||||
"softedge": "hed_image_processor",
|
||||
"hed": "hed_image_processor",
|
||||
"shuffle": "content_shuffle_image_processor",
|
||||
"pose": "dw_openpose_image_processor",
|
||||
"mediapipe": "mediapipe_face_processor",
|
||||
@@ -449,7 +467,8 @@ MODEL_NAME_TO_PREPROCESSOR = {
|
||||
|
||||
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
||||
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
|
||||
if k in model_name:
|
||||
model_name_lower = model_name.lower()
|
||||
if k in model_name_lower:
|
||||
return ControlAdapterDefaultSettings(preprocessor=v)
|
||||
return None
|
||||
|
||||
@@ -623,6 +642,11 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
if is_state_dict_xlabs_controlnet(checkpoint) or is_state_dict_instantx_controlnet(checkpoint):
|
||||
# TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing
|
||||
# get_format()?
|
||||
return BaseModelType.Flux
|
||||
|
||||
for key_name in (
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"controlnet_mid_block.bias",
|
||||
@@ -844,22 +868,19 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
|
||||
if config.get("_class_name", None) == "FluxControlNetModel":
|
||||
return BaseModelType.Flux
|
||||
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
dimension = config["cross_attention_dim"]
|
||||
base_model = (
|
||||
BaseModelType.StableDiffusion1
|
||||
if dimension == 768
|
||||
else (
|
||||
BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
)
|
||||
)
|
||||
if not base_model:
|
||||
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
|
||||
return base_model
|
||||
if dimension == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
if dimension == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
if dimension == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
|
||||
|
||||
|
||||
class LoRAFolderProbe(FolderProbeBase):
|
||||
|
||||
@@ -422,6 +422,13 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="FLUX.1-dev-Controlnet-Union-Pro",
|
||||
base=BaseModelType.Flux,
|
||||
source="Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
||||
description="A unified ControlNet for FLUX.1-dev model that supports 7 control modes, including canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6)",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
# endregion
|
||||
# region T2I Adapter
|
||||
StarterModel(
|
||||
|
||||
@@ -198,20 +198,24 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
self.disable_attention_slicing()
|
||||
return
|
||||
elif config.attention_type == "torch-sdp":
|
||||
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
# diffusers enables sdp automatically
|
||||
return
|
||||
else:
|
||||
raise Exception("torch-sdp attention slicing not available")
|
||||
# torch-sdp is the default in diffusers.
|
||||
return
|
||||
|
||||
# the remainder if this code is called when attention_type=='auto'
|
||||
# See https://github.com/invoke-ai/InvokeAI/issues/7049 for context.
|
||||
# Bumping torch from 2.2.2 to 2.4.1 caused the sliced attention implementation to produce incorrect results.
|
||||
# For now, if a user is on an MPS device and has not explicitly set the attention_type, then we select the
|
||||
# non-sliced torch-sdp implementation. This keeps things working on MPS at the cost of increased peak memory
|
||||
# utilization.
|
||||
if torch.backends.mps.is_available():
|
||||
return
|
||||
|
||||
# The remainder if this code is called when attention_type=='auto'.
|
||||
if self.unet.device.type == "cuda":
|
||||
if is_xformers_available() and prefer_xformers:
|
||||
self.enable_xformers_memory_efficient_attention()
|
||||
return
|
||||
elif hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
# diffusers enables sdp automatically
|
||||
return
|
||||
# torch-sdp is the default in diffusers.
|
||||
return
|
||||
|
||||
if self.unet.device.type == "cpu" or self.unet.device.type == "mps":
|
||||
mem_free = psutil.virtual_memory().free
|
||||
|
||||
@@ -936,7 +936,8 @@
|
||||
},
|
||||
"paramScheduler": {
|
||||
"paragraphs": [
|
||||
"\"Planer\" definiert, wie iterativ Rauschen zu einem Bild hinzugefügt wird, oder wie ein Sample bei der Ausgabe eines Modells aktualisiert wird."
|
||||
"Verwendeter Planer währende des Generierungsprozesses.",
|
||||
"Jeder Planer definiert, wie einem Bild iterativ Rauschen hinzugefügt wird, oder wie ein Sample basierend auf der Ausgabe eines Modells aktualisiert wird."
|
||||
],
|
||||
"heading": "Planer"
|
||||
},
|
||||
@@ -962,6 +963,61 @@
|
||||
},
|
||||
"ipAdapterMethod": {
|
||||
"heading": "Methode"
|
||||
},
|
||||
"refinerScheduler": {
|
||||
"heading": "Planer",
|
||||
"paragraphs": [
|
||||
"Planer, der während der Veredelungsphase des Generierungsprozesses verwendet wird.",
|
||||
"Ähnlich wie der Generierungsplaner."
|
||||
]
|
||||
},
|
||||
"compositingCoherenceMode": {
|
||||
"paragraphs": [
|
||||
"Verwendete Methode zur Erstellung eines kohärenten Bildes mit dem neu generierten maskierten Bereich."
|
||||
],
|
||||
"heading": "Modus"
|
||||
},
|
||||
"compositingCoherencePass": {
|
||||
"heading": "Kohärenzdurchlauf"
|
||||
},
|
||||
"controlNet": {
|
||||
"heading": "ControlNet"
|
||||
},
|
||||
"compositingMaskAdjustments": {
|
||||
"paragraphs": [
|
||||
"Die Maske anpassen."
|
||||
],
|
||||
"heading": "Maskenanpassungen"
|
||||
},
|
||||
"compositingMaskBlur": {
|
||||
"paragraphs": [
|
||||
"Der Unschärferadius der Maske."
|
||||
],
|
||||
"heading": "Maskenunschärfe"
|
||||
},
|
||||
"compositingBlurMethod": {
|
||||
"paragraphs": [
|
||||
"Die auf den maskierten Bereich angewendete Unschärfemethode."
|
||||
],
|
||||
"heading": "Unschärfemethode"
|
||||
},
|
||||
"controlNetResizeMode": {
|
||||
"heading": "Größenänderungsmodus"
|
||||
},
|
||||
"paramWidth": {
|
||||
"heading": "Breite",
|
||||
"paragraphs": [
|
||||
"Breite des generierten Bildes. Muss ein Vielfaches von 8 sein."
|
||||
]
|
||||
},
|
||||
"controlNetControlMode": {
|
||||
"heading": "Kontrollmodus"
|
||||
},
|
||||
"controlNetProcessor": {
|
||||
"heading": "Prozessor"
|
||||
},
|
||||
"patchmatchDownScaleSize": {
|
||||
"heading": "Herunterskalieren"
|
||||
}
|
||||
},
|
||||
"invocationCache": {
|
||||
@@ -1080,7 +1136,8 @@
|
||||
"workflowContact": "Kontaktdaten",
|
||||
"workflowNotes": "Notizen",
|
||||
"workflowTags": "Tags",
|
||||
"workflowVersion": "Version"
|
||||
"workflowVersion": "Version",
|
||||
"saveToGallery": "In Galerie speichern"
|
||||
},
|
||||
"hrf": {
|
||||
"enableHrf": "Korrektur für hohe Auflösungen",
|
||||
@@ -1250,7 +1307,16 @@
|
||||
"searchByName": "Nach Name suchen",
|
||||
"promptTemplateCleared": "Promptvorlage gelöscht",
|
||||
"preview": "Vorschau",
|
||||
"positivePrompt": "Positiv-Prompt"
|
||||
"positivePrompt": "Positiv-Prompt",
|
||||
"active": "Aktiv",
|
||||
"deleteTemplate2": "Sind Sie sicher, dass Sie diese Vorlage löschen möchten? Dies kann nicht rückgängig gemacht werden.",
|
||||
"deleteTemplate": "Vorlage löschen",
|
||||
"copyTemplate": "Vorlage kopieren",
|
||||
"editTemplate": "Vorlage bearbeiten",
|
||||
"deleteImage": "Bild löschen",
|
||||
"defaultTemplates": "Standardvorlagen",
|
||||
"nameColumn": "'name'",
|
||||
"exportDownloaded": "Export heruntergeladen"
|
||||
},
|
||||
"newUserExperience": {
|
||||
"gettingStartedSeries": "Wünschen Sie weitere Anleitungen? In unserer <LinkComponent>Einführungsserie</LinkComponent> finden Sie Tipps, wie Sie das Potenzial von Invoke Studio voll ausschöpfen können.",
|
||||
@@ -1263,13 +1329,22 @@
|
||||
"bbox": "Bbox"
|
||||
},
|
||||
"transform": {
|
||||
"fitToBbox": "An Bbox anpassen"
|
||||
"fitToBbox": "An Bbox anpassen",
|
||||
"reset": "Zurücksetzen",
|
||||
"apply": "Anwenden",
|
||||
"cancel": "Abbrechen"
|
||||
},
|
||||
"pullBboxIntoLayerError": "Problem, Bbox in die Ebene zu ziehen",
|
||||
"pullBboxIntoLayer": "Bbox in Ebene ziehen",
|
||||
"HUD": {
|
||||
"bbox": "Bbox",
|
||||
"scaledBbox": "Skalierte Bbox"
|
||||
"scaledBbox": "Skalierte Bbox",
|
||||
"entityStatus": {
|
||||
"isHidden": "{{title}} ist ausgeblendet",
|
||||
"isDisabled": "{{title}} ist deaktiviert",
|
||||
"isLocked": "{{title}} ist gesperrt",
|
||||
"isEmpty": "{{title}} ist leer"
|
||||
}
|
||||
},
|
||||
"fitBboxToLayers": "Bbox an Ebenen anpassen",
|
||||
"pullBboxIntoReferenceImage": "Bbox ins Referenzbild ziehen",
|
||||
@@ -1279,7 +1354,12 @@
|
||||
"clipToBbox": "Pinselstriche auf Bbox beschränken",
|
||||
"canvasContextMenu": {
|
||||
"saveBboxToGallery": "Bbox in Galerie speichern",
|
||||
"bboxGroup": "Aus Bbox erstellen"
|
||||
"bboxGroup": "Aus Bbox erstellen",
|
||||
"canvasGroup": "Leinwand",
|
||||
"newGlobalReferenceImage": "Neues globales Referenzbild",
|
||||
"newRegionalReferenceImage": "Neues regionales Referenzbild",
|
||||
"newControlLayer": "Neue Kontroll-Ebene",
|
||||
"newRasterLayer": "Neue Raster-Ebene"
|
||||
},
|
||||
"rectangle": "Rechteck",
|
||||
"saveCanvasToGallery": "Leinwand in Galerie speichern",
|
||||
@@ -1310,7 +1390,7 @@
|
||||
"regional": "Regional",
|
||||
"newGlobalReferenceImageOk": "Globales Referenzbild erstellt",
|
||||
"savedToGalleryError": "Fehler beim Speichern in der Galerie",
|
||||
"savedToGalleryOk": "In Galerie speichern",
|
||||
"savedToGalleryOk": "In Galerie gespeichert",
|
||||
"newGlobalReferenceImageError": "Problem beim Erstellen eines globalen Referenzbilds",
|
||||
"newRegionalReferenceImageOk": "Regionales Referenzbild erstellt",
|
||||
"duplicate": "Duplizieren",
|
||||
@@ -1343,12 +1423,39 @@
|
||||
"showProgressOnCanvas": "Fortschritt auf Leinwand anzeigen",
|
||||
"controlMode": {
|
||||
"balanced": "Ausgewogen"
|
||||
}
|
||||
},
|
||||
"globalReferenceImages_withCount_hidden": "Globale Referenzbilder ({{count}} ausgeblendet)",
|
||||
"sendToGallery": "An Galerie senden",
|
||||
"stagingArea": {
|
||||
"accept": "Annehmen",
|
||||
"next": "Nächste",
|
||||
"discardAll": "Alle verwerfen",
|
||||
"discard": "Verwerfen",
|
||||
"previous": "Vorherige"
|
||||
},
|
||||
"regionalGuidance_withCount_visible": "Regionale Führung ({{count}})",
|
||||
"regionalGuidance_withCount_hidden": "Regionale Führung ({{count}} ausgeblendet)",
|
||||
"settings": {
|
||||
"snapToGrid": {
|
||||
"on": "Ein",
|
||||
"off": "Aus",
|
||||
"label": "Am Raster ausrichten"
|
||||
}
|
||||
},
|
||||
"layer_one": "Ebene",
|
||||
"layer_other": "Ebenen",
|
||||
"layer_withCount_one": "Ebene ({{count}})",
|
||||
"layer_withCount_other": "Ebenen ({{count}})"
|
||||
},
|
||||
"upsell": {
|
||||
"shareAccess": "Zugang teilen",
|
||||
"professional": "Professionell",
|
||||
"inviteTeammates": "Teamkollegen einladen",
|
||||
"professionalUpsell": "Verfügbar in der Professional Edition von Invoke. Klicken Sie hier oder besuchen Sie invoke.com/pricing für weitere Details."
|
||||
},
|
||||
"upscaling": {
|
||||
"creativity": "Kreativität",
|
||||
"structure": "Struktur",
|
||||
"scale": "Maßstab"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -285,6 +285,7 @@
|
||||
"assetsTab": "Files you’ve uploaded for use in your projects.",
|
||||
"autoAssignBoardOnClick": "Auto-Assign Board on Click",
|
||||
"autoSwitchNewImages": "Auto-Switch to New Images",
|
||||
"boardsSettings": "Boards Settings",
|
||||
"copy": "Copy",
|
||||
"currentlyInUse": "This image is currently in use in the following features:",
|
||||
"drop": "Drop",
|
||||
@@ -304,6 +305,7 @@
|
||||
"go": "Go",
|
||||
"image": "image",
|
||||
"imagesTab": "Images you’ve created and saved within Invoke.",
|
||||
"imagesSettings": "Gallery Images Settings",
|
||||
"jump": "Jump",
|
||||
"loading": "Loading",
|
||||
"newestFirst": "Newest First",
|
||||
@@ -1641,6 +1643,7 @@
|
||||
"sendToCanvas": "Send To Canvas",
|
||||
"newLayerFromImage": "New Layer from Image",
|
||||
"newCanvasFromImage": "New Canvas from Image",
|
||||
"newImg2ImgCanvasFromImage": "New Img2Img from Image",
|
||||
"copyToClipboard": "Copy to Clipboard",
|
||||
"sendToCanvasDesc": "Pressing Invoke stages your work in progress on the canvas.",
|
||||
"viewProgressInViewer": "View progress and outputs in the <Btn>Image Viewer</Btn>.",
|
||||
|
||||
@@ -1730,7 +1730,8 @@
|
||||
"mlsd_detection": {
|
||||
"score_threshold": "Soglia di punteggio",
|
||||
"distance_threshold": "Soglia di distanza",
|
||||
"description": "Genera una mappa dei segmenti di linea dal livello selezionato utilizzando il modello di rilevamento dei segmenti di linea MLSD."
|
||||
"description": "Genera una mappa dei segmenti di linea dal livello selezionato utilizzando il modello di rilevamento dei segmenti di linea MLSD.",
|
||||
"label": "Rilevamento segmenti di linea"
|
||||
},
|
||||
"content_shuffle": {
|
||||
"label": "Mescola contenuto",
|
||||
|
||||
@@ -158,7 +158,9 @@
|
||||
"move": "Двигать",
|
||||
"gallery": "Галерея",
|
||||
"openViewer": "Открыть просмотрщик",
|
||||
"closeViewer": "Закрыть просмотрщик"
|
||||
"closeViewer": "Закрыть просмотрщик",
|
||||
"imagesTab": "Изображения, созданные и сохраненные в Invoke.",
|
||||
"assetsTab": "Файлы, которые вы загрузили для использования в своих проектах."
|
||||
},
|
||||
"hotkeys": {
|
||||
"searchHotkeys": "Поиск горячих клавиш",
|
||||
@@ -928,7 +930,10 @@
|
||||
"imageAccessError": "Невозможно найти изображение {{image_name}}, сбрасываем на значение по умолчанию",
|
||||
"boardAccessError": "Невозможно найти доску {{board_id}}, сбрасываем на значение по умолчанию",
|
||||
"modelAccessError": "Невозможно найти модель {{key}}, сброс на модель по умолчанию",
|
||||
"saveToGallery": "Сохранить в галерею"
|
||||
"saveToGallery": "Сохранить в галерею",
|
||||
"noWorkflows": "Нет рабочих процессов",
|
||||
"noMatchingWorkflows": "Нет совпадающих рабочих процессов",
|
||||
"workflowHelpText": "Нужна помощь? Ознакомьтесь с нашим руководством <LinkComponent>Getting Started with Workflows</LinkComponent>"
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Авто добавление Доски",
|
||||
@@ -1553,7 +1558,10 @@
|
||||
"autoLayout": "Автоматическое расположение",
|
||||
"userWorkflows": "Пользовательские рабочие процессы",
|
||||
"projectWorkflows": "Рабочие процессы проекта",
|
||||
"defaultWorkflows": "Стандартные рабочие процессы"
|
||||
"defaultWorkflows": "Стандартные рабочие процессы",
|
||||
"deleteWorkflow2": "Вы уверены, что хотите удалить этот рабочий процесс? Это нельзя отменить.",
|
||||
"chooseWorkflowFromLibrary": "Выбрать рабочий процесс из библиотеки",
|
||||
"uploadAndSaveWorkflow": "Загрузить в библиотеку"
|
||||
},
|
||||
"hrf": {
|
||||
"enableHrf": "Включить исправление высокого разрешения",
|
||||
@@ -1872,8 +1880,8 @@
|
||||
"duplicate": "Дублировать",
|
||||
"inpaintMasks_withCount_visible": "Маски перерисовки ({{count}})",
|
||||
"layer_one": "Слой",
|
||||
"layer_few": "",
|
||||
"layer_many": "",
|
||||
"layer_few": "Слоя",
|
||||
"layer_many": "Слоев",
|
||||
"prompt": "Запрос",
|
||||
"negativePrompt": "Исключающий запрос",
|
||||
"beginEndStepPercentShort": "Начало/конец %",
|
||||
@@ -2035,7 +2043,7 @@
|
||||
"whatsNewInInvoke": "Что нового в Invoke"
|
||||
},
|
||||
"newUserExperience": {
|
||||
"toGetStarted": "Чтобы начать работу, введите в поле запрос и нажмите <StrongComponent>Invoke</StrongComponent>, чтобы сгенерировать первое изображение. Вы можете сохранить изображения непосредственно в <StrongComponent>Галерею</StrongComponent> или отредактировать их на <StrongComponent>Холсте</StrongComponent>.",
|
||||
"toGetStarted": "Чтобы начать работу, введите в поле запрос и нажмите <StrongComponent>Invoke</StrongComponent>, чтобы сгенерировать первое изображение. Выберите шаблон запроса, чтобы улучшить результаты. Вы можете сохранить изображения непосредственно в <StrongComponent>Галерею</StrongComponent> или отредактировать их на <StrongComponent>Холсте</StrongComponent>.",
|
||||
"gettingStartedSeries": "Хотите получить больше рекомендаций? Ознакомьтесь с нашей серией <LinkComponent>Getting Started Series</LinkComponent> для получения советов по раскрытию всего потенциала Invoke Studio."
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
|
||||
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
|
||||
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
|
||||
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||
import { ShareWorkflowModal } from 'features/nodes/components/sidePanel/WorkflowListMenu/ShareWorkflowModal';
|
||||
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
@@ -120,6 +121,7 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
|
||||
<GlobalImageHotkeys />
|
||||
<NewGallerySessionDialog />
|
||||
<NewCanvasSessionDialog />
|
||||
<ImageContextMenu />
|
||||
</ErrorBoundary>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -4,9 +4,9 @@ import { IAILoadingImageFallback, IAINoContentFallback } from 'common/components
|
||||
import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay';
|
||||
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
import type { MouseEvent, ReactElement, ReactNode, SyntheticEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { PiImageBold, PiUploadSimpleBold } from 'react-icons/pi';
|
||||
import type { ImageDTO, PostUploadAction } from 'services/api/types';
|
||||
|
||||
@@ -17,7 +17,14 @@ const defaultUploadElement = <Icon as={PiUploadSimpleBold} boxSize={16} />;
|
||||
|
||||
const defaultNoContentFallback = <IAINoContentFallback icon={PiImageBold} />;
|
||||
|
||||
const baseStyles: SystemStyleObject = {
|
||||
touchAction: 'none',
|
||||
userSelect: 'none',
|
||||
webkitUserSelect: 'none',
|
||||
};
|
||||
|
||||
const sx: SystemStyleObject = {
|
||||
...baseStyles,
|
||||
'.gallery-image-container::before': {
|
||||
content: '""',
|
||||
display: 'inline-block',
|
||||
@@ -102,59 +109,10 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
useThumbailFallback,
|
||||
withHoverOverlay = false,
|
||||
children,
|
||||
onMouseOver,
|
||||
onMouseOut,
|
||||
dataTestId,
|
||||
...rest
|
||||
} = props;
|
||||
|
||||
const handleMouseOver = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
if (onMouseOver) {
|
||||
onMouseOver(e);
|
||||
}
|
||||
},
|
||||
[onMouseOver]
|
||||
);
|
||||
const handleMouseOut = useCallback(
|
||||
(e: MouseEvent<HTMLDivElement>) => {
|
||||
if (onMouseOut) {
|
||||
onMouseOut(e);
|
||||
}
|
||||
},
|
||||
[onMouseOut]
|
||||
);
|
||||
|
||||
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
|
||||
postUploadAction,
|
||||
isDisabled: isUploadDisabled,
|
||||
});
|
||||
|
||||
const uploadButtonStyles = useMemo<SystemStyleObject>(() => {
|
||||
const styles: SystemStyleObject = {
|
||||
minH: minSize,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.1s',
|
||||
color: 'base.500',
|
||||
};
|
||||
if (!isUploadDisabled) {
|
||||
Object.assign(styles, {
|
||||
cursor: 'pointer',
|
||||
bg: 'base.700',
|
||||
_hover: {
|
||||
bg: 'base.650',
|
||||
color: 'base.300',
|
||||
},
|
||||
});
|
||||
}
|
||||
return styles;
|
||||
}, [isUploadDisabled, minSize]);
|
||||
|
||||
const openInNewTab = useCallback(
|
||||
(e: MouseEvent) => {
|
||||
if (!imageDTO) {
|
||||
@@ -168,76 +126,126 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
[imageDTO]
|
||||
);
|
||||
|
||||
const ref = useRef<HTMLDivElement>(null);
|
||||
useImageContextMenu(imageDTO, ref);
|
||||
|
||||
return (
|
||||
<ImageContextMenu imageDTO={imageDTO}>
|
||||
{(ref) => (
|
||||
<Flex
|
||||
ref={ref}
|
||||
width="full"
|
||||
height="full"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
position="relative"
|
||||
minW={minSize ? minSize : undefined}
|
||||
minH={minSize ? minSize : undefined}
|
||||
userSelect="none"
|
||||
cursor={isDragDisabled || !imageDTO ? 'default' : 'pointer'}
|
||||
sx={withHoverOverlay ? sx : baseStyles}
|
||||
data-selected={isSelectedForCompare ? 'selectedForCompare' : isSelected ? 'selected' : undefined}
|
||||
{...rest}
|
||||
>
|
||||
{imageDTO && (
|
||||
<Flex
|
||||
ref={ref}
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
width="full"
|
||||
height="full"
|
||||
className="gallery-image-container"
|
||||
w="full"
|
||||
h="full"
|
||||
position={fitContainer ? 'absolute' : 'relative'}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
position="relative"
|
||||
minW={minSize ? minSize : undefined}
|
||||
minH={minSize ? minSize : undefined}
|
||||
userSelect="none"
|
||||
cursor={isDragDisabled || !imageDTO ? 'default' : 'pointer'}
|
||||
sx={withHoverOverlay ? sx : undefined}
|
||||
data-selected={isSelectedForCompare ? 'selectedForCompare' : isSelected ? 'selected' : undefined}
|
||||
{...rest}
|
||||
>
|
||||
{imageDTO && (
|
||||
<Flex
|
||||
className="gallery-image-container"
|
||||
w="full"
|
||||
h="full"
|
||||
position={fitContainer ? 'absolute' : 'relative'}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<Image
|
||||
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
|
||||
fallbackStrategy="beforeLoadOrError"
|
||||
fallbackSrc={useThumbailFallback ? imageDTO.thumbnail_url : undefined}
|
||||
fallback={useThumbailFallback ? undefined : <IAILoadingImageFallback image={imageDTO} />}
|
||||
onError={onError}
|
||||
draggable={false}
|
||||
w={imageDTO.width}
|
||||
objectFit="contain"
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
borderRadius="base"
|
||||
sx={imageSx}
|
||||
data-testid={dataTestId}
|
||||
/>
|
||||
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
|
||||
</Flex>
|
||||
)}
|
||||
{!imageDTO && !isUploadDisabled && (
|
||||
<>
|
||||
<Flex sx={uploadButtonStyles} {...getUploadButtonProps()}>
|
||||
<input {...getUploadInputProps()} />
|
||||
{uploadElement}
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
{!imageDTO && isUploadDisabled && noContentFallback}
|
||||
{imageDTO && !isDragDisabled && (
|
||||
<IAIDraggable
|
||||
data={draggableData}
|
||||
disabled={isDragDisabled || !imageDTO}
|
||||
onClick={onClick}
|
||||
onAuxClick={openInNewTab}
|
||||
/>
|
||||
)}
|
||||
{children}
|
||||
{!isDropDisabled && <IAIDroppable data={droppableData} disabled={isDropDisabled} dropLabel={dropLabel} />}
|
||||
<Image
|
||||
src={thumbnail ? imageDTO.thumbnail_url : imageDTO.image_url}
|
||||
fallbackStrategy="beforeLoadOrError"
|
||||
fallbackSrc={useThumbailFallback ? imageDTO.thumbnail_url : undefined}
|
||||
fallback={useThumbailFallback ? undefined : <IAILoadingImageFallback image={imageDTO} />}
|
||||
onError={onError}
|
||||
draggable={false}
|
||||
w={imageDTO.width}
|
||||
objectFit="contain"
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
borderRadius="base"
|
||||
sx={imageSx}
|
||||
data-testid={dataTestId}
|
||||
/>
|
||||
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
|
||||
</Flex>
|
||||
)}
|
||||
</ImageContextMenu>
|
||||
{!imageDTO && !isUploadDisabled && (
|
||||
<UploadButton
|
||||
isUploadDisabled={isUploadDisabled}
|
||||
postUploadAction={postUploadAction}
|
||||
uploadElement={uploadElement}
|
||||
minSize={minSize}
|
||||
/>
|
||||
)}
|
||||
{!imageDTO && isUploadDisabled && noContentFallback}
|
||||
{imageDTO && !isDragDisabled && (
|
||||
<IAIDraggable
|
||||
data={draggableData}
|
||||
disabled={isDragDisabled || !imageDTO}
|
||||
onClick={onClick}
|
||||
onAuxClick={openInNewTab}
|
||||
/>
|
||||
)}
|
||||
{children}
|
||||
{!isDropDisabled && <IAIDroppable data={droppableData} disabled={isDropDisabled} dropLabel={dropLabel} />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IAIDndImage);
|
||||
|
||||
const UploadButton = memo(
|
||||
({
|
||||
isUploadDisabled,
|
||||
postUploadAction,
|
||||
uploadElement,
|
||||
minSize,
|
||||
}: {
|
||||
isUploadDisabled: boolean;
|
||||
postUploadAction?: PostUploadAction;
|
||||
uploadElement: ReactNode;
|
||||
minSize: number;
|
||||
}) => {
|
||||
const { getUploadButtonProps, getUploadInputProps } = useImageUploadButton({
|
||||
postUploadAction,
|
||||
isDisabled: isUploadDisabled,
|
||||
});
|
||||
|
||||
const uploadButtonStyles = useMemo<SystemStyleObject>(() => {
|
||||
const styles: SystemStyleObject = {
|
||||
minH: minSize,
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
borderRadius: 'base',
|
||||
transitionProperty: 'common',
|
||||
transitionDuration: '0.1s',
|
||||
color: 'base.500',
|
||||
};
|
||||
if (!isUploadDisabled) {
|
||||
Object.assign(styles, {
|
||||
cursor: 'pointer',
|
||||
bg: 'base.700',
|
||||
_hover: {
|
||||
bg: 'base.650',
|
||||
color: 'base.300',
|
||||
},
|
||||
});
|
||||
}
|
||||
return styles;
|
||||
}, [isUploadDisabled, minSize]);
|
||||
|
||||
return (
|
||||
<Flex sx={uploadButtonStyles} {...getUploadButtonProps()}>
|
||||
<input {...getUploadInputProps()} />
|
||||
{uploadElement}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
UploadButton.displayName = 'UploadButton';
|
||||
|
||||
@@ -9,7 +9,6 @@ import {
|
||||
isModalOpenChanged,
|
||||
selectChangeBoardModalSlice,
|
||||
} from 'features/changeBoardModal/store/slice';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||
@@ -29,8 +28,7 @@ const ChangeBoardModal = () => {
|
||||
useAssertSingleton('ChangeBoardModal');
|
||||
const dispatch = useAppDispatch();
|
||||
const [selectedBoard, setSelectedBoard] = useState<string | null>();
|
||||
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
|
||||
const { data: boards, isFetching } = useListAllBoardsQuery(queryArgs);
|
||||
const { data: boards, isFetching } = useListAllBoardsQuery({ include_archived: true });
|
||||
const isModalOpen = useAppSelector(selectIsModalOpen);
|
||||
const imagesToChange = useAppSelector(selectImagesToChange);
|
||||
const [addImagesToBoard] = useAddImagesToBoardMutation();
|
||||
|
||||
@@ -80,7 +80,6 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addControlLayer}
|
||||
isDisabled={isFLUX}
|
||||
>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</Button>
|
||||
|
||||
@@ -56,7 +56,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
<MenuGroup title={t('controlLayers.layer_other')}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer} isDisabled={isFLUX}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>
|
||||
|
||||
@@ -99,7 +99,6 @@ const PanelTabs = memo(() => {
|
||||
<Box as="span" w="full">
|
||||
{layersTabLabel}
|
||||
</Box>
|
||||
{dndCtx.active && <Box position="absolute" top={0} left={0} right={0} bottom={0} border="2px solid red" />}
|
||||
</Tab>
|
||||
<Tab position="relative" onMouseOver={onOnMouseOverGalleryTab} onMouseOut={onMouseOut}>
|
||||
{t('gallery.gallery')}
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
controlLayerModelChanged,
|
||||
controlLayerWeightChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -42,6 +43,7 @@ export const ControlLayerControlAdapter = memo(() => {
|
||||
const entityIdentifier = useEntityIdentifierContext('control_layer');
|
||||
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
|
||||
const filter = useEntityFilter(entityIdentifier);
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
|
||||
const onChangeBeginEndStepPct = useCallback(
|
||||
(beginEndStepPct: [number, number]) => {
|
||||
@@ -117,7 +119,7 @@ export const ControlLayerControlAdapter = memo(() => {
|
||||
</Flex>
|
||||
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
|
||||
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
{controlAdapter.type === 'controlnet' && (
|
||||
{controlAdapter.type === 'controlnet' && !isFLUX && (
|
||||
<ControlLayerControlAdapterControlMode
|
||||
controlMode={controlAdapter.controlMode}
|
||||
onChange={onChangeControlMode}
|
||||
|
||||
@@ -18,7 +18,7 @@ export const ControlLayerMenuItems = memo(() => {
|
||||
<IconMenuItemGroup>
|
||||
<CanvasEntityMenuItemsArrange />
|
||||
<CanvasEntityMenuItemsDuplicate />
|
||||
<CanvasEntityMenuItemsDelete />
|
||||
<CanvasEntityMenuItemsDelete asIcon />
|
||||
</IconMenuItemGroup>
|
||||
<MenuDivider />
|
||||
<CanvasEntityMenuItemsTransform />
|
||||
|
||||
@@ -9,7 +9,7 @@ export const IPAdapterMenuItems = memo(() => {
|
||||
<IconMenuItemGroup>
|
||||
<CanvasEntityMenuItemsArrange />
|
||||
<CanvasEntityMenuItemsDuplicate />
|
||||
<CanvasEntityMenuItemsDelete />
|
||||
<CanvasEntityMenuItemsDelete asIcon />
|
||||
</IconMenuItemGroup>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -13,7 +13,7 @@ export const InpaintMaskMenuItems = memo(() => {
|
||||
<IconMenuItemGroup>
|
||||
<CanvasEntityMenuItemsArrange />
|
||||
<CanvasEntityMenuItemsDuplicate />
|
||||
<CanvasEntityMenuItemsDelete />
|
||||
<CanvasEntityMenuItemsDelete asIcon />
|
||||
</IconMenuItemGroup>
|
||||
<MenuDivider />
|
||||
<CanvasEntityMenuItemsTransform />
|
||||
|
||||
@@ -17,7 +17,7 @@ export const RasterLayerMenuItems = memo(() => {
|
||||
<IconMenuItemGroup>
|
||||
<CanvasEntityMenuItemsArrange />
|
||||
<CanvasEntityMenuItemsDuplicate />
|
||||
<CanvasEntityMenuItemsDelete />
|
||||
<CanvasEntityMenuItemsDelete asIcon />
|
||||
</IconMenuItemGroup>
|
||||
<MenuDivider />
|
||||
<CanvasEntityMenuItemsTransform />
|
||||
|
||||
@@ -14,7 +14,7 @@ export const RegionalGuidanceMenuItems = memo(() => {
|
||||
<Flex gap={2}>
|
||||
<CanvasEntityMenuItemsArrange />
|
||||
<CanvasEntityMenuItemsDuplicate />
|
||||
<CanvasEntityMenuItemsDelete />
|
||||
<CanvasEntityMenuItemsDelete asIcon />
|
||||
</Flex>
|
||||
<MenuDivider />
|
||||
<RegionalGuidanceMenuItemsAddPromptsAndIPAdapter />
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { IconMenuItem } from 'common/components/IconMenuItem';
|
||||
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
|
||||
@@ -7,7 +8,11 @@ import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiTrashSimpleBold } from 'react-icons/pi';
|
||||
|
||||
export const CanvasEntityMenuItemsDelete = memo(() => {
|
||||
type Props = {
|
||||
asIcon?: boolean;
|
||||
};
|
||||
|
||||
export const CanvasEntityMenuItemsDelete = memo(({ asIcon = false }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const entityIdentifier = useEntityIdentifierContext();
|
||||
@@ -17,15 +22,23 @@ export const CanvasEntityMenuItemsDelete = memo(() => {
|
||||
dispatch(entityDeleted({ entityIdentifier }));
|
||||
}, [dispatch, entityIdentifier]);
|
||||
|
||||
if (asIcon) {
|
||||
return (
|
||||
<IconMenuItem
|
||||
aria-label={t('common.delete')}
|
||||
tooltip={t('common.delete')}
|
||||
onClick={deleteEntity}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
isDestructive
|
||||
isDisabled={!isInteractable}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<IconMenuItem
|
||||
aria-label={t('common.delete')}
|
||||
tooltip={t('common.delete')}
|
||||
onClick={deleteEntity}
|
||||
icon={<PiTrashSimpleBold />}
|
||||
isDestructive
|
||||
isDisabled={!isInteractable}
|
||||
/>
|
||||
<MenuItem onClick={deleteEntity} icon={<PiTrashSimpleBold />} isDestructive isDisabled={!isInteractable}>
|
||||
{t('common.delete')}
|
||||
</MenuItem>
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -2,8 +2,11 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { CanvasEntityAdapterBase } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityAdapterBase';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import { canvasReset } from 'features/controlLayers/store/actions';
|
||||
import {
|
||||
bboxChangedFromCanvas,
|
||||
controlLayerAdded,
|
||||
inpaintMaskAdded,
|
||||
rasterLayerAdded,
|
||||
@@ -14,19 +17,32 @@ import {
|
||||
rgPositivePromptChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectBase } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import {
|
||||
selectBboxModelBase,
|
||||
selectBboxRect,
|
||||
selectCanvasSlice,
|
||||
selectEntityOrThrow,
|
||||
} from 'features/controlLayers/store/selectors';
|
||||
import type {
|
||||
CanvasEntityIdentifier,
|
||||
CanvasRasterLayerState,
|
||||
CanvasRegionalGuidanceState,
|
||||
ControlNetConfig,
|
||||
IPAdapterConfig,
|
||||
T2IAdapterConfig,
|
||||
} from 'features/controlLayers/store/types';
|
||||
import { initialControlNet, initialIPAdapter, initialT2IAdapter } from 'features/controlLayers/store/util';
|
||||
import {
|
||||
imageDTOToImageObject,
|
||||
initialControlNet,
|
||||
initialIPAdapter,
|
||||
initialT2IAdapter,
|
||||
} from 'features/controlLayers/store/util';
|
||||
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { useCallback } from 'react';
|
||||
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
|
||||
import type { ControlNetModelConfig, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import type { ControlNetModelConfig, ImageDTO, IPAdapterModelConfig, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { isControlNetOrT2IAdapterModelConfig, isIPAdapterModelConfig } from 'services/api/types';
|
||||
|
||||
export const selectDefaultControlAdapter = createSelector(
|
||||
@@ -90,6 +106,74 @@ export const useAddRasterLayer = () => {
|
||||
return func;
|
||||
};
|
||||
|
||||
export const useNewRasterLayerFromImage = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const bboxRect = useAppSelector(selectBboxRect);
|
||||
const func = useCallback(
|
||||
(imageDTO: ImageDTO) => {
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
position: { x: bboxRect.x, y: bboxRect.y },
|
||||
objects: [imageObject],
|
||||
};
|
||||
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||
},
|
||||
[bboxRect.x, bboxRect.y, dispatch]
|
||||
);
|
||||
|
||||
return func;
|
||||
};
|
||||
|
||||
/**
|
||||
* Returns a function that adds a new canvas with the given image as the initial image, replicating the img2img flow:
|
||||
* - Reset the canvas
|
||||
* - Resize the bbox to the image's aspect ratio at the optimal size for the selected model
|
||||
* - Add the image as a raster layer
|
||||
* - Resizes the layer to fit the bbox using the 'fill' strategy
|
||||
*
|
||||
* This allows the user to immediately generate a new image from the given image without any additional steps.
|
||||
*/
|
||||
export const useNewCanvasFromImage = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const bboxRect = useAppSelector(selectBboxRect);
|
||||
const base = useAppSelector(selectBboxModelBase);
|
||||
const func = useCallback(
|
||||
(imageDTO: ImageDTO) => {
|
||||
// Calculate the new bbox dimensions to fit the image's aspect ratio at the optimal size
|
||||
const ratio = imageDTO.width / imageDTO.height;
|
||||
const optimalDimension = getOptimalDimension(base);
|
||||
const { width, height } = calculateNewSize(ratio, optimalDimension ** 2, base);
|
||||
|
||||
// The overrides need to include the layer's ID so we can transform the layer it is initialized
|
||||
const overrides = {
|
||||
id: getPrefixedId('raster_layer'),
|
||||
position: { x: bboxRect.x, y: bboxRect.y },
|
||||
objects: [imageDTOToImageObject(imageDTO)],
|
||||
} satisfies Partial<CanvasRasterLayerState>;
|
||||
|
||||
CanvasEntityAdapterBase.registerInitCallback(async (adapter) => {
|
||||
// Skip the callback if the adapter is not the one we are creating
|
||||
if (adapter.id !== overrides.id) {
|
||||
return false;
|
||||
}
|
||||
// Fit the layer to the bbox w/ fill strategy
|
||||
await adapter.transformer.startTransform({ silent: true });
|
||||
adapter.transformer.fitToBboxFill();
|
||||
await adapter.transformer.applyTransform();
|
||||
return true;
|
||||
});
|
||||
|
||||
dispatch(canvasReset());
|
||||
// The `bboxChangedFromCanvas` reducer does no validation! Careful!
|
||||
dispatch(bboxChangedFromCanvas({ x: 0, y: 0, width, height }));
|
||||
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||
},
|
||||
[base, bboxRect.x, bboxRect.y, dispatch]
|
||||
);
|
||||
|
||||
return func;
|
||||
};
|
||||
|
||||
export const useAddInpaintMask = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const func = useCallback(() => {
|
||||
|
||||
@@ -7,6 +7,7 @@ import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/ko
|
||||
import type { CanvasEntityFilterer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityFilterer';
|
||||
import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer';
|
||||
import type { CanvasEntityTransformer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityTransformer';
|
||||
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase';
|
||||
import { getKonvaNodeDebugAttrs, getRectIntersection } from 'features/controlLayers/konva/util';
|
||||
@@ -15,7 +16,8 @@ import {
|
||||
selectIsolatedTransformingPreview,
|
||||
} from 'features/controlLayers/store/canvasSettingsSlice';
|
||||
import {
|
||||
buildEntityIsHiddenSelector,
|
||||
buildSelectIsHidden,
|
||||
buildSelectIsSelected,
|
||||
selectBboxRect,
|
||||
selectCanvasSlice,
|
||||
selectEntity,
|
||||
@@ -29,6 +31,11 @@ import type { ImageDTO } from 'services/api/types';
|
||||
import stableHash from 'stable-hash';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
// Ideally, we'd type `adapter` as `CanvasEntityAdapterBase`, but the generics make this tricky. `CanvasEntityAdapter`
|
||||
// is a union of all entity adapters and is functionally identical to `CanvasEntityAdapterBase`. We'll need to do a
|
||||
// type assertion below in the `onInit` method, which calls these callbacks.
|
||||
type InitCallback = (adapter: CanvasEntityAdapter) => Promise<boolean>;
|
||||
|
||||
export abstract class CanvasEntityAdapterBase<
|
||||
T extends CanvasRenderableEntityState,
|
||||
U extends string,
|
||||
@@ -87,7 +94,79 @@ export abstract class CanvasEntityAdapterBase<
|
||||
*/
|
||||
abstract getHashableState: () => SerializableObject;
|
||||
|
||||
/**
|
||||
* Callbacks that are executed when the module is initialized.
|
||||
*/
|
||||
private static initCallbacks = new Set<InitCallback>();
|
||||
|
||||
/**
|
||||
* Register a callback to be run when an entity adapter is initialized.
|
||||
*
|
||||
* The callback is called for every adapter that is initialized with the adapter as its only argument. Use an early
|
||||
* return to skip entities that are not of interest, returning `false` to keep the callback registered. Return `true`
|
||||
* to unregister the callback after it is called.
|
||||
*
|
||||
* @param callback The callback to register.
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* // A callback that is executed once for a specific entity:
|
||||
* const myId = 'my_id';
|
||||
* canvasManager.entityRenderer.registerOnInitCallback(async (adapter) => {
|
||||
* if (adapter.id !== myId) {
|
||||
* // These are not the droids you are looking for, move along
|
||||
* return false;
|
||||
* }
|
||||
*
|
||||
* doSomething();
|
||||
*
|
||||
* // Remove the callback
|
||||
* return true;
|
||||
* });
|
||||
* ```
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* // A callback that is executed once for the next entity that is initialized:
|
||||
* canvasManager.entityRenderer.registerOnInitCallback(async (adapter) => {
|
||||
* doSomething();
|
||||
*
|
||||
* // Remove the callback
|
||||
* return true;
|
||||
* });
|
||||
* ```
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* // A callback that is executed for every entity and is never removed:
|
||||
* canvasManager.entityRenderer.registerOnInitCallback(async (adapter) => {
|
||||
* // Do something with the adapter
|
||||
* return false;
|
||||
* });
|
||||
*/
|
||||
static registerInitCallback = (callback: InitCallback) => {
|
||||
const wrapped = async (adapter: CanvasEntityAdapter) => {
|
||||
const result = await callback(adapter);
|
||||
if (result) {
|
||||
this.initCallbacks.delete(wrapped);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
this.initCallbacks.add(wrapped);
|
||||
};
|
||||
|
||||
/**
|
||||
* Runs all init callbacks with the given entity adapter.
|
||||
* @param adapter The adapter of the entity that was initialized.
|
||||
*/
|
||||
private static runInitCallbacks = (adapter: CanvasEntityAdapter) => {
|
||||
for (const callback of this.initCallbacks) {
|
||||
callback(adapter);
|
||||
}
|
||||
};
|
||||
|
||||
selectIsHidden: Selector<RootState, boolean>;
|
||||
selectIsSelected: Selector<RootState, boolean>;
|
||||
|
||||
/**
|
||||
* The Konva nodes that make up the entity adapter:
|
||||
@@ -171,7 +250,8 @@ export abstract class CanvasEntityAdapterBase<
|
||||
assert(state !== undefined, 'Missing entity state on creation');
|
||||
this.state = state;
|
||||
|
||||
this.selectIsHidden = buildEntityIsHiddenSelector(this.entityIdentifier);
|
||||
this.selectIsHidden = buildSelectIsHidden(this.entityIdentifier);
|
||||
this.selectIsSelected = buildSelectIsSelected(this.entityIdentifier);
|
||||
|
||||
/**
|
||||
* There are a number of reason we may need to show or hide a layer:
|
||||
@@ -180,6 +260,7 @@ export abstract class CanvasEntityAdapterBase<
|
||||
* - Staging status changes and `isolatedStagingPreview` is enabled
|
||||
* - Global filtering status changes and `isolatedFilteringPreview` is enabled
|
||||
* - Global transforming status changes and `isolatedTransformingPreview` is enabled
|
||||
* - The entity is selected or deselected (only selected and onscreen entities are rendered)
|
||||
*/
|
||||
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectIsHidden, this.syncVisibility));
|
||||
this.subscriptions.add(
|
||||
@@ -190,6 +271,7 @@ export abstract class CanvasEntityAdapterBase<
|
||||
this.manager.stateApi.createStoreSubscription(selectIsolatedTransformingPreview, this.syncVisibility)
|
||||
);
|
||||
this.subscriptions.add(this.manager.stateApi.$transformingAdapter.listen(this.syncVisibility));
|
||||
this.subscriptions.add(this.manager.stateApi.createStoreSubscription(this.selectIsSelected, this.syncVisibility));
|
||||
|
||||
/**
|
||||
* The tool preview may need to be updated when the entity is locked or disabled. For example, when we disable the
|
||||
@@ -228,21 +310,8 @@ export abstract class CanvasEntityAdapterBase<
|
||||
|
||||
syncIsOnscreen = () => {
|
||||
const stageRect = this.manager.stage.getScaledStageRect();
|
||||
const entityRect = this.transformer.$pixelRect.get();
|
||||
const position = this.manager.stateApi.runSelector(this.selectPosition);
|
||||
if (!position) {
|
||||
return;
|
||||
}
|
||||
const entityRectRelativeToStage = {
|
||||
x: entityRect.x + position.x,
|
||||
y: entityRect.y + position.y,
|
||||
width: entityRect.width,
|
||||
height: entityRect.height,
|
||||
};
|
||||
|
||||
const intersection = getRectIntersection(stageRect, entityRectRelativeToStage);
|
||||
const isOnScreen = this.checkIntersection(stageRect);
|
||||
const prevIsOnScreen = this.$isOnScreen.get();
|
||||
const isOnScreen = intersection.width > 0 && intersection.height > 0;
|
||||
this.$isOnScreen.set(isOnScreen);
|
||||
if (prevIsOnScreen !== isOnScreen) {
|
||||
this.log.trace(`Moved ${isOnScreen ? 'on-screen' : 'off-screen'}`);
|
||||
@@ -252,10 +321,19 @@ export abstract class CanvasEntityAdapterBase<
|
||||
|
||||
syncIntersectsBbox = () => {
|
||||
const bboxRect = this.manager.stateApi.getBbox().rect;
|
||||
const intersectsBbox = this.checkIntersection(bboxRect);
|
||||
const prevIntersectsBbox = this.$intersectsBbox.get();
|
||||
this.$intersectsBbox.set(intersectsBbox);
|
||||
if (prevIntersectsBbox !== intersectsBbox) {
|
||||
this.log.trace(`Moved ${intersectsBbox ? 'into bbox' : 'out of bbox'}`);
|
||||
}
|
||||
};
|
||||
|
||||
checkIntersection = (rect: Rect): boolean => {
|
||||
const entityRect = this.transformer.$pixelRect.get();
|
||||
const position = this.manager.stateApi.runSelector(this.selectPosition);
|
||||
if (!position) {
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
const entityRectRelativeToStage = {
|
||||
x: entityRect.x + position.x,
|
||||
@@ -263,14 +341,9 @@ export abstract class CanvasEntityAdapterBase<
|
||||
width: entityRect.width,
|
||||
height: entityRect.height,
|
||||
};
|
||||
|
||||
const intersection = getRectIntersection(bboxRect, entityRectRelativeToStage);
|
||||
const prevIntersectsBbox = this.$intersectsBbox.get();
|
||||
const intersectsBbox = intersection.width > 0 && intersection.height > 0;
|
||||
this.$intersectsBbox.set(intersectsBbox);
|
||||
if (prevIntersectsBbox !== intersectsBbox) {
|
||||
this.log.trace(`Moved ${intersectsBbox ? 'into bbox' : 'out of bbox'}`);
|
||||
}
|
||||
const intersection = getRectIntersection(rect, entityRectRelativeToStage);
|
||||
const doesIntersect = intersection.width > 0 && intersection.height > 0;
|
||||
return doesIntersect;
|
||||
};
|
||||
|
||||
initialize = async () => {
|
||||
@@ -299,6 +372,10 @@ export abstract class CanvasEntityAdapterBase<
|
||||
await this.renderer.initialize();
|
||||
this.syncZIndices();
|
||||
this.syncVisibility();
|
||||
|
||||
// Call the init callbacks.
|
||||
// TODO(psyche): Get rid of the cast - see note in type def for `InitCallback`.
|
||||
CanvasEntityAdapterBase.runInitCallbacks(this as CanvasEntityAdapter);
|
||||
};
|
||||
|
||||
syncZIndices = () => {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { Mutex } from 'async-mutex';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { roundToMultiple } from 'common/util/roundDownToMultiple';
|
||||
import type { CanvasEntityAdapter } from 'features/controlLayers/konva/CanvasEntity/types';
|
||||
@@ -166,6 +167,13 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
*/
|
||||
$silentTransform = atom(false);
|
||||
|
||||
/**
|
||||
* A mutex to prevent concurrent operations.
|
||||
*
|
||||
* The mutex is locked during transformation and during rect calculations which are handled in a web worker.
|
||||
*/
|
||||
transformMutex = new Mutex();
|
||||
|
||||
konva: {
|
||||
transformer: Konva.Transformer;
|
||||
proxyRect: Konva.Rect;
|
||||
@@ -424,6 +432,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
const { rect } = this.manager.stateApi.getBbox();
|
||||
const gridSize = this.manager.stateApi.getGridSize();
|
||||
const width = this.konva.proxyRect.width();
|
||||
const height = this.konva.proxyRect.height();
|
||||
const scaleX = rect.width / width;
|
||||
@@ -437,8 +446,8 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
const offsetY = (rect.height - height * scale) / 2;
|
||||
|
||||
this.konva.proxyRect.setAttrs({
|
||||
x: clamp(Math.round(rect.x + offsetX), rect.x, rect.x + rect.width),
|
||||
y: clamp(Math.round(rect.y + offsetY), rect.y, rect.y + rect.height),
|
||||
x: clamp(roundToMultiple(rect.x + offsetX, gridSize), rect.x, rect.x + rect.width),
|
||||
y: clamp(roundToMultiple(rect.y + offsetY, gridSize), rect.y, rect.y + rect.height),
|
||||
scaleX: scale,
|
||||
scaleY: scale,
|
||||
rotation: 0,
|
||||
@@ -455,6 +464,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
return;
|
||||
}
|
||||
const { rect } = this.manager.stateApi.getBbox();
|
||||
const gridSize = this.manager.stateApi.getGridSize();
|
||||
const width = this.konva.proxyRect.width();
|
||||
const height = this.konva.proxyRect.height();
|
||||
const scaleX = rect.width / width;
|
||||
@@ -468,8 +478,8 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
const offsetY = (rect.height - height * scale) / 2;
|
||||
|
||||
this.konva.proxyRect.setAttrs({
|
||||
x: Math.round(rect.x + offsetX),
|
||||
y: Math.round(rect.y + offsetY),
|
||||
x: roundToMultiple(rect.x + offsetX, gridSize),
|
||||
y: roundToMultiple(rect.y + offsetY, gridSize),
|
||||
scaleX: scale,
|
||||
scaleY: scale,
|
||||
rotation: 0,
|
||||
@@ -647,11 +657,13 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
* @param arg.silent Whether the transformation should be silent. If silent, the transform controls will not be shown,
|
||||
* so you _must_ immediately call `applyTransform` or `stopTransform` to complete the transformation.
|
||||
*/
|
||||
startTransform = (arg?: { silent: boolean }) => {
|
||||
startTransform = async (arg?: { silent: boolean }) => {
|
||||
const transformingAdapter = this.manager.stateApi.$transformingAdapter.get();
|
||||
if (transformingAdapter) {
|
||||
assert(false, `Already transforming an entity: ${transformingAdapter.id}`);
|
||||
}
|
||||
// This will be released when the transformation is stopped
|
||||
await this.transformMutex.acquire();
|
||||
this.log.debug('Starting transform');
|
||||
const { silent } = { silent: false, ...arg };
|
||||
this.$silentTransform.set(silent);
|
||||
@@ -704,6 +716,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
this.syncInteractionState();
|
||||
this.manager.stateApi.$transformingAdapter.set(null);
|
||||
this.$isProcessing.set(false);
|
||||
this.transformMutex.release();
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -807,7 +820,6 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
calculateRect = debounce(() => {
|
||||
this.log.debug('Calculating bbox');
|
||||
|
||||
this.$isPendingRectCalculation.set(true);
|
||||
const canvas = this.parent.getCanvas();
|
||||
|
||||
if (!this.parent.renderer.hasObjects()) {
|
||||
@@ -817,6 +829,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
this.parent.$canvasCache.set(canvas);
|
||||
this.$isPendingRectCalculation.set(false);
|
||||
this.updateBbox();
|
||||
this.transformMutex.release();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -829,6 +842,7 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
this.parent.$canvasCache.set(canvas);
|
||||
this.$isPendingRectCalculation.set(false);
|
||||
this.updateBbox();
|
||||
this.transformMutex.release();
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -857,11 +871,14 @@ export class CanvasEntityTransformer extends CanvasModuleBase {
|
||||
this.parent.$canvasCache.set(canvas);
|
||||
this.$isPendingRectCalculation.set(false);
|
||||
this.updateBbox();
|
||||
this.transformMutex.release();
|
||||
}
|
||||
);
|
||||
}, this.config.RECT_CALC_DEBOUNCE_MS);
|
||||
|
||||
requestRectCalculation = () => {
|
||||
requestRectCalculation = async () => {
|
||||
// This will be released when the rect calculation is complete
|
||||
await this.transformMutex.acquire();
|
||||
this.$isPendingRectCalculation.set(true);
|
||||
this.syncInteractionState();
|
||||
this.calculateRect();
|
||||
|
||||
@@ -25,7 +25,6 @@ import {
|
||||
getScaledBoundingBoxDimensions,
|
||||
} from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
|
||||
import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify';
|
||||
import type { MainModelBase } from 'features/nodes/types/common';
|
||||
import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
|
||||
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
@@ -772,11 +771,6 @@ export const canvasSlice = createSlice({
|
||||
|
||||
syncScaledSize(state);
|
||||
},
|
||||
bboxModelBaseChanged: (state, action: PayloadAction<{ modelBase: MainModelBase }>) => {
|
||||
const { modelBase } = action.payload;
|
||||
state.bbox.modelBase = modelBase;
|
||||
syncScaledSize(state);
|
||||
},
|
||||
bboxSyncedToOptimalDimension: (state) => {
|
||||
const optimalDimension = getOptimalDimension(state.bbox.modelBase);
|
||||
|
||||
|
||||
@@ -308,7 +308,7 @@ const getSelectIsTypeHidden = (type: CanvasEntityType) => {
|
||||
/**
|
||||
* Builds a selector taht selects if the entity is hidden.
|
||||
*/
|
||||
export const buildEntityIsHiddenSelector = (entityIdentifier: CanvasEntityIdentifier) => {
|
||||
export const buildSelectIsHidden = (entityIdentifier: CanvasEntityIdentifier) => {
|
||||
const selectIsTypeHidden = getSelectIsTypeHidden(entityIdentifier.type);
|
||||
return createSelector(
|
||||
[selectCanvasSlice, selectIsTypeHidden, selectIsStaging, selectIsolatedStagingPreview],
|
||||
@@ -339,6 +339,16 @@ export const buildEntityIsHiddenSelector = (entityIdentifier: CanvasEntityIdenti
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Builds a selector taht selects if the entity is selected.
|
||||
*/
|
||||
export const buildSelectIsSelected = (entityIdentifier: CanvasEntityIdentifier) => {
|
||||
return createSelector(
|
||||
selectSelectedEntityIdentifier,
|
||||
(selectedEntityIdentifier) => selectedEntityIdentifier?.id === entityIdentifier.id
|
||||
);
|
||||
};
|
||||
|
||||
export const selectWidth = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.width);
|
||||
export const selectHeight = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect.height);
|
||||
export const selectAspectRatioID = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.aspectRatio.id);
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectBoardsListOrderBy, selectBoardsListOrderDir } from 'features/gallery/store/gallerySelectors';
|
||||
import { boardsListOrderByChanged, boardsListOrderDirChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zOrderBy = z.enum(['created_at', 'board_name']);
|
||||
type OrderBy = z.infer<typeof zOrderBy>;
|
||||
const isOrderBy = (v: unknown): v is OrderBy => zOrderBy.safeParse(v).success;
|
||||
|
||||
const zDirection = z.enum(['ASC', 'DESC']);
|
||||
type Direction = z.infer<typeof zDirection>;
|
||||
const isDirection = (v: unknown): v is Direction => zDirection.safeParse(v).success;
|
||||
|
||||
export const BoardsListSortControls = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const orderBy = useAppSelector(selectBoardsListOrderBy);
|
||||
const direction = useAppSelector(selectBoardsListOrderDir);
|
||||
|
||||
const ORDER_BY_OPTIONS: ComboboxOption[] = useMemo(
|
||||
() => [
|
||||
{ value: 'created_at', label: t('workflows.created') },
|
||||
{ value: 'board_name', label: t('workflows.name') },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const DIRECTION_OPTIONS: ComboboxOption[] = useMemo(
|
||||
() => [
|
||||
{ value: 'ASC', label: t('workflows.ascending') },
|
||||
{ value: 'DESC', label: t('workflows.descending') },
|
||||
],
|
||||
[t]
|
||||
);
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onChangeOrderBy = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isOrderBy(v?.value) || v.value === orderBy) {
|
||||
return;
|
||||
}
|
||||
dispatch(boardsListOrderByChanged(v.value));
|
||||
},
|
||||
[orderBy, dispatch]
|
||||
);
|
||||
const valueOrderBy = useMemo(() => {
|
||||
return ORDER_BY_OPTIONS.find((o) => o.value === orderBy) || ORDER_BY_OPTIONS[0];
|
||||
}, [orderBy, ORDER_BY_OPTIONS]);
|
||||
|
||||
const onChangeDirection = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
if (!isDirection(v?.value) || v.value === direction) {
|
||||
return;
|
||||
}
|
||||
dispatch(boardsListOrderDirChanged(v.value));
|
||||
},
|
||||
[direction, dispatch]
|
||||
);
|
||||
const valueDirection = useMemo(
|
||||
() => DIRECTION_OPTIONS.find((o) => o.value === direction),
|
||||
[direction, DIRECTION_OPTIONS]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={4}>
|
||||
<FormControl orientation="horizontal" gap={1}>
|
||||
<FormLabel>{t('common.orderBy')}</FormLabel>
|
||||
<Combobox isSearchable={false} value={valueOrderBy} options={ORDER_BY_OPTIONS} onChange={onChangeOrderBy} />
|
||||
</FormControl>
|
||||
<FormControl orientation="horizontal" gap={1}>
|
||||
<FormLabel>{t('common.direction')}</FormLabel>
|
||||
<Combobox
|
||||
isSearchable={false}
|
||||
value={valueDirection}
|
||||
options={DIRECTION_OPTIONS}
|
||||
onChange={onChangeDirection}
|
||||
/>
|
||||
</FormControl>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,53 @@
|
||||
import {
|
||||
Box,
|
||||
Divider,
|
||||
Flex,
|
||||
IconButton,
|
||||
Popover,
|
||||
PopoverBody,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import BoardAutoAddSelect from 'features/gallery/components/Boards/BoardAutoAddSelect';
|
||||
import AutoAssignBoardCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoAssignBoardCheckbox';
|
||||
import ShowArchivedBoardsCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowArchivedBoardsCheckbox';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiGearSixFill } from 'react-icons/pi';
|
||||
|
||||
import { BoardsListSortControls } from './BoardsListSortControls';
|
||||
|
||||
const BoardsSettingsPopover = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
<Popover isLazy>
|
||||
<PopoverTrigger>
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
aria-label={t('gallery.boardsSettings')}
|
||||
icon={<PiGearSixFill />}
|
||||
tooltip={t('gallery.boardsSettings')}
|
||||
/>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
<PopoverBody>
|
||||
<Flex direction="column" gap={2}>
|
||||
<AutoAssignBoardCheckbox />
|
||||
<ShowArchivedBoardsCheckbox />
|
||||
<BoardAutoAddSelect />
|
||||
<Box py={2}>
|
||||
<Divider />
|
||||
</Box>
|
||||
|
||||
<BoardsListSortControls />
|
||||
</Flex>
|
||||
</PopoverBody>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(BoardsSettingsPopover);
|
||||
@@ -23,6 +23,7 @@ import { useTranslation } from 'react-i18next';
|
||||
import { PiMagnifyingGlassBold } from 'react-icons/pi';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
|
||||
import GallerySettingsPopover from './GallerySettingsPopover/GallerySettingsPopover';
|
||||
import GalleryImageGrid from './ImageGrid/GalleryImageGrid';
|
||||
import { GalleryPagination } from './ImageGrid/GalleryPagination';
|
||||
import { GallerySearch } from './ImageGrid/GallerySearch';
|
||||
@@ -85,15 +86,18 @@ export const Gallery = () => {
|
||||
{t('gallery.assets')}
|
||||
</Tab>
|
||||
</Tooltip>
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={handleClickSearch}
|
||||
tooltip={searchDisclosure.isOpen ? `${t('gallery.exitSearch')}` : `${t('gallery.displaySearch')}`}
|
||||
aria-label={t('gallery.displaySearch')}
|
||||
icon={<PiMagnifyingGlassBold />}
|
||||
/>
|
||||
<Flex h="full" justifyContent="flex-end">
|
||||
<GallerySettingsPopover />
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={handleClickSearch}
|
||||
tooltip={searchDisclosure.isOpen ? `${t('gallery.exitSearch')}` : `${t('gallery.displaySearch')}`}
|
||||
aria-label={t('gallery.displaySearch')}
|
||||
icon={<PiMagnifyingGlassBold />}
|
||||
/>
|
||||
</Flex>
|
||||
</TabList>
|
||||
</Tabs>
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@ import { Panel, PanelGroup } from 'react-resizable-panels';
|
||||
|
||||
import BoardsListWrapper from './Boards/BoardsList/BoardsListWrapper';
|
||||
import BoardsSearch from './Boards/BoardsList/BoardsSearch';
|
||||
import BoardsSettingsPopover from './Boards/BoardsSettingsPopover';
|
||||
import { Gallery } from './Gallery';
|
||||
import GallerySettingsPopover from './GallerySettingsPopover/GallerySettingsPopover';
|
||||
|
||||
const COLLAPSE_STYLES: CSSProperties = { flexShrink: 0, minHeight: 0 };
|
||||
|
||||
@@ -64,7 +64,7 @@ const GalleryPanelContent = () => {
|
||||
</Flex>
|
||||
<GalleryHeader />
|
||||
<Flex h="full" w="25%" justifyContent="flex-end">
|
||||
<GallerySettingsPopover />
|
||||
<BoardsSettingsPopover />
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import { Divider, Flex, IconButton, Popover, PopoverBody, PopoverContent, PopoverTrigger } from '@invoke-ai/ui-library';
|
||||
import BoardAutoAddSelect from 'features/gallery/components/Boards/BoardAutoAddSelect';
|
||||
import AlwaysShowImageSizeCheckbox from 'features/gallery/components/GallerySettingsPopover/AlwaysShowImageSizeCheckbox';
|
||||
import AutoAssignBoardCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoAssignBoardCheckbox';
|
||||
import AutoSwitchCheckbox from 'features/gallery/components/GallerySettingsPopover/AutoSwitchCheckbox';
|
||||
import ImageMinimumWidthSlider from 'features/gallery/components/GallerySettingsPopover/ImageMinimumWidthSlider';
|
||||
import ShowArchivedBoardsCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowArchivedBoardsCheckbox';
|
||||
import ShowStarredFirstCheckbox from 'features/gallery/components/GallerySettingsPopover/ShowStarredFirstCheckbox';
|
||||
import SortDirectionCombobox from 'features/gallery/components/GallerySettingsPopover/SortDirectionCombobox';
|
||||
import { memo } from 'react';
|
||||
@@ -21,8 +18,9 @@ const GallerySettingsPopover = () => {
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
aria-label={t('gallery.gallerySettings')}
|
||||
aria-label={t('gallery.imagesSettings')}
|
||||
icon={<PiGearSixFill />}
|
||||
tooltip={t('gallery.imagesSettings')}
|
||||
/>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent>
|
||||
@@ -30,10 +28,7 @@ const GallerySettingsPopover = () => {
|
||||
<Flex direction="column" gap={2}>
|
||||
<ImageMinimumWidthSlider />
|
||||
<AutoSwitchCheckbox />
|
||||
<AutoAssignBoardCheckbox />
|
||||
<AlwaysShowImageSizeCheckbox />
|
||||
<ShowArchivedBoardsCheckbox />
|
||||
<BoardAutoAddSelect />
|
||||
<Divider pt={2} />
|
||||
<ShowStarredFirstCheckbox />
|
||||
<SortDirectionCombobox />
|
||||
|
||||
@@ -1,42 +1,276 @@
|
||||
import type { ContextMenuProps } from '@invoke-ai/ui-library';
|
||||
import { ContextMenu, MenuList } from '@invoke-ai/ui-library';
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Menu, MenuButton, MenuList, Portal, useGlobalMenuClose } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
|
||||
import MultipleSelectionMenuItems from 'features/gallery/components/ImageContextMenu/MultipleSelectionMenuItems';
|
||||
import SingleSelectionMenuItems from 'features/gallery/components/ImageContextMenu/SingleSelectionMenuItems';
|
||||
import { selectSelectionCount } from 'features/gallery/store/gallerySelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { map } from 'nanostores';
|
||||
import type { RefObject } from 'react';
|
||||
import { memo, useCallback, useEffect, useRef } from 'react';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import MultipleSelectionMenuItems from './MultipleSelectionMenuItems';
|
||||
import SingleSelectionMenuItems from './SingleSelectionMenuItems';
|
||||
/**
|
||||
* The delay in milliseconds before the context menu opens on long press.
|
||||
*/
|
||||
const LONGPRESS_DELAY_MS = 500;
|
||||
/**
|
||||
* The threshold in pixels that the pointer must move before the long press is cancelled.
|
||||
*/
|
||||
const LONGPRESS_MOVE_THRESHOLD_PX = 10;
|
||||
|
||||
type Props = {
|
||||
imageDTO: ImageDTO | undefined;
|
||||
children: ContextMenuProps<HTMLDivElement>['children'];
|
||||
/**
|
||||
* The singleton state of the context menu.
|
||||
*/
|
||||
const $imageContextMenuState = map<{
|
||||
isOpen: boolean;
|
||||
imageDTO: ImageDTO | null;
|
||||
position: { x: number; y: number };
|
||||
}>({
|
||||
isOpen: false,
|
||||
imageDTO: null,
|
||||
position: { x: -1, y: -1 },
|
||||
});
|
||||
|
||||
/**
|
||||
* Convenience function to close the context menu.
|
||||
*/
|
||||
const onClose = () => {
|
||||
$imageContextMenuState.setKey('isOpen', false);
|
||||
};
|
||||
|
||||
const ImageContextMenu = ({ imageDTO, children }: Props) => {
|
||||
const selectionCount = useAppSelector(selectSelectionCount);
|
||||
/**
|
||||
* Map of elements to image DTOs. This is used to determine which image DTO to show the context menu for, depending on
|
||||
* the target of the context menu or long press event.
|
||||
*/
|
||||
const elToImageMap = new Map<HTMLDivElement, ImageDTO>();
|
||||
|
||||
/**
|
||||
* Given a target node, find the first registered parent element that contains the target node and return the imageDTO
|
||||
* associated with it.
|
||||
*/
|
||||
const getImageDTOFromMap = (target: Node): ImageDTO | undefined => {
|
||||
const entry = Array.from(elToImageMap.entries()).find((entry) => entry[0].contains(target));
|
||||
return entry?.[1];
|
||||
};
|
||||
|
||||
/**
|
||||
* Register a context menu for an image DTO on a target element.
|
||||
* @param imageDTO The image DTO to register the context menu for.
|
||||
* @param targetRef The ref of the target element that should trigger the context menu.
|
||||
*/
|
||||
export const useImageContextMenu = (imageDTO: ImageDTO | undefined, targetRef: RefObject<HTMLDivElement>) => {
|
||||
useEffect(() => {
|
||||
if (!targetRef.current || !imageDTO) {
|
||||
return;
|
||||
}
|
||||
const el = targetRef.current;
|
||||
elToImageMap.set(el, imageDTO);
|
||||
return () => {
|
||||
elToImageMap.delete(el);
|
||||
};
|
||||
}, [imageDTO, targetRef]);
|
||||
};
|
||||
|
||||
/**
|
||||
* Singleton component that renders the context menu for images.
|
||||
*/
|
||||
export const ImageContextMenu = memo(() => {
|
||||
useAssertSingleton('ImageContextMenu');
|
||||
const state = useStore($imageContextMenuState);
|
||||
useGlobalMenuClose(onClose);
|
||||
|
||||
return (
|
||||
<Portal>
|
||||
<Menu isOpen={state.isOpen} gutter={0} placement="auto-end" onClose={onClose}>
|
||||
<MenuButton
|
||||
aria-hidden={true}
|
||||
w={1}
|
||||
h={1}
|
||||
position="absolute"
|
||||
left={state.position.x}
|
||||
top={state.position.y}
|
||||
cursor="default"
|
||||
bg="transparent"
|
||||
_hover={_hover}
|
||||
pointerEvents="none"
|
||||
/>
|
||||
<MenuContent />
|
||||
</Menu>
|
||||
<ImageContextMenuEventLogical />
|
||||
</Portal>
|
||||
);
|
||||
});
|
||||
|
||||
ImageContextMenu.displayName = 'ImageContextMenu';
|
||||
|
||||
const _hover: ChakraProps['_hover'] = { bg: 'transparent' };
|
||||
|
||||
/**
|
||||
* A logical component that listens for context menu events and opens the context menu. It's separate from
|
||||
* ImageContextMenu component to avoid re-rendering the whole context menu on every context menu event.
|
||||
*/
|
||||
const ImageContextMenuEventLogical = memo(() => {
|
||||
const lastPositionRef = useRef<{ x: number; y: number }>({ x: -1, y: -1 });
|
||||
const longPressTimeoutRef = useRef(0);
|
||||
const animationTimeoutRef = useRef(0);
|
||||
|
||||
const onContextMenu = useCallback((e: MouseEvent | PointerEvent) => {
|
||||
if (e.shiftKey) {
|
||||
// This is a shift + right click event, which should open the native context menu
|
||||
onClose();
|
||||
return;
|
||||
}
|
||||
|
||||
const imageDTO = getImageDTOFromMap(e.target as Node);
|
||||
|
||||
const renderMenuFunc = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return null;
|
||||
// Can't find the image DTO, close the context menu
|
||||
onClose();
|
||||
return;
|
||||
}
|
||||
|
||||
if (selectionCount > 1) {
|
||||
return (
|
||||
<MenuList visibility="visible">
|
||||
<MultipleSelectionMenuItems />
|
||||
</MenuList>
|
||||
);
|
||||
// clear pending delayed open
|
||||
window.clearTimeout(animationTimeoutRef.current);
|
||||
e.preventDefault();
|
||||
|
||||
if (lastPositionRef.current.x !== e.pageX || lastPositionRef.current.y !== e.pageY) {
|
||||
// if the mouse moved, we need to close, wait for animation and reopen the menu at the new position
|
||||
if ($imageContextMenuState.get().isOpen) {
|
||||
onClose();
|
||||
}
|
||||
animationTimeoutRef.current = window.setTimeout(() => {
|
||||
// Open the menu after the animation with the new state
|
||||
$imageContextMenuState.set({
|
||||
isOpen: true,
|
||||
position: { x: e.pageX, y: e.pageY },
|
||||
imageDTO,
|
||||
});
|
||||
}, 100);
|
||||
} else {
|
||||
// else we can just open the menu at the current position w/ new state
|
||||
$imageContextMenuState.set({
|
||||
isOpen: true,
|
||||
position: { x: e.pageX, y: e.pageY },
|
||||
imageDTO,
|
||||
});
|
||||
}
|
||||
|
||||
// Always sync the last position
|
||||
lastPositionRef.current = { x: e.pageX, y: e.pageY };
|
||||
}, []);
|
||||
|
||||
// Use a long press to open the context menu on touch devices
|
||||
const onPointerDown = useCallback(
|
||||
(e: PointerEvent) => {
|
||||
if (e.pointerType === 'mouse') {
|
||||
// Bail out if it's a mouse event - this is for touch/pen only
|
||||
return;
|
||||
}
|
||||
|
||||
longPressTimeoutRef.current = window.setTimeout(() => {
|
||||
onContextMenu(e);
|
||||
}, LONGPRESS_DELAY_MS);
|
||||
|
||||
lastPositionRef.current = { x: e.pageX, y: e.pageY };
|
||||
},
|
||||
[onContextMenu]
|
||||
);
|
||||
|
||||
const onPointerMove = useCallback((e: PointerEvent) => {
|
||||
if (e.pointerType === 'mouse') {
|
||||
// Bail out if it's a mouse event - this is for touch/pen only
|
||||
return;
|
||||
}
|
||||
if (longPressTimeoutRef.current === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If the pointer has moved more than the threshold, cancel the long press
|
||||
const lastPosition = lastPositionRef.current;
|
||||
|
||||
const distanceFromLastPosition = Math.hypot(e.pageX - lastPosition.x, e.pageY - lastPosition.y);
|
||||
|
||||
if (distanceFromLastPosition > LONGPRESS_MOVE_THRESHOLD_PX) {
|
||||
clearTimeout(longPressTimeoutRef.current);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const onPointerUp = useCallback((e: PointerEvent) => {
|
||||
if (e.pointerType === 'mouse') {
|
||||
// Bail out if it's a mouse event - this is for touch/pen only
|
||||
return;
|
||||
}
|
||||
if (longPressTimeoutRef.current) {
|
||||
clearTimeout(longPressTimeoutRef.current);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const onPointerCancel = useCallback((e: PointerEvent) => {
|
||||
if (e.pointerType === 'mouse') {
|
||||
// Bail out if it's a mouse event - this is for touch/pen only
|
||||
return;
|
||||
}
|
||||
if (longPressTimeoutRef.current) {
|
||||
clearTimeout(longPressTimeoutRef.current);
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const controller = new AbortController();
|
||||
|
||||
// Context menu events
|
||||
window.addEventListener('contextmenu', onContextMenu, { signal: controller.signal });
|
||||
|
||||
// Long press events
|
||||
window.addEventListener('pointerdown', onPointerDown, { signal: controller.signal });
|
||||
window.addEventListener('pointerup', onPointerUp, { signal: controller.signal });
|
||||
window.addEventListener('pointercancel', onPointerCancel, { signal: controller.signal });
|
||||
window.addEventListener('pointermove', onPointerMove, { signal: controller.signal });
|
||||
|
||||
return () => {
|
||||
controller.abort();
|
||||
};
|
||||
}, [onContextMenu, onPointerCancel, onPointerDown, onPointerMove, onPointerUp]);
|
||||
|
||||
useEffect(
|
||||
() => () => {
|
||||
// Clean up any timeouts when we unmount
|
||||
window.clearTimeout(animationTimeoutRef.current);
|
||||
window.clearTimeout(longPressTimeoutRef.current);
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
return null;
|
||||
});
|
||||
|
||||
ImageContextMenuEventLogical.displayName = 'ImageContextMenuEventLogical';
|
||||
|
||||
// The content of the context menu, which changes based on the selection count. Split out and memoized to avoid
|
||||
// re-rendering the whole context menu too often.
|
||||
const MenuContent = memo(() => {
|
||||
const selectionCount = useAppSelector(selectSelectionCount);
|
||||
const state = useStore($imageContextMenuState);
|
||||
|
||||
if (!state.imageDTO) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (selectionCount > 1) {
|
||||
return (
|
||||
<MenuList visibility="visible">
|
||||
<SingleSelectionMenuItems imageDTO={imageDTO} />
|
||||
<MultipleSelectionMenuItems />
|
||||
</MenuList>
|
||||
);
|
||||
}, [imageDTO, selectionCount]);
|
||||
}
|
||||
|
||||
return <ContextMenu renderMenu={renderMenuFunc}>{children}</ContextMenu>;
|
||||
};
|
||||
return (
|
||||
<MenuList visibility="visible">
|
||||
<SingleSelectionMenuItems imageDTO={state.imageDTO} />
|
||||
</MenuList>
|
||||
);
|
||||
});
|
||||
|
||||
export default memo(ImageContextMenu);
|
||||
MenuContent.displayName = 'MenuContent';
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { canvasReset } from 'features/controlLayers/store/actions';
|
||||
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useNewCanvasFromImage } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
|
||||
import { toast } from 'features/toast/toast';
|
||||
@@ -14,23 +10,16 @@ import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiFileBold } from 'react-icons/pi';
|
||||
|
||||
const selectBboxRect = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect);
|
||||
|
||||
export const ImageMenuItemNewCanvasFromImage = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const imageDTO = useImageDTOContext();
|
||||
const bboxRect = useAppSelector(selectBboxRect);
|
||||
const imageViewer = useImageViewer();
|
||||
const newCanvasFromImage = useNewCanvasFromImage();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
const handleSendToCanvas = useCallback(() => {
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
position: { x: bboxRect.x, y: bboxRect.y },
|
||||
objects: [imageObject],
|
||||
};
|
||||
dispatch(canvasReset());
|
||||
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||
const onClick = useCallback(() => {
|
||||
newCanvasFromImage(imageDTO);
|
||||
dispatch(setActiveTab('canvas'));
|
||||
imageViewer.close();
|
||||
toast({
|
||||
@@ -38,10 +27,10 @@ export const ImageMenuItemNewCanvasFromImage = memo(() => {
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [bboxRect.x, bboxRect.y, dispatch, imageDTO, imageViewer, t]);
|
||||
}, [dispatch, imageDTO, imageViewer, newCanvasFromImage, t]);
|
||||
|
||||
return (
|
||||
<MenuItem icon={<PiFileBold />} onClickCapture={handleSendToCanvas}>
|
||||
<MenuItem icon={<PiFileBold />} onClickCapture={onClick} isDisabled={isBusy}>
|
||||
{t('controlLayers.newCanvasFromImage')}
|
||||
</MenuItem>
|
||||
);
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import { MenuItem } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { NewLayerIcon } from 'features/controlLayers/components/common/icons';
|
||||
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
|
||||
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
|
||||
import { useNewRasterLayerFromImage } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
|
||||
import { sentImageToCanvas } from 'features/gallery/store/actions';
|
||||
@@ -14,23 +11,17 @@ import { setActiveTab } from 'features/ui/store/uiSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selectBboxRect = createSelector(selectCanvasSlice, (canvas) => canvas.bbox.rect);
|
||||
|
||||
export const ImageMenuItemNewLayerFromImage = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const imageDTO = useImageDTOContext();
|
||||
const bboxRect = useAppSelector(selectBboxRect);
|
||||
const imageViewer = useImageViewer();
|
||||
const newRasterLayerFromImage = useNewRasterLayerFromImage();
|
||||
const isBusy = useCanvasIsBusy();
|
||||
|
||||
const handleSendToCanvas = useCallback(() => {
|
||||
const imageObject = imageDTOToImageObject(imageDTO);
|
||||
const overrides: Partial<CanvasRasterLayerState> = {
|
||||
position: { x: bboxRect.x, y: bboxRect.y },
|
||||
objects: [imageObject],
|
||||
};
|
||||
const onClick = useCallback(() => {
|
||||
dispatch(sentImageToCanvas());
|
||||
dispatch(rasterLayerAdded({ overrides, isSelected: true }));
|
||||
newRasterLayerFromImage(imageDTO);
|
||||
dispatch(setActiveTab('canvas'));
|
||||
imageViewer.close();
|
||||
toast({
|
||||
@@ -38,10 +29,10 @@ export const ImageMenuItemNewLayerFromImage = memo(() => {
|
||||
title: t('toast.sentToCanvas'),
|
||||
status: 'success',
|
||||
});
|
||||
}, [bboxRect.x, bboxRect.y, dispatch, imageDTO, imageViewer, t]);
|
||||
}, [dispatch, imageDTO, imageViewer, newRasterLayerFromImage, t]);
|
||||
|
||||
return (
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={handleSendToCanvas}>
|
||||
<MenuItem icon={<NewLayerIcon />} onClickCapture={onClick} isDisabled={isBusy}>
|
||||
{t('controlLayers.newLayerFromImage')}
|
||||
</MenuItem>
|
||||
);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { MenuDivider } from '@invoke-ai/ui-library';
|
||||
import { IconMenuItemGroup } from 'common/components/IconMenuItem';
|
||||
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import { ImageMenuItemChangeBoard } from 'features/gallery/components/ImageContextMenu/ImageMenuItemChangeBoard';
|
||||
import { ImageMenuItemCopy } from 'features/gallery/components/ImageContextMenu/ImageMenuItemCopy';
|
||||
import { ImageMenuItemDelete } from 'features/gallery/components/ImageContextMenu/ImageMenuItemDelete';
|
||||
@@ -37,8 +38,10 @@ const SingleSelectionMenuItems = ({ imageDTO }: SingleSelectionMenuItemsProps) =
|
||||
<ImageMenuItemMetadataRecallActions />
|
||||
<MenuDivider />
|
||||
<ImageMenuItemSendToUpscale />
|
||||
<ImageMenuItemNewLayerFromImage />
|
||||
<ImageMenuItemNewCanvasFromImage />
|
||||
<CanvasManagerProviderGate>
|
||||
<ImageMenuItemNewLayerFromImage />
|
||||
<ImageMenuItemNewCanvasFromImage />
|
||||
</CanvasManagerProviderGate>
|
||||
<MenuDivider />
|
||||
<ImageMenuItemChangeBoard />
|
||||
<ImageMenuItemStarUnstar />
|
||||
|
||||
@@ -63,12 +63,9 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
|
||||
() => createSelector(selectGallerySlice, (gallery) => gallery.imageToCompare?.image_name === imageDTO.image_name),
|
||||
[imageDTO.image_name]
|
||||
);
|
||||
const alwaysShowImageSizeBadge = useAppSelector(selectAlwaysShouldImageSizeBadge);
|
||||
const isSelectedForCompare = useAppSelector(selectIsSelectedForCompare);
|
||||
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
|
||||
|
||||
const customStarUi = useStore($customStarUI);
|
||||
|
||||
const imageContainerRef = useScrollIntoView(isSelected, index, areMultiplesSelected);
|
||||
|
||||
const draggableData = useMemo<TypesafeDraggableData | undefined>(() => {
|
||||
@@ -91,20 +88,6 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
|
||||
}
|
||||
}, [imageDTO, selectedBoardId, areMultiplesSelected]);
|
||||
|
||||
const [starImages] = useStarImagesMutation();
|
||||
const [unstarImages] = useUnstarImagesMutation();
|
||||
|
||||
const toggleStarredState = useCallback(() => {
|
||||
if (imageDTO) {
|
||||
if (imageDTO.starred) {
|
||||
unstarImages({ imageDTOs: [imageDTO] });
|
||||
}
|
||||
if (!imageDTO.starred) {
|
||||
starImages({ imageDTOs: [imageDTO] });
|
||||
}
|
||||
}
|
||||
}, [starImages, unstarImages, imageDTO]);
|
||||
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
const handleMouseOver = useCallback(() => {
|
||||
@@ -121,25 +104,6 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
|
||||
setIsHovered(false);
|
||||
}, []);
|
||||
|
||||
const starIcon = useMemo(() => {
|
||||
if (imageDTO.starred) {
|
||||
return customStarUi ? customStarUi.on.icon : <PiStarFill />;
|
||||
}
|
||||
if (!imageDTO.starred && isHovered) {
|
||||
return customStarUi ? customStarUi.off.icon : <PiStarBold />;
|
||||
}
|
||||
}, [imageDTO.starred, isHovered, customStarUi]);
|
||||
|
||||
const starTooltip = useMemo(() => {
|
||||
if (imageDTO.starred) {
|
||||
return customStarUi ? customStarUi.off.text : 'Unstar';
|
||||
}
|
||||
if (!imageDTO.starred) {
|
||||
return customStarUi ? customStarUi.on.text : 'Star';
|
||||
}
|
||||
return '';
|
||||
}, [imageDTO.starred, customStarUi]);
|
||||
|
||||
const dataTestId = useMemo(() => getGalleryImageDataTestId(imageDTO.image_name), [imageDTO.image_name]);
|
||||
|
||||
if (!imageDTO) {
|
||||
@@ -155,6 +119,8 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
|
||||
justifyContent="center"
|
||||
alignItems="center"
|
||||
aspectRatio="1/1"
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
>
|
||||
<IAIDndImage
|
||||
onClick={handleClick}
|
||||
@@ -169,38 +135,8 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
|
||||
isUploadDisabled={true}
|
||||
thumbnail={true}
|
||||
withHoverOverlay
|
||||
onMouseOver={handleMouseOver}
|
||||
onMouseOut={handleMouseOut}
|
||||
>
|
||||
<>
|
||||
{(isHovered || alwaysShowImageSizeBadge) && (
|
||||
<Text
|
||||
position="absolute"
|
||||
background="base.900"
|
||||
color="base.50"
|
||||
fontSize="sm"
|
||||
fontWeight="semibold"
|
||||
bottom={1}
|
||||
left={1}
|
||||
opacity={0.7}
|
||||
px={2}
|
||||
lineHeight={1.25}
|
||||
borderTopEndRadius="base"
|
||||
sx={badgeSx}
|
||||
pointerEvents="none"
|
||||
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
|
||||
)}
|
||||
<IAIDndImageIcon
|
||||
onClick={toggleStarredState}
|
||||
icon={starIcon}
|
||||
tooltip={starTooltip}
|
||||
position="absolute"
|
||||
top={2}
|
||||
insetInlineEnd={2}
|
||||
/>
|
||||
{isHovered && <DeleteIcon imageDTO={imageDTO} />}
|
||||
{isHovered && <OpenInViewerIconButton imageDTO={imageDTO} />}
|
||||
</>
|
||||
<HoverIcons imageDTO={imageDTO} isHovered={isHovered} />
|
||||
</IAIDndImage>
|
||||
</Flex>
|
||||
</Box>
|
||||
@@ -209,7 +145,21 @@ const GalleryImageContent = memo(({ index, imageDTO }: HoverableImageProps) => {
|
||||
|
||||
GalleryImageContent.displayName = 'GalleryImageContent';
|
||||
|
||||
const DeleteIcon = ({ imageDTO }: { imageDTO: ImageDTO }) => {
|
||||
const HoverIcons = memo(({ imageDTO, isHovered }: { imageDTO: ImageDTO; isHovered: boolean }) => {
|
||||
const alwaysShowImageSizeBadge = useAppSelector(selectAlwaysShouldImageSizeBadge);
|
||||
|
||||
return (
|
||||
<>
|
||||
{(isHovered || alwaysShowImageSizeBadge) && <SizeBadge imageDTO={imageDTO} />}
|
||||
{(isHovered || imageDTO.starred) && <StarIcon imageDTO={imageDTO} />}
|
||||
{isHovered && <DeleteIcon imageDTO={imageDTO} />}
|
||||
{isHovered && <OpenInViewerIconButton imageDTO={imageDTO} />}
|
||||
</>
|
||||
);
|
||||
});
|
||||
HoverIcons.displayName = 'HoverIcons';
|
||||
|
||||
const DeleteIcon = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
|
||||
const shift = useShiftModifier();
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -238,9 +188,11 @@ const DeleteIcon = ({ imageDTO }: { imageDTO: ImageDTO }) => {
|
||||
insetInlineEnd={2}
|
||||
/>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
const OpenInViewerIconButton = ({ imageDTO }: { imageDTO: ImageDTO }) => {
|
||||
DeleteIcon.displayName = 'DeleteIcon';
|
||||
|
||||
const OpenInViewerIconButton = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
|
||||
const imageViewer = useImageViewer();
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -258,4 +210,77 @@ const OpenInViewerIconButton = ({ imageDTO }: { imageDTO: ImageDTO }) => {
|
||||
insetInlineStart={2}
|
||||
/>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
OpenInViewerIconButton.displayName = 'OpenInViewerIconButton';
|
||||
|
||||
const StarIcon = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
|
||||
const customStarUi = useStore($customStarUI);
|
||||
const [starImages] = useStarImagesMutation();
|
||||
const [unstarImages] = useUnstarImagesMutation();
|
||||
|
||||
const toggleStarredState = useCallback(() => {
|
||||
if (imageDTO) {
|
||||
if (imageDTO.starred) {
|
||||
unstarImages({ imageDTOs: [imageDTO] });
|
||||
}
|
||||
if (!imageDTO.starred) {
|
||||
starImages({ imageDTOs: [imageDTO] });
|
||||
}
|
||||
}
|
||||
}, [starImages, unstarImages, imageDTO]);
|
||||
|
||||
const starIcon = useMemo(() => {
|
||||
if (imageDTO.starred) {
|
||||
return customStarUi ? customStarUi.on.icon : <PiStarFill />;
|
||||
}
|
||||
if (!imageDTO.starred) {
|
||||
return customStarUi ? customStarUi.off.icon : <PiStarBold />;
|
||||
}
|
||||
}, [imageDTO.starred, customStarUi]);
|
||||
|
||||
const starTooltip = useMemo(() => {
|
||||
if (imageDTO.starred) {
|
||||
return customStarUi ? customStarUi.off.text : 'Unstar';
|
||||
}
|
||||
if (!imageDTO.starred) {
|
||||
return customStarUi ? customStarUi.on.text : 'Star';
|
||||
}
|
||||
return '';
|
||||
}, [imageDTO.starred, customStarUi]);
|
||||
|
||||
return (
|
||||
<IAIDndImageIcon
|
||||
onClick={toggleStarredState}
|
||||
icon={starIcon}
|
||||
tooltip={starTooltip}
|
||||
position="absolute"
|
||||
top={2}
|
||||
insetInlineEnd={2}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
StarIcon.displayName = 'StarIcon';
|
||||
|
||||
const SizeBadge = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
|
||||
return (
|
||||
<Text
|
||||
position="absolute"
|
||||
background="base.900"
|
||||
color="base.50"
|
||||
fontSize="sm"
|
||||
fontWeight="semibold"
|
||||
bottom={1}
|
||||
left={1}
|
||||
opacity={0.7}
|
||||
px={2}
|
||||
lineHeight={1.25}
|
||||
borderTopEndRadius="base"
|
||||
sx={badgeSx}
|
||||
pointerEvents="none"
|
||||
>{`${imageDTO.width}x${imageDTO.height}`}</Text>
|
||||
);
|
||||
});
|
||||
|
||||
SizeBadge.displayName = 'SizeBadge';
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Button, Flex, IconButton, Spacer } from '@invoke-ai/ui-library';
|
||||
import { ELLIPSIS, useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
|
||||
import { useCallback } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||
|
||||
import { JumpTo } from './JumpTo';
|
||||
|
||||
export const GalleryPagination = () => {
|
||||
export const GalleryPagination = memo(() => {
|
||||
const { goPrev, goNext, isPrevEnabled, isNextEnabled, pageButtons, goToPage, currentPage, total } =
|
||||
useGalleryPagination();
|
||||
|
||||
@@ -47,7 +47,9 @@ export const GalleryPagination = () => {
|
||||
<JumpTo />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
GalleryPagination.displayName = 'GalleryPagination';
|
||||
|
||||
type PageButtonProps = {
|
||||
page: number | typeof ELLIPSIS;
|
||||
@@ -55,7 +57,7 @@ type PageButtonProps = {
|
||||
goToPage: (page: number) => void;
|
||||
};
|
||||
|
||||
const PageButton = ({ page, currentPage, goToPage }: PageButtonProps) => {
|
||||
const PageButton = memo(({ page, currentPage, goToPage }: PageButtonProps) => {
|
||||
if (page === ELLIPSIS) {
|
||||
return (
|
||||
<Button size="sm" variant="link" isDisabled>
|
||||
@@ -68,4 +70,6 @@ const PageButton = ({ page, currentPage, goToPage }: PageButtonProps) => {
|
||||
{page}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
PageButton.displayName = 'PageButton';
|
||||
|
||||
@@ -11,11 +11,11 @@ import {
|
||||
useDisclosure,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useGalleryPagination } from 'features/gallery/hooks/useGalleryPagination';
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
export const JumpTo = () => {
|
||||
export const JumpTo = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const { goToPage, currentPage, pages } = useGalleryPagination();
|
||||
const [newPage, setNewPage] = useState(currentPage);
|
||||
@@ -64,7 +64,7 @@ export const JumpTo = () => {
|
||||
}, [currentPage]);
|
||||
|
||||
return (
|
||||
<Popover isOpen={isOpen} onClose={onClose} onOpen={onOpen}>
|
||||
<Popover isOpen={isOpen} onClose={onClose} onOpen={onOpen} isLazy lazyBehavior="unmount">
|
||||
<PopoverTrigger>
|
||||
<Button aria-label={t('gallery.jump')} size="sm" onClick={onToggle} variant="outline">
|
||||
{t('gallery.jump')}
|
||||
@@ -94,4 +94,6 @@ export const JumpTo = () => {
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
};
|
||||
});
|
||||
|
||||
JumpTo.displayName = 'JumpTo';
|
||||
|
||||
@@ -32,6 +32,8 @@ export const selectListImagesQueryArgs = createMemoizedSelector(
|
||||
export const selectListBoardsQueryArgs = createMemoizedSelector(
|
||||
selectGallerySlice,
|
||||
(gallery): ListBoardsArgs => ({
|
||||
order_by: gallery.boardsListOrderBy,
|
||||
direction: gallery.boardsListOrderDir,
|
||||
include_archived: gallery.shouldShowArchivedBoards ? true : undefined,
|
||||
})
|
||||
);
|
||||
@@ -44,6 +46,9 @@ export const selectAutoAssignBoardOnClick = createSelector(
|
||||
);
|
||||
export const selectBoardSearchText = createSelector(selectGallerySlice, (gallery) => gallery.boardSearchText);
|
||||
export const selectSearchTerm = createSelector(selectGallerySlice, (gallery) => gallery.searchTerm);
|
||||
export const selectBoardsListOrderBy = createSelector(selectGallerySlice, (gallery) => gallery.boardsListOrderBy);
|
||||
export const selectBoardsListOrderDir = createSelector(selectGallerySlice, (gallery) => gallery.boardsListOrderDir);
|
||||
|
||||
export const selectSelectionCount = createSelector(selectGallerySlice, (gallery) => gallery.selection.length);
|
||||
export const selectHasMultipleImagesSelected = createSelector(selectSelectionCount, (count) => count > 1);
|
||||
export const selectGalleryImageMinimumWidth = createSelector(
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { isEqual, uniqBy } from 'lodash-es';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { BoardRecordOrderBy, ImageDTO } from 'services/api/types';
|
||||
|
||||
import type { BoardId, ComparisonMode, GalleryState, GalleryView, OrderDir } from './types';
|
||||
|
||||
@@ -25,6 +25,8 @@ const initialGalleryState: GalleryState = {
|
||||
comparisonMode: 'slider',
|
||||
comparisonFit: 'fill',
|
||||
shouldShowArchivedBoards: false,
|
||||
boardsListOrderBy: 'created_at',
|
||||
boardsListOrderDir: 'DESC',
|
||||
};
|
||||
|
||||
export const gallerySlice = createSlice({
|
||||
@@ -161,6 +163,12 @@ export const gallerySlice = createSlice({
|
||||
state.searchTerm = action.payload;
|
||||
state.offset = 0;
|
||||
},
|
||||
boardsListOrderByChanged: (state, action: PayloadAction<BoardRecordOrderBy>) => {
|
||||
state.boardsListOrderBy = action.payload;
|
||||
},
|
||||
boardsListOrderDirChanged: (state, action: PayloadAction<OrderDir>) => {
|
||||
state.boardsListOrderDir = action.payload;
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -186,6 +194,8 @@ export const {
|
||||
starredFirstChanged,
|
||||
shouldShowArchivedBoardsChanged,
|
||||
searchTermChanged,
|
||||
boardsListOrderByChanged,
|
||||
boardsListOrderDirChanged,
|
||||
} = gallerySlice.actions;
|
||||
|
||||
export const selectGallerySlice = (state: RootState) => state.gallery;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { ImageCategory, ImageDTO } from 'services/api/types';
|
||||
import type { BoardRecordOrderBy, ImageCategory, ImageDTO } from 'services/api/types';
|
||||
|
||||
export const IMAGE_CATEGORIES: ImageCategory[] = ['general'];
|
||||
export const ASSETS_CATEGORIES: ImageCategory[] = ['control', 'mask', 'user', 'other'];
|
||||
@@ -28,4 +28,6 @@ export type GalleryState = {
|
||||
comparisonMode: ComparisonMode;
|
||||
comparisonFit: ComparisonFit;
|
||||
shouldShowArchivedBoards: boolean;
|
||||
boardsListOrderBy: BoardRecordOrderBy;
|
||||
boardsListOrderDir: OrderDir;
|
||||
};
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, FormControl } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { fieldBoardValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { BoardFieldInputInstance, BoardFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -14,26 +13,28 @@ const BoardFieldInputComponent = (props: FieldComponentProps<BoardFieldInputInst
|
||||
const { nodeId, field } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
|
||||
const { options, hasBoards } = useListAllBoardsQuery(queryArgs, {
|
||||
selectFromResult: ({ data }) => {
|
||||
const options: ComboboxOption[] = [
|
||||
{
|
||||
label: 'None',
|
||||
value: 'none',
|
||||
},
|
||||
].concat(
|
||||
(data ?? []).map(({ board_id, board_name }) => ({
|
||||
label: board_name,
|
||||
value: board_id,
|
||||
}))
|
||||
);
|
||||
return {
|
||||
options,
|
||||
hasBoards: options.length > 1,
|
||||
};
|
||||
},
|
||||
});
|
||||
const { options, hasBoards } = useListAllBoardsQuery(
|
||||
{ include_archived: true },
|
||||
{
|
||||
selectFromResult: ({ data }) => {
|
||||
const options: ComboboxOption[] = [
|
||||
{
|
||||
label: 'None',
|
||||
value: 'none',
|
||||
},
|
||||
].concat(
|
||||
(data ?? []).map(({ board_id, board_name }) => ({
|
||||
label: board_name,
|
||||
value: board_id,
|
||||
}))
|
||||
);
|
||||
return {
|
||||
options,
|
||||
hasBoards: options.length > 1,
|
||||
};
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
const onChange = useCallback<ComboboxOnChange>(
|
||||
(v) => {
|
||||
|
||||
@@ -43,7 +43,7 @@ export const ShareWorkflowModal = () => {
|
||||
if (!workflowToShare || !projectUrl) {
|
||||
return null;
|
||||
}
|
||||
return `${projectUrl}/studio?selectedWorkflowId=${workflowToShare.workflow_id}`;
|
||||
return `${window.location.origin}/${projectUrl}/studio?selectedWorkflowId=${workflowToShare.workflow_id}`;
|
||||
}, [projectUrl, workflowToShare]);
|
||||
|
||||
const handleCopy = useCallback(() => {
|
||||
|
||||
@@ -36,8 +36,6 @@ export const addControlNets = async (
|
||||
};
|
||||
|
||||
for (const layer of validControlLayers) {
|
||||
result.addedControlNets++;
|
||||
|
||||
const getImageDTOResult = await withResultAsync(() => {
|
||||
const adapter = manager.adapters.controlLayers.get(layer.id);
|
||||
assert(adapter, 'Adapter not found');
|
||||
@@ -50,6 +48,7 @@ export const addControlNets = async (
|
||||
|
||||
const imageDTO = getImageDTOResult.value;
|
||||
addControlNetToGraph(g, layer, imageDTO, collector);
|
||||
result.addedControlNets++;
|
||||
}
|
||||
|
||||
return result;
|
||||
@@ -77,8 +76,6 @@ export const addT2IAdapters = async (
|
||||
};
|
||||
|
||||
for (const layer of validControlLayers) {
|
||||
result.addedT2IAdapters++;
|
||||
|
||||
const getImageDTOResult = await withResultAsync(() => {
|
||||
const adapter = manager.adapters.controlLayers.get(layer.id);
|
||||
assert(adapter, 'Adapter not found');
|
||||
@@ -91,6 +88,7 @@ export const addT2IAdapters = async (
|
||||
|
||||
const imageDTO = getImageDTOResult.value;
|
||||
addT2IAdapterToGraph(g, layer, imageDTO, collector);
|
||||
result.addedT2IAdapters++;
|
||||
}
|
||||
|
||||
return result;
|
||||
@@ -110,10 +108,10 @@ const addControlNetToGraph = (
|
||||
|
||||
const controlNet = g.addNode({
|
||||
id: `control_net_${id}`,
|
||||
type: 'controlnet',
|
||||
type: model.base === 'flux' ? 'flux_controlnet' : 'controlnet',
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
control_mode: controlMode,
|
||||
control_mode: model.base === 'flux' ? undefined : controlMode,
|
||||
resize_mode: 'just_resize',
|
||||
control_model: model,
|
||||
control_weight: weight,
|
||||
|
||||
@@ -19,6 +19,8 @@ import type { Invocation } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { addControlNets } from './addControlAdapters';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildFLUXGraph = async (
|
||||
@@ -93,6 +95,7 @@ export const buildFLUXGraph = async (
|
||||
> = l2i;
|
||||
|
||||
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
|
||||
g.addEdge(modelLoader, 'vae', noise, 'controlnet_vae');
|
||||
g.addEdge(modelLoader, 'vae', l2i, 'vae');
|
||||
|
||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||
@@ -177,6 +180,24 @@ export const buildFLUXGraph = async (
|
||||
);
|
||||
}
|
||||
|
||||
const controlNetCollector = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('control_net_collector'),
|
||||
});
|
||||
const controlNetResult = await addControlNets(
|
||||
manager,
|
||||
canvas.controlLayers.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
controlNetCollector,
|
||||
modelConfig.base
|
||||
);
|
||||
if (controlNetResult.addedControlNets > 0) {
|
||||
g.addEdge(controlNetCollector, 'collection', noise, 'control');
|
||||
} else {
|
||||
g.deleteNode(controlNetCollector.id);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
canvasOutput = addNSFWChecker(g, canvasOutput);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
@@ -44,10 +43,9 @@ export const checkImageAccess = async (name: string): Promise<boolean> => {
|
||||
* @returns A promise that resolves to true if the client has access, else false.
|
||||
*/
|
||||
export const checkBoardAccess = async (id: string): Promise<boolean> => {
|
||||
const { dispatch, getState } = getStore();
|
||||
const { dispatch } = getStore();
|
||||
try {
|
||||
const queryArgs = selectListBoardsQueryArgs(getState());
|
||||
const req = dispatch(boardsApi.endpoints.listAllBoards.initiate(queryArgs));
|
||||
const req = dispatch(boardsApi.endpoints.listAllBoards.initiate({ include_archived: true }));
|
||||
req.unsubscribe();
|
||||
const result = await req.unwrap();
|
||||
return result.some((b) => b.board_id === id);
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { t } from 'i18next';
|
||||
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
|
||||
|
||||
export const useBoardName = (board_id: BoardId) => {
|
||||
const queryArgs = useAppSelector(selectListBoardsQueryArgs);
|
||||
const { boardName } = useListAllBoardsQuery(queryArgs, {
|
||||
selectFromResult: ({ data }) => {
|
||||
const selectedBoard = data?.find((b) => b.board_id === board_id);
|
||||
const boardName = selectedBoard?.board_name || t('boards.uncategorized');
|
||||
const { boardName } = useListAllBoardsQuery(
|
||||
{ include_archived: true },
|
||||
{
|
||||
selectFromResult: ({ data }) => {
|
||||
const selectedBoard = data?.find((b) => b.board_id === board_id);
|
||||
const boardName = selectedBoard?.board_name || t('boards.uncategorized');
|
||||
|
||||
return { boardName };
|
||||
},
|
||||
});
|
||||
return { boardName };
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
return boardName;
|
||||
};
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -241,3 +241,5 @@ export type PostUploadAction =
|
||||
| RGIPAdapterImagePostUploadAction
|
||||
| UpscaleInitialImageAction
|
||||
| ReplaceLayerWithImagePostUploadAction;
|
||||
|
||||
export type BoardRecordOrderBy = S['BoardRecordOrderBy'];
|
||||
|
||||
@@ -24,6 +24,15 @@ export default defineConfig(({ mode }) => {
|
||||
cssInjectedByJsPlugin(),
|
||||
],
|
||||
build: {
|
||||
/**
|
||||
* zone.js (via faro) requires max ES2015 to prevent spamming unhandled promise rejections.
|
||||
*
|
||||
* See:
|
||||
* - https://github.com/grafana/faro-web-sdk/issues/566
|
||||
* - https://github.com/angular/angular/issues/51328
|
||||
* - https://github.com/open-telemetry/opentelemetry-js/issues/3030
|
||||
*/
|
||||
target: 'ES2015',
|
||||
cssCodeSplit: true,
|
||||
lib: {
|
||||
entry: path.resolve(__dirname, './src/index.ts'),
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "5.1.1"
|
||||
__version__ = "5.2.0rc1"
|
||||
|
||||
@@ -43,8 +43,8 @@ dependencies = [
|
||||
"invisible-watermark==0.2.0", # needed to install SDXL base and refiner using their repo_ids
|
||||
"mediapipe>=0.10.7", # needed for "mediapipeface" controlnet model
|
||||
"numpy==1.26.4", # >1.24.0 is needed to use the 'strict' argument to np.testing.assert_array_equal()
|
||||
"onnx>=1.15.0",
|
||||
"onnxruntime>=1.16.3",
|
||||
"onnx==1.16.1",
|
||||
"onnxruntime==1.19.2",
|
||||
"opencv-python==4.9.0.80",
|
||||
"pytorch-lightning==2.1.3",
|
||||
"safetensors==0.4.3",
|
||||
|
||||
30
scripts/extract_sd_keys_and_shapes.py
Normal file
30
scripts/extract_sd_keys_and_shapes.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def extract_sd_keys_and_shapes(safetensors_file: str):
|
||||
sd = load_file(safetensors_file)
|
||||
|
||||
keys_to_shapes = {k: v.shape for k, v in sd.items()}
|
||||
|
||||
out_file = "keys_and_shapes.json"
|
||||
with open(out_file, "w") as f:
|
||||
json.dump(keys_to_shapes, f, indent=4)
|
||||
|
||||
print(f"Keys and shapes written to '{out_file}'.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extracts the keys and shapes from the state dict in a safetensors file. Intended for creating "
|
||||
+ "dummy state dicts for use in unit tests."
|
||||
)
|
||||
parser.add_argument("safetensors_file", type=str, help="Path to the safetensors file.")
|
||||
args = parser.parse_args()
|
||||
extract_sd_keys_and_shapes(args.safetensors_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,374 @@
|
||||
# State dict keys and shapes for an InstantX FLUX ControlNet Union model. Intended to be used for unit tests.
|
||||
# These keys were extracted from:
|
||||
# https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/4f32d6f2b220f8873d49bb8acc073e1df180c994/diffusion_pytorch_model.safetensors
|
||||
instantx_sd_shapes = {
|
||||
"context_embedder.bias": [3072],
|
||||
"context_embedder.weight": [3072, 4096],
|
||||
"controlnet_blocks.0.bias": [3072],
|
||||
"controlnet_blocks.0.weight": [3072, 3072],
|
||||
"controlnet_blocks.1.bias": [3072],
|
||||
"controlnet_blocks.1.weight": [3072, 3072],
|
||||
"controlnet_blocks.2.bias": [3072],
|
||||
"controlnet_blocks.2.weight": [3072, 3072],
|
||||
"controlnet_blocks.3.bias": [3072],
|
||||
"controlnet_blocks.3.weight": [3072, 3072],
|
||||
"controlnet_blocks.4.bias": [3072],
|
||||
"controlnet_blocks.4.weight": [3072, 3072],
|
||||
"controlnet_mode_embedder.weight": [10, 3072],
|
||||
"controlnet_single_blocks.0.bias": [3072],
|
||||
"controlnet_single_blocks.0.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.1.bias": [3072],
|
||||
"controlnet_single_blocks.1.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.2.bias": [3072],
|
||||
"controlnet_single_blocks.2.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.3.bias": [3072],
|
||||
"controlnet_single_blocks.3.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.4.bias": [3072],
|
||||
"controlnet_single_blocks.4.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.5.bias": [3072],
|
||||
"controlnet_single_blocks.5.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.6.bias": [3072],
|
||||
"controlnet_single_blocks.6.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.7.bias": [3072],
|
||||
"controlnet_single_blocks.7.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.8.bias": [3072],
|
||||
"controlnet_single_blocks.8.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.9.bias": [3072],
|
||||
"controlnet_single_blocks.9.weight": [3072, 3072],
|
||||
"controlnet_x_embedder.bias": [3072],
|
||||
"controlnet_x_embedder.weight": [3072, 64],
|
||||
"single_transformer_blocks.0.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.0.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.0.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.0.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.0.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.0.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.0.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.0.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.0.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.0.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.0.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.0.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.0.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.0.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.1.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.1.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.1.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.1.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.1.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.1.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.1.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.1.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.1.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.1.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.1.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.1.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.1.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.1.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.2.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.2.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.2.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.2.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.2.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.2.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.2.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.2.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.2.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.2.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.2.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.2.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.2.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.2.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.3.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.3.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.3.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.3.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.3.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.3.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.3.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.3.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.3.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.3.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.3.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.3.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.3.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.3.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.4.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.4.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.4.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.4.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.4.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.4.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.4.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.4.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.4.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.4.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.4.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.4.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.4.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.4.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.5.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.5.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.5.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.5.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.5.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.5.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.5.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.5.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.5.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.5.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.5.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.5.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.5.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.5.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.6.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.6.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.6.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.6.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.6.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.6.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.6.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.6.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.6.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.6.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.6.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.6.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.6.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.6.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.7.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.7.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.7.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.7.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.7.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.7.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.7.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.7.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.7.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.7.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.7.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.7.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.7.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.7.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.8.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.8.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.8.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.8.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.8.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.8.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.8.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.8.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.8.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.8.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.8.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.8.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.8.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.8.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.9.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.9.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.9.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.9.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.9.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.9.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.9.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.9.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.9.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.9.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.9.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.9.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.9.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.9.proj_out.weight": [3072, 15360],
|
||||
"time_text_embed.guidance_embedder.linear_1.bias": [3072],
|
||||
"time_text_embed.guidance_embedder.linear_1.weight": [3072, 256],
|
||||
"time_text_embed.guidance_embedder.linear_2.bias": [3072],
|
||||
"time_text_embed.guidance_embedder.linear_2.weight": [3072, 3072],
|
||||
"time_text_embed.text_embedder.linear_1.bias": [3072],
|
||||
"time_text_embed.text_embedder.linear_1.weight": [3072, 768],
|
||||
"time_text_embed.text_embedder.linear_2.bias": [3072],
|
||||
"time_text_embed.text_embedder.linear_2.weight": [3072, 3072],
|
||||
"time_text_embed.timestep_embedder.linear_1.bias": [3072],
|
||||
"time_text_embed.timestep_embedder.linear_1.weight": [3072, 256],
|
||||
"time_text_embed.timestep_embedder.linear_2.bias": [3072],
|
||||
"time_text_embed.timestep_embedder.linear_2.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.add_k_proj.bias": [3072],
|
||||
"transformer_blocks.0.attn.add_k_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.add_q_proj.bias": [3072],
|
||||
"transformer_blocks.0.attn.add_q_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.add_v_proj.bias": [3072],
|
||||
"transformer_blocks.0.attn.add_v_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.norm_added_k.weight": [128],
|
||||
"transformer_blocks.0.attn.norm_added_q.weight": [128],
|
||||
"transformer_blocks.0.attn.norm_k.weight": [128],
|
||||
"transformer_blocks.0.attn.norm_q.weight": [128],
|
||||
"transformer_blocks.0.attn.to_add_out.bias": [3072],
|
||||
"transformer_blocks.0.attn.to_add_out.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.to_k.bias": [3072],
|
||||
"transformer_blocks.0.attn.to_k.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.to_out.0.bias": [3072],
|
||||
"transformer_blocks.0.attn.to_out.0.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.to_q.bias": [3072],
|
||||
"transformer_blocks.0.attn.to_q.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.to_v.bias": [3072],
|
||||
"transformer_blocks.0.attn.to_v.weight": [3072, 3072],
|
||||
"transformer_blocks.0.ff.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.0.ff.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.0.ff.net.2.bias": [3072],
|
||||
"transformer_blocks.0.ff.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.0.ff_context.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.0.ff_context.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.0.ff_context.net.2.bias": [3072],
|
||||
"transformer_blocks.0.ff_context.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.0.norm1.linear.bias": [18432],
|
||||
"transformer_blocks.0.norm1.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.0.norm1_context.linear.bias": [18432],
|
||||
"transformer_blocks.0.norm1_context.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.1.attn.add_k_proj.bias": [3072],
|
||||
"transformer_blocks.1.attn.add_k_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.add_q_proj.bias": [3072],
|
||||
"transformer_blocks.1.attn.add_q_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.add_v_proj.bias": [3072],
|
||||
"transformer_blocks.1.attn.add_v_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.norm_added_k.weight": [128],
|
||||
"transformer_blocks.1.attn.norm_added_q.weight": [128],
|
||||
"transformer_blocks.1.attn.norm_k.weight": [128],
|
||||
"transformer_blocks.1.attn.norm_q.weight": [128],
|
||||
"transformer_blocks.1.attn.to_add_out.bias": [3072],
|
||||
"transformer_blocks.1.attn.to_add_out.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.to_k.bias": [3072],
|
||||
"transformer_blocks.1.attn.to_k.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.to_out.0.bias": [3072],
|
||||
"transformer_blocks.1.attn.to_out.0.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.to_q.bias": [3072],
|
||||
"transformer_blocks.1.attn.to_q.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.to_v.bias": [3072],
|
||||
"transformer_blocks.1.attn.to_v.weight": [3072, 3072],
|
||||
"transformer_blocks.1.ff.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.1.ff.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.1.ff.net.2.bias": [3072],
|
||||
"transformer_blocks.1.ff.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.1.ff_context.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.1.ff_context.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.1.ff_context.net.2.bias": [3072],
|
||||
"transformer_blocks.1.ff_context.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.1.norm1.linear.bias": [18432],
|
||||
"transformer_blocks.1.norm1.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.1.norm1_context.linear.bias": [18432],
|
||||
"transformer_blocks.1.norm1_context.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.2.attn.add_k_proj.bias": [3072],
|
||||
"transformer_blocks.2.attn.add_k_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.add_q_proj.bias": [3072],
|
||||
"transformer_blocks.2.attn.add_q_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.add_v_proj.bias": [3072],
|
||||
"transformer_blocks.2.attn.add_v_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.norm_added_k.weight": [128],
|
||||
"transformer_blocks.2.attn.norm_added_q.weight": [128],
|
||||
"transformer_blocks.2.attn.norm_k.weight": [128],
|
||||
"transformer_blocks.2.attn.norm_q.weight": [128],
|
||||
"transformer_blocks.2.attn.to_add_out.bias": [3072],
|
||||
"transformer_blocks.2.attn.to_add_out.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.to_k.bias": [3072],
|
||||
"transformer_blocks.2.attn.to_k.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.to_out.0.bias": [3072],
|
||||
"transformer_blocks.2.attn.to_out.0.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.to_q.bias": [3072],
|
||||
"transformer_blocks.2.attn.to_q.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.to_v.bias": [3072],
|
||||
"transformer_blocks.2.attn.to_v.weight": [3072, 3072],
|
||||
"transformer_blocks.2.ff.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.2.ff.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.2.ff.net.2.bias": [3072],
|
||||
"transformer_blocks.2.ff.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.2.ff_context.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.2.ff_context.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.2.ff_context.net.2.bias": [3072],
|
||||
"transformer_blocks.2.ff_context.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.2.norm1.linear.bias": [18432],
|
||||
"transformer_blocks.2.norm1.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.2.norm1_context.linear.bias": [18432],
|
||||
"transformer_blocks.2.norm1_context.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.3.attn.add_k_proj.bias": [3072],
|
||||
"transformer_blocks.3.attn.add_k_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.add_q_proj.bias": [3072],
|
||||
"transformer_blocks.3.attn.add_q_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.add_v_proj.bias": [3072],
|
||||
"transformer_blocks.3.attn.add_v_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.norm_added_k.weight": [128],
|
||||
"transformer_blocks.3.attn.norm_added_q.weight": [128],
|
||||
"transformer_blocks.3.attn.norm_k.weight": [128],
|
||||
"transformer_blocks.3.attn.norm_q.weight": [128],
|
||||
"transformer_blocks.3.attn.to_add_out.bias": [3072],
|
||||
"transformer_blocks.3.attn.to_add_out.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.to_k.bias": [3072],
|
||||
"transformer_blocks.3.attn.to_k.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.to_out.0.bias": [3072],
|
||||
"transformer_blocks.3.attn.to_out.0.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.to_q.bias": [3072],
|
||||
"transformer_blocks.3.attn.to_q.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.to_v.bias": [3072],
|
||||
"transformer_blocks.3.attn.to_v.weight": [3072, 3072],
|
||||
"transformer_blocks.3.ff.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.3.ff.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.3.ff.net.2.bias": [3072],
|
||||
"transformer_blocks.3.ff.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.3.ff_context.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.3.ff_context.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.3.ff_context.net.2.bias": [3072],
|
||||
"transformer_blocks.3.ff_context.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.3.norm1.linear.bias": [18432],
|
||||
"transformer_blocks.3.norm1.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.3.norm1_context.linear.bias": [18432],
|
||||
"transformer_blocks.3.norm1_context.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.4.attn.add_k_proj.bias": [3072],
|
||||
"transformer_blocks.4.attn.add_k_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.add_q_proj.bias": [3072],
|
||||
"transformer_blocks.4.attn.add_q_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.add_v_proj.bias": [3072],
|
||||
"transformer_blocks.4.attn.add_v_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.norm_added_k.weight": [128],
|
||||
"transformer_blocks.4.attn.norm_added_q.weight": [128],
|
||||
"transformer_blocks.4.attn.norm_k.weight": [128],
|
||||
"transformer_blocks.4.attn.norm_q.weight": [128],
|
||||
"transformer_blocks.4.attn.to_add_out.bias": [3072],
|
||||
"transformer_blocks.4.attn.to_add_out.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.to_k.bias": [3072],
|
||||
"transformer_blocks.4.attn.to_k.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.to_out.0.bias": [3072],
|
||||
"transformer_blocks.4.attn.to_out.0.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.to_q.bias": [3072],
|
||||
"transformer_blocks.4.attn.to_q.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.to_v.bias": [3072],
|
||||
"transformer_blocks.4.attn.to_v.weight": [3072, 3072],
|
||||
"transformer_blocks.4.ff.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.4.ff.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.4.ff.net.2.bias": [3072],
|
||||
"transformer_blocks.4.ff.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.4.ff_context.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.4.ff_context.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.4.ff_context.net.2.bias": [3072],
|
||||
"transformer_blocks.4.ff_context.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.4.norm1.linear.bias": [18432],
|
||||
"transformer_blocks.4.norm1.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.4.norm1_context.linear.bias": [18432],
|
||||
"transformer_blocks.4.norm1_context.linear.weight": [18432, 3072],
|
||||
"x_embedder.bias": [3072],
|
||||
"x_embedder.weight": [3072, 64],
|
||||
}
|
||||
|
||||
|
||||
# InstantX FLUX ControlNet config for unit tests.
|
||||
# Copied from https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/main/config.json
|
||||
instantx_config = {
|
||||
"_class_name": "FluxControlNetModel",
|
||||
"_diffusers_version": "0.30.0.dev0",
|
||||
"_name_or_path": "/mnt/wangqixun/",
|
||||
"attention_head_dim": 128,
|
||||
"axes_dims_rope": [16, 56, 56],
|
||||
"guidance_embeds": True,
|
||||
"in_channels": 64,
|
||||
"joint_attention_dim": 4096,
|
||||
"num_attention_heads": 24,
|
||||
"num_layers": 5,
|
||||
"num_mode": 10,
|
||||
"num_single_layers": 10,
|
||||
"patch_size": 1,
|
||||
"pooled_projection_dim": 768,
|
||||
}
|
||||
108
tests/backend/flux/controlnet/test_state_dict_utils.py
Normal file
108
tests/backend/flux/controlnet/test_state_dict_utils.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
convert_diffusers_instantx_state_dict_to_bfl_format,
|
||||
infer_flux_params_from_state_dict,
|
||||
infer_instantx_num_control_modes_from_state_dict,
|
||||
is_state_dict_instantx_controlnet,
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from tests.backend.flux.controlnet.instantx_flux_controlnet_state_dict import instantx_config, instantx_sd_shapes
|
||||
from tests.backend.flux.controlnet.xlabs_flux_controlnet_state_dict import xlabs_sd_shapes
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["sd_shapes", "expected"],
|
||||
[
|
||||
(xlabs_sd_shapes, True),
|
||||
(instantx_sd_shapes, False),
|
||||
(["foo"], False),
|
||||
],
|
||||
)
|
||||
def test_is_state_dict_xlabs_controlnet(sd_shapes: dict[str, list[int]], expected: bool):
|
||||
sd = {k: None for k in sd_shapes}
|
||||
assert is_state_dict_xlabs_controlnet(sd) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["sd_keys", "expected"],
|
||||
[
|
||||
(instantx_sd_shapes, True),
|
||||
(xlabs_sd_shapes, False),
|
||||
(["foo"], False),
|
||||
],
|
||||
)
|
||||
def test_is_state_dict_instantx_controlnet(sd_keys: list[str], expected: bool):
|
||||
sd = {k: None for k in sd_keys}
|
||||
assert is_state_dict_instantx_controlnet(sd) == expected
|
||||
|
||||
|
||||
def test_convert_diffusers_instantx_state_dict_to_bfl_format():
|
||||
"""Smoke test convert_diffusers_instantx_state_dict_to_bfl_format() to ensure that it handles all of the keys."""
|
||||
sd = {k: torch.zeros(1) for k in instantx_sd_shapes}
|
||||
bfl_sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
|
||||
assert bfl_sd is not None
|
||||
|
||||
|
||||
# TODO(ryand): Figure out why some tests in this file are failing on the MacOS CI runners. It seems to be related to
|
||||
# using the meta device. I can't reproduce the issue on my local MacOS system.
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
|
||||
def test_infer_flux_params_from_state_dict():
|
||||
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
|
||||
with torch.device("meta"):
|
||||
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
|
||||
|
||||
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
|
||||
flux_params = infer_flux_params_from_state_dict(sd)
|
||||
|
||||
assert flux_params.in_channels == instantx_config["in_channels"]
|
||||
assert flux_params.vec_in_dim == instantx_config["pooled_projection_dim"]
|
||||
assert flux_params.context_in_dim == instantx_config["joint_attention_dim"]
|
||||
assert flux_params.hidden_size // flux_params.num_heads == instantx_config["attention_head_dim"]
|
||||
assert flux_params.num_heads == instantx_config["num_attention_heads"]
|
||||
assert flux_params.mlp_ratio == 4
|
||||
assert flux_params.depth == instantx_config["num_layers"]
|
||||
assert flux_params.depth_single_blocks == instantx_config["num_single_layers"]
|
||||
assert flux_params.axes_dim == instantx_config["axes_dims_rope"]
|
||||
assert flux_params.theta == 10000
|
||||
assert flux_params.qkv_bias
|
||||
assert flux_params.guidance_embed == instantx_config["guidance_embeds"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
|
||||
def test_infer_instantx_num_control_modes_from_state_dict():
|
||||
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
|
||||
with torch.device("meta"):
|
||||
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
|
||||
|
||||
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
|
||||
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
|
||||
|
||||
assert num_control_modes == instantx_config["num_mode"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
|
||||
def test_load_instantx_from_state_dict():
|
||||
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
|
||||
with torch.device("meta"):
|
||||
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
|
||||
|
||||
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
|
||||
flux_params = infer_flux_params_from_state_dict(sd)
|
||||
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
|
||||
|
||||
with torch.device("meta"):
|
||||
model = InstantXControlNetFlux(flux_params, num_control_modes)
|
||||
|
||||
model_sd = model.state_dict()
|
||||
|
||||
assert set(model_sd.keys()) == set(sd.keys())
|
||||
for key, tensor in model_sd.items():
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
assert tensor.shape == sd[key].shape
|
||||
@@ -0,0 +1,91 @@
|
||||
# State dict keys and shapes for an XLabs FLUX ControlNet model. Intended to be used for unit tests.
|
||||
# These keys were extracted from:
|
||||
# https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
|
||||
xlabs_sd_shapes = {
|
||||
"controlnet_blocks.0.bias": [3072],
|
||||
"controlnet_blocks.0.weight": [3072, 3072],
|
||||
"controlnet_blocks.1.bias": [3072],
|
||||
"controlnet_blocks.1.weight": [3072, 3072],
|
||||
"double_blocks.0.img_attn.norm.key_norm.scale": [128],
|
||||
"double_blocks.0.img_attn.norm.query_norm.scale": [128],
|
||||
"double_blocks.0.img_attn.proj.bias": [3072],
|
||||
"double_blocks.0.img_attn.proj.weight": [3072, 3072],
|
||||
"double_blocks.0.img_attn.qkv.bias": [9216],
|
||||
"double_blocks.0.img_attn.qkv.weight": [9216, 3072],
|
||||
"double_blocks.0.img_mlp.0.bias": [12288],
|
||||
"double_blocks.0.img_mlp.0.weight": [12288, 3072],
|
||||
"double_blocks.0.img_mlp.2.bias": [3072],
|
||||
"double_blocks.0.img_mlp.2.weight": [3072, 12288],
|
||||
"double_blocks.0.img_mod.lin.bias": [18432],
|
||||
"double_blocks.0.img_mod.lin.weight": [18432, 3072],
|
||||
"double_blocks.0.txt_attn.norm.key_norm.scale": [128],
|
||||
"double_blocks.0.txt_attn.norm.query_norm.scale": [128],
|
||||
"double_blocks.0.txt_attn.proj.bias": [3072],
|
||||
"double_blocks.0.txt_attn.proj.weight": [3072, 3072],
|
||||
"double_blocks.0.txt_attn.qkv.bias": [9216],
|
||||
"double_blocks.0.txt_attn.qkv.weight": [9216, 3072],
|
||||
"double_blocks.0.txt_mlp.0.bias": [12288],
|
||||
"double_blocks.0.txt_mlp.0.weight": [12288, 3072],
|
||||
"double_blocks.0.txt_mlp.2.bias": [3072],
|
||||
"double_blocks.0.txt_mlp.2.weight": [3072, 12288],
|
||||
"double_blocks.0.txt_mod.lin.bias": [18432],
|
||||
"double_blocks.0.txt_mod.lin.weight": [18432, 3072],
|
||||
"double_blocks.1.img_attn.norm.key_norm.scale": [128],
|
||||
"double_blocks.1.img_attn.norm.query_norm.scale": [128],
|
||||
"double_blocks.1.img_attn.proj.bias": [3072],
|
||||
"double_blocks.1.img_attn.proj.weight": [3072, 3072],
|
||||
"double_blocks.1.img_attn.qkv.bias": [9216],
|
||||
"double_blocks.1.img_attn.qkv.weight": [9216, 3072],
|
||||
"double_blocks.1.img_mlp.0.bias": [12288],
|
||||
"double_blocks.1.img_mlp.0.weight": [12288, 3072],
|
||||
"double_blocks.1.img_mlp.2.bias": [3072],
|
||||
"double_blocks.1.img_mlp.2.weight": [3072, 12288],
|
||||
"double_blocks.1.img_mod.lin.bias": [18432],
|
||||
"double_blocks.1.img_mod.lin.weight": [18432, 3072],
|
||||
"double_blocks.1.txt_attn.norm.key_norm.scale": [128],
|
||||
"double_blocks.1.txt_attn.norm.query_norm.scale": [128],
|
||||
"double_blocks.1.txt_attn.proj.bias": [3072],
|
||||
"double_blocks.1.txt_attn.proj.weight": [3072, 3072],
|
||||
"double_blocks.1.txt_attn.qkv.bias": [9216],
|
||||
"double_blocks.1.txt_attn.qkv.weight": [9216, 3072],
|
||||
"double_blocks.1.txt_mlp.0.bias": [12288],
|
||||
"double_blocks.1.txt_mlp.0.weight": [12288, 3072],
|
||||
"double_blocks.1.txt_mlp.2.bias": [3072],
|
||||
"double_blocks.1.txt_mlp.2.weight": [3072, 12288],
|
||||
"double_blocks.1.txt_mod.lin.bias": [18432],
|
||||
"double_blocks.1.txt_mod.lin.weight": [18432, 3072],
|
||||
"guidance_in.in_layer.bias": [3072],
|
||||
"guidance_in.in_layer.weight": [3072, 256],
|
||||
"guidance_in.out_layer.bias": [3072],
|
||||
"guidance_in.out_layer.weight": [3072, 3072],
|
||||
"img_in.bias": [3072],
|
||||
"img_in.weight": [3072, 64],
|
||||
"input_hint_block.0.bias": [16],
|
||||
"input_hint_block.0.weight": [16, 3, 3, 3],
|
||||
"input_hint_block.10.bias": [16],
|
||||
"input_hint_block.10.weight": [16, 16, 3, 3],
|
||||
"input_hint_block.12.bias": [16],
|
||||
"input_hint_block.12.weight": [16, 16, 3, 3],
|
||||
"input_hint_block.14.bias": [16],
|
||||
"input_hint_block.14.weight": [16, 16, 3, 3],
|
||||
"input_hint_block.2.bias": [16],
|
||||
"input_hint_block.2.weight": [16, 16, 3, 3],
|
||||
"input_hint_block.4.bias": [16],
|
||||
"input_hint_block.4.weight": [16, 16, 3, 3],
|
||||
"input_hint_block.6.bias": [16],
|
||||
"input_hint_block.6.weight": [16, 16, 3, 3],
|
||||
"input_hint_block.8.bias": [16],
|
||||
"input_hint_block.8.weight": [16, 16, 3, 3],
|
||||
"pos_embed_input.bias": [3072],
|
||||
"pos_embed_input.weight": [3072, 64],
|
||||
"time_in.in_layer.bias": [3072],
|
||||
"time_in.in_layer.weight": [3072, 256],
|
||||
"time_in.out_layer.bias": [3072],
|
||||
"time_in.out_layer.weight": [3072, 3072],
|
||||
"txt_in.bias": [3072],
|
||||
"txt_in.weight": [3072, 4096],
|
||||
"vector_in.in_layer.bias": [3072],
|
||||
"vector_in.in_layer.weight": [3072, 768],
|
||||
"vector_in.out_layer.bias": [3072],
|
||||
"vector_in.out_layer.weight": [3072, 3072],
|
||||
}
|
||||
Reference in New Issue
Block a user