Compare commits

..

1 Commits

Author SHA1 Message Date
psychedelicious
143621dbfd chore: bump version to v6.0.0rc1 2025-07-02 00:07:29 +10:00
315 changed files with 13634 additions and 16572 deletions

View File

@@ -3,15 +3,15 @@ description: Installs frontend dependencies with pnpm, with caching
runs:
using: 'composite'
steps:
- name: setup node 20
- name: setup node 18
uses: actions/setup-node@v4
with:
node-version: '20'
node-version: '18'
- name: setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10
version: 8.15.6
run_install: false
- name: get pnpm store directory

View File

@@ -72,7 +72,7 @@ async def upload_image(
resize_to: Optional[str] = Body(
default=None,
description=f"Dimensions to resize the image to, must be stringified tuple of 2 integers. Max total pixel count: {ResizeToDimensions.MAX_SIZE}",
examples=['"[1024,1024]"'],
example='"[1024,1024]"',
),
metadata: Optional[str] = Body(
default=None,

View File

@@ -292,7 +292,7 @@ async def get_hugging_face_models(
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
) -> AnyModelConfig:
"""Update a model's config."""
logger = ApiDependencies.invoker.services.logger
@@ -450,7 +450,7 @@ async def install_model(
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
config: ModelRecordChanges = Body(
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
examples=[{"name": "string", "description": "string"}],
example={"name": "string", "description": "string"},
),
) -> ModelInstallJob:
"""Install a model using a string identifier.

View File

@@ -1,6 +1,6 @@
from typing import Optional
from fastapi import Body, HTTPException, Path, Query
from fastapi import Body, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
@@ -22,7 +22,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemNotFoundError,
SessionQueueStatus,
)
from invokeai.app.services.shared.pagination import CursorPaginatedResults
@@ -60,12 +59,10 @@ async def enqueue_batch(
),
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
try:
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while enqueuing batch: {e}")
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
@session_queue_router.get(
@@ -85,17 +82,14 @@ async def list_queue_items(
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets cursor-paginated queue items"""
try:
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id,
limit=limit,
status=status,
cursor=cursor,
priority=priority,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all items: {e}")
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id,
limit=limit,
status=status,
cursor=cursor,
priority=priority,
destination=destination,
)
@session_queue_router.get(
@@ -110,13 +104,11 @@ async def list_all_queue_items(
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> list[SessionQueueItem]:
"""Gets all queue items"""
try:
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
destination=destination,
)
@session_queue_router.put(
@@ -128,10 +120,7 @@ async def resume(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Resumes session processor"""
try:
return ApiDependencies.invoker.services.session_processor.resume()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while resuming queue: {e}")
return ApiDependencies.invoker.services.session_processor.resume()
@session_queue_router.put(
@@ -143,10 +132,7 @@ async def Pause(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Pauses session processor"""
try:
return ApiDependencies.invoker.services.session_processor.pause()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while pausing queue: {e}")
return ApiDependencies.invoker.services.session_processor.pause()
@session_queue_router.put(
@@ -158,10 +144,7 @@ async def cancel_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> CancelAllExceptCurrentResult:
"""Immediately cancels all queue items except in-processing items"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling all except current: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
@session_queue_router.put(
@@ -173,10 +156,7 @@ async def delete_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> DeleteAllExceptCurrentResult:
"""Immediately deletes all queue items except in-processing items"""
try:
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting all except current: {e}")
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
@session_queue_router.put(
@@ -189,12 +169,7 @@ async def cancel_by_batch_ids(
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
) -> CancelByBatchIDsResult:
"""Immediately cancels all queue items from the given batch ids"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(
queue_id=queue_id, batch_ids=batch_ids
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by batch id: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
@session_queue_router.put(
@@ -207,12 +182,9 @@ async def cancel_by_destination(
destination: str = Query(description="The destination to cancel all queue items for"),
) -> CancelByDestinationResult:
"""Immediately cancels all queue items with the given origin"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by destination: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
)
@session_queue_router.put(
@@ -225,10 +197,7 @@ async def retry_items_by_id(
item_ids: list[int] = Body(description="The queue item ids to retry"),
) -> RetryItemsResult:
"""Immediately cancels all queue items with the given origin"""
try:
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while retrying queue items: {e}")
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
@session_queue_router.put(
@@ -242,14 +211,11 @@ async def clear(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> ClearResult:
"""Clears the queue entirely, immediately canceling the currently-executing session"""
try:
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while clearing queue: {e}")
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
@session_queue_router.put(
@@ -263,10 +229,7 @@ async def prune(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> PruneResult:
"""Prunes all completed or errored queue items"""
try:
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while pruning queue: {e}")
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
@session_queue_router.get(
@@ -280,10 +243,7 @@ async def get_current_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the currently execution queue item"""
try:
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
@session_queue_router.get(
@@ -297,10 +257,7 @@ async def get_next_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the next queue item, without executing it"""
try:
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
@session_queue_router.get(
@@ -314,12 +271,9 @@ async def get_queue_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
"""Gets the status of the session queue"""
try:
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting queue status: {e}")
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
@session_queue_router.get(
@@ -334,10 +288,7 @@ async def get_batch_status(
batch_id: str = Path(description="The batch to get the status of"),
) -> BatchStatus:
"""Gets the status of the session queue"""
try:
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
@session_queue_router.get(
@@ -353,12 +304,7 @@ async def get_queue_item(
item_id: int = Path(description="The queue item to get"),
) -> SessionQueueItem:
"""Gets a queue item"""
try:
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching queue item: {e}")
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
@session_queue_router.delete(
@@ -370,10 +316,7 @@ async def delete_queue_item(
item_id: int = Path(description="The queue item to delete"),
) -> None:
"""Deletes a queue item"""
try:
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting queue item: {e}")
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
@session_queue_router.put(
@@ -388,12 +331,8 @@ async def cancel_queue_item(
item_id: int = Path(description="The queue item to cancel"),
) -> SessionQueueItem:
"""Deletes a queue item"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling queue item: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
@session_queue_router.get(
@@ -406,12 +345,9 @@ async def counts_by_destination(
destination: str = Query(description="The destination to query"),
) -> SessionQueueCountsByDestination:
"""Gets the counts of queue items by destination"""
try:
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
)
@session_queue_router.delete(
@@ -424,9 +360,6 @@ async def delete_by_destination(
destination: str = Path(description="The destination to query"),
) -> DeleteByDestinationResult:
"""Deletes all items with the given destination"""
try:
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting by destination: {e}")
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
queue_id=queue_id, destination=destination
)

View File

@@ -391,29 +391,28 @@ class FluxDenoiseInvocation(BaseInvocation):
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
kontext_extension = KontextExtension(
kontext_field=self.kontext_conditioning,
context=context,
kontext_conditioning=self.kontext_conditioning,
vae_field=self.controlnet_vae,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
)
# Prepare Kontext conditioning if provided
img_cond_seq = None
img_cond_seq_ids = None
final_img, final_img_ids = x, img_ids
original_seq_len = x.shape[1]
if kontext_extension is not None:
# Ensure batch sizes match
kontext_extension.ensure_batch_size(x.shape[0])
img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids
final_img, final_img_ids = kontext_extension.apply(final_img, final_img_ids)
x = denoise(
model=transformer,
img=x,
img_ids=img_ids,
img=final_img,
img_ids=final_img_ids,
pos_regional_prompting_extension=pos_regional_prompting_extension,
neg_regional_prompting_extension=neg_regional_prompting_extension,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
step_callback=self._build_step_callback(
context, original_seq_len if kontext_extension is not None else None
),
guidance=self.guidance,
cfg_scale=cfg_scale,
inpaint_extension=inpaint_extension,
@@ -421,10 +420,11 @@ class FluxDenoiseInvocation(BaseInvocation):
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
img_cond=img_cond,
img_cond_seq=img_cond_seq,
img_cond_seq_ids=img_cond_seq_ids,
)
if kontext_extension is not None:
x = x[:, :original_seq_len, :] # Keep only the first original_seq_len tokens
x = unpack(x.float(), self.height, self.width)
return x
@@ -895,11 +895,14 @@ class FluxDenoiseInvocation(BaseInvocation):
yield (lora_info.model, lora.weight)
del lora_info
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def _build_step_callback(
self, context: InvocationContext, original_seq_len: Optional[int] = None
) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
# The denoise function now handles Kontext conditioning correctly,
# so we don't need to slice the latents here
# Extract only main image tokens if Kontext conditioning was applied
latents = state.latents.float()
if original_seq_len is not None:
latents = latents[:, :original_seq_len, :]
state.latents = unpack(latents, self.height, self.width).squeeze()
context.util.flux_step_callback(state)

View File

@@ -404,8 +404,6 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id] + batch_ids
cursor.execute(
@@ -444,8 +442,6 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = (queue_id, destination)
cursor.execute(
@@ -548,8 +544,6 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id]
cursor.execute(
@@ -570,9 +564,12 @@ class SqliteSessionQueue(SessionQueueBase):
tuple(params),
)
self._conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
)
except Exception:
self._conn.rollback()
raise
@@ -743,7 +740,7 @@ class SqliteSessionQueue(SessionQueueBase):
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] or 0 for row in counts_result)
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueStatus(
queue_id=queue_id,
@@ -772,7 +769,7 @@ class SqliteSessionQueue(SessionQueueBase):
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in result)
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
@@ -804,7 +801,7 @@ class SqliteSessionQueue(SessionQueueBase):
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in counts_result)
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueCountsByDestination(

View File

@@ -30,11 +30,8 @@ def denoise(
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
# extra img tokens (channel-wise)
# extra img tokens
img_cond: torch.Tensor | None,
# extra img tokens (sequence-wise) - for Kontext conditioning
img_cond_seq: torch.Tensor | None = None,
img_cond_seq_ids: torch.Tensor | None = None,
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -49,10 +46,6 @@ def denoise(
)
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
# Store original sequence length for slicing predictions
original_seq_len = img.shape[1]
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
@@ -78,26 +71,10 @@ def denoise(
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
# Prepare input for model - concatenate fresh each step
img_input = img
img_input_ids = img_ids
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
if img_cond is not None:
img_input = torch.cat((img_input, img_cond), dim=-1)
# Add sequence-wise conditioning (for Kontext)
if img_cond_seq is not None:
assert img_cond_seq_ids is not None, (
"You need to provide either both or neither of the sequence conditioning"
)
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
pred = model(
img=img_input,
img_ids=img_input_ids,
img=pred_img,
img_ids=img_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
@@ -111,10 +88,6 @@ def denoise(
regional_prompting_extension=pos_regional_prompting_extension,
)
# Slice prediction to only include the main image tokens
if img_input_ids is not None:
pred = pred[:, :original_seq_len]
step_cfg_scale = cfg_scale[step_index]
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.

View File

@@ -1,15 +1,13 @@
import einops
import numpy as np
import torch
from einops import repeat
from PIL import Image
from invokeai.app.invocations.fields import FluxKontextConditioningField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.invocations.model import VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import pack
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
def generate_img_ids_with_offset(
@@ -73,7 +71,7 @@ class KontextExtension:
def __init__(
self,
kontext_conditioning: FluxKontextConditioningField,
kontext_field: FluxKontextConditioningField,
context: InvocationContext,
vae_field: VAEField,
device: torch.device,
@@ -87,49 +85,30 @@ class KontextExtension:
self._device = device
self._dtype = dtype
self._vae_field = vae_field
self.kontext_conditioning = kontext_conditioning
self.kontext_field = kontext_field
# Pre-process and cache the kontext latents and ids upon initialization.
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Encodes the reference image and prepares its latents and IDs."""
image = self._context.images.get_pil(self.kontext_conditioning.image.image_name)
image = self._context.images.get_pil(self.kontext_field.image.image_name)
# Calculate aspect ratio of input image
width, height = image.size
aspect_ratio = width / height
# Find the closest preferred resolution by aspect ratio
_, target_width, target_height = min(
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
)
# Apply BFL's scaling formula
# This ensures compatibility with the model's training
scaled_width = 2 * int(target_width / 16)
scaled_height = 2 * int(target_height / 16)
# Resize to the exact resolution used during training
image = image.convert("RGB")
final_width = 8 * scaled_width
final_height = 8 * scaled_height
image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
# Convert to tensor with same normalization as BFL
image_np = np.array(image)
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0
image_tensor = einops.rearrange(image_tensor, "h w c -> 1 c h w")
# Reuse VAE encoding logic from FluxVaeEncodeInvocation
vae_info = self._context.models.load(self._vae_field.vae)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
image_tensor = image_tensor.to(self._device)
# Continue with VAE encoding
vae_info = self._context.models.load(self._vae_field.vae)
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
# Extract tensor dimensions
# Extract tensor dimensions with descriptive names
# Latent tensor shape: [batch_size, channels, latent_height, latent_width]
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Pack the latents and generate IDs
# Pack the latents and generate IDs. The idx_offset distinguishes these
# tokens from the main image's tokens, which have an index of 0.
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
kontext_ids = generate_img_ids_with_offset(
latent_height=latent_height,
@@ -137,13 +116,24 @@ class KontextExtension:
batch_size=batch_size,
device=self._device,
dtype=self._dtype,
idx_offset=1,
idx_offset=1, # Distinguishes reference tokens from main image tokens
)
return kontext_latents_packed, kontext_ids
def ensure_batch_size(self, target_batch_size: int) -> None:
"""Ensures the kontext latents and IDs match the target batch size by repeating if necessary."""
if self.kontext_latents.shape[0] != target_batch_size:
self.kontext_latents = self.kontext_latents.repeat(target_batch_size, 1, 1)
self.kontext_ids = self.kontext_ids.repeat(target_batch_size, 1, 1)
def apply(
self,
img: torch.Tensor,
img_ids: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Concatenates the pre-processed kontext data to the main image sequence."""
# Ensure batch sizes match, repeating kontext data if necessary for batch operations.
if img.shape[0] != self.kontext_latents.shape[0]:
self.kontext_latents = self.kontext_latents.repeat(img.shape[0], 1, 1)
self.kontext_ids = self.kontext_ids.repeat(img.shape[0], 1, 1)
# Concatenate along the sequence dimension (dim=1)
combined_img = torch.cat([img, self.kontext_latents], dim=1)
combined_img_ids = torch.cat([img_ids, self.kontext_ids], dim=1)
return combined_img, combined_img_ids

View File

@@ -174,13 +174,11 @@ def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtyp
dtype = torch.float16
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
# Set batch offset to 0 for main image tokens
img_ids[..., 0] = 0
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
if device.type == "mps":
img_ids = img_ids.to(orig_dtype)
img_ids.to(orig_dtype)
return img_ids

View File

@@ -18,29 +18,6 @@ class ModelSpec:
repo_ae: str | None
# Preferred resolutions for Kontext models to avoid tiling artifacts
# These are the specific resolutions the model was trained on
PREFERED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-dev-fill": 512,

View File

@@ -7,14 +7,7 @@ from typing import Optional
import accelerate
import torch
from safetensors.torch import load_file
from transformers import (
AutoConfig,
AutoModelForTextEncoding,
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
@@ -146,7 +139,7 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
)
match submodel_type:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
@@ -190,7 +183,7 @@ class T5EncoderCheckpointModel(ModelLoader):
match submodel_type:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(
Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True

View File

@@ -143,19 +143,11 @@ flux_dev = StarterModel(
flux_kontext = StarterModel(
name="FLUX.1 Kontext dev",
base=BaseModelType.Flux,
source="https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/resolve/main/flux1-kontext-dev.safetensors",
source="black-forest-labs/FLUX.1-Kontext-dev::flux1-kontext-dev.safetensors",
description="FLUX.1 Kontext dev transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_kontext_quantized = StarterModel(
name="FLUX.1 Kontext dev (Quantized)",
base=BaseModelType.Flux,
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
sd35_medium = StarterModel(
name="SD3.5 Medium",
base=BaseModelType.StableDiffusion3,
@@ -672,7 +664,7 @@ flux_fill = StarterModel(
# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
STARTER_MODELS: list[StarterModel] = [
flux_kontext_quantized,
flux_kontext,
flux_schnell_quantized,
flux_dev_quantized,
flux_schnell,
@@ -793,7 +785,7 @@ flux_bundle: list[StarterModel] = [
flux_depth_control_lora,
flux_redux,
flux_fill,
flux_kontext_quantized,
flux_kontext,
]
STARTER_BUNDLES: dict[str, StarterModelBundle] = {

View File

@@ -17,15 +17,6 @@ module.exports = {
'no-promise-executor-return': 'error',
// https://eslint.org/docs/latest/rules/require-await
'require-await': 'error',
// Restrict setActiveTab calls to only use-navigation-api.tsx
'no-restricted-syntax': [
'error',
{
selector: 'CallExpression[callee.name="setActiveTab"]',
message:
'setActiveTab() can only be called from use-navigation-api.tsx. Use navigationApi.switchToTab() instead.',
},
],
// TODO: ENABLE THIS RULE BEFORE v6.0.0
'react/display-name': 'off',
'no-restricted-properties': [
@@ -65,15 +56,6 @@ module.exports = {
],
},
overrides: [
/**
* Allow setActiveTab calls only in use-navigation-api.tsx
*/
{
files: ['**/use-navigation-api.tsx'],
rules: {
'no-restricted-syntax': 'off',
},
},
/**
* Overrides for stories
*/

View File

@@ -3,6 +3,8 @@ import type { KnipConfig } from 'knip';
const config: KnipConfig = {
project: ['src/**/*.{ts,tsx}!'],
ignore: [
// TODO(psyche): temporarily ignored all files for test build purposes
'src/**',
// This file is only used during debugging
'src/app/store/middleware/debugLoggerMiddleware.ts',
// Autogenerated types - shouldn't ever touch these
@@ -12,8 +14,10 @@ const config: KnipConfig = {
'src/features/parameters/types/parameterSchemas.ts',
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
'src/features/controlLayers/konva/util.ts',
// Will be using this
'src/common/hooks/useAsyncState.ts',
// TODO(psyche): restore HRF functionality?
'src/features/hrf/**',
// This feature is (temprarily?) disabled
'src/features/controlLayers/components/InpaintMask/InpaintMaskAddButtons.tsx',
],
ignoreBinaries: ['only-allow'],
paths: {

View File

@@ -38,6 +38,19 @@
"test:ui": "vitest --coverage --ui",
"test:no-watch": "vitest --no-watch"
},
"madge": {
"excludeRegExp": [
"^index.ts$"
],
"detectiveOptions": {
"ts": {
"skipTypeImports": true
},
"tsx": {
"skipTypeImports": true
}
}
},
"dependencies": {
"@atlaskit/pragmatic-drag-and-drop": "^1.7.4",
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^2.1.1",
@@ -134,7 +147,7 @@
"eslint": "^8.57.1",
"eslint-plugin-i18next": "^6.1.1",
"eslint-plugin-path": "^1.3.0",
"knip": "^5.61.3",
"knip": "^5.50.5",
"openapi-types": "^12.1.3",
"openapi-typescript": "^7.6.1",
"prettier": "^3.5.3",
@@ -143,7 +156,7 @@
"tsafe": "^1.8.5",
"type-fest": "^4.40.0",
"typescript": "^5.8.3",
"vite": "^7.0.2",
"vite": "^6.3.3",
"vite-plugin-css-injected-by-js": "^3.5.2",
"vite-plugin-dts": "^4.5.3",
"vite-plugin-eslint": "^1.8.1",
@@ -151,7 +164,7 @@
"vitest": "^3.1.2"
},
"engines": {
"pnpm": "10"
"pnpm": "8"
},
"packageManager": "pnpm@10.12.4"
"packageManager": "pnpm@8.15.9+sha512.499434c9d8fdd1a2794ebf4552b3b25c0a633abcee5bb15e7b5de90f32f47b513aca98cd5cfd001c31f0db454bc3804edccd578501e4ca293a6816166bbd9f81"
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +0,0 @@
onlyBuiltDependencies:
- '@swc/core'
- esbuild

View File

@@ -225,16 +225,7 @@
"prompt": {
"addPromptTrigger": "Add Prompt Trigger",
"compatibleEmbeddings": "Compatible Embeddings",
"noMatchingTriggers": "No matching triggers",
"generateFromImage": "Generate prompt from image",
"expandCurrentPrompt": "Expand Current Prompt",
"uploadImageForPromptGeneration": "Upload Image for Prompt Generation",
"expandingPrompt": "Expanding prompt...",
"resultTitle": "Prompt Expansion Complete",
"resultSubtitle": "Choose how to handle the expanded prompt:",
"replace": "Replace",
"insert": "Insert",
"discard": "Discard"
"noMatchingTriggers": "No matching triggers"
},
"queue": {
"queue": "Queue",
@@ -351,7 +342,7 @@
"copy": "Copy",
"currentlyInUse": "This image is currently in use in the following features:",
"drop": "Drop",
"dropOrUpload": "Drop or Upload",
"dropOrUpload": "$t(gallery.drop) or Upload",
"dropToUpload": "$t(gallery.drop) to Upload",
"deleteImage_one": "Delete Image",
"deleteImage_other": "Delete {{count}} Images",
@@ -405,8 +396,7 @@
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit.",
"openViewer": "Open Viewer",
"closeViewer": "Close Viewer",
"move": "Move",
"useForPromptGeneration": "Use for Prompt Generation"
"move": "Move"
},
"hotkeys": {
"hotkeys": "Hotkeys",
@@ -762,7 +752,7 @@
"vae": "VAE",
"width": "Width",
"workflow": "Workflow",
"canvasV2Metadata": "Canvas Layers"
"canvasV2Metadata": "Canvas"
},
"modelManager": {
"active": "active",
@@ -948,8 +938,7 @@
"selectModel": "Select a Model",
"noLoRAsInstalled": "No LoRAs installed",
"noRefinerModelsInstalled": "No SDXL Refiner models installed",
"defaultVAE": "Default VAE",
"noCompatibleLoRAs": "No Compatible LoRAs"
"defaultVAE": "Default VAE"
},
"nodes": {
"arithmeticSequence": "Arithmetic Sequence",
@@ -1199,9 +1188,7 @@
"canvasIsSelectingObject": "Canvas is busy (selecting object)",
"noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph",
"systemDisconnected": "System disconnected",
"promptExpansionPending": "Prompt expansion in progress",
"promptExpansionResultPending": "Please accept or discard your prompt expansion result"
"systemDisconnected": "System disconnected"
},
"maskBlur": "Mask Blur",
"negativePromptPlaceholder": "Negative Prompt",
@@ -1399,15 +1386,10 @@
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
"imagenIncompatibleGenerationMode": "Google {{model}} supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
"fluxKontextIncompatibleGenerationMode": "FLUX Kontext does not support generation from images placed on the canvas. Re-try using the Reference Image section and disable any Raster Layers.",
"fluxKontextIncompatibleGenerationMode": "Flux Kontext supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
"workflowUnpublished": "Workflow Unpublished",
"sentToCanvas": "Sent to Canvas",
"sentToUpscale": "Sent to Upscale",
"promptGenerationStarted": "Prompt generation started",
"uploadAndPromptGenerationFailed": "Failed to upload image and generate prompt",
"promptExpansionFailed": "We ran into an issue. Please try prompt expansion again."
"workflowUnpublished": "Workflow Unpublished"
},
"popovers": {
"clipSkip": {
@@ -1962,7 +1944,6 @@
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"outputOnlyMaskedRegions": "Output Only Generated Regions",
"saveAllImagesToGallery": "Save All Images to Gallery",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
@@ -2331,9 +2312,6 @@
"label": "Preserve Masked Region",
"alert": "Preserving Masked Region"
},
"saveAllImagesToGallery": {
"alert": "Saving All Images to Gallery"
},
"isolatedStagingPreview": "Isolated Staging Preview",
"isolatedPreview": "Isolated Preview",
"isolatedLayerPreview": "Isolated Layer Preview",
@@ -2362,7 +2340,6 @@
"newGlobalReferenceImage": "New Global Reference Image",
"newRegionalReferenceImage": "New Regional Reference Image",
"newControlLayer": "New Control Layer",
"newResizedControlLayer": "New Resized Control Layer",
"newRasterLayer": "New Raster Layer",
"newInpaintMask": "New Inpaint Mask",
"newRegionalGuidance": "New Regional Guidance",
@@ -2380,11 +2357,6 @@
"saveToGallery": "Save To Gallery",
"showResultsOn": "Showing Results",
"showResultsOff": "Hiding Results"
},
"autoSwitch": {
"off": "Off",
"switchOnStart": "On Start",
"switchOnFinish": "On Finish"
}
},
"upscaling": {
@@ -2560,9 +2532,8 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"Generate images faster with new Launchpads and a simplified Generate tab.",
"Edit with prompts using Flux Kontext Dev.",
"Export to PSD, bulk-hide overlays, organize models & images — all in a reimagined interface built for control."
"Inpainting: Per-mask noise levels and denoise limits.",
"Canvas: Smarter aspect ratios for SDXL and improved scroll-to-zoom."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
@@ -2571,16 +2542,62 @@
"supportVideos": {
"supportVideos": "Support Videos",
"gettingStarted": "Getting Started",
"controlCanvas": "Control Canvas",
"watch": "Watch",
"studioSessionsDesc": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
"studioSessionsDesc1": "Check out the <StudioSessionsPlaylistLink /> for Invoke deep dives.",
"studioSessionsDesc2": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
"videos": {
"gettingStarted": {
"title": "Getting Started with Invoke",
"description": "Complete video series covering everything you need to know to get started with Invoke, from creating your first image to advanced techniques."
"creatingYourFirstImage": {
"title": "Creating Your First Image",
"description": "Introduction to creating an image from scratch using Invoke's tools."
},
"studioSessions": {
"title": "Studio Sessions",
"description": "Deep dive sessions exploring advanced Invoke features, creative workflows, and community discussions."
"usingControlLayersAndReferenceGuides": {
"title": "Using Control Layers and Reference Guides",
"description": "Learn how to guide your image creation with control layers and reference images."
},
"understandingImageToImageAndDenoising": {
"title": "Understanding Image-to-Image and Denoising",
"description": "Overview of image-to-image transformations and denoising in Invoke."
},
"exploringAIModelsAndConceptAdapters": {
"title": "Exploring AI Models and Concept Adapters",
"description": "Dive into AI models and how to use concept adapters for creative control."
},
"creatingAndComposingOnInvokesControlCanvas": {
"title": "Creating and Composing on Invoke's Control Canvas",
"description": "Learn to compose images using Invoke's control canvas."
},
"upscaling": {
"title": "Upscaling",
"description": "How to upscale images with Invoke's tools to enhance resolution."
},
"howDoIGenerateAndSaveToTheGallery": {
"title": "How Do I Generate and Save to the Gallery?",
"description": "Steps to generate and save images to the gallery."
},
"howDoIEditOnTheCanvas": {
"title": "How Do I Edit on the Canvas?",
"description": "Guide to editing images directly on the canvas."
},
"howDoIDoImageToImageTransformation": {
"title": "How Do I Do Image-to-Image Transformation?",
"description": "Tutorial on performing image-to-image transformations in Invoke."
},
"howDoIUseControlNetsAndControlLayers": {
"title": "How Do I Use Control Nets and Control Layers?",
"description": "Learn to apply control layers and controlnets to your images."
},
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
"title": "How Do I Use Global IP Adapters and Reference Images?",
"description": "Introduction to adding reference images and global IP adapters."
},
"howDoIUseInpaintMasks": {
"title": "How Do I Use Inpaint Masks?",
"description": "How to apply inpaint masks for image correction and variation."
},
"howDoIOutpaint": {
"title": "How Do I Outpaint?",
"description": "Guide to outpainting beyond the original image borders."
}
}
}

View File

@@ -2,7 +2,8 @@ import { Box } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $globalIsLoading } from 'app/store/nanostores/globalIsLoading';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import { useClearStorage } from 'common/hooks/useClearStorage';
@@ -11,7 +12,6 @@ import { memo, useCallback } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import ThemeLocaleProvider from './ThemeLocaleProvider';
const DEFAULT_CONFIG = {};
interface Props {
@@ -20,7 +20,7 @@ interface Props {
}
const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
const didStudioInit = useStore($didStudioInit);
const globalIsLoading = useStore($globalIsLoading);
const clearStorage = useClearStorage();
const handleReset = useCallback(() => {
@@ -31,14 +31,12 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
<ThemeLocaleProvider>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{!didStudioInit && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />
</ThemeLocaleProvider>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{globalIsLoading && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />
</ErrorBoundary>
);
};

View File

@@ -1,5 +1,4 @@
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
import { setupListeners } from '@reduxjs/toolkit/query';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { useStudioInitAction } from 'app/hooks/useStudioInitAction';
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
@@ -11,7 +10,7 @@ import type { PartialAppConfig } from 'app/types/invokeai';
import { useFocusRegionWatcher } from 'common/hooks/focus';
import { useCloseChakraTooltipsOnDragFix } from 'common/hooks/useCloseChakraTooltipsOnDragFix';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import { useDndMonitor } from 'features/dnd/useDndMonitor';
import { size } from 'es-toolkit/compat';
import { useDynamicPromptsWatcher } from 'features/dynamicPrompts/hooks/useDynamicPromptsWatcher';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { useWorkflowBuilderWatcher } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
@@ -46,7 +45,6 @@ export const GlobalHookIsolator = memo(
useSyncLoggingConfig();
useCloseChakraTooltipsOnDragFix();
useNavigationApi();
useDndMonitor();
// Persistent subscription to the queue counts query - canvas relies on this to know if there are pending
// and/or in progress canvas sessions.
@@ -57,18 +55,16 @@ export const GlobalHookIsolator = memo(
}, [language]);
useEffect(() => {
logger.info({ config }, 'Received config');
dispatch(configChanged(config));
if (size(config)) {
logger.info({ config }, 'Received config');
dispatch(configChanged(config));
}
}, [dispatch, config, logger]);
useEffect(() => {
dispatch(appStarted());
}, [dispatch]);
useEffect(() => {
return setupListeners(dispatch);
}, [dispatch]);
useStudioInitAction(studioInitAction);
useStarterModelsToast();
useSyncQueueStatus();

View File

@@ -1,14 +1,11 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useIsRegionFocused } from 'common/hooks/focus';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { useLoadWorkflow } from 'features/gallery/hooks/useLoadWorkflow';
import { useRecallAll } from 'features/gallery/hooks/useRecallAll';
import { useRecallDimensions } from 'features/gallery/hooks/useRecallDimensions';
import { useRecallPrompts } from 'features/gallery/hooks/useRecallPrompts';
import { useRecallRemix } from 'features/gallery/hooks/useRecallRemix';
import { useRecallSeed } from 'features/gallery/hooks/useRecallSeed';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
@@ -30,64 +27,59 @@ GlobalImageHotkeys.displayName = 'GlobalImageHotkeys';
const GlobalImageHotkeysInternal = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
const isGalleryFocused = useIsRegionFocused('gallery');
const isViewerFocused = useIsRegionFocused('viewer');
const isFocusOK = isGalleryFocused || isViewerFocused;
const recallAll = useRecallAll(imageDTO);
const recallRemix = useRecallRemix(imageDTO);
const recallPrompts = useRecallPrompts(imageDTO);
const recallSeed = useRecallSeed(imageDTO);
const recallDimensions = useRecallDimensions(imageDTO);
const loadWorkflow = useLoadWorkflow(imageDTO);
const imageActions = useImageActions(imageDTO);
const isStaging = useAppSelector(selectIsStaging);
const isUpscalingEnabled = useFeatureStatus('upscaling');
useRegisteredHotkeys({
id: 'loadWorkflow',
category: 'viewer',
callback: loadWorkflow.load,
options: { enabled: loadWorkflow.isEnabled && isFocusOK },
dependencies: [loadWorkflow, isFocusOK],
callback: imageActions.loadWorkflow,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.loadWorkflow, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'recallAll',
category: 'viewer',
callback: recallAll.recall,
options: { enabled: recallAll.isEnabled && isFocusOK },
dependencies: [recallAll, isFocusOK],
callback: imageActions.recallAll,
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
dependencies: [imageActions.recallAll, isStaging, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'recallSeed',
category: 'viewer',
callback: recallSeed.recall,
options: { enabled: recallSeed.isEnabled && isFocusOK },
dependencies: [recallSeed, isFocusOK],
callback: imageActions.recallSeed,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.recallSeed, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'recallPrompts',
category: 'viewer',
callback: recallPrompts.recall,
options: { enabled: recallPrompts.isEnabled && isFocusOK },
dependencies: [recallPrompts, isFocusOK],
callback: imageActions.recallPrompts,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.recallPrompts, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'remix',
category: 'viewer',
callback: recallRemix.recall,
options: { enabled: recallRemix.isEnabled && isFocusOK },
dependencies: [recallRemix, isFocusOK],
callback: imageActions.remix,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.remix, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'useSize',
category: 'viewer',
callback: recallDimensions.recall,
options: { enabled: recallDimensions.isEnabled && isFocusOK },
dependencies: [recallDimensions, isFocusOK],
callback: imageActions.recallSize,
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
dependencies: [imageActions.recallSize, isStaging, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'runPostprocessing',
category: 'viewer',
callback: imageActions.upscale,
options: { enabled: isUpscalingEnabled && isViewerFocused },
dependencies: [isUpscalingEnabled, imageDTO, isViewerFocused],
});
return null;
});

View File

@@ -42,6 +42,7 @@ import { $socketOptions } from 'services/events/stores';
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
interface Props extends PropsWithChildren {
apiUrl?: string;
@@ -329,7 +330,9 @@ const InvokeAIUI = ({
<React.StrictMode>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App config={config} studioInitAction={studioInitAction} />
<ThemeLocaleProvider>
<App config={config} studioInitAction={studioInitAction} />
</ThemeLocaleProvider>
</React.Suspense>
</Provider>
</React.StrictMode>

View File

@@ -8,7 +8,7 @@ import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { sentImageToCanvas } from 'features/gallery/store/actions';
import { MetadataUtils } from 'features/metadata/parsing';
import { parseAndRecallAllMetadata } from 'features/metadata/util/handlers';
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
import { $isWorkflowLibraryModalOpen } from 'features/nodes/store/workflowLibraryModal';
import {
@@ -19,9 +19,7 @@ import {
} from 'features/nodes/store/workflowLibrarySlice';
import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { LAUNCHPAD_PANEL_ID, WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { activeTabCanvasRightPanelChanged, setActiveTab } from 'features/ui/store/uiSlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { atom } from 'nanostores';
import { useCallback, useEffect } from 'react';
@@ -92,7 +90,6 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
const overrides: Partial<CanvasRasterLayerState> = {
objects: [imageObject],
};
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
store.dispatch(canvasReset());
store.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
store.dispatch(sentImageToCanvas());
@@ -119,23 +116,23 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
const metadata = getImageMetadataResult.value;
store.dispatch(canvasReset());
// This shows a toast
await MetadataUtils.recallAll(metadata, store);
await parseAndRecallAllMetadata(metadata, true);
},
[store, t]
);
const handleLoadWorkflow = useCallback(
(workflowId: string) => {
async (workflowId: string) => {
// This shows a toast
loadWorkflowWithDialog({
await loadWorkflowWithDialog({
type: 'library',
data: workflowId,
onSuccess: () => {
navigationApi.switchToTab('workflows');
store.dispatch(setActiveTab('workflows'));
},
});
},
[loadWorkflowWithDialog]
[loadWorkflowWithDialog, store]
);
const handleSelectStylePreset = useCallback(
@@ -149,7 +146,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
return;
}
store.dispatch(activeStylePresetIdChanged(stylePresetId));
navigationApi.switchToTab('canvas');
store.dispatch(setActiveTab('canvas'));
toast({
title: t('toast.stylePresetLoaded'),
status: 'info',
@@ -159,34 +156,33 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
);
const handleGoToDestination = useCallback(
async (destination: StudioDestinationAction['data']['destination']) => {
(destination: StudioDestinationAction['data']['destination']) => {
switch (destination) {
case 'generation':
// Go to the generate tab, open the launchpad
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
// Go to the canvas tab, open the image viewer, and enable send-to-gallery mode
store.dispatch(paramsReset());
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
break;
case 'canvas':
// Go to the canvas tab, open the launchpad
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
// Go to the canvas tab, close the image viewer, and disable send-to-gallery mode
store.dispatch(canvasReset());
break;
case 'workflows':
// Go to the workflows tab
navigationApi.switchToTab('workflows');
store.dispatch(setActiveTab('workflows'));
break;
case 'upscaling':
// Go to the upscaling tab
navigationApi.switchToTab('upscaling');
store.dispatch(setActiveTab('upscaling'));
break;
case 'viewAllWorkflows':
// Go to the workflows tab and open the workflow library modal
navigationApi.switchToTab('workflows');
store.dispatch(setActiveTab('workflows'));
$isWorkflowLibraryModalOpen.set(true);
break;
case 'viewAllWorkflowsRecommended':
// Go to the workflows tab and open the workflow library modal with the recommended workflows view
navigationApi.switchToTab('workflows');
store.dispatch(setActiveTab('workflows'));
$isWorkflowLibraryModalOpen.set(true);
store.dispatch(workflowLibraryViewChanged('defaults'));
store.dispatch(workflowLibraryTagsReset());
@@ -198,7 +194,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
break;
case 'viewAllStylePresets':
// Go to the canvas tab and open the style presets menu
navigationApi.switchToTab('canvas');
store.dispatch(setActiveTab('canvas'));
$isStylePresetsMenuOpen.set(true);
break;
}

View File

@@ -1,6 +1,6 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectGetImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { selectListImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
@@ -20,7 +20,7 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
const board_id = selectSelectedBoardId(state);
const queryArgs = { ...selectGetImageNamesQueryArgs(state), board_id };
const queryArgs = { ...selectListImageNamesQueryArgs(state), board_id };
// wait until the board has some images - maybe it already has some from a previous fetch
// must use getState() to ensure we do not have stale state

View File

@@ -1,28 +1,14 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { bboxSyncedToOptimalDimension, rgRefImageModelChanged } from 'features/controlLayers/store/canvasSlice';
import { bboxSyncedToOptimalDimension } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { loraDeleted } from 'features/controlLayers/store/lorasSlice';
import { modelChanged, syncedToOptimalDimension, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { refImageModelChanged, selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
import {
selectAllEntitiesOfType,
selectBboxModelBase,
selectCanvasSlice,
} from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { modelChanged, vaeSelected } from 'features/controlLayers/store/paramsSlice';
import { selectBboxModelBase } from 'features/controlLayers/store/selectors';
import { modelSelected } from 'features/parameters/store/actions';
import { zParameterModel } from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { selectGlobalRefImageModels, selectRegionalRefImageModels } from 'services/api/hooks/modelsByType';
import type { AnyModelConfig } from 'services/api/types';
import {
isChatGPT4oModelConfig,
isFluxKontextApiModelConfig,
isFluxKontextModelConfig,
isFluxReduxModelConfig,
} from 'services/api/types';
const log = logger('models');
@@ -39,8 +25,9 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
const newModel = result.data;
const newBase = newModel.base;
const didBaseModelChange = state.params.model?.base !== newBase;
const newBaseModel = newModel.base;
const didBaseModelChange = state.params.model?.base !== newBaseModel;
if (didBaseModelChange) {
// we may need to reset some incompatible submodels
@@ -48,7 +35,7 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
// handle incompatible loras
state.loras.loras.forEach((lora) => {
if (lora.model.base !== newBase) {
if (lora.model.base !== newBaseModel) {
dispatch(loraDeleted({ id: lora.id }));
modelsCleared += 1;
}
@@ -56,82 +43,20 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
// handle incompatible vae
const { vae } = state.params;
if (vae && vae.base !== newBase) {
if (vae && vae.base !== newBaseModel) {
dispatch(vaeSelected(null));
modelsCleared += 1;
}
// Handle incompatible reference image models - switch to first compatible model, with some smart logic
// to choose the best available model based on the new main model.
const allRefImageModels = selectGlobalRefImageModels(state).filter(({ base }) => base === newBase);
let newGlobalRefImageModel = null;
// Certain models require the ref image model to be the same as the main model - others just need a matching
// base. Helper to grab the first exact match or the first available model if no exact match is found.
const exactMatchOrFirst = <T extends AnyModelConfig>(candidates: T[]): T | null =>
candidates.find(({ key }) => key === newModel.key) ?? candidates[0] ?? null;
// The only way we can differentiate between FLUX and FLUX Kontext is to check for "kontext" in the name
if (newModel.base === 'flux' && newModel.name.toLowerCase().includes('kontext')) {
const fluxKontextDevModels = allRefImageModels.filter(isFluxKontextModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextDevModels);
} else if (newModel.base === 'chatgpt-4o') {
const chatGPT4oModels = allRefImageModels.filter(isChatGPT4oModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(chatGPT4oModels);
} else if (newModel.base === 'flux-kontext') {
const fluxKontextApiModels = allRefImageModels.filter(isFluxKontextApiModelConfig);
newGlobalRefImageModel = exactMatchOrFirst(fluxKontextApiModels);
} else if (newModel.base === 'flux') {
const fluxReduxModels = allRefImageModels.filter(isFluxReduxModelConfig);
newGlobalRefImageModel = fluxReduxModels[0] ?? null;
} else {
newGlobalRefImageModel = allRefImageModels[0] ?? null;
}
// All ref image entities are updated to use the same new model
const refImageEntities = selectReferenceImageEntities(state);
for (const entity of refImageEntities) {
const shouldUpdateModel =
(entity.config.model && entity.config.model.base !== newBase) ||
(!entity.config.model && newGlobalRefImageModel);
if (shouldUpdateModel) {
dispatch(
refImageModelChanged({
id: entity.id,
modelConfig: newGlobalRefImageModel,
})
);
modelsCleared += 1;
}
}
// For regional guidance, there is no smart logic - we just pick the first available model.
const newRegionalRefImageModel = selectRegionalRefImageModels(state)[0] ?? null;
// All regional guidance entities are updated to use the same new model.
const canvasState = selectCanvasSlice(state);
const canvasRegionalGuidanceEntities = selectAllEntitiesOfType(canvasState, 'regional_guidance');
for (const entity of canvasRegionalGuidanceEntities) {
for (const refImage of entity.referenceImages) {
// Only change the model if the current one is not compatible with the new base model.
const shouldUpdateModel =
(refImage.config.model && refImage.config.model.base !== newBase) ||
(!refImage.config.model && newRegionalRefImageModel);
if (shouldUpdateModel) {
dispatch(
rgRefImageModelChanged({
entityIdentifier: getEntityIdentifier(entity),
referenceImageId: refImage.id,
modelConfig: newRegionalRefImageModel,
})
);
modelsCleared += 1;
}
}
}
// handle incompatible controlnets
// state.canvas.present.controlAdapters.entities.forEach((ca) => {
// if (ca.model?.base !== newBaseModel) {
// modelsCleared += 1;
// if (ca.isEnabled) {
// dispatch(entityIsEnabledToggled({ entityIdentifier: { id: ca.id, type: 'control_adapter' } }));
// }
// }
// });
if (modelsCleared > 0) {
toast({
@@ -146,16 +71,9 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
dispatch(modelChanged({ model: newModel, previousModel: state.params.model }));
const modelBase = selectBboxModelBase(state);
if (modelBase !== state.params.model?.base) {
// Sync generate tab settings whenever the model base changes
dispatch(syncedToOptimalDimension());
if (!selectIsStaging(state)) {
// Canvas tab only syncs if not staging
dispatch(bboxSyncedToOptimalDimension());
}
if (!selectIsStaging(state) && modelBase !== state.params.model?.base) {
dispatch(bboxSyncedToOptimalDimension());
}
},
});

View File

@@ -1,9 +1,7 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { isNil } from 'es-toolkit';
import { bboxHeightChanged, bboxWidthChanged } from 'features/controlLayers/store/canvasSlice';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import {
heightChanged,
setCfgRescaleMultiplier,
setCfgScale,
setGuidance,
@@ -11,7 +9,6 @@ import {
setSteps,
vaePrecisionChanged,
vaeSelected,
widthChanged,
} from 'features/controlLayers/store/paramsSlice';
import { setDefaultSettings } from 'features/parameters/store/actions';
import {
@@ -26,7 +23,6 @@ import {
zParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { t } from 'i18next';
import { modelConfigsAdapterSelectors, modelsApi } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';
@@ -90,16 +86,10 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
}
}
if (!isNil(cfg_rescale_multiplier)) {
if (cfg_rescale_multiplier) {
if (isParameterCFGRescaleMultiplier(cfg_rescale_multiplier)) {
dispatch(setCfgRescaleMultiplier(cfg_rescale_multiplier));
}
} else {
// Set this to 0 if it doesn't have a default. This value is
// easy to miss in the UI when users are resetting defaults
// and leaving it non-zero could lead to detrimental
// effects.
dispatch(setCfgRescaleMultiplier(0));
}
if (steps) {
@@ -116,24 +106,15 @@ export const addSetDefaultSettingsListener = (startAppListening: AppStartListeni
const setSizeOptions = { updateAspectRatio: true, clamp: true };
const isStaging = selectIsStaging(getState());
const activeTab = selectActiveTab(getState());
if (activeTab === 'generate') {
if (!isStaging && width) {
if (isParameterWidth(width)) {
dispatch(widthChanged({ width, ...setSizeOptions }));
}
if (isParameterHeight(height)) {
dispatch(heightChanged({ height, ...setSizeOptions }));
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
}
}
if (activeTab === 'canvas') {
if (!isStaging) {
if (isParameterWidth(width)) {
dispatch(bboxWidthChanged({ width, ...setSizeOptions }));
}
if (isParameterHeight(height)) {
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
}
if (!isStaging && height) {
if (isParameterHeight(height)) {
dispatch(bboxHeightChanged({ height, ...setSizeOptions }));
}
}

View File

@@ -0,0 +1,13 @@
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import { atom, computed } from 'nanostores';
import { flushSync } from 'react-dom';
export const $isLayoutLoading = atom(false);
export const setIsLayoutLoading = (isLoading: boolean) => {
flushSync(() => {
$isLayoutLoading.set(isLoading);
});
};
export const $globalIsLoading = computed([$didStudioInit, $isLayoutLoading], (didStudioInit, isLayoutLoading) => {
return !didStudioInit || isLayoutLoading;
});

View File

@@ -11,7 +11,5 @@ export const $false: ReadableAtom<boolean> = atom(false);
/**
* A fallback non-writable atom that always returns `true`, used when a nanostores atom is only conditionally available
* in a hook or component.
*
* @knipignore
*/
export const $true: ReadableAtom<boolean> = atom(true);

View File

@@ -17,6 +17,7 @@ import { paramsPersistConfig, paramsSlice } from 'features/controlLayers/store/p
import { refImagesPersistConfig, refImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { dynamicPromptsPersistConfig, dynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { galleryPersistConfig, gallerySlice } from 'features/gallery/store/gallerySlice';
import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice';
import { modelManagerV2PersistConfig, modelManagerV2Slice } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { nodesPersistConfig, nodesSlice, nodesUndoableConfig } from 'features/nodes/store/nodesSlice';
import { workflowLibraryPersistConfig, workflowLibrarySlice } from 'features/nodes/store/workflowLibrarySlice';
@@ -56,6 +57,7 @@ const allReducers = {
[changeBoardModalSlice.name]: changeBoardModalSlice.reducer,
[modelManagerV2Slice.name]: modelManagerV2Slice.reducer,
[queueSlice.name]: queueSlice.reducer,
[hrfSlice.name]: hrfSlice.reducer,
[canvasSlice.name]: undoable(canvasSlice.reducer, canvasUndoableConfig),
[workflowSettingsSlice.name]: workflowSettingsSlice.reducer,
[upscaleSlice.name]: upscaleSlice.reducer,
@@ -101,6 +103,7 @@ const persistConfigs: { [key in keyof typeof allReducers]?: PersistConfig } = {
[uiPersistConfig.name]: uiPersistConfig,
[dynamicPromptsPersistConfig.name]: dynamicPromptsPersistConfig,
[modelManagerV2PersistConfig.name]: modelManagerV2PersistConfig,
[hrfPersistConfig.name]: hrfPersistConfig,
[canvasPersistConfig.name]: canvasPersistConfig,
[workflowSettingsPersistConfig.name]: workflowSettingsPersistConfig,
[upscalePersistConfig.name]: upscalePersistConfig,

View File

@@ -14,7 +14,6 @@ export type AppFeature =
| 'githubLink'
| 'discordLink'
| 'bugLink'
| 'aboutModal'
| 'localization'
| 'consoleLogging'
| 'dynamicPrompting'
@@ -30,8 +29,7 @@ export type AppFeature =
| 'hfToken'
| 'retryQueueItem'
| 'cancelAndClearAll'
| 'chatGPT4oHigh'
| 'modelRelationships';
| 'chatGPT4oHigh';
/**
* A disable-able Stable Diffusion feature
*/
@@ -78,7 +76,6 @@ export type AppConfig = {
allowPrivateStylePresets: boolean;
allowClientSideUpload: boolean;
allowPublishWorkflows: boolean;
allowPromptExpansion: boolean;
disabledTabs: TabName[];
disabledFeatures: AppFeature[];
disabledSDFeatures: SDFeature[];

View File

@@ -8,16 +8,21 @@ const Loading = () => {
return (
<Flex
position="absolute"
width="100dvw"
height="100dvh"
alignItems="center"
justifyContent="center"
bg="hsl(220 12% 10% / 1)" // base.900
inset={0}
bg="#151519"
top={0}
right={0}
bottom={0}
left={0}
zIndex={99999}
>
<Image src={InvokeLogoWhite} w="8rem" h="8rem" />
<Spinner
label="Loading"
color="hsl(220 12% 68% / 1)" // base.300
color="grey"
position="absolute"
size="sm"
width="24px !important"

View File

@@ -87,10 +87,14 @@ export const buildGroup = <T extends object>(group: Omit<Group<T>, typeof unique
[uniqueGroupKey]: true,
});
export const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
const isGroup = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is Group<T> => {
return uniqueGroupKey in optionOrGroup && optionOrGroup[uniqueGroupKey] === true;
};
export const isOption = <T extends object>(optionOrGroup: OptionOrGroup<T>): optionOrGroup is T => {
return !(uniqueGroupKey in optionOrGroup);
};
const DefaultOptionComponent = typedMemo(<T extends object>({ option }: { option: T }) => {
const { getOptionId } = usePickerContext();
return <Text fontWeight="bold">{getOptionId(option)}</Text>;

View File

@@ -1,15 +1,20 @@
import { MenuItem } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import {
useNewCanvasSession,
useNewGallerySession,
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
import { allEntitiesDeleted } from 'features/controlLayers/store/canvasSlice';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsCounterClockwiseBold } from 'react-icons/pi';
import { PiArrowsCounterClockwiseBold, PiFilePlusBold } from 'react-icons/pi';
export const SessionMenuItems = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { newGallerySessionWithDialog } = useNewGallerySession();
const { newCanvasSessionWithDialog } = useNewCanvasSession();
const resetCanvasLayers = useCallback(() => {
dispatch(allEntitiesDeleted());
}, [dispatch]);
@@ -18,6 +23,12 @@ export const SessionMenuItems = memo(() => {
}, [dispatch]);
return (
<>
<MenuItem icon={<PiFilePlusBold />} onClick={newGallerySessionWithDialog}>
{t('controlLayers.newGallerySession')}
</MenuItem>
<MenuItem icon={<PiFilePlusBold />} onClick={newCanvasSessionWithDialog}>
{t('controlLayers.newCanvasSession')}
</MenuItem>
<MenuItem icon={<PiArrowsCounterClockwiseBold />} onClick={resetCanvasLayers}>
{t('controlLayers.resetCanvasLayers')}
</MenuItem>

View File

@@ -1,115 +0,0 @@
import { useStore } from '@nanostores/react';
import { WrappedError } from 'common/util/result';
import type { Atom } from 'nanostores';
import { atom } from 'nanostores';
import { useCallback, useEffect, useMemo, useState } from 'react';
type SuccessState<T> = {
status: 'success';
value: T;
error: null;
};
type ErrorState = {
status: 'error';
value: null;
error: Error;
};
type PendingState = {
status: 'pending';
value: null;
error: null;
};
type IdleState = {
status: 'idle';
value: null;
error: null;
};
export type State<T> = IdleState | PendingState | SuccessState<T> | ErrorState;
type UseAsyncStateOptions = {
immediate?: boolean;
};
type UseAsyncReturn<T> = {
$state: Atom<State<T>>;
trigger: () => Promise<void>;
reset: () => void;
};
export const useAsyncState = <T>(execute: () => Promise<T>, options?: UseAsyncStateOptions): UseAsyncReturn<T> => {
const $state = useState(() =>
atom<State<T>>({
status: 'idle',
value: null,
error: null,
})
)[0];
const trigger = useCallback(async () => {
$state.set({
status: 'pending',
value: null,
error: null,
});
try {
const value = await execute();
$state.set({
status: 'success',
value,
error: null,
});
} catch (error) {
$state.set({
status: 'error',
value: null,
error: WrappedError.wrap(error),
});
}
}, [$state, execute]);
const reset = useCallback(() => {
$state.set({
status: 'idle',
value: null,
error: null,
});
}, [$state]);
useEffect(() => {
if (options?.immediate) {
trigger();
}
}, [options?.immediate, trigger]);
const api = useMemo(
() =>
({
$state,
trigger,
reset,
}) satisfies UseAsyncReturn<T>,
[$state, trigger, reset]
);
return api;
};
type UseAsyncReturnReactive<T> = {
state: State<T>;
trigger: () => Promise<void>;
reset: () => void;
};
export const useAsyncStateReactive = <T>(
execute: () => Promise<T>,
options?: UseAsyncStateOptions
): UseAsyncReturnReactive<T> => {
const { $state, trigger, reset } = useAsyncState(execute, options);
const state = useStore($state);
return { state, trigger, reset };
};

View File

@@ -73,7 +73,7 @@ export const useBoolean = (initialValue: boolean): UseBoolean => {
};
};
type UseDisclosure = {
export type UseDisclosure = {
isOpen: boolean;
open: () => void;
close: () => void;

View File

@@ -0,0 +1,165 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/**
* Adapted from https://github.com/chakra-ui/chakra-ui/blob/v2/packages/hooks/src/use-outside-click.ts
*
* The main change here is to support filtering of outside clicks via a `filter` function.
*
* This lets us work around issues with portals and components like popovers, which typically close on an outside click.
*
* For example, consider a popover that has a custom drop-down component inside it, which uses a portal to render
* the drop-down options. The original outside click handler would close the popover when clicking on the drop-down options,
* because the click is outside the popover - but we expect the popover to stay open in this case.
*
* A filter function like this can fix that:
*
* ```ts
* const filter = (el: HTMLElement) => el.className.includes('chakra-portal') || el.id.includes('react-select')
* ```
*
* This ignores clicks on react-select-based drop-downs and Chakra UI portals and is used as the default filter.
*/
import { useCallback, useEffect, useRef } from 'react';
type FilterFunction = (el: HTMLElement | SVGElement) => boolean;
export function useCallbackRef<T extends (...args: any[]) => any>(
callback: T | undefined,
deps: React.DependencyList = []
) {
const callbackRef = useRef(callback);
useEffect(() => {
callbackRef.current = callback;
});
// eslint-disable-next-line react-hooks/exhaustive-deps
return useCallback(((...args) => callbackRef.current?.(...args)) as T, deps);
}
export interface UseOutsideClickProps {
/**
* Whether the hook is enabled
*/
enabled?: boolean;
/**
* The reference to a DOM element.
*/
ref: React.RefObject<HTMLElement | null>;
/**
* Function invoked when a click is triggered outside the referenced element.
*/
handler?: (e: Event) => void;
/**
* A function that filters the elements that should be considered as outside clicks.
*
* If omitted, a default filter function that ignores clicks in Chakra UI portals and react-select components is used.
*/
filter?: FilterFunction;
}
export const DEFAULT_FILTER: FilterFunction = (el) => {
if (el instanceof SVGElement) {
// SVGElement's type appears to be incorrect. Its className is not a string, which causes `includes` to fail.
// Let's assume that SVG elements with a class name are not part of the portal and should not be filtered.
return false;
}
return el.className.includes('chakra-portal') || el.id.includes('react-select');
};
/**
* Example, used in components like Dialogs and Popovers, so they can close
* when a user clicks outside them.
*/
export function useFilterableOutsideClick(props: UseOutsideClickProps) {
const { ref, handler, enabled = true, filter = DEFAULT_FILTER } = props;
const savedHandler = useCallbackRef(handler);
const stateRef = useRef({
isPointerDown: false,
ignoreEmulatedMouseEvents: false,
});
const state = stateRef.current;
useEffect(() => {
if (!enabled) {
return;
}
const onPointerDown: any = (e: PointerEvent) => {
if (isValidEvent(e, ref, filter)) {
state.isPointerDown = true;
}
};
const onMouseUp: any = (event: MouseEvent) => {
if (state.ignoreEmulatedMouseEvents) {
state.ignoreEmulatedMouseEvents = false;
return;
}
if (state.isPointerDown && handler && isValidEvent(event, ref)) {
state.isPointerDown = false;
savedHandler(event);
}
};
const onTouchEnd = (event: TouchEvent) => {
state.ignoreEmulatedMouseEvents = true;
if (handler && state.isPointerDown && isValidEvent(event, ref)) {
state.isPointerDown = false;
savedHandler(event);
}
};
const doc = getOwnerDocument(ref.current);
doc.addEventListener('mousedown', onPointerDown, true);
doc.addEventListener('mouseup', onMouseUp, true);
doc.addEventListener('touchstart', onPointerDown, true);
doc.addEventListener('touchend', onTouchEnd, true);
return () => {
doc.removeEventListener('mousedown', onPointerDown, true);
doc.removeEventListener('mouseup', onMouseUp, true);
doc.removeEventListener('touchstart', onPointerDown, true);
doc.removeEventListener('touchend', onTouchEnd, true);
};
}, [handler, ref, savedHandler, state, enabled, filter]);
}
function isValidEvent(event: Event, ref: React.RefObject<HTMLElement | null>, filter?: FilterFunction): boolean {
const target = (event.composedPath?.()[0] ?? event.target) as HTMLElement;
if (target) {
const doc = getOwnerDocument(target);
if (!doc.contains(target)) {
return false;
}
}
if (ref.current?.contains(target)) {
return false;
}
// This is the main logic change from the original hook.
if (filter) {
// Check if the click is inside an element matching the filter.
// This is used for portal-awareness or other general exclusion cases.
let currentElement: HTMLElement | null = target;
// Traverse up the DOM tree from the target element.
while (currentElement && currentElement !== document.body) {
if (filter(currentElement)) {
return false;
}
currentElement = currentElement.parentElement;
}
}
// If the click is not inside the ref and not inside a portal, it's a valid outside click.
return true;
}
function getOwnerDocument(node?: Element | null): Document {
return node?.ownerDocument ?? document;
}

View File

@@ -6,7 +6,7 @@ import { useDeleteCurrentQueueItem } from 'features/queue/hooks/useDeleteCurrent
import { useInvoke } from 'features/queue/hooks/useInvoke';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { getFocusedRegion } from './focus';
@@ -69,7 +69,7 @@ export const useGlobalHotkeys = () => {
id: 'selectGenerateTab',
category: 'app',
callback: () => {
navigationApi.switchToTab('generate');
dispatch(setActiveTab('generate'));
},
dependencies: [dispatch],
});
@@ -78,7 +78,7 @@ export const useGlobalHotkeys = () => {
id: 'selectCanvasTab',
category: 'app',
callback: () => {
navigationApi.switchToTab('canvas');
dispatch(setActiveTab('canvas'));
},
dependencies: [dispatch],
});
@@ -87,7 +87,7 @@ export const useGlobalHotkeys = () => {
id: 'selectUpscalingTab',
category: 'app',
callback: () => {
navigationApi.switchToTab('upscaling');
dispatch(setActiveTab('upscaling'));
},
dependencies: [dispatch],
});
@@ -96,7 +96,7 @@ export const useGlobalHotkeys = () => {
id: 'selectWorkflowsTab',
category: 'app',
callback: () => {
navigationApi.switchToTab('workflows');
dispatch(setActiveTab('workflows'));
},
dependencies: [dispatch],
});
@@ -105,7 +105,7 @@ export const useGlobalHotkeys = () => {
id: 'selectModelsTab',
category: 'app',
callback: () => {
navigationApi.switchToTab('models');
dispatch(setActiveTab('models'));
},
options: {
enabled: isModelManagerEnabled,
@@ -117,7 +117,7 @@ export const useGlobalHotkeys = () => {
id: 'selectQueueTab',
category: 'app',
callback: () => {
navigationApi.switchToTab('queue');
dispatch(setActiveTab('queue'));
},
dependencies: [dispatch, isModelManagerEnabled],
});

View File

@@ -21,15 +21,11 @@ type UseImageUploadButtonArgs =
isDisabled?: boolean;
allowMultiple: false;
onUpload?: (imageDTO: ImageDTO) => void;
onUploadStarted?: (files: File) => void;
onError?: (error: unknown) => void;
}
| {
isDisabled?: boolean;
allowMultiple: true;
onUpload?: (imageDTOs: ImageDTO[]) => void;
onUploadStarted?: (files: File[]) => void;
onError?: (error: unknown) => void;
};
const log = logger('gallery');
@@ -53,13 +49,7 @@ const log = logger('gallery');
* <Button {...getUploadButtonProps()} /> // will open the file dialog on click
* <input {...getUploadInputProps()} /> // hidden, handles native upload functionality
*/
export const useImageUploadButton = ({
onUpload,
isDisabled,
allowMultiple,
onUploadStarted,
onError,
}: UseImageUploadButtonArgs) => {
export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: UseImageUploadButtonArgs) => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const isClientSideUploadEnabled = useAppSelector(selectIsClientSideUploadEnabled);
const [uploadImage, request] = useUploadImageMutation();
@@ -81,7 +71,6 @@ export const useImageUploadButton = ({
}
const file = files[0];
assert(file !== undefined); // should never happen
onUploadStarted?.(file);
const imageDTO = await uploadImage({
file,
image_category: 'user',
@@ -93,8 +82,6 @@ export const useImageUploadButton = ({
onUpload(imageDTO);
}
} else {
onUploadStarted?.(files);
let imageDTOs: ImageDTO[] = [];
if (isClientSideUploadEnabled && files.length > 1) {
imageDTOs = await Promise.all(files.map((file, i) => clientSideUpload(file, i)));
@@ -115,7 +102,6 @@ export const useImageUploadButton = ({
}
}
} catch (error) {
onError?.(error);
toast({
id: 'UPLOAD_FAILED',
title: t('toast.imageUploadFailed'),
@@ -123,17 +109,7 @@ export const useImageUploadButton = ({
});
}
},
[
allowMultiple,
onUploadStarted,
uploadImage,
autoAddBoardId,
onUpload,
isClientSideUploadEnabled,
clientSideUpload,
onError,
t,
]
[allowMultiple, autoAddBoardId, onUpload, uploadImage, isClientSideUploadEnabled, clientSideUpload, t]
);
const onDropRejected = useCallback(

View File

@@ -0,0 +1,132 @@
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { EMPTY_ARRAY } from 'app/store/constants';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import type { GroupBase } from 'chakra-react-select';
import { uniq } from 'es-toolkit/compat';
import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetRelatedModelIdsBatchQuery } from 'services/api/endpoints/modelRelationships';
import type { AnyModelConfig } from 'services/api/types';
import { useGroupedModelCombobox } from './useGroupedModelCombobox';
type UseRelatedGroupedModelComboboxArg<T extends AnyModelConfig> = {
modelConfigs: T[];
selectedModel?: ModelIdentifierField | null;
onChange: (value: T | null) => void;
getIsDisabled?: (model: T) => boolean;
isLoading?: boolean;
groupByType?: boolean;
};
// Custom hook to overlay the grouped model combobox with related models on top!
// Cleaner than hooking into useGroupedModelCombobox with a flag to enable/disable the related models
// Also allows for related models to be shown conditionally with some pretty simple logic if it ends up as a config flag.
type UseRelatedGroupedModelComboboxReturn = {
value: ComboboxOption | undefined | null;
options: GroupBase<ComboboxOption>[];
onChange: ComboboxOnChange;
placeholder: string;
noOptionsMessage: () => string;
};
const selectSelectedModelKeys = createMemoizedSelector(selectParamsSlice, selectLoRAsSlice, (params, loras) => {
const keys: string[] = [];
const main = params.model;
const vae = params.vae;
const refiner = params.refinerModel;
const controlnet = params.controlLora;
if (main) {
keys.push(main.key);
}
if (vae) {
keys.push(vae.key);
}
if (refiner) {
keys.push(refiner.key);
}
if (controlnet) {
keys.push(controlnet.key);
}
for (const { model } of loras.loras) {
keys.push(model.key);
}
return uniq(keys);
});
export function useRelatedGroupedModelCombobox<T extends AnyModelConfig>({
modelConfigs,
selectedModel,
onChange,
isLoading = false,
getIsDisabled,
groupByType,
}: UseRelatedGroupedModelComboboxArg<T>): UseRelatedGroupedModelComboboxReturn {
const { t } = useTranslation();
const selectedKeys = useAppSelector(selectSelectedModelKeys);
const { relatedKeys } = useGetRelatedModelIdsBatchQuery(selectedKeys, {
selectFromResult: ({ data }) => {
if (!data) {
return { relatedKeys: EMPTY_ARRAY };
}
return { relatedKeys: data };
},
});
// Base grouped options
const base = useGroupedModelCombobox({
modelConfigs,
selectedModel,
onChange,
getIsDisabled,
isLoading,
groupByType,
});
const options = useMemo(() => {
if (relatedKeys.length === 0) {
return base.options;
}
const relatedOptions: ComboboxOption[] = [];
const updatedGroups: GroupBase<ComboboxOption>[] = [];
for (const group of base.options) {
const remainingOptions: ComboboxOption[] = [];
for (const option of group.options) {
if (relatedKeys.includes(option.value)) {
relatedOptions.push({ ...option, label: `* ${option.label}` });
} else {
remainingOptions.push(option);
}
}
if (remainingOptions.length > 0) {
updatedGroups.push({
label: group.label,
options: remainingOptions,
});
}
}
if (relatedOptions.length > 0) {
return [{ label: t('modelManager.relatedModels'), options: relatedOptions }, ...updatedGroups];
} else {
return updatedGroups;
}
}, [base.options, relatedKeys, t]);
return {
...base,
options,
};
}

View File

@@ -0,0 +1,28 @@
import type { Selector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import type { Atom, WritableAtom } from 'nanostores';
import { atom } from 'nanostores';
import { useEffect, useState } from 'react';
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
export const useSelectorAsAtom = <T extends Selector<RootState, any>>(selector: T): Atom<ReturnType<T>> => {
const store = useAppStore();
const $atom = useState<WritableAtom<ReturnType<T>>>(() => atom<ReturnType<T>>(selector(store.getState())))[0];
useEffect(() => {
const unsubscribe = store.subscribe(() => {
const prev = $atom.get();
const next = selector(store.getState());
if (prev !== next) {
$atom.set(next);
}
});
return () => {
unsubscribe();
};
}, [$atom, selector, store]);
return $atom;
};

View File

@@ -1,20 +0,0 @@
export type Deferred<T> = {
promise: Promise<T>;
resolve: (value: T) => void;
reject: (error: Error) => void;
};
/**
* Create a promise and expose its resolve and reject callbacks.
*/
export const createDeferredPromise = <T>(): Deferred<T> => {
let resolve!: (value: T) => void;
let reject!: (error: Error) => void;
const promise = new Promise<T>((res, rej) => {
resolve = res;
reject = rej;
});
return { promise, resolve, reject };
};

View File

@@ -0,0 +1,6 @@
/**
* Get the keys of an object. This is a wrapper around `Object.keys` that types the result as an array of the keys of the object.
* @param obj The object to get the keys of.
* @returns The keys of the object.
*/
export const objectKeys = <T extends Record<string, unknown>>(obj: T) => Object.keys(obj) as Array<keyof T>;

View File

@@ -57,7 +57,7 @@ export class Err<E> {
* @template T The type of the value in the `Ok` case.
* @template E The type of the error in the `Err` case.
*/
type Result<T, E = Error> = Ok<T> | Err<E>;
export type Result<T, E = Error> = Ok<T> | Err<E>;
/**
* Creates a successful result.

View File

@@ -1,23 +0,0 @@
import { Alert, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectSaveAllImagesToGallery } from 'features/controlLayers/store/canvasSettingsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasAlertsSaveAllImagesToGallery = memo(() => {
const { t } = useTranslation();
const saveAllImagesToGallery = useAppSelector(selectSaveAllImagesToGallery);
if (!saveAllImagesToGallery) {
return null;
}
return (
<Alert status="info" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<AlertIcon />
<AlertTitle>{t('controlLayers.settings.saveAllImagesToGallery.alert')}</AlertTitle>
</Alert>
);
});
CanvasAlertsSaveAllImagesToGallery.displayName = 'CanvasAlertsSaveAllImagesToGallery';

View File

@@ -1,4 +1,3 @@
import type { SpinnerProps } from '@invoke-ai/ui-library';
import { Spinner } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
@@ -6,7 +5,7 @@ import { useAllEntityAdapters } from 'features/controlLayers/contexts/EntityAdap
import { computed } from 'nanostores';
import { memo, useMemo } from 'react';
export const CanvasBusySpinner = memo((props: SpinnerProps) => {
export const CanvasBusySpinner = memo(() => {
const canvasManager = useCanvasManager();
const allEntityAdapters = useAllEntityAdapters();
const $isPendingRectCalculation = useMemo(
@@ -22,7 +21,7 @@ export const CanvasBusySpinner = memo((props: SpinnerProps) => {
const isCompositing = useStore(canvasManager.compositor.$isBusy);
if (isRasterizing || isCompositing || isPendingRectCalculation) {
return <Spinner opacity={0.3} {...props} />;
return <Spinner opacity={0.3} />;
}
return null;
});

View File

@@ -12,10 +12,6 @@ const addControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.
const addRegionalGuidanceReferenceImageFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'regional_guidance_with_reference_image',
});
const addResizedControlLayerFromImageDndTargetData = newCanvasEntityFromImageDndTarget.getData({
type: 'control_layer',
withResize: true,
});
export const CanvasDropArea = memo(() => {
const { t } = useTranslation();
@@ -49,6 +45,7 @@ export const CanvasDropArea = memo(() => {
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
@@ -57,14 +54,6 @@ export const CanvasDropArea = memo(() => {
isDisabled={isBusy}
/>
</GridItem>
<GridItem position="relative">
<DndDropTarget
dndTarget={newCanvasEntityFromImageDndTarget}
dndTargetData={addResizedControlLayerFromImageDndTargetData}
label={t('controlLayers.canvasContextMenu.newResizedControlLayer')}
isDisabled={isBusy}
/>
</GridItem>
</Grid>
</>
);

View File

@@ -0,0 +1,27 @@
// import { Button, Flex } from '@invoke-ai/ui-library';
// import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
// import { useAddInpaintMaskDenoiseLimit, useAddInpaintMaskNoise } from 'features/controlLayers/hooks/addLayerHooks';
// import { useTranslation } from 'react-i18next';
// import { PiPlusBold } from 'react-icons/pi';
// Removed buttons because denosie limit is not helpful for many architectures
// Users can access with right click menu instead.
// If buttons for noise or new features are deemed important in the future, add them back here.
export const InpaintMaskAddButtons = () => {
// Buttons are temporarily hidden. To restore, uncomment the code below.
return null;
// const entityIdentifier = useEntityIdentifierContext('inpaint_mask');
// const { t } = useTranslation();
// const addInpaintMaskDenoiseLimit = useAddInpaintMaskDenoiseLimit(entityIdentifier);
// const addInpaintMaskNoise = useAddInpaintMaskNoise(entityIdentifier);
// return (
// <Flex w="full" p={2} justifyContent="center">
// <Button size="sm" variant="ghost" leftIcon={<PiPlusBold />} onClick={addInpaintMaskDenoiseLimit}>
// {t('controlLayers.denoiseLimit')}
// </Button>
// <Button size="sm" variant="ghost" leftIcon={<PiPlusBold />} onClick={addInpaintMaskNoise}>
// {t('controlLayers.imageNoise')}
// </Button>
// </Flex>
// );
};

View File

@@ -14,7 +14,7 @@ import { useTranslation } from 'react-i18next';
const [useNewGallerySessionDialog] = buildUseBoolean(false);
const [useNewCanvasSessionDialog] = buildUseBoolean(false);
const useNewGallerySession = () => {
export const useNewGallerySession = () => {
const dispatch = useAppDispatch();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const newSessionDialog = useNewGallerySessionDialog();
@@ -35,7 +35,7 @@ const useNewGallerySession = () => {
return { newGallerySessionImmediate, newGallerySessionWithDialog };
};
const useNewCanvasSession = () => {
export const useNewCanvasSession = () => {
const dispatch = useAppDispatch();
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const newSessionDialog = useNewCanvasSessionDialog();

View File

@@ -4,17 +4,9 @@ import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import {
refImageDeleted,
refImageIsEnabledToggled,
selectRefImageEntityIds,
} from 'features/controlLayers/store/refImagesSlice';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import { refImageDeleted, selectRefImageEntityIds } from 'features/controlLayers/store/refImagesSlice';
import { memo, useCallback, useMemo } from 'react';
import { PiCircleBold, PiCircleFill, PiTrashBold, PiWarningBold } from 'react-icons/pi';
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
import { PiTrashBold } from 'react-icons/pi';
const textSx: SystemStyleObject = {
color: 'base.300',
@@ -32,63 +24,25 @@ export const RefImageHeader = memo(() => {
);
const refImageNumber = useAppSelector(selectRefImageNumber);
const entity = useRefImageEntity(id);
const mainModelConfig = useAppSelector(selectMainModelConfig);
const warnings = useMemo(() => {
return getGlobalReferenceImageWarnings(entity, mainModelConfig);
}, [entity, mainModelConfig]);
const deleteRefImage = useCallback(() => {
dispatch(refImageDeleted({ id }));
}, [dispatch, id]);
const toggleIsEnabled = useCallback(() => {
dispatch(refImageIsEnabledToggled({ id }));
}, [dispatch, id]);
return (
<Flex justifyContent="space-between" alignItems="center" w="full" ps={2}>
<Text fontWeight="semibold" sx={textSx} data-is-error={!entity.config.image}>
Reference Image #{refImageNumber}
</Text>
<Flex alignItems="center" gap={1}>
{warnings.length > 0 && (
<IconButton
as="span"
size="sm"
variant="link"
alignSelf="stretch"
aria-label="warnings"
tooltip={<RefImageWarningTooltipContent warnings={warnings} />}
icon={<PiWarningBold />}
colorScheme="warning"
/>
)}
{!entity.isEnabled && (
<Text fontSize="xs" fontStyle="italic" color="base.400">
Disabled
</Text>
)}
<IconButton
tooltip={entity.isEnabled ? 'Disable Reference Image' : 'Enable Reference Image'}
size="xs"
variant="link"
alignSelf="stretch"
aria-label={entity.isEnabled ? 'Disable ref image' : 'Enable ref image'}
onClick={toggleIsEnabled}
icon={entity.isEnabled ? <PiCircleFill /> : <PiCircleBold />}
/>
<IconButton
tooltip="Delete Reference Image"
size="xs"
variant="link"
alignSelf="stretch"
aria-label="Delete ref image"
onClick={deleteRefImage}
icon={<PiTrashBold />}
colorScheme="error"
/>
</Flex>
<IconButton
tooltip="Delete Reference Image"
size="xs"
variant="link"
alignSelf="stretch"
aria-label="Delete ref image"
onClick={deleteRefImage}
icon={<PiTrashBold />}
colorScheme="error"
/>
</Flex>
);
});

View File

@@ -61,7 +61,7 @@ export const RefImageImage = memo(
)}
{imageDTO && (
<>
<DndImage imageDTO={imageDTO} borderRadius="base" borderWidth={1} borderStyle="solid" w="full" />
<DndImage imageDTO={imageDTO} borderWidth={1} borderStyle="solid" w="full" />
<Flex position="absolute" flexDir="column" top={2} insetInlineEnd={2} gap={1}>
<DndImageIcon
onClick={handleResetControlImage}

View File

@@ -1,12 +1,9 @@
import { Button, Collapse, Divider, Flex, IconButton } from '@invoke-ai/ui-library';
import { Button, Collapse, Divider, Flex } from '@invoke-ai/ui-library';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
import { RefImagePreview } from 'features/controlLayers/components/RefImage/RefImagePreview';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { RefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { getDefaultRefImageConfig } from 'features/controlLayers/hooks/addLayerHooks';
import { useNewGlobalReferenceImageFromBbox } from 'features/controlLayers/hooks/saveCanvasHooks';
import { useCanvasIsBusySafe } from 'features/controlLayers/hooks/useCanvasIsBusy';
import {
refImageAdded,
selectIsRefImagePanelOpen,
@@ -16,10 +13,8 @@ import {
import { imageDTOToImageWithDims } from 'features/controlLayers/store/util';
import { addGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiBoundingBoxBold, PiUploadBold } from 'react-icons/pi';
import { PiUploadBold } from 'react-icons/pi';
import type { ImageDTO } from 'services/api/types';
import { RefImageHeader } from './RefImageHeader';
@@ -83,7 +78,6 @@ MaxRefImages.displayName = 'MaxRefImages';
const AddRefImageDropTargetAndButton = memo(() => {
const { dispatch, getState } = useAppStore();
const tab = useAppSelector(selectActiveTab);
const uploadOptions = useMemo(
() =>
@@ -101,7 +95,7 @@ const AddRefImageDropTargetAndButton = memo(() => {
const uploadApi = useImageUploadButton(uploadOptions);
return (
<Flex gap={1} h="full" w="full">
<>
<Button
position="relative"
size="sm"
@@ -118,31 +112,7 @@ const AddRefImageDropTargetAndButton = memo(() => {
<input {...uploadApi.getUploadInputProps()} />
<DndDropTarget label="Drop" dndTarget={addGlobalReferenceImageDndTarget} dndTargetData={dndTargetData} />
</Button>
{tab === 'canvas' && (
<CanvasManagerProviderGate>
<BboxButton />
</CanvasManagerProviderGate>
)}
</Flex>
);
});
const BboxButton = memo(() => {
const { t } = useTranslation();
const isBusy = useCanvasIsBusySafe();
const newGlobalReferenceImageFromBbox = useNewGlobalReferenceImageFromBbox();
return (
<IconButton
size="lg"
variant="outline"
h="full"
icon={<PiBoundingBoxBold />}
onClick={newGlobalReferenceImageFromBbox}
isDisabled={isBusy}
aria-label={t('controlLayers.pullBboxIntoReferenceImage')}
tooltip={t('controlLayers.pullBboxIntoReferenceImage')}
/>
</>
);
});
AddRefImageDropTargetAndButton.displayName = 'AddRefImageDropTargetAndButton';

View File

@@ -1,35 +1,25 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import { areBasesCompatibleForRefImage } from 'features/controlLayers/store/validators';
import { selectBase } from 'features/controlLayers/store/paramsSlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGlobalReferenceImageModels } from 'services/api/hooks/modelsByType';
import type {
ChatGPT4oModelConfig,
FLUXKontextModelConfig,
FLUXReduxModelConfig,
IPAdapterModelConfig,
} from 'services/api/types';
import type { AnyModelConfig, ApiModelConfig, FLUXReduxModelConfig, IPAdapterModelConfig } from 'services/api/types';
type Props = {
modelKey: string | null;
onChangeModel: (
modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig
) => void;
onChangeModel: (modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => void;
};
export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
const { t } = useTranslation();
const mainModelConfig = useAppSelector(selectMainModelConfig);
const currentBaseModel = useAppSelector(selectBase);
const [modelConfigs, { isLoading }] = useGlobalReferenceImageModels();
const selectedModel = useMemo(() => modelConfigs.find((m) => m.key === modelKey), [modelConfigs, modelKey]);
const _onChangeModel = useCallback(
(
modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig | null
) => {
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null) => {
if (!modelConfig) {
return;
}
@@ -39,10 +29,12 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
);
const getIsDisabled = useCallback(
(model: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig): boolean => {
return !areBasesCompatibleForRefImage(mainModelConfig, model);
(model: AnyModelConfig): boolean => {
const hasMainModel = Boolean(currentBaseModel);
const hasSameBase = currentBaseModel === model.base;
return !hasMainModel || !hasSameBase;
},
[mainModelConfig]
[currentBaseModel]
);
const { options, value, onChange, noOptionsMessage } = useGroupedModelCombobox({
@@ -55,11 +47,7 @@ export const RefImageModel = memo(({ modelKey, onChangeModel }: Props) => {
return (
<Tooltip label={selectedModel?.description}>
<FormControl
isInvalid={!value || !areBasesCompatibleForRefImage(mainModelConfig, selectedModel)}
w="full"
minW={0}
>
<FormControl isInvalid={!value || currentBaseModel !== selectedModel?.base} w="full" minW={0}>
<Combobox
options={options}
placeholder={t('common.placeholderSelectAModel')}

View File

@@ -1,41 +1,24 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Icon, IconButton, Image, Skeleton, Text, Tooltip } from '@invoke-ai/ui-library';
import { Flex, Icon, IconButton, Image, Skeleton, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { round } from 'es-toolkit/compat';
import { useRefImageEntity } from 'features/controlLayers/components/RefImage/useRefImageEntity';
import { useRefImageIdContext } from 'features/controlLayers/contexts/RefImageIdContext';
import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import {
refImageSelected,
selectIsRefImagePanelOpen,
selectSelectedRefEntityId,
} from 'features/controlLayers/store/refImagesSlice';
import { isIPAdapterConfig } from 'features/controlLayers/store/types';
import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { PiExclamationMarkBold, PiEyeSlashBold, PiImageBold } from 'react-icons/pi';
import { PiExclamationMarkBold, PiImageBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { RefImageWarningTooltipContent } from './RefImageWarningTooltipContent';
const baseSx: SystemStyleObject = {
'&[data-is-open="true"]': {
borderColor: 'invokeBlue.300',
},
'&[data-is-disabled="true"]': {
img: {
opacity: 0.4,
filter: 'grayscale(100%)',
},
},
'&[data-is-error="true"]': {
borderColor: 'error.500',
img: {
opacity: 0.4,
filter: 'grayscale(100%)',
},
},
};
const weightDisplaySx: SystemStyleObject = {
@@ -68,7 +51,6 @@ export const RefImagePreview = memo(() => {
const dispatch = useAppDispatch();
const id = useRefImageIdContext();
const entity = useRefImageEntity(id);
const mainModelConfig = useAppSelector(selectMainModelConfig);
const selectedEntityId = useAppSelector(selectSelectedRefEntityId);
const isPanelOpen = useAppSelector(selectIsRefImagePanelOpen);
const [showWeightDisplay, setShowWeightDisplay] = useState(false);
@@ -94,10 +76,6 @@ export const RefImagePreview = memo(() => {
};
}, [entity.config]);
const warnings = useMemo(() => {
return getGlobalReferenceImageWarnings(entity, mainModelConfig);
}, [entity, mainModelConfig]);
const onClick = useCallback(() => {
dispatch(refImageSelected({ id }));
}, [dispatch, id]);
@@ -119,82 +97,66 @@ export const RefImagePreview = memo(() => {
flexShrink={0}
data-is-open={selectedEntityId === id && isPanelOpen}
data-is-error={true}
data-is-disabled={!entity.isEnabled}
sx={sx}
/>
);
}
return (
<Tooltip label={warnings.length > 0 ? <RefImageWarningTooltipContent warnings={warnings} /> : undefined}>
<Flex
position="relative"
borderWidth={1}
borderStyle="solid"
borderRadius="base"
<Flex
position="relative"
borderWidth={1}
borderStyle="solid"
borderRadius="base"
aspectRatio="1/1"
maxW="full"
maxH="full"
flexShrink={0}
sx={sx}
data-is-open={selectedEntityId === id && isPanelOpen}
data-is-error={!entity.config.model}
role="button"
onClick={onClick}
cursor="pointer"
>
<Image
src={imageDTO?.thumbnail_url}
objectFit="contain"
aspectRatio="1/1"
height={imageDTO?.height}
fallback={<Skeleton h="full" aspectRatio="1/1" />}
maxW="full"
maxH="full"
flexShrink={0}
sx={sx}
data-is-open={selectedEntityId === id && isPanelOpen}
data-is-error={warnings.length > 0}
data-is-disabled={!entity.isEnabled}
role="button"
onClick={onClick}
cursor="pointer"
overflow="hidden"
>
<Image
src={imageDTO?.thumbnail_url}
objectFit="contain"
aspectRatio="1/1"
height={imageDTO?.height}
fallback={<Skeleton h="full" aspectRatio="1/1" />}
maxW="full"
maxH="full"
borderRadius="base"
/>
{isIPAdapterConfig(entity.config) && (
<Flex
position="absolute"
inset={0}
fontWeight="semibold"
alignItems="center"
justifyContent="center"
zIndex={1}
data-visible={showWeightDisplay}
sx={weightDisplaySx}
>
<Text filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))">
{`${round(entity.config.weight * 100, 2)}%`}
</Text>
</Flex>
)}
{!entity.config.model && (
<Icon
position="absolute"
top="50%"
left="50%"
transform="translateX(-50%) translateY(-50%)"
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
color="error.500"
boxSize={16}
as={PiExclamationMarkBold}
/>
{isIPAdapterConfig(entity.config) && (
<Flex
position="absolute"
inset={0}
fontWeight="semibold"
alignItems="center"
justifyContent="center"
zIndex={1}
data-visible={showWeightDisplay}
sx={weightDisplaySx}
>
<Text filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))">
{`${round(entity.config.weight * 100, 2)}%`}
</Text>
</Flex>
)}
{!entity.isEnabled && (
<Icon
position="absolute"
top="50%"
left="50%"
transform="translateX(-50%) translateY(-50%)"
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
color="base.300"
boxSize={8}
as={PiEyeSlashBold}
/>
)}
{entity.isEnabled && warnings.length > 0 && (
<Icon
position="absolute"
top="50%"
left="50%"
transform="translateX(-50%) translateY(-50%)"
filter="drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 2px rgba(0, 0, 0, 1))"
color="error.500"
boxSize={12}
as={PiExclamationMarkBold}
/>
)}
</Flex>
</Tooltip>
)}
</Flex>
);
});
RefImagePreview.displayName = 'RefImagePreview';

View File

@@ -38,13 +38,7 @@ import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { memo, useCallback, useMemo } from 'react';
import type {
ChatGPT4oModelConfig,
FLUXKontextModelConfig,
FLUXReduxModelConfig,
ImageDTO,
IPAdapterModelConfig,
} from 'services/api/types';
import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { RefImageImage } from './RefImageImage';
@@ -90,7 +84,7 @@ const RefImageSettingsContent = memo(() => {
);
const onChangeModel = useCallback(
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig) => {
(modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig) => {
dispatch(refImageModelChanged({ id, modelConfig }));
},
[dispatch, id]

View File

@@ -1,18 +0,0 @@
import { Flex, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
import { upperFirst } from 'es-toolkit/compat';
import { useTranslation } from 'react-i18next';
export const RefImageWarningTooltipContent = ({ warnings }: { warnings: string[] }) => {
const { t } = useTranslation();
return (
<Flex flexDir="column">
<Text fontWeight="semibold">Invalid Reference Image:</Text>
<UnorderedList>
{warnings.map((tKey) => (
<ListItem key={tKey}>{upperFirst(t(tKey))}</ListItem>
))}
</UnorderedList>
</Flex>
);
};

View File

@@ -83,7 +83,7 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
</Flex>
<Flex alignItems="center" gap={2} p={4}>
<Text textAlign="center" color="base.300">
<Trans i18nKey="controlLayers.referenceImageEmptyStateWithCanvasOptions" components={components} />
<Trans i18nKey="controlLayers.referenceImageEmptyStateWithCanvasTab" components={components} />
</Text>
</Flex>
<input {...uploadApi.getUploadInputProps()} />

View File

@@ -26,7 +26,6 @@ import { CanvasSettingsPreserveMaskCheckbox } from 'features/controlLayers/compo
import { CanvasSettingsPressureSensitivityCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsPressureSensitivity';
import { CanvasSettingsRecalculateRectsButton } from 'features/controlLayers/components/Settings/CanvasSettingsRecalculateRectsButton';
import { CanvasSettingsRuleOfThirdsSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsRuleOfThirdsGuideSwitch';
import { CanvasSettingsSaveAllImagesToGalleryCheckbox } from 'features/controlLayers/components/Settings/CanvasSettingsSaveAllImagesToGalleryCheckbox';
import { CanvasSettingsShowHUDSwitch } from 'features/controlLayers/components/Settings/CanvasSettingsShowHUDSwitch';
import { CanvasSettingsShowProgressOnCanvas } from 'features/controlLayers/components/Settings/CanvasSettingsShowProgressOnCanvasSwitch';
import { memo } from 'react';
@@ -62,7 +61,6 @@ export const CanvasSettingsPopover = memo(() => {
<CanvasSettingsPreserveMaskCheckbox />
<CanvasSettingsClipToBboxCheckbox />
<CanvasSettingsOutputOnlyMaskedRegionsCheckbox />
<CanvasSettingsSaveAllImagesToGalleryCheckbox />
</Flex>
<Divider />

View File

@@ -1,25 +0,0 @@
import { Checkbox, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectSaveAllImagesToGallery,
settingsSaveAllImagesToGalleryToggled,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
export const CanvasSettingsSaveAllImagesToGalleryCheckbox = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const saveAllImagesToGallery = useAppSelector(selectSaveAllImagesToGallery);
const onChange = useCallback(() => {
dispatch(settingsSaveAllImagesToGalleryToggled());
}, [dispatch]);
return (
<FormControl w="full">
<FormLabel flexGrow={1}>{t('controlLayers.saveAllImagesToGallery')}</FormLabel>
<Checkbox isChecked={saveAllImagesToGallery} onChange={onChange} />
</FormControl>
);
});
CanvasSettingsSaveAllImagesToGalleryCheckbox.displayName = 'CanvasSettingsSaveAllImagesToGalleryCheckbox';

View File

@@ -1,4 +1,5 @@
import { Button, Flex, Grid, Text } from '@invoke-ai/ui-library';
import { Button, Flex, Grid, Heading, Text } from '@invoke-ai/ui-library';
import { useAutoLayoutContext } from 'features/ui/layouts/auto-layout-context';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
import { memo, useCallback } from 'react';
@@ -6,41 +7,44 @@ import { useTranslation } from 'react-i18next';
import { InitialStateMainModelPicker } from './InitialStateMainModelPicker';
import { LaunchpadAddStyleReference } from './LaunchpadAddStyleReference';
import { LaunchpadContainer } from './LaunchpadContainer';
import { LaunchpadEditImageButton } from './LaunchpadEditImageButton';
import { LaunchpadGenerateFromTextButton } from './LaunchpadGenerateFromTextButton';
import { LaunchpadUseALayoutImageButton } from './LaunchpadUseALayoutImageButton';
export const CanvasLaunchpadPanel = memo(() => {
const { t } = useTranslation();
const { tab } = useAutoLayoutContext();
const focusCanvas = useCallback(() => {
navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
}, []);
navigationApi.focusPanelInTab(tab, WORKSPACE_PANEL_ID);
}, [tab]);
return (
<LaunchpadContainer heading={t('ui.launchpad.canvasTitle')}>
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
<InitialStateMainModelPicker />
<Flex flexDir="column" gap={2} justifyContent="center">
<Text>
{t('ui.launchpad.modelGuideText')}{' '}
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
target="_blank"
rel="noopener noreferrer"
size="sm"
>
{t('ui.launchpad.modelGuideLink')}
</Button>
</Text>
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768} pt="20vh">
<Heading mb={4}>{t('ui.launchpad.canvasTitle')}</Heading>
<Flex flexDir="column" gap={8}>
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
<InitialStateMainModelPicker />
<Flex flexDir="column" gap={2} justifyContent="center">
<Text>
{t('ui.launchpad.modelGuideText')}{' '}
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
size="sm"
>
{t('ui.launchpad.modelGuideLink')}
</Button>
</Text>
</Flex>
</Grid>
<LaunchpadGenerateFromTextButton extraAction={focusCanvas} />
<LaunchpadAddStyleReference extraAction={focusCanvas} />
<LaunchpadEditImageButton extraAction={focusCanvas} />
<LaunchpadUseALayoutImageButton extraAction={focusCanvas} />
</Flex>
</Grid>
<LaunchpadGenerateFromTextButton extraAction={focusCanvas} />
<LaunchpadAddStyleReference extraAction={focusCanvas} />
<LaunchpadEditImageButton extraAction={focusCanvas} />
<LaunchpadUseALayoutImageButton extraAction={focusCanvas} />
</LaunchpadContainer>
</Flex>
</Flex>
);
});
CanvasLaunchpadPanel.displayName = 'CanvasLaunchpadPanel';

View File

@@ -1,48 +1,52 @@
import { Alert, Button, Flex, Grid, Text } from '@invoke-ai/ui-library';
import { Alert, Button, Flex, Grid, Heading, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { InitialStateMainModelPicker } from 'features/controlLayers/components/SimpleSession/InitialStateMainModelPicker';
import { LaunchpadAddStyleReference } from 'features/controlLayers/components/SimpleSession/LaunchpadAddStyleReference';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { setActiveTab } from 'features/ui/store/uiSlice';
import { memo, useCallback } from 'react';
import { LaunchpadContainer } from './LaunchpadContainer';
import { LaunchpadGenerateFromTextButton } from './LaunchpadGenerateFromTextButton';
export const GenerateLaunchpadPanel = memo(() => {
const dispatch = useAppDispatch();
const newCanvasSession = useCallback(() => {
navigationApi.switchToTab('canvas');
}, []);
dispatch(setActiveTab('canvas'));
}, [dispatch]);
return (
<LaunchpadContainer heading="Generate images from text prompts.">
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
<InitialStateMainModelPicker />
<Flex flexDir="column" gap={2} justifyContent="center">
<Text>
Want to learn what prompts work best for each model?{' '}
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
target="_blank"
rel="noopener noreferrer"
size="sm"
>
Check out our Model Guide.
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768} pt="20vh">
<Heading mb={4}>Generate images from text prompts.</Heading>
<Flex flexDir="column" gap={8}>
<Grid gridTemplateColumns="1fr 1fr" gap={8}>
<InitialStateMainModelPicker />
<Flex flexDir="column" gap={2} justifyContent="center">
<Text>
Want to learn what prompts work best for each model?{' '}
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000216086-model-guide"
size="sm"
>
Check out our Model Guide.
</Button>
</Text>
</Flex>
</Grid>
<LaunchpadGenerateFromTextButton />
<LaunchpadAddStyleReference />
<Alert status="info" borderRadius="base" flexDir="column" gap={2} overflow="unset">
<Text fontSize="md" fontWeight="semibold">
Looking to get more control, edit, and iterate on your images?
</Text>
<Button variant="link" onClick={newCanvasSession}>
Navigate to Canvas for more capabilities.
</Button>
</Text>
</Alert>
</Flex>
</Grid>
<LaunchpadGenerateFromTextButton />
<LaunchpadAddStyleReference />
<Alert status="info" borderRadius="base" flexDir="column" gap={2} overflow="unset">
<Text fontSize="md" fontWeight="semibold">
Looking to get more control, edit, and iterate on your images?
</Text>
<Button variant="link" onClick={newCanvasSession}>
Navigate to Canvas for more capabilities.
</Button>
</Alert>
</LaunchpadContainer>
</Flex>
</Flex>
);
});
GenerateLaunchpadPanel.displayName = 'GenerateLaunchpad';

View File

@@ -0,0 +1,28 @@
import type { ButtonGroupProps } from '@invoke-ai/ui-library';
import { Button, ButtonGroup } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/storeHooks';
import { newCanvasFromImage } from 'features/imageActions/actions';
import { memo, useCallback } from 'react';
import type { ImageDTO } from 'services/api/types';
export const ImageActions = memo(({ imageDTO, ...rest }: { imageDTO: ImageDTO } & ButtonGroupProps) => {
const { getState, dispatch } = useAppStore();
const edit = useCallback(() => {
newCanvasFromImage({
imageDTO,
type: 'raster_layer',
withInpaintMask: true,
getState,
dispatch,
});
}, [dispatch, getState, imageDTO]);
return (
<ButtonGroup isAttached={false} size="sm" {...rest}>
<Button onClick={edit} tooltip="Edit parts of this image with Inpainting">
Edit
</Button>
</ButtonGroup>
);
});
ImageActions.displayName = 'ImageActions';

View File

@@ -1,17 +0,0 @@
import { Flex, Heading } from '@invoke-ai/ui-library';
import type { PropsWithChildren } from 'react';
import { memo } from 'react';
export const LaunchpadContainer = memo((props: PropsWithChildren<{ heading: string }>) => {
return (
<Flex flexDir="column" h="full" w="full" alignItems="center" justifyContent="center" gap={2}>
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768}>
<Heading>{props.heading}</Heading>
<Flex flexDir="column" gap={4}>
{props.children}
</Flex>
</Flex>
</Flex>
);
});
LaunchpadContainer.displayName = 'LaunchpadContainer';

View File

@@ -1,10 +1,11 @@
import { Flex, Heading, Icon, Text } from '@invoke-ai/ui-library';
import { LaunchpadButton } from 'features/controlLayers/components/SimpleSession/LaunchpadButton';
import { useAutoLayoutContext } from 'features/ui/layouts/auto-layout-context';
import { memo, useCallback } from 'react';
import { PiCursorTextBold, PiTextAaBold } from 'react-icons/pi';
const focusOnPrompt = () => {
const promptElement = document.querySelector('.positive-prompt-textarea');
const focusOnPrompt = (el: HTMLElement) => {
const promptElement = el.querySelector('.positive-prompt-textarea');
if (promptElement instanceof HTMLTextAreaElement) {
promptElement.focus();
promptElement.select();
@@ -12,10 +13,15 @@ const focusOnPrompt = () => {
};
export const LaunchpadGenerateFromTextButton = memo((props: { extraAction?: () => void }) => {
const { rootRef } = useAutoLayoutContext();
const onClick = useCallback(() => {
focusOnPrompt();
const el = rootRef.current;
if (!el) {
return;
}
focusOnPrompt(el);
props.extraAction?.();
}, [props]);
}, [props, rootRef]);
return (
<LaunchpadButton onClick={onClick} position="relative" gap={8}>
<Icon as={PiTextAaBold} boxSize={8} color="base.500" />

View File

@@ -0,0 +1,56 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import {
useCanvasSessionContext,
useOutputImageDTO,
useProgressData,
} from 'features/controlLayers/components/SimpleSession/context';
import { ImageActions } from 'features/controlLayers/components/SimpleSession/ImageActions';
import { QueueItemCircularProgress } from 'features/controlLayers/components/SimpleSession/QueueItemCircularProgress';
import { QueueItemNumber } from 'features/controlLayers/components/SimpleSession/QueueItemNumber';
import { QueueItemProgressImage } from 'features/controlLayers/components/SimpleSession/QueueItemProgressImage';
import { QueueItemStatusLabel } from 'features/controlLayers/components/SimpleSession/QueueItemStatusLabel';
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
import { DndImage } from 'features/dnd/DndImage';
import { memo } from 'react';
import type { S } from 'services/api/types';
type Props = {
item: S['SessionQueueItem'];
number: number;
};
const sx = {
userSelect: 'none',
pos: 'relative',
alignItems: 'center',
justifyContent: 'center',
overflow: 'hidden',
h: 'full',
w: 'full',
} satisfies SystemStyleObject;
export const QueueItemPreviewFull = memo(({ item, number }: Props) => {
const ctx = useCanvasSessionContext();
const imageDTO = useOutputImageDTO(item);
const { imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
return (
<Flex id={getQueueItemElementId(item.item_id)} sx={sx}>
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
{imageDTO && <DndImage imageDTO={imageDTO} />}
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
{imageDTO && <ImageActions imageDTO={imageDTO} position="absolute" top={1} right={2} />}
<QueueItemNumber number={number} position="absolute" top={1} left={2} />
<QueueItemCircularProgress
itemId={item.item_id}
status={item.status}
position="absolute"
top={1}
right={2}
size={8}
/>
</Flex>
);
});
QueueItemPreviewFull.displayName = 'QueueItemPreviewFull';

View File

@@ -1,6 +1,5 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
useCanvasSessionContext,
useOutputImageDTO,
@@ -11,10 +10,6 @@ import { QueueItemNumber } from 'features/controlLayers/components/SimpleSession
import { QueueItemProgressImage } from 'features/controlLayers/components/SimpleSession/QueueItemProgressImage';
import { QueueItemStatusLabel } from 'features/controlLayers/components/SimpleSession/QueueItemStatusLabel';
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
import {
selectStagingAreaAutoSwitch,
settingsStagingAreaAutoSwitchChanged,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { DndImage } from 'features/dnd/DndImage';
import { toast } from 'features/toast/toast';
import { memo, useCallback } from 'react';
@@ -26,13 +21,12 @@ const sx = {
pos: 'relative',
alignItems: 'center',
justifyContent: 'center',
h: 108,
w: 108,
flexShrink: 0,
h: 'full',
aspectRatio: '1/1',
borderWidth: 2,
borderRadius: 'base',
bg: 'base.900',
overflow: 'hidden',
'&[data-selected="true"]': {
borderColor: 'invokeBlue.300',
},
@@ -40,29 +34,28 @@ const sx = {
type Props = {
item: S['SessionQueueItem'];
index: number;
number: number;
isSelected: boolean;
};
export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) => {
const dispatch = useAppDispatch();
export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) => {
const ctx = useCanvasSessionContext();
const { imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
const imageDTO = useOutputImageDTO(item);
const autoSwitch = useAppSelector(selectStagingAreaAutoSwitch);
const onClick = useCallback(() => {
ctx.$selectedItemId.set(item.item_id);
}, [ctx.$selectedItemId, item.item_id]);
const onDoubleClick = useCallback(() => {
const autoSwitch = ctx.$autoSwitch.get();
if (autoSwitch !== 'off') {
dispatch(settingsStagingAreaAutoSwitchChanged('off'));
ctx.$autoSwitch.set('off');
toast({
title: 'Auto-Switch Disabled',
});
}
}, [autoSwitch, dispatch]);
}, [ctx.$autoSwitch]);
const onLoad = useCallback(() => {
ctx.onImageLoad(item.item_id);
@@ -70,7 +63,7 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) =>
return (
<Flex
id={getQueueItemElementId(index)}
id={getQueueItemElementId(item.item_id)}
sx={sx}
data-selected={isSelected}
onClick={onClick}
@@ -79,7 +72,7 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) =>
<QueueItemStatusLabel item={item} position="absolute" margin="auto" />
{imageDTO && <DndImage imageDTO={imageDTO} onLoad={onLoad} asThumbnail />}
{!imageLoaded && <QueueItemProgressImage itemId={item.item_id} position="absolute" />}
<QueueItemNumber number={index + 1} position="absolute" top={0} left={1} />
<QueueItemNumber number={number} position="absolute" top={0} left={1} />
<QueueItemCircularProgress itemId={item.item_id} status={item.status} position="absolute" top={1} right={2} />
</Flex>
);

View File

@@ -0,0 +1,32 @@
import type { TextProps } from '@invoke-ai/ui-library';
import { Text } from '@invoke-ai/ui-library';
import { useCanvasSessionContext, useProgressData } from 'features/controlLayers/components/SimpleSession/context';
import { DROP_SHADOW, getProgressMessage } from 'features/controlLayers/components/SimpleSession/shared';
import { memo } from 'react';
import type { S } from 'services/api/types';
type Props = { itemId: number; status: S['SessionQueueItem']['status'] } & TextProps;
export const QueueItemProgressMessage = memo(({ itemId, status, ...rest }: Props) => {
const ctx = useCanvasSessionContext();
const { progressEvent } = useProgressData(ctx.$progressData, itemId);
if (status === 'completed' || status === 'failed' || status === 'canceled') {
return null;
}
if (status === 'pending') {
return (
<Text pointerEvents="none" userSelect="none" filter={DROP_SHADOW} {...rest}>
Waiting to start...
</Text>
);
}
return (
<Text pointerEvents="none" userSelect="none" filter={DROP_SHADOW} {...rest}>
{getProgressMessage(progressEvent)}
</Text>
);
});
QueueItemProgressMessage.displayName = 'QueueItemProgressMessage';

View File

@@ -16,21 +16,21 @@ export const QueueItemStatusLabel = memo(({ item, ...rest }: Props) => {
if (item.status === 'pending') {
return (
<Text fontSize="xs" pointerEvents="none" userSelect="none" fontWeight="semibold" color="base.300" {...rest}>
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="base.300" {...rest}>
Pending
</Text>
);
}
if (item.status === 'canceled') {
return (
<Text fontSize="xs" pointerEvents="none" userSelect="none" fontWeight="semibold" color="warning.300" {...rest}>
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="warning.300" {...rest}>
Canceled
</Text>
);
}
if (item.status === 'failed') {
return (
<Text fontSize="xs" pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300" {...rest}>
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="error.300" {...rest}>
Failed
</Text>
);
@@ -38,7 +38,7 @@ export const QueueItemStatusLabel = memo(({ item, ...rest }: Props) => {
if (item.status === 'in_progress') {
return (
<Text fontSize="xs" pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeBlue.300" {...rest}>
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeBlue.300" {...rest}>
In Progress
</Text>
);
@@ -46,14 +46,7 @@ export const QueueItemStatusLabel = memo(({ item, ...rest }: Props) => {
if (item.status === 'completed') {
return (
<Text
fontSize="xs"
pointerEvents="none"
userSelect="none"
fontWeight="semibold"
color="invokeGreen.300"
{...rest}
>
<Text pointerEvents="none" userSelect="none" fontWeight="semibold" color="invokeGreen.300" {...rest}>
Completed
</Text>
);

View File

@@ -0,0 +1,15 @@
import { Divider, Flex } from '@invoke-ai/ui-library';
import { StagingAreaHeader } from 'features/controlLayers/components/SimpleSession/StagingAreaHeader';
import { StagingAreaNoItems } from 'features/controlLayers/components/SimpleSession/StagingAreaNoItems';
import { memo } from 'react';
export const SimpleSessionNoId = memo(() => {
return (
<Flex flexDir="column" gap={2} w="full" h="full" minW={0} minH={0}>
<StagingAreaHeader />
<Divider />
<StagingAreaNoItems />
</Flex>
);
});
SimpleSessionNoId.displayName = 'StSimpleSessionNoIdagingArea';

View File

@@ -0,0 +1,33 @@
import { Divider, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
import { StagingAreaContent } from 'features/controlLayers/components/SimpleSession/StagingAreaContent';
import { StagingAreaHeader } from 'features/controlLayers/components/SimpleSession/StagingAreaHeader';
import { StagingAreaNoItems } from 'features/controlLayers/components/SimpleSession/StagingAreaNoItems';
import { useStagingAreaKeyboardNav } from 'features/controlLayers/components/SimpleSession/use-staging-keyboard-nav';
import { memo, useEffect } from 'react';
export const StagingArea = memo(() => {
const ctx = useCanvasSessionContext();
const hasItems = useStore(ctx.$hasItems);
useStagingAreaKeyboardNav();
useEffect(() => {
return ctx.$selectedItemId.listen((id) => {
if (id !== null) {
document.getElementById(getQueueItemElementId(id))?.scrollIntoView();
}
});
}, [ctx.$selectedItemId]);
return (
<Flex flexDir="column" gap={2} w="full" h="full" minW={0} minH={0}>
<StagingAreaHeader />
<Divider />
{hasItems && <StagingAreaContent />}
{!hasItems && <StagingAreaNoItems />}
</Flex>
);
});
StagingArea.displayName = 'StagingArea';

View File

@@ -0,0 +1,23 @@
import { Divider, Flex } from '@invoke-ai/ui-library';
import { StagingAreaItemsList } from 'features/controlLayers/components/SimpleSession/StagingAreaItemsList';
import { StagingAreaSelectedItem } from 'features/controlLayers/components/SimpleSession/StagingAreaSelectedItem';
import { SimpleStagingAreaToolbar } from 'features/controlLayers/components/StagingArea/SimpleStagingAreaToolbar';
import { memo } from 'react';
export const StagingAreaContent = memo(() => {
return (
<>
<Flex position="relative" w="full" h="full" maxH="full" alignItems="center" justifyContent="center" minH={0}>
<StagingAreaSelectedItem />
</Flex>
<Divider />
<Flex position="relative" maxW="full" w="full" h={108} flexShrink={0}>
<StagingAreaItemsList />
</Flex>
<Flex gap={2} w="full" justifyContent="safe center">
<SimpleStagingAreaToolbar />
</Flex>
</>
);
});
StagingAreaContent.displayName = 'StagingAreaContent';

View File

@@ -0,0 +1,12 @@
import { Flex, Heading, Spacer } from '@invoke-ai/ui-library';
import { memo } from 'react';
export const StagingAreaHeader = memo(() => {
return (
<Flex gap={2} w="full" alignItems="center" px={2}>
<Heading size="sm">Review Session</Heading>
<Spacer />
</Flex>
);
});
StagingAreaHeader.displayName = 'StagingAreaHeader';

View File

@@ -1,214 +1,38 @@
import { Box, Flex, forwardRef } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { logger } from 'app/logging/logger';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { QueueItemPreviewMini } from 'features/controlLayers/components/SimpleSession/QueueItemPreviewMini';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
import type { CSSProperties, RefObject } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import type { Components, ItemContent, ListRange, VirtuosoHandle, VirtuosoProps } from 'react-virtuoso';
import { Virtuoso } from 'react-virtuoso';
import type { S } from 'services/api/types';
import { getQueueItemElementId } from './shared';
const log = logger('system');
const virtuosoStyles = {
width: '100%',
height: '72px',
} satisfies CSSProperties;
type VirtuosoContext = { selectedItemId: number | null };
/**
* Scroll the item at the given index into view if it is not currently visible.
*/
const scrollIntoView = (
targetIndex: number,
rootEl: HTMLDivElement,
virtuosoHandle: VirtuosoHandle,
range: ListRange
) => {
if (range.endIndex === 0) {
// No range is rendered; no need to scroll to anything.
return;
}
const targetItem = rootEl.querySelector(`#${getQueueItemElementId(targetIndex)}`);
if (!targetItem) {
if (targetIndex > range.endIndex) {
virtuosoHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'end',
});
} else if (targetIndex < range.startIndex) {
virtuosoHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'start',
});
} else {
log.debug(
`Unable to find queue item at index ${targetIndex} but it is in the rendered range ${range.startIndex}-${range.endIndex}`
);
}
return;
}
// We found the image in the DOM, but it might be in the overscan range - rendered but not in the visible viewport.
// Check if it is in the viewport and scroll if necessary.
const itemRect = targetItem.getBoundingClientRect();
const rootRect = rootEl.getBoundingClientRect();
if (itemRect.left < rootRect.left) {
virtuosoHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'start',
});
} else if (itemRect.right > rootRect.right) {
virtuosoHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'end',
});
} else {
// Image is already in view
}
return;
};
const useScrollableStagingArea = (rootRef: RefObject<HTMLDivElement>) => {
const [scroller, scrollerRef] = useState<HTMLElement | null>(null);
const [initialize, osInstance] = useOverlayScrollbars({
defer: true,
events: {
initialized(osInstance) {
// force overflow styles
const { viewport } = osInstance.elements();
viewport.style.overflowX = `var(--os-viewport-overflow-x)`;
viewport.style.overflowY = `var(--os-viewport-overflow-y)`;
},
},
options: {
scrollbars: {
visibility: 'auto',
autoHide: 'scroll',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
overflow: {
y: 'hidden',
x: 'scroll',
},
},
});
useEffect(() => {
const { current: root } = rootRef;
if (scroller && root) {
initialize({
target: root,
elements: {
viewport: scroller,
},
});
}
return () => {
osInstance()?.destroy();
};
}, [scroller, initialize, osInstance, rootRef]);
return scrollerRef;
};
import { memo, useEffect } from 'react';
export const StagingAreaItemsList = memo(() => {
const canvasManager = useCanvasManagerSafe();
const ctx = useCanvasSessionContext();
const virtuosoRef = useRef<VirtuosoHandle>(null);
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
const rootRef = useRef<HTMLDivElement>(null);
const items = useStore(ctx.$items);
const selectedItemId = useStore(ctx.$selectedItemId);
const context = useMemo(() => ({ selectedItemId }), [selectedItemId]);
const scrollerRef = useScrollableStagingArea(rootRef);
useEffect(() => {
if (!canvasManager) {
return;
}
return canvasManager.stagingArea.connectToSession(ctx.$selectedItemId, ctx.$progressData, ctx.$isPending);
}, [canvasManager, ctx.$progressData, ctx.$selectedItemId, ctx.$isPending]);
useEffect(() => {
return ctx.$selectedItemIndex.listen((index) => {
if (!virtuosoRef.current) {
return;
}
if (!rootRef.current) {
return;
}
if (index === null) {
return;
}
scrollIntoView(index, rootRef.current, virtuosoRef.current, rangeRef.current);
});
}, [ctx.$selectedItemIndex]);
const onRangeChanged = useCallback((range: ListRange) => {
rangeRef.current = range;
}, []);
return canvasManager.stagingArea.connectToSession(ctx.$selectedItemId, ctx.$progressData);
}, [canvasManager, ctx.$progressData, ctx.$selectedItemId]);
return (
<Box data-overlayscrollbars-initialize="" ref={rootRef} position="relative" w="full" h="full">
<Virtuoso<S['SessionQueueItem'], VirtuosoContext>
ref={virtuosoRef}
context={context}
data={items}
horizontalDirection
style={virtuosoStyles}
itemContent={itemContent}
components={components}
rangeChanged={onRangeChanged}
// Virtuoso expects the ref to be of HTMLElement | null | Window, but overlayscrollbars doesn't allow Window
scrollerRef={scrollerRef as VirtuosoProps<S['SessionQueueItem'], VirtuosoContext>['scrollerRef']}
/>
</Box>
<ScrollableContent overflowX="scroll" overflowY="hidden">
<Flex gap={2} w="full" h="full" justifyContent="safe center">
{items.map((item, i) => (
<QueueItemPreviewMini
key={`${item.item_id}-mini`}
item={item}
number={i + 1}
isSelected={selectedItemId === item.item_id}
/>
))}
</Flex>
</ScrollableContent>
);
});
StagingAreaItemsList.displayName = 'StagingAreaItemsList';
const itemContent: ItemContent<S['SessionQueueItem'], VirtuosoContext> = (index, item, { selectedItemId }) => (
<QueueItemPreviewMini
key={`${item.item_id}-mini`}
item={item}
index={index}
isSelected={selectedItemId === item.item_id}
/>
);
const listSx = {
'& > * + *': {
pl: 2,
},
};
const components: Components<S['SessionQueueItem'], VirtuosoContext> = {
List: forwardRef(({ context: _, ...rest }, ref) => {
return <Flex ref={ref} sx={listSx} {...rest} />;
}),
};

View File

@@ -0,0 +1,11 @@
import { Flex, Text } from '@invoke-ai/ui-library';
import { memo } from 'react';
export const StagingAreaNoItems = memo(() => {
return (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<Text>No generations</Text>
</Flex>
);
});
StagingAreaNoItems.displayName = 'StagingAreaNoItems';

View File

@@ -0,0 +1,20 @@
import { Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { QueueItemPreviewFull } from 'features/controlLayers/components/SimpleSession/QueueItemPreviewFull';
import { memo } from 'react';
export const StagingAreaSelectedItem = memo(() => {
const ctx = useCanvasSessionContext();
const selectedItem = useStore(ctx.$selectedItem);
const selectedItemIndex = useStore(ctx.$selectedItemIndex);
if (selectedItem && selectedItemIndex !== null) {
return (
<QueueItemPreviewFull key={`${selectedItem.item_id}-full`} item={selectedItem} number={selectedItemIndex + 1} />
);
}
return <Text>No generation selected</Text>;
});
StagingAreaSelectedItem.displayName = 'StagingAreaSelectedItem';

View File

@@ -24,7 +24,6 @@ import {
import type { ImageDTO } from 'services/api/types';
import { LaunchpadButton } from './LaunchpadButton';
import { LaunchpadContainer } from './LaunchpadContainer';
export const UpscalingLaunchpadPanel = memo(() => {
const { t } = useTranslation();
@@ -66,104 +65,108 @@ export const UpscalingLaunchpadPanel = memo(() => {
}, [dispatch]);
return (
<LaunchpadContainer heading={t('ui.launchpad.upscalingTitle')}>
{/* Upload Area */}
<LaunchpadButton {...uploadApi.getUploadButtonProps()} position="relative" gap={8}>
{!upscaleInitialImage ? (
<>
<Icon as={PiImageBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.upscaling.uploadImage.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.upscaling.uploadImage.description')}</Text>
</Flex>
<Flex position="absolute" right={3} bottom={3}>
<PiUploadBold />
<input {...uploadApi.getUploadInputProps()} />
</Flex>
</>
) : (
<>
<Icon as={PiImageBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.upscaling.replaceImage.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.upscaling.replaceImage.description')}</Text>
</Flex>
<Flex position="absolute" right={3} bottom={3}>
<PiUploadBold />
<input {...uploadApi.getUploadInputProps()} />
</Flex>
</>
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
<Flex flexDir="column" w="full" gap={8} px={14} maxW={768} pt="20vh">
<Heading>{t('ui.launchpad.upscalingTitle')}</Heading>
{/* Upload Area */}
<LaunchpadButton {...uploadApi.getUploadButtonProps()} position="relative" gap={8}>
{!upscaleInitialImage ? (
<>
<Icon as={PiImageBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.upscaling.uploadImage.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.upscaling.uploadImage.description')}</Text>
</Flex>
<Flex position="absolute" right={3} bottom={3}>
<PiUploadBold />
<input {...uploadApi.getUploadInputProps()} />
</Flex>
</>
) : (
<>
<Icon as={PiImageBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.upscaling.replaceImage.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.upscaling.replaceImage.description')}</Text>
</Flex>
<Flex position="absolute" right={3} bottom={3}>
<PiUploadBold />
<input {...uploadApi.getUploadInputProps()} />
</Flex>
</>
)}
<DndDropTarget
dndTarget={setUpscaleInitialImageDndTarget}
dndTargetData={dndTargetData}
label={t('gallery.drop')}
/>
</LaunchpadButton>
{/* Guidance text */}
{upscaleInitialImage && (
<Flex bg="base.800" p={4} borderRadius="base" border="1px solid" borderColor="base.700">
<Text variant="subtext" fontSize="sm" lineHeight="1.6">
<strong>{t('ui.launchpad.upscaling.readyToUpscale.title')}</strong>{' '}
{t('ui.launchpad.upscaling.readyToUpscale.description')}
</Text>
</Flex>
)}
<DndDropTarget
dndTarget={setUpscaleInitialImageDndTarget}
dndTargetData={dndTargetData}
label={t('gallery.drop')}
/>
</LaunchpadButton>
{/* Guidance text */}
{upscaleInitialImage && (
<Flex bg="base.800" p={4} borderRadius="base" border="1px solid" borderColor="base.700">
<Text variant="subtext" fontSize="sm" lineHeight="1.6">
<strong>{t('ui.launchpad.upscaling.readyToUpscale.title')}</strong>{' '}
{t('ui.launchpad.upscaling.readyToUpscale.description')}
</Text>
</Flex>
)}
{/* Controls */}
<Grid gridTemplateColumns="1fr 1fr" gap={8} alignItems="start">
{/* Left Column: Creativity and Structural Defaults */}
<Box>
<Text fontWeight="semibold" fontSize="sm" mb={3}>
Creativity & Structure Defaults
</Text>
<ButtonGroup size="sm" orientation="vertical" variant="outline" w="full">
<Button
colorScheme={creativity === -5 && structure === 5 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onConservativeClick}
leftIcon={<PiShieldCheckBold />}
>
Conservative
</Button>
<Button
colorScheme={creativity === 0 && structure === 0 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onBalancedClick}
leftIcon={<PiScalesBold />}
>
Balanced
</Button>
<Button
colorScheme={creativity === 5 && structure === -2 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onCreativeClick}
leftIcon={<PiPaletteBold />}
>
Creative
</Button>
<Button
colorScheme={creativity === 8 && structure === -5 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onArtisticClick}
leftIcon={<PiSparkleBold />}
>
Artistic
</Button>
</ButtonGroup>
</Box>
{/* Right Column: Description/help text */}
<Box>
<Text variant="subtext" fontSize="sm" lineHeight="1.6">
{t('ui.launchpad.upscaling.helpText.promptAdvice')}
</Text>
<Text variant="subtext" fontSize="sm" lineHeight="1.6" mt={3}>
{t('ui.launchpad.upscaling.helpText.styleAdvice')}
</Text>
</Box>
</Grid>
</LaunchpadContainer>
{/* Controls */}
<Grid gridTemplateColumns="1fr 1fr" gap={8} alignItems="start">
{/* Left Column: Creativity and Structural Defaults */}
<Box>
<Text fontWeight="semibold" fontSize="sm" mb={3}>
Creativity & Structure Defaults
</Text>
<ButtonGroup size="sm" orientation="vertical" variant="outline" w="full">
<Button
colorScheme={creativity === -5 && structure === 5 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onConservativeClick}
leftIcon={<PiShieldCheckBold />}
>
Conservative
</Button>
<Button
colorScheme={creativity === 0 && structure === 0 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onBalancedClick}
leftIcon={<PiScalesBold />}
>
Balanced
</Button>
<Button
colorScheme={creativity === 5 && structure === -2 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onCreativeClick}
leftIcon={<PiPaletteBold />}
>
Creative
</Button>
<Button
colorScheme={creativity === 8 && structure === -5 ? 'invokeBlue' : undefined}
justifyContent="center"
onClick={onArtisticClick}
leftIcon={<PiSparkleBold />}
>
Artistic
</Button>
</ButtonGroup>
</Box>
{/* Right Column: Description/help text */}
<Box>
<Text variant="subtext" fontSize="sm" lineHeight="1.6">
{t('ui.launchpad.upscaling.helpText.promptAdvice')}
</Text>
<Text variant="subtext" fontSize="sm" lineHeight="1.6" mt={3}>
{t('ui.launchpad.upscaling.helpText.styleAdvice')}
</Text>
</Box>
</Grid>
</Flex>
</Flex>
);
});

View File

@@ -8,7 +8,6 @@ import { useTranslation } from 'react-i18next';
import { PiFilePlusBold, PiFolderOpenBold, PiUploadBold } from 'react-icons/pi';
import { LaunchpadButton } from './LaunchpadButton';
import { LaunchpadContainer } from './LaunchpadContainer';
export const WorkflowsLaunchpadPanel = memo(() => {
const { t } = useTranslation();
@@ -46,59 +45,63 @@ export const WorkflowsLaunchpadPanel = memo(() => {
});
return (
<LaunchpadContainer heading={t('ui.launchpad.workflowsTitle')}>
{/* Description */}
<Text variant="subtext" fontSize="md" lineHeight="1.6">
{t('ui.launchpad.workflows.description')}
</Text>
<Flex flexDir="column" h="full" w="full" alignItems="center" gap={2}>
<Flex flexDir="column" w="full" gap={4} px={14} maxW={768} pt="20vh">
<Heading>{t('ui.launchpad.workflowsTitle')}</Heading>
<Text>
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000189610-getting-started-with-workflows-denoise-latents"
target="_blank"
rel="noopener noreferrer"
size="sm"
>
{t('ui.launchpad.workflows.learnMoreLink')}
</Button>
</Text>
{/* Description */}
<Text variant="subtext" fontSize="md" lineHeight="1.6">
{t('ui.launchpad.workflows.description')}
</Text>
{/* Action Buttons */}
<Flex flexDir="column" gap={8}>
{/* Browse Workflow Templates */}
<LaunchpadButton onClick={handleBrowseTemplates} position="relative" gap={8}>
<Icon as={PiFolderOpenBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.workflows.browseTemplates.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.workflows.browseTemplates.description')}</Text>
</Flex>
</LaunchpadButton>
<Text>
<Button
as="a"
variant="link"
href="https://support.invoke.ai/support/solutions/articles/151000189610-getting-started-with-workflows-denoise-latents"
target="_blank"
rel="noopener noreferrer"
size="sm"
>
{t('ui.launchpad.workflows.learnMoreLink')}
</Button>
</Text>
{/* Create a new Workflow */}
<LaunchpadButton onClick={handleCreateNew} position="relative" gap={8}>
<Icon as={PiFilePlusBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.workflows.createNew.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.workflows.createNew.description')}</Text>
</Flex>
</LaunchpadButton>
{/* Action Buttons */}
<Flex flexDir="column" gap={8}>
{/* Browse Workflow Templates */}
<LaunchpadButton onClick={handleBrowseTemplates} position="relative" gap={8}>
<Icon as={PiFolderOpenBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.workflows.browseTemplates.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.workflows.browseTemplates.description')}</Text>
</Flex>
</LaunchpadButton>
{/* Load workflow from existing image or file */}
<LaunchpadButton {...uploadApi.getRootProps()} position="relative" gap={8}>
<Icon as={PiUploadBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.workflows.loadFromFile.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.workflows.loadFromFile.description')}</Text>
</Flex>
<Flex position="absolute" right={3} bottom={3}>
<PiUploadBold />
<input {...uploadApi.getInputProps()} />
</Flex>
</LaunchpadButton>
{/* Create a new Workflow */}
<LaunchpadButton onClick={handleCreateNew} position="relative" gap={8}>
<Icon as={PiFilePlusBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.workflows.createNew.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.workflows.createNew.description')}</Text>
</Flex>
</LaunchpadButton>
{/* Load workflow from existing image or file */}
<LaunchpadButton {...uploadApi.getRootProps()} position="relative" gap={8}>
<Icon as={PiUploadBold} boxSize={8} color="base.500" />
<Flex flexDir="column" alignItems="flex-start" gap={2}>
<Heading size="sm">{t('ui.launchpad.workflows.loadFromFile.title')}</Heading>
<Text color="base.300">{t('ui.launchpad.workflows.loadFromFile.description')}</Text>
</Flex>
<Flex position="absolute" right={3} bottom={3}>
<PiUploadBold />
<input {...uploadApi.getInputProps()} />
</Flex>
</LaunchpadButton>
</Flex>
</Flex>
</LaunchpadContainer>
</Flex>
);
});

View File

@@ -1,12 +1,9 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppStore } from 'app/store/storeHooks';
import { buildZodTypeGuard } from 'common/util/zodUtils';
import { getOutputImageName } from 'features/controlLayers/components/SimpleSession/shared';
import { selectStagingAreaAutoSwitch } from 'features/controlLayers/store/canvasSettingsSlice';
import {
buildSelectSessionQueueItems,
canvasQueueItemDiscarded,
canvasSessionReset,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import type { ProgressImage } from 'features/nodes/types/common';
import type { Atom, MapStore, StoreValue, WritableAtom } from 'nanostores';
import { atom, computed, effect, map, subscribeKeys } from 'nanostores';
@@ -17,6 +14,11 @@ import { queueApi } from 'services/api/endpoints/queue';
import type { ImageDTO, S } from 'services/api/types';
import { $socket } from 'services/events/stores';
import { assert, objectEntries } from 'tsafe';
import { z } from 'zod/v4';
const zAutoSwitchMode = z.enum(['off', 'switch_on_start', 'switch_on_finish']);
export const isAutoSwitchMode = buildZodTypeGuard(zAutoSwitchMode);
export type AutoSwitchMode = z.infer<typeof zAutoSwitchMode>;
export type ProgressData = {
itemId: number;
@@ -90,19 +92,17 @@ type CanvasSessionContextValue = {
$items: Atom<S['SessionQueueItem'][]>;
$itemCount: Atom<number>;
$hasItems: Atom<boolean>;
$isPending: Atom<boolean>;
$progressData: ProgressDataMap;
$selectedItemId: WritableAtom<number | null>;
$selectedItem: Atom<S['SessionQueueItem'] | null>;
$selectedItemIndex: Atom<number | null>;
$selectedItemOutputImageDTO: Atom<ImageDTO | null>;
$autoSwitch: WritableAtom<AutoSwitchMode>;
selectNext: () => void;
selectPrev: () => void;
selectFirst: () => void;
selectLast: () => void;
onImageLoad: (itemId: number) => void;
discard: (itemId: number) => void;
discardAll: () => void;
};
const CanvasSessionContext = createContext<CanvasSessionContextValue | null>(null);
@@ -139,6 +139,11 @@ export const CanvasSessionContextProvider = memo(
*/
const $items = useState(() => atom<S['SessionQueueItem'][]>([]))[0];
/**
* Whether auto-switch is enabled.
*/
const $autoSwitch = useState(() => atom<AutoSwitchMode>('switch_on_start'))[0];
/**
* An internal flag used to work around race conditions with auto-switch switching to queue items before their
* output images have fully loaded.
@@ -165,13 +170,6 @@ export const CanvasSessionContextProvider = memo(
*/
const $hasItems = useState(() => computed([$items], (items) => items.length > 0))[0];
/**
* Whether there are any pending or in-progress items. Computed from the queue items array.
*/
const $isPending = useState(() =>
computed([$items], (items) => items.some((item) => item.status === 'pending' || item.status === 'in_progress'))
)[0];
/**
* The currently selected queue item, or null if one is not selected.
*/
@@ -220,21 +218,19 @@ export const CanvasSessionContextProvider = memo(
)[0];
/**
* A redux selector to select all queue items from the RTK Query cache.
* A redux selector to select all queue items from the RTK Query cache. It's important that this returns stable
* references if possible to reduce re-renders. All derivations of the queue items (e.g. filtering out canceled
* items) should be done in a nanostores computed.
*/
const selectQueueItems = useMemo(() => buildSelectSessionQueueItems(session.id), [session.id]);
const discard = useCallback(
(itemId: number) => {
store.dispatch(canvasQueueItemDiscarded({ itemId }));
},
[store]
const selectQueueItems = useMemo(
() =>
createSelector(
queueApi.endpoints.listAllQueueItems.select({ destination: session.id }),
({ data }) => data ?? EMPTY_ARRAY
),
[session.id]
);
const discardAll = useCallback(() => {
store.dispatch(canvasSessionReset());
}, [store]);
const selectNext = useCallback(() => {
const selectedItemId = $selectedItemId.get();
if (selectedItemId === null) {
@@ -296,15 +292,12 @@ export const CanvasSessionContextProvider = memo(
imageLoaded: true,
});
}
if (
$lastCompletedItemId.get() === itemId &&
selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_finish'
) {
if ($lastCompletedItemId.get() === itemId && $autoSwitch.get() === 'switch_on_finish') {
$selectedItemId.set(itemId);
$lastCompletedItemId.set(null);
}
},
[$lastCompletedItemId, $progressData, $selectedItemId, store]
[$autoSwitch, $lastCompletedItemId, $progressData, $selectedItemId]
);
// Set up socket listeners
@@ -339,7 +332,7 @@ export const CanvasSessionContextProvider = memo(
socket.off('invocation_progress', onProgress);
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
};
}, [$lastCompletedItemId, $lastStartedItemId, $progressData, $selectedItemId, session.id, socket]);
}, [$autoSwitch, $lastCompletedItemId, $lastStartedItemId, $progressData, $selectedItemId, session.id, socket]);
// Set up state subscriptions and effects
useEffect(() => {
@@ -361,32 +354,33 @@ export const CanvasSessionContextProvider = memo(
const unsubEnsureSelectedItemIdExists = effect(
[$items, $selectedItemId, $lastStartedItemId],
(items, selectedItemId, lastStartedItemId) => {
// If there are no items, cannot have a selected item.
if (items.length === 0) {
// If there are no items, cannot have a selected item.
$selectedItemId.set(null);
} else if (selectedItemId === null && items.length > 0) {
// If there is no selected item but there are items, select the first one.
return;
}
// If there is no selected item but there are items, select the first one.
if (selectedItemId === null && items.length > 0) {
$selectedItemId.set(items[0]?.item_id ?? null);
return;
} else if (
selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_start' &&
}
if (
$autoSwitch.get() === 'switch_on_start' &&
items.findIndex(({ item_id }) => item_id === lastStartedItemId) !== -1
) {
$selectedItemId.set(lastStartedItemId);
$lastStartedItemId.set(null);
} else if (selectedItemId !== null && items.findIndex(({ item_id }) => item_id === selectedItemId) === -1) {
// If an item is selected and it is not in the list of items, un-set it. This effect will run again and we'll
// the above case, selecting the first item if there are any.
}
// If an item is selected and it is not in the list of items, un-set it. This effect will run again and we'll
// the above case, selecting the first item if there are any.
if (selectedItemId !== null && items.findIndex(({ item_id }) => item_id === selectedItemId) === -1) {
let prevIndex = _prevItems.findIndex(({ item_id }) => item_id === selectedItemId);
if (prevIndex >= items.length) {
prevIndex = items.length - 1;
}
const nextItem = items[prevIndex];
$selectedItemId.set(nextItem?.item_id ?? null);
}
if (items !== _prevItems) {
_prevItems = items;
return;
}
}
);
@@ -407,12 +401,12 @@ export const CanvasSessionContextProvider = memo(
if (!item) {
toDelete.push(datum.itemId);
} else if (item.status === 'canceled' || item.status === 'failed') {
toUpdate.push({
toUpdate[datum.itemId] = {
...datum,
progressEvent: null,
progressImage: null,
imageDTO: null,
});
};
}
}
@@ -472,7 +466,7 @@ export const CanvasSessionContextProvider = memo(
if (lastLoadedItemId === null) {
return;
}
if (selectStagingAreaAutoSwitch(store.getState()) === 'switch_on_finish') {
if ($autoSwitch.get() === 'switch_on_finish') {
$selectedItemId.set(lastLoadedItemId);
}
$lastLoadedItemId.set(null);
@@ -484,22 +478,6 @@ export const CanvasSessionContextProvider = memo(
queueApi.endpoints.listAllQueueItems.initiate({ destination: session.id })
);
// const unsubListener = store.dispatch(
// addAppListener({
// matcher: queueApi.endpoints.cancelQueueItem.matchFulfilled,
// effect: ({ payload }, { getState }) => {
// const { item_id } = payload;
// const items = selectQueueItems(getState());
// if (items.length === 0) {
// $selectedItemId.set(null);
// } else if ($selectedItemId.get() === null) {
// $selectedItemId.set(items[0].item_id);
// }
// },
// })
// );
// Clean up all subscriptions and top-level (i.e. non-computed/derived state)
return () => {
unsubHandleAutoSwitch();
@@ -512,6 +490,7 @@ export const CanvasSessionContextProvider = memo(
$selectedItemId.set(null);
};
}, [
$autoSwitch,
$items,
$lastLoadedItemId,
$lastStartedItemId,
@@ -527,9 +506,9 @@ export const CanvasSessionContextProvider = memo(
session,
$items,
$hasItems,
$isPending,
$progressData,
$selectedItemId,
$autoSwitch,
$selectedItem,
$selectedItemIndex,
$selectedItemOutputImageDTO,
@@ -539,13 +518,11 @@ export const CanvasSessionContextProvider = memo(
selectFirst,
selectLast,
onImageLoad,
discard,
discardAll,
}),
[
$autoSwitch,
$items,
$hasItems,
$isPending,
$progressData,
$selectedItem,
$selectedItemId,
@@ -558,8 +535,6 @@ export const CanvasSessionContextProvider = memo(
selectFirst,
selectLast,
onImageLoad,
discard,
discardAll,
]
);

View File

@@ -13,7 +13,7 @@ export const getProgressMessage = (data?: S['InvocationProgressEvent'] | null) =
export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
export const getQueueItemElementId = (index: number) => `queue-item-preview-${index}`;
export const getQueueItemElementId = (itemId: number) => `queue-item-status-card-${itemId}`;
export const getOutputImageName = (item: S['SessionQueueItem']) => {
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>

View File

@@ -0,0 +1,11 @@
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useHotkeys } from 'react-hotkeys-hook';
export const useStagingAreaKeyboardNav = () => {
const ctx = useCanvasSessionContext();
useHotkeys('left', ctx.selectPrev, { preventDefault: true });
useHotkeys('right', ctx.selectNext, { preventDefault: true });
useHotkeys('meta+left', ctx.selectFirst, { preventDefault: true });
useHotkeys('meta+right', ctx.selectLast, { preventDefault: true });
};

View File

@@ -0,0 +1,27 @@
import { ButtonGroup } from '@invoke-ai/ui-library';
import { SimpleStagingAreaToolbarMenu } from 'features/controlLayers/components/StagingArea/SimpleStagingAreaToolbarMenu';
import { StagingAreaToolbarDiscardAllButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardAllButton';
import { StagingAreaToolbarDiscardSelectedButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardSelectedButton';
import { StagingAreaToolbarImageCountButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarImageCountButton';
import { StagingAreaToolbarNextButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarNextButton';
import { StagingAreaToolbarPrevButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarPrevButton';
import { memo } from 'react';
export const SimpleStagingAreaToolbar = memo(() => {
return (
<>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<StagingAreaToolbarPrevButton />
<StagingAreaToolbarImageCountButton />
<StagingAreaToolbarNextButton />
</ButtonGroup>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<StagingAreaToolbarDiscardSelectedButton />
<SimpleStagingAreaToolbarMenu />
<StagingAreaToolbarDiscardAllButton />
</ButtonGroup>
</>
);
});
SimpleStagingAreaToolbar.displayName = 'SimpleStagingAreaToolbar';

View File

@@ -0,0 +1,17 @@
import { IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
import { StagingAreaToolbarMenuAutoSwitch } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarMenuAutoSwitch';
import { memo } from 'react';
import { PiDotsThreeBold } from 'react-icons/pi';
export const SimpleStagingAreaToolbarMenu = memo(() => {
return (
<Menu>
<MenuButton as={IconButton} icon={<PiDotsThreeBold />} colorScheme="invokeBlue" />
<MenuList>
<StagingAreaToolbarMenuAutoSwitch />
</MenuList>
</Menu>
);
});
SimpleStagingAreaToolbarMenu.displayName = 'SimpleStagingAreaToolbarMenu';

View File

@@ -1,50 +0,0 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectStagingAreaAutoSwitch,
settingsStagingAreaAutoSwitchChanged,
} from 'features/controlLayers/store/canvasSettingsSlice';
import { memo, useCallback } from 'react';
import { PiCaretLineRightBold, PiCaretRightBold, PiMoonBold } from 'react-icons/pi';
export const StagingAreaAutoSwitchButtons = memo(() => {
const autoSwitch = useAppSelector(selectStagingAreaAutoSwitch);
const dispatch = useAppDispatch();
const onClickOff = useCallback(() => {
dispatch(settingsStagingAreaAutoSwitchChanged('off'));
}, [dispatch]);
const onClickSwitchOnStart = useCallback(() => {
dispatch(settingsStagingAreaAutoSwitchChanged('switch_on_start'));
}, [dispatch]);
const onClickSwitchOnFinished = useCallback(() => {
dispatch(settingsStagingAreaAutoSwitchChanged('switch_on_finish'));
}, [dispatch]);
return (
<>
<IconButton
aria-label="Do not auto-switch"
tooltip="Do not auto-switch"
icon={<PiMoonBold />}
colorScheme={autoSwitch === 'off' ? 'invokeBlue' : 'base'}
onClick={onClickOff}
/>
<IconButton
aria-label="Switch on start"
tooltip="Switch on start"
icon={<PiCaretRightBold />}
colorScheme={autoSwitch === 'switch_on_start' ? 'invokeBlue' : 'base'}
onClick={onClickSwitchOnStart}
/>
<IconButton
aria-label="Switch on finish"
tooltip="Switch on finish"
icon={<PiCaretLineRightBold />}
colorScheme={autoSwitch === 'switch_on_finish' ? 'invokeBlue' : 'base'}
onClick={onClickSwitchOnFinished}
/>
</>
);
});
StagingAreaAutoSwitchButtons.displayName = 'StagingAreaAutoSwitchButtons';

View File

@@ -1,6 +1,7 @@
import { ButtonGroup, Flex } from '@invoke-ai/ui-library';
import { ButtonGroup } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
import { StagingAreaToolbarAcceptButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarAcceptButton';
import { StagingAreaToolbarDiscardAllButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardAllButton';
import { StagingAreaToolbarDiscardSelectedButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardSelectedButton';
@@ -11,22 +12,27 @@ import { StagingAreaToolbarPrevButton } from 'features/controlLayers/components/
import { StagingAreaToolbarSaveSelectedToGalleryButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveSelectedToGalleryButton';
import { StagingAreaToolbarToggleShowResultsButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarToggleShowResultsButton';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { memo } from 'react';
import { memo, useEffect } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { StagingAreaAutoSwitchButtons } from './StagingAreaAutoSwitchButtons';
export const StagingAreaToolbar = memo(() => {
const canvasManager = useCanvasManager();
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
const ctx = useCanvasSessionContext();
useEffect(() => {
return ctx.$selectedItemId.listen((id) => {
if (id !== null) {
document.getElementById(getQueueItemElementId(id))?.scrollIntoView();
}
});
}, [ctx.$selectedItemId]);
useHotkeys('meta+left', ctx.selectFirst, { preventDefault: true });
useHotkeys('meta+right', ctx.selectLast, { preventDefault: true });
return (
<Flex gap={2}>
<>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<StagingAreaToolbarPrevButton isDisabled={!shouldShowStagedImage} />
<StagingAreaToolbarImageCountButton />
@@ -38,14 +44,9 @@ export const StagingAreaToolbar = memo(() => {
<StagingAreaToolbarSaveSelectedToGalleryButton />
<StagingAreaToolbarMenu />
<StagingAreaToolbarDiscardSelectedButton isDisabled={!shouldShowStagedImage} />
</ButtonGroup>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<StagingAreaAutoSwitchButtons />
</ButtonGroup>
<ButtonGroup borderRadius="base" shadow="dark-lg">
<StagingAreaToolbarDiscardAllButton isDisabled={!shouldShowStagedImage} />
</ButtonGroup>
</Flex>
</>
);
});

View File

@@ -9,7 +9,7 @@ import { canvasSessionReset } from 'features/controlLayers/store/canvasStagingAr
import { selectBboxRect, selectSelectedEntityIdentifier } from 'features/controlLayers/store/selectors';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageNameToImageObject } from 'features/controlLayers/store/util';
import { useCancelQueueItemsByDestination } from 'features/queue/hooks/useCancelQueueItemsByDestination';
import { useDeleteQueueItemsByDestination } from 'features/queue/hooks/useDeleteQueueItemsByDestination';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
@@ -24,7 +24,7 @@ export const StagingAreaToolbarAcceptButton = memo(() => {
const shouldShowStagedImage = useStore(canvasManager.stagingArea.$shouldShowStagedImage);
const isCanvasFocused = useIsRegionFocused('canvas');
const selectedItemImageDTO = useStore(ctx.$selectedItemOutputImageDTO);
const cancelQueueItemsByDestination = useCancelQueueItemsByDestination();
const deleteQueueItemsByDestination = useDeleteQueueItemsByDestination();
const { t } = useTranslation();
@@ -41,13 +41,13 @@ export const StagingAreaToolbarAcceptButton = memo(() => {
dispatch(rasterLayerAdded({ overrides, isSelected: selectedEntityIdentifier?.type === 'raster_layer' }));
dispatch(canvasSessionReset());
cancelQueueItemsByDestination.trigger(ctx.session.id, { withToast: false });
deleteQueueItemsByDestination.trigger(ctx.session.id);
}, [
selectedItemImageDTO,
bboxRect,
dispatch,
selectedEntityIdentifier?.type,
cancelQueueItemsByDestination,
deleteQueueItemsByDestination,
ctx.session.id,
]);
@@ -68,8 +68,8 @@ export const StagingAreaToolbarAcceptButton = memo(() => {
icon={<PiCheckBold />}
onClick={acceptSelected}
colorScheme="invokeBlue"
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage || cancelQueueItemsByDestination.isDisabled}
isLoading={cancelQueueItemsByDestination.isLoading}
isDisabled={!selectedItemImageDTO || !shouldShowStagedImage || deleteQueueItemsByDestination.isDisabled}
isLoading={deleteQueueItemsByDestination.isLoading}
/>
);
});

View File

@@ -1,19 +1,27 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useCancelQueueItemsByDestination } from 'features/queue/hooks/useCancelQueueItemsByDestination';
import { canvasSessionReset, generateSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useDeleteQueueItemsByDestination } from 'features/queue/hooks/useDeleteQueueItemsByDestination';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
export const StagingAreaToolbarDiscardAllButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
const ctx = useCanvasSessionContext();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const cancelQueueItemsByDestination = useCancelQueueItemsByDestination();
const deleteQueueItemsByDestination = useDeleteQueueItemsByDestination();
const discardAll = useCallback(() => {
ctx.discardAll();
cancelQueueItemsByDestination.trigger(ctx.session.id, { withToast: false });
}, [cancelQueueItemsByDestination, ctx]);
deleteQueueItemsByDestination.trigger(ctx.session.id);
if (ctx.session.type === 'advanced') {
dispatch(canvasSessionReset());
} else {
// ctx.session.type === 'simple'
dispatch(generateSessionReset());
}
}, [deleteQueueItemsByDestination, ctx.session.id, ctx.session.type, dispatch]);
return (
<IconButton
@@ -22,8 +30,9 @@ export const StagingAreaToolbarDiscardAllButton = memo(({ isDisabled }: { isDisa
icon={<PiTrashSimpleBold />}
onClick={discardAll}
colorScheme="error"
isDisabled={isDisabled || cancelQueueItemsByDestination.isDisabled}
isLoading={cancelQueueItemsByDestination.isLoading}
fontSize={16}
isDisabled={isDisabled || deleteQueueItemsByDestination.isDisabled}
isLoading={deleteQueueItemsByDestination.isLoading}
/>
);
});

View File

@@ -1,14 +1,17 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch } from 'app/store/storeHooks';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { useCancelQueueItem } from 'features/queue/hooks/useCancelQueueItem';
import { canvasSessionReset, generateSessionReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useDeleteQueueItem } from 'features/queue/hooks/useDeleteQueueItem';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
export const StagingAreaToolbarDiscardSelectedButton = memo(({ isDisabled }: { isDisabled?: boolean }) => {
const dispatch = useAppDispatch();
const ctx = useCanvasSessionContext();
const cancelQueueItem = useCancelQueueItem();
const deleteQueueItem = useDeleteQueueItem();
const selectedItemId = useStore(ctx.$selectedItemId);
const { t } = useTranslation();
@@ -17,9 +20,17 @@ export const StagingAreaToolbarDiscardSelectedButton = memo(({ isDisabled }: { i
if (selectedItemId === null) {
return;
}
ctx.discard(selectedItemId);
await cancelQueueItem.trigger(selectedItemId, { withToast: false });
}, [selectedItemId, ctx, cancelQueueItem]);
await deleteQueueItem.trigger(selectedItemId);
const itemCount = ctx.$itemCount.get();
if (itemCount <= 1) {
if (ctx.session.type === 'advanced') {
dispatch(canvasSessionReset());
} else {
// ctx.session.type === 'simple'
dispatch(generateSessionReset());
}
}
}, [selectedItemId, deleteQueueItem, ctx.$itemCount, ctx.session.type, dispatch]);
return (
<IconButton
@@ -28,8 +39,9 @@ export const StagingAreaToolbarDiscardSelectedButton = memo(({ isDisabled }: { i
icon={<PiXBold />}
onClick={discardSelected}
colorScheme="invokeBlue"
isDisabled={selectedItemId === null || cancelQueueItem.isDisabled || isDisabled}
isLoading={cancelQueueItem.isLoading}
fontSize={16}
isDisabled={selectedItemId === null || deleteQueueItem.isDisabled || isDisabled}
isLoading={deleteQueueItem.isLoading}
/>
);
});

View File

@@ -1,13 +1,16 @@
import { IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library';
import { IconButton, Menu, MenuButton, MenuDivider, MenuList } from '@invoke-ai/ui-library';
import { StagingAreaToolbarMenuAutoSwitch } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarMenuAutoSwitch';
import { StagingAreaToolbarNewLayerFromImageMenuItems } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarMenuNewLayerFromImage';
import { memo } from 'react';
import { PiDotsThreeVerticalBold } from 'react-icons/pi';
import { PiDotsThreeBold } from 'react-icons/pi';
export const StagingAreaToolbarMenu = memo(() => {
return (
<Menu>
<MenuButton as={IconButton} icon={<PiDotsThreeVerticalBold />} colorScheme="invokeBlue" />
<MenuButton as={IconButton} icon={<PiDotsThreeBold />} colorScheme="invokeBlue" />
<MenuList>
<StagingAreaToolbarMenuAutoSwitch />
<MenuDivider />
<StagingAreaToolbarNewLayerFromImageMenuItems />
</MenuList>
</Menu>

View File

@@ -0,0 +1,34 @@
import { MenuItemOption, MenuOptionGroup } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { isAutoSwitchMode, useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { memo, useCallback } from 'react';
import { assert } from 'tsafe';
export const StagingAreaToolbarMenuAutoSwitch = memo(() => {
const ctx = useCanvasSessionContext();
const autoSwitch = useStore(ctx.$autoSwitch);
const onChange = useCallback(
(val: string | string[]) => {
assert(isAutoSwitchMode(val));
ctx.$autoSwitch.set(val);
},
[ctx.$autoSwitch]
);
return (
<MenuOptionGroup value={autoSwitch} onChange={onChange} title="Auto-Switch" type="radio">
<MenuItemOption value="off" closeOnSelect={false}>
Off
</MenuItemOption>
<MenuItemOption value="switch_on_start" closeOnSelect={false}>
Switch on Start
</MenuItemOption>
<MenuItemOption value="switch_on_finish" closeOnSelect={false}>
Switch on Finish
</MenuItemOption>
</MenuOptionGroup>
);
});
StagingAreaToolbarMenuAutoSwitch.displayName = 'StagingAreaToolbarMenuAutoSwitch';

View File

@@ -16,6 +16,7 @@ import {
rgRefImageAdded,
} from 'features/controlLayers/store/canvasSlice';
import { selectBase, selectMainModelConfig } from 'features/controlLayers/store/paramsSlice';
import { refImageAdded } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice, selectEntity } from 'features/controlLayers/store/selectors';
import type {
CanvasEntityIdentifier,
@@ -87,7 +88,7 @@ export const getDefaultRefImageConfig = (
return config;
}
if (base === 'flux-kontext' || (base === 'flux' && mainModelConfig?.name?.toLowerCase().includes('kontext'))) {
if (base === 'flux-kontext') {
const config = deepClone(initialFluxKontextReferenceImage);
config.model = zModelIdentifierField.parse(mainModelConfig);
return config;
@@ -185,6 +186,17 @@ export const useAddNewRegionalGuidanceWithARefImage = () => {
return func;
};
export const useAddGlobalReferenceImage = () => {
const { dispatch, getState } = useAppStore();
const func = useCallback(() => {
const config = getDefaultRefImageConfig(getState);
const overrides = { config };
dispatch(refImageAdded({ overrides }));
}, [dispatch, getState]);
return func;
};
export const useAddRefImageToExistingRegionalGuidance = (
entityIdentifier: CanvasEntityIdentifier<'regional_guidance'>
) => {

View File

@@ -71,29 +71,15 @@ export const useExportCanvasToPSD = () => {
const psdLayers: Layer[] = await Promise.all(
adapters.map((adapter, index) => {
const layer = adapter.state;
// Get the actual content bounds for this layer (excluding transparent regions)
const canvas = adapter.getCanvas();
const layerPosition = adapter.state.position;
const pixelRect = adapter.transformer.$pixelRect.get();
// Calculate the layer's content bounds in stage coordinates
const layerContentBounds = {
x: layerPosition.x + pixelRect.x,
y: layerPosition.y + pixelRect.y,
width: pixelRect.width,
height: pixelRect.height,
};
// Get the canvas cropped to the layer's actual content bounds
const canvas = adapter.getCanvas(layerContentBounds);
const layerDataPSD: Layer = {
name: layer.name || `Layer ${index + 1}`,
// Position relative to the visible rect, using the actual content bounds
left: Math.floor(layerContentBounds.x - visibleRect.x),
top: Math.floor(layerContentBounds.y - visibleRect.y),
right: Math.floor(layerContentBounds.x - visibleRect.x + canvas.width),
bottom: Math.floor(layerContentBounds.y - visibleRect.y + canvas.height),
left: Math.floor(layerPosition.x - visibleRect.x),
top: Math.floor(layerPosition.y - visibleRect.y),
right: Math.floor(layerPosition.x - visibleRect.x + canvas.width),
bottom: Math.floor(layerPosition.y - visibleRect.y + canvas.height),
opacity: Math.floor(layer.opacity * 255),
hidden: false,
blendMode: 'normal',

View File

@@ -424,15 +424,9 @@ export class CanvasEntityFilterer extends CanvasModuleBase {
// the user has applied the filter and the image has been adopted by the parent entity.
if (this.imageModule && this.imageModule.konva.group.parent === this.konva.group) {
this.imageModule.destroy();
this.imageModule = null;
}
// When a filter is applied, the image module is adopted by the parent entity as a "permanent" module.
// Null this reference to prevent the filter module from accidentally trying to destroy a module that the
// parent entity is now responsible for.
this.imageModule = null;
const initialFilterConfig = deepClone(this.$initialFilterConfig.get() ?? this.createInitialFilterConfig());
this.$filterConfig.set(initialFilterConfig);
this.$imageState.set(null);
this.$lastProcessedHash.set('');

View File

@@ -59,7 +59,6 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
$shouldShowStagedImage = atom<boolean>(true);
$isStaging = atom<boolean>(false);
$isPending = atom<boolean>(false);
constructor(manager: CanvasManager) {
super();
@@ -153,11 +152,7 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
this.$isStaging.set(this.manager.stateApi.runSelector(selectIsStaging));
};
connectToSession = (
$selectedItemId: Atom<number | null>,
$progressData: ProgressDataMap,
$isPending: Atom<boolean>
) => {
connectToSession = ($selectedItemId: Atom<number | null>, $progressData: ProgressDataMap) => {
const cb = (selectedItemId: number | null, progressData: Record<number, ProgressData | undefined>) => {
if (!selectedItemId) {
this.$imageSrc.set(null);
@@ -181,17 +176,7 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
cb($selectedItemId.get(), $progressData.get());
this.render();
// Sync the $isPending flag with the computed
const unsubIsPending = effect([$isPending], (isPending) => {
this.$isPending.set(isPending);
});
const unsubImageSrc = effect([$selectedItemId, $progressData], cb);
return () => {
unsubIsPending();
unsubImageSrc();
};
return effect([$selectedItemId, $progressData], cb);
};
private _getImageFromSrc = (
@@ -221,7 +206,6 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
const { x, y, width, height } = this.manager.stateApi.getBbox().rect;
const shouldShowStagedImage = this.$shouldShowStagedImage.get();
const isPending = this.$isPending.get();
this.konva.group.position({ x, y });
@@ -242,8 +226,7 @@ export class CanvasStagingAreaModule extends CanvasModuleBase {
} else {
this.image?.destroy();
this.image = null;
// Only show placeholder if there are pending items, otherwise show nothing
this.konva.placeholder.group.visible(isPending);
this.konva.placeholder.group.visible(true);
}
this.konva.group.visible(shouldShowStagedImage && this.$isStaging.get());

View File

@@ -270,7 +270,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
outputNodeId: string;
options?: RunGraphOptions;
}): Promise<ImageDTO> => {
const dependencies = buildRunGraphDependencies(this.store.dispatch, this.manager.socket);
const dependencies = buildRunGraphDependencies(this.store, this.manager.socket);
const { output } = await runGraph({
dependencies,

View File

@@ -1,41 +1,38 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { zRgbaColor } from 'features/controlLayers/store/types';
import { z } from 'zod/v4';
import type { RgbaColor } from 'features/controlLayers/store/types';
const zAutoSwitchMode = z.enum(['off', 'switch_on_start', 'switch_on_finish']);
const zCanvasSettingsState = z.object({
type CanvasSettingsState = {
/**
* Whether to show HUD (Heads-Up Display) on the canvas.
*/
showHUD: z.boolean().default(true),
showHUD: boolean;
/**
* Whether to clip lines and shapes to the generation bounding box. If disabled, lines and shapes will be clipped to
* the canvas bounds.
*/
clipToBbox: z.boolean().default(false),
clipToBbox: boolean;
/**
* Whether to show a dynamic grid on the canvas. If disabled, a checkerboard pattern will be shown instead.
*/
dynamicGrid: z.boolean().default(false),
dynamicGrid: boolean;
/**
* Whether to invert the scroll direction when adjusting the brush or eraser width with the scroll wheel.
*/
invertScrollForToolWidth: z.boolean().default(false),
invertScrollForToolWidth: boolean;
/**
* The width of the brush tool.
*/
brushWidth: z.int().gt(0).default(50),
brushWidth: number;
/**
* The width of the eraser tool.
*/
eraserWidth: z.int().gt(0).default(50),
eraserWidth: number;
/**
* The color to use when drawing lines or filling shapes.
*/
color: zRgbaColor.default({ r: 31, g: 160, b: 224, a: 1 }), // invokeBlue.500
color: RgbaColor;
/**
* Whether to composite inpainted/outpainted regions back onto the source image when saving canvas generations.
*
@@ -43,61 +40,70 @@ const zCanvasSettingsState = z.object({
*
* When `sendToCanvas` is disabled, this setting is ignored, masked regions will always be composited.
*/
outputOnlyMaskedRegions: z.boolean().default(true),
outputOnlyMaskedRegions: boolean;
/**
* Whether to automatically process the operations like filtering and auto-masking.
*/
autoProcess: z.boolean().default(true),
autoProcess: boolean;
/**
* The snap-to-grid setting for the canvas.
*/
snapToGrid: z.boolean().default(true),
snapToGrid: boolean;
/**
* Whether to show progress on the canvas when generating images.
*/
showProgressOnCanvas: z.boolean().default(true),
showProgressOnCanvas: boolean;
/**
* Whether to show the bounding box overlay on the canvas.
*/
bboxOverlay: z.boolean().default(false),
bboxOverlay: boolean;
/**
* Whether to preserve the masked region instead of inpainting it.
*/
preserveMask: z.boolean().default(false),
preserveMask: boolean;
/**
* Whether to show only raster layers while staging.
*/
isolatedStagingPreview: z.boolean().default(true),
isolatedStagingPreview: boolean;
/**
* Whether to show only the selected layer while filtering, transforming, or doing other operations.
*/
isolatedLayerPreview: z.boolean().default(true),
isolatedLayerPreview: boolean;
/**
* Whether to use pressure sensitivity for the brush and eraser tool when a pen device is used.
*/
pressureSensitivity: z.boolean().default(true),
pressureSensitivity: boolean;
/**
* Whether to show the rule of thirds composition guide overlay on the canvas.
*/
ruleOfThirds: z.boolean().default(false),
/**
* Whether to save all staging images to the gallery instead of keeping them as intermediate images.
*/
saveAllImagesToGallery: z.boolean().default(false),
/**
* The auto-switch mode for the canvas staging area.
*/
stagingAreaAutoSwitch: zAutoSwitchMode.default('switch_on_start'),
});
ruleOfThirds: boolean;
};
type CanvasSettingsState = z.infer<typeof zCanvasSettingsState>;
const getInitialState = () => zCanvasSettingsState.parse({});
const initialState: CanvasSettingsState = {
showHUD: true,
clipToBbox: false,
dynamicGrid: false,
brushWidth: 50,
eraserWidth: 50,
invertScrollForToolWidth: false,
color: { r: 31, g: 160, b: 224, a: 1 }, // invokeBlue.500
outputOnlyMaskedRegions: true,
autoProcess: true,
snapToGrid: true,
showProgressOnCanvas: true,
bboxOverlay: false,
preserveMask: false,
isolatedStagingPreview: true,
isolatedLayerPreview: true,
pressureSensitivity: true,
ruleOfThirds: false,
};
export const canvasSettingsSlice = createSlice({
name: 'canvasSettings',
initialState: getInitialState(),
initialState,
reducers: {
settingsClipToBboxChanged: (state, action: PayloadAction<CanvasSettingsState['clipToBbox']>) => {
settingsClipToBboxChanged: (state, action: PayloadAction<boolean>) => {
state.clipToBbox = action.payload;
},
settingsDynamicGridToggled: (state) => {
@@ -106,19 +112,16 @@ export const canvasSettingsSlice = createSlice({
settingsShowHUDToggled: (state) => {
state.showHUD = !state.showHUD;
},
settingsBrushWidthChanged: (state, action: PayloadAction<CanvasSettingsState['brushWidth']>) => {
settingsBrushWidthChanged: (state, action: PayloadAction<number>) => {
state.brushWidth = Math.round(action.payload);
},
settingsEraserWidthChanged: (state, action: PayloadAction<CanvasSettingsState['eraserWidth']>) => {
settingsEraserWidthChanged: (state, action: PayloadAction<number>) => {
state.eraserWidth = Math.round(action.payload);
},
settingsColorChanged: (state, action: PayloadAction<CanvasSettingsState['color']>) => {
settingsColorChanged: (state, action: PayloadAction<RgbaColor>) => {
state.color = action.payload;
},
settingsInvertScrollForToolWidthChanged: (
state,
action: PayloadAction<CanvasSettingsState['invertScrollForToolWidth']>
) => {
settingsInvertScrollForToolWidthChanged: (state, action: PayloadAction<boolean>) => {
state.invertScrollForToolWidth = action.payload;
},
settingsOutputOnlyMaskedRegionsToggled: (state) => {
@@ -151,15 +154,6 @@ export const canvasSettingsSlice = createSlice({
settingsRuleOfThirdsToggled: (state) => {
state.ruleOfThirds = !state.ruleOfThirds;
},
settingsSaveAllImagesToGalleryToggled: (state) => {
state.saveAllImagesToGallery = !state.saveAllImagesToGallery;
},
settingsStagingAreaAutoSwitchChanged: (
state,
action: PayloadAction<CanvasSettingsState['stagingAreaAutoSwitch']>
) => {
state.stagingAreaAutoSwitch = action.payload;
},
},
});
@@ -181,8 +175,6 @@ export const {
settingsIsolatedLayerPreviewToggled,
settingsPressureSensitivityToggled,
settingsRuleOfThirdsToggled,
settingsSaveAllImagesToGalleryToggled,
settingsStagingAreaAutoSwitchChanged,
} = canvasSettingsSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
@@ -192,7 +184,7 @@ const migrate = (state: any): any => {
export const canvasSettingsPersistConfig: PersistConfig<CanvasSettingsState> = {
name: canvasSettingsSlice.name,
initialState: getInitialState(),
initialState,
migrate,
persistDenylist: [],
};
@@ -217,5 +209,3 @@ export const selectIsolatedStagingPreview = createCanvasSettingsSelector((settin
export const selectIsolatedLayerPreview = createCanvasSettingsSelector((settings) => settings.isolatedLayerPreview);
export const selectPressureSensitivity = createCanvasSettingsSelector((settings) => settings.pressureSensitivity);
export const selectRuleOfThirds = createCanvasSettingsSelector((settings) => settings.ruleOfThirds);
export const selectSaveAllImagesToGallery = createCanvasSettingsSelector((settings) => settings.saveAllImagesToGallery);
export const selectStagingAreaAutoSwitch = createCanvasSettingsSelector((settings) => settings.stagingAreaAutoSwitch);

View File

@@ -32,6 +32,7 @@ import {
} from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify';
import { isMainModelBase, zModelIdentifierField } from 'features/nodes/types/common';
import { ASPECT_RATIO_MAP } from 'features/parameters/components/Bbox/constants';
import { API_BASE_MODELS } from 'features/parameters/types/constants';
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import type { IRect } from 'konva/lib/types';
@@ -68,13 +69,9 @@ import type {
T2IAdapterConfig,
} from './types';
import {
ASPECT_RATIO_MAP,
CHATGPT_ASPECT_RATIOS,
DEFAULT_ASPECT_RATIO_CONFIG,
FLUX_KONTEXT_ASPECT_RATIOS,
getEntityIdentifier,
getInitialCanvasState,
IMAGEN_ASPECT_RATIOS,
isChatGPT4oAspectRatioID,
isFluxKontextAspectRatioID,
isFLUXReduxConfig,
@@ -1103,21 +1100,62 @@ export const canvasSlice = createSlice({
(state.bbox.modelBase === 'imagen3' || state.bbox.modelBase === 'imagen4') &&
isImagenAspectRatioID(id)
) {
const { width, height } = IMAGEN_ASPECT_RATIOS[id];
state.bbox.rect.width = width;
state.bbox.rect.height = height;
// Imagen3 has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
if (id === '16:9') {
state.bbox.rect.width = 1408;
state.bbox.rect.height = 768;
} else if (id === '4:3') {
state.bbox.rect.width = 1280;
state.bbox.rect.height = 896;
} else if (id === '1:1') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1024;
} else if (id === '3:4') {
state.bbox.rect.width = 896;
state.bbox.rect.height = 1280;
} else if (id === '9:16') {
state.bbox.rect.width = 768;
state.bbox.rect.height = 1408;
}
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else if (state.bbox.modelBase === 'chatgpt-4o' && isChatGPT4oAspectRatioID(id)) {
const { width, height } = CHATGPT_ASPECT_RATIOS[id];
state.bbox.rect.width = width;
state.bbox.rect.height = height;
// gpt-image has specific output sizes that are not exactly the same as the aspect ratio. Need special handling.
if (id === '3:2') {
state.bbox.rect.width = 1536;
state.bbox.rect.height = 1024;
} else if (id === '1:1') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1024;
} else if (id === '2:3') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1536;
}
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else if (state.bbox.modelBase === 'flux-kontext' && isFluxKontextAspectRatioID(id)) {
const { width, height } = FLUX_KONTEXT_ASPECT_RATIOS[id];
state.bbox.rect.width = width;
state.bbox.rect.height = height;
if (id === '3:4') {
state.bbox.rect.width = 880;
state.bbox.rect.height = 1184;
} else if (id === '4:3') {
state.bbox.rect.width = 1184;
state.bbox.rect.height = 880;
} else if (id === '9:16') {
state.bbox.rect.width = 752;
state.bbox.rect.height = 1392;
} else if (id === '16:9') {
state.bbox.rect.width = 1392;
state.bbox.rect.height = 752;
} else if (id === '21:9') {
state.bbox.rect.width = 1568;
state.bbox.rect.height = 672;
} else if (id === '9:21') {
state.bbox.rect.width = 672;
state.bbox.rect.height = 1568;
} else if (id === '1:1') {
state.bbox.rect.width = 1024;
state.bbox.rect.height = 1024;
}
state.bbox.aspectRatio.value = state.bbox.rect.width / state.bbox.rect.height;
state.bbox.aspectRatio.isLocked = true;
} else {

View File

@@ -1,20 +1,16 @@
import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit';
import { EMPTY_ARRAY } from 'app/store/constants';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { canvasReset } from 'features/controlLayers/store/actions';
import { queueApi } from 'services/api/endpoints/queue';
type CanvasStagingAreaState = {
generateSessionId: string | null;
canvasSessionId: string | null;
canvasDiscardedQueueItems: number[];
};
const INITIAL_STATE: CanvasStagingAreaState = {
generateSessionId: null,
canvasSessionId: null,
canvasDiscardedQueueItems: [],
};
const getInitialState = (): CanvasStagingAreaState => deepClone(INITIAL_STATE);
@@ -30,20 +26,12 @@ export const canvasSessionSlice = createSlice({
generateSessionReset: (state) => {
state.generateSessionId = null;
},
canvasQueueItemDiscarded: (state, action: PayloadAction<{ itemId: number }>) => {
const { itemId } = action.payload;
if (!state.canvasDiscardedQueueItems.includes(itemId)) {
state.canvasDiscardedQueueItems.push(itemId);
}
},
canvasSessionIdChanged: (state, action: PayloadAction<{ id: string }>) => {
const { id } = action.payload;
state.canvasSessionId = id;
state.canvasDiscardedQueueItems = [];
},
canvasSessionReset: (state) => {
state.canvasSessionId = null;
state.canvasDiscardedQueueItems = [];
},
},
extraReducers(builder) {
@@ -53,13 +41,8 @@ export const canvasSessionSlice = createSlice({
},
});
export const {
generateSessionIdChanged,
generateSessionReset,
canvasSessionIdChanged,
canvasSessionReset,
canvasQueueItemDiscarded,
} = canvasSessionSlice.actions;
export const { generateSessionIdChanged, generateSessionReset, canvasSessionIdChanged, canvasSessionReset } =
canvasSessionSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
const migrate = (state: any): any => {
@@ -80,34 +63,4 @@ export const selectGenerateSessionId = createSelector(
selectCanvasSessionSlice,
({ generateSessionId }) => generateSessionId
);
export const buildSelectSessionQueueItems = (sessionId: string) =>
createSelector(
[queueApi.endpoints.listAllQueueItems.select({ destination: sessionId }), selectDiscardedItems],
({ data }, discardedItems) => {
if (!data) {
return EMPTY_ARRAY;
}
return data.filter(
({ status, item_id }) => status !== 'canceled' && status !== 'failed' && !discardedItems.includes(item_id)
);
}
);
export const selectIsStaging = (state: RootState) => {
const sessionId = selectCanvasSessionId(state);
if (!sessionId) {
return false;
}
const { data } = queueApi.endpoints.listAllQueueItems.select({ destination: sessionId })(state);
if (!data) {
return false;
}
const discardedItems = selectDiscardedItems(state);
return data.some(
({ status, item_id }) => status !== 'canceled' && status !== 'failed' && !discardedItems.includes(item_id)
);
};
const selectDiscardedItems = createSelector(
selectCanvasSessionSlice,
({ canvasDiscardedQueueItems }) => canvasDiscardedQueueItems
);
export const selectIsStaging = createSelector(selectCanvasSessionId, (canvasSessionId) => canvasSessionId !== null);

View File

@@ -11,7 +11,7 @@ type LoRAsState = {
loras: LoRA[];
};
const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
export const defaultLoRAConfig: Pick<LoRA, 'weight' | 'isEnabled'> = {
weight: 0.75,
isEnabled: true,
};

View File

@@ -1,22 +1,9 @@
import type { PayloadAction, Selector } from '@reduxjs/toolkit';
import { createSelector, createSlice } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { roundDownToMultiple, roundToMultiple } from 'common/util/roundDownToMultiple';
import { clamp } from 'es-toolkit/compat';
import type { AspectRatioID, ParamsState, RgbaColor } from 'features/controlLayers/store/types';
import {
ASPECT_RATIO_MAP,
CHATGPT_ASPECT_RATIOS,
DEFAULT_ASPECT_RATIO_CONFIG,
FLUX_KONTEXT_ASPECT_RATIOS,
getInitialParamsState,
IMAGEN_ASPECT_RATIOS,
isChatGPT4oAspectRatioID,
isFluxKontextAspectRatioID,
isImagenAspectRatioID,
} from 'features/controlLayers/store/types';
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import type { ParamsState, RgbaColor } from 'features/controlLayers/store/types';
import { getInitialParamsState } from 'features/controlLayers/store/types';
import { CLIP_SKIP_MAP } from 'features/parameters/types/constants';
import type {
ParameterCanvasCoherenceMode,
@@ -36,7 +23,6 @@ import type {
ParameterT5EncoderModel,
ParameterVAEModel,
} from 'features/parameters/types/parameterSchemas';
import { getGridSize, getIsSizeOptimal, getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { modelConfigsAdapterSelectors, selectModelConfigsQuery } from 'services/api/endpoints/models';
import { isNonRefinerMainModelConfig } from 'services/api/types';
@@ -200,129 +186,6 @@ export const paramsSlice = createSlice({
setCanvasCoherenceMinDenoise: (state, action: PayloadAction<number>) => {
state.canvasCoherenceMinDenoise = action.payload;
},
//#region Dimensions
widthChanged: (state, action: PayloadAction<{ width: number; updateAspectRatio?: boolean; clamp?: boolean }>) => {
const { width, updateAspectRatio, clamp } = action.payload;
const gridSize = getGridSize(state.model?.base);
state.dimensions.rect.width = clamp ? Math.max(roundDownToMultiple(width, gridSize), 64) : width;
if (state.dimensions.aspectRatio.isLocked) {
state.dimensions.rect.height = roundToMultiple(
state.dimensions.rect.width / state.dimensions.aspectRatio.value,
gridSize
);
}
if (updateAspectRatio || !state.dimensions.aspectRatio.isLocked) {
state.dimensions.aspectRatio.value = state.dimensions.rect.width / state.dimensions.rect.height;
state.dimensions.aspectRatio.id = 'Free';
state.dimensions.aspectRatio.isLocked = false;
}
},
heightChanged: (state, action: PayloadAction<{ height: number; updateAspectRatio?: boolean; clamp?: boolean }>) => {
const { height, updateAspectRatio, clamp } = action.payload;
const gridSize = getGridSize(state.model?.base);
state.dimensions.rect.height = clamp ? Math.max(roundDownToMultiple(height, gridSize), 64) : height;
if (state.dimensions.aspectRatio.isLocked) {
state.dimensions.rect.width = roundToMultiple(
state.dimensions.rect.height * state.dimensions.aspectRatio.value,
gridSize
);
}
if (updateAspectRatio || !state.dimensions.aspectRatio.isLocked) {
state.dimensions.aspectRatio.value = state.dimensions.rect.width / state.dimensions.rect.height;
state.dimensions.aspectRatio.id = 'Free';
state.dimensions.aspectRatio.isLocked = false;
}
},
aspectRatioLockToggled: (state) => {
state.dimensions.aspectRatio.isLocked = !state.dimensions.aspectRatio.isLocked;
},
aspectRatioIdChanged: (state, action: PayloadAction<{ id: AspectRatioID }>) => {
const { id } = action.payload;
state.dimensions.aspectRatio.id = id;
if (id === 'Free') {
state.dimensions.aspectRatio.isLocked = false;
} else if ((state.model?.base === 'imagen3' || state.model?.base === 'imagen4') && isImagenAspectRatioID(id)) {
const { width, height } = IMAGEN_ASPECT_RATIOS[id];
state.dimensions.rect.width = width;
state.dimensions.rect.height = height;
state.dimensions.aspectRatio.value = state.dimensions.rect.width / state.dimensions.rect.height;
state.dimensions.aspectRatio.isLocked = true;
} else if (state.model?.base === 'chatgpt-4o' && isChatGPT4oAspectRatioID(id)) {
const { width, height } = CHATGPT_ASPECT_RATIOS[id];
state.dimensions.rect.width = width;
state.dimensions.rect.height = height;
state.dimensions.aspectRatio.value = state.dimensions.rect.width / state.dimensions.rect.height;
state.dimensions.aspectRatio.isLocked = true;
} else if (state.model?.base === 'flux-kontext' && isFluxKontextAspectRatioID(id)) {
const { width, height } = FLUX_KONTEXT_ASPECT_RATIOS[id];
state.dimensions.rect.width = width;
state.dimensions.rect.height = height;
state.dimensions.aspectRatio.value = state.dimensions.rect.width / state.dimensions.rect.height;
state.dimensions.aspectRatio.isLocked = true;
} else {
state.dimensions.aspectRatio.isLocked = true;
state.dimensions.aspectRatio.value = ASPECT_RATIO_MAP[id].ratio;
const { width, height } = calculateNewSize(
state.dimensions.aspectRatio.value,
state.dimensions.rect.width * state.dimensions.rect.height,
state.model?.base
);
state.dimensions.rect.width = width;
state.dimensions.rect.height = height;
}
},
dimensionsSwapped: (state) => {
state.dimensions.aspectRatio.value = 1 / state.dimensions.aspectRatio.value;
if (state.dimensions.aspectRatio.id === 'Free') {
const newWidth = state.dimensions.rect.height;
const newHeight = state.dimensions.rect.width;
state.dimensions.rect.width = newWidth;
state.dimensions.rect.height = newHeight;
} else {
const { width, height } = calculateNewSize(
state.dimensions.aspectRatio.value,
state.dimensions.rect.width * state.dimensions.rect.height,
state.model?.base
);
state.dimensions.rect.width = width;
state.dimensions.rect.height = height;
state.dimensions.aspectRatio.id = ASPECT_RATIO_MAP[state.dimensions.aspectRatio.id].inverseID;
}
},
sizeOptimized: (state) => {
const optimalDimension = getOptimalDimension(state.model?.base);
if (state.dimensions.aspectRatio.isLocked) {
const { width, height } = calculateNewSize(
state.dimensions.aspectRatio.value,
optimalDimension * optimalDimension,
state.model?.base
);
state.dimensions.rect.width = width;
state.dimensions.rect.height = height;
} else {
state.dimensions.aspectRatio = deepClone(DEFAULT_ASPECT_RATIO_CONFIG);
state.dimensions.rect.width = optimalDimension;
state.dimensions.rect.height = optimalDimension;
}
},
syncedToOptimalDimension: (state) => {
const optimalDimension = getOptimalDimension(state.model?.base);
if (!getIsSizeOptimal(state.dimensions.rect.width, state.dimensions.rect.height, state.model?.base)) {
const bboxDims = calculateNewSize(
state.dimensions.aspectRatio.value,
optimalDimension * optimalDimension,
state.model?.base
);
state.dimensions.rect.width = bboxDims.width;
state.dimensions.rect.height = bboxDims.height;
}
},
paramsReset: (state) => resetState(state),
},
});
@@ -386,16 +249,6 @@ export const {
setRefinerNegativeAestheticScore,
setRefinerStart,
modelChanged,
// Dimensions
widthChanged,
heightChanged,
aspectRatioLockToggled,
aspectRatioIdChanged,
dimensionsSwapped,
sizeOptimized,
syncedToOptimalDimension,
paramsReset,
} = paramsSlice.actions;
@@ -422,16 +275,7 @@ export const selectIsSD3 = createParamsSelector((params) => params.model?.base =
export const selectIsCogView4 = createParamsSelector((params) => params.model?.base === 'cogview4');
export const selectIsImagen3 = createParamsSelector((params) => params.model?.base === 'imagen3');
export const selectIsImagen4 = createParamsSelector((params) => params.model?.base === 'imagen4');
export const selectIsFluxKontextApi = createParamsSelector((params) => params.model?.base === 'flux-kontext');
export const selectIsFluxKontext = createParamsSelector((params) => {
if (params.model?.base === 'flux-kontext') {
return true;
}
if (params.model?.base === 'flux' && params.model?.name.toLowerCase().includes('kontext')) {
return true;
}
return false;
});
export const selectIsFluxKontext = createParamsSelector((params) => params.model?.base === 'flux-kontext');
export const selectIsChatGPT4o = createParamsSelector((params) => params.model?.base === 'chatgpt-4o');
export const selectModel = createParamsSelector((params) => params.model);
@@ -467,8 +311,8 @@ export const selectNegativePrompt = createParamsSelector((params) => params.nega
export const selectNegativePromptWithFallback = createParamsSelector((params) => params.negativePrompt ?? '');
export const selectHasNegativePrompt = createParamsSelector((params) => params.negativePrompt !== null);
export const selectModelSupportsNegativePrompt = createSelector(
[selectIsFLUX, selectIsChatGPT4o, selectIsFluxKontext],
(isFLUX, isChatGPT4o, isFluxKontext) => !isFLUX && !isChatGPT4o && !isFluxKontext
[selectIsFLUX, selectIsChatGPT4o],
(isFLUX, isChatGPT4o) => !isFLUX && !isChatGPT4o
);
export const selectPositivePrompt2 = createParamsSelector((params) => params.positivePrompt2);
export const selectNegativePrompt2 = createParamsSelector((params) => params.negativePrompt2);
@@ -498,12 +342,6 @@ export const selectRefinerScheduler = createParamsSelector((params) => params.re
export const selectRefinerStart = createParamsSelector((params) => params.refinerStart);
export const selectRefinerSteps = createParamsSelector((params) => params.refinerSteps);
export const selectWidth = createParamsSelector((params) => params.dimensions.rect.width);
export const selectHeight = createParamsSelector((params) => params.dimensions.rect.height);
export const selectAspectRatioID = createParamsSelector((params) => params.dimensions.aspectRatio.id);
export const selectAspectRatioValue = createParamsSelector((params) => params.dimensions.aspectRatio.value);
export const selectAspectRatioIsLocked = createParamsSelector((params) => params.dimensions.aspectRatio.isLocked);
export const selectMainModelConfig = createSelector(
selectModelConfigsQuery,
selectParamsSlice,

View File

@@ -5,15 +5,10 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { PersistConfig, RootState } from 'app/store/store';
import { clamp } from 'es-toolkit/compat';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { canvasMetadataRecalled } from 'features/controlLayers/store/canvasSlice';
import type { FLUXReduxImageInfluence, RefImagesState } from 'features/controlLayers/store/types';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type {
ChatGPT4oModelConfig,
FLUXKontextModelConfig,
FLUXReduxModelConfig,
ImageDTO,
IPAdapterModelConfig,
} from 'services/api/types';
import type { ApiModelConfig, FLUXReduxModelConfig, ImageDTO, IPAdapterModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import type { PartialDeep } from 'type-fest';
@@ -23,7 +18,6 @@ import {
getReferenceImageState,
imageDTOToImageWithDims,
initialChatGPT4oReferenceImage,
initialFluxKontextReferenceImage,
initialFLUXRedux,
initialIPAdapter,
} from './util';
@@ -53,15 +47,9 @@ export const refImagesSlice = createSlice({
payload: { ...payload, id: getPrefixedId('reference_image') },
}),
},
refImagesRecalled: (state, action: PayloadAction<{ entities: RefImageState[]; replace: boolean }>) => {
const { entities, replace } = action.payload;
if (replace) {
state.entities = entities;
state.isPanelOpen = false;
state.selectedEntityId = null;
} else {
state.entities.push(...entities);
}
refImageRecalled: (state, action: PayloadAction<{ data: RefImageState }>) => {
const { data } = action.payload;
state.entities.push(data);
},
refImageImageChanged: (state, action: PayloadActionWithId<{ imageDTO: ImageDTO | null }>) => {
const { id, imageDTO } = action.payload;
@@ -98,9 +86,7 @@ export const refImagesSlice = createSlice({
},
refImageModelChanged: (
state,
action: PayloadActionWithId<{
modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ChatGPT4oModelConfig | FLUXKontextModelConfig | null;
}>
action: PayloadActionWithId<{ modelConfig: IPAdapterModelConfig | FLUXReduxModelConfig | ApiModelConfig | null }>
) => {
const { id, modelConfig } = action.payload;
const entity = selectRefImageEntity(state, id);
@@ -135,19 +121,6 @@ export const refImagesSlice = createSlice({
return;
}
if (
entity.config.model.base === 'flux-kontext' ||
(entity.config.model.base === 'flux' && entity.config.model.name?.toLowerCase().includes('kontext'))
) {
// Switching to flux-kontext ref image
entity.config = {
...initialFluxKontextReferenceImage,
image: entity.config.image,
model: entity.config.model,
};
return;
}
if (entity.config.model.type === 'flux_redux') {
// Switching to flux_redux
entity.config = {
@@ -238,16 +211,14 @@ export const refImagesSlice = createSlice({
}
state.selectedEntityId = id;
},
refImageIsEnabledToggled: (state, action: PayloadActionWithId) => {
const { id } = action.payload;
const entity = selectRefImageEntity(state, id);
if (!entity) {
return;
}
entity.isEnabled = !entity.isEnabled;
},
refImagesReset: () => getInitialRefImagesState(),
},
extraReducers(builder) {
builder.addCase(canvasMetadataRecalled, (state, action) => {
const { referenceImages } = action.payload;
state.entities = referenceImages;
});
},
});
export const {
@@ -261,8 +232,6 @@ export const {
refImageIPAdapterWeightChanged,
refImageIPAdapterBeginEndStepPctChanged,
refImageFLUXReduxImageInfluenceChanged,
refImageIsEnabledToggled,
refImagesRecalled,
} = refImagesSlice.actions;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
@@ -287,6 +256,12 @@ export const selectRefImageEntityIds = createMemoizedSelector(selectReferenceIma
);
export const selectRefImageEntity = (state: RefImagesState, id: string) =>
state.entities.find((entity) => entity.id === id) ?? null;
export const selectSelectedRefEntity = createSelector(selectRefImagesSlice, (state) => {
if (!state.selectedEntityId) {
return null;
}
return selectRefImageEntity(state, state.selectedEntityId);
});
export function selectRefImageEntityOrThrow(state: RefImagesState, id: string, caller: string): RefImageState {
const entity = selectRefImageEntity(state, id);

View File

@@ -2,6 +2,7 @@ import type { Selector } from '@reduxjs/toolkit';
import { createSelector } from '@reduxjs/toolkit';
import type { RootState } from 'app/store/store';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectReferenceImageEntities } from 'features/controlLayers/store/refImagesSlice';
import type {
CanvasControlLayerState,
CanvasEntityIdentifier,
@@ -22,7 +23,8 @@ import { assert } from 'tsafe';
*/
export const selectCanvasSlice = (state: RootState) => state.canvas.present;
const createCanvasSelector = <T>(selector: Selector<CanvasState, T>) => createSelector(selectCanvasSlice, selector);
export const createCanvasSelector = <T>(selector: Selector<CanvasState, T>) =>
createSelector(selectCanvasSlice, selector);
/**
* Selects the total canvas entity count:
@@ -65,6 +67,36 @@ export const selectActiveRegionalGuidanceEntities = createSelector(selectRegiona
entities.filter(isVisibleEntity)
);
/**
* Selects the total _active_ canvas entity count:
* - Regions
* - IP adapters
* - Raster layers
* - Control layers
* - Inpaint masks
*
* Active entities are those that are enabled and have at least one object.
*/
export const selectEntityCountActive = createSelector(
selectActiveRasterLayerEntities,
selectActiveControlLayerEntities,
selectActiveInpaintMaskEntities,
selectActiveRegionalGuidanceEntities,
(
activeRasterLayerEntities,
activeControlLayerEntities,
activeInpaintMaskEntities,
activeRegionalGuidanceEntities
) => {
return (
activeRasterLayerEntities.length +
activeControlLayerEntities.length +
activeInpaintMaskEntities.length +
activeRegionalGuidanceEntities.length
);
}
);
/**
* Selects if the canvas has any entities.
*/
@@ -345,8 +377,10 @@ export const selectBboxModelBase = createSelector(selectBbox, (bbox) => bbox.mod
export const selectCanvasMetadata = createSelector(
selectCanvasSlice,
(canvas): { canvas_v2_metadata: CanvasMetadata } => {
selectReferenceImageEntities,
(canvas, refImageEntities): { canvas_v2_metadata: CanvasMetadata } => {
const canvas_v2_metadata: CanvasMetadata = {
referenceImages: refImageEntities,
controlLayers: selectAllEntitiesOfType(canvas, 'control_layer'),
inpaintMasks: selectAllEntitiesOfType(canvas, 'inpaint_mask'),
rasterLayers: selectAllEntitiesOfType(canvas, 'raster_layer'),
@@ -356,6 +390,23 @@ export const selectCanvasMetadata = createSelector(
}
);
export const selectIsCanvasEmpty = createCanvasSelector(
({ controlLayers, inpaintMasks, rasterLayers, regionalGuidance }) => {
// Check it all manually - could use lodash isEqual, but this selector will be called very often!
// Also note - we do not care about ref images, as they are technically not part of canvas
return (
controlLayers.entities.length === 0 &&
controlLayers.isHidden === false &&
inpaintMasks.entities.length === 0 &&
inpaintMasks.isHidden === false &&
rasterLayers.entities.length === 0 &&
rasterLayers.isHidden === false &&
regionalGuidance.entities.length === 0 &&
regionalGuidance.isHidden === false
);
}
);
/**
* Selects whether all non-raster layer categories (control layers, inpaint masks, regional guidance) are hidden.
* This is used to determine the state of the toggle button that shows/hides all non-raster layers.

Some files were not shown because too many files have changed in this diff Show More