diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index ae8e6e05a6..ff55749f6b 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -54,6 +54,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( CogView4ConditioningInfo, ConditioningFieldData, FLUXConditioningInfo, + QwenImageConditioningInfo, SD3ConditioningInfo, SDXLConditioningInfo, ZImageConditioningInfo, @@ -144,6 +145,7 @@ class ApiDependencies: SD3ConditioningInfo, CogView4ConditioningInfo, ZImageConditioningInfo, + QwenImageConditioningInfo, AnimaConditioningInfo, ], ephemeral=True, diff --git a/invokeai/app/api/routers/auth.py b/invokeai/app/api/routers/auth.py index 36aeabda82..e0b0c885cd 100644 --- a/invokeai/app/api/routers/auth.py +++ b/invokeai/app/api/routers/auth.py @@ -80,6 +80,7 @@ class SetupStatusResponse(BaseModel): setup_required: bool = Field(description="Whether initial setup is required") multiuser_enabled: bool = Field(description="Whether multiuser mode is enabled") strict_password_checking: bool = Field(description="Whether strict password requirements are enforced") + admin_email: str | None = Field(default=None, description="Email of the first active admin user, if any") @auth_router.get("/status", response_model=SetupStatusResponse) @@ -94,15 +95,25 @@ async def get_setup_status() -> SetupStatusResponse: # If multiuser is disabled, setup is never required if not config.multiuser: return SetupStatusResponse( - setup_required=False, multiuser_enabled=False, strict_password_checking=config.strict_password_checking + setup_required=False, + multiuser_enabled=False, + strict_password_checking=config.strict_password_checking, + admin_email=None, ) # In multiuser mode, check if an admin exists user_service = ApiDependencies.invoker.services.users setup_required = not user_service.has_admin() + # Only expose admin_email during initial setup to avoid leaking + # administrator identity on public deployments. + admin_email = user_service.get_admin_email() if setup_required else None + return SetupStatusResponse( - setup_required=setup_required, multiuser_enabled=True, strict_password_checking=config.strict_password_checking + setup_required=setup_required, + multiuser_enabled=True, + strict_password_checking=config.strict_password_checking, + admin_email=admin_email, ) diff --git a/invokeai/app/api/routers/board_images.py b/invokeai/app/api/routers/board_images.py index cb5e0ab51a..f94e4f2437 100644 --- a/invokeai/app/api/routers/board_images.py +++ b/invokeai/app/api/routers/board_images.py @@ -1,12 +1,53 @@ from fastapi import Body, HTTPException from fastapi.routing import APIRouter +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"]) +def _assert_board_write_access(board_id: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user may not mutate the given board. + + Write access is granted when ANY of these hold: + - The user is an admin. + - The user owns the board. + - The board visibility is Public (public boards accept contributions from any user). + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + except Exception: + raise HTTPException(status_code=404, detail="Board not found") + if current_user.is_admin: + return + if board.user_id == current_user.user_id: + return + if board.board_visibility == BoardVisibility.Public: + return + raise HTTPException(status_code=403, detail="Not authorized to modify this board") + + +def _assert_image_direct_owner(image_name: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user is not the direct owner of the image. + + This is intentionally stricter than _assert_image_owner in images.py: + board ownership is NOT sufficient here. Allowing a user to add someone + else's image to their own board would grant them mutation rights via the + board-ownership fallback in _assert_image_owner, escalating read access + into write access. + """ + if current_user.is_admin: + return + owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name) + if owner is not None and owner == current_user.user_id: + return + raise HTTPException(status_code=403, detail="Not authorized to move this image") + + @board_images_router.post( "/", operation_id="add_image_to_board", @@ -17,14 +58,17 @@ board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"]) response_model=AddImagesToBoardResult, ) async def add_image_to_board( + current_user: CurrentUserOrDefault, board_id: str = Body(description="The id of the board to add to"), image_name: str = Body(description="The name of the image to add"), ) -> AddImagesToBoardResult: """Creates a board_image""" + _assert_board_write_access(board_id, current_user) + _assert_image_direct_owner(image_name, current_user) try: added_images: set[str] = set() affected_boards: set[str] = set() - old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" + old_board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none" ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name) added_images.add(image_name) affected_boards.add(board_id) @@ -48,13 +92,16 @@ async def add_image_to_board( response_model=RemoveImagesFromBoardResult, ) async def remove_image_from_board( + current_user: CurrentUserOrDefault, image_name: str = Body(description="The name of the image to remove", embed=True), ) -> RemoveImagesFromBoardResult: """Removes an image from its board, if it had one""" try: + old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" + if old_board_id != "none": + _assert_board_write_access(old_board_id, current_user) removed_images: set[str] = set() affected_boards: set[str] = set() - old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name) removed_images.add(image_name) affected_boards.add("none") @@ -64,6 +111,8 @@ async def remove_image_from_board( affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to remove image from board") @@ -78,16 +127,21 @@ async def remove_image_from_board( response_model=AddImagesToBoardResult, ) async def add_images_to_board( + current_user: CurrentUserOrDefault, board_id: str = Body(description="The id of the board to add to"), image_names: list[str] = Body(description="The names of the images to add", embed=True), ) -> AddImagesToBoardResult: """Adds a list of images to a board""" + _assert_board_write_access(board_id, current_user) try: added_images: set[str] = set() affected_boards: set[str] = set() for image_name in image_names: try: - old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" + _assert_image_direct_owner(image_name, current_user) + old_board_id = ( + ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none" + ) ApiDependencies.invoker.services.board_images.add_image_to_board( board_id=board_id, image_name=image_name, @@ -96,12 +150,16 @@ async def add_images_to_board( affected_boards.add(board_id) affected_boards.add(old_board_id) + except HTTPException: + raise except Exception: pass return AddImagesToBoardResult( added_images=list(added_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to add images to board") @@ -116,6 +174,7 @@ async def add_images_to_board( response_model=RemoveImagesFromBoardResult, ) async def remove_images_from_board( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(description="The names of the images to remove", embed=True), ) -> RemoveImagesFromBoardResult: """Removes a list of images from their board, if they had one""" @@ -125,15 +184,21 @@ async def remove_images_from_board( for image_name in image_names: try: old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none" + if old_board_id != "none": + _assert_board_write_access(old_board_id, current_user) ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name) removed_images.add(image_name) affected_boards.add("none") affected_boards.add(old_board_id) + except HTTPException: + raise except Exception: pass return RemoveImagesFromBoardResult( removed_images=list(removed_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to remove images from board") diff --git a/invokeai/app/api/routers/boards.py b/invokeai/app/api/routers/boards.py index e93bb8b2a9..6897e90aff 100644 --- a/invokeai/app/api/routers/boards.py +++ b/invokeai/app/api/routers/boards.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies -from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy +from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy, BoardVisibility from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.image_records.image_records_common import ImageCategory from invokeai.app.services.shared.pagination import OffsetPaginatedResults @@ -56,7 +56,14 @@ async def get_board( except Exception: raise HTTPException(status_code=404, detail="Board not found") - if not current_user.is_admin and result.user_id != current_user.user_id: + # Admins can access any board. + # Owners can access their own boards. + # Shared and public boards are visible to all authenticated users. + if ( + not current_user.is_admin + and result.user_id != current_user.user_id + and result.board_visibility == BoardVisibility.Private + ): raise HTTPException(status_code=403, detail="Not authorized to access this board") return result @@ -188,7 +195,11 @@ async def list_all_board_image_names( except Exception: raise HTTPException(status_code=404, detail="Board not found") - if not current_user.is_admin and board.user_id != current_user.user_id: + if ( + not current_user.is_admin + and board.user_id != current_user.user_id + and board.board_visibility == BoardVisibility.Private + ): raise HTTPException(status_code=403, detail="Not authorized to access this board") image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board( @@ -196,4 +207,15 @@ async def list_all_board_image_names( categories, is_intermediate, ) + + # For uncategorized images (board_id="none"), filter to only the caller's + # images so that one user cannot enumerate another's uncategorized images. + # Admin users can see all uncategorized images. + if board_id == "none" and not current_user.is_admin: + image_names = [ + name + for name in image_names + if ApiDependencies.invoker.services.image_records.get_user_id(name) == current_user.user_id + ] + return image_names diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 6b11762c9e..a3ae6fce82 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -38,6 +38,96 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"]) IMAGE_MAX_AGE = 31536000 +def _assert_image_owner(image_name: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user does not own the image and is not an admin. + + Ownership is satisfied when ANY of these hold: + - The user is an admin. + - The user is the image's direct owner (image_records.user_id). + - The user owns the board the image sits on. + - The image sits on a Public board (public boards grant mutation rights). + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + if current_user.is_admin: + return + owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name) + if owner is not None and owner == current_user.user_id: + return + + # Check whether the user owns the board the image belongs to, + # or the board is Public (public boards grant mutation rights). + board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) + if board_id is not None: + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + if board.user_id == current_user.user_id: + return + if board.board_visibility == BoardVisibility.Public: + return + except Exception: + pass + + raise HTTPException(status_code=403, detail="Not authorized to modify this image") + + +def _assert_image_read_access(image_name: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user may not view the image. + + Access is granted when ANY of these hold: + - The user is an admin. + - The user owns the image. + - The image sits on a shared or public board. + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + if current_user.is_admin: + return + + owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name) + if owner is not None and owner == current_user.user_id: + return + + # Check whether the image's board makes it visible to other users. + board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) + if board_id is not None: + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public): + return + except Exception: + pass + + raise HTTPException(status_code=403, detail="Not authorized to access this image") + + +def _assert_board_read_access(board_id: str, current_user: CurrentUserOrDefault) -> None: + """Raise 403 if the current user may not read images from this board. + + Access is granted when ANY of these hold: + - The user is an admin. + - The user owns the board. + - The board visibility is Shared or Public. + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + if current_user.is_admin: + return + + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + except Exception: + raise HTTPException(status_code=404, detail="Board not found") + + if board.user_id == current_user.user_id: + return + + if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public): + return + + raise HTTPException(status_code=403, detail="Not authorized to access this board") + + class ResizeToDimensions(BaseModel): width: int = Field(..., gt=0) height: int = Field(..., gt=0) @@ -83,6 +173,22 @@ async def upload_image( ), ) -> ImageDTO: """Uploads an image for the current user""" + # If uploading into a board, verify the user has write access. + # Public boards allow uploads from any authenticated user. + if board_id is not None: + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + except Exception: + raise HTTPException(status_code=404, detail="Board not found") + if ( + not current_user.is_admin + and board.user_id != current_user.user_id + and board.board_visibility != BoardVisibility.Public + ): + raise HTTPException(status_code=403, detail="Not authorized to upload to this board") + if not file.content_type or not file.content_type.startswith("image"): raise HTTPException(status_code=415, detail="Not an image") @@ -165,9 +271,11 @@ async def create_image_upload_entry( @images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult) async def delete_image( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of the image to delete"), ) -> DeleteImagesResult: """Deletes an image""" + _assert_image_owner(image_name, current_user) deleted_images: set[str] = set() affected_boards: set[str] = set() @@ -189,26 +297,31 @@ async def delete_image( @images_router.delete("/intermediates", operation_id="clear_intermediates") -async def clear_intermediates() -> int: - """Clears all intermediates""" +async def clear_intermediates( + current_user: CurrentUserOrDefault, +) -> int: + """Clears all intermediates. Requires admin.""" + if not current_user.is_admin: + raise HTTPException(status_code=403, detail="Only admins can clear all intermediates") try: count_deleted = ApiDependencies.invoker.services.images.delete_intermediates() return count_deleted except Exception: raise HTTPException(status_code=500, detail="Failed to clear intermediates") - pass @images_router.get("/intermediates", operation_id="get_intermediates_count") -async def get_intermediates_count() -> int: - """Gets the count of intermediate images""" +async def get_intermediates_count( + current_user: CurrentUserOrDefault, +) -> int: + """Gets the count of intermediate images. Non-admin users only see their own intermediates.""" try: - return ApiDependencies.invoker.services.images.get_intermediates_count() + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.images.get_intermediates_count(user_id=user_id) except Exception: raise HTTPException(status_code=500, detail="Failed to get intermediates") - pass @images_router.patch( @@ -217,10 +330,12 @@ async def get_intermediates_count() -> int: response_model=ImageDTO, ) async def update_image( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of the image to update"), image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"), ) -> ImageDTO: """Updates an image""" + _assert_image_owner(image_name, current_user) try: return ApiDependencies.invoker.services.images.update(image_name, image_changes) @@ -234,9 +349,11 @@ async def update_image( response_model=ImageDTO, ) async def get_image_dto( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of image to get"), ) -> ImageDTO: """Gets an image's DTO""" + _assert_image_read_access(image_name, current_user) try: return ApiDependencies.invoker.services.images.get_dto(image_name) @@ -250,9 +367,11 @@ async def get_image_dto( response_model=Optional[MetadataField], ) async def get_image_metadata( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of image to get"), ) -> Optional[MetadataField]: """Gets an image's metadata""" + _assert_image_read_access(image_name, current_user) try: return ApiDependencies.invoker.services.images.get_metadata(image_name) @@ -269,8 +388,11 @@ class WorkflowAndGraphResponse(BaseModel): "/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse ) async def get_image_workflow( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of image whose workflow to get"), ) -> WorkflowAndGraphResponse: + _assert_image_read_access(image_name, current_user) + try: workflow = ApiDependencies.invoker.services.images.get_workflow(image_name) graph = ApiDependencies.invoker.services.images.get_graph(image_name) @@ -306,8 +428,12 @@ async def get_image_workflow( async def get_image_full( image_name: str = Path(description="The name of full-resolution image file to get"), ) -> Response: - """Gets a full-resolution image file""" + """Gets a full-resolution image file. + This endpoint is intentionally unauthenticated because browsers load images + via tags which cannot send Bearer tokens. Image names are UUIDs, + providing security through unguessability. + """ try: path = ApiDependencies.invoker.services.images.get_path(image_name) with open(path, "rb") as f: @@ -335,8 +461,12 @@ async def get_image_full( async def get_image_thumbnail( image_name: str = Path(description="The name of thumbnail image file to get"), ) -> Response: - """Gets a thumbnail image file""" + """Gets a thumbnail image file. + This endpoint is intentionally unauthenticated because browsers load images + via tags which cannot send Bearer tokens. Image names are UUIDs, + providing security through unguessability. + """ try: path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True) with open(path, "rb") as f: @@ -354,9 +484,11 @@ async def get_image_thumbnail( response_model=ImageUrlsDTO, ) async def get_image_urls( + current_user: CurrentUserOrDefault, image_name: str = Path(description="The name of the image whose URL to get"), ) -> ImageUrlsDTO: """Gets an image and thumbnail URL""" + _assert_image_read_access(image_name, current_user) try: image_url = ApiDependencies.invoker.services.images.get_url(image_name) @@ -392,6 +524,11 @@ async def list_image_dtos( ) -> OffsetPaginatedResults[ImageDTO]: """Gets a list of image DTOs for the current user""" + # Validate that the caller can read from this board before listing its images. + # "none" is a sentinel for uncategorized images and is handled by the SQL layer. + if board_id is not None and board_id != "none": + _assert_board_read_access(board_id, current_user) + image_dtos = ApiDependencies.invoker.services.images.get_many( offset, limit, @@ -410,6 +547,7 @@ async def list_image_dtos( @images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult) async def delete_images_from_list( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(description="The list of names of images to delete", embed=True), ) -> DeleteImagesResult: try: @@ -417,24 +555,31 @@ async def delete_images_from_list( affected_boards: set[str] = set() for image_name in image_names: try: + _assert_image_owner(image_name, current_user) image_dto = ApiDependencies.invoker.services.images.get_dto(image_name) board_id = image_dto.board_id or "none" ApiDependencies.invoker.services.images.delete(image_name) deleted_images.add(image_name) affected_boards.add(board_id) + except HTTPException: + raise except Exception: pass return DeleteImagesResult( deleted_images=list(deleted_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to delete images") @images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult) -async def delete_uncategorized_images() -> DeleteImagesResult: - """Deletes all images that are uncategorized""" +async def delete_uncategorized_images( + current_user: CurrentUserOrDefault, +) -> DeleteImagesResult: + """Deletes all uncategorized images owned by the current user (or all if admin)""" image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board( board_id="none", categories=None, is_intermediate=None @@ -445,9 +590,13 @@ async def delete_uncategorized_images() -> DeleteImagesResult: affected_boards: set[str] = set() for image_name in image_names: try: + _assert_image_owner(image_name, current_user) ApiDependencies.invoker.services.images.delete(image_name) deleted_images.add(image_name) affected_boards.add("none") + except HTTPException: + # Skip images not owned by the current user + pass except Exception: pass return DeleteImagesResult( @@ -464,6 +613,7 @@ class ImagesUpdatedFromListResult(BaseModel): @images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult) async def star_images_in_list( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(description="The list of names of images to star", embed=True), ) -> StarredImagesResult: try: @@ -471,23 +621,29 @@ async def star_images_in_list( affected_boards: set[str] = set() for image_name in image_names: try: + _assert_image_owner(image_name, current_user) updated_image_dto = ApiDependencies.invoker.services.images.update( image_name, changes=ImageRecordChanges(starred=True) ) starred_images.add(image_name) affected_boards.add(updated_image_dto.board_id or "none") + except HTTPException: + raise except Exception: pass return StarredImagesResult( starred_images=list(starred_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to star images") @images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult) async def unstar_images_in_list( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(description="The list of names of images to unstar", embed=True), ) -> UnstarredImagesResult: try: @@ -495,17 +651,22 @@ async def unstar_images_in_list( affected_boards: set[str] = set() for image_name in image_names: try: + _assert_image_owner(image_name, current_user) updated_image_dto = ApiDependencies.invoker.services.images.update( image_name, changes=ImageRecordChanges(starred=False) ) unstarred_images.add(image_name) affected_boards.add(updated_image_dto.board_id or "none") + except HTTPException: + raise except Exception: pass return UnstarredImagesResult( unstarred_images=list(unstarred_images), affected_boards=list(affected_boards), ) + except HTTPException: + raise except Exception: raise HTTPException(status_code=500, detail="Failed to unstar images") @@ -523,6 +684,7 @@ class ImagesDownloaded(BaseModel): "/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202 ) async def download_images_from_list( + current_user: CurrentUserOrDefault, background_tasks: BackgroundTasks, image_names: Optional[list[str]] = Body( default=None, description="The list of names of images to download", embed=True @@ -533,6 +695,16 @@ async def download_images_from_list( ) -> ImagesDownloaded: if (image_names is None or len(image_names) == 0) and board_id is None: raise HTTPException(status_code=400, detail="No images or board id specified.") + + # Validate that the caller can read every image they are requesting. + # For a board_id request, check board visibility; for explicit image names, + # check each image individually. + if board_id: + _assert_board_read_access(board_id, current_user) + if image_names: + for name in image_names: + _assert_image_read_access(name, current_user) + bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id) background_tasks.add_task( @@ -540,6 +712,7 @@ async def download_images_from_list( image_names, board_id, bulk_download_item_id, + current_user.user_id, ) return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip") @@ -558,11 +731,21 @@ async def download_images_from_list( }, ) async def get_bulk_download_item( + current_user: CurrentUserOrDefault, background_tasks: BackgroundTasks, bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"), ) -> FileResponse: - """Gets a bulk download zip file""" + """Gets a bulk download zip file. + + Requires authentication. The caller must be the user who initiated the + download (tracked by the bulk download service) or an admin. + """ try: + # Verify the caller owns this download (or is an admin) + owner = ApiDependencies.invoker.services.bulk_download.get_owner(bulk_download_item_name) + if owner is not None and owner != current_user.user_id and not current_user.is_admin: + raise HTTPException(status_code=403, detail="Not authorized to access this download") + path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name) response = FileResponse( @@ -574,6 +757,8 @@ async def get_bulk_download_item( response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}" background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name) return response + except HTTPException: + raise except Exception: raise HTTPException(status_code=404) @@ -594,6 +779,10 @@ async def get_image_names( ) -> ImageNamesResult: """Gets ordered list of image names with metadata for optimistic updates""" + # Validate that the caller can read from this board before listing its images. + if board_id is not None and board_id != "none": + _assert_board_read_access(board_id, current_user) + try: result = ApiDependencies.invoker.services.images.get_image_names( starred_first=starred_first, @@ -617,6 +806,7 @@ async def get_image_names( responses={200: {"model": list[ImageDTO]}}, ) async def get_images_by_names( + current_user: CurrentUserOrDefault, image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"), ) -> list[ImageDTO]: """Gets image DTOs for the specified image names. Maintains order of input names.""" @@ -628,8 +818,12 @@ async def get_images_by_names( image_dtos: list[ImageDTO] = [] for name in image_names: try: + _assert_image_read_access(name, current_user) dto = image_service.get_dto(name) image_dtos.append(dto) + except HTTPException: + # Skip images the user is not authorized to view + continue except Exception: # Skip missing images - they may have been deleted between name fetch and DTO fetch continue diff --git a/invokeai/app/api/routers/model_manager.py b/invokeai/app/api/routers/model_manager.py index a1f55a3b04..f351be11ad 100644 --- a/invokeai/app/api/routers/model_manager.py +++ b/invokeai/app/api/routers/model_manager.py @@ -889,7 +889,7 @@ async def install_hugging_face_model( "/install", operation_id="list_model_installs", ) -async def list_model_installs() -> List[ModelInstallJob]: +async def list_model_installs(current_admin: AdminUserOrDefault) -> List[ModelInstallJob]: """Return the list of model install jobs. Install jobs have a numeric `id`, a `status`, and other fields that provide information on @@ -921,7 +921,9 @@ async def list_model_installs() -> List[ModelInstallJob]: 404: {"description": "No such job"}, }, ) -async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob: +async def get_model_install_job( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install id") +) -> ModelInstallJob: """ Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs' for information on the format of the return value. @@ -964,7 +966,9 @@ async def cancel_model_install_job( }, status_code=201, ) -async def pause_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob: +async def pause_model_install_job( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID") +) -> ModelInstallJob: """Pause the model install job corresponding to the given job ID.""" installer = ApiDependencies.invoker.services.model_manager.install try: @@ -984,7 +988,9 @@ async def pause_model_install_job(id: int = Path(description="Model install job }, status_code=201, ) -async def resume_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob: +async def resume_model_install_job( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID") +) -> ModelInstallJob: """Resume a paused model install job corresponding to the given job ID.""" installer = ApiDependencies.invoker.services.model_manager.install try: @@ -1004,7 +1010,9 @@ async def resume_model_install_job(id: int = Path(description="Model install job }, status_code=201, ) -async def restart_failed_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob: +async def restart_failed_model_install_job( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID") +) -> ModelInstallJob: """Restart failed or non-resumable file downloads for the given job.""" installer = ApiDependencies.invoker.services.model_manager.install try: @@ -1025,6 +1033,7 @@ async def restart_failed_model_install_job(id: int = Path(description="Model ins status_code=201, ) async def restart_model_install_file( + current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID"), file_source: AnyHttpUrl = Body(description="File download URL to restart"), ) -> ModelInstallJob: @@ -1336,7 +1345,7 @@ class DeleteOrphanedModelsResponse(BaseModel): operation_id="get_orphaned_models", response_model=list[OrphanedModelInfo], ) -async def get_orphaned_models() -> list[OrphanedModelInfo]: +async def get_orphaned_models(_: AdminUserOrDefault) -> list[OrphanedModelInfo]: """Find orphaned model directories. Orphaned models are directories in the models folder that contain model files @@ -1363,7 +1372,9 @@ async def get_orphaned_models() -> list[OrphanedModelInfo]: operation_id="delete_orphaned_models", response_model=DeleteOrphanedModelsResponse, ) -async def delete_orphaned_models(request: DeleteOrphanedModelsRequest) -> DeleteOrphanedModelsResponse: +async def delete_orphaned_models( + request: DeleteOrphanedModelsRequest, _: AdminUserOrDefault +) -> DeleteOrphanedModelsResponse: """Delete specified orphaned model directories. Args: diff --git a/invokeai/app/api/routers/recall_parameters.py b/invokeai/app/api/routers/recall_parameters.py index 0af3fd29b0..ec08adba2e 100644 --- a/invokeai/app/api/routers/recall_parameters.py +++ b/invokeai/app/api/routers/recall_parameters.py @@ -7,6 +7,7 @@ from fastapi import Body, HTTPException, Path from fastapi.routing import APIRouter from pydantic import BaseModel, ConfigDict, Field +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies from invokeai.backend.image_util.controlnet_processor import process_controlnet_image from invokeai.backend.model_manager.taxonomy import ModelType @@ -291,12 +292,58 @@ def resolve_ip_adapter_models(ip_adapters: list[IPAdapterRecallParameter]) -> li return resolved_adapters +def _assert_recall_image_access(parameters: "RecallParameter", current_user: CurrentUserOrDefault) -> None: + """Validate that the caller can read every image referenced in the recall parameters. + + Control layers and IP adapters may reference image_name fields. Without this + check an attacker who knows another user's image UUID could use the recall + endpoint to extract image dimensions and — for ControlNet preprocessors — mint + a derived processed image they can then fetch. + """ + from invokeai.app.services.board_records.board_records_common import BoardVisibility + + image_names: list[str] = [] + if parameters.control_layers: + for layer in parameters.control_layers: + if layer.image_name is not None: + image_names.append(layer.image_name) + if parameters.ip_adapters: + for adapter in parameters.ip_adapters: + if adapter.image_name is not None: + image_names.append(adapter.image_name) + + if not image_names: + return + + # Admin can access all images + if current_user.is_admin: + return + + for image_name in image_names: + owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name) + if owner is not None and owner == current_user.user_id: + continue + + # Check board visibility + board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) + if board_id is not None: + try: + board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id) + if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public): + continue + except Exception: + pass + + raise HTTPException(status_code=403, detail=f"Not authorized to access image {image_name}") + + @recall_parameters_router.post( "/{queue_id}", operation_id="update_recall_parameters", response_model=dict[str, Any], ) async def update_recall_parameters( + current_user: CurrentUserOrDefault, queue_id: str = Path(..., description="The queue id to perform this operation on"), parameters: RecallParameter = Body(..., description="Recall parameters to update"), ) -> dict[str, Any]: @@ -328,6 +375,10 @@ async def update_recall_parameters( """ logger = ApiDependencies.invoker.services.logger + # Validate image access before processing — prevents information leakage + # (dimensions) and derived-image minting via ControlNet preprocessors. + _assert_recall_image_access(parameters, current_user) + try: # Get only the parameters that were actually provided (non-None values) provided_params = {k: v for k, v in parameters.model_dump().items() if v is not None} @@ -335,14 +386,14 @@ async def update_recall_parameters( if not provided_params: return {"status": "no_parameters_provided", "updated_count": 0} - # Store each parameter in client state using a consistent key format + # Store each parameter in client state scoped to the current user updated_count = 0 for param_key, param_value in provided_params.items(): # Convert parameter values to JSON strings for storage value_str = json.dumps(param_value) try: ApiDependencies.invoker.services.client_state_persistence.set_by_key( - queue_id, f"recall_{param_key}", value_str + current_user.user_id, f"recall_{param_key}", value_str ) updated_count += 1 except Exception as e: @@ -396,7 +447,9 @@ async def update_recall_parameters( logger.info( f"Emitting recall_parameters_updated event for queue {queue_id} with {len(provided_params)} parameters" ) - ApiDependencies.invoker.services.events.emit_recall_parameters_updated(queue_id, provided_params) + ApiDependencies.invoker.services.events.emit_recall_parameters_updated( + queue_id, current_user.user_id, provided_params + ) logger.info("Successfully emitted recall_parameters_updated event") except Exception as e: logger.error(f"Error emitting recall parameters event: {e}", exc_info=True) @@ -425,6 +478,7 @@ async def update_recall_parameters( response_model=dict[str, Any], ) async def get_recall_parameters( + current_user: CurrentUserOrDefault, queue_id: str = Path(..., description="The queue id to retrieve parameters for"), ) -> dict[str, Any]: """ diff --git a/invokeai/app/api/routers/session_queue.py b/invokeai/app/api/routers/session_queue.py index 403e7727cb..41a5a411c7 100644 --- a/invokeai/app/api/routers/session_queue.py +++ b/invokeai/app/api/routers/session_queue.py @@ -44,7 +44,8 @@ def sanitize_queue_item_for_user( """Sanitize queue item for non-admin users viewing other users' items. For non-admin users viewing queue items belonging to other users, - the field_values, session graph, and workflow should be hidden/cleared to protect privacy. + only timestamps, status, and error information are exposed. All other + fields (user identity, generation parameters, graphs, workflows) are stripped. Args: queue_item: The queue item to sanitize @@ -58,15 +59,25 @@ def sanitize_queue_item_for_user( if is_admin or queue_item.user_id == current_user_id: return queue_item - # For non-admins viewing other users' items, clear sensitive fields - # Create a shallow copy to avoid mutating the original + # For non-admins viewing other users' items, strip everything except + # item_id, queue_id, status, and timestamps sanitized_item = queue_item.model_copy(deep=False) + sanitized_item.user_id = "redacted" + sanitized_item.user_display_name = None + sanitized_item.user_email = None + sanitized_item.batch_id = "redacted" + sanitized_item.session_id = "redacted" + sanitized_item.origin = None + sanitized_item.destination = None + sanitized_item.priority = 0 sanitized_item.field_values = None + sanitized_item.retried_from_item_id = None sanitized_item.workflow = None - # Clear the session graph by replacing it with an empty graph execution state - # This prevents information leakage through the generation graph + sanitized_item.error_type = None + sanitized_item.error_message = None + sanitized_item.error_traceback = None sanitized_item.session = GraphExecutionState( - id=queue_item.session.id, + id="redacted", graph=Graph(), ) return sanitized_item @@ -126,12 +137,16 @@ async def list_all_queue_items( }, ) async def get_queue_item_ids( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"), ) -> ItemIdsResult: - """Gets all queue item ids that match the given parameters""" + """Gets all queue item ids that match the given parameters. Non-admin users only see their own items.""" try: - return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir) + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.session_queue.get_queue_item_ids( + queue_id=queue_id, order_dir=order_dir, user_id=user_id + ) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}") @@ -376,11 +391,15 @@ async def prune( }, ) async def get_current_queue_item( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), ) -> Optional[SessionQueueItem]: """Gets the currently execution queue item""" try: - return ApiDependencies.invoker.services.session_queue.get_current(queue_id) + item = ApiDependencies.invoker.services.session_queue.get_current(queue_id) + if item is not None: + item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin) + return item except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}") @@ -393,11 +412,15 @@ async def get_current_queue_item( }, ) async def get_next_queue_item( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), ) -> Optional[SessionQueueItem]: """Gets the next queue item, without executing it""" try: - return ApiDependencies.invoker.services.session_queue.get_next(queue_id) + item = ApiDependencies.invoker.services.session_queue.get_next(queue_id) + if item is not None: + item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin) + return item except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}") @@ -413,9 +436,10 @@ async def get_queue_status( current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), ) -> SessionQueueAndProcessorStatus: - """Gets the status of the session queue""" + """Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it.""" try: - queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=current_user.user_id) + user_id = None if current_user.is_admin else current_user.user_id + queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=user_id) processor = ApiDependencies.invoker.services.session_processor.get_status() return SessionQueueAndProcessorStatus(queue=queue, processor=processor) except Exception as e: @@ -430,12 +454,16 @@ async def get_queue_status( }, ) async def get_batch_status( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to perform this operation on"), batch_id: str = Path(description="The batch to get the status of"), ) -> BatchStatus: - """Gets the status of the session queue""" + """Gets the status of a batch. Non-admin users only see their own batches.""" try: - return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id) + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.session_queue.get_batch_status( + queue_id=queue_id, batch_id=batch_id, user_id=user_id + ) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}") @@ -529,13 +557,15 @@ async def cancel_queue_item( responses={200: {"model": SessionQueueCountsByDestination}}, ) async def counts_by_destination( + current_user: CurrentUserOrDefault, queue_id: str = Path(description="The queue id to query"), destination: str = Query(description="The destination to query"), ) -> SessionQueueCountsByDestination: - """Gets the counts of queue items by destination""" + """Gets the counts of queue items by destination. Non-admin users only see their own items.""" try: + user_id = None if current_user.is_admin else current_user.user_id return ApiDependencies.invoker.services.session_queue.get_counts_by_destination( - queue_id=queue_id, destination=destination + queue_id=queue_id, destination=destination, user_id=user_id ) except Exception as e: raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}") diff --git a/invokeai/app/api/routers/workflows.py b/invokeai/app/api/routers/workflows.py index 72d50a416b..1c88a77a3f 100644 --- a/invokeai/app/api/routers/workflows.py +++ b/invokeai/app/api/routers/workflows.py @@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, File, HTTPException, Path, Query, UploadFil from fastapi.responses import FileResponse from PIL import Image +from invokeai.app.api.auth_dependencies import CurrentUserOrDefault from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection @@ -33,16 +34,25 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"]) }, ) async def get_workflow( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to get"), ) -> WorkflowRecordWithThumbnailDTO: """Gets a workflow""" try: - thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id) workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id) - return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump()) except WorkflowNotFoundError: raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration + if config.multiuser: + is_default = workflow.workflow.meta.category is WorkflowCategory.Default + is_owner = workflow.user_id == current_user.user_id + if not (is_default or is_owner or workflow.is_public or current_user.is_admin): + raise HTTPException(status_code=403, detail="Not authorized to access this workflow") + + thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id) + return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump()) + @workflows_router.patch( "/i/{workflow_id}", @@ -52,10 +62,21 @@ async def get_workflow( }, ) async def update_workflow( + current_user: CurrentUserOrDefault, workflow: Workflow = Body(description="The updated workflow", embed=True), ) -> WorkflowRecordDTO: """Updates a workflow""" - return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow) + config = ApiDependencies.invoker.services.configuration + if config.multiuser: + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + if not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + # Pass user_id for defense-in-depth SQL scoping; admins pass None to allow any. + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow, user_id=user_id) @workflows_router.delete( @@ -63,15 +84,25 @@ async def update_workflow( operation_id="delete_workflow", ) async def delete_workflow( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to delete"), ) -> None: """Deletes a workflow""" + config = ApiDependencies.invoker.services.configuration + if config.multiuser: + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + if not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to delete this workflow") try: ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id) except WorkflowThumbnailFileNotFoundException: # It's OK if the workflow has no thumbnail file. We can still delete the workflow. pass - ApiDependencies.invoker.services.workflow_records.delete(workflow_id) + user_id = None if current_user.is_admin else current_user.user_id + ApiDependencies.invoker.services.workflow_records.delete(workflow_id, user_id=user_id) @workflows_router.post( @@ -82,10 +113,11 @@ async def delete_workflow( }, ) async def create_workflow( + current_user: CurrentUserOrDefault, workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True), ) -> WorkflowRecordDTO: """Creates a workflow""" - return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow) + return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow, user_id=current_user.user_id) @workflows_router.get( @@ -96,6 +128,7 @@ async def create_workflow( }, ) async def list_workflows( + current_user: CurrentUserOrDefault, page: int = Query(default=0, description="The page to get"), per_page: Optional[int] = Query(default=None, description="The number of workflows per page"), order_by: WorkflowRecordOrderBy = Query( @@ -106,8 +139,19 @@ async def list_workflows( tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"), query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"), has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]: """Gets a page of workflows""" + config = ApiDependencies.invoker.services.configuration + + # In multiuser mode, scope user-category workflows to the current user unless fetching shared workflows + user_id_filter: Optional[str] = None + if config.multiuser: + # Only filter 'user' category results by user_id when not explicitly listing public workflows + has_user_category = not categories or WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id + workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = [] workflows = ApiDependencies.invoker.services.workflow_records.get_many( order_by=order_by, @@ -118,6 +162,8 @@ async def list_workflows( categories=categories, tags=tags, has_been_opened=has_been_opened, + user_id=user_id_filter, + is_public=is_public, ) for workflow in workflows.items: workflows_with_thumbnails.append( @@ -143,15 +189,20 @@ async def list_workflows( }, ) async def set_workflow_thumbnail( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to update"), image: UploadFile = File(description="The image file to upload"), ): """Sets a workflow's thumbnail image""" try: - ApiDependencies.invoker.services.workflow_records.get(workflow_id) + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) except WorkflowNotFoundError: raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + if not image.content_type or not image.content_type.startswith("image"): raise HTTPException(status_code=415, detail="Not an image") @@ -177,14 +228,19 @@ async def set_workflow_thumbnail( }, ) async def delete_workflow_thumbnail( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to update"), ): """Removes a workflow's thumbnail image""" try: - ApiDependencies.invoker.services.workflow_records.get(workflow_id) + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) except WorkflowNotFoundError: raise HTTPException(status_code=404, detail="Workflow not found") + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + try: ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id) except ValueError as e: @@ -206,8 +262,12 @@ async def delete_workflow_thumbnail( async def get_workflow_thumbnail( workflow_id: str = Path(description="The id of the workflow thumbnail to get"), ) -> FileResponse: - """Gets a workflow's thumbnail image""" + """Gets a workflow's thumbnail image. + This endpoint is intentionally unauthenticated because browsers load images + via tags which cannot send Bearer tokens. Workflow IDs are UUIDs, + providing security through unguessability. + """ try: path = ApiDependencies.invoker.services.workflow_thumbnails.get_path(workflow_id) @@ -223,37 +283,91 @@ async def get_workflow_thumbnail( raise HTTPException(status_code=404) +@workflows_router.patch( + "/i/{workflow_id}/is_public", + operation_id="update_workflow_is_public", + responses={ + 200: {"model": WorkflowRecordDTO}, + }, +) +async def update_workflow_is_public( + current_user: CurrentUserOrDefault, + workflow_id: str = Path(description="The workflow to update"), + is_public: bool = Body(description="Whether the workflow should be shared publicly", embed=True), +) -> WorkflowRecordDTO: + """Updates whether a workflow is shared publicly""" + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + + user_id = None if current_user.is_admin else current_user.user_id + return ApiDependencies.invoker.services.workflow_records.update_is_public( + workflow_id=workflow_id, is_public=is_public, user_id=user_id + ) + + @workflows_router.get("/tags", operation_id="get_all_tags") async def get_all_tags( + current_user: CurrentUserOrDefault, categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> list[str]: """Gets all unique tags from workflows""" + config = ApiDependencies.invoker.services.configuration + user_id_filter: Optional[str] = None + if config.multiuser: + has_user_category = not categories or WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id - return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories) + return ApiDependencies.invoker.services.workflow_records.get_all_tags( + categories=categories, user_id=user_id_filter, is_public=is_public + ) @workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag") async def get_counts_by_tag( + current_user: CurrentUserOrDefault, tags: list[str] = Query(description="The tags to get counts for"), categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"), has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> dict[str, int]: """Counts workflows by tag""" + config = ApiDependencies.invoker.services.configuration + user_id_filter: Optional[str] = None + if config.multiuser: + has_user_category = not categories or WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id return ApiDependencies.invoker.services.workflow_records.counts_by_tag( - tags=tags, categories=categories, has_been_opened=has_been_opened + tags=tags, categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public ) @workflows_router.get("/counts_by_category", operation_id="counts_by_category") async def counts_by_category( + current_user: CurrentUserOrDefault, categories: list[WorkflowCategory] = Query(description="The categories to include"), has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"), + is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"), ) -> dict[str, int]: """Counts workflows by category""" + config = ApiDependencies.invoker.services.configuration + user_id_filter: Optional[str] = None + if config.multiuser: + has_user_category = WorkflowCategory.User in categories + if has_user_category and is_public is not True: + user_id_filter = current_user.user_id return ApiDependencies.invoker.services.workflow_records.counts_by_category( - categories=categories, has_been_opened=has_been_opened + categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public ) @@ -262,7 +376,18 @@ async def counts_by_category( operation_id="update_opened_at", ) async def update_opened_at( + current_user: CurrentUserOrDefault, workflow_id: str = Path(description="The workflow to update"), ) -> None: """Updates the opened_at field of a workflow""" - ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id) + try: + existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id) + except WorkflowNotFoundError: + raise HTTPException(status_code=404, detail="Workflow not found") + + config = ApiDependencies.invoker.services.configuration + if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id: + raise HTTPException(status_code=403, detail="Not authorized to update this workflow") + + user_id = None if current_user.is_admin else current_user.user_id + ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id, user_id=user_id) diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index fcead54eb1..5783b804c0 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -121,6 +121,11 @@ class SocketIO: Returns True to accept the connection, False to reject it. Stores user_id in the internal socket users dict for later use. + + In multiuser mode, connections without a valid token are rejected outright + so that anonymous clients cannot subscribe to queue rooms and observe + queue activity belonging to other users. In single-user mode, unauthenticated + connections are accepted as the system admin user. """ # Extract token from auth data or headers token = None @@ -137,6 +142,23 @@ class SocketIO: if token: token_data = verify_token(token) if token_data: + # In multiuser mode, also verify the backing user record still + # exists and is active — mirrors the REST auth check in + # auth_dependencies.py. A deleted or deactivated user whose + # JWT has not yet expired must not be allowed to open a socket. + if self._is_multiuser_enabled(): + try: + from invokeai.app.api.dependencies import ApiDependencies + + user = ApiDependencies.invoker.services.users.get(token_data.user_id) + if user is None or not user.is_active: + logger.warning(f"Rejecting socket {sid}: user {token_data.user_id} not found or inactive") + return False + except Exception: + # If user service is unavailable, fail closed + logger.warning(f"Rejecting socket {sid}: unable to verify user record") + return False + # Store user_id and is_admin in socket users dict self._socket_users[sid] = { "user_id": token_data.user_id, @@ -147,14 +169,37 @@ class SocketIO: ) return True - # If no valid token, store system user for backward compatibility + # No valid token provided. In multiuser mode this is not allowed — reject + # the connection so anonymous clients cannot subscribe to queue rooms. + # In single-user mode, fall through and accept the socket as system admin. + if self._is_multiuser_enabled(): + logger.warning( + f"Rejecting socket {sid} connection: multiuser mode is enabled and no valid auth token was provided" + ) + return False + self._socket_users[sid] = { "user_id": "system", - "is_admin": False, + "is_admin": True, } - logger.debug(f"Socket {sid} connected as system user (no valid token)") + logger.debug(f"Socket {sid} connected as system admin (single-user mode)") return True + @staticmethod + def _is_multiuser_enabled() -> bool: + """Check whether multiuser mode is enabled. Fails closed if configuration + is not yet initialized, which should not happen in practice but prevents + accidentally opening the socket during startup races.""" + try: + # Imported here to avoid a circular import at module load time. + from invokeai.app.api.dependencies import ApiDependencies + + return bool(ApiDependencies.invoker.services.configuration.multiuser) + except Exception: + # If dependencies are not initialized, fail closed (treat as multiuser) + # so we never accidentally admit an anonymous socket. + return True + async def _handle_disconnect(self, sid: str) -> None: """Handle socket disconnection and cleanup user info.""" if sid in self._socket_users: @@ -165,15 +210,20 @@ class SocketIO: """Handle queue subscription and add socket to both queue and user-specific rooms.""" queue_id = QueueSubscriptionEvent(**data).queue_id - # Check if we have user info for this socket + # Check if we have user info for this socket. In multiuser mode _handle_connect + # will have already rejected any socket without a valid token, so missing user + # info here is a bug — refuse the subscription rather than silently falling back + # to an anonymous system user who could then receive queue item events. if sid not in self._socket_users: - logger.warning( - f"Socket {sid} subscribing to queue {queue_id} but has no user info - need to authenticate via connect event" - ) - # Store as system user temporarily - real auth should happen in connect + if self._is_multiuser_enabled(): + logger.warning( + f"Refusing queue subscription for socket {sid}: no user info (socket not authenticated via connect event)" + ) + return + # Single-user mode: safe to fall back to the system admin user. self._socket_users[sid] = { "user_id": "system", - "is_admin": False, + "is_admin": True, } user_id = self._socket_users[sid]["user_id"] @@ -198,6 +248,13 @@ class SocketIO: await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id) async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None: + # In multiuser mode, only allow authenticated sockets to subscribe. + # Bulk download events are routed to user-specific rooms, so the + # bulk_download_id room subscription is only kept for single-user + # backward compatibility. + if self._is_multiuser_enabled() and sid not in self._socket_users: + logger.warning(f"Refusing bulk download subscription for unknown socket {sid} in multiuser mode") + return await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id) async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None: @@ -206,9 +263,17 @@ class SocketIO: async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): """Handle queue events with user isolation. - Invocation events (progress, started, complete) are private - only emit to owner and admins. - Queue item status events are public - emit to all users (field values hidden via API). - Other queue events emit to all subscribers. + All queue item events (invocation events AND QueueItemStatusChangedEvent) are + private to the owning user and admins. They carry unsanitized user_id, batch_id, + session_id, origin, destination and error metadata, and must never be broadcast + to the whole queue room — otherwise any other authenticated subscriber could + observe cross-user queue activity. + + RecallParametersUpdatedEvent is also private to the owner + admins. + + BatchEnqueuedEvent carries the enqueuing user's batch_id/origin/counts and + is also routed privately. QueueClearedEvent is the only queue event that + is still broadcast to the whole queue room. IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase inherits from QueueItemEventBase. The order of isinstance checks matters! @@ -237,24 +302,40 @@ class SocketIO: logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room") - # Queue item status events are visible to all users (field values masked via API) - # This catches QueueItemStatusChangedEvent but NOT InvocationEvents (already handled above) + # Other queue item events (QueueItemStatusChangedEvent) carry unsanitized + # user_id, batch_id, session_id, origin, destination and error metadata. + # They are private to the owning user + admins — never broadcast to the + # full queue room. elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"): - # Emit to all subscribers in the queue - await self._sio.emit( - event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id - ) + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") - logger.info( - f"Emitted public queue item event {event_name} to all subscribers in queue {event_data.queue_id}" - ) + logger.debug(f"Emitted private queue item event {event_name} to user room {user_room} and admin room") + + # RecallParametersUpdatedEvent is private - only emit to owner + admins + elif isinstance(event_data, RecallParametersUpdatedEvent): + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") + logger.debug(f"Emitted private recall_parameters_updated event to user room {user_room} and admin room") + + # BatchEnqueuedEvent carries the enqueuing user's batch_id, origin, and + # enqueued counts. Route it privately to the owner + admins so other + # users do not observe cross-user batch activity. + elif isinstance(event_data, BatchEnqueuedEvent): + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") + logger.debug(f"Emitted private batch_enqueued event to user room {user_room} and admin room") else: - # For other queue events (like QueueClearedEvent, BatchEnqueuedEvent), emit to all subscribers + # For remaining queue events (e.g. QueueClearedEvent) that do not + # carry user identity, emit to all subscribers in the queue room. await self._sio.emit( event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id ) - logger.info( + logger.debug( f"Emitted general queue event {event_name} to all subscribers in queue {event_data.queue_id}" ) except Exception as e: @@ -265,4 +346,17 @@ class SocketIO: await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json")) async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None: - await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id) + event_name, event_data = event + # Route to user-specific + admin rooms so that other authenticated + # users cannot learn the bulk_download_item_name (the capability token + # needed to fetch the zip from the unauthenticated GET endpoint). + # In single-user mode (user_id="system"), fall back to the shared + # bulk_download_id room for backward compatibility. + if hasattr(event_data, "user_id") and event_data.user_id != "system": + user_room = f"user:{event_data.user_id}" + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room) + await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin") + else: + await self._sio.emit( + event=event_name, data=event_data.model_dump(mode="json"), room=event_data.bulk_download_id + ) diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 71b99d6687..fbe0e9a615 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -171,6 +171,8 @@ class FieldDescriptions: sd3_model = "SD3 model (MMDiTX) to load" cogview4_model = "CogView4 model (Transformer) to load" z_image_model = "Z-Image model (Transformer) to load" + qwen_image_model = "Qwen Image Edit model (Transformer) to load" + qwen_vl_encoder = "Qwen2.5-VL tokenizer, processor and text/vision encoder" sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" @@ -340,6 +342,12 @@ class ZImageConditioningField(BaseModel): ) +class QwenImageConditioningField(BaseModel): + """A Qwen Image Edit conditioning tensor primitive value""" + + conditioning_name: str = Field(description="The name of conditioning tensor") + + class AnimaConditioningField(BaseModel): """An Anima conditioning tensor primitive value. diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index 29e8b3d69b..da24d8802b 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -166,6 +166,10 @@ GENERATION_MODES = Literal[ "z_image_img2img", "z_image_inpaint", "z_image_outpaint", + "qwen_image_txt2img", + "qwen_image_img2img", + "qwen_image_inpaint", + "qwen_image_outpaint", "anima_txt2img", "anima_img2img", "anima_inpaint", diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 29fbe5100c..6b5afb5529 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -72,6 +72,13 @@ class GlmEncoderField(BaseModel): text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel") +class QwenVLEncoderField(BaseModel): + """Field for Qwen2.5-VL encoder used by Qwen Image Edit models.""" + + tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel") + text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel") + + class Qwen3EncoderField(BaseModel): """Field for Qwen3 text encoder used by Z-Image models.""" diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index 2f404d16ba..7ec6c3dc14 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -25,6 +25,7 @@ from invokeai.app.invocations.fields import ( InputField, LatentsField, OutputField, + QwenImageConditioningField, SD3ConditioningField, TensorField, UIComponent, @@ -474,6 +475,17 @@ class ZImageConditioningOutput(BaseInvocationOutput): return cls(conditioning=ZImageConditioningField(conditioning_name=conditioning_name)) +@invocation_output("qwen_image_conditioning_output") +class QwenImageConditioningOutput(BaseInvocationOutput): + """Base class for nodes that output a Qwen Image Edit conditioning tensor.""" + + conditioning: QwenImageConditioningField = OutputField(description=FieldDescriptions.cond) + + @classmethod + def build(cls, conditioning_name: str) -> "QwenImageConditioningOutput": + return cls(conditioning=QwenImageConditioningField(conditioning_name=conditioning_name)) + + @invocation_output("anima_conditioning_output") class AnimaConditioningOutput(BaseInvocationOutput): """Base class for nodes that output an Anima text conditioning tensor.""" diff --git a/invokeai/app/invocations/qwen_image_denoise.py b/invokeai/app/invocations/qwen_image_denoise.py new file mode 100644 index 0000000000..04e21a26c3 --- /dev/null +++ b/invokeai/app/invocations/qwen_image_denoise.py @@ -0,0 +1,490 @@ +from contextlib import ExitStack +from typing import Callable, Iterator, Optional, Tuple + +import torch +import torchvision.transforms as tv_transforms +from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel +from torchvision.transforms.functional import resize as tv_resize +from tqdm import tqdm + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.invocations.fields import ( + DenoiseMaskField, + FieldDescriptions, + Input, + InputField, + LatentsField, + QwenImageConditioningField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.model import TransformerField +from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat +from invokeai.backend.patches.layer_patcher import LayerPatcher +from invokeai.backend.patches.lora_conversions.qwen_image_lora_constants import ( + QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX, +) +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw +from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension +from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import QwenImageConditioningInfo +from invokeai.backend.util.devices import TorchDevice + + +@invocation( + "qwen_image_denoise", + title="Denoise - Qwen Image", + tags=["image", "qwen_image"], + category="image", + version="1.0.0", + classification=Classification.Prototype, +) +class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): + """Run the denoising process with a Qwen Image model.""" + + # If latents is provided, this means we are doing image-to-image. + latents: Optional[LatentsField] = InputField( + default=None, description=FieldDescriptions.latents, input=Input.Connection + ) + # Reference image latents (encoded through VAE) to concatenate with noisy latents. + reference_latents: Optional[LatentsField] = InputField( + default=None, + description="Reference image latents to guide generation. Encoded through the VAE.", + input=Input.Connection, + ) + # denoise_mask is used for image-to-image inpainting. Only the masked region is modified. + denoise_mask: Optional[DenoiseMaskField] = InputField( + default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection + ) + denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start) + denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) + transformer: TransformerField = InputField( + description=FieldDescriptions.qwen_image_model, input=Input.Connection, title="Transformer" + ) + positive_conditioning: QwenImageConditioningField = InputField( + description=FieldDescriptions.positive_cond, input=Input.Connection + ) + negative_conditioning: Optional[QwenImageConditioningField] = InputField( + default=None, description=FieldDescriptions.negative_cond, input=Input.Connection + ) + cfg_scale: float | list[float] = InputField(default=4.0, description=FieldDescriptions.cfg_scale, title="CFG Scale") + width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.") + height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.") + steps: int = InputField(default=40, gt=0, description=FieldDescriptions.steps) + seed: int = InputField(default=0, description="Randomness seed for reproducibility.") + shift: Optional[float] = InputField( + default=None, + description="Override the sigma schedule shift. " + "When set, uses a fixed shift (e.g. 3.0 for Lightning LoRAs) instead of the default dynamic shifting. " + "Leave unset for the base model's default schedule.", + ) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> LatentsOutput: + latents = self._run_diffusion(context) + latents = latents.detach().to("cpu") + + name = context.tensors.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) + + def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None: + if self.denoise_mask is None: + return None + mask = context.tensors.load(self.denoise_mask.mask_name) + mask = 1.0 - mask + + _, _, latent_height, latent_width = latents.shape + mask = tv_resize( + img=mask, + size=[latent_height, latent_width], + interpolation=tv_transforms.InterpolationMode.BILINEAR, + antialias=False, + ) + + mask = mask.to(device=latents.device, dtype=latents.dtype) + return mask + + def _load_text_conditioning( + self, + context: InvocationContext, + conditioning_name: str, + dtype: torch.dtype, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + cond_data = context.conditioning.load(conditioning_name) + assert len(cond_data.conditionings) == 1 + conditioning = cond_data.conditionings[0] + assert isinstance(conditioning, QwenImageConditioningInfo) + conditioning = conditioning.to(dtype=dtype, device=device) + return conditioning.prompt_embeds, conditioning.prompt_embeds_mask + + def _get_noise( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + seed: int, + ) -> torch.Tensor: + rand_device = "cpu" + rand_dtype = torch.float32 + + return torch.randn( + batch_size, + num_channels_latents, + int(height) // LATENT_SCALE_FACTOR, + int(width) // LATENT_SCALE_FACTOR, + device=rand_device, + dtype=rand_dtype, + generator=torch.Generator(device=rand_device).manual_seed(seed), + ).to(device=device, dtype=dtype) + + def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]: + if isinstance(self.cfg_scale, float): + cfg_scale = [self.cfg_scale] * num_timesteps + elif isinstance(self.cfg_scale, list): + assert len(self.cfg_scale) == num_timesteps + cfg_scale = self.cfg_scale + else: + raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}") + return cfg_scale + + @staticmethod + def _pack_latents( + latents: torch.Tensor, batch_size: int, num_channels: int, height: int, width: int + ) -> torch.Tensor: + """Pack 4D latents (B, C, H, W) into 2x2-patched 3D (B, H/2*W/2, C*4).""" + latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4) + return latents + + @staticmethod + def _unpack_latents(latents: torch.Tensor, height: int, width: int) -> torch.Tensor: + """Unpack 3D patched latents (B, seq, C*4) back to 4D (B, C, H, W).""" + batch_size, _num_patches, channels = latents.shape + # height/width are in latent space; they must be divisible by 2 for packing + h = 2 * (height // 2) + w = 2 * (width // 2) + latents = latents.view(batch_size, h // 2, w // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + latents = latents.reshape(batch_size, channels // 4, h, w) + return latents + + def _run_diffusion(self, context: InvocationContext): + inference_dtype = torch.bfloat16 + device = TorchDevice.choose_torch_device() + + transformer_info = context.models.load(self.transformer.transformer) + assert isinstance(transformer_info.model, QwenImageTransformer2DModel) + + # Load conditioning + pos_prompt_embeds, pos_prompt_mask = self._load_text_conditioning( + context=context, + conditioning_name=self.positive_conditioning.conditioning_name, + dtype=inference_dtype, + device=device, + ) + + neg_prompt_embeds = None + neg_prompt_mask = None + # Match the diffusers pipeline: only enable CFG when cfg_scale > 1 AND negative conditioning is provided. + # With cfg_scale <= 1, the negative prediction is unused, so skip it entirely. + # For per-step arrays, enable CFG if any step has scale > 1. + if isinstance(self.cfg_scale, list): + any_cfg_above_one = any(v > 1.0 for v in self.cfg_scale) + else: + any_cfg_above_one = self.cfg_scale > 1.0 + do_classifier_free_guidance = self.negative_conditioning is not None and any_cfg_above_one + if do_classifier_free_guidance: + neg_prompt_embeds, neg_prompt_mask = self._load_text_conditioning( + context=context, + conditioning_name=self.negative_conditioning.conditioning_name, + dtype=inference_dtype, + device=device, + ) + + # Prepare the timestep / sigma schedule + patch_size = transformer_info.model.config.patch_size + assert isinstance(patch_size, int) + # Output channels is 16 (the actual latent channels) + out_channels = transformer_info.model.config.out_channels + assert isinstance(out_channels, int) + + latent_height = self.height // LATENT_SCALE_FACTOR + latent_width = self.width // LATENT_SCALE_FACTOR + image_seq_len = (latent_height * latent_width) // (patch_size**2) + + # Use the actual FlowMatchEulerDiscreteScheduler to compute sigmas/timesteps, + # exactly matching the diffusers pipeline. + import math + + import numpy as np + from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler + + # Try to load the scheduler config from the model's directory (Diffusers models + # have a scheduler/ subdir). For GGUF models this path doesn't exist, so fall + # back to instantiating the scheduler with the known Qwen Image defaults. + model_path = context.models.get_absolute_path(context.models.get_config(self.transformer.transformer)) + scheduler_path = model_path / "scheduler" + if scheduler_path.is_dir() and (scheduler_path / "scheduler_config.json").exists(): + scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(str(scheduler_path), local_files_only=True) + else: + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=True, + base_shift=0.5, + max_shift=0.9, + base_image_seq_len=256, + max_image_seq_len=8192, + shift_terminal=0.02, + num_train_timesteps=1000, + time_shift_type="exponential", + ) + + if self.shift is not None: + # Lightning LoRA: fixed shift + mu = math.log(self.shift) + else: + # Default dynamic shifting + # Linear interpolation matching diffusers' calculate_shift + base_shift = scheduler.config.get("base_shift", 0.5) + max_shift = scheduler.config.get("max_shift", 0.9) + base_seq = scheduler.config.get("base_image_seq_len", 256) + max_seq = scheduler.config.get("max_image_seq_len", 4096) + m = (max_shift - base_shift) / (max_seq - base_seq) + b = base_shift - m * base_seq + mu = image_seq_len * m + b + + init_sigmas = np.linspace(1.0, 1.0 / self.steps, self.steps).tolist() + scheduler.set_timesteps(sigmas=init_sigmas, mu=mu, device=device) + + # Clip the schedule based on denoising_start/denoising_end to support img2img strength. + # The scheduler's sigmas go from high (noisy) to 0 (clean). We clip to the fractional range. + sigmas_sched = scheduler.sigmas # (N+1,) including terminal 0 + if self.denoising_start > 0 or self.denoising_end < 1: + total_sigmas = len(sigmas_sched) - 1 # exclude terminal + start_idx = int(round(self.denoising_start * total_sigmas)) + end_idx = int(round(self.denoising_end * total_sigmas)) + sigmas_sched = sigmas_sched[start_idx : end_idx + 1] # +1 to include the next sigma for dt + # Rebuild timesteps from clipped sigmas (exclude terminal 0) + timesteps_sched = sigmas_sched[:-1] * scheduler.config.num_train_timesteps + else: + timesteps_sched = scheduler.timesteps + + total_steps = len(timesteps_sched) + + cfg_scale = self._prepare_cfg_scale(total_steps) + + # Load initial latents if provided (for img2img) + init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None + if init_latents is not None: + init_latents = init_latents.to(device=device, dtype=inference_dtype) + if init_latents.dim() == 5: + init_latents = init_latents.squeeze(2) + + # Load reference image latents if provided + ref_latents = None + if self.reference_latents is not None: + ref_latents = context.tensors.load(self.reference_latents.latents_name) + ref_latents = ref_latents.to(device=device, dtype=inference_dtype) + # The VAE encoder produces 5D latents (B, C, 1, H, W); squeeze the frame dim + # so we have 4D (B, C, H, W) for packing. + if ref_latents.dim() == 5: + ref_latents = ref_latents.squeeze(2) + + # Generate noise (16 channels - the output latent channels) + noise = self._get_noise( + batch_size=1, + num_channels_latents=out_channels, + height=self.height, + width=self.width, + dtype=inference_dtype, + device=device, + seed=self.seed, + ) + + # Prepare input latent image + if init_latents is not None: + s_0 = sigmas_sched[0].item() + latents = s_0 * noise + (1.0 - s_0) * init_latents + else: + if self.denoising_start > 1e-5: + raise ValueError("denoising_start should be 0 when initial latents are not provided.") + latents = noise + + if total_steps <= 0: + return latents + + # Pack latents into 2x2 patches: (B, C, H, W) -> (B, H/2*W/2, C*4) + latents = self._pack_latents(latents, 1, out_channels, latent_height, latent_width) + + # Determine whether the model uses reference latent conditioning (zero_cond_t). + # Edit models (zero_cond_t=True) expect [noisy_patches ; ref_patches] in the sequence. + # Txt2img models (zero_cond_t=False) only take noisy patches. + has_zero_cond_t = getattr(transformer_info.model, "zero_cond_t", False) or getattr( + transformer_info.model.config, "zero_cond_t", False + ) + use_ref_latents = has_zero_cond_t + + ref_latents_packed = None + if use_ref_latents: + if ref_latents is not None: + _, ref_ch, rh, rw = ref_latents.shape + if rh != latent_height or rw != latent_width: + ref_latents = torch.nn.functional.interpolate( + ref_latents, size=(latent_height, latent_width), mode="bilinear" + ) + else: + # No reference image provided — use zeros so the model still gets the + # expected sequence layout. + ref_latents = torch.zeros( + 1, out_channels, latent_height, latent_width, device=device, dtype=inference_dtype + ) + ref_latents_packed = self._pack_latents(ref_latents, 1, out_channels, latent_height, latent_width) + + # img_shapes tells the transformer the spatial layout of patches. + if use_ref_latents: + img_shapes = [ + [ + (1, latent_height // 2, latent_width // 2), + (1, latent_height // 2, latent_width // 2), + ] + ] + else: + img_shapes = [ + [ + (1, latent_height // 2, latent_width // 2), + ] + ] + + # Prepare inpaint extension (operates in 4D space, so unpack/repack around it) + inpaint_mask = self._prep_inpaint_mask(context, noise) # noise has the right 4D shape + inpaint_extension: RectifiedFlowInpaintExtension | None = None + if inpaint_mask is not None: + assert init_latents is not None + inpaint_extension = RectifiedFlowInpaintExtension( + init_latents=init_latents, + inpaint_mask=inpaint_mask, + noise=noise, + ) + + step_callback = self._build_step_callback(context) + + step_callback( + PipelineIntermediateState( + step=0, + order=1, + total_steps=total_steps, + timestep=int(timesteps_sched[0].item()) if len(timesteps_sched) > 0 else 0, + latents=self._unpack_latents(latents, latent_height, latent_width), + ), + ) + + noisy_seq_len = latents.shape[1] + + # Determine if the model is quantized — GGUF models need sidecar patching for LoRAs + transformer_config = context.models.get_config(self.transformer.transformer) + model_is_quantized = transformer_config.format in (ModelFormat.GGUFQuantized,) + + with ExitStack() as exit_stack: + (cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device()) + assert isinstance(transformer, QwenImageTransformer2DModel) + + # Apply LoRA patches to the transformer + exit_stack.enter_context( + LayerPatcher.apply_smart_model_patches( + model=transformer, + patches=self._lora_iterator(context), + prefix=QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX, + dtype=inference_dtype, + cached_weights=cached_weights, + force_sidecar_patching=model_is_quantized, + ) + ) + + for step_idx, t in enumerate(tqdm(timesteps_sched)): + # The pipeline passes timestep / 1000 to the transformer + timestep = t.expand(latents.shape[0]).to(inference_dtype) + + # For edit models: concatenate noisy and reference patches along the sequence dim + # For txt2img models: just use noisy patches + if ref_latents_packed is not None: + model_input = torch.cat([latents, ref_latents_packed], dim=1) + else: + model_input = latents + + noise_pred_cond = transformer( + hidden_states=model_input, + encoder_hidden_states=pos_prompt_embeds, + encoder_hidden_states_mask=pos_prompt_mask, + timestep=timestep / 1000, + img_shapes=img_shapes, + return_dict=False, + )[0] + # Only keep the noisy-latent portion of the output + noise_pred_cond = noise_pred_cond[:, :noisy_seq_len] + + if do_classifier_free_guidance and neg_prompt_embeds is not None: + noise_pred_uncond = transformer( + hidden_states=model_input, + encoder_hidden_states=neg_prompt_embeds, + encoder_hidden_states_mask=neg_prompt_mask, + timestep=timestep / 1000, + img_shapes=img_shapes, + return_dict=False, + )[0] + noise_pred_uncond = noise_pred_uncond[:, :noisy_seq_len] + + noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond) + else: + noise_pred = noise_pred_cond + + # Euler step using the (possibly clipped) sigma schedule + sigma_curr = sigmas_sched[step_idx] + sigma_next = sigmas_sched[step_idx + 1] + dt = sigma_next - sigma_curr + latents = latents.to(torch.float32) + dt * noise_pred.to(torch.float32) + latents = latents.to(inference_dtype) + + if inpaint_extension is not None: + sigma_next = sigmas_sched[step_idx + 1].item() + latents_4d = self._unpack_latents(latents, latent_height, latent_width) + latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(latents_4d, sigma_next) + latents = self._pack_latents(latents_4d, 1, out_channels, latent_height, latent_width) + + step_callback( + PipelineIntermediateState( + step=step_idx + 1, + order=1, + total_steps=total_steps, + timestep=int(t.item()), + latents=self._unpack_latents(latents, latent_height, latent_width), + ), + ) + + # Unpack back to 4D then add frame dim for the video-style VAE: (B, C, 1, H, W) + latents = self._unpack_latents(latents, latent_height, latent_width) + latents = latents.unsqueeze(2) + return latents + + def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]: + def step_callback(state: PipelineIntermediateState) -> None: + context.util.sd_step_callback(state, BaseModelType.QwenImage) + + return step_callback + + def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]: + """Iterate over LoRA models to apply to the transformer.""" + for lora in self.transformer.loras: + lora_info = context.models.load(lora.lora) + if not isinstance(lora_info.model, ModelPatchRaw): + raise TypeError( + f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}." + ) + yield (lora_info.model, lora.weight) + del lora_info diff --git a/invokeai/app/invocations/qwen_image_image_to_latents.py b/invokeai/app/invocations/qwen_image_image_to_latents.py new file mode 100644 index 0000000000..c5fe1b5d5c --- /dev/null +++ b/invokeai/app/invocations/qwen_image_image_to_latents.py @@ -0,0 +1,96 @@ +import einops +import torch +from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage +from PIL import Image as PILImage + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + Input, + InputField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.model import VAEField +from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor +from invokeai.backend.util.devices import TorchDevice + + +@invocation( + "qwen_image_i2l", + title="Image to Latents - Qwen Image", + tags=["image", "latents", "vae", "i2l", "qwen_image"], + category="image", + version="1.0.0", + classification=Classification.Prototype, +) +class QwenImageImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): + """Generates latents from an image using the Qwen Image VAE.""" + + image: ImageField = InputField(description="The image to encode.") + vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) + width: int | None = InputField( + default=None, + description="Resize the image to this width before encoding. If not set, encodes at the image's original size.", + ) + height: int | None = InputField( + default=None, + description="Resize the image to this height before encoding. If not set, encodes at the image's original size.", + ) + + @staticmethod + def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor: + with vae_info.model_on_device() as (_, vae): + assert isinstance(vae, AutoencoderKLQwenImage) + + vae.disable_tiling() + + image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype) + with torch.inference_mode(): + # The Qwen Image VAE expects 5D input: (B, C, num_frames, H, W) + if image_tensor.dim() == 4: + image_tensor = image_tensor.unsqueeze(2) + + posterior = vae.encode(image_tensor).latent_dist + # Use mode (argmax) for deterministic encoding, matching diffusers + latents: torch.Tensor = posterior.mode().to(dtype=vae.dtype) + + # Normalize with per-channel latents_mean / latents_std + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(vae.config.latents_std) + .view(1, vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents = (latents - latents_mean) / latents_std + + return latents + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> LatentsOutput: + image = context.images.get_pil(self.image.image_name) + + # If target dimensions are specified, resize the image BEFORE encoding + # (matching the diffusers pipeline which resizes in pixel space, not latent space). + if self.width is not None and self.height is not None: + image = image.convert("RGB").resize((self.width, self.height), resample=PILImage.LANCZOS) + + image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB")) + if image_tensor.dim() == 3: + image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w") + + vae_info = context.models.load(self.vae.vae) + + latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor) + + latents = latents.to("cpu") + name = context.tensors.save(tensor=latents) + return LatentsOutput.build(latents_name=name, latents=latents, seed=None) diff --git a/invokeai/app/invocations/qwen_image_latents_to_image.py b/invokeai/app/invocations/qwen_image_latents_to_image.py new file mode 100644 index 0000000000..b3ea39c4bb --- /dev/null +++ b/invokeai/app/invocations/qwen_image_latents_to_image.py @@ -0,0 +1,85 @@ +from contextlib import nullcontext + +import torch +from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage +from einops import rearrange +from PIL import Image + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + Input, + InputField, + LatentsField, + WithBoard, + WithMetadata, +) +from invokeai.app.invocations.model import VAEField +from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt +from invokeai.backend.util.devices import TorchDevice + + +@invocation( + "qwen_image_l2i", + title="Latents to Image - Qwen Image", + tags=["latents", "image", "vae", "l2i", "qwen_image"], + category="latents", + version="1.0.0", + classification=Classification.Prototype, +) +class QwenImageLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard): + """Generates an image from latents using the Qwen Image VAE.""" + + latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection) + vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection) + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> ImageOutput: + latents = context.tensors.load(self.latents.latents_name) + + vae_info = context.models.load(self.vae.vae) + assert isinstance(vae_info.model, AutoencoderKLQwenImage) + with ( + SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), + vae_info.model_on_device() as (_, vae), + ): + context.util.signal_progress("Running VAE") + assert isinstance(vae, AutoencoderKLQwenImage) + latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype) + + vae.disable_tiling() + + tiling_context = nullcontext() + + TorchDevice.empty_cache() + + with torch.inference_mode(), tiling_context: + # The Qwen Image VAE uses per-channel latents_mean / latents_std + # instead of a single scaling_factor. + # Latents are 5D: (B, C, num_frames, H, W) — the unpack from the + # denoise step already produces this shape. + latents_mean = ( + torch.tensor(vae.config.latents_mean) + .view(1, vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + + img = vae.decode(latents, return_dict=False)[0] + # Drop the temporal frame dimension: (B, C, 1, H, W) -> (B, C, H, W) + img = img[:, :, 0] + + img = img.clamp(-1, 1) + img = rearrange(img[0], "c h w -> h w c") + img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy()) + + TorchDevice.empty_cache() + + image_dto = context.images.save(image=img_pil) + + return ImageOutput.build(image_dto) diff --git a/invokeai/app/invocations/qwen_image_lora_loader.py b/invokeai/app/invocations/qwen_image_lora_loader.py new file mode 100644 index 0000000000..f670b2d895 --- /dev/null +++ b/invokeai/app/invocations/qwen_image_lora_loader.py @@ -0,0 +1,115 @@ +from typing import Optional + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + Classification, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField +from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType + + +@invocation_output("qwen_image_lora_loader_output") +class QwenImageLoRALoaderOutput(BaseInvocationOutput): + """Qwen Image LoRA Loader Output""" + + transformer: Optional[TransformerField] = OutputField( + default=None, description=FieldDescriptions.transformer, title="Transformer" + ) + + +@invocation( + "qwen_image_lora_loader", + title="Apply LoRA - Qwen Image", + tags=["lora", "model", "qwen_image"], + category="model", + version="1.0.0", + classification=Classification.Prototype, +) +class QwenImageLoRALoaderInvocation(BaseInvocation): + """Apply a LoRA model to a Qwen Image transformer.""" + + lora: ModelIdentifierField = InputField( + description=FieldDescriptions.lora_model, + title="LoRA", + ui_model_base=BaseModelType.QwenImage, + ui_model_type=ModelType.LoRA, + ) + weight: float = InputField(default=1.0, description=FieldDescriptions.lora_weight) + transformer: TransformerField | None = InputField( + default=None, + description=FieldDescriptions.transformer, + input=Input.Connection, + title="Transformer", + ) + + def invoke(self, context: InvocationContext) -> QwenImageLoRALoaderOutput: + lora_key = self.lora.key + + if not context.models.exists(lora_key): + raise ValueError(f"Unknown lora: {lora_key}!") + + if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras): + raise ValueError(f'LoRA "{lora_key}" already applied to transformer.') + + output = QwenImageLoRALoaderOutput() + + if self.transformer is not None: + output.transformer = self.transformer.model_copy(deep=True) + output.transformer.loras.append( + LoRAField( + lora=self.lora, + weight=self.weight, + ) + ) + + return output + + +@invocation( + "qwen_image_lora_collection_loader", + title="Apply LoRA Collection - Qwen Image", + tags=["lora", "model", "qwen_image"], + category="model", + version="1.0.0", + classification=Classification.Prototype, +) +class QwenImageLoRACollectionLoader(BaseInvocation): + """Applies a collection of LoRAs to a Qwen Image transformer.""" + + loras: Optional[LoRAField | list[LoRAField]] = InputField( + default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs" + ) + transformer: Optional[TransformerField] = InputField( + default=None, + description=FieldDescriptions.transformer, + input=Input.Connection, + title="Transformer", + ) + + def invoke(self, context: InvocationContext) -> QwenImageLoRALoaderOutput: + output = QwenImageLoRALoaderOutput() + loras = self.loras if isinstance(self.loras, list) else [self.loras] + added_loras: list[str] = [] + + if self.transformer is not None: + output.transformer = self.transformer.model_copy(deep=True) + + for lora in loras: + if lora is None: + continue + if lora.lora.key in added_loras: + continue + if not context.models.exists(lora.lora.key): + raise Exception(f"Unknown lora: {lora.lora.key}!") + + added_loras.append(lora.lora.key) + + if self.transformer is not None and output.transformer is not None: + output.transformer.loras.append(lora) + + return output diff --git a/invokeai/app/invocations/qwen_image_model_loader.py b/invokeai/app/invocations/qwen_image_model_loader.py new file mode 100644 index 0000000000..fd96067f56 --- /dev/null +++ b/invokeai/app/invocations/qwen_image_model_loader.py @@ -0,0 +1,107 @@ +from typing import Optional + +from invokeai.app.invocations.baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + Classification, + invocation, + invocation_output, +) +from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField +from invokeai.app.invocations.model import ( + ModelIdentifierField, + QwenVLEncoderField, + TransformerField, + VAEField, +) +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType + + +@invocation_output("qwen_image_model_loader_output") +class QwenImageModelLoaderOutput(BaseInvocationOutput): + """Qwen Image model loader output.""" + + transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer") + qwen_vl_encoder: QwenVLEncoderField = OutputField( + description=FieldDescriptions.qwen_vl_encoder, title="Qwen VL Encoder" + ) + vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") + + +@invocation( + "qwen_image_model_loader", + title="Main Model - Qwen Image", + tags=["model", "qwen_image"], + category="model", + version="1.1.0", + classification=Classification.Prototype, +) +class QwenImageModelLoaderInvocation(BaseInvocation): + """Loads a Qwen Image model, outputting its submodels. + + The transformer is always loaded from the main model (Diffusers or GGUF). + + For GGUF quantized models, the VAE and Qwen VL encoder must come from a + separate Diffusers model specified in the "Component Source" field. + + For Diffusers models, all components are extracted from the main model + automatically. The "Component Source" field is ignored. + """ + + model: ModelIdentifierField = InputField( + description=FieldDescriptions.qwen_image_model, + input=Input.Direct, + ui_model_base=BaseModelType.QwenImage, + ui_model_type=ModelType.Main, + title="Transformer", + ) + + component_source: Optional[ModelIdentifierField] = InputField( + default=None, + description="Diffusers Qwen Image model to extract the VAE and Qwen VL encoder from. " + "Required when using a GGUF quantized transformer. " + "Ignored when the main model is already in Diffusers format.", + input=Input.Direct, + ui_model_base=BaseModelType.QwenImage, + ui_model_type=ModelType.Main, + ui_model_format=ModelFormat.Diffusers, + title="Component Source (Diffusers)", + ) + + def invoke(self, context: InvocationContext) -> QwenImageModelLoaderOutput: + main_config = context.models.get_config(self.model) + main_is_diffusers = main_config.format == ModelFormat.Diffusers + + # Transformer always comes from the main model + transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) + + if main_is_diffusers: + # Diffusers model: extract all components directly + vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE}) + tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) + text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder}) + elif self.component_source is not None: + # GGUF/checkpoint transformer: get VAE + encoder from the component source + source_config = context.models.get_config(self.component_source) + if source_config.format != ModelFormat.Diffusers: + raise ValueError( + f"The Component Source model must be in Diffusers format. " + f"The selected model '{source_config.name}' is in {source_config.format.value} format." + ) + vae = self.component_source.model_copy(update={"submodel_type": SubModelType.VAE}) + tokenizer = self.component_source.model_copy(update={"submodel_type": SubModelType.Tokenizer}) + text_encoder = self.component_source.model_copy(update={"submodel_type": SubModelType.TextEncoder}) + else: + raise ValueError( + "No source for VAE and Qwen VL encoder. " + "GGUF quantized models only contain the transformer — " + "please set 'Component Source' to a Diffusers Qwen Image model " + "to provide the VAE and text encoder." + ) + + return QwenImageModelLoaderOutput( + transformer=TransformerField(transformer=transformer, loras=[]), + qwen_vl_encoder=QwenVLEncoderField(tokenizer=tokenizer, text_encoder=text_encoder), + vae=VAEField(vae=vae), + ) diff --git a/invokeai/app/invocations/qwen_image_text_encoder.py b/invokeai/app/invocations/qwen_image_text_encoder.py new file mode 100644 index 0000000000..a067421452 --- /dev/null +++ b/invokeai/app/invocations/qwen_image_text_encoder.py @@ -0,0 +1,298 @@ +from typing import Literal + +import torch +from PIL import Image as PILImage + +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation +from invokeai.app.invocations.fields import ( + FieldDescriptions, + ImageField, + Input, + InputField, + UIComponent, +) +from invokeai.app.invocations.model import QwenVLEncoderField +from invokeai.app.invocations.primitives import QwenImageConditioningOutput +from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + ConditioningFieldData, + QwenImageConditioningInfo, +) + +# Prompt templates and drop indices for the two Qwen Image model modes. +# These are taken directly from the diffusers pipelines. + +# Image editing mode (QwenImagePipeline) +_EDIT_SYSTEM_PROMPT = ( + "Describe the key features of the input image (color, shape, size, texture, objects, background), " + "then explain how the user's text instruction should alter or modify the image. " + "Generate a new image that meets the user's requirements while maintaining consistency " + "with the original input where appropriate." +) +_EDIT_DROP_IDX = 64 + +# Text-to-image mode (QwenImagePipeline) +_GENERATE_SYSTEM_PROMPT = ( + "Describe the image by detailing the color, shape, size, texture, quantity, " + "text, spatial relationships of the objects and background:" +) +_GENERATE_DROP_IDX = 34 + +_IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>" + + +def _build_prompt(user_prompt: str, num_images: int) -> str: + """Build the full prompt with the appropriate template based on whether reference images are provided.""" + if num_images > 0: + # Edit mode: include vision placeholders for reference images + image_tokens = _IMAGE_PLACEHOLDER * num_images + return ( + f"<|im_start|>system\n{_EDIT_SYSTEM_PROMPT}<|im_end|>\n" + f"<|im_start|>user\n{image_tokens}{user_prompt}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + else: + # Generate mode: text-only prompt + return ( + f"<|im_start|>system\n{_GENERATE_SYSTEM_PROMPT}<|im_end|>\n" + f"<|im_start|>user\n{user_prompt}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + + +@invocation( + "qwen_image_text_encoder", + title="Prompt - Qwen Image", + tags=["prompt", "conditioning", "qwen_image"], + category="conditioning", + version="1.2.0", + classification=Classification.Prototype, +) +class QwenImageTextEncoderInvocation(BaseInvocation): + """Encodes text and reference images for Qwen Image using Qwen2.5-VL.""" + + prompt: str = InputField(description="Text prompt describing the desired edit.", ui_component=UIComponent.Textarea) + reference_images: list[ImageField] = InputField( + default=[], + description="Reference images to guide the edit. The model can use multiple reference images.", + ) + qwen_vl_encoder: QwenVLEncoderField = InputField( + title="Qwen VL Encoder", + description=FieldDescriptions.qwen_vl_encoder, + input=Input.Connection, + ) + quantization: Literal["none", "int8", "nf4"] = InputField( + default="none", + description="Quantize the Qwen VL encoder to reduce VRAM usage. " + "'nf4' (4-bit) saves the most memory, 'int8' (8-bit) is a middle ground.", + ) + + @staticmethod + def _resize_for_vl_encoder(image: PILImage.Image, target_pixels: int = 512 * 512) -> PILImage.Image: + """Resize image to fit within target_pixels while preserving aspect ratio. + + Matches the diffusers pipeline's calculate_dimensions logic: the image is resized + so its total pixel count is approximately target_pixels, with dimensions rounded to + multiples of 32. This prevents large images from producing too many vision tokens + which can overwhelm the text prompt. + """ + w, h = image.size + aspect = w / h + # Compute dimensions that preserve aspect ratio at ~target_pixels total + new_w = int((target_pixels * aspect) ** 0.5) + new_h = int(target_pixels / new_w) + # Round to multiples of 32 + new_w = max(32, (new_w // 32) * 32) + new_h = max(32, (new_h // 32) * 32) + if new_w != w or new_h != h: + image = image.resize((new_w, new_h), resample=PILImage.LANCZOS) + return image + + @torch.no_grad() + def invoke(self, context: InvocationContext) -> QwenImageConditioningOutput: + # Load and resize reference images to ~1M pixels (matching diffusers pipeline) + pil_images: list[PILImage.Image] = [] + for img_field in self.reference_images: + pil_img = context.images.get_pil(img_field.image_name) + pil_img = self._resize_for_vl_encoder(pil_img.convert("RGB")) + pil_images.append(pil_img) + + prompt_embeds, prompt_mask = self._encode(context, pil_images) + prompt_embeds = prompt_embeds.detach().to("cpu") + prompt_mask = prompt_mask.detach().to("cpu") if prompt_mask is not None else None + + conditioning_data = ConditioningFieldData( + conditionings=[QwenImageConditioningInfo(prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_mask)] + ) + conditioning_name = context.conditioning.save(conditioning_data) + return QwenImageConditioningOutput.build(conditioning_name) + + def _encode( + self, context: InvocationContext, images: list[PILImage.Image] + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Encode text prompt and reference images using Qwen2.5-VL. + + Matches the diffusers QwenImagePipeline._get_qwen_prompt_embeds logic: + 1. Format prompt with the edit-specific system template + 2. Run through Qwen2.5-VL to get hidden states + 3. Extract valid (non-padding) tokens and drop the system prefix + 4. Return padded embeddings + attention mask + """ + from transformers import AutoTokenizer, Qwen2_5_VLProcessor + + try: + from transformers import Qwen2_5_VLImageProcessor as _ImageProcessorCls + except ImportError: + from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( # type: ignore[no-redef] + Qwen2VLImageProcessor as _ImageProcessorCls, + ) + + try: + from transformers import Qwen2_5_VLVideoProcessor as _VideoProcessorCls + except ImportError: + from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( # type: ignore[no-redef] + Qwen2VLVideoProcessor as _VideoProcessorCls, + ) + + # Format the prompt with one vision placeholder per reference image + text = _build_prompt(self.prompt, len(images)) + + # Build the processor + tokenizer_config = context.models.get_config(self.qwen_vl_encoder.tokenizer) + model_root = context.models.get_absolute_path(tokenizer_config) + tokenizer_dir = model_root / "tokenizer" + + tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_dir), local_files_only=True) + + image_processor = None + for search_dir in [model_root / "processor", tokenizer_dir, model_root, model_root / "image_processor"]: + if (search_dir / "preprocessor_config.json").exists(): + image_processor = _ImageProcessorCls.from_pretrained(str(search_dir), local_files_only=True) + break + if image_processor is None: + image_processor = _ImageProcessorCls() + + processor = Qwen2_5_VLProcessor( + tokenizer=tokenizer, + image_processor=image_processor, + video_processor=_VideoProcessorCls(), + ) + + context.util.signal_progress("Running Qwen2.5-VL text/vision encoder") + + if self.quantization != "none": + text_encoder, device, cleanup = self._load_quantized_encoder(context) + else: + text_encoder, device, cleanup = self._load_cached_encoder(context) + + try: + model_inputs = processor( + text=[text], + images=images if images else None, + padding=True, + return_tensors="pt", + ).to(device=device) + + outputs = text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=getattr(model_inputs, "pixel_values", None), + image_grid_thw=getattr(model_inputs, "image_grid_thw", None), + output_hidden_states=True, + ) + + # Use last hidden state (matching diffusers pipeline) + hidden_states = outputs.hidden_states[-1] + + # Extract valid (non-padding) tokens using the attention mask, + # then drop the system prompt prefix tokens. + # The drop index differs between edit mode (64) and generate mode (34). + drop_idx = _EDIT_DROP_IDX if images else _GENERATE_DROP_IDX + + attn_mask = model_inputs.attention_mask + bool_mask = attn_mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_hidden = torch.split(selected, valid_lengths.tolist(), dim=0) + + # Drop system prefix tokens and build padded output + trimmed = [h[drop_idx:] for h in split_hidden] + attn_mask_list = [torch.ones(h.size(0), dtype=torch.long, device=device) for h in trimmed] + max_seq_len = max(h.size(0) for h in trimmed) + + prompt_embeds = torch.stack( + [torch.cat([h, h.new_zeros(max_seq_len - h.size(0), h.size(1))]) for h in trimmed] + ) + encoder_attention_mask = torch.stack( + [torch.cat([m, m.new_zeros(max_seq_len - m.size(0))]) for m in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=torch.bfloat16) + finally: + if cleanup is not None: + cleanup() + + # If all tokens are valid (no padding), mask is not needed + if encoder_attention_mask.all(): + encoder_attention_mask = None + + return prompt_embeds, encoder_attention_mask + + def _load_cached_encoder(self, context: InvocationContext): + """Load the text encoder through the model cache (no quantization).""" + from transformers import Qwen2_5_VLForConditionalGeneration + + text_encoder_info = context.models.load(self.qwen_vl_encoder.text_encoder) + ctx = text_encoder_info.model_on_device() + _, text_encoder = ctx.__enter__() + device = get_effective_device(text_encoder) + assert isinstance(text_encoder, Qwen2_5_VLForConditionalGeneration) + return text_encoder, device, lambda: ctx.__exit__(None, None, None) + + def _load_quantized_encoder(self, context: InvocationContext): + """Load the text encoder with BitsAndBytes quantization, bypassing the model cache. + + BnB-quantized models are pinned to GPU and can't be moved between devices, + so they can't go through the standard model cache. The model is loaded fresh + each time and freed after use via the cleanup callback. + """ + import gc + import warnings + + from transformers import BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration + + encoder_config = context.models.get_config(self.qwen_vl_encoder.text_encoder) + model_root = context.models.get_absolute_path(encoder_config) + encoder_path = model_root / "text_encoder" + + if self.quantization == "nf4": + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_quant_type="nf4", + ) + else: # int8 + bnb_config = BitsAndBytesConfig(load_in_8bit=True) + + context.util.signal_progress("Loading Qwen2.5-VL encoder (quantized)") + with warnings.catch_warnings(): + # BnB int8 internally casts bfloat16→float16; the warning is harmless + warnings.filterwarnings("ignore", message="MatMul8bitLt.*cast.*float16") + text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained( + str(encoder_path), + quantization_config=bnb_config, + device_map="auto", + torch_dtype=torch.bfloat16, + local_files_only=True, + ) + + device = next(text_encoder.parameters()).device + + def cleanup(): + nonlocal text_encoder + del text_encoder + gc.collect() + torch.cuda.empty_cache() + + return text_encoder, device, cleanup diff --git a/invokeai/app/services/board_records/board_records_common.py b/invokeai/app/services/board_records/board_records_common.py index ab6355a393..b263f264cb 100644 --- a/invokeai/app/services/board_records/board_records_common.py +++ b/invokeai/app/services/board_records/board_records_common.py @@ -9,6 +9,17 @@ from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.model_exclude_null import BaseModelExcludeNull +class BoardVisibility(str, Enum, metaclass=MetaEnum): + """The visibility options for a board.""" + + Private = "private" + """Only the board owner (and admins) can see and modify this board.""" + Shared = "shared" + """All users can view this board, but only the owner (and admins) can modify it.""" + Public = "public" + """All users can view this board; only the owner (and admins) can modify its structure.""" + + class BoardRecord(BaseModelExcludeNull): """Deserialized board record.""" @@ -28,6 +39,10 @@ class BoardRecord(BaseModelExcludeNull): """The name of the cover image of the board.""" archived: bool = Field(description="Whether or not the board is archived.") """Whether or not the board is archived.""" + board_visibility: BoardVisibility = Field( + default=BoardVisibility.Private, description="The visibility of the board." + ) + """The visibility of the board (private, shared, or public).""" def deserialize_board_record(board_dict: dict) -> BoardRecord: @@ -44,6 +59,11 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord: updated_at = board_dict.get("updated_at", get_iso_timestamp()) deleted_at = board_dict.get("deleted_at", get_iso_timestamp()) archived = board_dict.get("archived", False) + board_visibility_raw = board_dict.get("board_visibility", BoardVisibility.Private.value) + try: + board_visibility = BoardVisibility(board_visibility_raw) + except ValueError: + board_visibility = BoardVisibility.Private return BoardRecord( board_id=board_id, @@ -54,6 +74,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord: updated_at=updated_at, deleted_at=deleted_at, archived=archived, + board_visibility=board_visibility, ) @@ -61,6 +82,7 @@ class BoardChanges(BaseModel, extra="forbid"): board_name: Optional[str] = Field(default=None, description="The board's new name.", max_length=300) cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.") archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived") + board_visibility: Optional[BoardVisibility] = Field(default=None, description="The visibility of the board.") class BoardRecordOrderBy(str, Enum, metaclass=MetaEnum): diff --git a/invokeai/app/services/board_records/board_records_sqlite.py b/invokeai/app/services/board_records/board_records_sqlite.py index a54f65686f..1e3e11c8a3 100644 --- a/invokeai/app/services/board_records/board_records_sqlite.py +++ b/invokeai/app/services/board_records/board_records_sqlite.py @@ -116,6 +116,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): (changes.archived, board_id), ) + # Change the visibility of a board + if changes.board_visibility is not None: + cursor.execute( + """--sql + UPDATE boards + SET board_visibility = ? + WHERE board_id = ?; + """, + (changes.board_visibility.value, board_id), + ) + except sqlite3.Error as e: raise BoardRecordSaveException from e return self.get(board_id) @@ -155,7 +166,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): SELECT DISTINCT boards.* FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1) + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')) {archived_filter} ORDER BY {order_by} {direction} LIMIT ? OFFSET ?; @@ -194,14 +205,14 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): SELECT COUNT(DISTINCT boards.board_id) FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1); + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')); """ else: count_query = """ SELECT COUNT(DISTINCT boards.board_id) FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1) + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')) AND boards.archived = 0; """ @@ -251,7 +262,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): SELECT DISTINCT boards.* FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1) + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')) {archived_filter} ORDER BY LOWER(boards.board_name) {direction} """ @@ -260,7 +271,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase): SELECT DISTINCT boards.* FROM boards LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id - WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1) + WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public')) {archived_filter} ORDER BY {order_by} {direction} """ diff --git a/invokeai/app/services/bulk_download/bulk_download_base.py b/invokeai/app/services/bulk_download/bulk_download_base.py index 617b611f56..6cd4ed0cba 100644 --- a/invokeai/app/services/bulk_download/bulk_download_base.py +++ b/invokeai/app/services/bulk_download/bulk_download_base.py @@ -7,7 +7,11 @@ class BulkDownloadBase(ABC): @abstractmethod def handler( - self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] + self, + image_names: Optional[list[str]], + board_id: Optional[str], + bulk_download_item_id: Optional[str], + user_id: str = "system", ) -> None: """ Create a zip file containing the images specified by the given image names or board id. @@ -15,6 +19,7 @@ class BulkDownloadBase(ABC): :param image_names: A list of image names to include in the zip file. :param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file. :param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated. + :param user_id: The ID of the user who initiated the download. """ @abstractmethod @@ -42,3 +47,12 @@ class BulkDownloadBase(ABC): :param bulk_download_item_name: The name of the bulk download item. """ + + @abstractmethod + def get_owner(self, bulk_download_item_name: str) -> Optional[str]: + """ + Get the user_id of the user who initiated the download. + + :param bulk_download_item_name: The name of the bulk download item. + :return: The user_id of the owner, or None if not tracked. + """ diff --git a/invokeai/app/services/bulk_download/bulk_download_default.py b/invokeai/app/services/bulk_download/bulk_download_default.py index dc4f8b1d81..c037e9c5c1 100644 --- a/invokeai/app/services/bulk_download/bulk_download_default.py +++ b/invokeai/app/services/bulk_download/bulk_download_default.py @@ -25,15 +25,24 @@ class BulkDownloadService(BulkDownloadBase): self._temp_directory = TemporaryDirectory() self._bulk_downloads_folder = Path(self._temp_directory.name) / "bulk_downloads" self._bulk_downloads_folder.mkdir(parents=True, exist_ok=True) + # Track which user owns each download so the fetch endpoint can enforce ownership + self._download_owners: dict[str, str] = {} def handler( - self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str] + self, + image_names: Optional[list[str]], + board_id: Optional[str], + bulk_download_item_id: Optional[str], + user_id: str = "system", ) -> None: bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID bulk_download_item_id = bulk_download_item_id or uuid_string() bulk_download_item_name = bulk_download_item_id + ".zip" - self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name) + # Record ownership so the fetch endpoint can verify the caller + self._download_owners[bulk_download_item_name] = user_id + + self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id) try: image_dtos: list[ImageDTO] = [] @@ -46,16 +55,16 @@ class BulkDownloadService(BulkDownloadBase): raise BulkDownloadParametersException() bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id) - self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name) + self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id) except ( ImageRecordNotFoundException, BoardRecordNotFoundException, BulkDownloadException, BulkDownloadParametersException, ) as e: - self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e) + self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id) except Exception as e: - self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e) + self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id) self._invoker.services.logger.error("Problem bulk downloading images.") raise e @@ -103,43 +112,60 @@ class BulkDownloadService(BulkDownloadBase): return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip() def _signal_job_started( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> None: """Signal that a bulk download job has started.""" if self._invoker: assert bulk_download_id is not None self._invoker.services.events.emit_bulk_download_started( - bulk_download_id, bulk_download_item_id, bulk_download_item_name + bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id ) def _signal_job_completed( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> None: """Signal that a bulk download job has completed.""" if self._invoker: assert bulk_download_id is not None assert bulk_download_item_name is not None self._invoker.services.events.emit_bulk_download_complete( - bulk_download_id, bulk_download_item_id, bulk_download_item_name + bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id ) def _signal_job_failed( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + exception: Exception, + user_id: str = "system", ) -> None: """Signal that a bulk download job has failed.""" if self._invoker: assert bulk_download_id is not None assert exception is not None self._invoker.services.events.emit_bulk_download_error( - bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception) + bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception), user_id=user_id ) def stop(self, *args, **kwargs): self._temp_directory.cleanup() + def get_owner(self, bulk_download_item_name: str) -> Optional[str]: + return self._download_owners.get(bulk_download_item_name) + def delete(self, bulk_download_item_name: str) -> None: path = self.get_path(bulk_download_item_name) Path(path).unlink() + self._download_owners.pop(bulk_download_item_name, None) def get_path(self, bulk_download_item_name: str) -> str: path = str(self._bulk_downloads_folder / bulk_download_item_name) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index aa1cbb5e0e..935b422a73 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -100,9 +100,9 @@ class EventServiceBase: """Emitted when a queue item's status changes""" self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status)) - def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None: + def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", user_id: str = "system") -> None: """Emitted when a batch is enqueued""" - self.dispatch(BatchEnqueuedEvent.build(enqueue_result)) + self.dispatch(BatchEnqueuedEvent.build(enqueue_result, user_id)) def emit_queue_items_retried(self, retry_result: "RetryItemsResult") -> None: """Emitted when a list of queue items are retried""" @@ -112,9 +112,9 @@ class EventServiceBase: """Emitted when a queue is cleared""" self.dispatch(QueueClearedEvent.build(queue_id)) - def emit_recall_parameters_updated(self, queue_id: str, parameters: dict) -> None: + def emit_recall_parameters_updated(self, queue_id: str, user_id: str, parameters: dict) -> None: """Emitted when recall parameters are updated""" - self.dispatch(RecallParametersUpdatedEvent.build(queue_id, parameters)) + self.dispatch(RecallParametersUpdatedEvent.build(queue_id, user_id, parameters)) # endregion @@ -194,23 +194,42 @@ class EventServiceBase: # region Bulk image download def emit_bulk_download_started( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> None: """Emitted when a bulk image download is started""" - self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) + self.dispatch( + BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id) + ) def emit_bulk_download_complete( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> None: """Emitted when a bulk image download is complete""" - self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name)) + self.dispatch( + BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id) + ) def emit_bulk_download_error( - self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str + self, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + error: str, + user_id: str = "system", ) -> None: """Emitted when a bulk image download has an error""" self.dispatch( - BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error) + BulkDownloadErrorEvent.build( + bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, user_id + ) ) # endregion diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index bfb44eb48e..998fe4f530 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -281,9 +281,10 @@ class BatchEnqueuedEvent(QueueEventBase): ) priority: int = Field(description="The priority of the batch") origin: str | None = Field(default=None, description="The origin of the batch") + user_id: str = Field(default="system", description="The ID of the user who enqueued the batch") @classmethod - def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent": + def build(cls, enqueue_result: EnqueueBatchResult, user_id: str = "system") -> "BatchEnqueuedEvent": return cls( queue_id=enqueue_result.queue_id, batch_id=enqueue_result.batch.batch_id, @@ -291,6 +292,7 @@ class BatchEnqueuedEvent(QueueEventBase): enqueued=enqueue_result.enqueued, requested=enqueue_result.requested, priority=enqueue_result.priority, + user_id=user_id, ) @@ -609,6 +611,7 @@ class BulkDownloadEventBase(EventBase): bulk_download_id: str = Field(description="The ID of the bulk image download") bulk_download_item_id: str = Field(description="The ID of the bulk image download item") bulk_download_item_name: str = Field(description="The name of the bulk image download item") + user_id: str = Field(default="system", description="The ID of the user who initiated the download") @payload_schema.register @@ -619,12 +622,17 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase): @classmethod def build( - cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + cls, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> "BulkDownloadStartedEvent": return cls( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, + user_id=user_id, ) @@ -636,12 +644,17 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase): @classmethod def build( - cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str + cls, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + user_id: str = "system", ) -> "BulkDownloadCompleteEvent": return cls( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, + user_id=user_id, ) @@ -655,13 +668,19 @@ class BulkDownloadErrorEvent(BulkDownloadEventBase): @classmethod def build( - cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str + cls, + bulk_download_id: str, + bulk_download_item_id: str, + bulk_download_item_name: str, + error: str, + user_id: str = "system", ) -> "BulkDownloadErrorEvent": return cls( bulk_download_id=bulk_download_id, bulk_download_item_id=bulk_download_item_id, bulk_download_item_name=bulk_download_item_name, error=error, + user_id=user_id, ) @@ -671,8 +690,9 @@ class RecallParametersUpdatedEvent(QueueEventBase): __event_name__ = "recall_parameters_updated" + user_id: str = Field(description="The ID of the user whose recall parameters were updated") parameters: dict[str, Any] = Field(description="The recall parameters that were updated") @classmethod - def build(cls, queue_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent": - return cls(queue_id=queue_id, parameters=parameters) + def build(cls, queue_id: str, user_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent": + return cls(queue_id=queue_id, user_id=user_id, parameters=parameters) diff --git a/invokeai/app/services/events/events_fastapievents.py b/invokeai/app/services/events/events_fastapievents.py index f44eecc555..90e1402773 100644 --- a/invokeai/app/services/events/events_fastapievents.py +++ b/invokeai/app/services/events/events_fastapievents.py @@ -46,3 +46,9 @@ class FastAPIEventService(EventServiceBase): except asyncio.CancelledError as e: raise e # Raise a proper error + except Exception: + import logging + + logging.getLogger("InvokeAI").error( + f"Error dispatching event {getattr(event, '__event_name__', event)}", exc_info=True + ) diff --git a/invokeai/app/services/image_records/image_records_base.py b/invokeai/app/services/image_records/image_records_base.py index 16405c5270..457cf2f468 100644 --- a/invokeai/app/services/image_records/image_records_base.py +++ b/invokeai/app/services/image_records/image_records_base.py @@ -74,8 +74,8 @@ class ImageRecordStorageBase(ABC): pass @abstractmethod - def get_intermediates_count(self) -> int: - """Gets a count of all intermediate images.""" + def get_intermediates_count(self, user_id: Optional[str] = None) -> int: + """Gets a count of intermediate images. If user_id is provided, only counts that user's intermediates.""" pass @abstractmethod @@ -97,6 +97,11 @@ class ImageRecordStorageBase(ABC): """Saves an image record.""" pass + @abstractmethod + def get_user_id(self, image_name: str) -> Optional[str]: + """Gets the user_id of the image owner. Returns None if image not found.""" + pass + @abstractmethod def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]: """Gets the most recent image for a board.""" diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index c6c237fc1e..07126d53a9 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -46,6 +46,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): return deserialize_image_record(dict(result)) + def get_user_id(self, image_name: str) -> Optional[str]: + with self._db.transaction() as cursor: + cursor.execute( + """--sql + SELECT user_id FROM images + WHERE image_name = ?; + """, + (image_name,), + ) + result = cast(Optional[sqlite3.Row], cursor.fetchone()) + if not result: + return None + return cast(Optional[str], dict(result).get("user_id")) + def get_metadata(self, image_name: str) -> Optional[MetadataField]: with self._db.transaction() as cursor: try: @@ -269,14 +283,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): except sqlite3.Error as e: raise ImageRecordDeleteException from e - def get_intermediates_count(self) -> int: + def get_intermediates_count(self, user_id: Optional[str] = None) -> int: with self._db.transaction() as cursor: - cursor.execute( - """--sql - SELECT COUNT(*) FROM images - WHERE is_intermediate = TRUE; - """ - ) + query = "SELECT COUNT(*) FROM images WHERE is_intermediate = TRUE" + params: list[str] = [] + if user_id is not None: + query += " AND user_id = ?" + params.append(user_id) + cursor.execute(query, params) count = cast(int, cursor.fetchone()[0]) return count diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py index d11d75b3c1..aebbead2f3 100644 --- a/invokeai/app/services/images/images_base.py +++ b/invokeai/app/services/images/images_base.py @@ -143,8 +143,8 @@ class ImageServiceABC(ABC): pass @abstractmethod - def get_intermediates_count(self) -> int: - """Gets the number of intermediate images.""" + def get_intermediates_count(self, user_id: Optional[str] = None) -> int: + """Gets the number of intermediate images. If user_id is provided, only counts that user's intermediates.""" pass @abstractmethod diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py index e82bd7f4de..0f03f7c400 100644 --- a/invokeai/app/services/images/images_default.py +++ b/invokeai/app/services/images/images_default.py @@ -310,9 +310,9 @@ class ImageService(ImageServiceABC): self.__invoker.services.logger.error("Problem deleting image records and files") raise e - def get_intermediates_count(self) -> int: + def get_intermediates_count(self, user_id: Optional[str] = None) -> int: try: - return self.__invoker.services.image_records.get_intermediates_count() + return self.__invoker.services.image_records.get_intermediates_count(user_id=user_id) except Exception as e: self.__invoker.services.logger.error("Problem getting intermediates count") raise e diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 361c2e4811..49d3cfdf7f 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union import torch import yaml -from huggingface_hub import HfFolder +from huggingface_hub import get_token as hf_get_token from pydantic.networks import AnyHttpUrl from pydantic_core import Url from requests import Session @@ -1115,7 +1115,7 @@ class ModelInstallService(ModelInstallServiceBase): ) -> ModelInstallJob: # Add user's cached access token to HuggingFace requests if source.access_token is None: - source.access_token = HfFolder.get_token() + source.access_token = hf_get_token() remote_files, metadata = self._remote_files_from_source(source) return self._import_remote_model( source=source, diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 318ebb000e..6420949c29 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -30,6 +30,7 @@ from invokeai.backend.model_manager.taxonomy import ( ModelType, ModelVariantType, Qwen3VariantType, + QwenImageVariantType, SchedulerPredictionType, ZImageVariantType, ) @@ -109,7 +110,13 @@ class ModelRecordChanges(BaseModelExcludeNull): # Checkpoint-specific changes # TODO(MM2): Should we expose these? Feels footgun-y... variant: Optional[ - ModelVariantType | ClipVariantType | FluxVariantType | Flux2VariantType | ZImageVariantType | Qwen3VariantType + ModelVariantType + | ClipVariantType + | FluxVariantType + | Flux2VariantType + | ZImageVariantType + | QwenImageVariantType + | Qwen3VariantType ] = Field(description="The variant of the model.", default=None) prediction_type: Optional[SchedulerPredictionType] = Field( description="The prediction type of the model.", default=None diff --git a/invokeai/app/services/session_queue/session_queue_base.py b/invokeai/app/services/session_queue/session_queue_base.py index 3c037dc77a..14b93d97fc 100644 --- a/invokeai/app/services/session_queue/session_queue_base.py +++ b/invokeai/app/services/session_queue/session_queue_base.py @@ -78,13 +78,15 @@ class SessionQueueBase(ABC): pass @abstractmethod - def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination: - """Gets the counts of queue items by destination""" + def get_counts_by_destination( + self, queue_id: str, destination: str, user_id: Optional[str] = None + ) -> SessionQueueCountsByDestination: + """Gets the counts of queue items by destination. If user_id is provided, only counts that user's items.""" pass @abstractmethod - def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus: - """Gets the status of a batch""" + def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus: + """Gets the status of a batch. If user_id is provided, only counts that user's items.""" pass @abstractmethod @@ -172,8 +174,9 @@ class SessionQueueBase(ABC): self, queue_id: str, order_dir: SQLiteDirection = SQLiteDirection.Descending, + user_id: Optional[str] = None, ) -> ItemIdsResult: - """Gets all queue item ids that match the given parameters""" + """Gets all queue item ids that match the given parameters. If user_id is provided, only returns items for that user.""" pass @abstractmethod diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 5854442211..09820fe621 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -304,12 +304,6 @@ class SessionQueueStatus(BaseModel): failed: int = Field(..., description="Number of queue items with status 'error'") canceled: int = Field(..., description="Number of queue items with status 'canceled'") total: int = Field(..., description="Total number of queue items") - user_pending: Optional[int] = Field( - default=None, description="Number of queue items with status 'pending' for the current user" - ) - user_in_progress: Optional[int] = Field( - default=None, description="Number of queue items with status 'in_progress' for the current user" - ) class SessionQueueCountsByDestination(BaseModel): diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 4f46136fd7..070a7cef29 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -151,7 +151,7 @@ class SqliteSessionQueue(SessionQueueBase): priority=priority, item_ids=item_ids, ) - self.__invoker.services.events.emit_batch_enqueued(enqueue_result) + self.__invoker.services.events.emit_batch_enqueued(enqueue_result, user_id=user_id) return enqueue_result def dequeue(self) -> Optional[SessionQueueItem]: @@ -765,15 +765,21 @@ class SqliteSessionQueue(SessionQueueBase): self, queue_id: str, order_dir: SQLiteDirection = SQLiteDirection.Descending, + user_id: Optional[str] = None, ) -> ItemIdsResult: with self._db.transaction() as cursor_: - query = f"""--sql + query = """--sql SELECT item_id FROM session_queue WHERE queue_id = ? - ORDER BY created_at {order_dir.value} """ - query_params = [queue_id] + query_params: list[str] = [queue_id] + + if user_id is not None: + query += " AND user_id = ?" + query_params.append(user_id) + + query += f" ORDER BY created_at {order_dir.value}" cursor_.execute(query, query_params) result = cast(list[sqlite3.Row], cursor_.fetchall()) @@ -783,20 +789,7 @@ class SqliteSessionQueue(SessionQueueBase): def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus: with self._db.transaction() as cursor: - # Get total counts - cursor.execute( - """--sql - SELECT status, count(*) - FROM session_queue - WHERE queue_id = ? - GROUP BY status - """, - (queue_id,), - ) - counts_result = cast(list[sqlite3.Row], cursor.fetchall()) - - # Get user-specific counts if user_id is provided (using a single query with CASE) - user_counts_result = [] + # When user_id is provided (non-admin), only count that user's items if user_id is not None: cursor.execute( """--sql @@ -807,48 +800,51 @@ class SqliteSessionQueue(SessionQueueBase): """, (queue_id, user_id), ) - user_counts_result = cast(list[sqlite3.Row], cursor.fetchall()) + else: + cursor.execute( + """--sql + SELECT status, count(*) + FROM session_queue + WHERE queue_id = ? + GROUP BY status + """, + (queue_id,), + ) + counts_result = cast(list[sqlite3.Row], cursor.fetchall()) current_item = self.get_current(queue_id=queue_id) total = sum(row[1] or 0 for row in counts_result) counts: dict[str, int] = {row[0]: row[1] for row in counts_result} - # Process user-specific counts if available - user_pending = None - user_in_progress = None - if user_id is not None: - user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result} - user_pending = user_counts.get("pending", 0) - user_in_progress = user_counts.get("in_progress", 0) + # For non-admin users, hide current item details if they don't own it + show_current_item = current_item is not None and (user_id is None or current_item.user_id == user_id) return SessionQueueStatus( queue_id=queue_id, - item_id=current_item.item_id if current_item else None, - session_id=current_item.session_id if current_item else None, - batch_id=current_item.batch_id if current_item else None, + item_id=current_item.item_id if show_current_item else None, + session_id=current_item.session_id if show_current_item else None, + batch_id=current_item.batch_id if show_current_item else None, pending=counts.get("pending", 0), in_progress=counts.get("in_progress", 0), completed=counts.get("completed", 0), failed=counts.get("failed", 0), canceled=counts.get("canceled", 0), total=total, - user_pending=user_pending, - user_in_progress=user_in_progress, ) - def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus: + def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus: with self._db.transaction() as cursor: - cursor.execute( - """--sql + query = """--sql SELECT status, count(*), origin, destination FROM session_queue - WHERE - queue_id = ? - AND batch_id = ? - GROUP BY status - """, - (queue_id, batch_id), - ) + WHERE queue_id = ? AND batch_id = ? + """ + params: list[str] = [queue_id, batch_id] + if user_id is not None: + query += " AND user_id = ?" + params.append(user_id) + query += " GROUP BY status" + cursor.execute(query, params) result = cast(list[sqlite3.Row], cursor.fetchall()) total = sum(row[1] or 0 for row in result) counts: dict[str, int] = {row[0]: row[1] for row in result} @@ -868,18 +864,21 @@ class SqliteSessionQueue(SessionQueueBase): total=total, ) - def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination: + def get_counts_by_destination( + self, queue_id: str, destination: str, user_id: Optional[str] = None + ) -> SessionQueueCountsByDestination: with self._db.transaction() as cursor: - cursor.execute( - """--sql + query = """--sql SELECT status, count(*) FROM session_queue - WHERE queue_id = ? - AND destination = ? - GROUP BY status - """, - (queue_id, destination), - ) + WHERE queue_id = ? AND destination = ? + """ + params: list[str] = [queue_id, destination] + if user_id is not None: + query += " AND user_id = ?" + params.append(user_id) + query += " GROUP BY status" + cursor.execute(query, params) counts_result = cast(list[sqlite3.Row], cursor.fetchall()) total = sum(row[1] or 0 for row in counts_result) diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 645509f1dd..fb8ca9fca3 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -30,6 +30,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_27 import build_migration_27 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_28 import build_migration_28 +from invokeai.app.services.shared.sqlite_migrator.migrations.migration_29 import build_migration_29 from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator @@ -77,6 +79,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto migrator.register_migration(build_migration_25(app_config=config, logger=logger)) migrator.register_migration(build_migration_26(app_config=config, logger=logger)) migrator.register_migration(build_migration_27()) + migrator.register_migration(build_migration_28()) + migrator.register_migration(build_migration_29()) migrator.run_migrations() return db diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py new file mode 100644 index 0000000000..0cbd683ab5 --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_28.py @@ -0,0 +1,45 @@ +"""Migration 28: Add per-user workflow isolation columns to workflow_library. + +This migration adds the database columns required for multiuser workflow isolation +to the workflow_library table: +- user_id: the owner of the workflow (defaults to 'system' for existing workflows) +- is_public: whether the workflow is shared with all users +""" + +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration28Callback: + """Migration to add user_id and is_public to the workflow_library table.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._update_workflow_library_table(cursor) + + def _update_workflow_library_table(self, cursor: sqlite3.Cursor) -> None: + """Add user_id and is_public columns to workflow_library table.""" + cursor.execute("PRAGMA table_info(workflow_library);") + columns = [row[1] for row in cursor.fetchall()] + + if "user_id" not in columns: + cursor.execute("ALTER TABLE workflow_library ADD COLUMN user_id TEXT DEFAULT 'system';") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_user_id ON workflow_library(user_id);") + + if "is_public" not in columns: + cursor.execute("ALTER TABLE workflow_library ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_is_public ON workflow_library(is_public);") + + +def build_migration_28() -> Migration: + """Builds the migration object for migrating from version 27 to version 28. + + This migration adds per-user workflow isolation to the workflow_library table: + - user_id column: identifies the owner of each workflow + - is_public column: controls whether a workflow is shared with all users + """ + return Migration( + from_version=27, + to_version=28, + callback=Migration28Callback(), + ) diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_29.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_29.py new file mode 100644 index 0000000000..c9eb7c901b --- /dev/null +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_29.py @@ -0,0 +1,53 @@ +"""Migration 29: Add board_visibility column to boards table. + +This migration adds a board_visibility column to the boards table to support +three visibility levels: + - 'private': only the board owner (and admins) can view/modify + - 'shared': all users can view, but only the owner (and admins) can modify + - 'public': all users can view; only the owner (and admins) can modify the + board structure (rename/archive/delete) + +Existing boards with is_public = 1 are migrated to 'public'. +All other existing boards default to 'private'. +""" + +import sqlite3 + +from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration + + +class Migration29Callback: + """Migration to add board_visibility column to the boards table.""" + + def __call__(self, cursor: sqlite3.Cursor) -> None: + self._update_boards_table(cursor) + + def _update_boards_table(self, cursor: sqlite3.Cursor) -> None: + """Add board_visibility column to boards table.""" + # Check if boards table exists + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='boards';") + if cursor.fetchone() is None: + return + + cursor.execute("PRAGMA table_info(boards);") + columns = [row[1] for row in cursor.fetchall()] + + if "board_visibility" not in columns: + cursor.execute("ALTER TABLE boards ADD COLUMN board_visibility TEXT NOT NULL DEFAULT 'private';") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_boards_board_visibility ON boards(board_visibility);") + # Migrate existing is_public = 1 boards to 'public' + if "is_public" in columns: + cursor.execute("UPDATE boards SET board_visibility = 'public' WHERE is_public = 1;") + + +def build_migration_29() -> Migration: + """Builds the migration object for migrating from version 28 to version 29. + + This migration adds the board_visibility column to the boards table, + supporting 'private', 'shared', and 'public' visibility levels. + """ + return Migration( + from_version=28, + to_version=29, + callback=Migration29Callback(), + ) diff --git a/invokeai/app/services/users/users_base.py b/invokeai/app/services/users/users_base.py index 728a0adfa3..dd789b561e 100644 --- a/invokeai/app/services/users/users_base.py +++ b/invokeai/app/services/users/users_base.py @@ -131,6 +131,15 @@ class UserServiceBase(ABC): """ pass + @abstractmethod + def get_admin_email(self) -> str | None: + """Get the email address of the first active admin user. + + Returns: + Email address of the first active admin, or None if no admin exists + """ + pass + @abstractmethod def count_admins(self) -> int: """Count active admin users. diff --git a/invokeai/app/services/users/users_default.py b/invokeai/app/services/users/users_default.py index 709e4cb82c..6e47288212 100644 --- a/invokeai/app/services/users/users_default.py +++ b/invokeai/app/services/users/users_default.py @@ -256,6 +256,20 @@ class UserService(UserServiceBase): for row in rows ] + def get_admin_email(self) -> str | None: + """Get the email address of the first active admin user.""" + with self._db.transaction() as cursor: + cursor.execute( + """ + SELECT email FROM users + WHERE is_admin = TRUE AND is_active = TRUE + ORDER BY created_at ASC + LIMIT 1 + """, + ) + row = cursor.fetchone() + return row[0] if row else None + def count_admins(self) -> int: """Count active admin users.""" with self._db.transaction() as cursor: diff --git a/invokeai/app/services/workflow_records/workflow_records_base.py b/invokeai/app/services/workflow_records/workflow_records_base.py index d5cf319594..856a6c6d49 100644 --- a/invokeai/app/services/workflow_records/workflow_records_base.py +++ b/invokeai/app/services/workflow_records/workflow_records_base.py @@ -4,6 +4,7 @@ from typing import Optional from invokeai.app.services.shared.pagination import PaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.workflow_records.workflow_records_common import ( + WORKFLOW_LIBRARY_DEFAULT_USER_ID, Workflow, WorkflowCategory, WorkflowRecordDTO, @@ -22,18 +23,18 @@ class WorkflowRecordsStorageBase(ABC): pass @abstractmethod - def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO: + def create(self, workflow: WorkflowWithoutID, user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID) -> WorkflowRecordDTO: """Creates a workflow.""" pass @abstractmethod - def update(self, workflow: Workflow) -> WorkflowRecordDTO: - """Updates a workflow.""" + def update(self, workflow: Workflow, user_id: Optional[str] = None) -> WorkflowRecordDTO: + """Updates a workflow. When user_id is provided, the UPDATE is scoped to that user.""" pass @abstractmethod - def delete(self, workflow_id: str) -> None: - """Deletes a workflow.""" + def delete(self, workflow_id: str, user_id: Optional[str] = None) -> None: + """Deletes a workflow. When user_id is provided, the DELETE is scoped to that user.""" pass @abstractmethod @@ -47,6 +48,8 @@ class WorkflowRecordsStorageBase(ABC): query: Optional[str], tags: Optional[list[str]], has_been_opened: Optional[bool], + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> PaginatedResults[WorkflowRecordListItemDTO]: """Gets many workflows.""" pass @@ -56,6 +59,8 @@ class WorkflowRecordsStorageBase(ABC): self, categories: list[WorkflowCategory], has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: """Gets a dictionary of counts for each of the provided categories.""" pass @@ -66,19 +71,28 @@ class WorkflowRecordsStorageBase(ABC): tags: list[str], categories: Optional[list[WorkflowCategory]] = None, has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: """Gets a dictionary of counts for each of the provided tags.""" pass @abstractmethod - def update_opened_at(self, workflow_id: str) -> None: - """Open a workflow.""" + def update_opened_at(self, workflow_id: str, user_id: Optional[str] = None) -> None: + """Open a workflow. When user_id is provided, the UPDATE is scoped to that user.""" pass @abstractmethod def get_all_tags( self, categories: Optional[list[WorkflowCategory]] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> list[str]: """Gets all unique tags from workflows.""" pass + + @abstractmethod + def update_is_public(self, workflow_id: str, is_public: bool, user_id: Optional[str] = None) -> WorkflowRecordDTO: + """Updates the is_public field of a workflow. When user_id is provided, the UPDATE is scoped to that user.""" + pass diff --git a/invokeai/app/services/workflow_records/workflow_records_common.py b/invokeai/app/services/workflow_records/workflow_records_common.py index e0cea37468..9c505530c9 100644 --- a/invokeai/app/services/workflow_records/workflow_records_common.py +++ b/invokeai/app/services/workflow_records/workflow_records_common.py @@ -9,6 +9,9 @@ from invokeai.app.util.metaenum import MetaEnum __workflow_meta_version__ = semver.Version.parse("1.0.0") +WORKFLOW_LIBRARY_DEFAULT_USER_ID = "system" +"""Default user_id for workflows created in single-user mode or migrated from pre-multiuser databases.""" + class ExposedField(BaseModel): nodeId: str @@ -26,6 +29,7 @@ class WorkflowRecordOrderBy(str, Enum, metaclass=MetaEnum): UpdatedAt = "updated_at" OpenedAt = "opened_at" Name = "name" + IsPublic = "is_public" class WorkflowCategory(str, Enum, metaclass=MetaEnum): @@ -100,6 +104,8 @@ class WorkflowRecordDTOBase(BaseModel): opened_at: Optional[Union[datetime.datetime, str]] = Field( default=None, description="The opened timestamp of the workflow." ) + user_id: str = Field(description="The id of the user who owns this workflow.") + is_public: bool = Field(description="Whether this workflow is shared with all users.") class WorkflowRecordDTO(WorkflowRecordDTOBase): diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py index 0f72f7cd92..c83d87eff6 100644 --- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py +++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py @@ -7,6 +7,7 @@ from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase from invokeai.app.services.workflow_records.workflow_records_common import ( + WORKFLOW_LIBRARY_DEFAULT_USER_ID, Workflow, WorkflowCategory, WorkflowNotFoundError, @@ -36,7 +37,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): with self._db.transaction() as cursor: cursor.execute( """--sql - SELECT workflow_id, workflow, name, created_at, updated_at, opened_at + SELECT workflow_id, workflow, name, created_at, updated_at, opened_at, user_id, is_public FROM workflow_library WHERE workflow_id = ?; """, @@ -47,7 +48,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found") return WorkflowRecordDTO.from_dict(dict(row)) - def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO: + def create(self, workflow: WorkflowWithoutID, user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID) -> WorkflowRecordDTO: if workflow.meta.category is WorkflowCategory.Default: raise ValueError("Default workflows cannot be created via this method") @@ -57,43 +58,98 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): """--sql INSERT OR IGNORE INTO workflow_library ( workflow_id, - workflow + workflow, + user_id ) - VALUES (?, ?); + VALUES (?, ?, ?); """, - (workflow_with_id.id, workflow_with_id.model_dump_json()), + (workflow_with_id.id, workflow_with_id.model_dump_json(), user_id), ) return self.get(workflow_with_id.id) - def update(self, workflow: Workflow) -> WorkflowRecordDTO: + def update(self, workflow: Workflow, user_id: Optional[str] = None) -> WorkflowRecordDTO: if workflow.meta.category is WorkflowCategory.Default: raise ValueError("Default workflows cannot be updated") with self._db.transaction() as cursor: - cursor.execute( - """--sql - UPDATE workflow_library - SET workflow = ? - WHERE workflow_id = ? AND category = 'user'; - """, - (workflow.model_dump_json(), workflow.id), - ) + if user_id is not None: + cursor.execute( + """--sql + UPDATE workflow_library + SET workflow = ? + WHERE workflow_id = ? AND category = 'user' AND user_id = ?; + """, + (workflow.model_dump_json(), workflow.id, user_id), + ) + else: + cursor.execute( + """--sql + UPDATE workflow_library + SET workflow = ? + WHERE workflow_id = ? AND category = 'user'; + """, + (workflow.model_dump_json(), workflow.id), + ) return self.get(workflow.id) - def delete(self, workflow_id: str) -> None: + def delete(self, workflow_id: str, user_id: Optional[str] = None) -> None: if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default: raise ValueError("Default workflows cannot be deleted") with self._db.transaction() as cursor: - cursor.execute( - """--sql - DELETE from workflow_library - WHERE workflow_id = ? AND category = 'user'; - """, - (workflow_id,), - ) + if user_id is not None: + cursor.execute( + """--sql + DELETE from workflow_library + WHERE workflow_id = ? AND category = 'user' AND user_id = ?; + """, + (workflow_id, user_id), + ) + else: + cursor.execute( + """--sql + DELETE from workflow_library + WHERE workflow_id = ? AND category = 'user'; + """, + (workflow_id,), + ) return None + def update_is_public(self, workflow_id: str, is_public: bool, user_id: Optional[str] = None) -> WorkflowRecordDTO: + """Updates the is_public field of a workflow and manages the 'shared' tag automatically.""" + record = self.get(workflow_id) + workflow = record.workflow + + # Manage "shared" tag: add when public, remove when private + tags_list = [t.strip() for t in workflow.tags.split(",") if t.strip()] if workflow.tags else [] + if is_public and "shared" not in tags_list: + tags_list.append("shared") + elif not is_public and "shared" in tags_list: + tags_list.remove("shared") + updated_tags = ", ".join(tags_list) + updated_workflow = workflow.model_copy(update={"tags": updated_tags}) + + with self._db.transaction() as cursor: + if user_id is not None: + cursor.execute( + """--sql + UPDATE workflow_library + SET workflow = ?, is_public = ? + WHERE workflow_id = ? AND category = 'user' AND user_id = ?; + """, + (updated_workflow.model_dump_json(), is_public, workflow_id, user_id), + ) + else: + cursor.execute( + """--sql + UPDATE workflow_library + SET workflow = ?, is_public = ? + WHERE workflow_id = ? AND category = 'user'; + """, + (updated_workflow.model_dump_json(), is_public, workflow_id), + ) + return self.get(workflow_id) + def get_many( self, order_by: WorkflowRecordOrderBy, @@ -104,6 +160,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): query: Optional[str] = None, tags: Optional[list[str]] = None, has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> PaginatedResults[WorkflowRecordListItemDTO]: with self._db.transaction() as cursor: # sanitize! @@ -122,7 +180,9 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): created_at, updated_at, opened_at, - tags + tags, + user_id, + is_public FROM workflow_library """ count_query = "SELECT COUNT(*) FROM workflow_library" @@ -177,6 +237,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): conditions.append(query_condition) params.extend([wildcard_query, wildcard_query, wildcard_query]) + if user_id is not None: + # Scope to the given user but always include default workflows + conditions.append("(user_id = ? OR category = 'default')") + params.append(user_id) + + if is_public is True: + conditions.append("is_public = TRUE") + elif is_public is False: + conditions.append("is_public = FALSE") + if conditions: # If there are conditions, add a WHERE clause and then join the conditions main_query += " WHERE " @@ -226,6 +296,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): tags: list[str], categories: Optional[list[WorkflowCategory]] = None, has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: if not tags: return {} @@ -248,6 +320,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): elif has_been_opened is False: base_conditions.append("opened_at IS NULL") + if user_id is not None: + # Scope to the given user but always include default workflows + base_conditions.append("(user_id = ? OR category = 'default')") + base_params.append(user_id) + + if is_public is True: + base_conditions.append("is_public = TRUE") + elif is_public is False: + base_conditions.append("is_public = FALSE") + # For each tag to count, run a separate query for tag in tags: # Start with the base conditions @@ -277,6 +359,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): self, categories: list[WorkflowCategory], has_been_opened: Optional[bool] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> dict[str, int]: with self._db.transaction() as cursor: result: dict[str, int] = {} @@ -296,6 +380,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): elif has_been_opened is False: base_conditions.append("opened_at IS NULL") + if user_id is not None: + # Scope to the given user but always include default workflows + base_conditions.append("(user_id = ? OR category = 'default')") + base_params.append(user_id) + + if is_public is True: + base_conditions.append("is_public = TRUE") + elif is_public is False: + base_conditions.append("is_public = FALSE") + # For each category to count, run a separate query for category in categories: # Start with the base conditions @@ -321,20 +415,32 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): return result - def update_opened_at(self, workflow_id: str) -> None: + def update_opened_at(self, workflow_id: str, user_id: Optional[str] = None) -> None: with self._db.transaction() as cursor: - cursor.execute( - f"""--sql - UPDATE workflow_library - SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW') - WHERE workflow_id = ?; - """, - (workflow_id,), - ) + if user_id is not None: + cursor.execute( + f"""--sql + UPDATE workflow_library + SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW') + WHERE workflow_id = ? AND user_id = ?; + """, + (workflow_id, user_id), + ) + else: + cursor.execute( + f"""--sql + UPDATE workflow_library + SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW') + WHERE workflow_id = ?; + """, + (workflow_id,), + ) def get_all_tags( self, categories: Optional[list[WorkflowCategory]] = None, + user_id: Optional[str] = None, + is_public: Optional[bool] = None, ) -> list[str]: with self._db.transaction() as cursor: conditions: list[str] = [] @@ -349,6 +455,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase): conditions.append(f"category IN ({placeholders})") params.extend([category.value for category in categories]) + if user_id is not None: + # Scope to the given user but always include default workflows + conditions.append("(user_id = ? OR category = 'default')") + params.append(user_id) + + if is_public is True: + conditions.append("is_public = TRUE") + elif is_public is False: + conditions.append("is_public = FALSE") + stmt = """--sql SELECT DISTINCT tags FROM workflow_library diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 0e2faeca39..08dc9a2265 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -93,6 +93,29 @@ COGVIEW4_LATENT_RGB_FACTORS = [ [-0.00955853, -0.00980067, -0.00977842], ] +# Qwen Image uses the same VAE as Wan 2.1 (16-channel). +# Factors from ComfyUI: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py +QWEN_IMAGE_LATENT_RGB_FACTORS = [ + [-0.1299, -0.1692, 0.2932], + [0.0671, 0.0406, 0.0442], + [0.3568, 0.2548, 0.1747], + [0.0372, 0.2344, 0.1420], + [0.0313, 0.0189, -0.0328], + [0.0296, -0.0956, -0.0665], + [-0.3477, -0.4059, -0.2925], + [0.0166, 0.1902, 0.1975], + [-0.0412, 0.0267, -0.1364], + [-0.1293, 0.0740, 0.1636], + [0.0680, 0.3019, 0.1128], + [0.0032, 0.0581, 0.0639], + [-0.1251, 0.0927, 0.1699], + [0.0060, -0.0633, 0.0005], + [0.3477, 0.2275, 0.2950], + [0.1984, 0.0913, 0.1861], +] + +QWEN_IMAGE_LATENT_RGB_BIAS = [-0.1835, -0.0868, -0.3360] + # FLUX.2 uses 32 latent channels. # Factors from ComfyUI: https://github.com/Comfy-Org/ComfyUI/blob/main/comfy/latent_formats.py FLUX2_LATENT_RGB_FACTORS = [ @@ -232,6 +255,9 @@ def diffusion_step_callback( latent_rgb_factors = SD3_5_LATENT_RGB_FACTORS elif base_model == BaseModelType.CogView4: latent_rgb_factors = COGVIEW4_LATENT_RGB_FACTORS + elif base_model == BaseModelType.QwenImage: + latent_rgb_factors = QWEN_IMAGE_LATENT_RGB_FACTORS + latent_rgb_bias = QWEN_IMAGE_LATENT_RGB_BIAS elif base_model == BaseModelType.Flux: latent_rgb_factors = FLUX_LATENT_RGB_FACTORS elif base_model == BaseModelType.Flux2: diff --git a/invokeai/backend/model_manager/configs/factory.py b/invokeai/backend/model_manager/configs/factory.py index 3865ea562a..4d26b4c334 100644 --- a/invokeai/backend/model_manager/configs/factory.py +++ b/invokeai/backend/model_manager/configs/factory.py @@ -50,6 +50,7 @@ from invokeai.backend.model_manager.configs.lora import ( LoRA_LyCORIS_Anima_Config, LoRA_LyCORIS_Flux2_Config, LoRA_LyCORIS_FLUX_Config, + LoRA_LyCORIS_QwenImage_Config, LoRA_LyCORIS_SD1_Config, LoRA_LyCORIS_SD2_Config, LoRA_LyCORIS_SDXL_Config, @@ -71,6 +72,7 @@ from invokeai.backend.model_manager.configs.main import ( Main_Diffusers_CogView4_Config, Main_Diffusers_Flux2_Config, Main_Diffusers_FLUX_Config, + Main_Diffusers_QwenImage_Config, Main_Diffusers_SD1_Config, Main_Diffusers_SD2_Config, Main_Diffusers_SD3_Config, @@ -79,6 +81,7 @@ from invokeai.backend.model_manager.configs.main import ( Main_Diffusers_ZImage_Config, Main_GGUF_Flux2_Config, Main_GGUF_FLUX_Config, + Main_GGUF_QwenImage_Config, Main_GGUF_ZImage_Config, MainModelDefaultSettings, ) @@ -163,6 +166,7 @@ AnyModelConfig = Annotated[ Annotated[Main_Diffusers_FLUX_Config, Main_Diffusers_FLUX_Config.get_tag()], Annotated[Main_Diffusers_Flux2_Config, Main_Diffusers_Flux2_Config.get_tag()], Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()], + Annotated[Main_Diffusers_QwenImage_Config, Main_Diffusers_QwenImage_Config.get_tag()], Annotated[Main_Diffusers_ZImage_Config, Main_Diffusers_ZImage_Config.get_tag()], # Main (Pipeline) - checkpoint format # IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation @@ -181,6 +185,7 @@ AnyModelConfig = Annotated[ Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()], Annotated[Main_GGUF_Flux2_Config, Main_GGUF_Flux2_Config.get_tag()], Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()], + Annotated[Main_GGUF_QwenImage_Config, Main_GGUF_QwenImage_Config.get_tag()], Annotated[Main_GGUF_ZImage_Config, Main_GGUF_ZImage_Config.get_tag()], # VAE - checkpoint format Annotated[VAE_Checkpoint_SD1_Config, VAE_Checkpoint_SD1_Config.get_tag()], @@ -213,6 +218,7 @@ AnyModelConfig = Annotated[ Annotated[LoRA_LyCORIS_Flux2_Config, LoRA_LyCORIS_Flux2_Config.get_tag()], Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()], Annotated[LoRA_LyCORIS_ZImage_Config, LoRA_LyCORIS_ZImage_Config.get_tag()], + Annotated[LoRA_LyCORIS_QwenImage_Config, LoRA_LyCORIS_QwenImage_Config.get_tag()], Annotated[LoRA_LyCORIS_Anima_Config, LoRA_LyCORIS_Anima_Config.get_tag()], # LoRA - OMI format Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()], diff --git a/invokeai/backend/model_manager/configs/lora.py b/invokeai/backend/model_manager/configs/lora.py index 65f2d1c08c..88f917d0d3 100644 --- a/invokeai/backend/model_manager/configs/lora.py +++ b/invokeai/backend/model_manager/configs/lora.py @@ -772,6 +772,85 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): raise NotAMatchError("model does not look like a Z-Image LoRA") +class LoRA_LyCORIS_QwenImage_Config(LoRA_LyCORIS_Config_Base, Config_Base): + """Model config for Qwen Image Edit LoRA models in LyCORIS format.""" + + base: Literal[BaseModelType.QwenImage] = Field(default=BaseModelType.QwenImage) + + @classmethod + def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None: + """Qwen Image Edit LoRAs have keys like transformer_blocks.X.attn.to_k.lora_down.weight.""" + state_dict = mod.load_state_dict() + + has_qwen_ie_keys = state_dict_has_any_keys_starting_with( + state_dict, + { + "transformer_blocks.", + "transformer.transformer_blocks.", + "lora_unet_transformer_blocks_", # Kohya format + }, + ) + has_lora_suffix = state_dict_has_any_keys_ending_with( + state_dict, + { + "lora_A.weight", + "lora_B.weight", + "lora_down.weight", + "lora_up.weight", + "dora_scale", + "lokr_w1", + "lokr_w2", # LoKR format + }, + ) + # Must NOT have diffusion_model.layers (Z-Image) or Flux-style keys. + # Flux LoRAs can have transformer.single_transformer_blocks or transformer.transformer_blocks + # (with the "transformer." prefix and "single_" variant) which would falsely match our check. + # Flux Kohya LoRAs use lora_unet_double_blocks or lora_unet_single_blocks. + has_z_image_keys = state_dict_has_any_keys_starting_with(state_dict, {"diffusion_model.layers."}) + has_flux_keys = state_dict_has_any_keys_starting_with( + state_dict, + { + "double_blocks.", + "single_blocks.", + "single_transformer_blocks.", + "transformer.single_transformer_blocks.", + "lora_unet_double_blocks_", + "lora_unet_single_blocks_", + "lora_unet_single_transformer_blocks_", + }, + ) + + if has_qwen_ie_keys and has_lora_suffix and not has_z_image_keys and not has_flux_keys: + return + + raise NotAMatchError("model does not match Qwen Image LoRA heuristics") + + @classmethod + def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType: + state_dict = mod.load_state_dict() + has_qwen_ie_keys = state_dict_has_any_keys_starting_with( + state_dict, + {"transformer_blocks.", "transformer.transformer_blocks.", "lora_unet_transformer_blocks_"}, + ) + has_z_image_keys = state_dict_has_any_keys_starting_with(state_dict, {"diffusion_model.layers."}) + has_flux_keys = state_dict_has_any_keys_starting_with( + state_dict, + { + "double_blocks.", + "single_blocks.", + "single_transformer_blocks.", + "transformer.single_transformer_blocks.", + "lora_unet_double_blocks_", + "lora_unet_single_blocks_", + "lora_unet_single_transformer_blocks_", + }, + ) + + if has_qwen_ie_keys and not has_z_image_keys and not has_flux_keys: + return BaseModelType.QwenImage + raise NotAMatchError("model does not look like a Qwen Image Edit LoRA") + + class LoRA_LyCORIS_Anima_Config(LoRA_LyCORIS_Config_Base, Config_Base): """Model config for Anima LoRA models in LyCORIS format.""" diff --git a/invokeai/backend/model_manager/configs/main.py b/invokeai/backend/model_manager/configs/main.py index ee9abe54fa..1be349f394 100644 --- a/invokeai/backend/model_manager/configs/main.py +++ b/invokeai/backend/model_manager/configs/main.py @@ -28,6 +28,7 @@ from invokeai.backend.model_manager.taxonomy import ( ModelFormat, ModelType, ModelVariantType, + QwenImageVariantType, SchedulerPredictionType, SubModelType, ZImageVariantType, @@ -86,6 +87,8 @@ class MainModelDefaultSettings(BaseModel): else: # Distilled models (Klein 4B, Klein 9B) use fewer steps return cls(steps=4, cfg_scale=1.0, width=1024, height=1024) + case BaseModelType.QwenImage: + return cls(steps=40, cfg_scale=4.0, width=1024, height=1024) case _: # TODO(psyche): Do we want defaults for other base types? return None @@ -196,9 +199,11 @@ class Main_SD_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base): cls._validate_base(mod) - prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) + prediction_type = override_fields.pop("prediction_type", None) or cls._get_scheduler_prediction_type_or_raise( + mod + ) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) return cls(**override_fields, prediction_type=prediction_type, variant=variant) @@ -471,7 +476,7 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf cls._validate_does_not_look_like_gguf_quantized(mod) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) return cls(**override_fields, variant=variant) @@ -546,7 +551,7 @@ class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Con cls._validate_does_not_look_like_gguf_quantized(mod) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) return cls(**override_fields, variant=variant) @@ -609,7 +614,7 @@ class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_B cls._validate_model_looks_like_bnb_quantized(mod) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) return cls(**override_fields, variant=variant) @@ -660,7 +665,7 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas cls._validate_is_not_flux2(mod) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) return cls(**override_fields, variant=variant) @@ -718,7 +723,7 @@ class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Ba cls._validate_is_flux2(mod) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) return cls(**override_fields, variant=variant) @@ -779,9 +784,9 @@ class Main_Diffusers_FLUX_Config(Diffusers_Config_Base, Main_Config_Base, Config }, ) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) - repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) return cls( **override_fields, @@ -833,9 +838,9 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi }, ) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) - repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) return cls( **override_fields, @@ -904,11 +909,13 @@ class Main_SD_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base): cls._validate_base(mod) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) - prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod) + prediction_type = override_fields.pop("prediction_type", None) or cls._get_scheduler_prediction_type_or_raise( + mod + ) - repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) return cls( **override_fields, @@ -1014,9 +1021,9 @@ class Main_Diffusers_SD3_Config(Diffusers_Config_Base, Main_Config_Base, Config_ }, ) - submodels = override_fields.get("submodels") or cls._get_submodels_or_raise(mod) + submodels = override_fields.pop("submodels", None) or cls._get_submodels_or_raise(mod) - repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) return cls( **override_fields, @@ -1089,7 +1096,7 @@ class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Co }, ) - repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) return cls( **override_fields, @@ -1155,9 +1162,9 @@ class Main_Diffusers_ZImage_Config(Diffusers_Config_Base, Main_Config_Base, Conf }, ) - variant = override_fields.get("variant") or cls._get_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod) - repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod) + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) return cls( **override_fields, @@ -1201,7 +1208,7 @@ class Main_Checkpoint_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Co cls._validate_does_not_look_like_gguf_quantized(mod) - variant = override_fields.get("variant", ZImageVariantType.Turbo) + variant = override_fields.pop("variant", None) or ZImageVariantType.Turbo return cls(**override_fields, variant=variant) @@ -1235,7 +1242,7 @@ class Main_GGUF_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_B cls._validate_looks_like_gguf_quantized(mod) - variant = override_fields.get("variant", ZImageVariantType.Turbo) + variant = override_fields.pop("variant", None) or ZImageVariantType.Turbo return cls(**override_fields, variant=variant) @@ -1252,6 +1259,106 @@ class Main_GGUF_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_B raise NotAMatchError("state dict does not look like GGUF quantized") +class Main_Diffusers_QwenImage_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base): + """Model config for Qwen Image diffusers models (both txt2img and edit).""" + + base: Literal[BaseModelType.QwenImage] = Field(BaseModelType.QwenImage) + variant: QwenImageVariantType | None = Field(default=None) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_dir(mod) + + raise_for_override_fields(cls, override_fields) + + # This check implies the base type - no further validation needed. + raise_for_class_name( + common_config_paths(mod.path), + { + "QwenImagePlusPipeline", + "QwenImageEditPlusPipeline", + "QwenImagePipeline", + }, + ) + + repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod) + variant = override_fields.pop("variant", None) or cls._get_qwen_image_variant(mod) + + return cls( + **override_fields, + repo_variant=repo_variant, + variant=variant, + ) + + @classmethod + def _get_qwen_image_variant(cls, mod: ModelOnDisk) -> QwenImageVariantType: + """Detect whether this is an edit or txt2img model from the pipeline class name.""" + import json + + model_index = mod.path / "model_index.json" + if model_index.exists(): + with open(model_index) as f: + config = json.load(f) + class_name = config.get("_class_name", "") + if "Edit" in class_name: + return QwenImageVariantType.Edit + return QwenImageVariantType.Generate + + +def _has_qwen_image_keys(state_dict: dict[str | int, Any]) -> bool: + """Check if state dict contains Qwen Image Edit transformer keys. + + Qwen Image Edit uses 'txt_in' and 'txt_norm' instead of 'context_embedder' (FLUX). + This distinguishes it from FLUX and other architectures. + """ + has_txt_in = any(isinstance(k, str) and k.startswith("txt_in.") for k in state_dict.keys()) + has_txt_norm = any(isinstance(k, str) and k.startswith("txt_norm.") for k in state_dict.keys()) + has_img_in = any(isinstance(k, str) and k.startswith("img_in.") for k in state_dict.keys()) + # Must NOT have context_embedder (which would indicate FLUX) + has_context_embedder = any(isinstance(k, str) and "context_embedder" in k for k in state_dict.keys()) + return has_txt_in and has_txt_norm and has_img_in and not has_context_embedder + + +class Main_GGUF_QwenImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): + """Model config for GGUF-quantized Qwen Image transformer models.""" + + base: Literal[BaseModelType.QwenImage] = Field(default=BaseModelType.QwenImage) + format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized) + variant: QwenImageVariantType | None = Field(default=None) + + @classmethod + def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self: + raise_if_not_file(mod) + + raise_for_override_fields(cls, override_fields) + + sd = mod.load_state_dict() + + if not _has_qwen_image_keys(sd): + raise NotAMatchError("state dict does not look like a Qwen Image Edit model") + + if not _has_ggml_tensors(sd): + raise NotAMatchError("state dict does not look like GGUF quantized") + + # Infer variant from the state dict if not explicitly provided. + # The Edit variant includes an extra tensor `__index_timestep_zero__` (used by the + # `zero_cond_t` dual-modulation path in diffusers' QwenImageTransformer2DModel). + # If the marker tensor is missing, fall back to the filename heuristic since older + # or alternate GGUF converters may not emit it. + explicit_variant = override_fields.pop("variant", None) + if explicit_variant is None: + if "__index_timestep_zero__" in sd: + explicit_variant = QwenImageVariantType.Edit + else: + filename = mod.path.stem.lower() + if "edit" in filename: + explicit_variant = QwenImageVariantType.Edit + else: + explicit_variant = QwenImageVariantType.Generate + + return cls(**override_fields, variant=explicit_variant) + + class Main_Checkpoint_Anima_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base): """Model config for Anima single-file checkpoint models (safetensors). diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 38d5aebeaa..6cf06d4807 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -57,6 +57,9 @@ from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils is_state_dict_likely_in_flux_xlabs_format, lora_model_from_flux_xlabs_state_dict, ) +from invokeai.backend.patches.lora_conversions.qwen_image_lora_conversion_utils import ( + lora_model_from_qwen_image_state_dict, +) from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import lora_model_from_z_image_state_dict @@ -162,6 +165,8 @@ class LoRALoader(ModelLoader): # Z-Image LoRAs use diffusers PEFT format with transformer and/or Qwen3 encoder layers. # We set alpha=None to use rank as alpha (common default). model = lora_model_from_z_image_state_dict(state_dict=state_dict, alpha=None) + elif self._model_base == BaseModelType.QwenImage: + model = lora_model_from_qwen_image_state_dict(state_dict=state_dict, alpha=None) elif self._model_base == BaseModelType.Anima: # Anima LoRAs use Kohya-style or diffusers PEFT format targeting Cosmos DiT blocks. model = lora_model_from_anima_state_dict(state_dict=state_dict, alpha=None) diff --git a/invokeai/backend/model_manager/load/model_loaders/qwen_image.py b/invokeai/backend/model_manager/load/model_loaders/qwen_image.py new file mode 100644 index 0000000000..a025e72794 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/qwen_image.py @@ -0,0 +1,177 @@ +from pathlib import Path +from typing import Optional + +import accelerate +import torch + +from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base +from invokeai.backend.model_manager.configs.factory import AnyModelConfig +from invokeai.backend.model_manager.configs.main import Main_GGUF_QwenImage_Config +from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry +from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader +from invokeai.backend.model_manager.taxonomy import ( + AnyModel, + BaseModelType, + ModelFormat, + ModelType, + QwenImageVariantType, + SubModelType, +) +from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor +from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader +from invokeai.backend.util.devices import TorchDevice + + +@ModelLoaderRegistry.register(base=BaseModelType.QwenImage, type=ModelType.Main, format=ModelFormat.Diffusers) +class QwenImageDiffusersModel(GenericDiffusersLoader): + """Class to load Qwen Image Edit main models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if isinstance(config, Checkpoint_Config_Base): + raise NotImplementedError("CheckpointConfigBase is not implemented for Qwen Image Edit models.") + + if submodel_type is None: + raise Exception("A submodel type must be provided when loading main pipelines.") + + model_path = Path(config.path) + load_class = self.get_hf_load_class(model_path, submodel_type) + repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None + variant = repo_variant.value if repo_variant else None + model_path = model_path / submodel_type.value + + # We force bfloat16 for Qwen Image Edit models. + # Use `dtype` (newer) with fallback to `torch_dtype` (older diffusers). + dtype_kwarg = {"dtype": torch.bfloat16} + try: + result: AnyModel = load_class.from_pretrained( + model_path, + **dtype_kwarg, + variant=variant, + local_files_only=True, + ) + except TypeError: + # Older diffusers uses torch_dtype instead of dtype + dtype_kwarg = {"torch_dtype": torch.bfloat16} + result = load_class.from_pretrained( + model_path, + **dtype_kwarg, + variant=variant, + local_files_only=True, + ) + except OSError as e: + if variant and "no file named" in str(e): + result = load_class.from_pretrained(model_path, **dtype_kwarg, local_files_only=True) + else: + raise e + + return result + + +@ModelLoaderRegistry.register(base=BaseModelType.QwenImage, type=ModelType.Main, format=ModelFormat.GGUFQuantized) +class QwenImageGGUFCheckpointModel(ModelLoader): + """Class to load GGUF-quantized Qwen Image Edit transformer models.""" + + def _load_model( + self, + config: AnyModelConfig, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if not isinstance(config, Checkpoint_Config_Base): + raise ValueError("Only CheckpointConfigBase models are currently supported here.") + + match submodel_type: + case SubModelType.Transformer: + return self._load_from_singlefile(config) + + raise ValueError( + f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}" + ) + + def _load_from_singlefile(self, config: AnyModelConfig) -> AnyModel: + from diffusers import QwenImageTransformer2DModel + + if not isinstance(config, Main_GGUF_QwenImage_Config): + raise TypeError(f"Expected Main_GGUF_QwenImage_Config, got {type(config).__name__}.") + model_path = Path(config.path) + + target_device = TorchDevice.choose_torch_device() + compute_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device) + + sd = gguf_sd_loader(model_path, compute_dtype=compute_dtype) + + # Strip ComfyUI-style prefixes if present + prefix_to_strip = None + for prefix in ["model.diffusion_model.", "diffusion_model."]: + if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)): + prefix_to_strip = prefix + break + + if prefix_to_strip: + stripped_sd = {} + for key, value in sd.items(): + if isinstance(key, str) and key.startswith(prefix_to_strip): + stripped_sd[key[len(prefix_to_strip) :]] = value + else: + stripped_sd[key] = value + sd = stripped_sd + + # Auto-detect architecture from state dict + num_layers = 0 + for key in sd.keys(): + if isinstance(key, str) and key.startswith("transformer_blocks."): + parts = key.split(".") + if len(parts) >= 2: + try: + layer_idx = int(parts[1]) + num_layers = max(num_layers, layer_idx + 1) + except ValueError: + pass + + # Detect dimensions from weights + num_attention_heads = 24 # default + attention_head_dim = 128 # default + + if "img_in.weight" in sd: + w = sd["img_in.weight"] + shape = w.tensor_shape if isinstance(w, GGMLTensor) else w.shape + hidden_dim = shape[0] + in_channels = shape[1] + num_attention_heads = hidden_dim // attention_head_dim + + joint_attention_dim = 3584 # default + if "txt_in.weight" in sd: + w = sd["txt_in.weight"] + shape = w.tensor_shape if isinstance(w, GGMLTensor) else w.shape + joint_attention_dim = shape[1] + + model_config: dict = { + "patch_size": 2, + "in_channels": in_channels if "img_in.weight" in sd else 64, + "out_channels": 16, + "num_layers": num_layers if num_layers > 0 else 60, + "attention_head_dim": attention_head_dim, + "num_attention_heads": num_attention_heads, + "joint_attention_dim": joint_attention_dim, + "guidance_embeds": False, + "axes_dims_rope": (16, 56, 56), + } + + # zero_cond_t is only used by edit-variant models. It enables dual modulation + # for noisy vs reference patches. Setting it on txt2img models produces garbage. + # Also requires diffusers 0.37+ (the parameter doesn't exist in older versions). + import inspect + + is_edit = getattr(config, "variant", None) == QwenImageVariantType.Edit + if is_edit and "zero_cond_t" in inspect.signature(QwenImageTransformer2DModel.__init__).parameters: + model_config["zero_cond_t"] = True + + with accelerate.init_empty_weights(): + model = QwenImageTransformer2DModel(**model_config) + + model.load_state_dict(sd, strict=False, assign=True) + return model diff --git a/invokeai/backend/model_manager/metadata/fetch/huggingface.py b/invokeai/backend/model_manager/metadata/fetch/huggingface.py index 1b2b6c3674..30fe418fe1 100644 --- a/invokeai/backend/model_manager/metadata/fetch/huggingface.py +++ b/invokeai/backend/model_manager/metadata/fetch/huggingface.py @@ -19,8 +19,7 @@ from pathlib import Path from typing import Optional import requests -from huggingface_hub import HfApi, configure_http_backend, hf_hub_url -from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundError +from huggingface_hub import hf_hub_url from pydantic.networks import AnyHttpUrl from requests.sessions import Session @@ -47,7 +46,6 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): this module without an internet connection. """ self._requests = session or requests.Session() - configure_http_backend(backend_factory=lambda: self._requests) @classmethod def from_json(cls, json: str) -> HuggingFaceMetadata: @@ -55,6 +53,22 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): metadata = HuggingFaceMetadata.model_validate_json(json) return metadata + def _fetch_model_info(self, repo_id: str, variant: Optional[ModelRepoVariant] = None) -> dict: + """Fetch model info from HuggingFace API using self._requests session. + + This allows the session to be mocked in tests via requests_testadapter. + """ + url = f"https://huggingface.co/api/models/{repo_id}" + params: dict[str, str] = {"blobs": "True"} + if variant is not None: + params["revision"] = str(variant) + + response = self._requests.get(url, params=params) + if response.status_code == 404: + raise UnknownMetadataException(f"'{repo_id}' not found.") + response.raise_for_status() + return response.json() + def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata: """Return a HuggingFaceMetadata object given the model's repo_id.""" # Little loop which tries fetching a revision corresponding to the selected variant. @@ -67,10 +81,10 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): repo_id = id.split("::")[0] or id while not model_info: try: - model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True, revision=variant) - except RepositoryNotFoundError as excp: - raise UnknownMetadataException(f"'{repo_id}' not found. See trace for details.") from excp - except RevisionNotFoundError: + model_info = self._fetch_model_info(repo_id, variant) + except UnknownMetadataException: + raise + except requests.HTTPError: if variant is None: raise else: @@ -80,15 +94,18 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): _, name = repo_id.split("/") - for s in model_info.siblings or []: - assert s.rfilename is not None - assert s.size is not None + for s in model_info.get("siblings") or []: + rfilename = s.get("rfilename") + size = s.get("size") + assert rfilename is not None + assert size is not None + lfs = s.get("lfs") files.append( RemoteModelFile( - url=hf_hub_url(repo_id, s.rfilename, revision=variant or "main"), - path=Path(name, s.rfilename), - size=s.size, - sha256=s.lfs.get("sha256") if s.lfs else None, + url=hf_hub_url(repo_id, rfilename, revision=variant or "main"), + path=Path(name, rfilename), + size=size, + sha256=lfs.get("sha256") if lfs else None, ) ) @@ -115,10 +132,10 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase): ) return HuggingFaceMetadata( - id=model_info.id, + id=model_info["id"], name=name, files=files, - api_response=json.dumps(model_info.__dict__, default=str), + api_response=json.dumps(model_info, default=str), is_diffusers=is_diffusers, ckpt_urls=ckpt_urls, ) diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index e16ad4cbc4..b048144e54 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -17,7 +17,7 @@ remote repo. from pathlib import Path from typing import List, Literal, Optional, Union -from huggingface_hub import configure_http_backend, hf_hub_url +from huggingface_hub import hf_hub_url from pydantic import BaseModel, Field, TypeAdapter from pydantic.networks import AnyHttpUrl from requests.sessions import Session @@ -111,7 +111,6 @@ class HuggingFaceMetadata(ModelMetadataWithFiles): full-precision model is returned. """ session = session or Session() - configure_http_backend(backend_factory=lambda: session) # used in testing paths = filter_files([x.path for x in self.files], variant, subfolder, subfolders) # all files in the model diff --git a/invokeai/backend/model_manager/starter_models.py b/invokeai/backend/model_manager/starter_models.py index edcac321f1..c93a606aa8 100644 --- a/invokeai/backend/model_manager/starter_models.py +++ b/invokeai/backend/model_manager/starter_models.py @@ -9,7 +9,13 @@ from invokeai.backend.model_manager.configs.external_api import ( ExternalModelPanelSchema, ExternalResolutionPreset, ) -from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType +from invokeai.backend.model_manager.taxonomy import ( + AnyVariant, + BaseModelType, + ModelFormat, + ModelType, + QwenImageVariantType, +) class StarterModelWithoutDependencies(BaseModel): @@ -19,6 +25,7 @@ class StarterModelWithoutDependencies(BaseModel): base: BaseModelType type: ModelType format: Optional[ModelFormat] = None + variant: Optional[AnyVariant] = None is_installed: bool = False capabilities: ExternalModelCapabilities | None = None default_settings: ExternalApiModelDefaultSettings | None = None @@ -659,6 +666,138 @@ cogview4 = StarterModel( ) # endregion +# region Qwen Image Edit +qwen_image_edit = StarterModel( + name="Qwen Image Edit 2511", + base=BaseModelType.QwenImage, + source="Qwen/Qwen-Image-Edit-2511", + description="Qwen Image Edit 2511 full diffusers model. Supports text-guided image editing with multiple reference images. (~40GB)", + type=ModelType.Main, + variant=QwenImageVariantType.Edit, +) + +qwen_image_edit_gguf_q4_k_m = StarterModel( + name="Qwen Image Edit 2511 (Q4_K_M)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-Edit-2511-GGUF/resolve/main/qwen-image-edit-2511-Q4_K_M.gguf", + description="Qwen Image Edit 2511 - Q4_K_M quantized transformer. Good quality/size balance. (~13GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, + variant=QwenImageVariantType.Edit, +) + +qwen_image_edit_gguf_q2_k = StarterModel( + name="Qwen Image Edit 2511 (Q2_K)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-Edit-2511-GGUF/resolve/main/qwen-image-edit-2511-Q2_K.gguf", + description="Qwen Image Edit 2511 - Q2_K heavily quantized transformer. Smallest size, lower quality. (~7.5GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, + variant=QwenImageVariantType.Edit, +) + +qwen_image_edit_gguf_q6_k = StarterModel( + name="Qwen Image Edit 2511 (Q6_K)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-Edit-2511-GGUF/resolve/main/qwen-image-edit-2511-Q6_K.gguf", + description="Qwen Image Edit 2511 - Q6_K quantized transformer. Near-lossless quality. (~17GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, + variant=QwenImageVariantType.Edit, +) + +qwen_image_edit_gguf_q8_0 = StarterModel( + name="Qwen Image Edit 2511 (Q8_0)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-Edit-2511-GGUF/resolve/main/qwen-image-edit-2511-Q8_0.gguf", + description="Qwen Image Edit 2511 - Q8_0 quantized transformer. Highest quality quantization. (~22GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, + variant=QwenImageVariantType.Edit, +) + +qwen_image_edit_lightning_4step = StarterModel( + name="Qwen Image Edit Lightning (4-step, bf16)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/lightx2v/Qwen-Image-Edit-2511-Lightning/resolve/main/Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors", + description="Lightning distillation LoRA for Qwen Image Edit — enables generation in just 4 steps. " + "Settings: Steps=4, CFG=1, Shift Override=3.", + type=ModelType.LoRA, +) + +qwen_image_edit_lightning_8step = StarterModel( + name="Qwen Image Edit Lightning (8-step, bf16)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/lightx2v/Qwen-Image-Edit-2511-Lightning/resolve/main/Qwen-Image-Edit-2511-Lightning-8steps-V1.0-bf16.safetensors", + description="Lightning distillation LoRA for Qwen Image Edit — enables generation in 8 steps with better quality. " + "Settings: Steps=8, CFG=1, Shift Override=3.", + type=ModelType.LoRA, +) + +# Qwen Image (txt2img) +qwen_image = StarterModel( + name="Qwen Image 2512", + base=BaseModelType.QwenImage, + source="Qwen/Qwen-Image-2512", + description="Qwen Image 2512 full diffusers model. High-quality text-to-image generation. (~40GB)", + type=ModelType.Main, +) + +qwen_image_gguf_q4_k_m = StarterModel( + name="Qwen Image 2512 (Q4_K_M)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-2512-GGUF/resolve/main/qwen-image-2512-Q4_K_M.gguf", + description="Qwen Image 2512 - Q4_K_M quantized transformer. Good quality/size balance. (~13GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, +) + +qwen_image_gguf_q2_k = StarterModel( + name="Qwen Image 2512 (Q2_K)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-2512-GGUF/resolve/main/qwen-image-2512-Q2_K.gguf", + description="Qwen Image 2512 - Q2_K heavily quantized transformer. Smallest size, lower quality. (~7.5GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, +) + +qwen_image_gguf_q6_k = StarterModel( + name="Qwen Image 2512 (Q6_K)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-2512-GGUF/resolve/main/qwen-image-2512-Q6_K.gguf", + description="Qwen Image 2512 - Q6_K quantized transformer. Near-lossless quality. (~17GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, +) + +qwen_image_gguf_q8_0 = StarterModel( + name="Qwen Image 2512 (Q8_0)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/unsloth/Qwen-Image-2512-GGUF/resolve/main/qwen-image-2512-Q8_0.gguf", + description="Qwen Image 2512 - Q8_0 quantized transformer. Highest quality quantization. (~22GB)", + type=ModelType.Main, + format=ModelFormat.GGUFQuantized, +) + +qwen_image_lightning_4step = StarterModel( + name="Qwen Image Lightning (4-step, V2.0, bf16)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors", + description="Lightning distillation LoRA for Qwen Image — enables generation in just 4 steps. " + "Settings: Steps=4, CFG=1, Shift Override=3.", + type=ModelType.LoRA, +) + +qwen_image_lightning_8step = StarterModel( + name="Qwen Image Lightning (8-step, V2.0, bf16)", + base=BaseModelType.QwenImage, + source="https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V2.0-bf16.safetensors", + description="Lightning distillation LoRA for Qwen Image — enables generation in 8 steps with better quality. " + "Settings: Steps=8, CFG=1, Shift Override=3.", + type=ModelType.LoRA, +) +# endregion + # region SigLIP siglip = StarterModel( name="SigLIP - google/siglip-so400m-patch14-384", @@ -1225,6 +1364,20 @@ STARTER_MODELS: list[StarterModel] = [ flux2_klein_qwen3_4b_encoder, flux2_klein_qwen3_8b_encoder, cogview4, + qwen_image_edit, + qwen_image_edit_gguf_q2_k, + qwen_image_edit_gguf_q4_k_m, + qwen_image_edit_gguf_q6_k, + qwen_image_edit_gguf_q8_0, + qwen_image_edit_lightning_4step, + qwen_image_edit_lightning_8step, + qwen_image, + qwen_image_gguf_q2_k, + qwen_image_gguf_q4_k_m, + qwen_image_gguf_q6_k, + qwen_image_gguf_q8_0, + qwen_image_lightning_4step, + qwen_image_lightning_8step, flux_krea, flux_krea_quantized, z_image_turbo, @@ -1313,6 +1466,19 @@ flux2_klein_bundle: list[StarterModel] = [ flux2_klein_qwen3_4b_encoder, ] +qwen_image_bundle: list[StarterModel] = [ + qwen_image_edit, + qwen_image_edit_gguf_q4_k_m, + qwen_image_edit_gguf_q8_0, + qwen_image_edit_lightning_4step, + qwen_image_edit_lightning_8step, + qwen_image, + qwen_image_gguf_q4_k_m, + qwen_image_gguf_q8_0, + qwen_image_lightning_4step, + qwen_image_lightning_8step, +] + anima_bundle: list[StarterModel] = [ anima_preview3, anima_qwen3_encoder, @@ -1326,6 +1492,7 @@ STARTER_BUNDLES: dict[str, StarterModelBundle] = { BaseModelType.Flux: StarterModelBundle(name="FLUX.1 dev", models=flux_bundle), BaseModelType.Flux2: StarterModelBundle(name="FLUX.2 Klein", models=flux2_klein_bundle), BaseModelType.ZImage: StarterModelBundle(name="Z-Image Turbo", models=zimage_bundle), + BaseModelType.QwenImage: StarterModelBundle(name="Qwen Image", models=qwen_image_bundle), BaseModelType.Anima: StarterModelBundle(name="Anima", models=anima_bundle), } diff --git a/invokeai/backend/model_manager/taxonomy.py b/invokeai/backend/model_manager/taxonomy.py index 4a20665094..b2b55ebd3f 100644 --- a/invokeai/backend/model_manager/taxonomy.py +++ b/invokeai/backend/model_manager/taxonomy.py @@ -54,6 +54,8 @@ class BaseModelType(str, Enum): """Indicates the model is associated with Z-Image model architecture, including Z-Image-Turbo.""" External = "external" """Indicates the model is hosted by an external provider.""" + QwenImage = "qwen-image" + """Indicates the model is associated with Qwen Image Edit 2511 model architecture.""" Anima = "anima" """Indicates the model is associated with Anima model architecture (Cosmos Predict2 DiT + LLM Adapter).""" Unknown = "unknown" @@ -148,6 +150,16 @@ class ZImageVariantType(str, Enum): """Z-Image Base - undistilled foundation model with full CFG and negative prompt support.""" +class QwenImageVariantType(str, Enum): + """Qwen Image model variants.""" + + Generate = "generate" + """Qwen Image - text-to-image generation model.""" + + Edit = "edit" + """Qwen Image Edit - image editing model with reference image support.""" + + class Qwen3VariantType(str, Enum): """Qwen3 text encoder variants based on model size.""" @@ -224,8 +236,28 @@ class FluxLoRAFormat(str, Enum): AnyVariant: TypeAlias = Union[ - ModelVariantType, ClipVariantType, FluxVariantType, Flux2VariantType, ZImageVariantType, Qwen3VariantType + ModelVariantType, + ClipVariantType, + FluxVariantType, + Flux2VariantType, + ZImageVariantType, + QwenImageVariantType, + Qwen3VariantType, ] variant_type_adapter = TypeAdapter[ - ModelVariantType | ClipVariantType | FluxVariantType | Flux2VariantType | ZImageVariantType | Qwen3VariantType -](ModelVariantType | ClipVariantType | FluxVariantType | Flux2VariantType | ZImageVariantType | Qwen3VariantType) + ModelVariantType + | ClipVariantType + | FluxVariantType + | Flux2VariantType + | ZImageVariantType + | QwenImageVariantType + | Qwen3VariantType +]( + ModelVariantType + | ClipVariantType + | FluxVariantType + | Flux2VariantType + | ZImageVariantType + | QwenImageVariantType + | Qwen3VariantType +) diff --git a/invokeai/backend/patches/lora_conversions/qwen_image_lora_constants.py b/invokeai/backend/patches/lora_conversions/qwen_image_lora_constants.py new file mode 100644 index 0000000000..727ee5a428 --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/qwen_image_lora_constants.py @@ -0,0 +1,5 @@ +# Qwen Image Edit LoRA prefix constants +# These prefixes are used for key mapping when applying LoRA patches to Qwen Image Edit models + +# Prefix for Qwen Image Edit transformer LoRA layers +QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX = "lora_transformer-" diff --git a/invokeai/backend/patches/lora_conversions/qwen_image_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/qwen_image_lora_conversion_utils.py new file mode 100644 index 0000000000..7fc01f7231 --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/qwen_image_lora_conversion_utils.py @@ -0,0 +1,197 @@ +"""Qwen Image LoRA conversion utilities. + +Qwen Image uses QwenImageTransformer2DModel architecture. +Supports multiple LoRA formats: +- Diffusers/PEFT: transformer_blocks.0.attn.to_k.lora_down.weight +- With prefix: transformer.transformer_blocks.0.attn.to_k.lora_down.weight +- Kohya: lora_unet_transformer_blocks_0_attn_to_k.lora_down.weight (underscores instead of dots) +""" + +import re +from typing import Dict + +import torch + +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict +from invokeai.backend.patches.lora_conversions.qwen_image_lora_constants import ( + QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX, +) +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw + +# Regex for Kohya-format Qwen Image LoRA keys. +# Example: lora_unet_transformer_blocks_0_attn_to_k +# Groups: (block_idx, sub_module_with_underscores) +_KOHYA_KEY_REGEX = re.compile(r"lora_unet_transformer_blocks_(\d+)_(.*)") + +# Mapping from Kohya underscore-separated sub-module names to dot-separated model paths. +# The Kohya format uses underscores everywhere, but some underscores are part of the +# module name (e.g., add_k_proj, to_out). We match the longest prefix first. +_KOHYA_MODULE_MAP: list[tuple[str, str]] = [ + # Attention projections + ("attn_add_k_proj", "attn.add_k_proj"), + ("attn_add_q_proj", "attn.add_q_proj"), + ("attn_add_v_proj", "attn.add_v_proj"), + ("attn_to_add_out", "attn.to_add_out"), + ("attn_to_out_0", "attn.to_out.0"), + ("attn_to_k", "attn.to_k"), + ("attn_to_q", "attn.to_q"), + ("attn_to_v", "attn.to_v"), + # Image stream MLP and modulation + ("img_mlp_net_0_proj", "img_mlp.net.0.proj"), + ("img_mlp_net_2", "img_mlp.net.2"), + ("img_mod_1", "img_mod.1"), + # Text stream MLP and modulation + ("txt_mlp_net_0_proj", "txt_mlp.net.0.proj"), + ("txt_mlp_net_2", "txt_mlp.net.2"), + ("txt_mod_1", "txt_mod.1"), +] + + +def is_state_dict_likely_kohya_qwen_image(state_dict: dict[str | int, torch.Tensor]) -> bool: + """Check if the state dict uses Kohya-format Qwen Image LoRA keys.""" + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] + if not str_keys: + return False + # Check if any key matches the Kohya pattern + return any(k.startswith("lora_unet_transformer_blocks_") for k in str_keys) + + +def _convert_kohya_key(kohya_layer: str) -> str | None: + """Convert a Kohya-format layer name to a dot-separated model module path. + + Example: lora_unet_transformer_blocks_0_attn_to_k -> transformer_blocks.0.attn.to_k + """ + m = _KOHYA_KEY_REGEX.match(kohya_layer) + if not m: + return None + + block_idx = m.group(1) + sub_module = m.group(2) + + for kohya_name, model_path in _KOHYA_MODULE_MAP: + if sub_module == kohya_name: + return f"transformer_blocks.{block_idx}.{model_path}" + + # Fallback: unknown sub-module, return None so caller can warn/skip + return None + + +def lora_model_from_qwen_image_state_dict( + state_dict: Dict[str, torch.Tensor], alpha: float | None = None +) -> ModelPatchRaw: + """Convert a Qwen Image LoRA state dict to a ModelPatchRaw. + + Handles three key formats: + - Diffusers/PEFT: transformer_blocks.0.attn.to_k.lora_down.weight + - With prefix: transformer.transformer_blocks.0.attn.to_k.lora_down.weight + - Kohya: lora_unet_transformer_blocks_0_attn_to_k.lora_down.weight + """ + is_kohya = is_state_dict_likely_kohya_qwen_image(state_dict) + + if is_kohya: + return _convert_kohya_format(state_dict, alpha) + else: + return _convert_diffusers_format(state_dict, alpha) + + +def _convert_kohya_format(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> ModelPatchRaw: + """Convert Kohya-format state dict. Keys are like lora_unet_transformer_blocks_0_attn_to_k.lokr_w1""" + layers: dict[str, BaseLayerPatch] = {} + + # Group by layer (split at first dot: layer_name.param_name) + grouped: dict[str, dict[str, torch.Tensor]] = {} + for key, value in state_dict.items(): + if not isinstance(key, str): + continue + layer_name, param_name = key.split(".", 1) + if layer_name not in grouped: + grouped[layer_name] = {} + grouped[layer_name][param_name] = value + + for kohya_layer, layer_dict in grouped.items(): + model_path = _convert_kohya_key(kohya_layer) + if model_path is None: + continue # Skip unrecognized layers + + layer = any_lora_layer_from_state_dict(layer_dict) + final_key = f"{QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX}{model_path}" + layers[final_key] = layer + + return ModelPatchRaw(layers=layers) + + +def _convert_diffusers_format(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> ModelPatchRaw: + """Convert Diffusers/PEFT format state dict.""" + layers: dict[str, BaseLayerPatch] = {} + + # Some LoRAs use a "transformer." prefix on keys + strip_prefixes = ["transformer."] + + grouped = _group_by_layer(state_dict) + + for layer_key, layer_dict in grouped.items(): + values = _normalize_lora_keys(layer_dict, alpha) + layer = any_lora_layer_from_state_dict(values) + clean_key = layer_key + for prefix in strip_prefixes: + if clean_key.startswith(prefix): + clean_key = clean_key[len(prefix) :] + break + final_key = f"{QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX}{clean_key}" + layers[final_key] = layer + + return ModelPatchRaw(layers=layers) + + +def _normalize_lora_keys(layer_dict: dict[str, torch.Tensor], alpha: float | None) -> dict[str, torch.Tensor]: + """Normalize LoRA key names to internal format.""" + if "lora_A.weight" in layer_dict: + values: dict[str, torch.Tensor] = { + "lora_down.weight": layer_dict["lora_A.weight"], + "lora_up.weight": layer_dict["lora_B.weight"], + } + if alpha is not None: + values["alpha"] = torch.tensor(alpha) + return values + elif "lora_down.weight" in layer_dict: + return layer_dict + else: + return layer_dict + + +def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]: + """Group state dict keys by layer path.""" + layer_dict: dict[str, dict[str, torch.Tensor]] = {} + + known_suffixes = [ + ".lora_A.weight", + ".lora_B.weight", + ".lora_down.weight", + ".lora_up.weight", + ".dora_scale", + ".alpha", + ] + + for key in state_dict: + if not isinstance(key, str): + continue + + layer_name = None + key_name = None + for suffix in known_suffixes: + if key.endswith(suffix): + layer_name = key[: -len(suffix)] + key_name = suffix[1:] + break + + if layer_name is None: + parts = key.rsplit(".", maxsplit=2) + layer_name = parts[0] + key_name = ".".join(parts[1:]) + + if layer_name not in layer_dict: + layer_dict[layer_name] = {} + layer_dict[layer_name][key_name] = state_dict[key] + + return layer_dict diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index de5253f073..054e04dcb2 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -17,7 +17,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from diffusers.utils.import_utils import is_xformers_available from pydantic import Field -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from invokeai.app.services.config.config_default import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData @@ -139,7 +139,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): + feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ @@ -151,7 +151,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: Optional[StableDiffusionSafetyChecker], - feature_extractor: Optional[CLIPFeatureExtractor], + feature_extractor: Optional[CLIPImageProcessor], requires_safety_checker: bool = False, ): super().__init__( diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index e6ca9aa18e..6a9959f1e8 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -88,6 +88,23 @@ class ZImageConditioningInfo: return self +@dataclass +class QwenImageConditioningInfo: + """Qwen Image Edit conditioning information from Qwen2.5-VL encoder.""" + + prompt_embeds: torch.Tensor + """Text/image embeddings from Qwen2.5-VL encoder. Shape: (batch_size, seq_len, hidden_size).""" + + prompt_embeds_mask: torch.Tensor | None = None + """Attention mask for prompt_embeds. Shape: (batch_size, seq_len). 1 for valid, 0 for padding.""" + + def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): + self.prompt_embeds = self.prompt_embeds.to(device=device, dtype=dtype) + if self.prompt_embeds_mask is not None: + self.prompt_embeds_mask = self.prompt_embeds_mask.to(device=device) + return self + + @dataclass class AnimaConditioningInfo: """Anima text conditioning information from Qwen3 0.6B encoder + T5-XXL tokenizer. @@ -125,6 +142,7 @@ class ConditioningFieldData: | List[SD3ConditioningInfo] | List[CogView4ConditioningInfo] | List[ZImageConditioningInfo] + | List[QwenImageConditioningInfo] | List[AnimaConditioningInfo] ) diff --git a/invokeai/frontend/web/openapi.json b/invokeai/frontend/web/openapi.json index af8476528d..19e5a3a68e 100644 --- a/invokeai/frontend/web/openapi.json +++ b/invokeai/frontend/web/openapi.json @@ -6463,6 +6463,23 @@ "title": "Has Been Opened" }, "description": "Whether to include/exclude recent workflows" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -6655,6 +6672,23 @@ "title": "Categories" }, "description": "The categories to include" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -6744,6 +6778,23 @@ "title": "Has Been Opened" }, "description": "Whether to include/exclude recent workflows" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -6812,6 +6863,23 @@ "title": "Has Been Opened" }, "description": "Whether to include/exclude recent workflows" + }, + { + "name": "is_public", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "title": "Is Public" + }, + "description": "Filter by public/shared status" } ], "responses": { @@ -7352,6 +7420,67 @@ } } } + }, + "/api/v1/workflows/i/{workflow_id}/is_public": { + "patch": { + "tags": ["workflows"], + "summary": "Update Workflow Is Public", + "description": "Updates whether a workflow is shared publicly", + "operationId": "update_workflow_is_public", + "parameters": [ + { + "name": "workflow_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Workflow Id" + }, + "description": "The workflow to update" + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "properties": { + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether the workflow should be shared publicly" + } + }, + "type": "object", + "required": ["is_public"], + "title": "Body_update_workflow_is_public" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/WorkflowRecordDTO" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } } }, "components": { @@ -59137,10 +59266,20 @@ "workflow": { "$ref": "#/components/schemas/Workflow", "description": "The workflow." + }, + "user_id": { + "type": "string", + "title": "User Id", + "description": "The id of the user who owns this workflow." + }, + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether this workflow is shared with all users." } }, "type": "object", - "required": ["workflow_id", "name", "created_at", "updated_at", "workflow"], + "required": ["workflow_id", "name", "created_at", "updated_at", "workflow", "user_id", "is_public"], "title": "WorkflowRecordDTO" }, "WorkflowRecordListItemWithThumbnailDTO": { @@ -59222,15 +59361,35 @@ ], "title": "Thumbnail Url", "description": "The URL of the workflow thumbnail." + }, + "user_id": { + "type": "string", + "title": "User Id", + "description": "The id of the user who owns this workflow." + }, + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether this workflow is shared with all users." } }, "type": "object", - "required": ["workflow_id", "name", "created_at", "updated_at", "description", "category", "tags"], + "required": [ + "workflow_id", + "name", + "created_at", + "updated_at", + "description", + "category", + "tags", + "user_id", + "is_public" + ], "title": "WorkflowRecordListItemWithThumbnailDTO" }, "WorkflowRecordOrderBy": { "type": "string", - "enum": ["created_at", "updated_at", "opened_at", "name"], + "enum": ["created_at", "updated_at", "opened_at", "name", "is_public"], "title": "WorkflowRecordOrderBy", "description": "The order by options for workflow records" }, @@ -59303,10 +59462,20 @@ ], "title": "Thumbnail Url", "description": "The URL of the workflow thumbnail." + }, + "user_id": { + "type": "string", + "title": "User Id", + "description": "The id of the user who owns this workflow." + }, + "is_public": { + "type": "boolean", + "title": "Is Public", + "description": "Whether this workflow is shared with all users." } }, "type": "object", - "required": ["workflow_id", "name", "created_at", "updated_at", "workflow"], + "required": ["workflow_id", "name", "created_at", "updated_at", "workflow", "user_id", "is_public"], "title": "WorkflowRecordWithThumbnailDTO" }, "WorkflowWithoutID": { diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index cb00d2a767..2e54f250f9 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -161,7 +161,17 @@ "imagesWithCount_other": "{{count}} images", "assetsWithCount_one": "{{count}} asset", "assetsWithCount_other": "{{count}} assets", - "updateBoardError": "Error updating board" + "updateBoardError": "Error updating board", + "setBoardVisibility": "Set Board Visibility", + "setVisibilityPrivate": "Set Private", + "setVisibilityShared": "Set Shared", + "setVisibilityPublic": "Set Public", + "visibilityPrivate": "Private", + "visibilityShared": "Shared", + "visibilityPublic": "Public", + "visibilityBadgeShared": "Shared board", + "visibilityBadgePublic": "Public board", + "updateBoardVisibilityError": "Error updating board visibility" }, "accordions": { "generation": { @@ -1199,7 +1209,9 @@ "numImages": "Num Images", "modelPickerFallbackNoModelsInstalled": "No models installed.", "modelPickerFallbackNoModelsInstalled2": "Visit the Model Manager to install models.", + "modelPickerFallbackNoModelsInstalledNonAdmin": "No models installed. Ask your InvokeAI administrator () to install some models.", "noModelsInstalledDesc1": "Install models with the", + "noModelsInstalledAskAdmin": "Ask your administrator to install some.", "noModelSelected": "No Model Selected", "noMatchingModels": "No matching models", "noModelsInstalled": "No models installed", @@ -1295,6 +1307,12 @@ "flux2KleinVaePlaceholder": "From main model", "flux2KleinQwen3Encoder": "Qwen3 Encoder (optional)", "flux2KleinQwen3EncoderPlaceholder": "From main model", + "qwenImageComponentSource": "VAE/Encoder Source (Diffusers)", + "qwenImageComponentSourcePlaceholder": "Required for GGUF models", + "qwenImageQuantization": "Encoder Quantization", + "qwenImageQuantizationNone": "None (bf16)", + "qwenImageQuantizationInt8": "8-bit (int8)", + "qwenImageQuantizationNf4": "4-bit (nf4)", "upcastAttention": "Upcast Attention", "uploadImage": "Upload Image", "urlOrLocalPath": "URL or Local Path", @@ -1562,6 +1580,7 @@ "info": "Info", "invoke": { "addingImagesTo": "Adding images to", + "boardNotWritable": "You do not have write access to board \"{{boardName}}\". Select a board you own or switch to Uncategorized.", "modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your account settings to upgrade.", "invoke": "Invoke", "missingFieldTemplate": "Missing field template", @@ -1588,6 +1607,7 @@ "noFLUXVAEModelSelected": "No VAE model selected for FLUX generation", "noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation", "noQwen3EncoderModelSelected": "No Qwen3 Encoder model selected for FLUX2 Klein generation", + "noQwenImageComponentSourceSelected": "GGUF Qwen Image models require a Diffusers Component Source for VAE/encoder", "noZImageVaeSourceSelected": "No VAE source: Select VAE (FLUX) or Qwen3 Source model", "noZImageQwen3EncoderSourceSelected": "No Qwen3 Encoder source: Select Qwen3 Encoder or Qwen3 Source model", "noAnimaVaeModelSelected": "No Anima VAE model selected", @@ -1641,6 +1661,7 @@ "sendToCanvas": "Send To Canvas", "sendToUpscale": "Send To Upscale", "showOptionsPanel": "Show Side Panel (O or T)", + "shift": "Shift", "shuffle": "Shuffle Seed", "steps": "Steps", "strength": "Strength", @@ -2317,6 +2338,8 @@ "tags": "Tags", "yourWorkflows": "Your Workflows", "recentlyOpened": "Recently Opened", + "sharedWorkflows": "Shared Workflows", + "shareWorkflow": "Shared workflow", "noRecentWorkflows": "No Recent Workflows", "private": "Private", "shared": "Shared", @@ -3051,6 +3074,7 @@ "tileOverlap": "Tile Overlap", "postProcessingMissingModelWarning": "Visit the Model Manager to install a post-processing (image to image) model.", "missingModelsWarning": "Visit the Model Manager to install the required models:", + "missingModelsWarningNonAdmin": "Ask your InvokeAI administrator () to install the required models:", "mainModelDesc": "Main model (SD1.5 or SDXL architecture)", "tileControlNetModelDesc": "Tile ControlNet model for the chosen main model architecture", "upscaleModelDesc": "Upscale (image to image) model", @@ -3159,6 +3183,7 @@ }, "workflows": { "description": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results.", + "descriptionMultiuser": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results. You may share your workflows with other users of the system by selecting 'Shared workflow' when you create or edit it.", "learnMoreLink": "Learn more about creating workflows", "browseTemplates": { "title": "Browse Workflow Templates", @@ -3237,9 +3262,11 @@ "toGetStartedLocal": "To get started, make sure to download or import models needed to run Invoke. Then, enter a prompt in the box and click Invoke to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the Gallery or edit them to the Canvas.", "toGetStarted": "To get started, enter a prompt in the box and click Invoke to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the Gallery or edit them to the Canvas.", "toGetStartedWorkflow": "To get started, fill in the fields on the left and press Invoke to generate your image. Want to explore more workflows? Click the folder icon next to the workflow title to see a list of other templates you can try.", + "toGetStartedNonAdmin": "To get started, ask your InvokeAI administrator () to install the AI models needed to run Invoke. Then, enter a prompt in the box and click Invoke to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the Gallery or edit them to the Canvas.", "gettingStartedSeries": "Want more guidance? Check out our Getting Started Series for tips on unlocking the full potential of the Invoke Studio.", "lowVRAMMode": "For best performance, follow our Low VRAM guide.", - "noModelsInstalled": "It looks like you don't have any models installed! You can download a starter model bundle or import models." + "noModelsInstalled": "It looks like you don't have any models installed! You can download a starter model bundle or import models.", + "noModelsInstalledAskAdmin": "Ask your administrator to install some." }, "whatsNew": { "whatsNewInInvoke": "What's New in Invoke", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx index e0e72d12ff..fa4c29b8f4 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/bulkDownload.tsx @@ -12,10 +12,14 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) = effect: (action) => { log.debug(action.payload, 'Bulk download requested'); - // If we have an item name, we are processing the bulk download locally and should use it as the toast id to - // prevent multiple toasts for the same item. + // Use a "preparing:" prefix so this toast cannot collide with the + // "ready to download" toast that arrives via the bulk_download_complete + // socket event. The background task can complete in under 20ms, so the + // socket event may arrive *before* this Redux middleware runs — without + // distinct IDs the "preparing" toast would overwrite the "ready" toast. + const itemName = action.payload.bulk_download_item_name; toast({ - id: action.payload.bulk_download_item_name ?? undefined, + id: itemName ? `preparing:${itemName}` : undefined, title: t('gallery.bulkDownloadRequested'), status: 'success', // Show the response message if it exists, otherwise show the default message diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index 251403ed04..1c7941106b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -11,6 +11,7 @@ import { kleinQwen3EncoderModelSelected, kleinVaeModelSelected, modelChanged, + qwenImageComponentSourceSelected, resolutionPresetSelected, setZImageScheduler, syncedToOptimalDimension, @@ -29,12 +30,18 @@ import { selectBboxModelBase, selectCanvasSlice, } from 'features/controlLayers/store/selectors'; -import { getEntityIdentifier, isAspectRatioID, isFlux2ReferenceImageConfig } from 'features/controlLayers/store/types'; +import { + getEntityIdentifier, + isAspectRatioID, + isFlux2ReferenceImageConfig, + isQwenImageReferenceImageConfig, +} from 'features/controlLayers/store/types'; import { initialFlux2ReferenceImage, initialFluxKontextReferenceImage, initialFLUXRedux, initialIPAdapter, + initialQwenImageReferenceImage, } from 'features/controlLayers/store/util'; import { SUPPORTS_REF_IMAGES_BASE_MODELS } from 'features/modelManagerV2/models'; import { zModelIdentifierField } from 'features/nodes/types/common'; @@ -49,6 +56,7 @@ import { selectFluxVAEModels, selectGlobalRefImageModels, selectQwen3EncoderModels, + selectQwenImageDiffusersModels, selectRegionalRefImageModels, selectT5EncoderModels, selectZImageDiffusersModels, @@ -238,6 +246,44 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = } } + // handle incompatible Qwen Image Edit component source - clear if switching away + const { qwenImageComponentSource } = state.params; + if (newBase !== 'qwen-image') { + if (qwenImageComponentSource) { + dispatch(qwenImageComponentSourceSelected(null)); + modelsUpdatedDisabledOrCleared += 1; + } + } else { + // Switching to Qwen Image - auto-default component source to a matching diffusers model + if (!qwenImageComponentSource) { + const availableQwenImageDiffusers = selectQwenImageDiffusersModels(state); + + // Look up the new model's variant to match generate vs edit + const modelConfigsResult = selectModelConfigsQuery(state); + let selectedVariant: string | null = null; + if (modelConfigsResult.data) { + const newModelConfig = modelConfigsAdapterSelectors.selectById(modelConfigsResult.data, newModel.key); + if (newModelConfig && 'variant' in newModelConfig && typeof newModelConfig.variant === 'string') { + selectedVariant = newModelConfig.variant; + } + } + + // Find a diffusers model matching the variant; if no variant on denoiser, prefer "generate" then "edit" + const variantToMatch = selectedVariant ?? 'generate'; + const matchingModel = availableQwenImageDiffusers.find( + (m) => 'variant' in m && m.variant === variantToMatch + ); + const fallbackModel = availableQwenImageDiffusers.find( + (m) => 'variant' in m && m.variant !== variantToMatch + ); + const diffusersModel = matchingModel ?? fallbackModel ?? availableQwenImageDiffusers[0]; + + if (diffusersModel) { + dispatch(qwenImageComponentSourceSelected(zModelIdentifierField.parse(diffusersModel))); + } + } + } + if (newModel.base !== 'external' && SUPPORTS_REF_IMAGES_BASE_MODELS.includes(newModel.base)) { // Handle incompatible reference image models - switch to first compatible model, with some smart logic // to choose the best available model based on the new main model. @@ -280,6 +326,20 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = continue; } + if (newBase === 'qwen-image') { + // Switching TO Qwen Image Edit - convert any non-qwen configs to qwen_image_reference_image + if (!isQwenImageReferenceImageConfig(entity.config)) { + dispatch( + refImageConfigChanged({ + id: entity.id, + config: { ...initialQwenImageReferenceImage }, + }) + ); + modelsUpdatedDisabledOrCleared += 1; + } + continue; + } + if (isFlux2ReferenceImageConfig(entity.config)) { // Switching AWAY from FLUX.2 - convert flux2_reference_image to the appropriate config type let newConfig; @@ -304,6 +364,30 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = continue; } + if (isQwenImageReferenceImageConfig(entity.config)) { + // Switching AWAY from Qwen Image Edit - convert to the appropriate config type + let newConfig; + if (newGlobalRefImageModel) { + const parsedModel = zModelIdentifierField.parse(newGlobalRefImageModel); + if (newModel.base === 'flux' && newModel.name.toLowerCase().includes('kontext')) { + newConfig = { ...initialFluxKontextReferenceImage, model: parsedModel }; + } else if (newGlobalRefImageModel.type === 'flux_redux') { + newConfig = { ...initialFLUXRedux, model: parsedModel }; + } else { + newConfig = { ...initialIPAdapter, model: parsedModel }; + if (parsedModel.base === 'flux') { + newConfig.clipVisionModel = 'ViT-L'; + } + } + } else { + // No compatible model found - fall back to an empty IP adapter config + newConfig = { ...initialIPAdapter }; + } + dispatch(refImageConfigChanged({ id: entity.id, config: newConfig })); + modelsUpdatedDisabledOrCleared += 1; + continue; + } + // Standard handling for non-flux2 configs const shouldUpdateModel = (entity.config.model && entity.config.model.base !== newBase) || @@ -391,6 +475,32 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) = } } + // Handle Qwen Image model changes within the same base (variant may change between generate/edit) + // Auto-update the component source diffusers model to match the new variant + if ( + newBase === 'qwen-image' && + state.params.model?.base === 'qwen-image' && + newModel.key !== state.params.model?.key + ) { + const modelConfigsResult = selectModelConfigsQuery(state); + if (modelConfigsResult.data) { + const newModelConfig = modelConfigsAdapterSelectors.selectById(modelConfigsResult.data, newModel.key); + const newVariant = + newModelConfig && 'variant' in newModelConfig && typeof newModelConfig.variant === 'string' + ? newModelConfig.variant + : 'generate'; + + const availableQwenImageDiffusers = selectQwenImageDiffusersModels(state); + const matchingModel = availableQwenImageDiffusers.find((m) => 'variant' in m && m.variant === newVariant); + const fallbackModel = availableQwenImageDiffusers.find((m) => 'variant' in m && m.variant !== newVariant); + const diffusersModel = matchingModel ?? fallbackModel ?? availableQwenImageDiffusers[0]; + + if (diffusersModel) { + dispatch(qwenImageComponentSourceSelected(zModelIdentifierField.parse(diffusersModel))); + } + } + } + // Handle Z-Image scheduler when switching to Z-Image Base (zbase) model // LCM is not supported for undistilled models, so reset to euler if (newBase === 'z-image' && state.params.zImageScheduler === 'lcm') { diff --git a/invokeai/frontend/web/src/common/components/Picker/Picker.tsx b/invokeai/frontend/web/src/common/components/Picker/Picker.tsx index ffd0b30242..b70e44dd64 100644 --- a/invokeai/frontend/web/src/common/components/Picker/Picker.tsx +++ b/invokeai/frontend/web/src/common/components/Picker/Picker.tsx @@ -867,7 +867,7 @@ const GroupToggleButtons = typedMemo(() => { } return ( - + {groups.map((group) => ( ))} @@ -927,6 +927,7 @@ const GroupToggleButton = typedMemo(({ group }: { group: Group size="xs" variant="solid" userSelect="none" + flexShrink={0} bg={bg} color={color} borderColor={groupColor} diff --git a/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx b/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx index 00217eb796..5ac6ffcb7c 100644 --- a/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx +++ b/invokeai/frontend/web/src/features/changeBoardModal/components/ChangeBoardModal.tsx @@ -3,6 +3,7 @@ import { Combobox, ConfirmationAlertDialog, Flex, FormControl, Text } from '@inv import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; +import { selectCurrentUser } from 'features/auth/store/authSlice'; import { changeBoardReset, isModalOpenChanged, @@ -13,6 +14,7 @@ import { memo, useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; import { useAddImagesToBoardMutation, useRemoveImagesFromBoardMutation } from 'services/api/endpoints/images'; +import type { BoardDTO } from 'services/api/types'; const selectImagesToChange = createSelector( selectChangeBoardModalSlice, @@ -28,6 +30,7 @@ const ChangeBoardModal = () => { useAssertSingleton('ChangeBoardModal'); const dispatch = useAppDispatch(); const currentBoardId = useAppSelector(selectSelectedBoardId); + const currentUser = useAppSelector(selectCurrentUser); const [selectedBoardId, setSelectedBoardId] = useState(); const { data: boards, isFetching } = useListAllBoardsQuery({ include_archived: true }); const isModalOpen = useAppSelector(selectIsModalOpen); @@ -36,10 +39,20 @@ const ChangeBoardModal = () => { const [removeImagesFromBoard] = useRemoveImagesFromBoardMutation(); const { t } = useTranslation(); + // Returns true if the current user can write images to the given board. + const canWriteToBoard = useCallback( + (board: BoardDTO): boolean => { + const isOwnerOrAdmin = !currentUser || currentUser.is_admin || board.user_id === currentUser.user_id; + return isOwnerOrAdmin || board.board_visibility === 'public'; + }, + [currentUser] + ); + const options = useMemo(() => { return [{ label: t('boards.uncategorized'), value: 'none' }] .concat( (boards ?? []) + .filter(canWriteToBoard) .map((board) => ({ label: board.board_name, value: board.board_id, @@ -47,7 +60,7 @@ const ChangeBoardModal = () => { .sort((a, b) => a.label.localeCompare(b.label)) ) .filter((board) => board.value !== currentBoardId); - }, [boards, currentBoardId, t]); + }, [boards, canWriteToBoard, currentBoardId, t]); const value = useMemo(() => options.find((o) => o.value === selectedBoardId), [options, selectedBoardId]); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageSettings.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageSettings.tsx index 1a5d7bbebd..54b345361d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageSettings.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RefImage/RefImageSettings.tsx @@ -34,7 +34,12 @@ import type { FLUXReduxImageInfluence as FLUXReduxImageInfluenceType, IPMethodV2, } from 'features/controlLayers/store/types'; -import { isFlux2ReferenceImageConfig, isFLUXReduxConfig, isIPAdapterConfig } from 'features/controlLayers/store/types'; +import { + isFlux2ReferenceImageConfig, + isFLUXReduxConfig, + isIPAdapterConfig, + isQwenImageReferenceImageConfig, +} from 'features/controlLayers/store/types'; import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd'; import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; @@ -124,8 +129,9 @@ const RefImageSettingsContent = memo(() => { const isFLUX = useAppSelector(selectIsFLUX); const isExternalModel = !!mainModelConfig && isExternalApiModelConfig(mainModelConfig); - // FLUX.2 Klein and external API models do not require a ref image model selection. - const showModelSelector = !isFlux2ReferenceImageConfig(config) && !isExternalModel; + // FLUX.2 Klein, Qwen Image Edit and external API models do not require a ref image model selection. + const showModelSelector = + !isFlux2ReferenceImageConfig(config) && !isQwenImageReferenceImageConfig(config) && !isExternalModel; return ( diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/addLayerHooks.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/addLayerHooks.ts index 038af19603..2027ff4174 100644 --- a/invokeai/frontend/web/src/features/controlLayers/hooks/addLayerHooks.ts +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/addLayerHooks.ts @@ -29,6 +29,7 @@ import type { Flux2ReferenceImageConfig, FluxKontextReferenceImageConfig, IPAdapterConfig, + QwenImageReferenceImageConfig, RegionalGuidanceIPAdapterConfig, T2IAdapterConfig, } from 'features/controlLayers/store/types'; @@ -37,6 +38,7 @@ import { initialFlux2ReferenceImage, initialFluxKontextReferenceImage, initialIPAdapter, + initialQwenImageReferenceImage, initialRegionalGuidanceIPAdapter, initialT2IAdapter, } from 'features/controlLayers/store/util'; @@ -78,7 +80,7 @@ export const selectDefaultControlAdapter = createSelector( export const getDefaultRefImageConfig = ( getState: AppGetState -): IPAdapterConfig | FluxKontextReferenceImageConfig | Flux2ReferenceImageConfig => { +): IPAdapterConfig | FluxKontextReferenceImageConfig | Flux2ReferenceImageConfig | QwenImageReferenceImageConfig => { const state = getState(); const mainModelConfig = selectMainModelConfig(state); @@ -91,6 +93,11 @@ export const getDefaultRefImageConfig = ( return deepClone(initialFlux2ReferenceImage); } + // Qwen Image Edit has built-in reference image support - no model needed + if (base === 'qwen-image') { + return deepClone(initialQwenImageReferenceImage); + } + if (base === 'flux' && mainModelConfig?.name?.toLowerCase().includes('kontext')) { const config = deepClone(initialFluxKontextReferenceImage); config.model = zModelIdentifierField.parse(mainModelConfig); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 15a53cd037..07cee8211c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -261,6 +261,19 @@ const slice = createSlice({ } state.kleinQwen3EncoderModel = result.data; }, + qwenImageComponentSourceSelected: (state, action: PayloadAction) => { + const result = zParamsState.shape.qwenImageComponentSource.safeParse(action.payload); + if (!result.success) { + return; + } + state.qwenImageComponentSource = result.data; + }, + qwenImageQuantizationChanged: (state, action: PayloadAction<'none' | 'int8' | 'nf4'>) => { + state.qwenImageQuantization = action.payload; + }, + qwenImageShiftChanged: (state, action: PayloadAction) => { + state.qwenImageShift = action.payload; + }, vaePrecisionChanged: (state, action: PayloadAction) => { state.vaePrecision = action.payload; }, @@ -566,6 +579,9 @@ const resetState = (state: ParamsState): ParamsState => { newState.animaT5EncoderModel = oldState.animaT5EncoderModel; newState.kleinVaeModel = oldState.kleinVaeModel; newState.kleinQwen3EncoderModel = oldState.kleinQwen3EncoderModel; + newState.qwenImageComponentSource = oldState.qwenImageComponentSource; + newState.qwenImageQuantization = oldState.qwenImageQuantization; + newState.qwenImageShift = oldState.qwenImageShift; return newState; }; @@ -613,6 +629,9 @@ export const { zImageQwen3SourceModelSelected, kleinVaeModelSelected, kleinQwen3EncoderModelSelected, + qwenImageComponentSourceSelected, + qwenImageQuantizationChanged, + qwenImageShiftChanged, setClipSkip, shouldUseCpuNoiseChanged, setColorCompensation, @@ -691,6 +710,7 @@ export const selectIsZImage = createParamsSelector((params) => params.model?.bas export const selectIsAnima = createParamsSelector((params) => params.model?.base === 'anima'); export const selectIsFlux2 = createParamsSelector((params) => params.model?.base === 'flux2'); export const selectIsExternal = createParamsSelector((params) => params.model?.base === 'external'); +export const selectIsQwenImage = createParamsSelector((params) => params.model?.base === 'qwen-image'); export const selectIsFluxKontext = createParamsSelector((params) => { if (params.model?.base === 'flux' && params.model?.name.toLowerCase().includes('kontext')) { return true; @@ -717,6 +737,9 @@ export const selectAnimaT5EncoderModel = createParamsSelector((params) => params export const selectAnimaScheduler = createParamsSelector((params) => params.animaScheduler); export const selectKleinVaeModel = createParamsSelector((params) => params.kleinVaeModel); export const selectKleinQwen3EncoderModel = createParamsSelector((params) => params.kleinQwen3EncoderModel); +export const selectQwenImageComponentSource = createParamsSelector((params) => params.qwenImageComponentSource); +export const selectQwenImageQuantization = createParamsSelector((params) => params.qwenImageQuantization); +export const selectQwenImageShift = createParamsSelector((params) => params.qwenImageShift); export const selectCFGScale = createParamsSelector((params) => params.cfgScale); export const selectGuidance = createParamsSelector((params) => params.guidance); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts index ab21db3fec..2b7c0f7d17 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/refImagesSlice.ts @@ -22,6 +22,7 @@ import { isFlux2ReferenceImageConfig, isFLUXReduxConfig, isIPAdapterConfig, + isQwenImageReferenceImageConfig, zRefImagesState, } from './types'; import { getReferenceImageState, initialFluxKontextReferenceImage, initialFLUXRedux, initialIPAdapter } from './util'; @@ -106,8 +107,8 @@ const slice = createSlice({ return; } - // FLUX.2 reference images don't have a model field - they use built-in support - if (isFlux2ReferenceImageConfig(entity.config)) { + // FLUX.2 and Qwen Image Edit reference images don't have a model field - they use built-in support + if (isFlux2ReferenceImageConfig(entity.config) || isQwenImageReferenceImageConfig(entity.config)) { return; } diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index dad1893911..eb5329dc10 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -370,6 +370,13 @@ const zFlux2ReferenceImageConfig = z.object({ }); export type Flux2ReferenceImageConfig = z.infer; +// Qwen Image Edit has built-in reference image support - no separate model needed +const zQwenImageReferenceImageConfig = z.object({ + type: z.literal('qwen_image_reference_image'), + image: zCroppableImageWithDims.nullable(), +}); +export type QwenImageReferenceImageConfig = z.infer; + const zCanvasEntityBase = z.object({ id: zId, name: zName, @@ -385,6 +392,7 @@ export const zRefImageState = z.object({ zFLUXReduxConfig, zFluxKontextReferenceImageConfig, zFlux2ReferenceImageConfig, + zQwenImageReferenceImageConfig, ]), }); export type RefImageState = z.infer; @@ -402,6 +410,10 @@ export const isFluxKontextReferenceImageConfig = ( export const isFlux2ReferenceImageConfig = (config: RefImageState['config']): config is Flux2ReferenceImageConfig => config.type === 'flux2_reference_image'; +export const isQwenImageReferenceImageConfig = ( + config: RefImageState['config'] +): config is QwenImageReferenceImageConfig => config.type === 'qwen_image_reference_image'; + const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']); export type FillStyle = z.infer; export const isFillStyle = (v: unknown): v is FillStyle => zFillStyle.safeParse(v).success; @@ -782,6 +794,10 @@ export const zParamsState = z.object({ // Flux2 Klein model components - uses Qwen3 instead of CLIP+T5 kleinVaeModel: zParameterVAEModel.nullable(), // Optional: Separate FLUX.2 VAE for Klein kleinQwen3EncoderModel: zModelIdentifierField.nullable(), // Optional: Separate Qwen3 Encoder for Klein + // Qwen Image Edit model components - GGUF transformer needs a Diffusers source for VAE/encoder + qwenImageComponentSource: zParameterModel.nullable(), // Diffusers model providing VAE + text encoder + qwenImageQuantization: z.enum(['none', 'int8', 'nf4']), // BitsAndBytes quantization for Qwen VL encoder + qwenImageShift: z.number().nullable(), // Sigma schedule shift override (e.g. 3.0 for Lightning LoRAs) // Z-Image Seed Variance Enhancer settings zImageSeedVarianceEnabled: z.boolean(), zImageSeedVarianceStrength: z.number().min(0).max(2), @@ -859,6 +875,9 @@ export const getInitialParamsState = (): ParamsState => ({ animaScheduler: 'euler', kleinVaeModel: null, kleinQwen3EncoderModel: null, + qwenImageComponentSource: null, + qwenImageQuantization: 'none' as const, + qwenImageShift: null, zImageSeedVarianceEnabled: false, zImageSeedVarianceStrength: 0.1, zImageSeedVarianceRandomizePercent: 50, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/util.ts b/invokeai/frontend/web/src/features/controlLayers/store/util.ts index f14af4feee..2aae90e72a 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/util.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/util.ts @@ -15,6 +15,7 @@ import type { FLUXReduxConfig, ImageWithDims, IPAdapterConfig, + QwenImageReferenceImageConfig, RasterLayerAdjustments, RefImageState, RegionalGuidanceIPAdapterConfig, @@ -117,6 +118,10 @@ export const initialFlux2ReferenceImage: Flux2ReferenceImageConfig = { type: 'flux2_reference_image', image: null, }; +export const initialQwenImageReferenceImage: QwenImageReferenceImageConfig = { + type: 'qwen_image_reference_image', + image: null, +}; export const initialT2IAdapter: T2IAdapterConfig = { type: 't2i_adapter', model: null, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts index f3aa68d588..db5ad4f766 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts @@ -147,8 +147,8 @@ export const getGlobalReferenceImageWarnings = ( const { config } = entity; - // FLUX.2 reference images don't require a model - it's built-in - if (config.type !== 'flux2_reference_image') { + // FLUX.2 and Qwen Image Edit reference images don't require a model - it's built-in + if (config.type !== 'flux2_reference_image' && config.type !== 'qwen_image_reference_image') { if (!('model' in config) || !config.model) { // No model selected warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED); @@ -159,8 +159,10 @@ export const getGlobalReferenceImageWarnings = ( } if (!entity.config.image) { - // No image selected - warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED); + // No image selected - for Qwen Image Edit, an image is optional (txt2img works without one) + if (config.type !== 'qwen_image_reference_image') { + warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED); + } } } diff --git a/invokeai/frontend/web/src/features/dnd/dnd.ts b/invokeai/frontend/web/src/features/dnd/dnd.ts index f5e38d4b94..ee648e82ef 100644 --- a/invokeai/frontend/web/src/features/dnd/dnd.ts +++ b/invokeai/frontend/web/src/features/dnd/dnd.ts @@ -434,6 +434,49 @@ export const replaceCanvasEntityObjectsWithImageDndTarget: DndTarget< //#endregion //#region Add To Board +/** + * Check whether the current user can move images out of their source board. + * Returns false if the source board is a shared board not owned by the current user + * (and the user is not an admin). In that case, images can be viewed/used but not moved. + */ +const canMoveFromSourceBoard = (sourceBoardId: BoardId, getState: AppGetState): boolean => { + const state = getState(); + // In single-user mode (no auth), always allow + const currentUser = state.auth?.user; + if (!currentUser) { + return true; + } + // Admins can always move + if (currentUser.is_admin) { + return true; + } + // "Uncategorized" (none) — user's own uncategorized images, allow + if (sourceBoardId === 'none') { + return true; + } + // Look up the board from the RTK Query cache + const boardsQueryState = state.api?.queries; + if (boardsQueryState) { + for (const query of Object.values(boardsQueryState)) { + if (query?.data && Array.isArray(query.data)) { + const board = (query.data as Array<{ board_id: string; user_id?: string; board_visibility?: string }>).find( + (b) => b.board_id === sourceBoardId + ); + if (board) { + // Owner can always move + if (board.user_id === currentUser.user_id) { + return true; + } + // Non-owner can only move from public boards + return board.board_visibility === 'public'; + } + } + } + } + // Board not found in cache — allow by default to avoid blocking legitimate operations + return true; +}; + const _addToBoard = buildTypeAndKey('add-to-board'); export type AddImageToBoardDndTargetData = DndData< typeof _addToBoard.type, @@ -447,16 +490,23 @@ export const addImageToBoardDndTarget: DndTarget< ..._addToBoard, typeGuard: buildTypeGuard(_addToBoard.key), getData: buildGetData(_addToBoard.key, _addToBoard.type), - isValid: ({ sourceData, targetData }) => { + isValid: ({ sourceData, targetData, getState }) => { if (singleImageDndSource.typeGuard(sourceData)) { const currentBoard = sourceData.payload.imageDTO.board_id ?? 'none'; const destinationBoard = targetData.payload.boardId; - return currentBoard !== destinationBoard; + if (currentBoard === destinationBoard) { + return false; + } + // Don't allow moving images from shared boards the user doesn't own + return canMoveFromSourceBoard(currentBoard, getState); } if (multipleImageDndSource.typeGuard(sourceData)) { const currentBoard = sourceData.payload.board_id; const destinationBoard = targetData.payload.boardId; - return currentBoard !== destinationBoard; + if (currentBoard === destinationBoard) { + return false; + } + return canMoveFromSourceBoard(currentBoard, getState); } return false; }, @@ -491,15 +541,22 @@ export const removeImageFromBoardDndTarget: DndTarget< ..._removeFromBoard, typeGuard: buildTypeGuard(_removeFromBoard.key), getData: buildGetData(_removeFromBoard.key, _removeFromBoard.type), - isValid: ({ sourceData }) => { + isValid: ({ sourceData, getState }) => { if (singleImageDndSource.typeGuard(sourceData)) { const currentBoard = sourceData.payload.imageDTO.board_id ?? 'none'; - return currentBoard !== 'none'; + if (currentBoard === 'none') { + return false; + } + // Don't allow removing images from shared boards the user doesn't own + return canMoveFromSourceBoard(currentBoard, getState); } if (multipleImageDndSource.typeGuard(sourceData)) { const currentBoard = sourceData.payload.board_id; - return currentBoard !== 'none'; + if (currentBoard === 'none') { + return false; + } + return canMoveFromSourceBoard(currentBoard, getState); } return false; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx index 5cc25f6c03..d10dde6ee4 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardContextMenu.tsx @@ -2,15 +2,26 @@ import type { ContextMenuProps } from '@invoke-ai/ui-library'; import { ContextMenu, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library'; import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { selectCurrentUser } from 'features/auth/store/authSlice'; import { $boardToDelete } from 'features/gallery/components/Boards/DeleteBoardModal'; import { selectAutoAddBoardId, selectAutoAssignBoardOnClick } from 'features/gallery/store/gallerySelectors'; import { autoAddBoardIdChanged } from 'features/gallery/store/gallerySlice'; import { toast } from 'features/toast/toast'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiArchiveBold, PiArchiveFill, PiDownloadBold, PiPlusBold, PiTrashSimpleBold } from 'react-icons/pi'; +import { + PiArchiveBold, + PiArchiveFill, + PiDownloadBold, + PiGlobeBold, + PiLockBold, + PiPlusBold, + PiShareNetworkBold, + PiTrashSimpleBold, +} from 'react-icons/pi'; import { useUpdateBoardMutation } from 'services/api/endpoints/boards'; import { useBulkDownloadImagesMutation } from 'services/api/endpoints/images'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; import { useBoardName } from 'services/api/hooks/useBoardName'; import type { BoardDTO } from 'services/api/types'; @@ -23,6 +34,7 @@ const BoardContextMenu = ({ board, children }: Props) => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const autoAssignBoardOnClick = useAppSelector(selectAutoAssignBoardOnClick); + const currentUser = useAppSelector(selectCurrentUser); const selectIsSelectedForAutoAdd = useMemo( () => createSelector(selectAutoAddBoardId, (autoAddBoardId) => board.board_id === autoAddBoardId), [board.board_id] @@ -35,6 +47,11 @@ const BoardContextMenu = ({ board, children }: Props) => { const [bulkDownload] = useBulkDownloadImagesMutation(); + // Only the board owner or admin can modify visibility + const canChangeVisibility = currentUser !== null && (currentUser.is_admin || board.user_id === currentUser.user_id); + + const { canDeleteBoard } = useBoardAccess(board); + const handleSetAutoAdd = useCallback(() => { dispatch(autoAddBoardIdChanged(board.board_id)); }, [board.board_id, dispatch]); @@ -64,6 +81,26 @@ const BoardContextMenu = ({ board, children }: Props) => { }); }, [board.board_id, updateBoard]); + const handleSetVisibility = useCallback( + async (visibility: 'private' | 'shared' | 'public') => { + try { + await updateBoard({ + board_id: board.board_id, + changes: { board_visibility: visibility }, + }).unwrap(); + } catch { + toast({ status: 'error', title: t('boards.updateBoardVisibilityError') }); + } + }, + [board.board_id, t, updateBoard] + ); + + const handleSetVisibilityPrivate = useCallback(() => handleSetVisibility('private'), [handleSetVisibility]); + + const handleSetVisibilityShared = useCallback(() => handleSetVisibility('shared'), [handleSetVisibility]); + + const handleSetVisibilityPublic = useCallback(() => handleSetVisibility('public'), [handleSetVisibility]); + const setAsBoardToDelete = useCallback(() => { $boardToDelete.set(board); }, [board]); @@ -83,18 +120,50 @@ const BoardContextMenu = ({ board, children }: Props) => { {board.archived && ( - } onClick={handleUnarchive}> + } onClick={handleUnarchive} isDisabled={!canDeleteBoard}> {t('boards.unarchiveBoard')} )} {!board.archived && ( - } onClick={handleArchive}> + } onClick={handleArchive} isDisabled={!canDeleteBoard}> {t('boards.archiveBoard')} )} - } onClick={setAsBoardToDelete} isDestructive> + {canChangeVisibility && ( + <> + } + onClick={handleSetVisibilityPrivate} + isDisabled={board.board_visibility === 'private'} + > + {t('boards.setVisibilityPrivate')} + + } + onClick={handleSetVisibilityShared} + isDisabled={board.board_visibility === 'shared'} + > + {t('boards.setVisibilityShared')} + + } + onClick={handleSetVisibilityPublic} + isDisabled={board.board_visibility === 'public'} + > + {t('boards.setVisibilityPublic')} + + + )} + + } + onClick={setAsBoardToDelete} + isDestructive + isDisabled={!canDeleteBoard} + > {t('boards.deleteBoard')} @@ -108,8 +177,14 @@ const BoardContextMenu = ({ board, children }: Props) => { t, handleBulkDownload, board.archived, + board.board_visibility, handleUnarchive, handleArchive, + canChangeVisibility, + handleSetVisibilityPrivate, + handleSetVisibilityShared, + handleSetVisibilityPublic, + canDeleteBoard, setAsBoardToDelete, ] ); diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardEditableTitle.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardEditableTitle.tsx index 67c7dad6ed..cf2749e340 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardEditableTitle.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/BoardEditableTitle.tsx @@ -7,6 +7,7 @@ import { memo, useCallback, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import { PiPencilBold } from 'react-icons/pi'; import { useUpdateBoardMutation } from 'services/api/endpoints/boards'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; import type { BoardDTO } from 'services/api/types'; type Props = { @@ -19,6 +20,7 @@ export const BoardEditableTitle = memo(({ board, isSelected }: Props) => { const isHovering = useBoolean(false); const inputRef = useRef(null); const [updateBoard, updateBoardResult] = useUpdateBoardMutation(); + const { canRenameBoard } = useBoardAccess(board); const onChange = useCallback( async (board_name: string) => { @@ -51,13 +53,13 @@ export const BoardEditableTitle = memo(({ board, isSelected }: Props) => { fontWeight="semibold" userSelect="none" color={isSelected ? 'base.100' : 'base.300'} - onDoubleClick={editable.startEditing} - cursor="text" + onDoubleClick={canRenameBoard ? editable.startEditing : undefined} + cursor={canRenameBoard ? 'text' : 'default'} noOfLines={1} > {editable.value} - {isHovering.isTrue && ( + {canRenameBoard && isHovering.isTrue && ( } diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx index 4d821f819c..10fbe61832 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx @@ -18,8 +18,9 @@ import { import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiArchiveBold, PiImageSquare } from 'react-icons/pi'; +import { PiArchiveBold, PiGlobeBold, PiImageSquare, PiShareNetworkBold } from 'react-icons/pi'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; import type { BoardDTO } from 'services/api/types'; const _hover: SystemStyleObject = { @@ -62,6 +63,8 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => { const showOwner = currentUser?.is_admin && board.owner_username; + const { canWriteImages } = useBoardAccess(board); + return ( @@ -99,6 +102,20 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => { {autoAddBoardId === board.board_id && } {board.archived && } + {board.board_visibility === 'shared' && ( + + + + + + )} + {board.board_visibility === 'public' && ( + + + + + + )} {board.image_count} | {board.asset_count} @@ -108,7 +125,12 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => { )} - + ); }; diff --git a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemChangeBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemChangeBoard.tsx index 7176487015..f5c044132e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemChangeBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemChangeBoard.tsx @@ -5,11 +5,15 @@ import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiFoldersBold } from 'react-icons/pi'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; +import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard'; export const ContextMenuItemChangeBoard = memo(() => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const imageDTO = useImageDTOContext(); + const selectedBoard = useSelectedBoard(); + const { canWriteImages } = useBoardAccess(selectedBoard); const onClick = useCallback(() => { dispatch(imagesToChangeSelected([imageDTO.image_name])); @@ -17,7 +21,7 @@ export const ContextMenuItemChangeBoard = memo(() => { }, [dispatch, imageDTO]); return ( - } onClickCapture={onClick}> + } onClickCapture={onClick} isDisabled={!canWriteImages}> {t('boards.changeBoard')} ); diff --git a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemDeleteImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemDeleteImage.tsx index e20221f342..5dfa7116b1 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemDeleteImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MenuItems/ContextMenuItemDeleteImage.tsx @@ -4,11 +4,15 @@ import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiTrashSimpleBold } from 'react-icons/pi'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; +import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard'; export const ContextMenuItemDeleteImage = memo(() => { const { t } = useTranslation(); const deleteImageModal = useDeleteImageModalApi(); const imageDTO = useImageDTOContext(); + const selectedBoard = useSelectedBoard(); + const { canWriteImages } = useBoardAccess(selectedBoard); const onClick = useCallback(async () => { try { @@ -18,6 +22,10 @@ export const ContextMenuItemDeleteImage = memo(() => { } }, [deleteImageModal, imageDTO]); + if (!canWriteImages) { + return null; + } + return ( } diff --git a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MultipleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MultipleSelectionMenuItems.tsx index d148332943..ee3c8e4e98 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MultipleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ContextMenu/MultipleSelectionMenuItems.tsx @@ -10,12 +10,16 @@ import { useStarImagesMutation, useUnstarImagesMutation, } from 'services/api/endpoints/images'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; +import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard'; const MultipleSelectionMenuItems = () => { const { t } = useTranslation(); const dispatch = useAppDispatch(); const selection = useAppSelector((s) => s.gallery.selection); const deleteImageModal = useDeleteImageModalApi(); + const selectedBoard = useSelectedBoard(); + const { canWriteImages } = useBoardAccess(selectedBoard); const [starImages] = useStarImagesMutation(); const [unstarImages] = useUnstarImagesMutation(); @@ -53,11 +57,16 @@ const MultipleSelectionMenuItems = () => { } onClickCapture={handleBulkDownload}> {t('gallery.downloadSelection')} - } onClickCapture={handleChangeBoard}> + } onClickCapture={handleChangeBoard} isDisabled={!canWriteImages}> {t('boards.changeBoard')} - } onClickCapture={handleDeleteSelection}> + } + onClickCapture={handleDeleteSelection} + isDisabled={!canWriteImages} + > {t('gallery.deleteSelection')} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx index ccd58992ef..af1d376887 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx @@ -108,6 +108,25 @@ export const GalleryImage = memo(({ imageDTO }: Props) => { if (!element) { return; } + + const monitorBinding = monitorForElements({ + // This is a "global" drag start event, meaning that it is called for all drag events. + onDragStart: ({ source }) => { + // When we start dragging multiple images, set the dragging state to true if the dragged image is part of the + // selection. This is called for all drag events. + if ( + multipleImageDndSource.typeGuard(source.data) && + source.data.payload.image_names.includes(imageDTO.image_name) + ) { + setIsDragging(true); + } + }, + onDrop: () => { + // Always set the dragging state to false when a drop event occurs. + setIsDragging(false); + }, + }); + return combine( firefoxDndFix(element), draggable({ @@ -153,23 +172,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => { } }, }), - monitorForElements({ - // This is a "global" drag start event, meaning that it is called for all drag events. - onDragStart: ({ source }) => { - // When we start dragging multiple images, set the dragging state to true if the dragged image is part of the - // selection. This is called for all drag events. - if ( - multipleImageDndSource.typeGuard(source.data) && - source.data.payload.image_names.includes(imageDTO.image_name) - ) { - setIsDragging(true); - } - }, - onDrop: () => { - // Always set the dragging state to false when a drop event occurs. - setIsDragging(false); - }, - }) + monitorBinding ); }, [imageDTO, store]); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryItemDeleteIconButton.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryItemDeleteIconButton.tsx index 0a97bf819d..612e6361b1 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryItemDeleteIconButton.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryItemDeleteIconButton.tsx @@ -5,6 +5,8 @@ import type { MouseEvent } from 'react'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { PiTrashSimpleFill } from 'react-icons/pi'; +import { useBoardAccess } from 'services/api/hooks/useBoardAccess'; +import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard'; import type { ImageDTO } from 'services/api/types'; type Props = { @@ -15,6 +17,8 @@ export const GalleryItemDeleteIconButton = memo(({ imageDTO }: Props) => { const shift = useShiftModifier(); const { t } = useTranslation(); const deleteImageModal = useDeleteImageModalApi(); + const selectedBoard = useSelectedBoard(); + const { canWriteImages } = useBoardAccess(selectedBoard); const onClick = useCallback( (e: MouseEvent) => { @@ -24,7 +28,7 @@ export const GalleryItemDeleteIconButton = memo(({ imageDTO }: Props) => { [deleteImageModal, imageDTO] ); - if (!shift) { + if (!shift || !canWriteImages) { return null; } diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.test.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.test.tsx new file mode 100644 index 0000000000..c743315382 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.test.tsx @@ -0,0 +1,24 @@ +import { ImageMetadataHandlers } from 'features/metadata/parsing'; +import { describe, expect, it } from 'vitest'; + +import { ImageMetadataActions } from './ImageMetadataActions'; + +describe('ImageMetadataActions', () => { + it('includes Qwen metadata handlers in the recall parameters UI', () => { + const element = (ImageMetadataActions as unknown as { type: (props: { metadata: unknown }) => unknown }).type({ + metadata: { model: { key: 'test' } }, + }) as { + props: { + children: Array<{ props?: { handler?: unknown } }>; + }; + }; + + const handlers = element.props.children + .map((child) => child.props?.handler) + .filter((handler): handler is unknown => handler !== undefined); + + expect(handlers).toContain(ImageMetadataHandlers.QwenImageComponentSource); + expect(handlers).toContain(ImageMetadataHandlers.QwenImageQuantization); + expect(handlers).toContain(ImageMetadataHandlers.QwenImageShift); + }); +}); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index 8123db4b0b..e123d0ebd0 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -58,6 +58,9 @@ export const ImageMetadataActions = memo((props: Props) => { + + + diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/NoContentForViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/NoContentForViewer.tsx index b8a522c3a6..c301922df9 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/NoContentForViewer.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/NoContentForViewer.tsx @@ -1,7 +1,9 @@ import type { ButtonProps } from '@invoke-ai/ui-library'; import { Alert, AlertDescription, AlertIcon, Button, Divider, Flex, Link, Spinner, Text } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import { InvokeLogoIcon } from 'common/components/InvokeLogoIcon'; +import { selectCurrentUser } from 'features/auth/store/authSlice'; import { LOADING_SYMBOL, useHasImages } from 'features/gallery/hooks/useHasImages'; import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore'; import { navigationApi } from 'features/ui/layouts/navigation-api'; @@ -9,16 +11,26 @@ import type { PropsWithChildren } from 'react'; import { memo, useCallback, useMemo } from 'react'; import { Trans, useTranslation } from 'react-i18next'; import { PiArrowSquareOutBold, PiImageBold } from 'react-icons/pi'; +import { useGetSetupStatusQuery } from 'services/api/endpoints/auth'; import { useMainModels } from 'services/api/hooks/modelsByType'; export const NoContentForViewer = memo(() => { const hasImages = useHasImages(); const [mainModels, { data }] = useMainModels(); + const { data: setupStatus } = useGetSetupStatusQuery(); + const user = useAppSelector(selectCurrentUser); const { t } = useTranslation(); + const isMultiuser = setupStatus?.multiuser_enabled ?? false; + const isAdmin = !isMultiuser || (user?.is_admin ?? false); + const adminEmail = setupStatus?.admin_email ?? null; + + const modelsLoaded = data !== undefined; + const hasModels = mainModels.length > 0; + const showStarterBundles = useMemo(() => { - return data && mainModels.length === 0; - }, [mainModels.length, data]); + return modelsLoaded && !hasModels && isAdmin; + }, [modelsLoaded, hasModels, isAdmin]); if (hasImages === LOADING_SYMBOL) { // Blank bg w/ a spinner. The new user experience components below have an invoke logo, but it's not centered. @@ -36,10 +48,18 @@ export const NoContentForViewer = memo(() => { - - {showStarterBundles && } - - + {isAdmin ? ( + // Admin / single-user mode + <> + {modelsLoaded && hasModels ? : } + {showStarterBundles && } + + + + ) : ( + // Non-admin user in multiuser mode + <>{modelsLoaded && hasModels ? : } + )} ); @@ -99,6 +119,32 @@ const GetStartedLocal = () => { ); }; +const GetStartedWithModels = () => { + return ( + + + + ); +}; + +const GetStartedNonAdmin = ({ adminEmail }: { adminEmail: string | null }) => { + const AdminEmailLink = adminEmail ? ( + + {adminEmail} + + ) : ( + + your administrator + + ); + + return ( + + + + ); +}; + const StarterBundlesCallout = () => { const handleClickDownloadStarterModels = useCallback(() => { navigationApi.switchToTab('models'); diff --git a/invokeai/frontend/web/src/features/metadata/parsing.test.ts b/invokeai/frontend/web/src/features/metadata/parsing.test.ts new file mode 100644 index 0000000000..a8eb2cb8af --- /dev/null +++ b/invokeai/frontend/web/src/features/metadata/parsing.test.ts @@ -0,0 +1,94 @@ +import { describe, expect, it, vi } from 'vitest'; + +import { ImageMetadataHandlers, MetadataUtils } from './parsing'; + +const createMockStore = () => ({ + dispatch: vi.fn(), + getState: vi.fn(() => ({ + params: { model: null }, + })), +}); + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +const createStore = () => createMockStore() as any; + +describe('Qwen metadata parsing', () => { + it('does not report missing Qwen metadata keys as available', async () => { + const store = createStore(); + + const hasMetadata = await MetadataUtils.hasMetadataByHandlers({ + metadata: {}, + handlers: [ + ImageMetadataHandlers.QwenImageComponentSource, + ImageMetadataHandlers.QwenImageQuantization, + ImageMetadataHandlers.QwenImageShift, + ], + store, + require: 'all', + }); + + // Handlers reject when keys are absent, so hasMetadata should be false + expect(hasMetadata).toBe(false); + }); + + it('does not recall Qwen values when metadata keys are absent', async () => { + const store = createStore(); + + const recalled = await MetadataUtils.recallByHandlers({ + metadata: {}, + handlers: [ + ImageMetadataHandlers.QwenImageComponentSource, + ImageMetadataHandlers.QwenImageQuantization, + ImageMetadataHandlers.QwenImageShift, + ], + store, + silent: true, + }); + + // No keys present → handlers reject → 0 recalls, no dispatches + expect(recalled.size).toBe(0); + const mockStore = store as ReturnType; + expect(mockStore.dispatch).not.toHaveBeenCalled(); + }); + + it('recalls Qwen handlers with actual values when metadata keys are present', async () => { + const store = createStore(); + + const recalled = await MetadataUtils.recallByHandlers({ + metadata: { + qwen_image_component_source: { key: 'test-key', hash: 'test', name: 'Test', base: 'qwen-image', type: 'main' }, + qwen_image_quantization: 'nf4', + qwen_image_shift: 3.0, + }, + handlers: [ + ImageMetadataHandlers.QwenImageComponentSource, + ImageMetadataHandlers.QwenImageQuantization, + ImageMetadataHandlers.QwenImageShift, + ], + store, + silent: true, + }); + + expect(recalled.size).toBe(3); + const mockStore = store as ReturnType; + expect(mockStore.dispatch).toHaveBeenCalledTimes(3); + }); + + it('recalls Qwen component source as null when key is present but value is null', async () => { + const store = createStore(); + + const recalled = await MetadataUtils.recallByHandlers({ + metadata: { + qwen_image_component_source: null, + }, + handlers: [ImageMetadataHandlers.QwenImageComponentSource], + store, + silent: true, + }); + + // Key is present with null value → handler resolves with null → 1 recall + expect(recalled.size).toBe(1); + const mockStore = store as ReturnType; + expect(mockStore.dispatch).toHaveBeenCalledTimes(1); + }); +}); diff --git a/invokeai/frontend/web/src/features/metadata/parsing.tsx b/invokeai/frontend/web/src/features/metadata/parsing.tsx index edf6270e13..4f643123be 100644 --- a/invokeai/frontend/web/src/features/metadata/parsing.tsx +++ b/invokeai/frontend/web/src/features/metadata/parsing.tsx @@ -16,6 +16,9 @@ import { kleinVaeModelSelected, negativePromptChanged, positivePromptChanged, + qwenImageComponentSourceSelected, + qwenImageQuantizationChanged, + qwenImageShiftChanged, refinerModelChanged, selectBase, setAnimaScheduler, @@ -687,6 +690,83 @@ const ZImageSeedVarianceRandomizePercent: SingleMetadataHandler = { }; //#endregion ZImageSeedVarianceRandomizePercent +//#region QwenImageComponentSource +const QwenImageComponentSource: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'QwenImageComponentSource', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'qwen_image_component_source'); + // Reject when the key is absent so the handler is not rendered for non-Qwen images + if (raw === undefined) { + return Promise.reject(); + } + if (raw === null) { + return Promise.resolve(null); + } + return Promise.resolve(zModelIdentifierField.parse(raw)); + }, + recall: (value, store) => { + store.dispatch(qwenImageComponentSourceSelected(value)); + }, + i18nKey: 'modelManager.qwenImageComponentSource', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => ( + + ), +}; +//#endregion QwenImageComponentSource + +//#region QwenImageQuantization +const QwenImageQuantization: SingleMetadataHandler<'none' | 'int8' | 'nf4'> = { + [SingleMetadataKey]: true, + type: 'QwenImageQuantization', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'qwen_image_quantization'); + // Reject when the key is absent so the handler is not rendered for non-Qwen images + if (raw === undefined) { + return Promise.reject(); + } + const parsed = z.enum(['none', 'int8', 'nf4']).parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(qwenImageQuantizationChanged(value)); + }, + i18nKey: 'modelManager.qwenImageQuantization', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps<'none' | 'int8' | 'nf4'>) => ( + + ), +}; +//#endregion QwenImageQuantization + +//#region QwenImageShift +const QwenImageShift: SingleMetadataHandler = { + [SingleMetadataKey]: true, + type: 'QwenImageShift', + parse: (metadata, _store) => { + const raw = getProperty(metadata, 'qwen_image_shift'); + // Reject when the key is absent so the handler is not rendered for non-Qwen images + if (raw === undefined) { + return Promise.reject(); + } + if (raw === null) { + return Promise.resolve(null); + } + const parsed = z.number().parse(raw); + return Promise.resolve(parsed); + }, + recall: (value, store) => { + store.dispatch(qwenImageShiftChanged(value)); + }, + i18nKey: 'modelManager.qwenImageShift', + LabelComponent: MetadataLabel, + ValueComponent: ({ value }: SingleMetadataValueProps) => ( + + ), +}; +//#endregion QwenImageShift + //#region ZImageShift const ZImageShift: SingleMetadataHandler = { [SingleMetadataKey]: true, @@ -1334,6 +1414,9 @@ export const ImageMetadataHandlers = { ZImageSeedVarianceEnabled, ZImageSeedVarianceStrength, ZImageSeedVarianceRandomizePercent, + QwenImageComponentSource, + QwenImageQuantization, + QwenImageShift, ZImageShift, LoRAs, CanvasLayers, diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useBuildModelsToInstall.ts b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useBuildModelsToInstall.ts index 457d48ce19..85e24a3d07 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useBuildModelsToInstall.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useBuildModelsToInstall.ts @@ -4,7 +4,7 @@ import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/ import type { StarterModel } from 'services/api/types'; type ModelInstallArg = { - config: Pick; + config: Pick; source: string; }; @@ -32,7 +32,7 @@ export const useBuildModelInstallArg = () => { ); const buildModelInstallArg = useCallback((starterModel: StarterModel): ModelInstallArg => { - const { name, base, type, source, description, format } = starterModel; + const { name, base, type, source, description, format, variant } = starterModel; return { config: { @@ -41,6 +41,7 @@ export const useBuildModelInstallArg = () => { type, description, format, + variant, }, source, }; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx index d1774f9ded..9b76fbbde6 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/hooks/useStarterModelsToast.tsx @@ -1,10 +1,11 @@ import { Button, Text, useToast } from '@invoke-ai/ui-library'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectIsAuthenticated } from 'features/auth/store/authSlice'; +import { selectCurrentUser, selectIsAuthenticated } from 'features/auth/store/authSlice'; import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore'; import { navigationApi } from 'features/ui/layouts/navigation-api'; import { useCallback, useEffect, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { useGetSetupStatusQuery } from 'services/api/endpoints/auth'; import { useMainModels } from 'services/api/hooks/modelsByType'; const TOAST_ID = 'starterModels'; @@ -15,6 +16,11 @@ export const useStarterModelsToast = () => { const [mainModels, { data }] = useMainModels(); const toast = useToast(); const isAuthenticated = useAppSelector(selectIsAuthenticated); + const { data: setupStatus } = useGetSetupStatusQuery(); + const user = useAppSelector(selectCurrentUser); + + const isMultiuser = setupStatus?.multiuser_enabled ?? false; + const isAdmin = !isMultiuser || (user?.is_admin ?? false); useEffect(() => { // Only show the toast if the user is authenticated @@ -33,17 +39,17 @@ export const useStarterModelsToast = () => { toast({ id: TOAST_ID, title: t('modelManager.noModelsInstalled'), - description: , + description: isAdmin ? : , status: 'info', isClosable: true, duration: null, onCloseComplete: () => setDidToast(true), }); } - }, [data, didToast, isAuthenticated, mainModels.length, t, toast]); + }, [data, didToast, isAuthenticated, isAdmin, mainModels.length, t, toast]); }; -const ToastDescription = () => { +const AdminToastDescription = () => { const { t } = useTranslation(); const toast = useToast(); @@ -62,3 +68,9 @@ const ToastDescription = () => { ); }; + +const NonAdminToastDescription = () => { + const { t } = useTranslation(); + + return {t('modelManager.noModelsInstalledAskAdmin')}; +}; diff --git a/invokeai/frontend/web/src/features/modelManagerV2/models.ts b/invokeai/frontend/web/src/features/modelManagerV2/models.ts index a63f38cafd..9cc4ed24d9 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/models.ts +++ b/invokeai/frontend/web/src/features/modelManagerV2/models.ts @@ -148,6 +148,7 @@ export const MODEL_BASE_TO_COLOR: Record = { flux: 'gold', flux2: 'gold', cogview4: 'red', + 'qwen-image': 'orange', 'z-image': 'cyan', external: 'orange', anima: 'invokePurple', @@ -192,6 +193,7 @@ export const MODEL_BASE_TO_LONG_NAME: Record = { flux: 'FLUX', flux2: 'FLUX.2', cogview4: 'CogView4', + 'qwen-image': 'Qwen Image', 'z-image': 'Z-Image', external: 'External', anima: 'Anima', @@ -211,6 +213,7 @@ export const MODEL_BASE_TO_SHORT_NAME: Record = { flux: 'FLUX', flux2: 'FLUX.2', cogview4: 'CogView4', + 'qwen-image': 'QwenImg', 'z-image': 'Z-Image', external: 'External', anima: 'Anima', @@ -231,6 +234,8 @@ export const MODEL_VARIANT_TO_LONG_NAME: Record = { zbase: 'Z-Image Base', large: 'CLIP L', gigantic: 'CLIP G', + generate: 'Qwen Image', + edit: 'Qwen Image Edit', qwen3_4b: 'Qwen3 4B', qwen3_8b: 'Qwen3 8B', qwen3_06b: 'Qwen3 0.6B', @@ -257,13 +262,14 @@ export const MODEL_FORMAT_TO_LONG_NAME: Record = { export const SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS: BaseModelType[] = ['flux', 'sd-3', 'z-image']; -export const SUPPORTS_REF_IMAGES_BASE_MODELS: BaseModelType[] = ['sd-1', 'sdxl', 'flux', 'flux2']; +export const SUPPORTS_REF_IMAGES_BASE_MODELS: BaseModelType[] = ['sd-1', 'sdxl', 'flux', 'flux2', 'qwen-image']; export const SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS: BaseModelType[] = [ 'sd-1', 'sd-2', 'sdxl', 'cogview4', + 'qwen-image', 'sd-3', 'z-image', 'anima', diff --git a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx index f6e1a18f6f..60200c8801 100644 --- a/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx +++ b/invokeai/frontend/web/src/features/modelManagerV2/subpanels/ModelManager.tsx @@ -37,7 +37,7 @@ export const ModelManager = memo(() => { {t('common.modelManager')} - + {canManageModels && } {!!selectedModelKey && canManageModels && (