Compare commits

..

5 Commits

Author SHA1 Message Date
Mary Hipp Rogers
649596cec5 fix 1:1 ratio (#8127)
Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-06-25 19:39:56 -04:00
psychedelicious
45aa84c01a feat: add user_label to FieldIdentifier (#8126)
Co-authored-by: Mary Hipp Rogers <maryhipp@gmail.com>
2025-06-25 09:48:15 -04:00
Mary Hipp Rogers
064d5787c9 Flux Kontext UI support (#8111)
* add support for flux-kontext models in nodes

* flux kontext in canvas

* add aspect ratio support

* lint

* restore aspect ratio logic

* more linting

* typegen

* fix typegen

---------

Co-authored-by: Mary Hipp <maryhipp@Marys-Air.lan>
2025-06-25 09:46:58 -04:00
psychedelicious
d81b23adff fix(nodes): ensure each invocation overrides _original_model_fields with own field data 2025-06-19 09:57:11 -04:00
psychedelicious
c72480fd1b fix: opencv dependency conflict (#8095)
* build: prevent `opencv-python` from being installed

Fixes this error: `AttributeError: module 'cv2.ximgproc' has no attribute 'thinning'`

`opencv-contrib-python` supersedes `opencv-python`, providing the same API + additional features. The two packages should not be installed at the same time to avoid conflicts and/or errors.

The `invisible-watermark` package requires `opencv-python`, but we require the contrib variant.

This change updates `pyproject.toml` to prevent `opencv-python` from ever being installed using a `uv` features called dependency overrides.

* feat(ui): data viewer supports disabling wrap

* feat(api): list _all_ pkgs in app deps endpoint

* chore(ui): typegen

* feat(ui): update about modal to display new full deps list

* chore: uv lock
2025-06-10 08:34:00 -04:00
736 changed files with 23227 additions and 35143 deletions

View File

@@ -3,15 +3,15 @@ description: Installs frontend dependencies with pnpm, with caching
runs:
using: 'composite'
steps:
- name: setup node 20
- name: setup node 18
uses: actions/setup-node@v4
with:
node-version: '20'
node-version: '18'
- name: setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10
version: 8.15.6
run_install: false
- name: get pnpm store directory

1
.gitignore vendored
View File

@@ -180,7 +180,6 @@ cython_debug/
# Scratch folder
.scratch/
.vscode/
.zed/
# source installer files
installer/*zip

View File

@@ -297,7 +297,7 @@ Migration logic is in [migrations.ts].
<!-- links -->
[pydantic]: https://github.com/pydantic/pydantic 'pydantic'
[zod]: https://github.com/colinhacks/zod 'zod/v4'
[zod]: https://github.com/colinhacks/zod 'zod'
[openapi-types]: https://github.com/kogosoftwarellc/open-api/tree/main/packages/openapi-types 'openapi-types'
[reactflow]: https://github.com/xyflow/xyflow 'reactflow'
[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions

View File

@@ -35,7 +35,7 @@ More detail on system requirements can be found [here](./requirements.md).
## Step 2: Download
Download the most recent launcher for your operating system:
Download the most launcher for your operating system:
- [Download for Windows](https://download.invoke.ai/Invoke%20Community%20Edition.exe)
- [Download for macOS](https://download.invoke.ai/Invoke%20Community%20Edition.dmg)

View File

@@ -1,12 +1,21 @@
from fastapi import Body, HTTPException
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
class AddImagesToBoardResult(BaseModel):
board_id: str = Field(description="The id of the board the images were added to")
added_image_names: list[str] = Field(description="The image names that were added to the board")
class RemoveImagesFromBoardResult(BaseModel):
removed_image_names: list[str] = Field(description="The image names that were removed from their board")
@board_images_router.post(
"/",
operation_id="add_image_to_board",
@@ -14,26 +23,17 @@ board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
201: {"description": "The image was added to a board successfully"},
},
status_code=201,
response_model=AddImagesToBoardResult,
)
async def add_image_to_board(
board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"),
) -> AddImagesToBoardResult:
):
"""Creates a board_image"""
try:
added_images: set[str] = set()
affected_boards: set[str] = set()
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
added_images.add(image_name)
affected_boards.add(board_id)
affected_boards.add(old_board_id)
return AddImagesToBoardResult(
added_images=list(added_images),
affected_boards=list(affected_boards),
result = ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id, image_name=image_name
)
return result
except Exception:
raise HTTPException(status_code=500, detail="Failed to add image to board")
@@ -45,25 +45,14 @@ async def add_image_to_board(
201: {"description": "The image was removed from the board successfully"},
},
status_code=201,
response_model=RemoveImagesFromBoardResult,
)
async def remove_image_from_board(
image_name: str = Body(description="The name of the image to remove", embed=True),
) -> RemoveImagesFromBoardResult:
):
"""Removes an image from its board, if it had one"""
try:
removed_images: set[str] = set()
affected_boards: set[str] = set()
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_images.add(image_name)
affected_boards.add("none")
affected_boards.add(old_board_id)
return RemoveImagesFromBoardResult(
removed_images=list(removed_images),
affected_boards=list(affected_boards),
)
result = ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
return result
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove image from board")
@@ -83,25 +72,16 @@ async def add_images_to_board(
) -> AddImagesToBoardResult:
"""Adds a list of images to a board"""
try:
added_images: set[str] = set()
affected_boards: set[str] = set()
added_image_names: list[str] = []
for image_name in image_names:
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id,
image_name=image_name,
board_id=board_id, image_name=image_name
)
added_images.add(image_name)
affected_boards.add(board_id)
affected_boards.add(old_board_id)
added_image_names.append(image_name)
except Exception:
pass
return AddImagesToBoardResult(
added_images=list(added_images),
affected_boards=list(affected_boards),
)
return AddImagesToBoardResult(board_id=board_id, added_image_names=added_image_names)
except Exception:
raise HTTPException(status_code=500, detail="Failed to add images to board")
@@ -120,20 +100,13 @@ async def remove_images_from_board(
) -> RemoveImagesFromBoardResult:
"""Removes a list of images from their board, if they had one"""
try:
removed_images: set[str] = set()
affected_boards: set[str] = set()
removed_image_names: list[str] = []
for image_name in image_names:
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_images.add(image_name)
affected_boards.add("none")
affected_boards.add(old_board_id)
removed_image_names.append(image_name)
except Exception:
pass
return RemoveImagesFromBoardResult(
removed_images=list(removed_images),
affected_boards=list(affected_boards),
)
return RemoveImagesFromBoardResult(removed_image_names=removed_image_names)
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove images from board")

View File

@@ -14,17 +14,10 @@ from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_i
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import (
DeleteImagesResult,
ImageDTO,
ImageUrlsDTO,
StarredImagesResult,
UnstarredImagesResult,
)
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.util.controlnet_utils import heuristic_resize_fast
@@ -72,7 +65,7 @@ async def upload_image(
resize_to: Optional[str] = Body(
default=None,
description=f"Dimensions to resize the image to, must be stringified tuple of 2 integers. Max total pixel count: {ResizeToDimensions.MAX_SIZE}",
examples=['"[1024,1024]"'],
example='"[1024,1024]"',
),
metadata: Optional[str] = Body(
default=None,
@@ -106,9 +99,7 @@ async def upload_image(
raise HTTPException(status_code=400, detail="Invalid resize_to format or size")
try:
# heuristic_resize_fast expects an RGB or RGBA image
pil_rgba = pil_image.convert("RGBA")
np_image = pil_to_np(pil_rgba)
np_image = pil_to_np(pil_image)
np_image = heuristic_resize_fast(np_image, (resize_dims.width, resize_dims.height))
pil_image = np_to_pil(np_image)
except Exception:
@@ -160,30 +151,18 @@ async def create_image_upload_entry(
raise HTTPException(status_code=501, detail="Not implemented")
@images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult)
@images_router.delete("/i/{image_name}", operation_id="delete_image")
async def delete_image(
image_name: str = Path(description="The name of the image to delete"),
) -> DeleteImagesResult:
) -> None:
"""Deletes an image"""
deleted_images: set[str] = set()
affected_boards: set[str] = set()
try:
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
board_id = image_dto.board_id or "none"
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add(board_id)
except Exception:
# TODO: Does this need any exception handling at all?
pass
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
@images_router.delete("/intermediates", operation_id="clear_intermediates")
async def clear_intermediates() -> int:
@@ -395,32 +374,31 @@ async def list_image_dtos(
return image_dtos
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult)
class DeleteImagesFromListResult(BaseModel):
deleted_images: list[str]
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesFromListResult)
async def delete_images_from_list(
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
) -> DeleteImagesResult:
) -> DeleteImagesFromListResult:
try:
deleted_images: set[str] = set()
affected_boards: set[str] = set()
deleted_images: list[str] = []
for image_name in image_names:
try:
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
board_id = image_dto.board_id or "none"
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add(board_id)
deleted_images.append(image_name)
except Exception:
pass
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
return DeleteImagesFromListResult(deleted_images=deleted_images)
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete images")
@images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult)
async def delete_uncategorized_images() -> DeleteImagesResult:
@images_router.delete(
"/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesFromListResult
)
async def delete_uncategorized_images() -> DeleteImagesFromListResult:
"""Deletes all images that are uncategorized"""
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
@@ -428,19 +406,14 @@ async def delete_uncategorized_images() -> DeleteImagesResult:
)
try:
deleted_images: set[str] = set()
affected_boards: set[str] = set()
deleted_images: list[str] = []
for image_name in image_names:
try:
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add("none")
deleted_images.append(image_name)
except Exception:
pass
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
return DeleteImagesFromListResult(deleted_images=deleted_images)
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete images")
@@ -449,50 +422,36 @@ class ImagesUpdatedFromListResult(BaseModel):
updated_image_names: list[str] = Field(description="The image names that were updated")
@images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult)
@images_router.post("/star", operation_id="star_images_in_list", response_model=ImagesUpdatedFromListResult)
async def star_images_in_list(
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
) -> StarredImagesResult:
) -> ImagesUpdatedFromListResult:
try:
starred_images: set[str] = set()
affected_boards: set[str] = set()
updated_image_names: list[str] = []
for image_name in image_names:
try:
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=True)
)
starred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=True))
updated_image_names.append(image_name)
except Exception:
pass
return StarredImagesResult(
starred_images=list(starred_images),
affected_boards=list(affected_boards),
)
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
except Exception:
raise HTTPException(status_code=500, detail="Failed to star images")
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult)
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=ImagesUpdatedFromListResult)
async def unstar_images_in_list(
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
) -> UnstarredImagesResult:
) -> ImagesUpdatedFromListResult:
try:
unstarred_images: set[str] = set()
affected_boards: set[str] = set()
updated_image_names: list[str] = []
for image_name in image_names:
try:
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=False)
)
unstarred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
ApiDependencies.invoker.services.images.update(image_name, changes=ImageRecordChanges(starred=False))
updated_image_names.append(image_name)
except Exception:
pass
return UnstarredImagesResult(
unstarred_images=list(unstarred_images),
affected_boards=list(affected_boards),
)
return ImagesUpdatedFromListResult(updated_image_names=updated_image_names)
except Exception:
raise HTTPException(status_code=500, detail="Failed to unstar images")
@@ -563,61 +522,3 @@ async def get_bulk_download_item(
return response
except Exception:
raise HTTPException(status_code=404)
@images_router.get("/names", operation_id="get_image_names")
async def get_image_names(
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
board_id: Optional[str] = Query(
default=None,
description="The board id to filter by. Use 'none' to find images without a board.",
),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates"""
try:
result = ApiDependencies.invoker.services.images.get_image_names(
starred_first=starred_first,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
return result
except Exception:
raise HTTPException(status_code=500, detail="Failed to get image names")
@images_router.post(
"/images_by_names",
operation_id="get_images_by_names",
responses={200: {"model": list[ImageDTO]}},
)
async def get_images_by_names(
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
) -> list[ImageDTO]:
"""Gets image DTOs for the specified image names. Maintains order of input names."""
try:
image_service = ApiDependencies.invoker.services.images
# Fetch DTOs preserving the order of requested names
image_dtos: list[ImageDTO] = []
for name in image_names:
try:
dto = image_service.get_dto(name)
image_dtos.append(dto)
except Exception:
# Skip missing images - they may have been deleted between name fetch and DTO fetch
continue
return image_dtos
except Exception:
raise HTTPException(status_code=500, detail="Failed to get image DTOs")

View File

@@ -41,7 +41,6 @@ from invokeai.backend.model_manager.starter_models import (
STARTER_BUNDLES,
STARTER_MODELS,
StarterModel,
StarterModelBundle,
StarterModelWithoutDependencies,
)
@@ -292,7 +291,7 @@ async def get_hugging_face_models(
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
changes: Annotated[ModelRecordChanges, Body(description="Model config", examples=[example_model_input])],
changes: Annotated[ModelRecordChanges, Body(description="Model config", example=example_model_input)],
) -> AnyModelConfig:
"""Update a model's config."""
logger = ApiDependencies.invoker.services.logger
@@ -450,7 +449,7 @@ async def install_model(
access_token: Optional[str] = Query(description="access token for the remote resource", default=None),
config: ModelRecordChanges = Body(
description="Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type ",
examples=[{"name": "string", "description": "string"}],
example={"name": "string", "description": "string"},
),
) -> ModelInstallJob:
"""Install a model using a string identifier.
@@ -800,7 +799,7 @@ async def convert_model(
class StarterModelResponse(BaseModel):
starter_models: list[StarterModel]
starter_bundles: dict[str, StarterModelBundle]
starter_bundles: dict[str, list[StarterModel]]
def get_is_installed(
@@ -834,7 +833,7 @@ async def get_starter_models() -> StarterModelResponse:
model.dependencies = missing_deps
for bundle in starter_bundles.values():
for model in bundle.models:
for model in bundle:
model.is_installed = get_is_installed(model, installed_models)
# Remove already-installed dependencies
missing_deps: list[StarterModelWithoutDependencies] = []

View File

@@ -1,6 +1,6 @@
from typing import Optional
from fastapi import Body, HTTPException, Path, Query
from fastapi import Body, Path, Query
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field
@@ -14,15 +14,13 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByBatchIDsResult,
CancelByDestinationResult,
ClearResult,
DeleteAllExceptCurrentResult,
DeleteByDestinationResult,
EnqueueBatchResult,
FieldIdentifier,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemNotFoundError,
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.pagination import CursorPaginatedResults
@@ -60,19 +58,17 @@ async def enqueue_batch(
),
) -> EnqueueBatchResult:
"""Processes a batch and enqueues the output graphs for execution."""
try:
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while enqueuing batch: {e}")
return await ApiDependencies.invoker.services.session_queue.enqueue_batch(
queue_id=queue_id, batch=batch, prepend=prepend
)
@session_queue_router.get(
"/{queue_id}/list",
operation_id="list_queue_items",
responses={
200: {"model": CursorPaginatedResults[SessionQueueItem]},
200: {"model": CursorPaginatedResults[SessionQueueItemDTO]},
},
)
async def list_queue_items(
@@ -81,42 +77,12 @@ async def list_queue_items(
status: Optional[QUEUE_ITEM_STATUS] = Query(default=None, description="The status of items to fetch"),
cursor: Optional[int] = Query(default=None, description="The pagination cursor"),
priority: int = Query(default=0, description="The pagination cursor priority"),
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> CursorPaginatedResults[SessionQueueItem]:
"""Gets cursor-paginated queue items"""
) -> CursorPaginatedResults[SessionQueueItemDTO]:
"""Gets all queue items (without graphs)"""
try:
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id,
limit=limit,
status=status,
cursor=cursor,
priority=priority,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all items: {e}")
@session_queue_router.get(
"/{queue_id}/list_all",
operation_id="list_all_queue_items",
responses={
200: {"model": list[SessionQueueItem]},
},
)
async def list_all_queue_items(
queue_id: str = Path(description="The queue id to perform this operation on"),
destination: Optional[str] = Query(default=None, description="The destination of queue items to fetch"),
) -> list[SessionQueueItem]:
"""Gets all queue items"""
try:
return ApiDependencies.invoker.services.session_queue.list_all_queue_items(
queue_id=queue_id,
destination=destination,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue items: {e}")
return ApiDependencies.invoker.services.session_queue.list_queue_items(
queue_id=queue_id, limit=limit, status=status, cursor=cursor, priority=priority
)
@session_queue_router.put(
@@ -128,10 +94,7 @@ async def resume(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Resumes session processor"""
try:
return ApiDependencies.invoker.services.session_processor.resume()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while resuming queue: {e}")
return ApiDependencies.invoker.services.session_processor.resume()
@session_queue_router.put(
@@ -143,10 +106,7 @@ async def Pause(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatus:
"""Pauses session processor"""
try:
return ApiDependencies.invoker.services.session_processor.pause()
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while pausing queue: {e}")
return ApiDependencies.invoker.services.session_processor.pause()
@session_queue_router.put(
@@ -158,25 +118,7 @@ async def cancel_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> CancelAllExceptCurrentResult:
"""Immediately cancels all queue items except in-processing items"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling all except current: {e}")
@session_queue_router.put(
"/{queue_id}/delete_all_except_current",
operation_id="delete_all_except_current",
responses={200: {"model": DeleteAllExceptCurrentResult}},
)
async def delete_all_except_current(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> DeleteAllExceptCurrentResult:
"""Immediately deletes all queue items except in-processing items"""
try:
return ApiDependencies.invoker.services.session_queue.delete_all_except_current(queue_id=queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting all except current: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_all_except_current(queue_id=queue_id)
@session_queue_router.put(
@@ -189,12 +131,7 @@ async def cancel_by_batch_ids(
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
) -> CancelByBatchIDsResult:
"""Immediately cancels all queue items from the given batch ids"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(
queue_id=queue_id, batch_ids=batch_ids
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by batch id: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
@session_queue_router.put(
@@ -207,12 +144,9 @@ async def cancel_by_destination(
destination: str = Query(description="The destination to cancel all queue items for"),
) -> CancelByDestinationResult:
"""Immediately cancels all queue items with the given origin"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling by destination: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_by_destination(
queue_id=queue_id, destination=destination
)
@session_queue_router.put(
@@ -225,10 +159,7 @@ async def retry_items_by_id(
item_ids: list[int] = Body(description="The queue item ids to retry"),
) -> RetryItemsResult:
"""Immediately cancels all queue items with the given origin"""
try:
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while retrying queue items: {e}")
return ApiDependencies.invoker.services.session_queue.retry_items_by_id(queue_id=queue_id, item_ids=item_ids)
@session_queue_router.put(
@@ -242,14 +173,11 @@ async def clear(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> ClearResult:
"""Clears the queue entirely, immediately canceling the currently-executing session"""
try:
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while clearing queue: {e}")
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
@session_queue_router.put(
@@ -263,10 +191,7 @@ async def prune(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> PruneResult:
"""Prunes all completed or errored queue items"""
try:
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while pruning queue: {e}")
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
@session_queue_router.get(
@@ -280,10 +205,7 @@ async def get_current_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the currently execution queue item"""
try:
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
@session_queue_router.get(
@@ -297,10 +219,7 @@ async def get_next_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the next queue item, without executing it"""
try:
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
@session_queue_router.get(
@@ -314,12 +233,9 @@ async def get_queue_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
"""Gets the status of the session queue"""
try:
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting queue status: {e}")
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
@session_queue_router.get(
@@ -334,10 +250,7 @@ async def get_batch_status(
batch_id: str = Path(description="The batch to get the status of"),
) -> BatchStatus:
"""Gets the status of the session queue"""
try:
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
@session_queue_router.get(
@@ -353,27 +266,7 @@ async def get_queue_item(
item_id: int = Path(description="The queue item to get"),
) -> SessionQueueItem:
"""Gets a queue item"""
try:
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching queue item: {e}")
@session_queue_router.delete(
"/{queue_id}/i/{item_id}",
operation_id="delete_queue_item",
)
async def delete_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
item_id: int = Path(description="The queue item to delete"),
) -> None:
"""Deletes a queue item"""
try:
ApiDependencies.invoker.services.session_queue.delete_queue_item(item_id)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting queue item: {e}")
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
@session_queue_router.put(
@@ -388,12 +281,8 @@ async def cancel_queue_item(
item_id: int = Path(description="The queue item to cancel"),
) -> SessionQueueItem:
"""Deletes a queue item"""
try:
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
raise HTTPException(status_code=404, detail=f"Queue item with id {item_id} not found in queue {queue_id}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while canceling queue item: {e}")
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
@session_queue_router.get(
@@ -406,27 +295,6 @@ async def counts_by_destination(
destination: str = Query(description="The destination to query"),
) -> SessionQueueCountsByDestination:
"""Gets the counts of queue items by destination"""
try:
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")
@session_queue_router.delete(
"/{queue_id}/d/{destination}",
operation_id="delete_by_destination",
responses={200: {"model": DeleteByDestinationResult}},
)
async def delete_by_destination(
queue_id: str = Path(description="The queue id to query"),
destination: str = Path(description="The destination to query"),
) -> DeleteByDestinationResult:
"""Deletes all items with the given destination"""
try:
return ApiDependencies.invoker.services.session_queue.delete_by_destination(
queue_id=queue_id, destination=destination
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while deleting by destination: {e}")
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
)

View File

@@ -158,7 +158,7 @@ web_root_path = Path(list(web_dir.__path__)[0])
try:
app.mount("/", NoCacheStaticFiles(directory=Path(web_root_path, "dist"), html=True), name="ui")
except RuntimeError:
logger.warning(f"No UI found at {web_root_path}/dist, skipping UI mount")
logger.warn(f"No UI found at {web_root_path}/dist, skipping UI mount")
app.mount(
"/static", NoCacheStaticFiles(directory=Path(web_root_path, "static/")), name="static"
) # docs favicon is in here

View File

@@ -499,7 +499,7 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None
ui_type = field.json_schema_extra.get("ui_type", None)
if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"):
logger.warning(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
logger.warn(f'"UIType.{ui_type.split("_")[-1]}" is deprecated, ignoring')
field.json_schema_extra.pop("ui_type")
return None
@@ -615,7 +615,7 @@ def invocation(
raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e
uiconfig["version"] = version
else:
logger.warning(f'No version specified for node "{invocation_type}", using "1.0.0"')
logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"')
uiconfig["version"] = "1.0.0"
cls.UIConfig = UIConfigBase(**uiconfig)

View File

@@ -114,13 +114,6 @@ class CompelInvocation(BaseInvocation):
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
del compel
del patched_tokenizer
del tokenizer
del ti_manager
del text_encoder
del text_encoder_info
c = c.detach().to("cpu")
conditioning_data = ConditioningFieldData(conditionings=[BasicConditioningInfo(embeds=c)])
@@ -229,10 +222,7 @@ class SDXLPromptInvocationBase:
else:
c_pooled = None
del compel
del patched_tokenizer
del tokenizer
del ti_manager
del text_encoder
del text_encoder_info

View File

@@ -215,7 +215,6 @@ class FieldDescriptions:
flux_redux_conditioning = "FLUX Redux conditioning tensor"
vllm_model = "The VLLM model to use"
flux_fill_conditioning = "FLUX Fill conditioning tensor"
flux_kontext_conditioning = "FLUX Kontext conditioning (reference image)"
class ImageField(BaseModel):
@@ -292,12 +291,6 @@ class FluxFillConditioningField(BaseModel):
mask: TensorField = Field(description="The FLUX Fill inpaint mask.")
class FluxKontextConditioningField(BaseModel):
"""A conditioning field for FLUX Kontext (reference image)."""
image: ImageField = Field(description="The Kontext reference image.")
class SD3ConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
@@ -445,7 +438,7 @@ class WithWorkflow:
workflow = None
def __init_subclass__(cls) -> None:
logger.warning(
logger.warn(
f"{cls.__module__.split('.')[0]}.{cls.__name__}: WithWorkflow is deprecated. Use `context.workflow` to access the workflow."
)
super().__init_subclass__()
@@ -586,7 +579,7 @@ def InputField(
if default_factory is not _Unset and default_factory is not None:
default = default_factory()
logger.warning('"default_factory" is not supported, calling it now to set "default"')
logger.warn('"default_factory" is not supported, calling it now to set "default"')
# These are the args we may wish pass to the pydantic `Field()` function
field_args = {

View File

@@ -16,12 +16,13 @@ from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxConditioningField,
FluxFillConditioningField,
FluxKontextConditioningField,
FluxReduxConditioningField,
ImageField,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
@@ -33,7 +34,6 @@ from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXCo
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.kontext_extension import KontextExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
@@ -63,9 +63,9 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="4.0.0",
version="3.3.0",
)
class FluxDenoiseInvocation(BaseInvocation):
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run denoising process with a FLUX transformer model."""
# If latents is provided, this means we are doing image-to-image.
@@ -145,20 +145,11 @@ class FluxDenoiseInvocation(BaseInvocation):
description=FieldDescriptions.vae,
input=Input.Connection,
)
# This node accepts a images for features like FLUX Fill, ControlNet, and Kontext, but needs to operate on them in
# latent space. We'll run the VAE to encode them in this node instead of requiring the user to run the VAE in
# upstream nodes.
ip_adapter: IPAdapterField | list[IPAdapterField] | None = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection
)
kontext_conditioning: Optional[FluxKontextConditioningField] = InputField(
default=None,
description="FLUX Kontext conditioning (reference image).",
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
@@ -385,27 +376,6 @@ class FluxDenoiseInvocation(BaseInvocation):
dtype=inference_dtype,
)
kontext_extension = None
if self.kontext_conditioning is not None:
if not self.controlnet_vae:
raise ValueError("A VAE (e.g., controlnet_vae) must be provided to use Kontext conditioning.")
kontext_extension = KontextExtension(
context=context,
kontext_conditioning=self.kontext_conditioning,
vae_field=self.controlnet_vae,
device=TorchDevice.choose_torch_device(),
dtype=inference_dtype,
)
# Prepare Kontext conditioning if provided
img_cond_seq = None
img_cond_seq_ids = None
if kontext_extension is not None:
# Ensure batch sizes match
kontext_extension.ensure_batch_size(x.shape[0])
img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids
x = denoise(
model=transformer,
img=x,
@@ -421,8 +391,6 @@ class FluxDenoiseInvocation(BaseInvocation):
pos_ip_adapter_extensions=pos_ip_adapter_extensions,
neg_ip_adapter_extensions=neg_ip_adapter_extensions,
img_cond=img_cond,
img_cond_seq=img_cond_seq,
img_cond_seq_ids=img_cond_seq_ids,
)
x = unpack(x.float(), self.height, self.width)
@@ -897,10 +865,7 @@ class FluxDenoiseInvocation(BaseInvocation):
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
# The denoise function now handles Kontext conditioning correctly,
# so we don't need to slice the latents here
latents = state.latents.float()
state.latents = unpack(latents, self.height, self.width).squeeze()
state.latents = unpack(state.latents.float(), self.height, self.width).squeeze()
context.util.flux_step_callback(state)
return step_callback

View File

@@ -1,40 +0,0 @@
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
FluxKontextConditioningField,
InputField,
OutputField,
)
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.shared.invocation_context import InvocationContext
@invocation_output("flux_kontext_output")
class FluxKontextOutput(BaseInvocationOutput):
"""The conditioning output of a FLUX Kontext invocation."""
kontext_cond: FluxKontextConditioningField = OutputField(
description=FieldDescriptions.flux_kontext_conditioning, title="Kontext Conditioning"
)
@invocation(
"flux_kontext",
title="Kontext Conditioning - FLUX",
tags=["conditioning", "kontext", "flux"],
category="conditioning",
version="1.0.0",
)
class FluxKontextInvocation(BaseInvocation):
"""Prepares a reference image for FLUX Kontext conditioning."""
image: ImageField = InputField(description="The Kontext reference image.")
def invoke(self, context: InvocationContext) -> FluxKontextOutput:
"""Packages the provided image into a Kontext conditioning field."""
return FluxKontextOutput(kontext_cond=FluxKontextConditioningField(image=self.image))

View File

@@ -1,5 +1,5 @@
from contextlib import ExitStack
from typing import Iterator, Literal, Optional, Tuple, Union
from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer, T5TokenizerFast
@@ -111,9 +111,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)
if context.config.get().log_tokenization:
self._log_t5_tokenization(context, t5_tokenizer)
context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)
@@ -154,9 +151,6 @@ class FluxTextEncoderInvocation(BaseInvocation):
clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)
if context.config.get().log_tokenization:
self._log_clip_tokenization(context, clip_tokenizer)
context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)
@@ -176,88 +170,3 @@ class FluxTextEncoderInvocation(BaseInvocation):
assert isinstance(lora_info.model, ModelPatchRaw)
yield (lora_info.model, lora.weight)
del lora_info
def _log_t5_tokenization(
self,
context: InvocationContext,
tokenizer: Union[T5Tokenizer, T5TokenizerFast],
) -> None:
"""Logs the tokenization of a prompt for a T5-based model like FLUX."""
# Tokenize the prompt using the same parameters as the model's text encoder.
# T5 tokenizers add an EOS token (</s>) and then pad to max_length.
tokenized_output = tokenizer(
self.prompt,
padding="max_length",
max_length=self.t5_max_seq_len,
truncation=True,
add_special_tokens=True, # This is important for T5 to add the EOS token.
return_tensors="pt",
)
input_ids = tokenized_output.input_ids[0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# The T5 tokenizer uses a space-like character ' ' (U+2581) to denote spaces.
# We'll replace it with a regular space for readability.
tokens = [t.replace("\u2581", " ") for t in tokens]
tokenized_str = ""
used_tokens = 0
for token in tokens:
if token == tokenizer.eos_token:
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
used_tokens += 1
elif token == tokenizer.pad_token:
# tokenized_str += f"\x1b[0;34m{token}\x1b[0m" # Blue for PAD
continue
else:
color = (used_tokens % 6) + 1 # Cycle through 6 colors
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
used_tokens += 1
context.logger.info(f">> [T5 TOKENLOG] Tokens ({used_tokens}/{self.t5_max_seq_len}):")
context.logger.info(f"{tokenized_str}\x1b[0m")
def _log_clip_tokenization(
self,
context: InvocationContext,
tokenizer: CLIPTokenizer,
) -> None:
"""Logs the tokenization of a prompt for a CLIP-based model."""
max_length = tokenizer.model_max_length
tokenized_output = tokenizer(
self.prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
input_ids = tokenized_output.input_ids[0]
attention_mask = tokenized_output.attention_mask[0]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
# The CLIP tokenizer uses '</w>' to denote spaces.
# We'll replace it with a regular space for readability.
tokens = [t.replace("</w>", " ") for t in tokens]
tokenized_str = ""
used_tokens = 0
for i, token in enumerate(tokens):
if attention_mask[i] == 0:
# Do not log padding tokens.
continue
if token == tokenizer.bos_token:
tokenized_str += f"\x1b[0;32m{token}\x1b[0m" # Green for BOS
elif token == tokenizer.eos_token:
tokenized_str += f"\x1b[0;31m{token}\x1b[0m" # Red for EOS
else:
color = (used_tokens % 6) + 1 # Cycle through 6 colors
tokenized_str += f"\x1b[0;3{color}m{token}\x1b[0m"
used_tokens += 1
context.logger.info(f">> [CLIP TOKENLOG] Tokens ({used_tokens}/{max_length}):")
context.logger.info(f"{tokenized_str}\x1b[0m")

View File

@@ -14,14 +14,15 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._conn = db.conn
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
@@ -30,12 +31,17 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
""",
(board_id, image_name, board_id),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def remove_image_from_board(
self,
image_name: str,
) -> None:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM board_images
@@ -43,6 +49,10 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
""",
(image_name,),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def get_images_for_board(
self,
@@ -50,26 +60,27 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# TODO: this isn't paginated yet?
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE board_images.board_id = ?
ORDER BY board_images.updated_at DESC;
""",
(board_id,),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images WHERE 1=1;
"""
)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
@@ -79,55 +90,56 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
with self._db.transaction() as cursor:
params: list[str | bool] = []
params: list[str | bool] = []
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Base query is a join between images and board_images
stmt = """
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Handle board_id filter
if board_id == "none":
stmt += """--sql
AND board_images.board_id IS NULL
"""
else:
stmt += """--sql
AND board_images.board_id = ?
"""
params.append(board_id)
# Handle board_id filter
if board_id == "none":
stmt += """--sql
AND board_images.board_id IS NULL
"""
else:
stmt += """--sql
AND board_images.board_id = ?
"""
params.append(board_id)
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Add the category filter
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
stmt += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Unpack the included categories into the query params
for c in category_strings:
params.append(c)
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Add the is_intermediate filter
if is_intermediate is not None:
stmt += """--sql
AND images.is_intermediate = ?
"""
params.append(is_intermediate)
# Put a ring on it
stmt += ";"
# Put a ring on it
stmt += ";"
cursor.execute(stmt, params)
# Execute the query
cursor = self._conn.cursor()
cursor.execute(stmt, params)
result = cast(list[sqlite3.Row], cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
return image_names
@@ -135,31 +147,31 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self,
image_name: str,
) -> Optional[str]:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT board_id
FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
result = cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
def get_image_count_for_board(self, board_id: str) -> int:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM board_images
INNER JOIN images ON board_images.image_name = images.image_name
WHERE images.is_intermediate = FALSE
AND board_images.board_id = ?;
""",
(board_id,),
)
count = cast(int, cursor.fetchone()[0])
return count

View File

@@ -20,57 +20,61 @@ from invokeai.app.util.misc import uuid_string
class SqliteBoardRecordStorage(BoardRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._conn = db.conn
def delete(self, board_id: str) -> None:
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
except Exception as e:
raise BoardRecordDeleteException from e
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
self._conn.commit()
except Exception as e:
self._conn.rollback()
raise BoardRecordDeleteException from e
def save(
self,
board_name: str,
) -> BoardRecord:
with self._db.transaction() as cursor:
try:
board_id = uuid_string()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
except sqlite3.Error as e:
raise BoardRecordSaveException from e
try:
board_id = uuid_string()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
VALUES (?, ?);
""",
(board_id, board_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
return self.get(board_id)
def get(
self,
board_id: str,
) -> BoardRecord:
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM boards
WHERE board_id = ?;
""",
(board_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
raise BoardRecordNotFoundException from e
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
except sqlite3.Error as e:
raise BoardRecordNotFoundException from e
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
@@ -80,43 +84,45 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
with self._db.transaction() as cursor:
try:
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
try:
cursor = self._conn.cursor()
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_name = ?
WHERE board_id = ?;
""",
(changes.board_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
# Change the cover image of a board
if changes.cover_image_name is not None:
cursor.execute(
"""--sql
UPDATE boards
SET cover_image_name = ?
WHERE board_id = ?;
""",
(changes.cover_image_name, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
# Change the archived status of a board
if changes.archived is not None:
cursor.execute(
"""--sql
UPDATE boards
SET archived = ?
WHERE board_id = ?;
""",
(changes.archived, board_id),
)
except sqlite3.Error as e:
raise BoardRecordSaveException from e
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise BoardRecordSaveException from e
return self.get(board_id)
def get_many(
@@ -127,77 +133,78 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
with self._db.transaction() as cursor:
# Build base query
base_query = """
SELECT *
cursor = self._conn.cursor()
# Build base query
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
WHERE archived = 0;
"""
# Determine archived filter condition
archived_filter = "" if include_archived else "WHERE archived = 0"
# Execute count query
cursor.execute(count_query)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
# Execute query to fetch boards
cursor.execute(final_query, (limit, offset))
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
# Determine count query
if include_archived:
count_query = """
SELECT COUNT(*)
FROM boards;
"""
else:
count_query = """
SELECT COUNT(*)
FROM boards
WHERE archived = 0;
"""
# Execute count query
cursor.execute(count_query)
count = cast(int, cursor.fetchone()[0])
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
with self._db.transaction() as cursor:
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
cursor = self._conn.cursor()
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY LOWER(board_name) {direction}
"""
else:
base_query = """
SELECT *
FROM boards
{archived_filter}
ORDER BY {order_by} {direction}
"""
archived_filter = "" if include_archived else "WHERE archived = 0"
archived_filter = "" if include_archived else "WHERE archived = 0"
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
final_query = base_query.format(
archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
)
cursor.execute(final_query)
cursor.execute(final_query)
result = cast(list[sqlite3.Row], cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
return boards

View File

@@ -24,6 +24,7 @@ from invokeai.frontend.cli.arg_parser import InvokeAIArgs
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
DEVICE = Literal["auto", "cpu", "cuda", "cuda:1", "mps"]
PRECISION = Literal["auto", "float16", "bfloat16", "float32"]
ATTENTION_TYPE = Literal["auto", "normal", "xformers", "sliced", "torch-sdp"]
ATTENTION_SLICE_SIZE = Literal["auto", "balanced", "max", 1, 2, 3, 4, 5, 6, 7, 8]
@@ -92,7 +93,7 @@ class InvokeAIAppConfig(BaseSettings):
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
pytorch_cuda_alloc_conf: Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to "backend:cudaMallocAsync" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
attention_type: Attention type.<br>Valid values: `auto`, `normal`, `xformers`, `sliced`, `torch-sdp`
@@ -175,7 +176,7 @@ class InvokeAIAppConfig(BaseSettings):
pytorch_cuda_alloc_conf: Optional[str] = Field(default=None, description="Configure the Torch CUDA memory allocator. This will impact peak reserved VRAM usage and performance. Setting to \"backend:cudaMallocAsync\" works well on many systems. The optimal configuration is highly dependent on the system configuration (device type, VRAM, CUDA driver version, etc.), so must be tuned experimentally.")
# DEVICE
device: str = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `mps`, `cuda:N` (where N is a device number)", pattern=r"^(auto|cpu|mps|cuda(:\d+)?)$")
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
precision: PRECISION = Field(default="auto", description="Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.")
# GENERATION

View File

@@ -8,7 +8,6 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from shutil import disk_usage
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
import requests
@@ -336,14 +335,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
assert job.download_path
free_space = disk_usage(job.download_path.parent).free
GB = 2**30
self._logger.debug(f"Download is {job.total_bytes / GB:.2f} GB of {free_space / GB:.2f} GB free.")
if free_space < job.total_bytes:
raise RuntimeError(
f"Free disk space {free_space / GB:.2f} GB is not enough for download of {job.total_bytes / GB:.2f} GB."
)
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
# for code that instead resumes an interrupted download.
if job.download_path.exists():

View File

@@ -5,7 +5,6 @@ from typing import Optional
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@@ -98,17 +97,3 @@ class ImageRecordStorageBase(ABC):
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
"""Gets the most recent image for a board."""
pass
@abstractmethod
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates."""
pass

View File

@@ -3,7 +3,7 @@ import datetime
from enum import Enum
from typing import Optional, Union
from pydantic import BaseModel, Field, StrictBool, StrictStr
from pydantic import Field, StrictBool, StrictStr
from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import get_iso_timestamp
@@ -207,16 +207,3 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
starred=starred,
has_workflow=has_workflow,
)
class ImageCollectionCounts(BaseModel):
starred_count: int = Field(description="The number of starred images in the collection.")
unstarred_count: int = Field(description="The number of unstarred images in the collection.")
class ImageNamesResult(BaseModel):
"""Response containing ordered image names with metadata for optimistic updates."""
image_names: list[str] = Field(description="Ordered list of image names")
starred_count: int = Field(description="Number of starred images (when starred_first=True)")
total_count: int = Field(description="Total number of images matching the query")

View File

@@ -7,7 +7,6 @@ from invokeai.app.services.image_records.image_records_base import ImageRecordSt
from invokeai.app.services.image_records.image_records_common import (
IMAGE_DTO_COLS,
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@@ -24,22 +23,22 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteImageRecordStorage(ImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._conn = db.conn
def get(self, image_name: str) -> ImageRecord:
with self._db.transaction() as cursor:
try:
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
""",
(image_name,),
)
try:
cursor = self._conn.cursor()
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
if not result:
raise ImageRecordNotFoundException
@@ -47,20 +46,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT metadata FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
result = cast(Optional[sqlite3.Row], cursor.fetchone())
if not result:
raise ImageRecordNotFoundException
@@ -68,60 +64,64 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
as_dict = dict(result)
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
except sqlite3.Error as e:
raise ImageRecordNotFoundException from e
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> None:
with self._db.transaction() as cursor:
try:
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
try:
cursor = self._conn.cursor()
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
cursor.execute(
"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.is_intermediate is not None:
cursor.execute(
"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;
""",
(changes.starred, image_name),
)
# Change the image's `starred`` state
if changes.starred is not None:
cursor.execute(
"""--sql
UPDATE images
SET starred = ?
WHERE image_name = ?;
""",
(changes.starred, image_name),
)
except sqlite3.Error as e:
raise ImageRecordSaveException from e
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
def get_many(
self,
@@ -135,162 +135,166 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
with self._db.transaction() as cursor:
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
cursor = self._conn.cursor()
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
query_params.append(is_intermediate)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if categories is not None:
# Convert the enum values to unique list of strings
category_strings = [c.value for c in set(categories)]
# Create the correct length of placeholders
placeholders = ",".join("?" * len(category_strings))
# Search term condition
if search_term:
query_conditions += """--sql
AND images.metadata LIKE ?
"""
query_params.append(f"%{search_term.lower()}%")
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Unpack the included categories into the query params
for c in category_strings:
query_params.append(c)
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
query_params.append(is_intermediate)
# board_id of "none" is reserved for images without a board
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
# Search term condition
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
if starred_first:
query_pagination = f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
else:
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Final images query with pagination
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
# Add the pagination parameters
images_params.extend([limit, offset])
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
cursor.execute(count_query, count_params)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def delete(self, image_name: str) -> None:
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
def delete_many(self, image_names: list[str]) -> None:
with self._db.transaction() as cursor:
try:
placeholders = ",".join("?" for _ in image_names)
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
def get_intermediates_count(self) -> int:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
DELETE FROM images
WHERE image_name = ?;
""",
(image_name,),
)
count = cast(int, cursor.fetchone()[0])
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
def delete_many(self, image_names: list[str]) -> None:
try:
cursor = self._conn.cursor()
placeholders = ",".join("?" for _ in image_names)
# Construct the SQLite query with the placeholders
query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
# Execute the query with the list of IDs as parameters
cursor.execute(query, image_names)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
def get_intermediates_count(self) -> int:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
count = cast(int, cursor.fetchone()[0])
self._conn.commit()
return count
def delete_intermediates(self) -> list[str]:
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
"""
)
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
return image_names
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT image_name FROM images
WHERE is_intermediate = TRUE;
"""
)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
cursor.execute(
"""--sql
DELETE FROM images
WHERE is_intermediate = TRUE;
"""
)
self._conn.commit()
return image_names
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordDeleteException from e
def save(
self,
@@ -306,165 +310,75 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str] = None,
metadata: Optional[str] = None,
) -> datetime:
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow,
),
)
cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
created_at = datetime.fromisoformat(cursor.fetchone()[0])
except sqlite3.Error as e:
raise ImageRecordSaveException from e
return created_at
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
INSERT OR IGNORE INTO images (
image_name,
image_origin,
image_category,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(board_id,),
(
image_name,
image_origin.value,
image_category.value,
width,
height,
node_id,
session_id,
metadata,
is_intermediate,
starred,
has_workflow,
),
)
self._conn.commit()
cursor.execute(
"""--sql
SELECT created_at
FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
created_at = datetime.fromisoformat(cursor.fetchone()[0])
return created_at
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT images.*
FROM images
JOIN board_images ON images.image_name = board_images.image_name
WHERE board_images.board_id = ?
AND images.is_intermediate = FALSE
ORDER BY images.starred DESC, images.created_at DESC
LIMIT 1;
""",
(board_id,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
if result is None:
return None
return deserialize_image_record(dict(result))
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
with self._db.transaction() as cursor:
# Build query conditions (reused for both starred count and image names queries)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Get starred count if starred_first is enabled
starred_count = 0
if starred_first:
starred_count_query = f"""--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE images.starred = TRUE AND (1=1{query_conditions})
"""
cursor.execute(starred_count_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get all image names with proper ordering
if starred_first:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
else:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.created_at {order_dir.value}
"""
cursor.execute(names_query, query_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [row[0] for row in result]
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))

View File

@@ -6,7 +6,6 @@ from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@@ -126,7 +125,7 @@ class ImageServiceABC(ABC):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs with starred images first when starred_first=True."""
"""Gets a paginated list of image DTOs."""
pass
@abstractmethod
@@ -148,17 +147,3 @@ class ImageServiceABC(ABC):
def delete_images_on_board(self, board_id: str):
"""Deletes all images on a board."""
pass
@abstractmethod
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates."""
pass

View File

@@ -1,6 +1,6 @@
from typing import Optional
from pydantic import BaseModel, Field
from pydantic import Field
from invokeai.app.services.image_records.image_records_common import ImageRecord
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
@@ -39,27 +39,3 @@ def image_record_to_dto(
thumbnail_url=thumbnail_url,
board_id=board_id,
)
class ResultWithAffectedBoards(BaseModel):
affected_boards: list[str] = Field(description="The ids of boards affected by the delete operation")
class DeleteImagesResult(ResultWithAffectedBoards):
deleted_images: list[str] = Field(description="The names of the images that were deleted")
class StarredImagesResult(ResultWithAffectedBoards):
starred_images: list[str] = Field(description="The names of the images that were starred")
class UnstarredImagesResult(ResultWithAffectedBoards):
unstarred_images: list[str] = Field(description="The names of the images that were unstarred")
class AddImagesToBoardResult(ResultWithAffectedBoards):
added_images: list[str] = Field(description="The image names that were added to the board")
class RemoveImagesFromBoardResult(ResultWithAffectedBoards):
removed_images: list[str] = Field(description="The image names that were removed from their board")

View File

@@ -10,7 +10,6 @@ from invokeai.app.services.image_files.image_files_common import (
)
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@@ -79,7 +78,7 @@ class ImageService(ImageServiceABC):
board_id=board_id, image_name=image_name
)
except Exception as e:
self.__invoker.services.logger.warning(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.logger.warn(f"Failed to add image to board {board_id}: {str(e)}")
self.__invoker.services.image_files.save(
image_name=image_name, image=image, metadata=metadata, workflow=workflow, graph=graph
)
@@ -310,27 +309,3 @@ class ImageService(ImageServiceABC):
except Exception as e:
self.__invoker.services.logger.error("Problem getting intermediates count")
raise e
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
try:
return self.__invoker.services.image_records.get_image_names(
starred_first=starred_first,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
except Exception as e:
self.__invoker.services.logger.error("Problem getting image names")
raise e

View File

@@ -148,7 +148,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _clear_pending_jobs(self) -> None:
for job in self.list_jobs():
if not job.in_terminal_state:
self._logger.warning(f"Cancelling job {job.id}")
self._logger.warning("Cancelling job {job.id}")
self.cancel_job(job)
while True:
try:

View File

@@ -78,6 +78,11 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db = db
self._logger = logger
@property
def db(self) -> SqliteDatabase:
"""Return the underlying database."""
return self._db
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
"""
Add a model to the database.
@@ -88,33 +93,38 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
with self._db.transaction() as cursor:
try:
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
INSERT INTO models (
id,
config
)
VALUES (?,?);
""",
(
config.key,
config.model_dump_json(),
),
)
self._db.conn.commit()
except sqlite3.IntegrityError as e:
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
except sqlite3.IntegrityError as e:
self._db.conn.rollback()
if "UNIQUE constraint failed" in str(e):
if "models.path" in str(e):
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in str(e):
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
raise e
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
else:
raise e
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(config.key)
@@ -126,7 +136,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
with self._db.transaction() as cursor:
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
DELETE FROM models
@@ -136,17 +147,22 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
with self._db.transaction() as cursor:
record = self.get_model(key)
record = self.get_model(key)
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
json_serialized = record.model_dump_json()
try:
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
UPDATE models
@@ -158,6 +174,10 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
self._db.conn.commit()
except sqlite3.Error as e:
self._db.conn.rollback()
raise e
return self.get_model(key)
@@ -169,30 +189,30 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE id=?;
""",
(key,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
@@ -204,15 +224,15 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
select count(*) FROM models
WHERE id=?;
""",
(key,),
)
count = cursor.fetchone()[0]
return count > 0
def search_by_attr(
@@ -235,42 +255,43 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all
models in the database.
"""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
where_clause: list[str] = []
bindings: list[str] = []
if model_name:
where_clause.append("name=?")
bindings.append(model_name)
if base_model:
where_clause.append("base=?")
bindings.append(base_model)
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
cursor = self._db.conn.cursor()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
FROM models
{where}
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
""",
tuple(bindings),
)
result = cursor.fetchall()
# Parse the model configs.
results: list[AnyModelConfig] = []
@@ -292,68 +313,69 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash."""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
cursor = self._db.conn.cursor()
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
WHERE hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
cursor = self._db.conn.cursor()
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
"""--sql
select count(*) from models;
""",
(),
)
total = int(cursor.fetchone()[0])
# query2: fetch key fields
cursor.execute(
f"""--sql
SELECT config
FROM models
ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
LIMIT ?
OFFSET ?;
""",
(
per_page,
page * per_page,
),
)
rows = cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)

View File

@@ -1,3 +1,5 @@
import sqlite3
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
ModelRelationshipRecordStorageBase,
)
@@ -7,49 +9,58 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._conn = db.conn
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
with self._db.transaction() as cursor:
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
try:
cursor = self._conn.cursor()
cursor.execute(
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
(a, b),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
with self._db.transaction() as cursor:
a, b = sorted([model_key_1, model_key_2])
a, b = sorted([model_key_1, model_key_2])
try:
cursor = self._conn.cursor()
cursor.execute(
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
(a, b),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise e
def get_related_model_keys(self, model_key: str) -> list[str]:
with self._db.transaction() as cursor:
cursor.execute(
"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
""",
(model_key, model_key),
)
result = [row[0] for row in cursor.fetchall()]
return result
cursor = self._conn.cursor()
cursor.execute(
"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
""",
(model_key, model_key),
)
return [row[0] for row in cursor.fetchall()]
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
with self._db.transaction() as cursor:
key_list = ",".join("?" for _ in model_keys)
cursor.execute(
f"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
""",
model_keys + model_keys,
)
result = [row[0] for row in cursor.fetchall()]
return result
cursor = self._conn.cursor()
key_list = ",".join("?" for _ in model_keys)
cursor.execute(
f"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
UNION
SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
""",
model_keys + model_keys,
)
return [row[0] for row in cursor.fetchall()]

View File

@@ -1,4 +1,3 @@
import gc
import traceback
from contextlib import suppress
from threading import BoundedSemaphore, Thread
@@ -440,12 +439,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.wait(self._polling_interval)
continue
# GC-ing here can reduce peak memory usage of the invoke process by freeing allocated memory blocks.
# Most queue items take seconds to execute, so the relative cost of a GC is very small.
# Python will never cede allocated memory back to the OS, so anything we can do to reduce the peak
# allocation is well worth it.
gc.collect()
self._invoker.services.logger.info(
f"Executing queue item {self._queue_item.item_id}, session {self._queue_item.session_id}"
)

View File

@@ -10,8 +10,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByDestinationResult,
CancelByQueueIDResult,
ClearResult,
DeleteAllExceptCurrentResult,
DeleteByDestinationResult,
EnqueueBatchResult,
IsEmptyResult,
IsFullResult,
@@ -19,6 +17,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import GraphExecutionState
@@ -93,11 +92,6 @@ class SessionQueueBase(ABC):
"""Cancels a session queue item"""
pass
@abstractmethod
def delete_queue_item(self, item_id: int) -> None:
"""Deletes a session queue item"""
pass
@abstractmethod
def fail_queue_item(
self, item_id: int, error_type: str, error_message: str, error_traceback: str
@@ -115,11 +109,6 @@ class SessionQueueBase(ABC):
"""Cancels all queue items with the given batch destination"""
pass
@abstractmethod
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
"""Deletes all queue items with the given batch destination"""
pass
@abstractmethod
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
"""Cancels all queue items with matching queue ID"""
@@ -130,11 +119,6 @@ class SessionQueueBase(ABC):
"""Cancels all queue items except in-progress items"""
pass
@abstractmethod
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
"""Deletes all queue items except in-progress items"""
pass
@abstractmethod
def list_queue_items(
self,
@@ -143,20 +127,10 @@ class SessionQueueBase(ABC):
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
) -> CursorPaginatedResults[SessionQueueItemDTO]:
"""Gets a page of session queue items"""
pass
@abstractmethod
def list_all_queue_items(
self,
queue_id: str,
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
pass
@abstractmethod
def get_queue_item(self, item_id: int) -> SessionQueueItem:
"""Gets a session queue item by ID"""

View File

@@ -208,7 +208,7 @@ class FieldIdentifier(BaseModel):
user_label: str | None = Field(description="The user label of the field, if any")
class SessionQueueItem(BaseModel):
class SessionQueueItemWithoutGraph(BaseModel):
"""Session queue item without the full graph. Used for serialization."""
item_id: int = Field(description="The identifier of the session queue item")
@@ -252,7 +252,42 @@ class SessionQueueItem(BaseModel):
default=None,
description="The ID of the published workflow associated with this queue item",
)
api_input_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The fields that were used as input to the API"
)
api_output_fields: Optional[list[FieldIdentifier]] = Field(
default=None, description="The nodes that were used as output from the API"
)
credits: Optional[float] = Field(default=None, description="The total credits used for this queue item")
@classmethod
def queue_item_dto_from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":
# must parse these manually
queue_item_dict["field_values"] = get_field_values(queue_item_dict)
return SessionQueueItemDTO(**queue_item_dict)
model_config = ConfigDict(
json_schema_extra={
"required": [
"item_id",
"status",
"batch_id",
"queue_id",
"session_id",
"priority",
"session_id",
"created_at",
"updated_at",
]
}
)
class SessionQueueItemDTO(SessionQueueItemWithoutGraph):
pass
class SessionQueueItem(SessionQueueItemWithoutGraph):
session: GraphExecutionState = Field(description="The fully-populated session to be executed")
workflow: Optional[WorkflowWithoutID] = Field(
default=None, description="The workflow associated with this queue item"
@@ -332,7 +367,6 @@ class EnqueueBatchResult(BaseModel):
requested: int = Field(description="The total number of queue items requested to be enqueued")
batch: Batch = Field(description="The batch that was enqueued")
priority: int = Field(description="The priority of the enqueued batch")
item_ids: list[int] = Field(description="The IDs of the queue items that were enqueued")
class RetryItemsResult(BaseModel):
@@ -364,18 +398,6 @@ class CancelByDestinationResult(CancelByBatchIDsResult):
pass
class DeleteByDestinationResult(BaseModel):
"""Result of deleting by a destination"""
deleted: int = Field(..., description="Number of queue items deleted")
class DeleteAllExceptCurrentResult(DeleteByDestinationResult):
"""Result of deleting all except current"""
pass
class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id"""

View File

@@ -17,8 +17,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
CancelByDestinationResult,
CancelByQueueIDResult,
ClearResult,
DeleteAllExceptCurrentResult,
DeleteByDestinationResult,
EnqueueBatchResult,
IsEmptyResult,
IsFullResult,
@@ -26,6 +24,7 @@ from invokeai.app.services.session_queue.session_queue_common import (
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueItemNotFoundError,
SessionQueueStatus,
ValueToInsertTuple,
@@ -47,17 +46,22 @@ class SqliteSessionQueue(SessionQueueBase):
clear_result = self.clear(DEFAULT_QUEUE_ID)
if clear_result.deleted > 0:
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
else:
prune_result = self.prune(DEFAULT_QUEUE_ID)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._conn = db.conn
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
This is necessary because the invoker may have been killed while processing a queue item.
"""
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
UPDATE session_queue
@@ -65,104 +69,102 @@ class SqliteSessionQueue(SessionQueueBase):
WHERE status = 'in_progress';
"""
)
except Exception:
self._conn.rollback()
raise
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
count = cast(int, cursor.fetchone()[0])
return count
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(int, cursor.fetchone()[0])
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
priority = cast(Union[int, None], cursor.fetchone()[0]) or 0
return priority
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
""",
(queue_id,),
)
return cast(Union[int, None], cursor.fetchone()[0]) or 0
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
return await asyncio.to_thread(self._enqueue_batch, queue_id, batch, prepend)
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
def _enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
cursor = self._conn.cursor()
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
requested_count = await asyncio.to_thread(
calc_session_count,
batch=batch,
)
values_to_insert = await asyncio.to_thread(
prepare_values_to_insert,
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = calc_session_count(batch)
values_to_insert = prepare_values_to_insert(
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
)
enqueued_count = len(values_to_insert)
if requested_count > enqueued_count:
values_to_insert = values_to_insert[:max_new_queue_items]
with self._db.transaction() as cursor:
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
cursor.execute(
"""--sql
SELECT item_id
FROM session_queue
WHERE batch_id = ?
ORDER BY item_id DESC;
""",
(batch.batch_id,),
)
item_ids = [row[0] for row in cursor.fetchall()]
self._conn.commit()
except Exception:
self._conn.rollback()
raise
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
enqueued=enqueued_count,
batch=batch,
priority=priority,
item_ids=item_ids,
)
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE status = 'pending'
ORDER BY
priority DESC,
item_id ASC
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
@@ -170,40 +172,40 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'pending'
ORDER BY
priority DESC,
created_at ASC
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
@@ -216,23 +218,8 @@ class SqliteSessionQueue(SessionQueueBase):
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status FROM session_queue WHERE item_id = ?
""",
(item_id,),
)
row = cursor.fetchone()
if row is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
current_status = row[0]
# Only update if not already finished (completed, failed or canceled)
if current_status in ("completed", "failed", "canceled"):
return self.get_queue_item(item_id)
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
UPDATE session_queue
@@ -241,7 +228,10 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(status, error_type, error_message, error_traceback, item_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
@@ -249,34 +239,35 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
is_empty = cast(int, cursor.fetchone()[0]) == 0
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
WHERE queue_id = ?
""",
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*)
@@ -294,19 +285,24 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
where = """--sql
WHERE
queue_id = ?
AND (
queue_id = ?
AND (
status = 'completed'
OR status = 'failed'
OR status = 'canceled'
)
)
"""
cursor.execute(
f"""--sql
@@ -325,28 +321,16 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
return queue_item
def delete_queue_item(self, item_id: int) -> None:
"""Deletes a session queue item"""
try:
self.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
pass
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE
FROM session_queue
WHERE item_id = ?
""",
(item_id,),
)
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
return queue_item
@@ -368,7 +352,8 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
@@ -378,8 +363,6 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id] + batch_ids
cursor.execute(
@@ -399,14 +382,17 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
self._conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
raise
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -415,8 +401,6 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = (queue_id, destination)
cursor.execute(
@@ -436,67 +420,17 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
self._conn.commit()
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self._conn.rollback()
raise
return CancelByDestinationResult(canceled=count)
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
if current_queue_item is not None and current_queue_item.destination == destination:
self.cancel_queue_item(current_queue_item.item_id)
params = (queue_id, destination)
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM session_queue
WHERE
queue_id = ?
AND destination = ?;
""",
params,
)
count = cursor.fetchone()[0]
cursor.execute(
"""--sql
DELETE
FROM session_queue
WHERE
queue_id = ?
AND destination = ?;
""",
params,
)
return DeleteByDestinationResult(deleted=count)
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
AND status == 'pending'
"""
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
(queue_id,),
)
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
DELETE
FROM session_queue
{where};
""",
(queue_id,),
)
return DeleteAllExceptCurrentResult(deleted=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -504,8 +438,6 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
-- We will cancel the current item separately below - skip it here
AND status != 'in_progress'
"""
params = [queue_id]
cursor.execute(
@@ -525,13 +457,21 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
self._conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
)
except Exception:
self._conn.rollback()
raise
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
where = """--sql
WHERE
queue_id == ?
@@ -554,25 +494,30 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
item_id = ?
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
@@ -585,6 +530,10 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(session_json, item_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get_queue_item(item_id)
def list_queue_items(
@@ -594,45 +543,53 @@ class SqliteSessionQueue(SessionQueueBase):
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
with self._db.transaction() as cursor_:
item_id = cursor
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
AND status = ?
"""
params.append(status)
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
) -> CursorPaginatedResults[SessionQueueItemDTO]:
cursor_ = self._conn.cursor()
item_id = cursor
query = """--sql
SELECT item_id,
status,
priority,
field_values,
error_type,
error_message,
error_traceback,
created_at,
updated_at,
completed_at,
started_at,
session_id,
batch_id,
queue_id,
origin,
destination
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if status is not None:
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
AND status = ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
params.append(status)
if item_id is not None:
query += """--sql
AND (priority < ?) OR (priority = ? AND item_id > ?)
"""
params.extend([priority, priority, item_id])
query += """--sql
ORDER BY
priority DESC,
item_id ASC
LIMIT ?
"""
params.append(limit + 1)
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
# remove the extra item
@@ -640,52 +597,21 @@ class SqliteSessionQueue(SessionQueueBase):
has_more = True
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def list_all_queue_items(
self,
queue_id: str,
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
with self._db.transaction() as cursor:
query = """--sql
SELECT *
FROM session_queue
WHERE queue_id = ?
"""
params: list[Union[str, int]] = [queue_id]
if destination is not None:
query += """---sql
AND destination = ?
"""
params.append(destination)
query += """--sql
ORDER BY
priority DESC,
item_id ASC
;
"""
cursor.execute(query, params)
results = cast(list[sqlite3.Row], cursor.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
return items
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] or 0 for row in counts_result)
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueStatus(
queue_id=queue_id,
@@ -701,20 +627,20 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in result)
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
destination = result[0]["destination"] if result else None
@@ -733,20 +659,20 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in counts_result)
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
return SessionQueueCountsByDestination(
@@ -762,7 +688,8 @@ class SqliteSessionQueue(SessionQueueBase):
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
values_to_insert: list[ValueToInsertTuple] = []
retried_item_ids: list[int] = []
@@ -813,6 +740,10 @@ class SqliteSessionQueue(SessionQueueBase):
values_to_insert,
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
retry_result = RetryItemsResult(
queue_id=queue_id,
retried_item_ids=retried_item_ids,

View File

@@ -2,12 +2,11 @@
import copy
import itertools
from typing import Any, Optional, TypeVar, Union, get_args, get_origin
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import (
BaseModel,
ConfigDict,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
ValidationError,
@@ -58,32 +57,17 @@ class Edge(BaseModel):
def get_output_field_type(node: BaseInvocation, field: str) -> Any:
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
# would require some fairly significant changes and I don't want risk breaking anything.
try:
invocation_class = type(node)
invocation_output_class = invocation_class.get_output_annotation()
field_info = invocation_output_class.model_fields.get(field)
assert field_info is not None, f"Output field '{field}' not found in {invocation_output_class.get_type()}"
output_field_type = field_info.annotation
return output_field_type
except Exception:
return None
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_annotation())
node_output_field = node_outputs.get(field) or None
return node_output_field
def get_input_field_type(node: BaseInvocation, field: str) -> Any:
# TODO(psyche): This is awkward - if field_info is None, it means the field is not defined in the output, which
# really should raise. The consumers of this utility expect it to never raise, and return None instead. Fixing this
# would require some fairly significant changes and I don't want risk breaking anything.
try:
invocation_class = type(node)
field_info = invocation_class.model_fields.get(field)
assert field_info is not None, f"Input field '{field}' not found in {invocation_class.get_type()}"
input_field_type = field_info.annotation
return input_field_type
except Exception:
return None
node_type = type(node)
node_inputs = get_type_hints(node_type)
node_input_field = node_inputs.get(field) or None
return node_input_field
def is_union_subtype(t1, t2):
@@ -803,22 +787,6 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
model_config = ConfigDict(
json_schema_extra={
"required": [
"id",
"graph",
"execution_graph",
"executed",
"executed_history",
"results",
"errors",
"prepared_source_mapping",
"source_prepared_mapping",
]
}
)
@field_validator("graph")
def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid"""
@@ -1007,11 +975,10 @@ class GraphExecutionState(BaseModel):
new_node_ids = []
if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation
all_iteration_mappings = []
for source_node_id in next_node_parents:
prepared_nodes = self.source_prepared_mapping[source_node_id]
all_iteration_mappings.extend([(source_node_id, p) for p in prepared_nodes])
all_iteration_mappings = list(
itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents))
)
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
create_results = self._create_execution_node(next_node_id, all_iteration_mappings)
if create_results is not None:
new_node_ids.extend(create_results)

View File

@@ -1,7 +1,4 @@
import sqlite3
import threading
from collections.abc import Generator
from contextlib import contextmanager
from logging import Logger
from pathlib import Path
@@ -29,65 +26,46 @@ class SqliteDatabase:
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
"""Initializes the database. This is used internally by the class constructor."""
self._logger = logger
self._db_path = db_path
self._verbose = verbose
self._lock = threading.RLock()
self.logger = logger
self.db_path = db_path
self.verbose = verbose
if not self._db_path:
if not self.db_path:
logger.info("Initializing in-memory database")
else:
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._logger.info(f"Initializing database at {self._db_path}")
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info(f"Initializing database at {self.db_path}")
self._conn = sqlite3.connect(database=self._db_path or sqlite_memory, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
self.conn.row_factory = sqlite3.Row
if self._verbose:
self._conn.set_trace_callback(self._logger.debug)
if self.verbose:
self.conn.set_trace_callback(self.logger.debug)
# Enable foreign key constraints
self._conn.execute("PRAGMA foreign_keys = ON;")
self.conn.execute("PRAGMA foreign_keys = ON;")
# Enable Write-Ahead Logging (WAL) mode for better concurrency
self._conn.execute("PRAGMA journal_mode = WAL;")
self.conn.execute("PRAGMA journal_mode = WAL;")
# Set a busy timeout to prevent database lockups during writes
self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
"""
# No need to clean in-memory database
if not self._db_path:
if not self.db_path:
return
try:
with self._conn as conn:
initial_db_size = Path(self._db_path).stat().st_size
conn.execute("VACUUM;")
conn.commit()
final_db_size = Path(self._db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
initial_db_size = Path(self.db_path).stat().st_size
self.conn.execute("VACUUM;")
self.conn.commit()
final_db_size = Path(self.db_path).stat().st_size
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
if freed_space_in_mb > 0:
self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
self._logger.error(f"Error cleaning database: {e}")
self.logger.error(f"Error cleaning database: {e}")
raise
@contextmanager
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
"""
Thread-safe context manager for DB work.
Acquires the RLock, yields a Cursor, then commits or rolls back.
"""
with self._lock:
cursor = self._conn.cursor()
try:
yield cursor
self._conn.commit()
except:
self._conn.rollback()
raise
finally:
cursor.close()

View File

@@ -32,7 +32,7 @@ class SqliteMigrator:
def __init__(self, db: SqliteDatabase) -> None:
self._db = db
self._logger = db._logger
self._logger = db.logger
self._migration_set = MigrationSet()
self._backup_path: Optional[Path] = None
@@ -45,7 +45,7 @@ class SqliteMigrator:
"""Migrates the database to the latest version."""
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
cursor = self._db._conn.cursor()
cursor = self._db.conn.cursor()
self._create_migrations_table(cursor=cursor)
if self._migration_set.count == 0:
@@ -59,13 +59,13 @@ class SqliteMigrator:
self._logger.info("Database update needed")
# Make a backup of the db if it needs to be updated and is a file db
if self._db._db_path is not None:
if self._db.db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
self._backup_path = self._db._db_path.parent / f"{self._db._db_path.stem}_backup_{timestamp}.db"
self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
self._db._conn.backup(backup_conn)
self._db.conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
@@ -81,7 +81,7 @@ class SqliteMigrator:
try:
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
# exception is raised.
with self._db._conn as conn:
with self._db.conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError(

View File

@@ -17,7 +17,7 @@ from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._conn = db.conn
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -25,23 +25,24 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM style_presets
WHERE id = ?;
""",
(style_preset_id,),
)
row = cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
@@ -59,11 +60,16 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
@@ -84,11 +90,16 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
# Change the name of a style preset
if changes.name is not None:
cursor.execute(
@@ -111,10 +122,15 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
(changes.preset_data.model_dump_json(), style_preset_id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE from style_presets
@@ -122,41 +138,51 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
""",
(style_preset_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
with self._db.transaction() as cursor:
main_query = """
SELECT
*
FROM style_presets
"""
main_query = """
SELECT
*
FROM style_presets
"""
if type is not None:
main_query += "WHERE type = ? "
if type is not None:
main_query += "WHERE type = ? "
main_query += "ORDER BY LOWER(name) ASC"
main_query += "ORDER BY LOWER(name) ASC"
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
cursor = self._conn.cursor()
if type is not None:
cursor.execute(main_query, (type,))
else:
cursor.execute(main_query)
rows = cursor.fetchall()
rows = cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
with self._db.transaction() as cursor:
# First delete all existing default style presets
# First delete all existing default style presets
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
"""
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
# Next, parse and create the default style presets
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)

View File

@@ -25,7 +25,7 @@ SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%f"
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
self._conn = db.conn
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -33,16 +33,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Gets a workflow by ID. Updates the opened_at column."""
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
FROM workflow_library
WHERE workflow_id = ?;
""",
(workflow_id,),
)
row = cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
@@ -51,8 +51,9 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be created via this method")
with self._db.transaction() as cursor:
try:
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO workflow_library (
@@ -63,13 +64,18 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_with_id.id, workflow_with_id.model_dump_json()),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(workflow_with_id.id)
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be updated")
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
UPDATE workflow_library
@@ -78,13 +84,18 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow.model_dump_json(), workflow.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return self.get(workflow.id)
def delete(self, workflow_id: str) -> None:
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be deleted")
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
"""--sql
DELETE from workflow_library
@@ -92,6 +103,10 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
return None
def get_many(
@@ -106,108 +121,108 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
with self._db.transaction() as cursor:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
# We will construct the query dynamically based on the query params
# We will construct the query dynamically based on the query params
# The main query to get the workflows / counts
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at,
tags
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
# The main query to get the workflows / counts
main_query = """
SELECT
workflow_id,
category,
name,
description,
created_at,
updated_at,
opened_at,
tags
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
# Start with an empty list of conditions and params
conditions: list[str] = []
params: list[str | int] = []
# Start with an empty list of conditions and params
conditions: list[str] = []
params: list[str | int] = []
if categories:
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
if categories:
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
# Ensure all categories are valid (is this necessary?)
assert all(c in WorkflowCategory for c in categories)
# Ensure all categories are valid (is this necessary?)
assert all(c in WorkflowCategory for c in categories)
# Construct a placeholder string for the number of categories
placeholders = ", ".join("?" for _ in categories)
# Construct a placeholder string for the number of categories
placeholders = ", ".join("?" for _ in categories)
# Construct the condition string & params
category_condition = f"category IN ({placeholders})"
category_params = [category.value for category in categories]
# Construct the condition string & params
category_condition = f"category IN ({placeholders})"
category_params = [category.value for category in categories]
conditions.append(category_condition)
params.extend(category_params)
conditions.append(category_condition)
params.extend(category_params)
if tags:
# Tags is a list of strings, and a single string in the DB
# The string in the DB has no guaranteed format
if tags:
# Tags is a list of strings, and a single string in the DB
# The string in the DB has no guaranteed format
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# Construct a list of conditions for each tag
tags_conditions = ["tags LIKE ?" for _ in tags]
tags_conditions_joined = " OR ".join(tags_conditions)
tags_condition = f"({tags_conditions_joined})"
# And the params for the tags, case-insensitive
tags_params = [f"%{t.strip()}%" for t in tags]
# And the params for the tags, case-insensitive
tags_params = [f"%{t.strip()}%" for t in tags]
conditions.append(tags_condition)
params.extend(tags_params)
conditions.append(tags_condition)
params.extend(tags_params)
if has_been_opened:
conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
conditions.append("opened_at IS NULL")
if has_been_opened:
conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
conditions.append("opened_at IS NULL")
# Ignore whitespace in the query
stripped_query = query.strip() if query else None
if stripped_query:
# Construct a wildcard query for the name, description, and tags
wildcard_query = "%" + stripped_query + "%"
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
# Ignore whitespace in the query
stripped_query = query.strip() if query else None
if stripped_query:
# Construct a wildcard query for the name, description, and tags
wildcard_query = "%" + stripped_query + "%"
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
count_query += " WHERE "
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
count_query += " WHERE "
all_conditions = " AND ".join(conditions)
main_query += all_conditions
count_query += all_conditions
all_conditions = " AND ".join(conditions)
main_query += all_conditions
count_query += all_conditions
# After this point, the query and params differ for the main query and the count query
main_params = params.copy()
count_params = params.copy()
# After this point, the query and params differ for the main query and the count query
main_params = params.copy()
count_params = params.copy()
# Main query also gets ORDER BY and LIMIT/OFFSET
main_query += f" ORDER BY {order_by.value} {direction.value}"
# Main query also gets ORDER BY and LIMIT/OFFSET
main_query += f" ORDER BY {order_by.value} {direction.value}"
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
if per_page:
main_query += " LIMIT ? OFFSET ?"
main_params.extend([per_page, page * per_page])
# Put a ring on it
main_query += ";"
count_query += ";"
# Put a ring on it
main_query += ";"
count_query += ";"
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor = self._conn.cursor()
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
cursor.execute(count_query, count_params)
total = cursor.fetchone()[0]
if per_page:
pages = total // per_page + (total % per_page > 0)
@@ -232,46 +247,46 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if not tags:
return {}
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Add this specific tag condition
conditions.append("tags LIKE ?")
params.append(f"%{tag.strip()}%")
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[tag] = count
return result
@@ -281,51 +296,52 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
base_params: list[str | int] = []
cursor = self._conn.cursor()
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
base_params: list[str | int] = []
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
# Add category conditions
if categories:
assert all(c in WorkflowCategory for c in categories)
placeholders = ", ".join("?" for _ in categories)
base_conditions.append(f"category IN ({placeholders})")
base_params.extend([category.value for category in categories])
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if has_been_opened:
base_conditions.append("opened_at IS NOT NULL")
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
conditions = base_conditions.copy()
params = base_params.copy()
# Add this specific category condition
conditions.append("category = ?")
params.append(category.value)
# Add this specific category condition
conditions.append("category = ?")
params.append(category.value)
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
# Construct the full query
stmt = """--sql
SELECT COUNT(*)
FROM workflow_library
"""
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
if conditions:
stmt += " WHERE " + " AND ".join(conditions)
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[category.value] = count
cursor.execute(stmt, params)
count = cursor.fetchone()[0]
result[category.value] = count
return result
def update_opened_at(self, workflow_id: str) -> None:
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
cursor.execute(
f"""--sql
UPDATE workflow_library
@@ -334,6 +350,10 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise
def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database. Internal use only."""
@@ -348,7 +368,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
meaningless, as they are overwritten every time the server starts.
"""
with self._db.transaction() as cursor:
try:
cursor = self._conn.cursor()
workflows_from_file: list[Workflow] = []
workflows_to_update: list[Workflow] = []
workflows_to_add: list[Workflow] = []
@@ -428,3 +449,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(w.model_dump_json(), w.id),
)
self._conn.commit()
except Exception:
self._conn.rollback()
raise

View File

@@ -123,11 +123,7 @@ def calc_percentage(intermediate_state: PipelineIntermediateState) -> float:
if total_steps == 0:
return 0.0
if order == 2:
# Prevent division by zero when total_steps is 1 or 2
denominator = floor(total_steps / 2)
if denominator == 0:
return 0.0
return floor(step / 2) / denominator
return floor(step / 2) / floor(total_steps / 2)
# order == 1
return step / total_steps

View File

@@ -30,11 +30,8 @@ def denoise(
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
pos_ip_adapter_extensions: list[XLabsIPAdapterExtension],
neg_ip_adapter_extensions: list[XLabsIPAdapterExtension],
# extra img tokens (channel-wise)
# extra img tokens
img_cond: torch.Tensor | None,
# extra img tokens (sequence-wise) - for Kontext conditioning
img_cond_seq: torch.Tensor | None = None,
img_cond_seq_ids: torch.Tensor | None = None,
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -49,10 +46,6 @@ def denoise(
)
# guidance_vec is ignored for schnell.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
# Store original sequence length for slicing predictions
original_seq_len = img.shape[1]
for step_index, (t_curr, t_prev) in tqdm(list(enumerate(zip(timesteps[:-1], timesteps[1:], strict=True)))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
@@ -78,26 +71,10 @@ def denoise(
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
# Prepare input for model - concatenate fresh each step
img_input = img
img_input_ids = img_ids
# Add channel-wise conditioning (for ControlNet, FLUX Fill, etc.)
if img_cond is not None:
img_input = torch.cat((img_input, img_cond), dim=-1)
# Add sequence-wise conditioning (for Kontext)
if img_cond_seq is not None:
assert img_cond_seq_ids is not None, (
"You need to provide either both or neither of the sequence conditioning"
)
img_input = torch.cat((img_input, img_cond_seq), dim=1)
img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
pred_img = torch.cat((img, img_cond), dim=-1) if img_cond is not None else img
pred = model(
img=img_input,
img_ids=img_input_ids,
img=pred_img,
img_ids=img_ids,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
@@ -111,10 +88,6 @@ def denoise(
regional_prompting_extension=pos_regional_prompting_extension,
)
# Slice prediction to only include the main image tokens
if img_input_ids is not None:
pred = pred[:, :original_seq_len]
step_cfg_scale = cfg_scale[step_index]
# If step_cfg_scale, is 1.0, then we don't need to run the negative prediction.

View File

@@ -1,149 +0,0 @@
import einops
import numpy as np
import torch
from einops import repeat
from PIL import Image
from invokeai.app.invocations.fields import FluxKontextConditioningField
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.invocations.model import VAEField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.sampling_utils import pack
from invokeai.backend.flux.util import PREFERED_KONTEXT_RESOLUTIONS
def generate_img_ids_with_offset(
latent_height: int,
latent_width: int,
batch_size: int,
device: torch.device,
dtype: torch.dtype,
idx_offset: int = 0,
) -> torch.Tensor:
"""Generate tensor of image position ids with an optional offset.
Args:
latent_height (int): Height of image in latent space (after packing, this becomes h//2).
latent_width (int): Width of image in latent space (after packing, this becomes w//2).
batch_size (int): Number of images in the batch.
device (torch.device): Device to create tensors on.
dtype (torch.dtype): Data type for the tensors.
idx_offset (int): Offset to add to the first dimension of the image ids.
Returns:
torch.Tensor: Image position ids with shape [batch_size, (latent_height//2 * latent_width//2), 3].
"""
if device.type == "mps":
orig_dtype = dtype
dtype = torch.float16
# After packing, the spatial dimensions are halved due to the 2x2 patch structure
packed_height = latent_height // 2
packed_width = latent_width // 2
# Create base tensor for position IDs with shape [packed_height, packed_width, 3]
# The 3 channels represent: [batch_offset, y_position, x_position]
img_ids = torch.zeros(packed_height, packed_width, 3, device=device, dtype=dtype)
# Set the batch offset for all positions
img_ids[..., 0] = idx_offset
# Create y-coordinate indices (vertical positions)
y_indices = torch.arange(packed_height, device=device, dtype=dtype)
# Broadcast y_indices to match the spatial dimensions [packed_height, 1]
img_ids[..., 1] = y_indices[:, None]
# Create x-coordinate indices (horizontal positions)
x_indices = torch.arange(packed_width, device=device, dtype=dtype)
# Broadcast x_indices to match the spatial dimensions [1, packed_width]
img_ids[..., 2] = x_indices[None, :]
# Expand to include batch dimension: [batch_size, (packed_height * packed_width), 3]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
if device.type == "mps":
img_ids = img_ids.to(orig_dtype)
return img_ids
class KontextExtension:
"""Applies FLUX Kontext (reference image) conditioning."""
def __init__(
self,
kontext_conditioning: FluxKontextConditioningField,
context: InvocationContext,
vae_field: VAEField,
device: torch.device,
dtype: torch.dtype,
):
"""
Initializes the KontextExtension, pre-processing the reference image
into latents and positional IDs.
"""
self._context = context
self._device = device
self._dtype = dtype
self._vae_field = vae_field
self.kontext_conditioning = kontext_conditioning
# Pre-process and cache the kontext latents and ids upon initialization.
self.kontext_latents, self.kontext_ids = self._prepare_kontext()
def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Encodes the reference image and prepares its latents and IDs."""
image = self._context.images.get_pil(self.kontext_conditioning.image.image_name)
# Calculate aspect ratio of input image
width, height = image.size
aspect_ratio = width / height
# Find the closest preferred resolution by aspect ratio
_, target_width, target_height = min(
((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS), key=lambda x: x[0]
)
# Apply BFL's scaling formula
# This ensures compatibility with the model's training
scaled_width = 2 * int(target_width / 16)
scaled_height = 2 * int(target_height / 16)
# Resize to the exact resolution used during training
image = image.convert("RGB")
final_width = 8 * scaled_width
final_height = 8 * scaled_height
image = image.resize((final_width, final_height), Image.Resampling.LANCZOS)
# Convert to tensor with same normalization as BFL
image_np = np.array(image)
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0
image_tensor = einops.rearrange(image_tensor, "h w c -> 1 c h w")
image_tensor = image_tensor.to(self._device)
# Continue with VAE encoding
vae_info = self._context.models.load(self._vae_field.vae)
kontext_latents_unpacked = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
# Extract tensor dimensions
batch_size, _, latent_height, latent_width = kontext_latents_unpacked.shape
# Pack the latents and generate IDs
kontext_latents_packed = pack(kontext_latents_unpacked).to(self._device, self._dtype)
kontext_ids = generate_img_ids_with_offset(
latent_height=latent_height,
latent_width=latent_width,
batch_size=batch_size,
device=self._device,
dtype=self._dtype,
idx_offset=1,
)
return kontext_latents_packed, kontext_ids
def ensure_batch_size(self, target_batch_size: int) -> None:
"""Ensures the kontext latents and IDs match the target batch size by repeating if necessary."""
if self.kontext_latents.shape[0] != target_batch_size:
self.kontext_latents = self.kontext_latents.repeat(target_batch_size, 1, 1)
self.kontext_ids = self.kontext_ids.repeat(target_batch_size, 1, 1)

View File

@@ -174,13 +174,11 @@ def generate_img_ids(h: int, w: int, batch_size: int, device: torch.device, dtyp
dtype = torch.float16
img_ids = torch.zeros(h // 2, w // 2, 3, device=device, dtype=dtype)
# Set batch offset to 0 for main image tokens
img_ids[..., 0] = 0
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2, device=device, dtype=dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2, device=device, dtype=dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
if device.type == "mps":
img_ids = img_ids.to(orig_dtype)
img_ids.to(orig_dtype)
return img_ids

View File

@@ -18,29 +18,6 @@ class ModelSpec:
repo_ae: str | None
# Preferred resolutions for Kontext models to avoid tiling artifacts
# These are the specific resolutions the model was trained on
PREFERED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
(752, 1392),
(800, 1328),
(832, 1248),
(880, 1184),
(944, 1104),
(1024, 1024),
(1104, 944),
(1184, 880),
(1248, 832),
(1328, 800),
(1392, 752),
(1456, 720),
(1504, 688),
(1568, 672),
]
max_seq_lengths: Dict[str, Literal[256, 512]] = {
"flux-dev": 512,
"flux-dev-fill": 512,

View File

@@ -42,5 +42,4 @@ IP-Adapters:
- [InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
- [InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
- [InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
- [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)
- [InvokeAI/ip-adapter-plus_sdxl_vit-h](https://huggingface.co/InvokeAI/ip-adapter-plus_sdxl_vit-h)
- [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)

View File

@@ -37,7 +37,6 @@ from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,
@@ -187,7 +186,7 @@ class ModelConfigBase(ABC, BaseModel):
else:
return config_cls.from_model_on_disk(mod, **overrides)
raise InvalidModelConfigException("Unable to determine model type")
raise InvalidModelConfigException("No valid config found")
@classmethod
def get_tag(cls) -> Tag:
@@ -297,7 +296,7 @@ class LoRAConfigBase(ABC, BaseModel):
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
sd = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(sd, mod.metadata())
value = flux_format_from_state_dict(sd)
mod.cache[key] = value
return value
@@ -335,36 +334,6 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin,
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
format: Literal[ModelFormat.OMI] = ModelFormat.OMI
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_dir():
return False
metadata = mod.metadata()
return (
metadata.get("modelspec.sai_model_spec")
and metadata.get("ot_branch") == "omi_format"
and metadata["modelspec.architecture"].split("/")[1].lower() == "lora"
)
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
metadata = mod.metadata()
architecture = metadata["modelspec.architecture"]
if architecture == stable_diffusion_xl_1_lora:
base = BaseModelType.StableDiffusionXL
elif architecture == flux_dev_1_lora:
base = BaseModelType.Flux
else:
raise InvalidModelConfigException(f"Unrecognised/unsupported architecture for OMI LoRA: {architecture}")
return {"base": base}
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
@@ -381,7 +350,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
state_dict = mod.load_state_dict()
for key in state_dict.keys():
if isinstance(key, int):
if type(key) is int:
continue
if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
@@ -699,7 +668,6 @@ AnyModelConfig = Annotated[
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRAOmiConfig, LoRAOmiConfig.get_tag()],
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],

View File

@@ -7,14 +7,7 @@ from typing import Optional
import accelerate
import torch
from safetensors.torch import load_file
from transformers import (
AutoConfig,
AutoModelForTextEncoding,
CLIPTextModel,
CLIPTokenizer,
T5EncoderModel,
T5TokenizerFast,
)
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
@@ -146,7 +139,7 @@ class BnbQuantizedLlmInt8bCheckpointModel(ModelLoader):
)
match submodel_type:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
te2_model_path = Path(config.path) / "text_encoder_2"
model_config = AutoConfig.from_pretrained(te2_model_path)
@@ -190,7 +183,7 @@ class T5EncoderCheckpointModel(ModelLoader):
match submodel_type:
case SubModelType.Tokenizer2 | SubModelType.Tokenizer3:
return T5TokenizerFast.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
return T5Tokenizer.from_pretrained(Path(config.path) / "tokenizer_2", max_length=512)
case SubModelType.TextEncoder2 | SubModelType.TextEncoder3:
return T5EncoderModel.from_pretrained(
Path(config.path) / "text_encoder_2", torch_dtype="auto", low_cpu_mem_usage=True

View File

@@ -13,7 +13,6 @@ from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.omi.omi import convert_from_omi
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
@@ -21,10 +20,6 @@ from invokeai.backend.model_manager.taxonomy import (
ModelType,
SubModelType,
)
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_flux_aitoolkit_format,
lora_model_from_flux_aitoolkit_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
lora_model_from_flux_control_state_dict,
@@ -44,8 +39,6 @@ from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import l
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.LoRA, format=ModelFormat.OMI)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.LoRA, format=ModelFormat.OMI)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlLoRa, format=ModelFormat.LyCORIS)
@@ -80,23 +73,12 @@ class LoRALoader(ModelLoader):
else:
state_dict = torch.load(model_path, map_location="cpu")
# Strip 'bundle_emb' keys - these are unused and currently cause downstream errors.
# To revisit later to determine if they're needed/useful.
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")}
# At the time of writing, we support the OMI standard for base models Flux and SDXL
if config.format == ModelFormat.OMI and self._model_base in [
BaseModelType.StableDiffusionXL,
BaseModelType.Flux,
]:
state_dict = convert_from_omi(state_dict, config.base) # type: ignore
# Apply state_dict key conversions, if necessary.
if self._model_base == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
model = lora_model_from_sd_state_dict(state_dict=state_dict)
elif self._model_base == BaseModelType.Flux:
if config.format in [ModelFormat.Diffusers, ModelFormat.OMI]:
if config.format == ModelFormat.Diffusers:
# HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically
# distributed as a single file without the associated metadata containing the alpha value. We chose
# alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank
@@ -110,10 +92,8 @@ class LoRALoader(ModelLoader):
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
elif is_state_dict_likely_flux_control(state_dict=state_dict):
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
else:
raise ValueError("LoRA model is in unsupported FLUX format")
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
elif self._model_base in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]:

View File

@@ -1,7 +0,0 @@
from invokeai.backend.model_manager.omi.omi import convert_from_omi
from invokeai.backend.model_manager.omi.vendor.model_spec.architecture import (
flux_dev_1_lora,
stable_diffusion_xl_1_lora,
)
__all__ = ["flux_dev_1_lora", "stable_diffusion_xl_1_lora", "convert_from_omi"]

View File

@@ -1,21 +0,0 @@
from invokeai.backend.model_manager.model_on_disk import StateDict
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
convert_flux_lora as omi_flux,
)
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
convert_lora_util as lora_util,
)
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
convert_sdxl_lora as omi_sdxl,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType
def convert_from_omi(weights_sd: StateDict, base: BaseModelType):
keyset = {
BaseModelType.Flux: omi_flux.convert_flux_lora_key_sets(),
BaseModelType.StableDiffusionXL: omi_sdxl.convert_sdxl_lora_key_sets(),
}[base]
source = "omi"
target = "legacy_diffusers"
return lora_util.__convert(weights_sd, keyset, source, target) # type: ignore

View File

@@ -1,20 +0,0 @@
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
LoraConversionKeySet,
map_prefix_range,
)
def map_clip(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("text_projection", "text_projection", parent=key_prefix)]
for k in map_prefix_range("text_model.encoder.layers", "text_model.encoder.layers", parent=key_prefix):
keys += [LoraConversionKeySet("mlp.fc1", "mlp.fc1", parent=k)]
keys += [LoraConversionKeySet("mlp.fc2", "mlp.fc2", parent=k)]
keys += [LoraConversionKeySet("self_attn.k_proj", "self_attn.k_proj", parent=k)]
keys += [LoraConversionKeySet("self_attn.out_proj", "self_attn.out_proj", parent=k)]
keys += [LoraConversionKeySet("self_attn.q_proj", "self_attn.q_proj", parent=k)]
keys += [LoraConversionKeySet("self_attn.v_proj", "self_attn.v_proj", parent=k)]
return keys

View File

@@ -1,84 +0,0 @@
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_clip import map_clip
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
LoraConversionKeySet,
map_prefix_range,
)
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_t5 import map_t5
def __map_double_transformer_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("img_attn.qkv.0", "attn.to_q", parent=key_prefix)]
keys += [LoraConversionKeySet("img_attn.qkv.1", "attn.to_k", parent=key_prefix)]
keys += [LoraConversionKeySet("img_attn.qkv.2", "attn.to_v", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_attn.qkv.0", "attn.add_q_proj", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_attn.qkv.1", "attn.add_k_proj", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_attn.qkv.2", "attn.add_v_proj", parent=key_prefix)]
keys += [LoraConversionKeySet("img_attn.proj", "attn.to_out.0", parent=key_prefix)]
keys += [LoraConversionKeySet("img_mlp.0", "ff.net.0.proj", parent=key_prefix)]
keys += [LoraConversionKeySet("img_mlp.2", "ff.net.2", parent=key_prefix)]
keys += [LoraConversionKeySet("img_mod.lin", "norm1.linear", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_attn.proj", "attn.to_add_out", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_mlp.0", "ff_context.net.0.proj", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_mlp.2", "ff_context.net.2", parent=key_prefix)]
keys += [LoraConversionKeySet("txt_mod.lin", "norm1_context.linear", parent=key_prefix)]
return keys
def __map_single_transformer_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("linear1.0", "attn.to_q", parent=key_prefix)]
keys += [LoraConversionKeySet("linear1.1", "attn.to_k", parent=key_prefix)]
keys += [LoraConversionKeySet("linear1.2", "attn.to_v", parent=key_prefix)]
keys += [LoraConversionKeySet("linear1.3", "proj_mlp", parent=key_prefix)]
keys += [LoraConversionKeySet("linear2", "proj_out", parent=key_prefix)]
keys += [LoraConversionKeySet("modulation.lin", "norm.linear", parent=key_prefix)]
return keys
def __map_transformer(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("txt_in", "context_embedder", parent=key_prefix)]
keys += [
LoraConversionKeySet("final_layer.adaLN_modulation.1", "norm_out.linear", parent=key_prefix, swap_chunks=True)
]
keys += [LoraConversionKeySet("final_layer.linear", "proj_out", parent=key_prefix)]
keys += [
LoraConversionKeySet("guidance_in.in_layer", "time_text_embed.guidance_embedder.linear_1", parent=key_prefix)
]
keys += [
LoraConversionKeySet("guidance_in.out_layer", "time_text_embed.guidance_embedder.linear_2", parent=key_prefix)
]
keys += [LoraConversionKeySet("vector_in.in_layer", "time_text_embed.text_embedder.linear_1", parent=key_prefix)]
keys += [LoraConversionKeySet("vector_in.out_layer", "time_text_embed.text_embedder.linear_2", parent=key_prefix)]
keys += [LoraConversionKeySet("time_in.in_layer", "time_text_embed.timestep_embedder.linear_1", parent=key_prefix)]
keys += [LoraConversionKeySet("time_in.out_layer", "time_text_embed.timestep_embedder.linear_2", parent=key_prefix)]
keys += [LoraConversionKeySet("img_in.proj", "x_embedder", parent=key_prefix)]
for k in map_prefix_range("double_blocks", "transformer_blocks", parent=key_prefix):
keys += __map_double_transformer_block(k)
for k in map_prefix_range("single_blocks", "single_transformer_blocks", parent=key_prefix):
keys += __map_single_transformer_block(k)
return keys
def convert_flux_lora_key_sets() -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("bundle_emb", "bundle_emb")]
keys += __map_transformer(LoraConversionKeySet("transformer", "lora_transformer"))
keys += map_clip(LoraConversionKeySet("clip_l", "lora_te1"))
keys += map_t5(LoraConversionKeySet("t5", "lora_te2"))
return keys

View File

@@ -1,217 +0,0 @@
import torch
from torch import Tensor
from typing_extensions import Self
class LoraConversionKeySet:
def __init__(
self,
omi_prefix: str,
diffusers_prefix: str,
legacy_diffusers_prefix: str | None = None,
parent: Self | None = None,
swap_chunks: bool = False,
filter_is_last: bool | None = None,
next_omi_prefix: str | None = None,
next_diffusers_prefix: str | None = None,
):
if parent is not None:
self.omi_prefix = combine(parent.omi_prefix, omi_prefix)
self.diffusers_prefix = combine(parent.diffusers_prefix, diffusers_prefix)
else:
self.omi_prefix = omi_prefix
self.diffusers_prefix = diffusers_prefix
if legacy_diffusers_prefix is None:
self.legacy_diffusers_prefix = self.diffusers_prefix.replace(".", "_")
elif parent is not None:
self.legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, legacy_diffusers_prefix).replace(
".", "_"
)
else:
self.legacy_diffusers_prefix = legacy_diffusers_prefix
self.parent = parent
self.swap_chunks = swap_chunks
self.filter_is_last = filter_is_last
self.prefix = parent
if next_omi_prefix is None and parent is not None:
self.next_omi_prefix = parent.next_omi_prefix
self.next_diffusers_prefix = parent.next_diffusers_prefix
self.next_legacy_diffusers_prefix = parent.next_legacy_diffusers_prefix
elif next_omi_prefix is not None and parent is not None:
self.next_omi_prefix = combine(parent.omi_prefix, next_omi_prefix)
self.next_diffusers_prefix = combine(parent.diffusers_prefix, next_diffusers_prefix)
self.next_legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, next_diffusers_prefix).replace(
".", "_"
)
elif next_omi_prefix is not None and parent is None:
self.next_omi_prefix = next_omi_prefix
self.next_diffusers_prefix = next_diffusers_prefix
self.next_legacy_diffusers_prefix = next_diffusers_prefix.replace(".", "_")
else:
self.next_omi_prefix = None
self.next_diffusers_prefix = None
self.next_legacy_diffusers_prefix = None
def __get_omi(self, in_prefix: str, key: str) -> str:
return self.omi_prefix + key.removeprefix(in_prefix)
def __get_diffusers(self, in_prefix: str, key: str) -> str:
return self.diffusers_prefix + key.removeprefix(in_prefix)
def __get_legacy_diffusers(self, in_prefix: str, key: str) -> str:
key = self.legacy_diffusers_prefix + key.removeprefix(in_prefix)
suffix = key[key.rfind(".") :]
if suffix not in [".alpha", ".dora_scale"]: # some keys only have a single . in the suffix
suffix = key[key.removesuffix(suffix).rfind(".") :]
key = key.removesuffix(suffix)
return key.replace(".", "_") + suffix
def get_key(self, in_prefix: str, key: str, target: str) -> str:
if target == "omi":
return self.__get_omi(in_prefix, key)
elif target == "diffusers":
return self.__get_diffusers(in_prefix, key)
elif target == "legacy_diffusers":
return self.__get_legacy_diffusers(in_prefix, key)
return key
def __str__(self) -> str:
return f"omi: {self.omi_prefix}, diffusers: {self.diffusers_prefix}, legacy: {self.legacy_diffusers_prefix}"
def combine(left: str, right: str) -> str:
left = left.rstrip(".")
right = right.lstrip(".")
if left == "" or left is None:
return right
elif right == "" or right is None:
return left
else:
return left + "." + right
def map_prefix_range(
omi_prefix: str,
diffusers_prefix: str,
parent: LoraConversionKeySet,
) -> list[LoraConversionKeySet]:
# 100 should be a safe upper bound. increase if it's not enough in the future
return [
LoraConversionKeySet(
omi_prefix=f"{omi_prefix}.{i}",
diffusers_prefix=f"{diffusers_prefix}.{i}",
parent=parent,
next_omi_prefix=f"{omi_prefix}.{i + 1}",
next_diffusers_prefix=f"{diffusers_prefix}.{i + 1}",
)
for i in range(100)
]
def __convert(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
source: str,
target: str,
) -> dict[str, Tensor]:
out_states = {}
if source == target:
return dict(state_dict)
# TODO: maybe replace with a non O(n^2) algorithm
for key, tensor in state_dict.items():
for key_set in key_sets:
in_prefix = ""
if source == "omi":
in_prefix = key_set.omi_prefix
elif source == "diffusers":
in_prefix = key_set.diffusers_prefix
elif source == "legacy_diffusers":
in_prefix = key_set.legacy_diffusers_prefix
if not key.startswith(in_prefix):
continue
if key_set.filter_is_last is not None:
next_prefix = None
if source == "omi":
next_prefix = key_set.next_omi_prefix
elif source == "diffusers":
next_prefix = key_set.next_diffusers_prefix
elif source == "legacy_diffusers":
next_prefix = key_set.next_legacy_diffusers_prefix
is_last = not any(k.startswith(next_prefix) for k in state_dict)
if key_set.filter_is_last != is_last:
continue
name = key_set.get_key(in_prefix, key, target)
can_swap_chunks = target == "omi" or source == "omi"
if key_set.swap_chunks and name.endswith(".lora_up.weight") and can_swap_chunks:
chunk_0, chunk_1 = tensor.chunk(2, dim=0)
tensor = torch.cat([chunk_1, chunk_0], dim=0)
out_states[name] = tensor
break # only map the first matching key set
return out_states
def __detect_source(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
) -> str:
omi_count = 0
diffusers_count = 0
legacy_diffusers_count = 0
for key in state_dict:
for key_set in key_sets:
if key.startswith(key_set.omi_prefix):
omi_count += 1
if key.startswith(key_set.diffusers_prefix):
diffusers_count += 1
if key.startswith(key_set.legacy_diffusers_prefix):
legacy_diffusers_count += 1
if omi_count > diffusers_count and omi_count > legacy_diffusers_count:
return "omi"
if diffusers_count > omi_count and diffusers_count > legacy_diffusers_count:
return "diffusers"
if legacy_diffusers_count > omi_count and legacy_diffusers_count > diffusers_count:
return "legacy_diffusers"
return ""
def convert_to_omi(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
) -> dict[str, Tensor]:
source = __detect_source(state_dict, key_sets)
return __convert(state_dict, key_sets, source, "omi")
def convert_to_diffusers(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
) -> dict[str, Tensor]:
source = __detect_source(state_dict, key_sets)
return __convert(state_dict, key_sets, source, "diffusers")
def convert_to_legacy_diffusers(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
) -> dict[str, Tensor]:
source = __detect_source(state_dict, key_sets)
return __convert(state_dict, key_sets, source, "legacy_diffusers")

View File

@@ -1,125 +0,0 @@
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_clip import map_clip
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
LoraConversionKeySet,
map_prefix_range,
)
def __map_unet_resnet_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("emb_layers.1", "time_emb_proj", parent=key_prefix)]
keys += [LoraConversionKeySet("in_layers.2", "conv1", parent=key_prefix)]
keys += [LoraConversionKeySet("out_layers.3", "conv2", parent=key_prefix)]
keys += [LoraConversionKeySet("skip_connection", "conv_shortcut", parent=key_prefix)]
return keys
def __map_unet_attention_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("proj_in", "proj_in", parent=key_prefix)]
keys += [LoraConversionKeySet("proj_out", "proj_out", parent=key_prefix)]
for k in map_prefix_range("transformer_blocks", "transformer_blocks", parent=key_prefix):
keys += [LoraConversionKeySet("attn1.to_q", "attn1.to_q", parent=k)]
keys += [LoraConversionKeySet("attn1.to_k", "attn1.to_k", parent=k)]
keys += [LoraConversionKeySet("attn1.to_v", "attn1.to_v", parent=k)]
keys += [LoraConversionKeySet("attn1.to_out.0", "attn1.to_out.0", parent=k)]
keys += [LoraConversionKeySet("attn2.to_q", "attn2.to_q", parent=k)]
keys += [LoraConversionKeySet("attn2.to_k", "attn2.to_k", parent=k)]
keys += [LoraConversionKeySet("attn2.to_v", "attn2.to_v", parent=k)]
keys += [LoraConversionKeySet("attn2.to_out.0", "attn2.to_out.0", parent=k)]
keys += [LoraConversionKeySet("ff.net.0.proj", "ff.net.0.proj", parent=k)]
keys += [LoraConversionKeySet("ff.net.2", "ff.net.2", parent=k)]
return keys
def __map_unet_down_blocks(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += __map_unet_resnet_block(LoraConversionKeySet("1.0", "0.resnets.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("2.0", "0.resnets.1", parent=key_prefix))
keys += [LoraConversionKeySet("3.0.op", "0.downsamplers.0.conv", parent=key_prefix)]
keys += __map_unet_resnet_block(LoraConversionKeySet("4.0", "1.resnets.0", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("4.1", "1.attentions.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("5.0", "1.resnets.1", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("5.1", "1.attentions.1", parent=key_prefix))
keys += [LoraConversionKeySet("6.0.op", "1.downsamplers.0.conv", parent=key_prefix)]
keys += __map_unet_resnet_block(LoraConversionKeySet("7.0", "2.resnets.0", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("7.1", "2.attentions.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("8.0", "2.resnets.1", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("8.1", "2.attentions.1", parent=key_prefix))
return keys
def __map_unet_mid_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += __map_unet_resnet_block(LoraConversionKeySet("0", "resnets.0", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("1", "attentions.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("2", "resnets.1", parent=key_prefix))
return keys
def __map_unet_up_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += __map_unet_resnet_block(LoraConversionKeySet("0.0", "0.resnets.0", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("0.1", "0.attentions.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("1.0", "0.resnets.1", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("1.1", "0.attentions.1", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("2.0", "0.resnets.2", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("2.1", "0.attentions.2", parent=key_prefix))
keys += [LoraConversionKeySet("2.2.conv", "0.upsamplers.0.conv", parent=key_prefix)]
keys += __map_unet_resnet_block(LoraConversionKeySet("3.0", "1.resnets.0", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("3.1", "1.attentions.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("4.0", "1.resnets.1", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("4.1", "1.attentions.1", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("5.0", "1.resnets.2", parent=key_prefix))
keys += __map_unet_attention_block(LoraConversionKeySet("5.1", "1.attentions.2", parent=key_prefix))
keys += [LoraConversionKeySet("5.2.conv", "1.upsamplers.0.conv", parent=key_prefix)]
keys += __map_unet_resnet_block(LoraConversionKeySet("6.0", "2.resnets.0", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("7.0", "2.resnets.1", parent=key_prefix))
keys += __map_unet_resnet_block(LoraConversionKeySet("8.0", "2.resnets.2", parent=key_prefix))
return keys
def __map_unet(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("input_blocks.0.0", "conv_in", parent=key_prefix)]
keys += [LoraConversionKeySet("time_embed.0", "time_embedding.linear_1", parent=key_prefix)]
keys += [LoraConversionKeySet("time_embed.2", "time_embedding.linear_2", parent=key_prefix)]
keys += [LoraConversionKeySet("label_emb.0.0", "add_embedding.linear_1", parent=key_prefix)]
keys += [LoraConversionKeySet("label_emb.0.2", "add_embedding.linear_2", parent=key_prefix)]
keys += __map_unet_down_blocks(LoraConversionKeySet("input_blocks", "down_blocks", parent=key_prefix))
keys += __map_unet_mid_block(LoraConversionKeySet("middle_block", "mid_block", parent=key_prefix))
keys += __map_unet_up_block(LoraConversionKeySet("output_blocks", "up_blocks", parent=key_prefix))
keys += [LoraConversionKeySet("out.0", "conv_norm_out", parent=key_prefix)]
keys += [LoraConversionKeySet("out.2", "conv_out", parent=key_prefix)]
return keys
def convert_sdxl_lora_key_sets() -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("bundle_emb", "bundle_emb")]
keys += __map_unet(LoraConversionKeySet("unet", "lora_unet"))
keys += map_clip(LoraConversionKeySet("clip_l", "lora_te1"))
keys += map_clip(LoraConversionKeySet("clip_g", "lora_te2"))
return keys

View File

@@ -1,19 +0,0 @@
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
LoraConversionKeySet,
map_prefix_range,
)
def map_t5(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
keys = []
for k in map_prefix_range("encoder.block", "encoder.block", parent=key_prefix):
keys += [LoraConversionKeySet("layer.0.SelfAttention.k", "layer.0.SelfAttention.k", parent=k)]
keys += [LoraConversionKeySet("layer.0.SelfAttention.o", "layer.0.SelfAttention.o", parent=k)]
keys += [LoraConversionKeySet("layer.0.SelfAttention.q", "layer.0.SelfAttention.q", parent=k)]
keys += [LoraConversionKeySet("layer.0.SelfAttention.v", "layer.0.SelfAttention.v", parent=k)]
keys += [LoraConversionKeySet("layer.1.DenseReluDense.wi_0", "layer.1.DenseReluDense.wi_0", parent=k)]
keys += [LoraConversionKeySet("layer.1.DenseReluDense.wi_1", "layer.1.DenseReluDense.wi_1", parent=k)]
keys += [LoraConversionKeySet("layer.1.DenseReluDense.wo", "layer.1.DenseReluDense.wo", parent=k)]
return keys

View File

@@ -1,31 +0,0 @@
stable_diffusion_1_lora = "stable-diffusion-v1/lora"
stable_diffusion_1_inpainting_lora = "stable-diffusion-v1-inpainting/lora"
stable_diffusion_2_512_lora = "stable-diffusion-v2-512/lora"
stable_diffusion_2_768_v_lora = "stable-diffusion-v2-768-v/lora"
stable_diffusion_2_depth_lora = "stable-diffusion-v2-depth/lora"
stable_diffusion_2_inpainting_lora = "stable-diffusion-v2-inpainting/lora"
stable_diffusion_3_medium_lora = "stable-diffusion-v3-medium/lora"
stable_diffusion_35_medium_lora = "stable-diffusion-v3.5-medium/lora"
stable_diffusion_35_large_lora = "stable-diffusion-v3.5-large/lora"
stable_diffusion_xl_1_lora = "stable-diffusion-xl-v1-base/lora"
stable_diffusion_xl_1_inpainting_lora = "stable-diffusion-xl-v1-base-inpainting/lora"
wuerstchen_2_lora = "wuerstchen-v2-prior/lora"
stable_cascade_1_stage_a_lora = "stable-cascade-v1-stage-a/lora"
stable_cascade_1_stage_b_lora = "stable-cascade-v1-stage-b/lora"
stable_cascade_1_stage_c_lora = "stable-cascade-v1-stage-c/lora"
pixart_alpha_lora = "pixart-alpha/lora"
pixart_sigma_lora = "pixart-sigma/lora"
flux_dev_1_lora = "Flux.1-dev/lora"
flux_fill_dev_1_lora = "Flux.1-fill-dev/lora"
sana_lora = "sana/lora"
hunyuan_video_lora = "hunyuan-video/lora"
hi_dream_i1_lora = "hidream-i1/lora"

View File

@@ -23,7 +23,7 @@ class StarterModel(StarterModelWithoutDependencies):
dependencies: Optional[list[StarterModelWithoutDependencies]] = None
class StarterModelBundle(BaseModel):
class StarterModelBundles(BaseModel):
name: str
models: list[StarterModel]
@@ -109,7 +109,7 @@ flux_vae = StarterModel(
# region: Main
flux_schnell_quantized = StarterModel(
name="FLUX.1 schnell (quantized)",
name="FLUX Schnell (Quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/bnb_nf4/flux1-schnell-bnb_nf4.safetensors",
description="FLUX schnell transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
@@ -117,7 +117,7 @@ flux_schnell_quantized = StarterModel(
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
flux_dev_quantized = StarterModel(
name="FLUX.1 dev (quantized)",
name="FLUX Dev (Quantized)",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/bnb_nf4/flux1-dev-bnb_nf4.safetensors",
description="FLUX dev transformer quantized to bitsandbytes NF4 format. Total size with dependencies: ~12GB",
@@ -125,7 +125,7 @@ flux_dev_quantized = StarterModel(
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
flux_schnell = StarterModel(
name="FLUX.1 schnell",
name="FLUX Schnell",
base=BaseModelType.Flux,
source="InvokeAI/flux_schnell::transformer/base/flux1-schnell.safetensors",
description="FLUX schnell transformer in bfloat16. Total size with dependencies: ~33GB",
@@ -133,29 +133,13 @@ flux_schnell = StarterModel(
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_dev = StarterModel(
name="FLUX.1 dev",
name="FLUX Dev",
base=BaseModelType.Flux,
source="InvokeAI/flux_dev::transformer/base/flux1-dev.safetensors",
description="FLUX dev transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_kontext = StarterModel(
name="FLUX.1 Kontext dev",
base=BaseModelType.Flux,
source="https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/resolve/main/flux1-kontext-dev.safetensors",
description="FLUX.1 Kontext dev transformer in bfloat16. Total size with dependencies: ~33GB",
type=ModelType.Main,
dependencies=[t5_base_encoder, flux_vae, clip_l_encoder],
)
flux_kontext_quantized = StarterModel(
name="FLUX.1 Kontext dev (Quantized)",
base=BaseModelType.Flux,
source="https://huggingface.co/unsloth/FLUX.1-Kontext-dev-GGUF/resolve/main/flux1-kontext-dev-Q4_K_M.gguf",
description="FLUX.1 Kontext dev quantized (q4_k_m). Total size with dependencies: ~14GB",
type=ModelType.Main,
dependencies=[t5_8b_quantized_encoder, flux_vae, clip_l_encoder],
)
sd35_medium = StarterModel(
name="SD3.5 Medium",
base=BaseModelType.StableDiffusion3,
@@ -313,15 +297,6 @@ ip_adapter_sdxl = StarterModel(
dependencies=[ip_adapter_sdxl_image_encoder],
previous_names=["IP Adapter SDXL"],
)
ip_adapter_plus_sdxl = StarterModel(
name="Precise Reference (IP Adapter Plus ViT-H)",
base=BaseModelType.StableDiffusionXL,
source="https://huggingface.co/InvokeAI/ip-adapter-plus_sdxl_vit-h/resolve/main/ip-adapter-plus_sdxl_vit-h.safetensors",
description="References images with a higher degree of precision.",
type=ModelType.IPAdapter,
dependencies=[ip_adapter_sdxl_image_encoder],
previous_names=["IP Adapter Plus SDXL"],
)
ip_adapter_flux = StarterModel(
name="Standard Reference (XLabs FLUX IP-Adapter v2)",
base=BaseModelType.Flux,
@@ -672,7 +647,6 @@ flux_fill = StarterModel(
# List of starter models, displayed on the frontend.
# The order/sort of this list is not changed by the frontend - set it how you want it here.
STARTER_MODELS: list[StarterModel] = [
flux_kontext_quantized,
flux_schnell_quantized,
flux_dev_quantized,
flux_schnell,
@@ -698,7 +672,6 @@ STARTER_MODELS: list[StarterModel] = [
ip_adapter_plus_sd1,
ip_adapter_plus_face_sd1,
ip_adapter_sdxl,
ip_adapter_plus_sdxl,
ip_adapter_flux,
qr_code_cnet_sd1,
qr_code_cnet_sdxl,
@@ -771,7 +744,6 @@ sdxl_bundle: list[StarterModel] = [
juggernaut_sdxl,
sdxl_fp16_vae_fix,
ip_adapter_sdxl,
ip_adapter_plus_sdxl,
canny_sdxl,
depth_sdxl,
softedge_sdxl,
@@ -793,13 +765,12 @@ flux_bundle: list[StarterModel] = [
flux_depth_control_lora,
flux_redux,
flux_fill,
flux_kontext_quantized,
]
STARTER_BUNDLES: dict[str, StarterModelBundle] = {
BaseModelType.StableDiffusion1: StarterModelBundle(name="Stable Diffusion 1.5", models=sd1_bundle),
BaseModelType.StableDiffusionXL: StarterModelBundle(name="SDXL", models=sdxl_bundle),
BaseModelType.Flux: StarterModelBundle(name="FLUX.1 dev", models=flux_bundle),
STARTER_BUNDLES: dict[str, list[StarterModel]] = {
BaseModelType.StableDiffusion1: sd1_bundle,
BaseModelType.StableDiffusionXL: sdxl_bundle,
BaseModelType.Flux: flux_bundle,
}
assert len(STARTER_MODELS) == len({m.source for m in STARTER_MODELS}), "Duplicate starter models"

View File

@@ -89,7 +89,6 @@ class ModelVariantType(str, Enum):
class ModelFormat(str, Enum):
"""Storage format of model."""
OMI = "omi"
Diffusers = "diffusers"
Checkpoint = "checkpoint"
LyCORIS = "lycoris"
@@ -139,7 +138,6 @@ class FluxLoRAFormat(str, Enum):
Kohya = "flux.kohya"
OneTrainer = "flux.onetrainer"
Control = "flux.control"
AIToolkit = "flux.aitoolkit"
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]

View File

@@ -46,10 +46,6 @@ class ModelPatcher:
text_encoder: Union[CLIPTextModel, CLIPTextModelWithProjection],
ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
if len(ti_list) == 0:
yield tokenizer, TextualInversionManager(tokenizer)
return
init_tokens_count = None
new_tokens_added = None

View File

@@ -1,63 +0,0 @@
import json
from dataclasses import dataclass, field
from typing import Any
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import _group_by_layer
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.util import InvokeAILogger
def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
if metadata:
try:
software = json.loads(metadata.get("software", "{}"))
except json.JSONDecodeError:
return False
return software.get("name") == "ai-toolkit"
# metadata got lost somewhere
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
@dataclass
class GroupedStateDict:
transformer: dict[str, Any] = field(default_factory=dict)
# might also grow CLIP and T5 submodels
def _group_state_by_submodel(state_dict: dict[str, Any]) -> GroupedStateDict:
logger = InvokeAILogger.get_logger()
grouped = GroupedStateDict()
for key, value in state_dict.items():
submodel_name, param_name = key.split(".", 1)
match submodel_name:
case "diffusion_model":
grouped.transformer[param_name] = value
case _:
logger.warning(f"Unexpected submodel name: {submodel_name}")
return grouped
def _rename_peft_lora_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Renames keys from the PEFT LoRA format to the InvokeAI format."""
renamed_state_dict = {}
for key, value in state_dict.items():
renamed_key = key.replace(".lora_A.", ".lora_down.").replace(".lora_B.", ".lora_up.")
renamed_state_dict[renamed_key] = value
return renamed_state_dict
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
state_dict = _rename_peft_lora_keys(state_dict)
by_layer = _group_by_layer(state_dict)
by_model = _group_state_by_submodel(by_layer)
layers: dict[str, BaseLayerPatch] = {}
for layer_key, layer_state_dict in by_model.transformer.items():
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
return ModelPatchRaw(layers=layers)

View File

@@ -1,7 +1,4 @@
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_flux_aitoolkit_format,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
@@ -14,7 +11,7 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
)
def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None:
def flux_format_from_state_dict(state_dict):
if is_state_dict_likely_in_flux_kohya_format(state_dict):
return FluxLoRAFormat.Kohya
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
@@ -23,7 +20,5 @@ def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None)
return FluxLoRAFormat.Diffusers
elif is_state_dict_likely_flux_control(state_dict):
return FluxLoRAFormat.Control
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata):
return FluxLoRAFormat.AIToolkit
else:
return None

View File

@@ -9,25 +9,13 @@ module.exports = {
// https://github.com/qdanik/eslint-plugin-path
'path/no-relative-imports': ['error', { maxDepth: 0 }],
// https://github.com/edvardchen/eslint-plugin-i18next/blob/HEAD/docs/rules/no-literal-string.md
// TODO: ENABLE THIS RULE BEFORE v6.0.0
// 'i18next/no-literal-string': 'error',
'i18next/no-literal-string': 'error',
// https://eslint.org/docs/latest/rules/no-console
'no-console': 'warn',
'no-console': 'error',
// https://eslint.org/docs/latest/rules/no-promise-executor-return
'no-promise-executor-return': 'error',
// https://eslint.org/docs/latest/rules/require-await
'require-await': 'error',
// Restrict setActiveTab calls to only use-navigation-api.tsx
'no-restricted-syntax': [
'error',
{
selector: 'CallExpression[callee.name="setActiveTab"]',
message:
'setActiveTab() can only be called from use-navigation-api.tsx. Use navigationApi.switchToTab() instead.',
},
],
// TODO: ENABLE THIS RULE BEFORE v6.0.0
'react/display-name': 'off',
'no-restricted-properties': [
'error',
{
@@ -42,38 +30,8 @@ module.exports = {
'The Clipboard API is not available by default in Firefox. Use the `useClipboard` hook instead, which wraps clipboard access to prevent errors.',
},
],
'no-restricted-imports': [
'error',
{
paths: [
{
name: 'lodash-es',
importNames: ['isEqual'],
message: 'Please use objectEquals from @observ33r/object-equals instead.',
},
{
name: 'lodash-es',
message: 'Please use es-toolkit instead.',
},
{
name: 'es-toolkit',
importNames: ['isEqual'],
message: 'Please use objectEquals from @observ33r/object-equals instead.',
},
],
},
],
},
overrides: [
/**
* Allow setActiveTab calls only in use-navigation-api.tsx
*/
{
files: ['**/use-navigation-api.tsx'],
rules: {
'no-restricted-syntax': 'off',
},
},
/**
* Overrides for stories
*/

View File

@@ -12,8 +12,10 @@ const config: KnipConfig = {
'src/features/parameters/types/parameterSchemas.ts',
// TODO(psyche): maybe we can clean up these utils after canvas v2 release
'src/features/controlLayers/konva/util.ts',
// Will be using this
'src/common/hooks/useAsyncState.ts',
// TODO(psyche): restore HRF functionality?
'src/features/hrf/**',
// This feature is (temprarily?) disabled
'src/features/controlLayers/components/InpaintMask/InpaintMaskAddButtons.tsx',
],
ignoreBinaries: ['only-allow'],
paths: {

View File

@@ -38,60 +38,70 @@
"test:ui": "vitest --coverage --ui",
"test:no-watch": "vitest --no-watch"
},
"madge": {
"excludeRegExp": [
"^index.ts$"
],
"detectiveOptions": {
"ts": {
"skipTypeImports": true
},
"tsx": {
"skipTypeImports": true
}
}
},
"dependencies": {
"@atlaskit/pragmatic-drag-and-drop": "^1.7.4",
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^2.1.1",
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.1.0",
"@dagrejs/dagre": "^1.1.5",
"@atlaskit/pragmatic-drag-and-drop": "^1.5.3",
"@atlaskit/pragmatic-drag-and-drop-auto-scroll": "^2.1.0",
"@atlaskit/pragmatic-drag-and-drop-hitbox": "^1.0.3",
"@dagrejs/dagre": "^1.1.4",
"@dagrejs/graphlib": "^2.2.4",
"@fontsource-variable/inter": "^5.2.6",
"@fontsource-variable/inter": "^5.2.5",
"@invoke-ai/ui-library": "^0.0.46",
"@nanostores/react": "^1.0.0",
"@observ33r/object-equals": "^1.1.4",
"@reduxjs/toolkit": "2.8.2",
"@reduxjs/toolkit": "2.7.0",
"@roarr/browser-log-writer": "^1.3.0",
"@xyflow/react": "^12.7.1",
"ag-psd": "^28.2.1",
"@xyflow/react": "^12.6.0",
"async-mutex": "^0.5.0",
"chakra-react-select": "^4.9.2",
"cmdk": "^1.1.1",
"compare-versions": "^6.1.1",
"dockview": "^4.4.0",
"es-toolkit": "^1.39.5",
"filesize": "^10.1.6",
"fracturedjsonjs": "^4.1.0",
"fracturedjsonjs": "^4.0.2",
"framer-motion": "^11.10.0",
"i18next": "^25.2.1",
"i18next": "^25.0.1",
"i18next-http-backend": "^3.0.2",
"idb-keyval": "6.2.1",
"idb-keyval": "^6.2.1",
"jsondiffpatch": "^0.7.3",
"konva": "^9.3.20",
"linkify-react": "^4.3.1",
"linkifyjs": "^4.3.1",
"linkify-react": "^4.2.0",
"linkifyjs": "^4.2.0",
"lodash-es": "^4.17.21",
"lru-cache": "^11.1.0",
"mtwist": "^1.0.2",
"nanoid": "^5.1.5",
"nanostores": "^1.0.1",
"new-github-issue-url": "^1.1.0",
"overlayscrollbars": "^2.11.4",
"overlayscrollbars": "^2.11.1",
"overlayscrollbars-react": "^0.5.6",
"perfect-freehand": "^1.2.2",
"query-string": "^9.2.1",
"query-string": "^9.1.1",
"raf-throttle": "^2.0.6",
"react": "^18.3.1",
"react-colorful": "^5.6.1",
"react-dom": "^18.3.1",
"react-dropzone": "^14.3.8",
"react-error-boundary": "^5.0.0",
"react-hook-form": "^7.58.1",
"react-hook-form": "^7.56.1",
"react-hotkeys-hook": "4.5.0",
"react-i18next": "^15.5.3",
"react-i18next": "^15.5.1",
"react-icons": "^5.5.0",
"react-redux": "9.2.0",
"react-resizable-panels": "^3.0.3",
"react-resizable-panels": "^2.1.8",
"react-textarea-autosize": "^8.5.9",
"react-use": "^17.6.0",
"react-virtuoso": "^4.13.0",
"react-virtuoso": "^4.12.6",
"redux-dynamic-middlewares": "^2.2.0",
"redux-remember": "^5.2.0",
"redux-undo": "^1.1.0",
@@ -99,12 +109,12 @@
"roarr": "^7.21.1",
"serialize-error": "^12.0.0",
"socket.io-client": "^4.8.1",
"stable-hash": "^0.0.6",
"use-debounce": "^10.0.5",
"stable-hash": "^0.0.5",
"use-debounce": "^10.0.4",
"use-device-pixel-ratio": "^1.1.2",
"uuid": "^11.1.0",
"zod": "^3.25.67",
"zod-validation-error": "^3.5.2"
"zod": "^3.24.3",
"zod-validation-error": "^3.4.0"
},
"peerDependencies": {
"react": "^18.2.0",
@@ -121,6 +131,7 @@
"@storybook/react": "^8.6.12",
"@storybook/react-vite": "^8.6.12",
"@storybook/theming": "^8.6.12",
"@types/lodash-es": "^4.17.12",
"@types/node": "^22.15.1",
"@types/react": "^18.3.11",
"@types/react-dom": "^18.3.0",
@@ -134,7 +145,7 @@
"eslint": "^8.57.1",
"eslint-plugin-i18next": "^6.1.1",
"eslint-plugin-path": "^1.3.0",
"knip": "^5.61.3",
"knip": "^5.50.5",
"openapi-types": "^12.1.3",
"openapi-typescript": "^7.6.1",
"prettier": "^3.5.3",
@@ -143,7 +154,7 @@
"tsafe": "^1.8.5",
"type-fest": "^4.40.0",
"typescript": "^5.8.3",
"vite": "^7.0.2",
"vite": "^6.3.3",
"vite-plugin-css-injected-by-js": "^3.5.2",
"vite-plugin-dts": "^4.5.3",
"vite-plugin-eslint": "^1.8.1",
@@ -151,7 +162,7 @@
"vitest": "^3.1.2"
},
"engines": {
"pnpm": "10"
"pnpm": "8"
},
"packageManager": "pnpm@10.12.4"
"packageManager": "pnpm@8.15.9+sha512.499434c9d8fdd1a2794ebf4552b3b25c0a633abcee5bb15e7b5de90f32f47b513aca98cd5cfd001c31f0db454bc3804edccd578501e4ca293a6816166bbd9f81"
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +0,0 @@
onlyBuiltDependencies:
- '@swc/core'
- esbuild

View File

@@ -225,16 +225,7 @@
"prompt": {
"addPromptTrigger": "Add Prompt Trigger",
"compatibleEmbeddings": "Compatible Embeddings",
"noMatchingTriggers": "No matching triggers",
"generateFromImage": "Generate prompt from image",
"expandCurrentPrompt": "Expand Current Prompt",
"uploadImageForPromptGeneration": "Upload Image for Prompt Generation",
"expandingPrompt": "Expanding prompt...",
"resultTitle": "Prompt Expansion Complete",
"resultSubtitle": "Choose how to handle the expanded prompt:",
"replace": "Replace",
"insert": "Insert",
"discard": "Discard"
"noMatchingTriggers": "No matching triggers"
},
"queue": {
"queue": "Queue",
@@ -344,14 +335,14 @@
"images": "Images",
"assets": "Assets",
"alwaysShowImageSizeBadge": "Always Show Image Size Badge",
"assetsTab": "Files you've uploaded for use in your projects.",
"assetsTab": "Files youve uploaded for use in your projects.",
"autoAssignBoardOnClick": "Auto-Assign Board on Click",
"autoSwitchNewImages": "Auto-Switch to New Images",
"boardsSettings": "Boards Settings",
"copy": "Copy",
"currentlyInUse": "This image is currently in use in the following features:",
"drop": "Drop",
"dropOrUpload": "Drop or Upload",
"dropOrUpload": "$t(gallery.drop) or Upload",
"dropToUpload": "$t(gallery.drop) to Upload",
"deleteImage_one": "Delete Image",
"deleteImage_other": "Delete {{count}} Images",
@@ -366,7 +357,7 @@
"gallerySettings": "Gallery Settings",
"go": "Go",
"image": "image",
"imagesTab": "Images you've created and saved within Invoke.",
"imagesTab": "Images youve created and saved within Invoke.",
"imagesSettings": "Gallery Images Settings",
"jump": "Jump",
"loading": "Loading",
@@ -405,8 +396,7 @@
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit.",
"openViewer": "Open Viewer",
"closeViewer": "Close Viewer",
"move": "Move",
"useForPromptGeneration": "Use for Prompt Generation"
"move": "Move"
},
"hotkeys": {
"hotkeys": "Hotkeys",
@@ -589,16 +579,6 @@
"cancelTransform": {
"title": "Cancel Transform",
"desc": "Cancel the pending transform."
},
"settings": {
"behavior": "Behavior",
"display": "Display",
"grid": "Grid",
"debug": "Debug"
},
"toggleNonRasterLayers": {
"title": "Toggle Non-Raster Layers",
"desc": "Show or hide all non-raster layer categories (Control Layers, Inpaint Masks, Regional Guidance)."
}
},
"workflows": {
@@ -762,7 +742,7 @@
"vae": "VAE",
"width": "Width",
"workflow": "Workflow",
"canvasV2Metadata": "Canvas Layers"
"canvasV2Metadata": "Canvas"
},
"modelManager": {
"active": "active",
@@ -783,7 +763,7 @@
"convertToDiffusers": "Convert To Diffusers",
"convertToDiffusersHelpText1": "This model will be converted to the 🧨 Diffusers format.",
"convertToDiffusersHelpText2": "This process will replace your Model Manager entry with the Diffusers version of the same model.",
"convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in the InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
"convertToDiffusersHelpText3": "Your checkpoint file on disk WILL be deleted if it is in InvokeAI root folder. If it is in a custom location, then it WILL NOT be deleted.",
"convertToDiffusersHelpText4": "This is a one time process only. It might take around 30s-60s depending on the specifications of your computer.",
"convertToDiffusersHelpText5": "Please make sure you have enough disk space. Models generally vary between 2GB-7GB in size.",
"convertToDiffusersHelpText6": "Do you wish to convert this model?",
@@ -826,11 +806,7 @@
"urlUnauthorizedErrorMessage": "You may need to configure an API token to access this model.",
"urlUnauthorizedErrorMessage2": "Learn how here.",
"imageEncoderModelId": "Image Encoder Model ID",
"installedModelsCount": "{{installed}} of {{total}} models installed.",
"includesNModels": "Includes {{n}} models and their dependencies.",
"allNModelsInstalled": "All {{count}} models installed",
"nToInstall": "{{count}} to install",
"nAlreadyInstalled": "{{count}} already installed",
"includesNModels": "Includes {{n}} models and their dependencies",
"installQueue": "Install Queue",
"inplaceInstall": "In-place install",
"inplaceInstallDesc": "Install models without copying the files. When using the model, it will be loaded from its this location. If disabled, the model file(s) will be copied into the Invoke-managed models directory during installation.",
@@ -893,25 +869,6 @@
"starterBundleHelpText": "Easily install all models needed to get started with a base model, including a main model, controlnets, IP adapters, and more. Selecting a bundle will skip any models that you already have installed.",
"starterModels": "Starter Models",
"starterModelsInModelManager": "Starter Models can be found in Model Manager",
"bundleAlreadyInstalled": "Bundle already installed",
"bundleAlreadyInstalledDesc": "All models in the {{bundleName}} bundle are already installed.",
"launchpadTab": "Launchpad",
"launchpad": {
"welcome": "Welcome to Model Management",
"description": "Invoke requires models to be installed to utilize most features of the platform. Choose from manual installation options or explore curated starter models.",
"manualInstall": "Manual Installation",
"urlDescription": "Install models from a URL or local file path. Perfect for specific models you want to add.",
"huggingFaceDescription": "Browse and install models directly from HuggingFace repositories.",
"scanFolderDescription": "Scan a local folder to automatically detect and install models.",
"recommendedModels": "Recommended Models",
"exploreStarter": "Or browse all available starter models",
"quickStart": "Quick Start Bundles",
"bundleDescription": "Each bundle includes essential models for each model family and curated base models to get started.",
"browseAll": "Or browse all available models:",
"stableDiffusion15": "Stable Diffusion 1.5",
"sdxl": "SDXL",
"fluxDev": "FLUX.1 dev"
},
"controlLora": "Control LoRA",
"llavaOnevision": "LLaVA OneVision",
"syncModels": "Sync Models",
@@ -948,8 +905,7 @@
"selectModel": "Select a Model",
"noLoRAsInstalled": "No LoRAs installed",
"noRefinerModelsInstalled": "No SDXL Refiner models installed",
"defaultVAE": "Default VAE",
"noCompatibleLoRAs": "No Compatible LoRAs"
"defaultVAE": "Default VAE"
},
"nodes": {
"arithmeticSequence": "Arithmetic Sequence",
@@ -1199,9 +1155,7 @@
"canvasIsSelectingObject": "Canvas is busy (selecting object)",
"noPrompts": "No prompts generated",
"noNodesInGraph": "No nodes in graph",
"systemDisconnected": "System disconnected",
"promptExpansionPending": "Prompt expansion in progress",
"promptExpansionResultPending": "Please accept or discard your prompt expansion result"
"systemDisconnected": "System disconnected"
},
"maskBlur": "Mask Blur",
"negativePromptPlaceholder": "Negative Prompt",
@@ -1359,21 +1313,6 @@
"problemCopyingLayer": "Unable to Copy Layer",
"problemSavingLayer": "Unable to Save Layer",
"problemDownloadingImage": "Unable to Download Image",
"noRasterLayers": "No Raster Layers Found",
"noRasterLayersDesc": "Create at least one raster layer to export to PSD",
"noActiveRasterLayers": "No Active Raster Layers",
"noActiveRasterLayersDesc": "Enable at least one raster layer to export to PSD",
"noVisibleRasterLayers": "No Visible Raster Layers",
"noVisibleRasterLayersDesc": "Enable at least one raster layer to export to PSD",
"invalidCanvasDimensions": "Invalid Canvas Dimensions",
"canvasTooLarge": "Canvas Too Large",
"canvasTooLargeDesc": "Canvas dimensions exceed the maximum allowed size for PSD export. Reduce the total width and height of the canvas of the canvas and try again.",
"failedToProcessLayers": "Failed to Process Layers",
"psdExportSuccess": "PSD Export Complete",
"psdExportSuccessDesc": "Successfully exported {{count}} layers to PSD file",
"problemExportingPSD": "Problem Exporting PSD",
"canvasManagerNotAvailable": "Canvas Manager Not Available",
"noValidLayerAdapters": "No Valid Layer Adapters Found",
"pasteSuccess": "Pasted to {{destination}}",
"pasteFailed": "Paste Failed",
"prunedQueue": "Pruned Queue",
@@ -1399,15 +1338,10 @@
"fluxFillIncompatibleWithT2IAndI2I": "FLUX Fill is not compatible with Text to Image or Image to Image. Use other FLUX models for these tasks.",
"imagenIncompatibleGenerationMode": "Google {{model}} supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"chatGPT4oIncompatibleGenerationMode": "ChatGPT 4o supports Text to Image and Image to Image only. Use other models Inpainting and Outpainting tasks.",
"fluxKontextIncompatibleGenerationMode": "FLUX Kontext does not support generation from images placed on the canvas. Re-try using the Reference Image section and disable any Raster Layers.",
"fluxKontextIncompatibleGenerationMode": "Flux Kontext supports Text to Image only. Use other models for Image to Image, Inpainting and Outpainting tasks.",
"problemUnpublishingWorkflow": "Problem Unpublishing Workflow",
"problemUnpublishingWorkflowDescription": "There was a problem unpublishing the workflow. Please try again.",
"workflowUnpublished": "Workflow Unpublished",
"sentToCanvas": "Sent to Canvas",
"sentToUpscale": "Sent to Upscale",
"promptGenerationStarted": "Prompt generation started",
"uploadAndPromptGenerationFailed": "Failed to upload image and generate prompt",
"promptExpansionFailed": "We ran into an issue. Please try prompt expansion again."
"workflowUnpublished": "Workflow Unpublished"
},
"popovers": {
"clipSkip": {
@@ -1930,7 +1864,6 @@
"saveCanvasToGallery": "Save Canvas to Gallery",
"saveBboxToGallery": "Save Bbox to Gallery",
"saveLayerToAssets": "Save Layer to Assets",
"exportCanvasToPSD": "Export Canvas to PSD",
"cropLayerToBbox": "Crop Layer to Bbox",
"savedToGalleryOk": "Saved to Gallery",
"savedToGalleryError": "Error saving to gallery",
@@ -1956,13 +1889,11 @@
"mergingLayers": "Merging layers",
"clearHistory": "Clear History",
"bboxOverlay": "Show Bbox Overlay",
"ruleOfThirds": "Show Rule of Thirds",
"newSession": "New Session",
"clearCaches": "Clear Caches",
"recalculateRects": "Recalculate Rects",
"clipToBbox": "Clip Strokes to Bbox",
"outputOnlyMaskedRegions": "Output Only Generated Regions",
"saveAllImagesToGallery": "Save All Images to Gallery",
"addLayer": "Add Layer",
"duplicate": "Duplicate",
"moveToFront": "Move to Front",
@@ -2063,8 +1994,6 @@
"disableTransparencyEffect": "Disable Transparency Effect",
"hidingType": "Hiding {{type}}",
"showingType": "Showing {{type}}",
"showNonRasterLayers": "Show Non-Raster Layers (Shift+H)",
"hideNonRasterLayers": "Hide Non-Raster Layers (Shift+H)",
"dynamicGrid": "Dynamic Grid",
"logDebugInfo": "Log Debug Info",
"locked": "Locked",
@@ -2088,9 +2017,7 @@
"resetGenerationSettings": "Reset Generation Settings",
"replaceCurrent": "Replace Current",
"controlLayerEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, <PullBboxButton>pull the bounding box into this layer</PullBboxButton>, or draw on the canvas to get started.",
"referenceImageEmptyStateWithCanvasOptions": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image or <PullBboxButton>pull the bounding box into this Reference Image</PullBboxButton> to get started.",
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton> or drag an image from the <GalleryButton>gallery</GalleryButton> onto this Reference Image to get started.",
"uploadOrDragAnImage": "Drag an image from the gallery or <UploadButton>upload an image</UploadButton>.",
"referenceImageEmptyState": "<UploadButton>Upload an image</UploadButton>, drag an image from the <GalleryButton>gallery</GalleryButton> onto this layer, or <PullBboxButton>pull the bounding box into this layer</PullBboxButton> to get started.",
"imageNoise": "Image Noise",
"denoiseLimit": "Denoise Limit",
"warnings": {
@@ -2331,9 +2258,6 @@
"label": "Preserve Masked Region",
"alert": "Preserving Masked Region"
},
"saveAllImagesToGallery": {
"alert": "Saving All Images to Gallery"
},
"isolatedStagingPreview": "Isolated Staging Preview",
"isolatedPreview": "Isolated Preview",
"isolatedLayerPreview": "Isolated Layer Preview",
@@ -2362,7 +2286,6 @@
"newGlobalReferenceImage": "New Global Reference Image",
"newRegionalReferenceImage": "New Regional Reference Image",
"newControlLayer": "New Control Layer",
"newResizedControlLayer": "New Resized Control Layer",
"newRasterLayer": "New Raster Layer",
"newInpaintMask": "New Inpaint Mask",
"newRegionalGuidance": "New Regional Guidance",
@@ -2380,11 +2303,6 @@
"saveToGallery": "Save To Gallery",
"showResultsOn": "Showing Results",
"showResultsOff": "Hiding Results"
},
"autoSwitch": {
"off": "Off",
"switchOnStart": "On Start",
"switchOnFinish": "On Finish"
}
},
"upscaling": {
@@ -2451,8 +2369,7 @@
"uploadImage": "Upload Image",
"useForTemplate": "Use For Prompt Template",
"viewList": "View Template List",
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box.",
"togglePromptPreviews": "Toggle Prompt Previews"
"viewModeTooltip": "This is how your prompt will look with your currently selected template. To edit your prompt, click anywhere in the text box."
},
"upsell": {
"inviteTeammates": "Invite Teammates",
@@ -2472,55 +2389,6 @@
"upscaling": "Upscaling",
"upscalingTab": "$t(ui.tabs.upscaling) $t(common.tab)",
"gallery": "Gallery"
},
"launchpad": {
"workflowsTitle": "Go deep with Workflows.",
"upscalingTitle": "Upscale and add detail.",
"canvasTitle": "Edit and refine on Canvas.",
"generateTitle": "Generate images from text prompts.",
"modelGuideText": "Want to learn what prompts work best for each model?",
"modelGuideLink": "Check out our Model Guide.",
"workflows": {
"description": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results.",
"learnMoreLink": "Learn more about creating workflows",
"browseTemplates": {
"title": "Browse Workflow Templates",
"description": "Choose from pre-built workflows for common tasks"
},
"createNew": {
"title": "Create a new Workflow",
"description": "Start a new workflow from scratch"
},
"loadFromFile": {
"title": "Load workflow from file",
"description": "Upload a workflow to start with an existing setup"
}
},
"upscaling": {
"uploadImage": {
"title": "Upload Image to Upscale",
"description": "Click or drag an image to upscale (JPG, PNG, WebP up to 100MB)"
},
"replaceImage": {
"title": "Replace Current Image",
"description": "Click or drag a new image to replace the current one"
},
"imageReady": {
"title": "Image Ready",
"description": "Press Invoke to begin upscaling"
},
"readyToUpscale": {
"title": "Ready to upscale!",
"description": "Configure your settings below, then click the Invoke button to begin upscaling your image."
},
"upscaleModel": "Upscale Model",
"model": "Model",
"scale": "Scale",
"helpText": {
"promptAdvice": "When upscaling, use a prompt that describes the medium and style. Avoid describing specific content details in the image.",
"styleAdvice": "Upscaling works best with the general style of your image."
}
}
}
},
"system": {
@@ -2560,9 +2428,8 @@
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",
"items": [
"Generate images faster with new Launchpads and a simplified Generate tab.",
"Edit with prompts using Flux Kontext Dev.",
"Export to PSD, bulk-hide overlays, organize models & images — all in a reimagined interface built for control."
"Inpainting: Per-mask noise levels and denoise limits.",
"Canvas: Smarter aspect ratios for SDXL and improved scroll-to-zoom."
],
"readReleaseNotes": "Read Release Notes",
"watchRecentReleaseVideos": "Watch Recent Release Videos",
@@ -2571,16 +2438,62 @@
"supportVideos": {
"supportVideos": "Support Videos",
"gettingStarted": "Getting Started",
"controlCanvas": "Control Canvas",
"watch": "Watch",
"studioSessionsDesc": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
"studioSessionsDesc1": "Check out the <StudioSessionsPlaylistLink /> for Invoke deep dives.",
"studioSessionsDesc2": "Join our <DiscordLink /> to participate in the live sessions and ask questions. Sessions are uploaded to the playlist the following week.",
"videos": {
"gettingStarted": {
"title": "Getting Started with Invoke",
"description": "Complete video series covering everything you need to know to get started with Invoke, from creating your first image to advanced techniques."
"creatingYourFirstImage": {
"title": "Creating Your First Image",
"description": "Introduction to creating an image from scratch using Invoke's tools."
},
"studioSessions": {
"title": "Studio Sessions",
"description": "Deep dive sessions exploring advanced Invoke features, creative workflows, and community discussions."
"usingControlLayersAndReferenceGuides": {
"title": "Using Control Layers and Reference Guides",
"description": "Learn how to guide your image creation with control layers and reference images."
},
"understandingImageToImageAndDenoising": {
"title": "Understanding Image-to-Image and Denoising",
"description": "Overview of image-to-image transformations and denoising in Invoke."
},
"exploringAIModelsAndConceptAdapters": {
"title": "Exploring AI Models and Concept Adapters",
"description": "Dive into AI models and how to use concept adapters for creative control."
},
"creatingAndComposingOnInvokesControlCanvas": {
"title": "Creating and Composing on Invoke's Control Canvas",
"description": "Learn to compose images using Invoke's control canvas."
},
"upscaling": {
"title": "Upscaling",
"description": "How to upscale images with Invoke's tools to enhance resolution."
},
"howDoIGenerateAndSaveToTheGallery": {
"title": "How Do I Generate and Save to the Gallery?",
"description": "Steps to generate and save images to the gallery."
},
"howDoIEditOnTheCanvas": {
"title": "How Do I Edit on the Canvas?",
"description": "Guide to editing images directly on the canvas."
},
"howDoIDoImageToImageTransformation": {
"title": "How Do I Do Image-to-Image Transformation?",
"description": "Tutorial on performing image-to-image transformations in Invoke."
},
"howDoIUseControlNetsAndControlLayers": {
"title": "How Do I Use Control Nets and Control Layers?",
"description": "Learn to apply control layers and controlnets to your images."
},
"howDoIUseGlobalIPAdaptersAndReferenceImages": {
"title": "How Do I Use Global IP Adapters and Reference Images?",
"description": "Introduction to adding reference images and global IP adapters."
},
"howDoIUseInpaintMasks": {
"title": "How Do I Use Inpaint Masks?",
"description": "How to apply inpaint masks for image correction and variation."
},
"howDoIOutpaint": {
"title": "How Do I Outpaint?",
"description": "Guide to outpainting beyond the original image borders."
}
}
}

View File

@@ -2,7 +2,8 @@ import { Box } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { GlobalHookIsolator } from 'app/components/GlobalHookIsolator';
import { GlobalModalIsolator } from 'app/components/GlobalModalIsolator';
import { $didStudioInit, type StudioInitAction } from 'app/hooks/useStudioInitAction';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { $didStudioInit } from 'app/hooks/useStudioInitAction';
import type { PartialAppConfig } from 'app/types/invokeai';
import Loading from 'common/components/Loading/Loading';
import { useClearStorage } from 'common/hooks/useClearStorage';
@@ -11,7 +12,6 @@ import { memo, useCallback } from 'react';
import { ErrorBoundary } from 'react-error-boundary';
import AppErrorBoundaryFallback from './AppErrorBoundaryFallback';
import ThemeLocaleProvider from './ThemeLocaleProvider';
const DEFAULT_CONFIG = {};
interface Props {
@@ -31,14 +31,12 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
return (
<ErrorBoundary onReset={handleReset} FallbackComponent={AppErrorBoundaryFallback}>
<ThemeLocaleProvider>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{!didStudioInit && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />
</ThemeLocaleProvider>
<Box id="invoke-app-wrapper" w="100dvw" h="100dvh" position="relative" overflow="hidden">
<AppContent />
{!didStudioInit && <Loading />}
</Box>
<GlobalHookIsolator config={config} studioInitAction={studioInitAction} />
<GlobalModalIsolator />
</ErrorBoundary>
);
};

View File

@@ -1,5 +1,4 @@
import { useGlobalModifiersInit } from '@invoke-ai/ui-library';
import { setupListeners } from '@reduxjs/toolkit/query';
import type { StudioInitAction } from 'app/hooks/useStudioInitAction';
import { useStudioInitAction } from 'app/hooks/useStudioInitAction';
import { useSyncQueueStatus } from 'app/hooks/useSyncQueueStatus';
@@ -9,24 +8,19 @@ import { appStarted } from 'app/store/middleware/listenerMiddleware/listeners/ap
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { PartialAppConfig } from 'app/types/invokeai';
import { useFocusRegionWatcher } from 'common/hooks/focus';
import { useCloseChakraTooltipsOnDragFix } from 'common/hooks/useCloseChakraTooltipsOnDragFix';
import { useGlobalHotkeys } from 'common/hooks/useGlobalHotkeys';
import { useDndMonitor } from 'features/dnd/useDndMonitor';
import { useDynamicPromptsWatcher } from 'features/dynamicPrompts/hooks/useDynamicPromptsWatcher';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { useWorkflowBuilderWatcher } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
import { useReadinessWatcher } from 'features/queue/store/readiness';
import { configChanged } from 'features/system/store/configSlice';
import { selectLanguage } from 'features/system/store/systemSelectors';
import { useNavigationApi } from 'features/ui/layouts/use-navigation-api';
import i18n from 'i18n';
import { size } from 'lodash-es';
import { memo, useEffect } from 'react';
import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo';
import { useGetQueueCountsByDestinationQuery } from 'services/api/endpoints/queue';
import { useSocketIO } from 'services/events/useSocketIO';
const queueCountArg = { destination: 'canvas' };
/**
* GlobalHookIsolator is a logical component that runs global hooks in an isolated component, so that they do not
* cause needless re-renders of any other components.
@@ -44,31 +38,22 @@ export const GlobalHookIsolator = memo(
useGlobalHotkeys();
useGetOpenAPISchemaQuery();
useSyncLoggingConfig();
useCloseChakraTooltipsOnDragFix();
useNavigationApi();
useDndMonitor();
// Persistent subscription to the queue counts query - canvas relies on this to know if there are pending
// and/or in progress canvas sessions.
useGetQueueCountsByDestinationQuery(queueCountArg);
useEffect(() => {
i18n.changeLanguage(language);
}, [language]);
useEffect(() => {
logger.info({ config }, 'Received config');
dispatch(configChanged(config));
if (size(config)) {
logger.info({ config }, 'Received config');
dispatch(configChanged(config));
}
}, [dispatch, config, logger]);
useEffect(() => {
dispatch(appStarted());
}, [dispatch]);
useEffect(() => {
return setupListeners(dispatch);
}, [dispatch]);
useStudioInitAction(studioInitAction);
useStarterModelsToast();
useSyncQueueStatus();

View File

@@ -1,22 +1,17 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useIsRegionFocused } from 'common/hooks/focus';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { useLoadWorkflow } from 'features/gallery/hooks/useLoadWorkflow';
import { useRecallAll } from 'features/gallery/hooks/useRecallAll';
import { useRecallDimensions } from 'features/gallery/hooks/useRecallDimensions';
import { useRecallPrompts } from 'features/gallery/hooks/useRecallPrompts';
import { useRecallRemix } from 'features/gallery/hooks/useRecallRemix';
import { useRecallSeed } from 'features/gallery/hooks/useRecallSeed';
import { selectIsStaging } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useImageDTO } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const GlobalImageHotkeys = memo(() => {
useAssertSingleton('GlobalImageHotkeys');
const imageName = useAppSelector(selectLastSelectedImage);
const imageDTO = useImageDTO(imageName);
const imageDTO = useAppSelector(selectLastSelectedImage);
if (!imageDTO) {
return null;
@@ -30,64 +25,59 @@ GlobalImageHotkeys.displayName = 'GlobalImageHotkeys';
const GlobalImageHotkeysInternal = memo(({ imageDTO }: { imageDTO: ImageDTO }) => {
const isGalleryFocused = useIsRegionFocused('gallery');
const isViewerFocused = useIsRegionFocused('viewer');
const isFocusOK = isGalleryFocused || isViewerFocused;
const recallAll = useRecallAll(imageDTO);
const recallRemix = useRecallRemix(imageDTO);
const recallPrompts = useRecallPrompts(imageDTO);
const recallSeed = useRecallSeed(imageDTO);
const recallDimensions = useRecallDimensions(imageDTO);
const loadWorkflow = useLoadWorkflow(imageDTO);
const imageActions = useImageActions(imageDTO);
const isStaging = useAppSelector(selectIsStaging);
const isUpscalingEnabled = useFeatureStatus('upscaling');
useRegisteredHotkeys({
id: 'loadWorkflow',
category: 'viewer',
callback: loadWorkflow.load,
options: { enabled: loadWorkflow.isEnabled && isFocusOK },
dependencies: [loadWorkflow, isFocusOK],
callback: imageActions.loadWorkflow,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.loadWorkflow, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'recallAll',
category: 'viewer',
callback: recallAll.recall,
options: { enabled: recallAll.isEnabled && isFocusOK },
dependencies: [recallAll, isFocusOK],
callback: imageActions.recallAll,
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
dependencies: [imageActions.recallAll, isStaging, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'recallSeed',
category: 'viewer',
callback: recallSeed.recall,
options: { enabled: recallSeed.isEnabled && isFocusOK },
dependencies: [recallSeed, isFocusOK],
callback: imageActions.recallSeed,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.recallSeed, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'recallPrompts',
category: 'viewer',
callback: recallPrompts.recall,
options: { enabled: recallPrompts.isEnabled && isFocusOK },
dependencies: [recallPrompts, isFocusOK],
callback: imageActions.recallPrompts,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.recallPrompts, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'remix',
category: 'viewer',
callback: recallRemix.recall,
options: { enabled: recallRemix.isEnabled && isFocusOK },
dependencies: [recallRemix, isFocusOK],
callback: imageActions.remix,
options: { enabled: isGalleryFocused || isViewerFocused },
dependencies: [imageActions.remix, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'useSize',
category: 'viewer',
callback: recallDimensions.recall,
options: { enabled: recallDimensions.isEnabled && isFocusOK },
dependencies: [recallDimensions, isFocusOK],
callback: imageActions.recallSize,
options: { enabled: !isStaging && (isGalleryFocused || isViewerFocused) },
dependencies: [imageActions.recallSize, isStaging, isGalleryFocused, isViewerFocused],
});
useRegisteredHotkeys({
id: 'runPostprocessing',
category: 'viewer',
callback: imageActions.upscale,
options: { enabled: isUpscalingEnabled && isViewerFocused },
dependencies: [isUpscalingEnabled, imageDTO, isViewerFocused],
});
return null;
});

View File

@@ -6,7 +6,7 @@ import {
NewGallerySessionDialog,
} from 'features/controlLayers/components/NewSessionConfirmationAlertDialog';
import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal';
import DeleteImageModal from 'features/deleteImageModal/components/DeleteImageModal';
import { FullscreenDropzone } from 'features/dnd/FullscreenDropzone';
import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicPromptsPreviewModal';
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
@@ -15,7 +15,6 @@ import { ShareWorkflowModal } from 'features/nodes/components/sidePanel/workflow
import { WorkflowLibraryModal } from 'features/nodes/components/sidePanel/workflow/WorkflowLibrary/WorkflowLibraryModal';
import { CancelAllExceptCurrentQueueItemConfirmationAlertDialog } from 'features/queue/components/CancelAllExceptCurrentQueueItemConfirmationAlertDialog';
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { DeleteAllExceptCurrentQueueItemConfirmationAlertDialog } from 'features/queue/components/DeleteAllExceptCurrentQueueItemConfirmationAlertDialog';
import { DeleteStylePresetDialog } from 'features/stylePresets/components/DeleteStylePresetDialog';
import { StylePresetModal } from 'features/stylePresets/components/StylePresetForm/StylePresetModal';
import RefreshAfterResetModal from 'features/system/components/SettingsModal/RefreshAfterResetModal';
@@ -40,7 +39,6 @@ export const GlobalModalIsolator = memo(() => {
<StylePresetModal />
<WorkflowLibraryModal />
<CancelAllExceptCurrentQueueItemConfirmationAlertDialog />
<DeleteAllExceptCurrentQueueItemConfirmationAlertDialog />
<ClearQueueConfirmationsAlertDialog />
<NewWorkflowConfirmationAlertDialog />
<LoadWorkflowConfirmationAlertDialog />

View File

@@ -42,6 +42,7 @@ import { $socketOptions } from 'services/events/stores';
import type { ManagerOptions, SocketOptions } from 'socket.io-client';
const App = lazy(() => import('./App'));
const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider'));
interface Props extends PropsWithChildren {
apiUrl?: string;
@@ -329,7 +330,9 @@ const InvokeAIUI = ({
<React.StrictMode>
<Provider store={store}>
<React.Suspense fallback={<Loading />}>
<App config={config} studioInitAction={studioInitAction} />
<ThemeLocaleProvider>
<App config={config} studioInitAction={studioInitAction} />
</ThemeLocaleProvider>
</React.Suspense>
</Provider>
</React.StrictMode>

View File

@@ -1,7 +1,6 @@
import '@fontsource-variable/inter';
import 'overlayscrollbars/overlayscrollbars.css';
import '@xyflow/react/dist/base.css';
import 'common/components/OverlayScrollbars/overlayscrollbars.css';
import { ChakraProvider, DarkMode, extendTheme, theme as _theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
import type { ReactNode } from 'react';

View File

@@ -3,12 +3,13 @@ import { useAppStore } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { withResultAsync } from 'common/util/result';
import { canvasReset } from 'features/controlLayers/store/actions';
import { settingsSendToCanvasChanged } from 'features/controlLayers/store/canvasSettingsSlice';
import { rasterLayerAdded } from 'features/controlLayers/store/canvasSlice';
import { paramsReset } from 'features/controlLayers/store/paramsSlice';
import type { CanvasRasterLayerState } from 'features/controlLayers/store/types';
import { imageDTOToImageObject } from 'features/controlLayers/store/util';
import { $imageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { sentImageToCanvas } from 'features/gallery/store/actions';
import { MetadataUtils } from 'features/metadata/parsing';
import { parseAndRecallAllMetadata } from 'features/metadata/util/handlers';
import { $hasTemplates } from 'features/nodes/store/nodesSlice';
import { $isWorkflowLibraryModalOpen } from 'features/nodes/store/workflowLibraryModal';
import {
@@ -19,9 +20,7 @@ import {
} from 'features/nodes/store/workflowLibrarySlice';
import { $isStylePresetsMenuOpen, activeStylePresetIdChanged } from 'features/stylePresets/store/stylePresetSlice';
import { toast } from 'features/toast/toast';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { LAUNCHPAD_PANEL_ID, WORKSPACE_PANEL_ID } from 'features/ui/layouts/shared';
import { activeTabCanvasRightPanelChanged } from 'features/ui/store/uiSlice';
import { activeTabCanvasRightPanelChanged, setActiveTab } from 'features/ui/store/uiSlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import { atom } from 'nanostores';
import { useCallback, useEffect } from 'react';
@@ -92,10 +91,12 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
const overrides: Partial<CanvasRasterLayerState> = {
objects: [imageObject],
};
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
store.dispatch(canvasReset());
store.dispatch(rasterLayerAdded({ overrides, isSelected: true }));
store.dispatch(settingsSendToCanvasChanged(true));
store.dispatch(setActiveTab('canvas'));
store.dispatch(sentImageToCanvas());
$imageViewer.set(false);
toast({
title: t('toast.sentToCanvas'),
status: 'info',
@@ -117,25 +118,25 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
return;
}
const metadata = getImageMetadataResult.value;
store.dispatch(canvasReset());
// This shows a toast
await MetadataUtils.recallAll(metadata, store);
await parseAndRecallAllMetadata(metadata, true);
store.dispatch(setActiveTab('canvas'));
},
[store, t]
);
const handleLoadWorkflow = useCallback(
(workflowId: string) => {
async (workflowId: string) => {
// This shows a toast
loadWorkflowWithDialog({
await loadWorkflowWithDialog({
type: 'library',
data: workflowId,
onSuccess: () => {
navigationApi.switchToTab('workflows');
store.dispatch(setActiveTab('workflows'));
},
});
},
[loadWorkflowWithDialog]
[loadWorkflowWithDialog, store]
);
const handleSelectStylePreset = useCallback(
@@ -149,7 +150,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
return;
}
store.dispatch(activeStylePresetIdChanged(stylePresetId));
navigationApi.switchToTab('canvas');
store.dispatch(setActiveTab('canvas'));
toast({
title: t('toast.stylePresetLoaded'),
status: 'info',
@@ -159,34 +160,37 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
);
const handleGoToDestination = useCallback(
async (destination: StudioDestinationAction['data']['destination']) => {
(destination: StudioDestinationAction['data']['destination']) => {
switch (destination) {
case 'generation':
// Go to the generate tab, open the launchpad
await navigationApi.focusPanel('generate', LAUNCHPAD_PANEL_ID);
store.dispatch(paramsReset());
// Go to the canvas tab, open the image viewer, and enable send-to-gallery mode
store.dispatch(setActiveTab('canvas'));
store.dispatch(activeTabCanvasRightPanelChanged('gallery'));
store.dispatch(settingsSendToCanvasChanged(false));
$imageViewer.set(true);
break;
case 'canvas':
// Go to the canvas tab, open the launchpad
await navigationApi.focusPanel('canvas', WORKSPACE_PANEL_ID);
// Go to the canvas tab, close the image viewer, and disable send-to-gallery mode
store.dispatch(setActiveTab('canvas'));
store.dispatch(settingsSendToCanvasChanged(true));
$imageViewer.set(false);
break;
case 'workflows':
// Go to the workflows tab
navigationApi.switchToTab('workflows');
store.dispatch(setActiveTab('workflows'));
break;
case 'upscaling':
// Go to the upscaling tab
navigationApi.switchToTab('upscaling');
store.dispatch(setActiveTab('upscaling'));
break;
case 'viewAllWorkflows':
// Go to the workflows tab and open the workflow library modal
navigationApi.switchToTab('workflows');
store.dispatch(setActiveTab('workflows'));
$isWorkflowLibraryModalOpen.set(true);
break;
case 'viewAllWorkflowsRecommended':
// Go to the workflows tab and open the workflow library modal with the recommended workflows view
navigationApi.switchToTab('workflows');
store.dispatch(setActiveTab('workflows'));
$isWorkflowLibraryModalOpen.set(true);
store.dispatch(workflowLibraryViewChanged('defaults'));
store.dispatch(workflowLibraryTagsReset());
@@ -198,7 +202,7 @@ export const useStudioInitAction = (action?: StudioInitAction) => {
break;
case 'viewAllStylePresets':
// Go to the canvas tab and open the style presets menu
navigationApi.switchToTab('canvas');
store.dispatch(setActiveTab('canvas'));
$isStylePresetsMenuOpen.set(true);
break;
}

View File

@@ -2,7 +2,7 @@ import { createLogWriter } from '@roarr/browser-log-writer';
import { atom } from 'nanostores';
import type { Logger, MessageSerializer } from 'roarr';
import { ROARR, Roarr } from 'roarr';
import { z } from 'zod/v4';
import { z } from 'zod';
const serializeMessage: MessageSerializer = (message) => {
return JSON.stringify(message);

View File

@@ -1,13 +1,13 @@
import { objectEquals } from '@observ33r/object-equals';
import { createDraftSafeSelectorCreator, createSelectorCreator, lruMemoize } from '@reduxjs/toolkit';
import { isEqual } from 'lodash-es';
/**
* A memoized selector creator that uses LRU cache and @observ33r/object-equals's objectEquals for equality check.
* A memoized selector creator that uses LRU cache and lodash's isEqual for equality check.
*/
export const createMemoizedSelector = createSelectorCreator({
memoize: lruMemoize,
memoizeOptions: {
resultEqualityCheck: objectEquals,
resultEqualityCheck: isEqual,
},
argsMemoize: lruMemoize,
});

View File

@@ -8,13 +8,10 @@ import { diff } from 'jsondiffpatch';
* Super simple logger middleware. Useful for debugging when the redux devtools are awkward.
*/
export const getDebugLoggerMiddleware =
(options?: { filter?: (action: unknown) => boolean; withDiff?: boolean; withNextState?: boolean }): Middleware =>
(options?: { withDiff?: boolean; withNextState?: boolean }): Middleware =>
(api: MiddlewareAPI) =>
(next) =>
(action) => {
if (options?.filter?.(action)) {
return next(action);
}
const originalState = api.getState();
console.log('REDUX: dispatching', action);
const result = next(action);

View File

@@ -1,6 +1,7 @@
import type { TypedStartListening } from '@reduxjs/toolkit';
import { addListener, createListenerMiddleware } from '@reduxjs/toolkit';
import { addAdHocPostProcessingRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/addAdHocPostProcessingRequestedListener';
import { addStagingListeners } from 'app/store/middleware/listenerMiddleware/listeners/addCommitStagingAreaImageListener';
import { addAnyEnqueuedListener } from 'app/store/middleware/listenerMiddleware/listeners/anyEnqueued';
import { addAppConfigReceivedListener } from 'app/store/middleware/listenerMiddleware/listeners/appConfigReceived';
import { addAppStartedListener } from 'app/store/middleware/listenerMiddleware/listeners/appStarted';
@@ -8,9 +9,16 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
import { addGalleryOffsetChangedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryOffsetChanged';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageDeletionListeners } from 'app/store/middleware/listenerMiddleware/listeners/imageDeletionListeners';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
import { addImagesStarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesStarred';
import { addImagesUnstarredListener } from 'app/store/middleware/listenerMiddleware/listeners/imagesUnstarred';
import { addImageToDeleteSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/imageToDeleteSelected';
import { addImageUploadedFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageUploaded';
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
@@ -19,6 +27,7 @@ import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMi
import type { AppDispatch, RootState } from 'app/store/store';
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
import { addEnqueueRequestedUpscale } from './listeners/enqueueRequestedUpscale';
export const listenerMiddleware = createListenerMiddleware();
@@ -38,12 +47,27 @@ export const addAppListener = addListener.withTypes<RootState, AppDispatch>();
addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addImageDeletionListeners(startAppListening);
addDeleteBoardAndImagesFulfilledListener(startAppListening);
addImageToDeleteSelectedListener(startAppListening);
// Image starred
addImagesStarredListener(startAppListening);
addImagesUnstarredListener(startAppListening);
// Gallery
addGalleryImageClickedListener(startAppListening);
addGalleryOffsetChangedListener(startAppListening);
// User Invoked
addEnqueueRequestedLinear(startAppListening);
addEnqueueRequestedUpscale(startAppListening);
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);
// Canvas actions
addStagingListeners(startAppListening);
// Socket.IO
addSocketConnectedEventListener(startAppListening);

View File

@@ -25,7 +25,7 @@ export const addArchivedOrDeletedBoardListener = (startAppListening: AppStartLis
matcher: matchAnyBoardDeleted,
effect: (action, { dispatch, getState }) => {
const state = getState();
const deletedBoardId = action.meta.arg.originalArgs.board_id;
const deletedBoardId = action.meta.arg.originalArgs;
const { autoAddBoardId, selectedBoardId } = state.gallery;
// If the deleted board was currently selected, we should reset the selected board to uncategorized

View File

@@ -0,0 +1,46 @@
import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { canvasReset, newSessionRequested } from 'features/controlLayers/store/actions';
import { stagingAreaReset } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { queueApi } from 'services/api/endpoints/queue';
const log = logger('canvas');
const matchCanvasOrStagingAreaReset = isAnyOf(stagingAreaReset, canvasReset, newSessionRequested);
export const addStagingListeners = (startAppListening: AppStartListening) => {
startAppListening({
matcher: matchCanvasOrStagingAreaReset,
effect: async (_, { dispatch }) => {
try {
const req = dispatch(
queueApi.endpoints.cancelByBatchDestination.initiate(
{ destination: 'canvas' },
{ fixedCacheKey: 'cancelByBatchOrigin' }
)
);
const { canceled } = await req.unwrap();
req.reset();
if (canceled > 0) {
log.debug(`Canceled ${canceled} canvas batches`);
toast({
id: 'CANCEL_BATCH_SUCCEEDED',
title: t('queue.cancelBatchSucceeded'),
status: 'success',
});
}
} catch {
log.error('Failed to cancel canvas batches');
toast({
id: 'CANCEL_BATCH_FAILED',
title: t('queue.cancelBatchFailed'),
status: 'error',
});
}
},
});
};

View File

@@ -1,29 +1,15 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
export const appStarted = createAction('app/appStarted');
export const addAppStartedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: appStarted,
effect: async (action, { unsubscribe, cancelActiveListeners, take, getState, dispatch }) => {
effect: (action, { unsubscribe, cancelActiveListeners }) => {
// this should only run once
cancelActiveListeners();
unsubscribe();
// ensure an image is selected when we load the first board
const firstImageLoad = await take(imagesApi.endpoints.getImageNames.matchFulfilled);
if (firstImageLoad !== null) {
const [{ payload }] = firstImageLoad;
const selectedImage = selectLastSelectedImage(getState());
if (selectedImage) {
return;
}
dispatch(imageSelected(payload.image_names.at(0) ?? null));
}
},
});
};

View File

@@ -1,9 +1,9 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { truncate } from 'es-toolkit/compat';
import { zPydanticValidationError } from 'features/system/store/zodSchemas';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { truncate } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { queueApi } from 'services/api/endpoints/queue';
import type { JsonObject } from 'type-fest';

View File

@@ -1,7 +1,6 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getImageUsage } from 'features/deleteImageModal/store/state';
import { getImageUsage } from 'features/deleteImageModal/store/selectors';
import { nodeEditorReset } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
@@ -21,10 +20,9 @@ export const addDeleteBoardAndImagesFulfilledListener = (startAppListening: AppS
const nodes = selectNodesSlice(state);
const canvas = selectCanvasSlice(state);
const upscale = selectUpscaleSlice(state);
const refImages = selectRefImagesSlice(state);
deleted_images.forEach((image_name) => {
const imageUsage = getImageUsage(nodes, canvas, upscale, refImages, image_name);
const imageUsage = getImageUsage(nodes, canvas, upscale, image_name);
if (imageUsage.isNodesImage && !wasNodeEditorReset) {
dispatch(nodeEditorReset());

View File

@@ -1,6 +1,6 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectGetImageNamesQueryArgs, selectSelectedBoardId } from 'features/gallery/store/gallerySelectors';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
@@ -11,35 +11,36 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
// Cancel any in-progress instances of this listener, we don't want to select an image from a previous board
cancelActiveListeners();
if (boardIdSelected.match(action) && action.payload.selectedImageName) {
// This action already has a selected image name, we trust it is valid
return;
}
const state = getState();
const board_id = selectSelectedBoardId(state);
const queryArgs = { ...selectGetImageNamesQueryArgs(state), board_id };
const queryArgs = selectListImagesQueryArgs(state);
// wait until the board has some images - maybe it already has some from a previous fetch
// must use getState() to ensure we do not have stale state
const isSuccess = await condition(
() => imagesApi.endpoints.getImageNames.select(queryArgs)(getState()).isSuccess,
() => imagesApi.endpoints.listImages.select(queryArgs)(getState()).isSuccess,
5000
);
if (!isSuccess) {
if (isSuccess) {
// the board was just changed - we can select the first image
const { data: boardImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
if (boardImagesData && boardIdSelected.match(action) && action.payload.selectedImageName) {
const selectedImage = boardImagesData.items.find(
(item) => item.image_name === action.payload.selectedImageName
);
dispatch(imageSelected(selectedImage || null));
} else if (boardImagesData) {
dispatch(imageSelected(boardImagesData.items[0] || null));
} else {
// board has no images - deselect
dispatch(imageSelected(null));
}
} else {
// fallback - deselect
dispatch(imageSelected(null));
return;
}
// the board was just changed - we can select the first image
const imageNames = imagesApi.endpoints.getImageNames.select(queryArgs)(getState()).data?.image_names;
const imageToSelect = imageNames?.at(0) ?? null;
dispatch(imageSelected(imageToSelect));
},
});
};

View File

@@ -0,0 +1,126 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
const log = logger('generation');
export const enqueueRequestedCanvas = createAction<{ prepend: boolean }>('app/enqueueRequestedCanvas');
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: enqueueRequestedCanvas,
effect: async (action, { getState, dispatch }) => {
log.debug('Enqueue requested');
const state = getState();
const { prepend } = action.payload;
const manager = $canvasManager.get();
assert(manager, 'No canvas manager');
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
switch (base) {
case 'sdxl':
return await buildSDXLGraph(state, manager);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(state, manager);
case `sd-3`:
return await buildSD3Graph(state, manager);
case `flux`:
return await buildFLUXGraph(state, manager);
case 'cogview4':
return await buildCogView4Graph(state, manager);
case 'imagen3':
return await buildImagen3Graph(state, manager);
case 'imagen4':
return await buildImagen4Graph(state, manager);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, manager);
case 'flux-kontext':
return await buildFluxKontextGraph(state, manager);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status,
title,
description,
});
return;
}
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
const destination = state.canvasSettings.sendToCanvas ? 'canvas' : 'gallery';
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'canvas',
destination,
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, enqueueMutationFixedCacheKeyOptions)
);
try {
await req.unwrap();
log.debug(parseify({ batchConfig: prepareBatchResult.value }), 'Enqueued batch');
} catch (error) {
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
} finally {
req.reset();
}
},
});
};

View File

@@ -0,0 +1,44 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
const log = logger('generation');
export const enqueueRequestedUpscaling = createAction<{ prepend: boolean }>('app/enqueueRequestedUpscaling');
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: enqueueRequestedUpscaling,
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { prepend } = action.payload;
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'upscaling',
destination: 'gallery',
});
const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions));
try {
await req.unwrap();
log.debug(parseify({ batchConfig }), 'Enqueued batch');
} catch (error) {
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
} finally {
req.reset();
}
},
});
};

