diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index d183c614c4..a9bcc9f768 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -1,7 +1,7 @@ import io import json import traceback -from typing import ClassVar, Literal, Optional +from typing import ClassVar, Optional from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi.responses import FileResponse @@ -14,7 +14,6 @@ from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_i from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ( ImageCategory, - ImageCollectionCounts, ImageRecordChanges, ResourceOrigin, ) @@ -565,67 +564,6 @@ async def get_bulk_download_item( raise HTTPException(status_code=404) -@images_router.get( - "/collections/counts", operation_id="get_image_collection_counts", response_model=ImageCollectionCounts -) -async def get_image_collection_counts( - image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to count."), - categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."), - is_intermediate: Optional[bool] = Query(default=None, description="Whether to include intermediate images."), - board_id: Optional[str] = Query( - default=None, - description="The board id to filter by. Use 'none' to find images without a board.", - ), - search_term: Optional[str] = Query(default=None, description="The term to search for"), -) -> ImageCollectionCounts: - """Gets counts for starred and unstarred image collections""" - - try: - return ApiDependencies.invoker.services.images.get_collection_counts( - image_origin=image_origin, - categories=categories, - is_intermediate=is_intermediate, - board_id=board_id, - search_term=search_term, - ) - except Exception: - raise HTTPException(status_code=500, detail="Failed to get collection counts") - - -@images_router.get("/collections/{collection}", operation_id="get_image_collection") -async def get_image_collection( - collection: Literal["starred", "unstarred"] = Path(..., description="The collection to retrieve from"), - image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."), - categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."), - is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."), - board_id: Optional[str] = Query( - default=None, - description="The board id to filter by. Use 'none' to find images without a board.", - ), - offset: int = Query(default=0, description="The offset within the collection"), - limit: int = Query(default=50, description="The number of images to return"), - order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"), - search_term: Optional[str] = Query(default=None, description="The term to search for"), -) -> OffsetPaginatedResults[ImageDTO]: - """Gets images from a specific collection (starred or unstarred)""" - - try: - image_dtos = ApiDependencies.invoker.services.images.get_collection_images( - collection=collection, - offset=offset, - limit=limit, - order_dir=order_dir, - image_origin=image_origin, - categories=categories, - is_intermediate=is_intermediate, - board_id=board_id, - search_term=search_term, - ) - return image_dtos - except Exception: - raise HTTPException(status_code=500, detail="Failed to get collection images") - - @images_router.get("/names", operation_id="get_image_names") async def get_image_names( image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."), @@ -636,12 +574,14 @@ async def get_image_names( description="The board id to filter by. Use 'none' to find images without a board.", ), order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"), + starred_first: bool = Query(default=True, description="Whether to sort by starred images first"), search_term: Optional[str] = Query(default=None, description="The term to search for"), ) -> list[str]: """Gets ordered list of all image names (starred first, then unstarred)""" try: image_names = ApiDependencies.invoker.services.images.get_image_names( + starred_first=starred_first, order_dir=order_dir, image_origin=image_origin, categories=categories, diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 622d8ea60f..4f2341fbcd 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -587,9 +587,9 @@ def invocation( for field_name, field_info in cls.model_fields.items(): annotation = field_info.annotation assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation." - assert isinstance(field_info.json_schema_extra, dict), ( - f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?" - ) + assert isinstance( + field_info.json_schema_extra, dict + ), f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?" original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info) @@ -712,9 +712,9 @@ def invocation_output( for field_name, field_info in cls.model_fields.items(): annotation = field_info.annotation assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation." - assert isinstance(field_info.json_schema_extra, dict), ( - f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?" - ) + assert isinstance( + field_info.json_schema_extra, dict + ), f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?" cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info) diff --git a/invokeai/app/invocations/segment_anything.py b/invokeai/app/invocations/segment_anything.py index 4d82624edf..6b2decff18 100644 --- a/invokeai/app/invocations/segment_anything.py +++ b/invokeai/app/invocations/segment_anything.py @@ -184,9 +184,9 @@ class SegmentAnythingInvocation(BaseInvocation): # Find the largest mask. return [max(masks, key=lambda x: float(x.sum()))] elif self.mask_filter == "highest_box_score": - assert bounding_boxes is not None, ( - "Bounding boxes must be provided to use the 'highest_box_score' mask filter." - ) + assert ( + bounding_boxes is not None + ), "Bounding boxes must be provided to use the 'highest_box_score' mask filter." assert len(masks) == len(bounding_boxes) # Find the index of the bounding box with the highest score. # Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 4dabac964b..18fd5c70db 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -482,9 +482,9 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: try: # Meta is not included in the model fields, so we need to validate it separately config = InvokeAIAppConfig.model_validate(loaded_config_dict) - assert config.schema_version == CONFIG_SCHEMA_VERSION, ( - f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}" - ) + assert ( + config.schema_version == CONFIG_SCHEMA_VERSION + ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}" return config except Exception as e: raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e diff --git a/invokeai/app/services/image_records/image_records_base.py b/invokeai/app/services/image_records/image_records_base.py index e640e3facb..128ced7b09 100644 --- a/invokeai/app/services/image_records/image_records_base.py +++ b/invokeai/app/services/image_records/image_records_base.py @@ -1,11 +1,10 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import Literal, Optional +from typing import Optional from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ( ImageCategory, - ImageCollectionCounts, ImageRecord, ImageRecordChanges, ResourceOrigin, @@ -99,37 +98,10 @@ class ImageRecordStorageBase(ABC): """Gets the most recent image for a board.""" pass - @abstractmethod - def get_collection_counts( - self, - image_origin: Optional[ResourceOrigin] = None, - categories: Optional[list[ImageCategory]] = None, - is_intermediate: Optional[bool] = None, - board_id: Optional[str] = None, - search_term: Optional[str] = None, - ) -> ImageCollectionCounts: - """Gets counts for starred and unstarred image collections.""" - pass - - @abstractmethod - def get_collection_images( - self, - collection: Literal["starred", "unstarred"], - offset: int = 0, - limit: int = 10, - order_dir: SQLiteDirection = SQLiteDirection.Descending, - image_origin: Optional[ResourceOrigin] = None, - categories: Optional[list[ImageCategory]] = None, - is_intermediate: Optional[bool] = None, - board_id: Optional[str] = None, - search_term: Optional[str] = None, - ) -> OffsetPaginatedResults[ImageRecord]: - """Gets images from a specific collection (starred or unstarred).""" - pass - @abstractmethod def get_image_names( self, + starred_first: bool = True, order_dir: SQLiteDirection = SQLiteDirection.Descending, image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index 704b99bd77..3086880560 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -1,13 +1,12 @@ import sqlite3 from datetime import datetime -from typing import Literal, Optional, Union, cast +from typing import Optional, Union, cast from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase from invokeai.app.services.image_records.image_records_common import ( IMAGE_DTO_COLS, ImageCategory, - ImageCollectionCounts, ImageRecord, ImageRecordChanges, ImageRecordDeleteException, @@ -388,182 +387,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): return deserialize_image_record(dict(result)) - def get_collection_counts( - self, - image_origin: Optional[ResourceOrigin] = None, - categories: Optional[list[ImageCategory]] = None, - is_intermediate: Optional[bool] = None, - board_id: Optional[str] = None, - search_term: Optional[str] = None, - ) -> ImageCollectionCounts: - cursor = self._conn.cursor() - - # Build the base query conditions (same as get_many) - base_query = """--sql - FROM images - LEFT JOIN board_images ON board_images.image_name = images.image_name - WHERE 1=1 - """ - - query_conditions = "" - query_params: list[Union[int, str, bool]] = [] - - if image_origin is not None: - query_conditions += """--sql - AND images.image_origin = ? - """ - query_params.append(image_origin.value) - - if categories is not None: - category_strings = [c.value for c in set(categories)] - placeholders = ",".join("?" * len(category_strings)) - query_conditions += f"""--sql - AND images.image_category IN ( {placeholders} ) - """ - for c in category_strings: - query_params.append(c) - - if is_intermediate is not None: - query_conditions += """--sql - AND images.is_intermediate = ? - """ - query_params.append(is_intermediate) - - if board_id == "none": - query_conditions += """--sql - AND board_images.board_id IS NULL - """ - elif board_id is not None: - query_conditions += """--sql - AND board_images.board_id = ? - """ - query_params.append(board_id) - - if search_term: - query_conditions += """--sql - AND ( - images.metadata LIKE ? - OR images.created_at LIKE ? - ) - """ - query_params.append(f"%{search_term.lower()}%") - query_params.append(f"%{search_term.lower()}%") - - # Get starred count - starred_query = f"SELECT COUNT(*) {base_query} {query_conditions} AND images.starred = TRUE;" - cursor.execute(starred_query, query_params) - starred_count = cast(int, cursor.fetchone()[0]) - - # Get unstarred count - unstarred_query = f"SELECT COUNT(*) {base_query} {query_conditions} AND images.starred = FALSE;" - cursor.execute(unstarred_query, query_params) - unstarred_count = cast(int, cursor.fetchone()[0]) - - return ImageCollectionCounts(starred_count=starred_count, unstarred_count=unstarred_count) - - def get_collection_images( - self, - collection: Literal["starred", "unstarred"], - offset: int = 0, - limit: int = 10, - order_dir: SQLiteDirection = SQLiteDirection.Descending, - image_origin: Optional[ResourceOrigin] = None, - categories: Optional[list[ImageCategory]] = None, - is_intermediate: Optional[bool] = None, - board_id: Optional[str] = None, - search_term: Optional[str] = None, - ) -> OffsetPaginatedResults[ImageRecord]: - cursor = self._conn.cursor() - - # Base queries - count_query = """--sql - SELECT COUNT(*) - FROM images - LEFT JOIN board_images ON board_images.image_name = images.image_name - WHERE 1=1 - """ - - images_query = f"""--sql - SELECT {IMAGE_DTO_COLS} - FROM images - LEFT JOIN board_images ON board_images.image_name = images.image_name - WHERE 1=1 - """ - - query_conditions = "" - query_params: list[Union[int, str, bool]] = [] - - # Add starred/unstarred filter - is_starred = collection == "starred" - query_conditions += """--sql - AND images.starred = ? - """ - query_params.append(is_starred) - - if image_origin is not None: - query_conditions += """--sql - AND images.image_origin = ? - """ - query_params.append(image_origin.value) - - if categories is not None: - category_strings = [c.value for c in set(categories)] - placeholders = ",".join("?" * len(category_strings)) - query_conditions += f"""--sql - AND images.image_category IN ( {placeholders} ) - """ - for c in category_strings: - query_params.append(c) - - if is_intermediate is not None: - query_conditions += """--sql - AND images.is_intermediate = ? - """ - query_params.append(is_intermediate) - - if board_id == "none": - query_conditions += """--sql - AND board_images.board_id IS NULL - """ - elif board_id is not None: - query_conditions += """--sql - AND board_images.board_id = ? - """ - query_params.append(board_id) - - if search_term: - query_conditions += """--sql - AND ( - images.metadata LIKE ? - OR images.created_at LIKE ? - ) - """ - query_params.append(f"%{search_term.lower()}%") - query_params.append(f"%{search_term.lower()}%") - - # Add ordering and pagination - query_pagination = f"""--sql - ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ? - """ - - # Execute images query - images_query += query_conditions + query_pagination + ";" - images_params = query_params.copy() - images_params.extend([limit, offset]) - - cursor.execute(images_query, images_params) - result = cast(list[sqlite3.Row], cursor.fetchall()) - images = [deserialize_image_record(dict(r)) for r in result] - - # Execute count query - count_query += query_conditions + ";" - cursor.execute(count_query, query_params) - count = cast(int, cursor.fetchone()[0]) - - return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count) - def get_image_names( self, + starred_first: bool = True, order_dir: SQLiteDirection = SQLiteDirection.Descending, image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, @@ -625,13 +451,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): query_params.append(f"%{search_term.lower()}%") query_params.append(f"%{search_term.lower()}%") - # Order by starred first, then by created_at - query += ( - query_conditions - + f"""--sql - ORDER BY images.starred DESC, images.created_at {order_dir.value} - """ - ) + if starred_first: + query += ( + query_conditions + + f"""--sql + ORDER BY images.starred DESC, images.created_at {order_dir.value} + """ + ) + else: + query += ( + query_conditions + + f"""--sql + ORDER BY images.created_at {order_dir.value} + """ + ) cursor.execute(query, query_params) result = cast(list[sqlite3.Row], cursor.fetchall()) diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py index 037d463e33..3bf832cc71 100644 --- a/invokeai/app/services/images/images_base.py +++ b/invokeai/app/services/images/images_base.py @@ -1,12 +1,11 @@ from abc import ABC, abstractmethod -from typing import Callable, Literal, Optional +from typing import Callable, Optional from PIL.Image import Image as PILImageType from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ( ImageCategory, - ImageCollectionCounts, ImageRecord, ImageRecordChanges, ResourceOrigin, @@ -149,37 +148,10 @@ class ImageServiceABC(ABC): """Deletes all images on a board.""" pass - @abstractmethod - def get_collection_counts( - self, - image_origin: Optional[ResourceOrigin] = None, - categories: Optional[list[ImageCategory]] = None, - is_intermediate: Optional[bool] = None, - board_id: Optional[str] = None, - search_term: Optional[str] = None, - ) -> ImageCollectionCounts: - """Gets counts for starred and unstarred image collections.""" - pass - - @abstractmethod - def get_collection_images( - self, - collection: Literal["starred", "unstarred"], - offset: int = 0, - limit: int = 10, - order_dir: SQLiteDirection = SQLiteDirection.Descending, - image_origin: Optional[ResourceOrigin] = None, - categories: Optional[list[ImageCategory]] = None, - is_intermediate: Optional[bool] = None, - board_id: Optional[str] = None, - search_term: Optional[str] = None, - ) -> OffsetPaginatedResults[ImageDTO]: - """Gets images from a specific collection (starred or unstarred).""" - pass - @abstractmethod def get_image_names( self, + starred_first: bool = True, order_dir: SQLiteDirection = SQLiteDirection.Descending, image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, @@ -187,5 +159,5 @@ class ImageServiceABC(ABC): board_id: Optional[str] = None, search_term: Optional[str] = None, ) -> list[str]: - """Gets ordered list of all image names (starred first, then unstarred).""" + """Gets ordered list of all image names.""" pass diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py index 6585e7ca05..4547d46c04 100644 --- a/invokeai/app/services/images/images_default.py +++ b/invokeai/app/services/images/images_default.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Optional from PIL.Image import Image as PILImageType @@ -10,7 +10,6 @@ from invokeai.app.services.image_files.image_files_common import ( ) from invokeai.app.services.image_records.image_records_common import ( ImageCategory, - ImageCollectionCounts, ImageRecord, ImageRecordChanges, ImageRecordDeleteException, @@ -311,73 +310,9 @@ class ImageService(ImageServiceABC): self.__invoker.services.logger.error("Problem getting intermediates count") raise e - def get_collection_counts( - self, - image_origin: Optional[ResourceOrigin] = None, - categories: Optional[list[ImageCategory]] = None, - is_intermediate: Optional[bool] = None, - board_id: Optional[str] = None, - search_term: Optional[str] = None, - ) -> ImageCollectionCounts: - try: - return self.__invoker.services.image_records.get_collection_counts( - image_origin=image_origin, - categories=categories, - is_intermediate=is_intermediate, - board_id=board_id, - search_term=search_term, - ) - except Exception as e: - self.__invoker.services.logger.error("Problem getting collection counts") - raise e - - def get_collection_images( - self, - collection: Literal["starred", "unstarred"], - offset: int = 0, - limit: int = 10, - order_dir: SQLiteDirection = SQLiteDirection.Descending, - image_origin: Optional[ResourceOrigin] = None, - categories: Optional[list[ImageCategory]] = None, - is_intermediate: Optional[bool] = None, - board_id: Optional[str] = None, - search_term: Optional[str] = None, - ) -> OffsetPaginatedResults[ImageDTO]: - try: - results = self.__invoker.services.image_records.get_collection_images( - collection=collection, - offset=offset, - limit=limit, - order_dir=order_dir, - image_origin=image_origin, - categories=categories, - is_intermediate=is_intermediate, - board_id=board_id, - search_term=search_term, - ) - - image_dtos = [ - image_record_to_dto( - image_record=r, - image_url=self.__invoker.services.urls.get_image_url(r.image_name), - thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True), - board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name), - ) - for r in results.items - ] - - return OffsetPaginatedResults[ImageDTO]( - items=image_dtos, - offset=results.offset, - limit=results.limit, - total=results.total, - ) - except Exception as e: - self.__invoker.services.logger.error("Problem getting collection images") - raise e - def get_image_names( self, + starred_first: bool = True, order_dir: SQLiteDirection = SQLiteDirection.Descending, image_origin: Optional[ResourceOrigin] = None, categories: Optional[list[ImageCategory]] = None, @@ -387,6 +322,7 @@ class ImageService(ImageServiceABC): ) -> list[str]: try: return self.__invoker.services.image_records.get_image_names( + starred_first=starred_first, order_dir=order_dir, image_origin=image_origin, categories=categories, diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py index b84b226d9f..367c00b503 100644 --- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py +++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py @@ -379,13 +379,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): bytes_ = path.read_bytes() workflow_from_file = WorkflowValidator.validate_json(bytes_) - assert workflow_from_file.id.startswith("default_"), ( - f'Invalid default workflow ID (must start with "default_"): {workflow_from_file.id}' - ) + assert workflow_from_file.id.startswith( + "default_" + ), f'Invalid default workflow ID (must start with "default_"): {workflow_from_file.id}' - assert workflow_from_file.meta.category is WorkflowCategory.Default, ( - f"Invalid default workflow category: {workflow_from_file.meta.category}" - ) + assert ( + workflow_from_file.meta.category is WorkflowCategory.Default + ), f"Invalid default workflow category: {workflow_from_file.meta.category}" workflows_from_file.append(workflow_from_file) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 0c07e6c53e..7521f2c512 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -381,7 +381,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase): state_dict = mod.load_state_dict() for key in state_dict.keys(): - if type(key) is int: + if isinstance(key, int): continue if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")): diff --git a/invokeai/backend/model_manager/merge.py b/invokeai/backend/model_manager/merge.py index 03056b10f5..b00bc99f3e 100644 --- a/invokeai/backend/model_manager/merge.py +++ b/invokeai/backend/model_manager/merge.py @@ -115,19 +115,19 @@ class ModelMerger(object): base_models: Set[BaseModelType] = set() variant = None if self._installer.app_config.precision == "float32" else "fp16" - assert len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference, ( - "When merging three models, only the 'add_difference' merge method is supported" - ) + assert ( + len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference + ), "When merging three models, only the 'add_difference' merge method is supported" for key in model_keys: info = store.get_model(key) model_names.append(info.name) - assert isinstance(info, MainDiffusersConfig), ( - f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging" - ) - assert info.variant == ModelVariantType("normal"), ( - f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged" - ) + assert isinstance( + info, MainDiffusersConfig + ), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging" + assert info.variant == ModelVariantType( + "normal" + ), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged" # tally base models used base_models.add(info.base) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts index 899e88a85c..52f9e5cdeb 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/boardIdSelected.ts @@ -1,6 +1,6 @@ import { isAnyOf } from '@reduxjs/toolkit'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors'; +import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors'; import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice'; import { imagesApi } from 'services/api/endpoints/images'; @@ -13,7 +13,7 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening) const state = getState(); - const queryArgs = selectListImagesQueryArgs(state); + const queryArgs = { ...selectListImagesBaseQueryArgs(state), offset: 0 }; // wait until the board has some images - maybe it already has some from a previous fetch // must use getState() to ensure we do not have stale state diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts index f95b7502f0..efa927ce64 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts @@ -1,29 +1,9 @@ import { createAction } from '@reduxjs/toolkit'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import type { RootState } from 'app/store/store'; -import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors'; +import { selectListImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors'; import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice'; import { uniq } from 'lodash-es'; import { imagesApi } from 'services/api/endpoints/images'; -import type { ImageCategory, SQLiteDirection } from 'services/api/types'; - -// Type for image collection query arguments -type ImageCollectionQueryArgs = { - board_id?: string; - categories?: ImageCategory[]; - search_term?: string; - order_dir?: SQLiteDirection; - is_intermediate: boolean; -}; - -/** - * Helper function to get cached image names list for selection operations - * Returns an ordered array of image names (starred first, then unstarred) - */ -const getCachedImageNames = (state: RootState, queryArgs: ImageCollectionQueryArgs): string[] => { - const queryResult = imagesApi.endpoints.getImageNames.select(queryArgs)(state); - return queryResult.data || []; -}; export const galleryImageClicked = createAction<{ imageName: string; @@ -50,10 +30,8 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen effect: (action, { dispatch, getState }) => { const { imageName, shiftKey, ctrlKey, metaKey, altKey } = action.payload; const state = getState(); - const queryArgs = selectListImagesQueryArgs(state); - - // Get cached image names for selection operations - const imageNames = getCachedImageNames(state, queryArgs); + const queryArgs = selectListImageNamesQueryArgs(state); + const imageNames = imagesApi.endpoints.getImageNames.select(queryArgs)(state).data ?? []; // If we don't have the image names cached, we can't perform selection operations // This can happen if the user clicks on an image before the names are loaded diff --git a/invokeai/frontend/web/src/features/deleteImageModal/store/state.ts b/invokeai/frontend/web/src/features/deleteImageModal/store/state.ts index c04f7dd294..a380ec02d6 100644 --- a/invokeai/frontend/web/src/features/deleteImageModal/store/state.ts +++ b/invokeai/frontend/web/src/features/deleteImageModal/store/state.ts @@ -10,7 +10,6 @@ import { import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; import type { CanvasState, RefImagesState } from 'features/controlLayers/store/types'; import type { ImageUsage } from 'features/deleteImageModal/store/types'; -import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors'; import { imageSelected } from 'features/gallery/store/gallerySlice'; import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; import { selectNodesSlice } from 'features/nodes/store/selectors'; @@ -81,14 +80,8 @@ const handleDeletions = async (image_names: string[], dispatch: AppDispatch, get await dispatch(imagesApi.endpoints.deleteImages.initiate({ image_names }, { track: false })).unwrap(); if (intersection(state.gallery.selection, image_names).length > 0) { - // Some selected images were deleted, need to select the next image - const queryArgs = selectListImagesQueryArgs(state); - const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state); - if (data) { - // When we delete multiple images, we clear the selection. Then, the the next time we load images, we will - // select the first one. This is handled below in the listener for `imagesApi.endpoints.listImages.matchFulfilled`. - dispatch(imageSelected(null)); - } + // Some selected images were deleted, clear selection + dispatch(imageSelected(null)); } // We need to reset the features where the image is in use - none of these work if their image(s) don't exist diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx index faa0bd91ad..1b4ce009f2 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImagePreview.tsx @@ -5,6 +5,7 @@ import { CanvasAlertsInvocationProgress } from 'features/controlLayers/component import { DndImage } from 'features/dnd/DndImage'; import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer'; import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons'; +import { selectAutoSwitch } from 'features/gallery/store/gallerySelectors'; import type { ProgressImage as ProgressImageType } from 'features/nodes/types/common'; import { selectShouldShowImageDetails, selectShouldShowProgressInViewer } from 'features/ui/store/uiSelectors'; import type { AnimationProps } from 'framer-motion'; @@ -21,6 +22,7 @@ import { ProgressIndicator } from './ProgressIndicator2'; export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | null }) => { const shouldShowImageDetails = useAppSelector(selectShouldShowImageDetails); const shouldShowProgressInViewer = useAppSelector(selectShouldShowProgressInViewer); + const autoSwitch = useAppSelector(selectAutoSwitch); const socket = useStore($socket); const [progressEvent, setProgressEvent] = useState(null); @@ -58,6 +60,29 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu }; }, [socket]); + useEffect(() => { + if (!socket) { + return; + } + + if (autoSwitch) { + return; + } + // When auto-switch is enabled, we will get a load event as we switch to the new image. This in turn clears the progress image, + // creating the illusion of the progress image turning into the new image. + // But when auto-switch is disabled, we won't get that load event, so we need to clear the progress image manually. + const onQueueItemStatusChanged = () => { + setProgressEvent(null); + setProgressImage(null); + }; + + socket.on('queue_item_status_changed', onQueueItemStatusChanged); + + return () => { + socket.off('queue_item_status_changed', onQueueItemStatusChanged); + }; + }, [autoSwitch, socket]); + const onLoadImage = useCallback(() => { if (!progressEvent || !imageDTO) { return; diff --git a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx index eea3395dd8..2eafcefe33 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx @@ -4,10 +4,11 @@ import { logger } from 'app/logging/logger'; import { EMPTY_ARRAY } from 'app/store/constants'; import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { + LIMIT, selectGalleryImageMinimumWidth, selectImageToCompare, selectLastSelectedImage, - selectListImagesQueryArgs, + selectListImageNamesQueryArgs, } from 'features/gallery/store/gallerySelectors'; import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice'; import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; @@ -37,31 +38,26 @@ const SCROLL_SEEK_VELOCITY_THRESHOLD = 4096; const DEBOUNCE_DELAY = 500; const SPINNER_OPACITY = 0.3; -type ListImagesQueryArgs = ReturnType; +type ListImageNamesQueryArgs = ReturnType; type GridContext = { - queryArgs: ListImagesQueryArgs; + queryArgs: ListImageNamesQueryArgs; imageNames: string[]; }; -export const useDebouncedListImagesQueryArgs = () => { - const _galleryQueryArgs = useAppSelector(selectListImagesQueryArgs); - const [queryArgs] = useDebounce(_galleryQueryArgs, DEBOUNCE_DELAY); - return queryArgs; -}; - // Hook to get an image DTO from cache or trigger loading const useImageDTOFromListQuery = ( index: number, imageName: string, - queryArgs: ListImagesQueryArgs + queryArgs: ListImageNamesQueryArgs ): ImageDTO | null => { const { arg, options } = useMemo(() => { - const pageOffset = Math.floor(index / queryArgs.limit) * queryArgs.limit; + const pageOffset = Math.floor(index / LIMIT) * LIMIT; return { arg: { ...queryArgs, offset: pageOffset, + limit: LIMIT, } satisfies Parameters[0], options: { selectFromResult: ({ data }) => { @@ -82,7 +78,7 @@ const useImageDTOFromListQuery = ( // Individual image component that gets its data from RTK Query cache const ImageAtPosition = memo( - ({ index, queryArgs, imageName }: { index: number; imageName: string; queryArgs: ListImagesQueryArgs }) => { + ({ index, queryArgs, imageName }: { index: number; imageName: string; queryArgs: ListImageNamesQueryArgs }) => { const imageDTO = useImageDTOFromListQuery(index, imageName, queryArgs); if (!imageDTO) { @@ -408,7 +404,8 @@ const getImageNamesQueryOptions = { } satisfies Parameters[1]; export const useGalleryImageNames = () => { - const queryArgs = useDebouncedListImagesQueryArgs(); + const _queryArgs = useAppSelector(selectListImageNamesQueryArgs); + const [queryArgs] = useDebounce(_queryArgs, DEBOUNCE_DELAY); const { imageNames, isLoading, isFetching } = useGetImageNamesQuery(queryArgs, getImageNamesQueryOptions); return { imageNames, isLoading, isFetching, queryArgs }; }; diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts index f5355a7ba0..e2062d23db 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts @@ -2,8 +2,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { selectGallerySlice } from 'features/gallery/store/gallerySlice'; import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types'; -import type { ListBoardsArgs, ListImagesArgs } from 'services/api/types'; -import type { SetNonNullable } from 'type-fest'; +import type { ListBoardsArgs } from 'services/api/types'; export const selectFirstSelectedImage = createSelector(selectGallerySlice, (gallery) => gallery.selection.at(0)); export const selectLastSelectedImage = createSelector(selectGallerySlice, (gallery) => gallery.selection.at(-1)); @@ -28,7 +27,7 @@ export const selectGallerySearchTerm = createSelector(selectGallerySlice, (galle export const selectGalleryOrderDir = createSelector(selectGallerySlice, (gallery) => gallery.orderDir); export const selectGalleryStarredFirst = createSelector(selectGallerySlice, (gallery) => gallery.starredFirst); -export const selectListImagesQueryArgs = createMemoizedSelector( +export const selectListImageNamesQueryArgs = createMemoizedSelector( [ selectSelectedBoardId, selectGalleryQueryCategories, @@ -36,17 +35,20 @@ export const selectListImagesQueryArgs = createMemoizedSelector( selectGalleryOrderDir, selectGalleryStarredFirst, ], - (board_id, categories, search_term, order_dir, starred_first) => - ({ - board_id, - categories, - search_term, - order_dir, - starred_first, - is_intermediate: false, // We don't show intermediate images in the gallery - limit: 100, // Page size is _always_ 100 - }) satisfies SetNonNullable + (board_id, categories, search_term, order_dir, starred_first) => ({ + board_id, + categories, + search_term, + order_dir, + starred_first, + is_intermediate: false, + }) ); +export const LIMIT = 100; +export const selectListImagesBaseQueryArgs = createMemoizedSelector(selectListImageNamesQueryArgs, (baseQueryArgs) => ({ + ...baseQueryArgs, + limit: LIMIT, +})); export const selectAutoAssignBoardOnClick = createSelector( selectGallerySlice, (gallery) => gallery.autoAssignBoardOnClick diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index cd7303e069..2d3184627e 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -427,61 +427,12 @@ export const imagesApi = api.injectEndpoints({ }, }), }), - /** - * Get counts for starred and unstarred image collections - */ - getImageCollectionCounts: build.query< - paths['/api/v1/images/collections/counts']['get']['responses']['200']['content']['application/json'], - paths['/api/v1/images/collections/counts']['get']['parameters']['query'] - >({ - query: (queryArgs) => ({ - url: buildImagesUrl('collections/counts', queryArgs), - method: 'GET', - }), - providesTags: ['ImageCollectionCounts', 'FetchOnReconnect'], - }), - /** - * Get images from a specific collection (starred or unstarred) - */ - getImageCollection: build.query< - paths['/api/v1/images/collections/{collection}']['get']['responses']['200']['content']['application/json'], - paths['/api/v1/images/collections/{collection}']['get']['parameters']['path'] & - paths['/api/v1/images/collections/{collection}']['get']['parameters']['query'] - >({ - query: ({ collection, ...queryArgs }) => ({ - url: buildImagesUrl(`collections/${collection}`, queryArgs), - method: 'GET', - }), - providesTags: (result, error, { collection, board_id, categories }) => { - const cacheKey = `${collection}-${board_id || 'all'}-${categories?.join(',') || 'all'}`; - return [ - { type: 'ImageCollection', id: collection }, - { type: 'ImageCollection', id: cacheKey }, - 'FetchOnReconnect', - ]; - }, - async onQueryStarted(_, { dispatch, queryFulfilled }) { - // Populate the getImageDTO cache with these images, similar to listImages - const res = await queryFulfilled; - const imageDTOs = res.data.items; - const updates: Param0 = []; - for (const imageDTO of imageDTOs) { - updates.push({ - endpointName: 'getImageDTO', - arg: imageDTO.image_name, - value: imageDTO, - }); - } - dispatch(imagesApi.util.upsertQueryEntries(updates)); - }, - }), /** * Get ordered list of image names for selection operations */ getImageNames: build.query< string[], { - image_origin?: 'internal' | 'external' | null; categories?: ImageCategory[] | null; is_intermediate?: boolean | null; board_id?: string | null; @@ -493,46 +444,11 @@ export const imagesApi = api.injectEndpoints({ url: buildImagesUrl('names', queryArgs), method: 'GET', }), - providesTags: ['ImageNameList', 'FetchOnReconnect'], - }), - /** - * Get paginated images with starred first (unified list) - */ - getUnifiedImageList: build.query< - ListImagesResponse, - { - offset?: number; - limit?: number; - image_origin?: 'internal' | 'external' | null; - categories?: ImageCategory[] | null; - is_intermediate?: boolean | null; - board_id?: string | null; - search_term?: string | null; - order_dir?: SQLiteDirection; - } - >({ - query: (queryArgs) => ({ - url: getListImagesUrl({ ...queryArgs, starred_first: true }), - method: 'GET', - }), - providesTags: (result, error, { board_id, categories }) => [ - { type: 'ImageList', id: getListImagesUrl({ board_id, categories }) }, + providesTags: (result, error, queryArgs) => [ + 'ImageNameList', 'FetchOnReconnect', + { type: 'ImageNameList', id: stableHash(queryArgs) }, ], - async onQueryStarted(_, { dispatch, queryFulfilled }) { - // Populate the getImageDTO cache with these images - const res = await queryFulfilled; - const imageDTOs = res.data.items; - const updates: Param0 = []; - for (const imageDTO of imageDTOs) { - updates.push({ - endpointName: 'getImageDTO', - arg: imageDTO.image_name, - value: imageDTO, - }); - } - dispatch(imagesApi.util.upsertQueryEntries(updates)); - }, }), }), }); @@ -555,11 +471,7 @@ export const { useStarImagesMutation, useUnstarImagesMutation, useBulkDownloadImagesMutation, - useGetImageCollectionCountsQuery, - useGetImageCollectionQuery, - useLazyGetImageCollectionQuery, useGetImageNamesQuery, - useGetUnifiedImageListQuery, } = imagesApi; /** diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 0c01dacfe0..ae843df433 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -752,7 +752,7 @@ export type paths = { patch?: never; trace?: never; }; - "/api/v1/images/collections/counts": { + "/api/v1/images/names": { parameters: { query?: never; header?: never; @@ -760,30 +760,10 @@ export type paths = { cookie?: never; }; /** - * Get Image Collection Counts - * @description Gets counts for starred and unstarred image collections + * Get Image Names + * @description Gets ordered list of all image names (starred first, then unstarred) */ - get: operations["get_image_collection_counts"]; - put?: never; - post?: never; - delete?: never; - options?: never; - head?: never; - patch?: never; - trace?: never; - }; - "/api/v1/images/collections/{collection}": { - parameters: { - query?: never; - header?: never; - path?: never; - cookie?: never; - }; - /** - * Get Image Collection - * @description Gets images from a specific collection (starred or unstarred) - */ - get: operations["get_image_collection"]; + get: operations["get_image_names"]; put?: never; post?: never; delete?: never; @@ -9844,19 +9824,6 @@ export type components = { */ type: "img_channel_offset"; }; - /** ImageCollectionCounts */ - ImageCollectionCounts: { - /** - * Starred Count - * @description The number of starred images in the collection. - */ - starred_count: number; - /** - * Unstarred Count - * @description The number of unstarred images in the collection. - */ - unstarred_count: number; - }; /** * Image Collection Primitive * @description A collection of image primitive values @@ -23728,17 +23695,21 @@ export interface operations { }; }; }; - get_image_collection_counts: { + get_image_names: { parameters: { query?: { - /** @description The origin of images to count. */ + /** @description The origin of images to list. */ image_origin?: components["schemas"]["ResourceOrigin"] | null; /** @description The categories of image to include. */ categories?: components["schemas"]["ImageCategory"][] | null; - /** @description Whether to include intermediate images. */ + /** @description Whether to list intermediate images. */ is_intermediate?: boolean | null; /** @description The board id to filter by. Use 'none' to find images without a board. */ board_id?: string | null; + /** @description The order of sort */ + order_dir?: components["schemas"]["SQLiteDirection"]; + /** @description Whether to sort by starred images first */ + starred_first?: boolean; /** @description The term to search for */ search_term?: string | null; }; @@ -23754,56 +23725,7 @@ export interface operations { [name: string]: unknown; }; content: { - "application/json": components["schemas"]["ImageCollectionCounts"]; - }; - }; - /** @description Validation Error */ - 422: { - headers: { - [name: string]: unknown; - }; - content: { - "application/json": components["schemas"]["HTTPValidationError"]; - }; - }; - }; - }; - get_image_collection: { - parameters: { - query?: { - /** @description The origin of images to list. */ - image_origin?: components["schemas"]["ResourceOrigin"] | null; - /** @description The categories of image to include. */ - categories?: components["schemas"]["ImageCategory"][] | null; - /** @description Whether to list intermediate images. */ - is_intermediate?: boolean | null; - /** @description The board id to filter by. Use 'none' to find images without a board. */ - board_id?: string | null; - /** @description The offset within the collection */ - offset?: number; - /** @description The number of images to return */ - limit?: number; - /** @description The order of sort */ - order_dir?: components["schemas"]["SQLiteDirection"]; - /** @description The term to search for */ - search_term?: string | null; - }; - header?: never; - path: { - /** @description The collection to retrieve from */ - collection: "starred" | "unstarred"; - }; - cookie?: never; - }; - requestBody?: never; - responses: { - /** @description Successful Response */ - 200: { - headers: { - [name: string]: unknown; - }; - content: { - "application/json": components["schemas"]["OffsetPaginatedResults_ImageDTO_"]; + "application/json": string[]; }; }; /** @description Validation Error */ diff --git a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx index 78322f2408..07ecd537d7 100644 --- a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx +++ b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx @@ -5,7 +5,7 @@ import { deepClone } from 'common/util/deepClone'; import { selectAutoSwitch, selectGalleryView, - selectListImagesQueryArgs, + selectListImagesBaseQueryArgs, selectSelectedBoardId, } from 'features/gallery/store/gallerySelectors'; import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice'; @@ -44,7 +44,7 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi const boardTotalAdditions: Record = {}; const boardTagIdsToInvalidate: Set = new Set(); const imageListTagIdsToInvalidate: Set = new Set(); - const listImagesArg = selectListImagesQueryArgs(getState()); + const listImagesArg = selectListImagesBaseQueryArgs(getState()); for (const imageDTO of imageDTOs) { if (imageDTO.is_intermediate) { @@ -94,7 +94,7 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi type: 'ImageList' as const, id: imageListId, })); - dispatch(imagesApi.util.invalidateTags([...boardTags, ...imageListTags])); + dispatch(imagesApi.util.invalidateTags(['ImageNameList', ...boardTags, ...imageListTags])); const autoSwitch = selectAutoSwitch(getState()); diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index edf3c115ac..8feb49f999 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -211,12 +211,12 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None: assert job.bytes > 0, "expected download bytes to be positive" assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes" assert job.download_path == tmp_path / "sdxl-turbo" - assert Path(tmp_path, "sdxl-turbo/model_index.json").exists(), ( - f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist" - ) - assert Path(tmp_path, "sdxl-turbo/text_encoder/config.json").exists(), ( - f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist" - ) + assert Path( + tmp_path, "sdxl-turbo/model_index.json" + ).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist" + assert Path( + tmp_path, "sdxl-turbo/text_encoder/config.json" + ).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist" assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED} queue.stop() diff --git a/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py index ed3e05a9b2..1ad408861e 100644 --- a/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py +++ b/tests/backend/patches/lora_conversions/test_flux_aitoolkit_lora_conversion_utils.py @@ -48,9 +48,9 @@ def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format(): model_keys = set(model.state_dict().keys()) for converted_key_prefix in converted_key_prefixes: - assert any(model_key.startswith(converted_key_prefix) for model_key in model_keys), ( - f"'{converted_key_prefix}' did not match any model keys." - ) + assert any( + model_key.startswith(converted_key_prefix) for model_key in model_keys + ), f"'{converted_key_prefix}' did not match any model keys." def test_lora_model_from_flux_aitoolkit_state_dict():