View File

@@ -0,0 +1,73 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const galleryImageClicked = createAction<{
imageDTO: ImageDTO;
shiftKey: boolean;
ctrlKey: boolean;
metaKey: boolean;
altKey: boolean;
}>('gallery/imageClicked');
/**
* This listener handles the logic for selecting images in the gallery.
*
* Previously, this logic was in a `useCallback` with the whole gallery selection as a dependency. Every time
* the selection changed, the callback got recreated and all images rerendered. This could easily block for
* hundreds of ms, more for lower end devices.
*
* Moving this logic into a listener means we don't need to recalculate anything dynamically and the gallery
* is much more responsive.
*/
export const addGalleryImageClickedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: galleryImageClicked,
effect: (action, { dispatch, getState }) => {
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
const queryResult = imagesApi.endpoints.listImages.select(queryArgs)(state);
if (!queryResult.data) {
// Should never happen if we have clicked a gallery image
return;
}
const imageDTOs = queryResult.data.items;
const selection = state.gallery.selection;
if (altKey) {
if (state.gallery.imageToCompare?.image_name === imageDTO.image_name) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(imageToCompareChanged(imageDTO));
}
} else if (shiftKey) {
const rangeEndImageName = imageDTO.image_name;
const lastSelectedImage = selection[selection.length - 1]?.image_name;
const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage);
const currentClickedIndex = imageDTOs.findIndex((n) => n.image_name === rangeEndImageName);
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
// We have a valid range!
const start = Math.min(lastClickedIndex, currentClickedIndex);
const end = Math.max(lastClickedIndex, currentClickedIndex);
const imagesToSelect = imageDTOs.slice(start, end + 1);
dispatch(selectionChanged(selection.concat(imagesToSelect)));
}
} else if (ctrlKey || metaKey) {
if (selection.some((i) => i.image_name === imageDTO.image_name) && selection.length > 1) {
dispatch(selectionChanged(selection.filter((n) => n.image_name !== imageDTO.image_name)));
} else {
dispatch(selectionChanged(selection.concat(imageDTO)));
}
} else {
dispatch(selectionChanged([imageDTO]));
}
},
});
};

View File

@@ -0,0 +1,119 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, offsetChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
export const addGalleryOffsetChangedListener = (startAppListening: AppStartListening) => {
/**
* When the user changes pages in the gallery, we need to wait until the next page of images is loaded, then maybe
* update the selection.
*
* There are a three scenarios:
*
* 1. The page is changed by clicking the pagination buttons. No changes to selection are needed.
*
* 2. The page is changed by using the arrow keys (without alt).
* - When going backwards, select the last image.
* - When going forwards, select the first image.
*
* 3. The page is changed by using the arrows keys with alt. This means the user is changing the comparison image.
* - When going backwards, select the last image _as the comparison image_.
* - When going forwards, select the first image _as the comparison image_.
*/
startAppListening({
actionCreator: offsetChanged,
effect: async (action, { dispatch, getState, getOriginalState, take, cancelActiveListeners }) => {
// Cancel any active listeners to prevent the selection from changing without user input
cancelActiveListeners();
const { withHotkey } = action.payload;
if (!withHotkey) {
// User changed pages by clicking the pagination buttons - no changes to selection
return;
}
const originalState = getOriginalState();
const prevOffset = originalState.gallery.offset;
const offset = getState().gallery.offset;
if (offset === prevOffset) {
// The page didn't change - bail
return;
}
/**
* We need to wait until the next page of images is loaded before updating the selection, so we use the correct
* page of images.
*
* The simplest way to do it would be to use `take` to wait for the next fulfilled action, but RTK-Q doesn't
* dispatch an action on cache hits. This means the `take` will only return if the cache is empty. If the user
* changes to a cached page - a common situation - the `take` will never resolve.
*
* So we need to take a two-step approach. First, check if we have data in the cache for the page of images. If
* we have data cached, use it to update the selection. If we don't have data cached, wait for the next fulfilled
* action, which updates the cache, then use the cache to update the selection.
*/
// Check if we have data in the cache for the page of images
const queryArgs = selectListImagesQueryArgs(getState());
let { data } = imagesApi.endpoints.listImages.select(queryArgs)(getState());
// No data yet - wait for the network request to complete
if (!data) {
const takeResult = await take(imagesApi.endpoints.listImages.matchFulfilled, 5000);
if (!takeResult) {
// The request didn't complete in time - bail
return;
}
data = takeResult[0].payload;
}
// We awaited a network request - state could have changed, get fresh state
const state = getState();
const { selection, imageToCompare } = state.gallery;
const imageDTOs = data?.items;
if (!imageDTOs) {
// The page didn't load - bail
return;
}
if (withHotkey === 'arrow') {
// User changed pages by using the arrow keys - selection changes to first or last image depending
if (offset < prevOffset) {
// We've gone backwards
const lastImage = imageDTOs[imageDTOs.length - 1];
if (!selection.some((selectedImage) => selectedImage.image_name === lastImage?.image_name)) {
dispatch(selectionChanged(lastImage ? [lastImage] : []));
}
} else {
// We've gone forwards
const firstImage = imageDTOs[0];
if (!selection.some((selectedImage) => selectedImage.image_name === firstImage?.image_name)) {
dispatch(selectionChanged(firstImage ? [firstImage] : []));
}
}
return;
}
if (withHotkey === 'alt+arrow') {
// User changed pages by using the arrow keys with alt - comparison image changes to first or last depending
if (offset < prevOffset) {
// We've gone backwards
const lastImage = imageDTOs[imageDTOs.length - 1];
if (lastImage && imageToCompare?.image_name !== lastImage.image_name) {
dispatch(imageToCompareChanged(lastImage));
}
} else {
// We've gone forwards
const firstImage = imageDTOs[0];
if (firstImage && imageToCompare?.image_name !== firstImage.image_name) {
dispatch(imageToCompareChanged(firstImage));
}
}
return;
}
},
});
};

View File

@@ -1,9 +1,9 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { size } from 'es-toolkit/compat';
import { $templates } from 'features/nodes/store/nodesSlice';
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
import { size } from 'lodash-es';
import { serializeError } from 'serialize-error';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { JsonObject } from 'type-fest';

View File

@@ -8,16 +8,16 @@ export const addImageAddedToBoardFulfilledListener = (startAppListening: AppStar
startAppListening({
matcher: imagesApi.endpoints.addImageToBoard.matchFulfilled,
effect: (action) => {
const { board_id, image_name } = action.meta.arg.originalArgs;
log.debug({ board_id, image_name }, 'Image added to board');
const { board_id, imageDTO } = action.meta.arg.originalArgs;
log.debug({ board_id, imageDTO }, 'Image added to board');
},
});
startAppListening({
matcher: imagesApi.endpoints.addImageToBoard.matchRejected,
effect: (action) => {
const { board_id, image_name } = action.meta.arg.originalArgs;
log.debug({ board_id, image_name }, 'Problem adding image to board');
const { board_id, imageDTO } = action.meta.arg.originalArgs;
log.debug({ board_id, imageDTO }, 'Problem adding image to board');
},
});
};

View File

@@ -0,0 +1,221 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, RootState } from 'app/store/store';
import { entityDeleted, referenceImageIPAdapterImageChanged } from 'features/controlLayers/store/canvasSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import { getEntityIdentifier } from 'features/controlLayers/store/types';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { isModalOpenChanged } from 'features/deleteImageModal/store/slice';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { isImageFieldCollectionInputInstance, isImageFieldInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { forEach, intersectionBy } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import type { Param0 } from 'tsafe';
const log = logger('gallery');
//TODO(psyche): handle image deletion (canvas staging area?)
// Some utils to delete images from different parts of the app
const deleteNodesImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
const actions: Param0<typeof dispatch>[] = [];
state.nodes.present.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
forEach(node.data.inputs, (input) => {
if (isImageFieldInputInstance(input) && input.value?.image_name === imageDTO.image_name) {
actions.push(
fieldImageValueChanged({
nodeId: node.data.id,
fieldName: input.name,
value: undefined,
})
);
return;
}
if (isImageFieldCollectionInputInstance(input)) {
actions.push(
fieldImageCollectionValueChanged({
nodeId: node.data.id,
fieldName: input.name,
value: input.value?.filter((value) => value?.image_name !== imageDTO.image_name),
})
);
}
});
});
actions.forEach(dispatch);
};
const deleteControlLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).controlLayers.entities.forEach(({ id, objects }) => {
let shouldDelete = false;
for (const obj of objects) {
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
shouldDelete = true;
break;
}
}
if (shouldDelete) {
dispatch(entityDeleted({ entityIdentifier: { id, type: 'control_layer' } }));
}
});
};
const deleteReferenceImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
if (entity.ipAdapter.image?.image_name === imageDTO.image_name) {
dispatch(referenceImageIPAdapterImageChanged({ entityIdentifier: getEntityIdentifier(entity), imageDTO: null }));
}
});
};
const deleteRasterLayerImages = (state: RootState, dispatch: AppDispatch, imageDTO: ImageDTO) => {
selectCanvasSlice(state).rasterLayers.entities.forEach(({ id, objects }) => {
let shouldDelete = false;
for (const obj of objects) {
if (obj.type === 'image' && obj.image.image_name === imageDTO.image_name) {
shouldDelete = true;
break;
}
}
if (shouldDelete) {
dispatch(entityDeleted({ entityIdentifier: { id, type: 'raster_layer' } }));
}
});
};
export const addImageDeletionListeners = (startAppListening: AppStartListening) => {
// Handle single image deletion
startAppListening({
actionCreator: imageDeletionConfirmed,
effect: async (action, { dispatch, getState }) => {
const { imageDTOs, imagesUsage } = action.payload;
if (imageDTOs.length !== 1 || imagesUsage.length !== 1) {
// handle multiples in separate listener
return;
}
const imageDTO = imageDTOs[0];
const imageUsage = imagesUsage[0];
if (!imageDTO || !imageUsage) {
// satisfy noUncheckedIndexedAccess
return;
}
try {
const state = getState();
await dispatch(imagesApi.endpoints.deleteImage.initiate(imageDTO)).unwrap();
if (state.gallery.selection.some((i) => i.image_name === imageDTO.image_name)) {
// The deleted image was a selected image, we need to select the next image
const newSelection = state.gallery.selection.filter((i) => i.image_name !== imageDTO.image_name);
if (newSelection.length > 0) {
return;
}
// Get the current list of images and select the same index
const baseQueryArgs = selectListImagesQueryArgs(state);
const data = imagesApi.endpoints.listImages.select(baseQueryArgs)(state).data;
if (data) {
const deletedImageIndex = data.items.findIndex((i) => i.image_name === imageDTO.image_name);
const nextImage = data.items[deletedImageIndex + 1] ?? data.items[0] ?? null;
if (nextImage?.image_name === imageDTO.image_name) {
// If the next image is the same as the deleted one, it means it was the last image, reset selection
dispatch(imageSelected(null));
} else {
dispatch(imageSelected(nextImage));
}
}
}
deleteNodesImages(state, dispatch, imageDTO);
deleteReferenceImages(state, dispatch, imageDTO);
deleteRasterLayerImages(state, dispatch, imageDTO);
deleteControlLayerImages(state, dispatch, imageDTO);
} catch {
// no-op
} finally {
dispatch(isModalOpenChanged(false));
}
},
});
// Handle multiple image deletion
startAppListening({
actionCreator: imageDeletionConfirmed,
effect: async (action, { dispatch, getState }) => {
const { imageDTOs, imagesUsage } = action.payload;
if (imageDTOs.length <= 1 || imagesUsage.length <= 1) {
// handle singles in separate listener
return;
}
try {
const state = getState();
await dispatch(imagesApi.endpoints.deleteImages.initiate({ imageDTOs })).unwrap();
if (intersectionBy(state.gallery.selection, imageDTOs, 'image_name').length > 0) {
// Some selected images were deleted, need to select the next image
const queryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state);
if (data) {
// When we delete multiple images, we clear the selection. Then, the the next time we load images, we will
// select the first one. This is handled below in the listener for `imagesApi.endpoints.listImages.matchFulfilled`.
dispatch(imageSelected(null));
}
}
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
imageDTOs.forEach((imageDTO) => {
deleteNodesImages(state, dispatch, imageDTO);
deleteControlLayerImages(state, dispatch, imageDTO);
deleteReferenceImages(state, dispatch, imageDTO);
deleteRasterLayerImages(state, dispatch, imageDTO);
});
} catch {
// no-op
} finally {
dispatch(isModalOpenChanged(false));
}
},
});
// When we list images, if no images is selected, select the first one.
startAppListening({
matcher: imagesApi.endpoints.listImages.matchFulfilled,
effect: (action, { dispatch, getState }) => {
const selection = getState().gallery.selection;
if (selection.length === 0) {
dispatch(imageSelected(action.payload.items[0] ?? null));
}
},
});
startAppListening({
matcher: imagesApi.endpoints.deleteImage.matchFulfilled,
effect: (action) => {
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Image deleted');
},
});
startAppListening({
matcher: imagesApi.endpoints.deleteImage.matchRejected,
effect: (action) => {
log.debug({ imageDTO: action.meta.arg.originalArgs }, 'Unable to delete image');
},
});
};

View File

@@ -0,0 +1,32 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions';
import { selectImageUsage } from 'features/deleteImageModal/store/selectors';
import { imagesToDeleteSelected, isModalOpenChanged } from 'features/deleteImageModal/store/slice';
export const addImageToDeleteSelectedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: imagesToDeleteSelected,
effect: (action, { dispatch, getState }) => {
const imageDTOs = action.payload;
const state = getState();
const { shouldConfirmOnDelete } = state.system;
const imagesUsage = selectImageUsage(getState());
const isImageInUse =
imagesUsage.some((i) => i.isRasterLayerImage) ||
imagesUsage.some((i) => i.isControlLayerImage) ||
imagesUsage.some((i) => i.isReferenceImage) ||
imagesUsage.some((i) => i.isInpaintMaskImage) ||
imagesUsage.some((i) => i.isUpscaleImage) ||
imagesUsage.some((i) => i.isNodesImage) ||
imagesUsage.some((i) => i.isRegionalGuidanceImage);
if (shouldConfirmOnDelete || isImageInUse) {
dispatch(isModalOpenChanged(true));
return;
}
dispatch(imageDeletionConfirmed({ imageDTOs, imagesUsage }));
},
});
};

View File

@@ -2,12 +2,12 @@ import { isAnyOf } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import { omit } from 'es-toolkit/compat';
import { imageUploadedClientSide } from 'features/gallery/store/actions';
import { selectListBoardsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged } from 'features/gallery/store/gallerySlice';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { omit } from 'lodash-es';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';

View File

@@ -0,0 +1,30 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const addImagesStarredListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.starImages.matchFulfilled,
effect: (action, { dispatch, getState }) => {
const { updated_image_names: starredImages } = action.payload;
const state = getState();
const { selection } = state.gallery;
const updatedSelection: ImageDTO[] = [];
selection.forEach((selectedImageDTO) => {
if (starredImages.includes(selectedImageDTO.image_name)) {
updatedSelection.push({
...selectedImageDTO,
starred: true,
});
} else {
updatedSelection.push(selectedImageDTO);
}
});
dispatch(selectionChanged(updatedSelection));
},
});
};

View File

@@ -0,0 +1,30 @@
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
export const addImagesUnstarredListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: imagesApi.endpoints.unstarImages.matchFulfilled,
effect: (action, { dispatch, getState }) => {
const { updated_image_names: unstarredImages } = action.payload;
const state = getState();
const { selection } = state.gallery;
const updatedSelection: ImageDTO[] = [];
selection.forEach((selectedImageDTO) => {
if (unstarredImages.includes(selectedImageDTO.image_name)) {
updatedSelection.push({
...selectedImageDTO,
starred: false,
});
} else {
updatedSelection.push(selectedImageDTO);
}
});
dispatch(selectionChanged(updatedSelection));
},
});
};

Some files were not shown because too many files have changed in this diff Show More