Merge remote-tracking branch 'upstream/main' into external-models

This commit is contained in:
Alexander Eichhorn
2026-04-14 00:58:09 +02:00
165 changed files with 10290 additions and 504 deletions

View File

@@ -54,6 +54,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
CogView4ConditioningInfo,
ConditioningFieldData,
FLUXConditioningInfo,
QwenImageConditioningInfo,
SD3ConditioningInfo,
SDXLConditioningInfo,
ZImageConditioningInfo,
@@ -144,6 +145,7 @@ class ApiDependencies:
SD3ConditioningInfo,
CogView4ConditioningInfo,
ZImageConditioningInfo,
QwenImageConditioningInfo,
AnimaConditioningInfo,
],
ephemeral=True,

View File

@@ -80,6 +80,7 @@ class SetupStatusResponse(BaseModel):
setup_required: bool = Field(description="Whether initial setup is required")
multiuser_enabled: bool = Field(description="Whether multiuser mode is enabled")
strict_password_checking: bool = Field(description="Whether strict password requirements are enforced")
admin_email: str | None = Field(default=None, description="Email of the first active admin user, if any")
@auth_router.get("/status", response_model=SetupStatusResponse)
@@ -94,15 +95,25 @@ async def get_setup_status() -> SetupStatusResponse:
# If multiuser is disabled, setup is never required
if not config.multiuser:
return SetupStatusResponse(
setup_required=False, multiuser_enabled=False, strict_password_checking=config.strict_password_checking
setup_required=False,
multiuser_enabled=False,
strict_password_checking=config.strict_password_checking,
admin_email=None,
)
# In multiuser mode, check if an admin exists
user_service = ApiDependencies.invoker.services.users
setup_required = not user_service.has_admin()
# Only expose admin_email during initial setup to avoid leaking
# administrator identity on public deployments.
admin_email = user_service.get_admin_email() if setup_required else None
return SetupStatusResponse(
setup_required=setup_required, multiuser_enabled=True, strict_password_checking=config.strict_password_checking
setup_required=setup_required,
multiuser_enabled=True,
strict_password_checking=config.strict_password_checking,
admin_email=admin_email,
)

View File

@@ -1,12 +1,53 @@
from fastapi import Body, HTTPException
from fastapi.routing import APIRouter
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.images.images_common import AddImagesToBoardResult, RemoveImagesFromBoardResult
board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
def _assert_board_write_access(board_id: str, current_user: CurrentUserOrDefault) -> None:
"""Raise 403 if the current user may not mutate the given board.
Write access is granted when ANY of these hold:
- The user is an admin.
- The user owns the board.
- The board visibility is Public (public boards accept contributions from any user).
"""
from invokeai.app.services.board_records.board_records_common import BoardVisibility
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if current_user.is_admin:
return
if board.user_id == current_user.user_id:
return
if board.board_visibility == BoardVisibility.Public:
return
raise HTTPException(status_code=403, detail="Not authorized to modify this board")
def _assert_image_direct_owner(image_name: str, current_user: CurrentUserOrDefault) -> None:
"""Raise 403 if the current user is not the direct owner of the image.
This is intentionally stricter than _assert_image_owner in images.py:
board ownership is NOT sufficient here. Allowing a user to add someone
else's image to their own board would grant them mutation rights via the
board-ownership fallback in _assert_image_owner, escalating read access
into write access.
"""
if current_user.is_admin:
return
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
if owner is not None and owner == current_user.user_id:
return
raise HTTPException(status_code=403, detail="Not authorized to move this image")
@board_images_router.post(
"/",
operation_id="add_image_to_board",
@@ -17,14 +58,17 @@ board_images_router = APIRouter(prefix="/v1/board_images", tags=["boards"])
response_model=AddImagesToBoardResult,
)
async def add_image_to_board(
current_user: CurrentUserOrDefault,
board_id: str = Body(description="The id of the board to add to"),
image_name: str = Body(description="The name of the image to add"),
) -> AddImagesToBoardResult:
"""Creates a board_image"""
_assert_board_write_access(board_id, current_user)
_assert_image_direct_owner(image_name, current_user)
try:
added_images: set[str] = set()
affected_boards: set[str] = set()
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
old_board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none"
ApiDependencies.invoker.services.board_images.add_image_to_board(board_id=board_id, image_name=image_name)
added_images.add(image_name)
affected_boards.add(board_id)
@@ -48,13 +92,16 @@ async def add_image_to_board(
response_model=RemoveImagesFromBoardResult,
)
async def remove_image_from_board(
current_user: CurrentUserOrDefault,
image_name: str = Body(description="The name of the image to remove", embed=True),
) -> RemoveImagesFromBoardResult:
"""Removes an image from its board, if it had one"""
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
if old_board_id != "none":
_assert_board_write_access(old_board_id, current_user)
removed_images: set[str] = set()
affected_boards: set[str] = set()
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_images.add(image_name)
affected_boards.add("none")
@@ -64,6 +111,8 @@ async def remove_image_from_board(
affected_boards=list(affected_boards),
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove image from board")
@@ -78,16 +127,21 @@ async def remove_image_from_board(
response_model=AddImagesToBoardResult,
)
async def add_images_to_board(
current_user: CurrentUserOrDefault,
board_id: str = Body(description="The id of the board to add to"),
image_names: list[str] = Body(description="The names of the images to add", embed=True),
) -> AddImagesToBoardResult:
"""Adds a list of images to a board"""
_assert_board_write_access(board_id, current_user)
try:
added_images: set[str] = set()
affected_boards: set[str] = set()
for image_name in image_names:
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
_assert_image_direct_owner(image_name, current_user)
old_board_id = (
ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name) or "none"
)
ApiDependencies.invoker.services.board_images.add_image_to_board(
board_id=board_id,
image_name=image_name,
@@ -96,12 +150,16 @@ async def add_images_to_board(
affected_boards.add(board_id)
affected_boards.add(old_board_id)
except HTTPException:
raise
except Exception:
pass
return AddImagesToBoardResult(
added_images=list(added_images),
affected_boards=list(affected_boards),
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to add images to board")
@@ -116,6 +174,7 @@ async def add_images_to_board(
response_model=RemoveImagesFromBoardResult,
)
async def remove_images_from_board(
current_user: CurrentUserOrDefault,
image_names: list[str] = Body(description="The names of the images to remove", embed=True),
) -> RemoveImagesFromBoardResult:
"""Removes a list of images from their board, if they had one"""
@@ -125,15 +184,21 @@ async def remove_images_from_board(
for image_name in image_names:
try:
old_board_id = ApiDependencies.invoker.services.images.get_dto(image_name).board_id or "none"
if old_board_id != "none":
_assert_board_write_access(old_board_id, current_user)
ApiDependencies.invoker.services.board_images.remove_image_from_board(image_name=image_name)
removed_images.add(image_name)
affected_boards.add("none")
affected_boards.add(old_board_id)
except HTTPException:
raise
except Exception:
pass
return RemoveImagesFromBoardResult(
removed_images=list(removed_images),
affected_boards=list(affected_boards),
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to remove images from board")

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy, BoardVisibility
from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
@@ -56,7 +56,14 @@ async def get_board(
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if not current_user.is_admin and result.user_id != current_user.user_id:
# Admins can access any board.
# Owners can access their own boards.
# Shared and public boards are visible to all authenticated users.
if (
not current_user.is_admin
and result.user_id != current_user.user_id
and result.board_visibility == BoardVisibility.Private
):
raise HTTPException(status_code=403, detail="Not authorized to access this board")
return result
@@ -188,7 +195,11 @@ async def list_all_board_image_names(
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if not current_user.is_admin and board.user_id != current_user.user_id:
if (
not current_user.is_admin
and board.user_id != current_user.user_id
and board.board_visibility == BoardVisibility.Private
):
raise HTTPException(status_code=403, detail="Not authorized to access this board")
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
@@ -196,4 +207,15 @@ async def list_all_board_image_names(
categories,
is_intermediate,
)
# For uncategorized images (board_id="none"), filter to only the caller's
# images so that one user cannot enumerate another's uncategorized images.
# Admin users can see all uncategorized images.
if board_id == "none" and not current_user.is_admin:
image_names = [
name
for name in image_names
if ApiDependencies.invoker.services.image_records.get_user_id(name) == current_user.user_id
]
return image_names

View File

@@ -38,6 +38,96 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
IMAGE_MAX_AGE = 31536000
def _assert_image_owner(image_name: str, current_user: CurrentUserOrDefault) -> None:
"""Raise 403 if the current user does not own the image and is not an admin.
Ownership is satisfied when ANY of these hold:
- The user is an admin.
- The user is the image's direct owner (image_records.user_id).
- The user owns the board the image sits on.
- The image sits on a Public board (public boards grant mutation rights).
"""
from invokeai.app.services.board_records.board_records_common import BoardVisibility
if current_user.is_admin:
return
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
if owner is not None and owner == current_user.user_id:
return
# Check whether the user owns the board the image belongs to,
# or the board is Public (public boards grant mutation rights).
board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
if board_id is not None:
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
if board.user_id == current_user.user_id:
return
if board.board_visibility == BoardVisibility.Public:
return
except Exception:
pass
raise HTTPException(status_code=403, detail="Not authorized to modify this image")
def _assert_image_read_access(image_name: str, current_user: CurrentUserOrDefault) -> None:
"""Raise 403 if the current user may not view the image.
Access is granted when ANY of these hold:
- The user is an admin.
- The user owns the image.
- The image sits on a shared or public board.
"""
from invokeai.app.services.board_records.board_records_common import BoardVisibility
if current_user.is_admin:
return
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
if owner is not None and owner == current_user.user_id:
return
# Check whether the image's board makes it visible to other users.
board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
if board_id is not None:
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
return
except Exception:
pass
raise HTTPException(status_code=403, detail="Not authorized to access this image")
def _assert_board_read_access(board_id: str, current_user: CurrentUserOrDefault) -> None:
"""Raise 403 if the current user may not read images from this board.
Access is granted when ANY of these hold:
- The user is an admin.
- The user owns the board.
- The board visibility is Shared or Public.
"""
from invokeai.app.services.board_records.board_records_common import BoardVisibility
if current_user.is_admin:
return
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if board.user_id == current_user.user_id:
return
if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
return
raise HTTPException(status_code=403, detail="Not authorized to access this board")
class ResizeToDimensions(BaseModel):
width: int = Field(..., gt=0)
height: int = Field(..., gt=0)
@@ -83,6 +173,22 @@ async def upload_image(
),
) -> ImageDTO:
"""Uploads an image for the current user"""
# If uploading into a board, verify the user has write access.
# Public boards allow uploads from any authenticated user.
if board_id is not None:
from invokeai.app.services.board_records.board_records_common import BoardVisibility
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
except Exception:
raise HTTPException(status_code=404, detail="Board not found")
if (
not current_user.is_admin
and board.user_id != current_user.user_id
and board.board_visibility != BoardVisibility.Public
):
raise HTTPException(status_code=403, detail="Not authorized to upload to this board")
if not file.content_type or not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
@@ -165,9 +271,11 @@ async def create_image_upload_entry(
@images_router.delete("/i/{image_name}", operation_id="delete_image", response_model=DeleteImagesResult)
async def delete_image(
current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of the image to delete"),
) -> DeleteImagesResult:
"""Deletes an image"""
_assert_image_owner(image_name, current_user)
deleted_images: set[str] = set()
affected_boards: set[str] = set()
@@ -189,26 +297,31 @@ async def delete_image(
@images_router.delete("/intermediates", operation_id="clear_intermediates")
async def clear_intermediates() -> int:
"""Clears all intermediates"""
async def clear_intermediates(
current_user: CurrentUserOrDefault,
) -> int:
"""Clears all intermediates. Requires admin."""
if not current_user.is_admin:
raise HTTPException(status_code=403, detail="Only admins can clear all intermediates")
try:
count_deleted = ApiDependencies.invoker.services.images.delete_intermediates()
return count_deleted
except Exception:
raise HTTPException(status_code=500, detail="Failed to clear intermediates")
pass
@images_router.get("/intermediates", operation_id="get_intermediates_count")
async def get_intermediates_count() -> int:
"""Gets the count of intermediate images"""
async def get_intermediates_count(
current_user: CurrentUserOrDefault,
) -> int:
"""Gets the count of intermediate images. Non-admin users only see their own intermediates."""
try:
return ApiDependencies.invoker.services.images.get_intermediates_count()
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.images.get_intermediates_count(user_id=user_id)
except Exception:
raise HTTPException(status_code=500, detail="Failed to get intermediates")
pass
@images_router.patch(
@@ -217,10 +330,12 @@ async def get_intermediates_count() -> int:
response_model=ImageDTO,
)
async def update_image(
current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(description="The changes to apply to the image"),
) -> ImageDTO:
"""Updates an image"""
_assert_image_owner(image_name, current_user)
try:
return ApiDependencies.invoker.services.images.update(image_name, image_changes)
@@ -234,9 +349,11 @@ async def update_image(
response_model=ImageDTO,
)
async def get_image_dto(
current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of image to get"),
) -> ImageDTO:
"""Gets an image's DTO"""
_assert_image_read_access(image_name, current_user)
try:
return ApiDependencies.invoker.services.images.get_dto(image_name)
@@ -250,9 +367,11 @@ async def get_image_dto(
response_model=Optional[MetadataField],
)
async def get_image_metadata(
current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of image to get"),
) -> Optional[MetadataField]:
"""Gets an image's metadata"""
_assert_image_read_access(image_name, current_user)
try:
return ApiDependencies.invoker.services.images.get_metadata(image_name)
@@ -269,8 +388,11 @@ class WorkflowAndGraphResponse(BaseModel):
"/i/{image_name}/workflow", operation_id="get_image_workflow", response_model=WorkflowAndGraphResponse
)
async def get_image_workflow(
current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of image whose workflow to get"),
) -> WorkflowAndGraphResponse:
_assert_image_read_access(image_name, current_user)
try:
workflow = ApiDependencies.invoker.services.images.get_workflow(image_name)
graph = ApiDependencies.invoker.services.images.get_graph(image_name)
@@ -306,8 +428,12 @@ async def get_image_workflow(
async def get_image_full(
image_name: str = Path(description="The name of full-resolution image file to get"),
) -> Response:
"""Gets a full-resolution image file"""
"""Gets a full-resolution image file.
This endpoint is intentionally unauthenticated because browsers load images
via <img src> tags which cannot send Bearer tokens. Image names are UUIDs,
providing security through unguessability.
"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_name)
with open(path, "rb") as f:
@@ -335,8 +461,12 @@ async def get_image_full(
async def get_image_thumbnail(
image_name: str = Path(description="The name of thumbnail image file to get"),
) -> Response:
"""Gets a thumbnail image file"""
"""Gets a thumbnail image file.
This endpoint is intentionally unauthenticated because browsers load images
via <img src> tags which cannot send Bearer tokens. Image names are UUIDs,
providing security through unguessability.
"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_name, thumbnail=True)
with open(path, "rb") as f:
@@ -354,9 +484,11 @@ async def get_image_thumbnail(
response_model=ImageUrlsDTO,
)
async def get_image_urls(
current_user: CurrentUserOrDefault,
image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL"""
_assert_image_read_access(image_name, current_user)
try:
image_url = ApiDependencies.invoker.services.images.get_url(image_name)
@@ -392,6 +524,11 @@ async def list_image_dtos(
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of image DTOs for the current user"""
# Validate that the caller can read from this board before listing its images.
# "none" is a sentinel for uncategorized images and is handled by the SQL layer.
if board_id is not None and board_id != "none":
_assert_board_read_access(board_id, current_user)
image_dtos = ApiDependencies.invoker.services.images.get_many(
offset,
limit,
@@ -410,6 +547,7 @@ async def list_image_dtos(
@images_router.post("/delete", operation_id="delete_images_from_list", response_model=DeleteImagesResult)
async def delete_images_from_list(
current_user: CurrentUserOrDefault,
image_names: list[str] = Body(description="The list of names of images to delete", embed=True),
) -> DeleteImagesResult:
try:
@@ -417,24 +555,31 @@ async def delete_images_from_list(
affected_boards: set[str] = set()
for image_name in image_names:
try:
_assert_image_owner(image_name, current_user)
image_dto = ApiDependencies.invoker.services.images.get_dto(image_name)
board_id = image_dto.board_id or "none"
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add(board_id)
except HTTPException:
raise
except Exception:
pass
return DeleteImagesResult(
deleted_images=list(deleted_images),
affected_boards=list(affected_boards),
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to delete images")
@images_router.delete("/uncategorized", operation_id="delete_uncategorized_images", response_model=DeleteImagesResult)
async def delete_uncategorized_images() -> DeleteImagesResult:
"""Deletes all images that are uncategorized"""
async def delete_uncategorized_images(
current_user: CurrentUserOrDefault,
) -> DeleteImagesResult:
"""Deletes all uncategorized images owned by the current user (or all if admin)"""
image_names = ApiDependencies.invoker.services.board_images.get_all_board_image_names_for_board(
board_id="none", categories=None, is_intermediate=None
@@ -445,9 +590,13 @@ async def delete_uncategorized_images() -> DeleteImagesResult:
affected_boards: set[str] = set()
for image_name in image_names:
try:
_assert_image_owner(image_name, current_user)
ApiDependencies.invoker.services.images.delete(image_name)
deleted_images.add(image_name)
affected_boards.add("none")
except HTTPException:
# Skip images not owned by the current user
pass
except Exception:
pass
return DeleteImagesResult(
@@ -464,6 +613,7 @@ class ImagesUpdatedFromListResult(BaseModel):
@images_router.post("/star", operation_id="star_images_in_list", response_model=StarredImagesResult)
async def star_images_in_list(
current_user: CurrentUserOrDefault,
image_names: list[str] = Body(description="The list of names of images to star", embed=True),
) -> StarredImagesResult:
try:
@@ -471,23 +621,29 @@ async def star_images_in_list(
affected_boards: set[str] = set()
for image_name in image_names:
try:
_assert_image_owner(image_name, current_user)
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=True)
)
starred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
except HTTPException:
raise
except Exception:
pass
return StarredImagesResult(
starred_images=list(starred_images),
affected_boards=list(affected_boards),
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to star images")
@images_router.post("/unstar", operation_id="unstar_images_in_list", response_model=UnstarredImagesResult)
async def unstar_images_in_list(
current_user: CurrentUserOrDefault,
image_names: list[str] = Body(description="The list of names of images to unstar", embed=True),
) -> UnstarredImagesResult:
try:
@@ -495,17 +651,22 @@ async def unstar_images_in_list(
affected_boards: set[str] = set()
for image_name in image_names:
try:
_assert_image_owner(image_name, current_user)
updated_image_dto = ApiDependencies.invoker.services.images.update(
image_name, changes=ImageRecordChanges(starred=False)
)
unstarred_images.add(image_name)
affected_boards.add(updated_image_dto.board_id or "none")
except HTTPException:
raise
except Exception:
pass
return UnstarredImagesResult(
unstarred_images=list(unstarred_images),
affected_boards=list(affected_boards),
)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=500, detail="Failed to unstar images")
@@ -523,6 +684,7 @@ class ImagesDownloaded(BaseModel):
"/download", operation_id="download_images_from_list", response_model=ImagesDownloaded, status_code=202
)
async def download_images_from_list(
current_user: CurrentUserOrDefault,
background_tasks: BackgroundTasks,
image_names: Optional[list[str]] = Body(
default=None, description="The list of names of images to download", embed=True
@@ -533,6 +695,16 @@ async def download_images_from_list(
) -> ImagesDownloaded:
if (image_names is None or len(image_names) == 0) and board_id is None:
raise HTTPException(status_code=400, detail="No images or board id specified.")
# Validate that the caller can read every image they are requesting.
# For a board_id request, check board visibility; for explicit image names,
# check each image individually.
if board_id:
_assert_board_read_access(board_id, current_user)
if image_names:
for name in image_names:
_assert_image_read_access(name, current_user)
bulk_download_item_id: str = ApiDependencies.invoker.services.bulk_download.generate_item_id(board_id)
background_tasks.add_task(
@@ -540,6 +712,7 @@ async def download_images_from_list(
image_names,
board_id,
bulk_download_item_id,
current_user.user_id,
)
return ImagesDownloaded(bulk_download_item_name=bulk_download_item_id + ".zip")
@@ -558,11 +731,21 @@ async def download_images_from_list(
},
)
async def get_bulk_download_item(
current_user: CurrentUserOrDefault,
background_tasks: BackgroundTasks,
bulk_download_item_name: str = Path(description="The bulk_download_item_name of the bulk download item to get"),
) -> FileResponse:
"""Gets a bulk download zip file"""
"""Gets a bulk download zip file.
Requires authentication. The caller must be the user who initiated the
download (tracked by the bulk download service) or an admin.
"""
try:
# Verify the caller owns this download (or is an admin)
owner = ApiDependencies.invoker.services.bulk_download.get_owner(bulk_download_item_name)
if owner is not None and owner != current_user.user_id and not current_user.is_admin:
raise HTTPException(status_code=403, detail="Not authorized to access this download")
path = ApiDependencies.invoker.services.bulk_download.get_path(bulk_download_item_name)
response = FileResponse(
@@ -574,6 +757,8 @@ async def get_bulk_download_item(
response.headers["Cache-Control"] = f"max-age={IMAGE_MAX_AGE}"
background_tasks.add_task(ApiDependencies.invoker.services.bulk_download.delete, bulk_download_item_name)
return response
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=404)
@@ -594,6 +779,10 @@ async def get_image_names(
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates"""
# Validate that the caller can read from this board before listing its images.
if board_id is not None and board_id != "none":
_assert_board_read_access(board_id, current_user)
try:
result = ApiDependencies.invoker.services.images.get_image_names(
starred_first=starred_first,
@@ -617,6 +806,7 @@ async def get_image_names(
responses={200: {"model": list[ImageDTO]}},
)
async def get_images_by_names(
current_user: CurrentUserOrDefault,
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
) -> list[ImageDTO]:
"""Gets image DTOs for the specified image names. Maintains order of input names."""
@@ -628,8 +818,12 @@ async def get_images_by_names(
image_dtos: list[ImageDTO] = []
for name in image_names:
try:
_assert_image_read_access(name, current_user)
dto = image_service.get_dto(name)
image_dtos.append(dto)
except HTTPException:
# Skip images the user is not authorized to view
continue
except Exception:
# Skip missing images - they may have been deleted between name fetch and DTO fetch
continue

View File

@@ -889,7 +889,7 @@ async def install_hugging_face_model(
"/install",
operation_id="list_model_installs",
)
async def list_model_installs() -> List[ModelInstallJob]:
async def list_model_installs(current_admin: AdminUserOrDefault) -> List[ModelInstallJob]:
"""Return the list of model install jobs.
Install jobs have a numeric `id`, a `status`, and other fields that provide information on
@@ -921,7 +921,9 @@ async def list_model_installs() -> List[ModelInstallJob]:
404: {"description": "No such job"},
},
)
async def get_model_install_job(id: int = Path(description="Model install id")) -> ModelInstallJob:
async def get_model_install_job(
current_admin: AdminUserOrDefault, id: int = Path(description="Model install id")
) -> ModelInstallJob:
"""
Return model install job corresponding to the given source. See the documentation for 'List Model Install Jobs'
for information on the format of the return value.
@@ -964,7 +966,9 @@ async def cancel_model_install_job(
},
status_code=201,
)
async def pause_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
async def pause_model_install_job(
current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
) -> ModelInstallJob:
"""Pause the model install job corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
@@ -984,7 +988,9 @@ async def pause_model_install_job(id: int = Path(description="Model install job
},
status_code=201,
)
async def resume_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
async def resume_model_install_job(
current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
) -> ModelInstallJob:
"""Resume a paused model install job corresponding to the given job ID."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
@@ -1004,7 +1010,9 @@ async def resume_model_install_job(id: int = Path(description="Model install job
},
status_code=201,
)
async def restart_failed_model_install_job(id: int = Path(description="Model install job ID")) -> ModelInstallJob:
async def restart_failed_model_install_job(
current_admin: AdminUserOrDefault, id: int = Path(description="Model install job ID")
) -> ModelInstallJob:
"""Restart failed or non-resumable file downloads for the given job."""
installer = ApiDependencies.invoker.services.model_manager.install
try:
@@ -1025,6 +1033,7 @@ async def restart_failed_model_install_job(id: int = Path(description="Model ins
status_code=201,
)
async def restart_model_install_file(
current_admin: AdminUserOrDefault,
id: int = Path(description="Model install job ID"),
file_source: AnyHttpUrl = Body(description="File download URL to restart"),
) -> ModelInstallJob:
@@ -1336,7 +1345,7 @@ class DeleteOrphanedModelsResponse(BaseModel):
operation_id="get_orphaned_models",
response_model=list[OrphanedModelInfo],
)
async def get_orphaned_models() -> list[OrphanedModelInfo]:
async def get_orphaned_models(_: AdminUserOrDefault) -> list[OrphanedModelInfo]:
"""Find orphaned model directories.
Orphaned models are directories in the models folder that contain model files
@@ -1363,7 +1372,9 @@ async def get_orphaned_models() -> list[OrphanedModelInfo]:
operation_id="delete_orphaned_models",
response_model=DeleteOrphanedModelsResponse,
)
async def delete_orphaned_models(request: DeleteOrphanedModelsRequest) -> DeleteOrphanedModelsResponse:
async def delete_orphaned_models(
request: DeleteOrphanedModelsRequest, _: AdminUserOrDefault
) -> DeleteOrphanedModelsResponse:
"""Delete specified orphaned model directories.
Args:

View File

@@ -7,6 +7,7 @@ from fastapi import Body, HTTPException, Path
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.backend.image_util.controlnet_processor import process_controlnet_image
from invokeai.backend.model_manager.taxonomy import ModelType
@@ -291,12 +292,58 @@ def resolve_ip_adapter_models(ip_adapters: list[IPAdapterRecallParameter]) -> li
return resolved_adapters
def _assert_recall_image_access(parameters: "RecallParameter", current_user: CurrentUserOrDefault) -> None:
"""Validate that the caller can read every image referenced in the recall parameters.
Control layers and IP adapters may reference image_name fields. Without this
check an attacker who knows another user's image UUID could use the recall
endpoint to extract image dimensions and — for ControlNet preprocessors — mint
a derived processed image they can then fetch.
"""
from invokeai.app.services.board_records.board_records_common import BoardVisibility
image_names: list[str] = []
if parameters.control_layers:
for layer in parameters.control_layers:
if layer.image_name is not None:
image_names.append(layer.image_name)
if parameters.ip_adapters:
for adapter in parameters.ip_adapters:
if adapter.image_name is not None:
image_names.append(adapter.image_name)
if not image_names:
return
# Admin can access all images
if current_user.is_admin:
return
for image_name in image_names:
owner = ApiDependencies.invoker.services.image_records.get_user_id(image_name)
if owner is not None and owner == current_user.user_id:
continue
# Check board visibility
board_id = ApiDependencies.invoker.services.board_image_records.get_board_for_image(image_name)
if board_id is not None:
try:
board = ApiDependencies.invoker.services.boards.get_dto(board_id=board_id)
if board.board_visibility in (BoardVisibility.Shared, BoardVisibility.Public):
continue
except Exception:
pass
raise HTTPException(status_code=403, detail=f"Not authorized to access image {image_name}")
@recall_parameters_router.post(
"/{queue_id}",
operation_id="update_recall_parameters",
response_model=dict[str, Any],
)
async def update_recall_parameters(
current_user: CurrentUserOrDefault,
queue_id: str = Path(..., description="The queue id to perform this operation on"),
parameters: RecallParameter = Body(..., description="Recall parameters to update"),
) -> dict[str, Any]:
@@ -328,6 +375,10 @@ async def update_recall_parameters(
"""
logger = ApiDependencies.invoker.services.logger
# Validate image access before processing — prevents information leakage
# (dimensions) and derived-image minting via ControlNet preprocessors.
_assert_recall_image_access(parameters, current_user)
try:
# Get only the parameters that were actually provided (non-None values)
provided_params = {k: v for k, v in parameters.model_dump().items() if v is not None}
@@ -335,14 +386,14 @@ async def update_recall_parameters(
if not provided_params:
return {"status": "no_parameters_provided", "updated_count": 0}
# Store each parameter in client state using a consistent key format
# Store each parameter in client state scoped to the current user
updated_count = 0
for param_key, param_value in provided_params.items():
# Convert parameter values to JSON strings for storage
value_str = json.dumps(param_value)
try:
ApiDependencies.invoker.services.client_state_persistence.set_by_key(
queue_id, f"recall_{param_key}", value_str
current_user.user_id, f"recall_{param_key}", value_str
)
updated_count += 1
except Exception as e:
@@ -396,7 +447,9 @@ async def update_recall_parameters(
logger.info(
f"Emitting recall_parameters_updated event for queue {queue_id} with {len(provided_params)} parameters"
)
ApiDependencies.invoker.services.events.emit_recall_parameters_updated(queue_id, provided_params)
ApiDependencies.invoker.services.events.emit_recall_parameters_updated(
queue_id, current_user.user_id, provided_params
)
logger.info("Successfully emitted recall_parameters_updated event")
except Exception as e:
logger.error(f"Error emitting recall parameters event: {e}", exc_info=True)
@@ -425,6 +478,7 @@ async def update_recall_parameters(
response_model=dict[str, Any],
)
async def get_recall_parameters(
current_user: CurrentUserOrDefault,
queue_id: str = Path(..., description="The queue id to retrieve parameters for"),
) -> dict[str, Any]:
"""

View File

@@ -44,7 +44,8 @@ def sanitize_queue_item_for_user(
"""Sanitize queue item for non-admin users viewing other users' items.
For non-admin users viewing queue items belonging to other users,
the field_values, session graph, and workflow should be hidden/cleared to protect privacy.
only timestamps, status, and error information are exposed. All other
fields (user identity, generation parameters, graphs, workflows) are stripped.
Args:
queue_item: The queue item to sanitize
@@ -58,15 +59,25 @@ def sanitize_queue_item_for_user(
if is_admin or queue_item.user_id == current_user_id:
return queue_item
# For non-admins viewing other users' items, clear sensitive fields
# Create a shallow copy to avoid mutating the original
# For non-admins viewing other users' items, strip everything except
# item_id, queue_id, status, and timestamps
sanitized_item = queue_item.model_copy(deep=False)
sanitized_item.user_id = "redacted"
sanitized_item.user_display_name = None
sanitized_item.user_email = None
sanitized_item.batch_id = "redacted"
sanitized_item.session_id = "redacted"
sanitized_item.origin = None
sanitized_item.destination = None
sanitized_item.priority = 0
sanitized_item.field_values = None
sanitized_item.retried_from_item_id = None
sanitized_item.workflow = None
# Clear the session graph by replacing it with an empty graph execution state
# This prevents information leakage through the generation graph
sanitized_item.error_type = None
sanitized_item.error_message = None
sanitized_item.error_traceback = None
sanitized_item.session = GraphExecutionState(
id=queue_item.session.id,
id="redacted",
graph=Graph(),
)
return sanitized_item
@@ -126,12 +137,16 @@ async def list_all_queue_items(
},
)
async def get_queue_item_ids(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
) -> ItemIdsResult:
"""Gets all queue item ids that match the given parameters"""
"""Gets all queue item ids that match the given parameters. Non-admin users only see their own items."""
try:
return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(queue_id=queue_id, order_dir=order_dir)
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.session_queue.get_queue_item_ids(
queue_id=queue_id, order_dir=order_dir, user_id=user_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while listing all queue item ids: {e}")
@@ -376,11 +391,15 @@ async def prune(
},
)
async def get_current_queue_item(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the currently execution queue item"""
try:
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if item is not None:
item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin)
return item
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting current queue item: {e}")
@@ -393,11 +412,15 @@ async def get_current_queue_item(
},
)
async def get_next_queue_item(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the next queue item, without executing it"""
try:
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
item = ApiDependencies.invoker.services.session_queue.get_next(queue_id)
if item is not None:
item = sanitize_queue_item_for_user(item, current_user.user_id, current_user.is_admin)
return item
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting next queue item: {e}")
@@ -413,9 +436,10 @@ async def get_queue_status(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndProcessorStatus:
"""Gets the status of the session queue"""
"""Gets the status of the session queue. Non-admin users see only their own counts and cannot see current item details unless they own it."""
try:
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=current_user.user_id)
user_id = None if current_user.is_admin else current_user.user_id
queue = ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id, user_id=user_id)
processor = ApiDependencies.invoker.services.session_processor.get_status()
return SessionQueueAndProcessorStatus(queue=queue, processor=processor)
except Exception as e:
@@ -430,12 +454,16 @@ async def get_queue_status(
},
)
async def get_batch_status(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to perform this operation on"),
batch_id: str = Path(description="The batch to get the status of"),
) -> BatchStatus:
"""Gets the status of the session queue"""
"""Gets the status of a batch. Non-admin users only see their own batches."""
try:
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.session_queue.get_batch_status(
queue_id=queue_id, batch_id=batch_id, user_id=user_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while getting batch status: {e}")
@@ -529,13 +557,15 @@ async def cancel_queue_item(
responses={200: {"model": SessionQueueCountsByDestination}},
)
async def counts_by_destination(
current_user: CurrentUserOrDefault,
queue_id: str = Path(description="The queue id to query"),
destination: str = Query(description="The destination to query"),
) -> SessionQueueCountsByDestination:
"""Gets the counts of queue items by destination"""
"""Gets the counts of queue items by destination. Non-admin users only see their own items."""
try:
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.session_queue.get_counts_by_destination(
queue_id=queue_id, destination=destination
queue_id=queue_id, destination=destination, user_id=user_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error while fetching counts by destination: {e}")

View File

@@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, File, HTTPException, Path, Query, UploadFil
from fastapi.responses import FileResponse
from PIL import Image
from invokeai.app.api.auth_dependencies import CurrentUserOrDefault
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
@@ -33,16 +34,25 @@ workflows_router = APIRouter(prefix="/v1/workflows", tags=["workflows"])
},
)
async def get_workflow(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to get"),
) -> WorkflowRecordWithThumbnailDTO:
"""Gets a workflow"""
try:
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
workflow = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
config = ApiDependencies.invoker.services.configuration
if config.multiuser:
is_default = workflow.workflow.meta.category is WorkflowCategory.Default
is_owner = workflow.user_id == current_user.user_id
if not (is_default or is_owner or workflow.is_public or current_user.is_admin):
raise HTTPException(status_code=403, detail="Not authorized to access this workflow")
thumbnail_url = ApiDependencies.invoker.services.workflow_thumbnails.get_url(workflow_id)
return WorkflowRecordWithThumbnailDTO(thumbnail_url=thumbnail_url, **workflow.model_dump())
@workflows_router.patch(
"/i/{workflow_id}",
@@ -52,10 +62,21 @@ async def get_workflow(
},
)
async def update_workflow(
current_user: CurrentUserOrDefault,
workflow: Workflow = Body(description="The updated workflow", embed=True),
) -> WorkflowRecordDTO:
"""Updates a workflow"""
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow)
config = ApiDependencies.invoker.services.configuration
if config.multiuser:
try:
existing = ApiDependencies.invoker.services.workflow_records.get(workflow.id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
if not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
# Pass user_id for defense-in-depth SQL scoping; admins pass None to allow any.
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.workflow_records.update(workflow=workflow, user_id=user_id)
@workflows_router.delete(
@@ -63,15 +84,25 @@ async def update_workflow(
operation_id="delete_workflow",
)
async def delete_workflow(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to delete"),
) -> None:
"""Deletes a workflow"""
config = ApiDependencies.invoker.services.configuration
if config.multiuser:
try:
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
if not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to delete this workflow")
try:
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
except WorkflowThumbnailFileNotFoundException:
# It's OK if the workflow has no thumbnail file. We can still delete the workflow.
pass
ApiDependencies.invoker.services.workflow_records.delete(workflow_id)
user_id = None if current_user.is_admin else current_user.user_id
ApiDependencies.invoker.services.workflow_records.delete(workflow_id, user_id=user_id)
@workflows_router.post(
@@ -82,10 +113,11 @@ async def delete_workflow(
},
)
async def create_workflow(
current_user: CurrentUserOrDefault,
workflow: WorkflowWithoutID = Body(description="The workflow to create", embed=True),
) -> WorkflowRecordDTO:
"""Creates a workflow"""
return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow)
return ApiDependencies.invoker.services.workflow_records.create(workflow=workflow, user_id=current_user.user_id)
@workflows_router.get(
@@ -96,6 +128,7 @@ async def create_workflow(
},
)
async def list_workflows(
current_user: CurrentUserOrDefault,
page: int = Query(default=0, description="The page to get"),
per_page: Optional[int] = Query(default=None, description="The number of workflows per page"),
order_by: WorkflowRecordOrderBy = Query(
@@ -106,8 +139,19 @@ async def list_workflows(
tags: Optional[list[str]] = Query(default=None, description="The tags of workflow to get"),
query: Optional[str] = Query(default=None, description="The text to query by (matches name and description)"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> PaginatedResults[WorkflowRecordListItemWithThumbnailDTO]:
"""Gets a page of workflows"""
config = ApiDependencies.invoker.services.configuration
# In multiuser mode, scope user-category workflows to the current user unless fetching shared workflows
user_id_filter: Optional[str] = None
if config.multiuser:
# Only filter 'user' category results by user_id when not explicitly listing public workflows
has_user_category = not categories or WorkflowCategory.User in categories
if has_user_category and is_public is not True:
user_id_filter = current_user.user_id
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
order_by=order_by,
@@ -118,6 +162,8 @@ async def list_workflows(
categories=categories,
tags=tags,
has_been_opened=has_been_opened,
user_id=user_id_filter,
is_public=is_public,
)
for workflow in workflows.items:
workflows_with_thumbnails.append(
@@ -143,15 +189,20 @@ async def list_workflows(
},
)
async def set_workflow_thumbnail(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
image: UploadFile = File(description="The image file to upload"),
):
"""Sets a workflow's thumbnail image"""
try:
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
config = ApiDependencies.invoker.services.configuration
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
if not image.content_type or not image.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image")
@@ -177,14 +228,19 @@ async def set_workflow_thumbnail(
},
)
async def delete_workflow_thumbnail(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
):
"""Removes a workflow's thumbnail image"""
try:
ApiDependencies.invoker.services.workflow_records.get(workflow_id)
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
config = ApiDependencies.invoker.services.configuration
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
try:
ApiDependencies.invoker.services.workflow_thumbnails.delete(workflow_id)
except ValueError as e:
@@ -206,8 +262,12 @@ async def delete_workflow_thumbnail(
async def get_workflow_thumbnail(
workflow_id: str = Path(description="The id of the workflow thumbnail to get"),
) -> FileResponse:
"""Gets a workflow's thumbnail image"""
"""Gets a workflow's thumbnail image.
This endpoint is intentionally unauthenticated because browsers load images
via <img src> tags which cannot send Bearer tokens. Workflow IDs are UUIDs,
providing security through unguessability.
"""
try:
path = ApiDependencies.invoker.services.workflow_thumbnails.get_path(workflow_id)
@@ -223,37 +283,91 @@ async def get_workflow_thumbnail(
raise HTTPException(status_code=404)
@workflows_router.patch(
"/i/{workflow_id}/is_public",
operation_id="update_workflow_is_public",
responses={
200: {"model": WorkflowRecordDTO},
},
)
async def update_workflow_is_public(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
is_public: bool = Body(description="Whether the workflow should be shared publicly", embed=True),
) -> WorkflowRecordDTO:
"""Updates whether a workflow is shared publicly"""
try:
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
config = ApiDependencies.invoker.services.configuration
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
user_id = None if current_user.is_admin else current_user.user_id
return ApiDependencies.invoker.services.workflow_records.update_is_public(
workflow_id=workflow_id, is_public=is_public, user_id=user_id
)
@workflows_router.get("/tags", operation_id="get_all_tags")
async def get_all_tags(
current_user: CurrentUserOrDefault,
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> list[str]:
"""Gets all unique tags from workflows"""
config = ApiDependencies.invoker.services.configuration
user_id_filter: Optional[str] = None
if config.multiuser:
has_user_category = not categories or WorkflowCategory.User in categories
if has_user_category and is_public is not True:
user_id_filter = current_user.user_id
return ApiDependencies.invoker.services.workflow_records.get_all_tags(categories=categories)
return ApiDependencies.invoker.services.workflow_records.get_all_tags(
categories=categories, user_id=user_id_filter, is_public=is_public
)
@workflows_router.get("/counts_by_tag", operation_id="get_counts_by_tag")
async def get_counts_by_tag(
current_user: CurrentUserOrDefault,
tags: list[str] = Query(description="The tags to get counts for"),
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> dict[str, int]:
"""Counts workflows by tag"""
config = ApiDependencies.invoker.services.configuration
user_id_filter: Optional[str] = None
if config.multiuser:
has_user_category = not categories or WorkflowCategory.User in categories
if has_user_category and is_public is not True:
user_id_filter = current_user.user_id
return ApiDependencies.invoker.services.workflow_records.counts_by_tag(
tags=tags, categories=categories, has_been_opened=has_been_opened
tags=tags, categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public
)
@workflows_router.get("/counts_by_category", operation_id="counts_by_category")
async def counts_by_category(
current_user: CurrentUserOrDefault,
categories: list[WorkflowCategory] = Query(description="The categories to include"),
has_been_opened: Optional[bool] = Query(default=None, description="Whether to include/exclude recent workflows"),
is_public: Optional[bool] = Query(default=None, description="Filter by public/shared status"),
) -> dict[str, int]:
"""Counts workflows by category"""
config = ApiDependencies.invoker.services.configuration
user_id_filter: Optional[str] = None
if config.multiuser:
has_user_category = WorkflowCategory.User in categories
if has_user_category and is_public is not True:
user_id_filter = current_user.user_id
return ApiDependencies.invoker.services.workflow_records.counts_by_category(
categories=categories, has_been_opened=has_been_opened
categories=categories, has_been_opened=has_been_opened, user_id=user_id_filter, is_public=is_public
)
@@ -262,7 +376,18 @@ async def counts_by_category(
operation_id="update_opened_at",
)
async def update_opened_at(
current_user: CurrentUserOrDefault,
workflow_id: str = Path(description="The workflow to update"),
) -> None:
"""Updates the opened_at field of a workflow"""
ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id)
try:
existing = ApiDependencies.invoker.services.workflow_records.get(workflow_id)
except WorkflowNotFoundError:
raise HTTPException(status_code=404, detail="Workflow not found")
config = ApiDependencies.invoker.services.configuration
if config.multiuser and not current_user.is_admin and existing.user_id != current_user.user_id:
raise HTTPException(status_code=403, detail="Not authorized to update this workflow")
user_id = None if current_user.is_admin else current_user.user_id
ApiDependencies.invoker.services.workflow_records.update_opened_at(workflow_id, user_id=user_id)

View File

@@ -121,6 +121,11 @@ class SocketIO:
Returns True to accept the connection, False to reject it.
Stores user_id in the internal socket users dict for later use.
In multiuser mode, connections without a valid token are rejected outright
so that anonymous clients cannot subscribe to queue rooms and observe
queue activity belonging to other users. In single-user mode, unauthenticated
connections are accepted as the system admin user.
"""
# Extract token from auth data or headers
token = None
@@ -137,6 +142,23 @@ class SocketIO:
if token:
token_data = verify_token(token)
if token_data:
# In multiuser mode, also verify the backing user record still
# exists and is active — mirrors the REST auth check in
# auth_dependencies.py. A deleted or deactivated user whose
# JWT has not yet expired must not be allowed to open a socket.
if self._is_multiuser_enabled():
try:
from invokeai.app.api.dependencies import ApiDependencies
user = ApiDependencies.invoker.services.users.get(token_data.user_id)
if user is None or not user.is_active:
logger.warning(f"Rejecting socket {sid}: user {token_data.user_id} not found or inactive")
return False
except Exception:
# If user service is unavailable, fail closed
logger.warning(f"Rejecting socket {sid}: unable to verify user record")
return False
# Store user_id and is_admin in socket users dict
self._socket_users[sid] = {
"user_id": token_data.user_id,
@@ -147,14 +169,37 @@ class SocketIO:
)
return True
# If no valid token, store system user for backward compatibility
# No valid token provided. In multiuser mode this is not allowed — reject
# the connection so anonymous clients cannot subscribe to queue rooms.
# In single-user mode, fall through and accept the socket as system admin.
if self._is_multiuser_enabled():
logger.warning(
f"Rejecting socket {sid} connection: multiuser mode is enabled and no valid auth token was provided"
)
return False
self._socket_users[sid] = {
"user_id": "system",
"is_admin": False,
"is_admin": True,
}
logger.debug(f"Socket {sid} connected as system user (no valid token)")
logger.debug(f"Socket {sid} connected as system admin (single-user mode)")
return True
@staticmethod
def _is_multiuser_enabled() -> bool:
"""Check whether multiuser mode is enabled. Fails closed if configuration
is not yet initialized, which should not happen in practice but prevents
accidentally opening the socket during startup races."""
try:
# Imported here to avoid a circular import at module load time.
from invokeai.app.api.dependencies import ApiDependencies
return bool(ApiDependencies.invoker.services.configuration.multiuser)
except Exception:
# If dependencies are not initialized, fail closed (treat as multiuser)
# so we never accidentally admit an anonymous socket.
return True
async def _handle_disconnect(self, sid: str) -> None:
"""Handle socket disconnection and cleanup user info."""
if sid in self._socket_users:
@@ -165,15 +210,20 @@ class SocketIO:
"""Handle queue subscription and add socket to both queue and user-specific rooms."""
queue_id = QueueSubscriptionEvent(**data).queue_id
# Check if we have user info for this socket
# Check if we have user info for this socket. In multiuser mode _handle_connect
# will have already rejected any socket without a valid token, so missing user
# info here is a bug — refuse the subscription rather than silently falling back
# to an anonymous system user who could then receive queue item events.
if sid not in self._socket_users:
logger.warning(
f"Socket {sid} subscribing to queue {queue_id} but has no user info - need to authenticate via connect event"
)
# Store as system user temporarily - real auth should happen in connect
if self._is_multiuser_enabled():
logger.warning(
f"Refusing queue subscription for socket {sid}: no user info (socket not authenticated via connect event)"
)
return
# Single-user mode: safe to fall back to the system admin user.
self._socket_users[sid] = {
"user_id": "system",
"is_admin": False,
"is_admin": True,
}
user_id = self._socket_users[sid]["user_id"]
@@ -198,6 +248,13 @@ class SocketIO:
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
# In multiuser mode, only allow authenticated sockets to subscribe.
# Bulk download events are routed to user-specific rooms, so the
# bulk_download_id room subscription is only kept for single-user
# backward compatibility.
if self._is_multiuser_enabled() and sid not in self._socket_users:
logger.warning(f"Refusing bulk download subscription for unknown socket {sid} in multiuser mode")
return
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
@@ -206,9 +263,17 @@ class SocketIO:
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
"""Handle queue events with user isolation.
Invocation events (progress, started, complete) are private - only emit to owner and admins.
Queue item status events are public - emit to all users (field values hidden via API).
Other queue events emit to all subscribers.
All queue item events (invocation events AND QueueItemStatusChangedEvent) are
private to the owning user and admins. They carry unsanitized user_id, batch_id,
session_id, origin, destination and error metadata, and must never be broadcast
to the whole queue room — otherwise any other authenticated subscriber could
observe cross-user queue activity.
RecallParametersUpdatedEvent is also private to the owner + admins.
BatchEnqueuedEvent carries the enqueuing user's batch_id/origin/counts and
is also routed privately. QueueClearedEvent is the only queue event that
is still broadcast to the whole queue room.
IMPORTANT: Check InvocationEventBase BEFORE QueueItemEventBase since InvocationEventBase
inherits from QueueItemEventBase. The order of isinstance checks matters!
@@ -237,24 +302,40 @@ class SocketIO:
logger.debug(f"Emitted private invocation event {event_name} to user room {user_room} and admin room")
# Queue item status events are visible to all users (field values masked via API)
# This catches QueueItemStatusChangedEvent but NOT InvocationEvents (already handled above)
# Other queue item events (QueueItemStatusChangedEvent) carry unsanitized
# user_id, batch_id, session_id, origin, destination and error metadata.
# They are private to the owning user + admins — never broadcast to the
# full queue room.
elif isinstance(event_data, QueueItemEventBase) and hasattr(event_data, "user_id"):
# Emit to all subscribers in the queue
await self._sio.emit(
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id
)
user_room = f"user:{event_data.user_id}"
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
logger.info(
f"Emitted public queue item event {event_name} to all subscribers in queue {event_data.queue_id}"
)
logger.debug(f"Emitted private queue item event {event_name} to user room {user_room} and admin room")
# RecallParametersUpdatedEvent is private - only emit to owner + admins
elif isinstance(event_data, RecallParametersUpdatedEvent):
user_room = f"user:{event_data.user_id}"
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
logger.debug(f"Emitted private recall_parameters_updated event to user room {user_room} and admin room")
# BatchEnqueuedEvent carries the enqueuing user's batch_id, origin, and
# enqueued counts. Route it privately to the owner + admins so other
# users do not observe cross-user batch activity.
elif isinstance(event_data, BatchEnqueuedEvent):
user_room = f"user:{event_data.user_id}"
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
logger.debug(f"Emitted private batch_enqueued event to user room {user_room} and admin room")
else:
# For other queue events (like QueueClearedEvent, BatchEnqueuedEvent), emit to all subscribers
# For remaining queue events (e.g. QueueClearedEvent) that do not
# carry user identity, emit to all subscribers in the queue room.
await self._sio.emit(
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.queue_id
)
logger.info(
logger.debug(
f"Emitted general queue event {event_name} to all subscribers in queue {event_data.queue_id}"
)
except Exception as e:
@@ -265,4 +346,17 @@ class SocketIO:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
event_name, event_data = event
# Route to user-specific + admin rooms so that other authenticated
# users cannot learn the bulk_download_item_name (the capability token
# needed to fetch the zip from the unauthenticated GET endpoint).
# In single-user mode (user_id="system"), fall back to the shared
# bulk_download_id room for backward compatibility.
if hasattr(event_data, "user_id") and event_data.user_id != "system":
user_room = f"user:{event_data.user_id}"
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room=user_room)
await self._sio.emit(event=event_name, data=event_data.model_dump(mode="json"), room="admin")
else:
await self._sio.emit(
event=event_name, data=event_data.model_dump(mode="json"), room=event_data.bulk_download_id
)

View File

@@ -171,6 +171,8 @@ class FieldDescriptions:
sd3_model = "SD3 model (MMDiTX) to load"
cogview4_model = "CogView4 model (Transformer) to load"
z_image_model = "Z-Image model (Transformer) to load"
qwen_image_model = "Qwen Image Edit model (Transformer) to load"
qwen_vl_encoder = "Qwen2.5-VL tokenizer, processor and text/vision encoder"
sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load"
sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load"
onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load"
@@ -340,6 +342,12 @@ class ZImageConditioningField(BaseModel):
)
class QwenImageConditioningField(BaseModel):
"""A Qwen Image Edit conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
class AnimaConditioningField(BaseModel):
"""An Anima conditioning tensor primitive value.

View File

@@ -166,6 +166,10 @@ GENERATION_MODES = Literal[
"z_image_img2img",
"z_image_inpaint",
"z_image_outpaint",
"qwen_image_txt2img",
"qwen_image_img2img",
"qwen_image_inpaint",
"qwen_image_outpaint",
"anima_txt2img",
"anima_img2img",
"anima_inpaint",

View File

@@ -72,6 +72,13 @@ class GlmEncoderField(BaseModel):
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class QwenVLEncoderField(BaseModel):
"""Field for Qwen2.5-VL encoder used by Qwen Image Edit models."""
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")
class Qwen3EncoderField(BaseModel):
"""Field for Qwen3 text encoder used by Z-Image models."""

View File

@@ -25,6 +25,7 @@ from invokeai.app.invocations.fields import (
InputField,
LatentsField,
OutputField,
QwenImageConditioningField,
SD3ConditioningField,
TensorField,
UIComponent,
@@ -474,6 +475,17 @@ class ZImageConditioningOutput(BaseInvocationOutput):
return cls(conditioning=ZImageConditioningField(conditioning_name=conditioning_name))
@invocation_output("qwen_image_conditioning_output")
class QwenImageConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output a Qwen Image Edit conditioning tensor."""
conditioning: QwenImageConditioningField = OutputField(description=FieldDescriptions.cond)
@classmethod
def build(cls, conditioning_name: str) -> "QwenImageConditioningOutput":
return cls(conditioning=QwenImageConditioningField(conditioning_name=conditioning_name))
@invocation_output("anima_conditioning_output")
class AnimaConditioningOutput(BaseInvocationOutput):
"""Base class for nodes that output an Anima text conditioning tensor."""

View File

@@ -0,0 +1,490 @@
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple
import torch
import torchvision.transforms as tv_transforms
from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformer2DModel
from torchvision.transforms.functional import resize as tv_resize
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.fields import (
DenoiseMaskField,
FieldDescriptions,
Input,
InputField,
LatentsField,
QwenImageConditioningField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.qwen_image_lora_constants import (
QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX,
)
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.rectified_flow.rectified_flow_inpaint_extension import RectifiedFlowInpaintExtension
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import QwenImageConditioningInfo
from invokeai.backend.util.devices import TorchDevice
@invocation(
"qwen_image_denoise",
title="Denoise - Qwen Image",
tags=["image", "qwen_image"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run the denoising process with a Qwen Image model."""
# If latents is provided, this means we are doing image-to-image.
latents: Optional[LatentsField] = InputField(
default=None, description=FieldDescriptions.latents, input=Input.Connection
)
# Reference image latents (encoded through VAE) to concatenate with noisy latents.
reference_latents: Optional[LatentsField] = InputField(
default=None,
description="Reference image latents to guide generation. Encoded through the VAE.",
input=Input.Connection,
)
# denoise_mask is used for image-to-image inpainting. Only the masked region is modified.
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.denoise_mask, input=Input.Connection
)
denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start)
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
transformer: TransformerField = InputField(
description=FieldDescriptions.qwen_image_model, input=Input.Connection, title="Transformer"
)
positive_conditioning: QwenImageConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_conditioning: Optional[QwenImageConditioningField] = InputField(
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
)
cfg_scale: float | list[float] = InputField(default=4.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
steps: int = InputField(default=40, gt=0, description=FieldDescriptions.steps)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
shift: Optional[float] = InputField(
default=None,
description="Override the sigma schedule shift. "
"When set, uses a fixed shift (e.g. 3.0 for Lightning LoRAs) instead of the default dynamic shifting. "
"Leave unset for the base model's default schedule.",
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = self._run_diffusion(context)
latents = latents.detach().to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
def _prep_inpaint_mask(self, context: InvocationContext, latents: torch.Tensor) -> torch.Tensor | None:
if self.denoise_mask is None:
return None
mask = context.tensors.load(self.denoise_mask.mask_name)
mask = 1.0 - mask
_, _, latent_height, latent_width = latents.shape
mask = tv_resize(
img=mask,
size=[latent_height, latent_width],
interpolation=tv_transforms.InterpolationMode.BILINEAR,
antialias=False,
)
mask = mask.to(device=latents.device, dtype=latents.dtype)
return mask
def _load_text_conditioning(
self,
context: InvocationContext,
conditioning_name: str,
dtype: torch.dtype,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor | None]:
cond_data = context.conditioning.load(conditioning_name)
assert len(cond_data.conditionings) == 1
conditioning = cond_data.conditionings[0]
assert isinstance(conditioning, QwenImageConditioningInfo)
conditioning = conditioning.to(dtype=dtype, device=device)
return conditioning.prompt_embeds, conditioning.prompt_embeds_mask
def _get_noise(
self,
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
seed: int,
) -> torch.Tensor:
rand_device = "cpu"
rand_dtype = torch.float32
return torch.randn(
batch_size,
num_channels_latents,
int(height) // LATENT_SCALE_FACTOR,
int(width) // LATENT_SCALE_FACTOR,
device=rand_device,
dtype=rand_dtype,
generator=torch.Generator(device=rand_device).manual_seed(seed),
).to(device=device, dtype=dtype)
def _prepare_cfg_scale(self, num_timesteps: int) -> list[float]:
if isinstance(self.cfg_scale, float):
cfg_scale = [self.cfg_scale] * num_timesteps
elif isinstance(self.cfg_scale, list):
assert len(self.cfg_scale) == num_timesteps
cfg_scale = self.cfg_scale
else:
raise ValueError(f"Invalid CFG scale type: {type(self.cfg_scale)}")
return cfg_scale
@staticmethod
def _pack_latents(
latents: torch.Tensor, batch_size: int, num_channels: int, height: int, width: int
) -> torch.Tensor:
"""Pack 4D latents (B, C, H, W) into 2x2-patched 3D (B, H/2*W/2, C*4)."""
latents = latents.view(batch_size, num_channels, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels * 4)
return latents
@staticmethod
def _unpack_latents(latents: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""Unpack 3D patched latents (B, seq, C*4) back to 4D (B, C, H, W)."""
batch_size, _num_patches, channels = latents.shape
# height/width are in latent space; they must be divisible by 2 for packing
h = 2 * (height // 2)
w = 2 * (width // 2)
latents = latents.view(batch_size, h // 2, w // 2, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // 4, h, w)
return latents
def _run_diffusion(self, context: InvocationContext):
inference_dtype = torch.bfloat16
device = TorchDevice.choose_torch_device()
transformer_info = context.models.load(self.transformer.transformer)
assert isinstance(transformer_info.model, QwenImageTransformer2DModel)
# Load conditioning
pos_prompt_embeds, pos_prompt_mask = self._load_text_conditioning(
context=context,
conditioning_name=self.positive_conditioning.conditioning_name,
dtype=inference_dtype,
device=device,
)
neg_prompt_embeds = None
neg_prompt_mask = None
# Match the diffusers pipeline: only enable CFG when cfg_scale > 1 AND negative conditioning is provided.
# With cfg_scale <= 1, the negative prediction is unused, so skip it entirely.
# For per-step arrays, enable CFG if any step has scale > 1.
if isinstance(self.cfg_scale, list):
any_cfg_above_one = any(v > 1.0 for v in self.cfg_scale)
else:
any_cfg_above_one = self.cfg_scale > 1.0
do_classifier_free_guidance = self.negative_conditioning is not None and any_cfg_above_one
if do_classifier_free_guidance:
neg_prompt_embeds, neg_prompt_mask = self._load_text_conditioning(
context=context,
conditioning_name=self.negative_conditioning.conditioning_name,
dtype=inference_dtype,
device=device,
)
# Prepare the timestep / sigma schedule
patch_size = transformer_info.model.config.patch_size
assert isinstance(patch_size, int)
# Output channels is 16 (the actual latent channels)
out_channels = transformer_info.model.config.out_channels
assert isinstance(out_channels, int)
latent_height = self.height // LATENT_SCALE_FACTOR
latent_width = self.width // LATENT_SCALE_FACTOR
image_seq_len = (latent_height * latent_width) // (patch_size**2)
# Use the actual FlowMatchEulerDiscreteScheduler to compute sigmas/timesteps,
# exactly matching the diffusers pipeline.
import math
import numpy as np
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
# Try to load the scheduler config from the model's directory (Diffusers models
# have a scheduler/ subdir). For GGUF models this path doesn't exist, so fall
# back to instantiating the scheduler with the known Qwen Image defaults.
model_path = context.models.get_absolute_path(context.models.get_config(self.transformer.transformer))
scheduler_path = model_path / "scheduler"
if scheduler_path.is_dir() and (scheduler_path / "scheduler_config.json").exists():
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(str(scheduler_path), local_files_only=True)
else:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.5,
max_shift=0.9,
base_image_seq_len=256,
max_image_seq_len=8192,
shift_terminal=0.02,
num_train_timesteps=1000,
time_shift_type="exponential",
)
if self.shift is not None:
# Lightning LoRA: fixed shift
mu = math.log(self.shift)
else:
# Default dynamic shifting
# Linear interpolation matching diffusers' calculate_shift
base_shift = scheduler.config.get("base_shift", 0.5)
max_shift = scheduler.config.get("max_shift", 0.9)
base_seq = scheduler.config.get("base_image_seq_len", 256)
max_seq = scheduler.config.get("max_image_seq_len", 4096)
m = (max_shift - base_shift) / (max_seq - base_seq)
b = base_shift - m * base_seq
mu = image_seq_len * m + b
init_sigmas = np.linspace(1.0, 1.0 / self.steps, self.steps).tolist()
scheduler.set_timesteps(sigmas=init_sigmas, mu=mu, device=device)
# Clip the schedule based on denoising_start/denoising_end to support img2img strength.
# The scheduler's sigmas go from high (noisy) to 0 (clean). We clip to the fractional range.
sigmas_sched = scheduler.sigmas # (N+1,) including terminal 0
if self.denoising_start > 0 or self.denoising_end < 1:
total_sigmas = len(sigmas_sched) - 1 # exclude terminal
start_idx = int(round(self.denoising_start * total_sigmas))
end_idx = int(round(self.denoising_end * total_sigmas))
sigmas_sched = sigmas_sched[start_idx : end_idx + 1] # +1 to include the next sigma for dt
# Rebuild timesteps from clipped sigmas (exclude terminal 0)
timesteps_sched = sigmas_sched[:-1] * scheduler.config.num_train_timesteps
else:
timesteps_sched = scheduler.timesteps
total_steps = len(timesteps_sched)
cfg_scale = self._prepare_cfg_scale(total_steps)
# Load initial latents if provided (for img2img)
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
init_latents = init_latents.to(device=device, dtype=inference_dtype)
if init_latents.dim() == 5:
init_latents = init_latents.squeeze(2)
# Load reference image latents if provided
ref_latents = None
if self.reference_latents is not None:
ref_latents = context.tensors.load(self.reference_latents.latents_name)
ref_latents = ref_latents.to(device=device, dtype=inference_dtype)
# The VAE encoder produces 5D latents (B, C, 1, H, W); squeeze the frame dim
# so we have 4D (B, C, H, W) for packing.
if ref_latents.dim() == 5:
ref_latents = ref_latents.squeeze(2)
# Generate noise (16 channels - the output latent channels)
noise = self._get_noise(
batch_size=1,
num_channels_latents=out_channels,
height=self.height,
width=self.width,
dtype=inference_dtype,
device=device,
seed=self.seed,
)
# Prepare input latent image
if init_latents is not None:
s_0 = sigmas_sched[0].item()
latents = s_0 * noise + (1.0 - s_0) * init_latents
else:
if self.denoising_start > 1e-5:
raise ValueError("denoising_start should be 0 when initial latents are not provided.")
latents = noise
if total_steps <= 0:
return latents
# Pack latents into 2x2 patches: (B, C, H, W) -> (B, H/2*W/2, C*4)
latents = self._pack_latents(latents, 1, out_channels, latent_height, latent_width)
# Determine whether the model uses reference latent conditioning (zero_cond_t).
# Edit models (zero_cond_t=True) expect [noisy_patches ; ref_patches] in the sequence.
# Txt2img models (zero_cond_t=False) only take noisy patches.
has_zero_cond_t = getattr(transformer_info.model, "zero_cond_t", False) or getattr(
transformer_info.model.config, "zero_cond_t", False
)
use_ref_latents = has_zero_cond_t
ref_latents_packed = None
if use_ref_latents:
if ref_latents is not None:
_, ref_ch, rh, rw = ref_latents.shape
if rh != latent_height or rw != latent_width:
ref_latents = torch.nn.functional.interpolate(
ref_latents, size=(latent_height, latent_width), mode="bilinear"
)
else:
# No reference image provided — use zeros so the model still gets the
# expected sequence layout.
ref_latents = torch.zeros(
1, out_channels, latent_height, latent_width, device=device, dtype=inference_dtype
)
ref_latents_packed = self._pack_latents(ref_latents, 1, out_channels, latent_height, latent_width)
# img_shapes tells the transformer the spatial layout of patches.
if use_ref_latents:
img_shapes = [
[
(1, latent_height // 2, latent_width // 2),
(1, latent_height // 2, latent_width // 2),
]
]
else:
img_shapes = [
[
(1, latent_height // 2, latent_width // 2),
]
]
# Prepare inpaint extension (operates in 4D space, so unpack/repack around it)
inpaint_mask = self._prep_inpaint_mask(context, noise) # noise has the right 4D shape
inpaint_extension: RectifiedFlowInpaintExtension | None = None
if inpaint_mask is not None:
assert init_latents is not None
inpaint_extension = RectifiedFlowInpaintExtension(
init_latents=init_latents,
inpaint_mask=inpaint_mask,
noise=noise,
)
step_callback = self._build_step_callback(context)
step_callback(
PipelineIntermediateState(
step=0,
order=1,
total_steps=total_steps,
timestep=int(timesteps_sched[0].item()) if len(timesteps_sched) > 0 else 0,
latents=self._unpack_latents(latents, latent_height, latent_width),
),
)
noisy_seq_len = latents.shape[1]
# Determine if the model is quantized — GGUF models need sidecar patching for LoRAs
transformer_config = context.models.get_config(self.transformer.transformer)
model_is_quantized = transformer_config.format in (ModelFormat.GGUFQuantized,)
with ExitStack() as exit_stack:
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
assert isinstance(transformer, QwenImageTransformer2DModel)
# Apply LoRA patches to the transformer
exit_stack.enter_context(
LayerPatcher.apply_smart_model_patches(
model=transformer,
patches=self._lora_iterator(context),
prefix=QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX,
dtype=inference_dtype,
cached_weights=cached_weights,
force_sidecar_patching=model_is_quantized,
)
)
for step_idx, t in enumerate(tqdm(timesteps_sched)):
# The pipeline passes timestep / 1000 to the transformer
timestep = t.expand(latents.shape[0]).to(inference_dtype)
# For edit models: concatenate noisy and reference patches along the sequence dim
# For txt2img models: just use noisy patches
if ref_latents_packed is not None:
model_input = torch.cat([latents, ref_latents_packed], dim=1)
else:
model_input = latents
noise_pred_cond = transformer(
hidden_states=model_input,
encoder_hidden_states=pos_prompt_embeds,
encoder_hidden_states_mask=pos_prompt_mask,
timestep=timestep / 1000,
img_shapes=img_shapes,
return_dict=False,
)[0]
# Only keep the noisy-latent portion of the output
noise_pred_cond = noise_pred_cond[:, :noisy_seq_len]
if do_classifier_free_guidance and neg_prompt_embeds is not None:
noise_pred_uncond = transformer(
hidden_states=model_input,
encoder_hidden_states=neg_prompt_embeds,
encoder_hidden_states_mask=neg_prompt_mask,
timestep=timestep / 1000,
img_shapes=img_shapes,
return_dict=False,
)[0]
noise_pred_uncond = noise_pred_uncond[:, :noisy_seq_len]
noise_pred = noise_pred_uncond + cfg_scale[step_idx] * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
# Euler step using the (possibly clipped) sigma schedule
sigma_curr = sigmas_sched[step_idx]
sigma_next = sigmas_sched[step_idx + 1]
dt = sigma_next - sigma_curr
latents = latents.to(torch.float32) + dt * noise_pred.to(torch.float32)
latents = latents.to(inference_dtype)
if inpaint_extension is not None:
sigma_next = sigmas_sched[step_idx + 1].item()
latents_4d = self._unpack_latents(latents, latent_height, latent_width)
latents_4d = inpaint_extension.merge_intermediate_latents_with_init_latents(latents_4d, sigma_next)
latents = self._pack_latents(latents_4d, 1, out_channels, latent_height, latent_width)
step_callback(
PipelineIntermediateState(
step=step_idx + 1,
order=1,
total_steps=total_steps,
timestep=int(t.item()),
latents=self._unpack_latents(latents, latent_height, latent_width),
),
)
# Unpack back to 4D then add frame dim for the video-style VAE: (B, C, 1, H, W)
latents = self._unpack_latents(latents, latent_height, latent_width)
latents = latents.unsqueeze(2)
return latents
def _build_step_callback(self, context: InvocationContext) -> Callable[[PipelineIntermediateState], None]:
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, BaseModelType.QwenImage)
return step_callback
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[ModelPatchRaw, float]]:
"""Iterate over LoRA models to apply to the transformer."""
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)
if not isinstance(lora_info.model, ModelPatchRaw):
raise TypeError(
f"Expected ModelPatchRaw for LoRA '{lora.lora.key}', got {type(lora_info.model).__name__}."
)
yield (lora_info.model, lora.weight)
del lora_info

View File

@@ -0,0 +1,96 @@
import einops
import torch
from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from PIL import Image as PILImage
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
@invocation(
"qwen_image_i2l",
title="Image to Latents - Qwen Image",
tags=["image", "latents", "vae", "i2l", "qwen_image"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates latents from an image using the Qwen Image VAE."""
image: ImageField = InputField(description="The image to encode.")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
width: int | None = InputField(
default=None,
description="Resize the image to this width before encoding. If not set, encodes at the image's original size.",
)
height: int | None = InputField(
default=None,
description="Resize the image to this height before encoding. If not set, encodes at the image's original size.",
)
@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
with vae_info.model_on_device() as (_, vae):
assert isinstance(vae, AutoencoderKLQwenImage)
vae.disable_tiling()
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
with torch.inference_mode():
# The Qwen Image VAE expects 5D input: (B, C, num_frames, H, W)
if image_tensor.dim() == 4:
image_tensor = image_tensor.unsqueeze(2)
posterior = vae.encode(image_tensor).latent_dist
# Use mode (argmax) for deterministic encoding, matching diffusers
latents: torch.Tensor = posterior.mode().to(dtype=vae.dtype)
# Normalize with per-channel latents_mean / latents_std
latents_mean = (
torch.tensor(vae.config.latents_mean)
.view(1, vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(vae.config.latents_std)
.view(1, vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents = (latents - latents_mean) / latents_std
return latents
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)
# If target dimensions are specified, resize the image BEFORE encoding
# (matching the diffusers pipeline which resizes in pixel space, not latent space).
if self.width is not None and self.height is not None:
image = image.convert("RGB").resize((self.width, self.height), resample=PILImage.LANCZOS)
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
vae_info = context.models.load(self.vae.vae)
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)

View File

@@ -0,0 +1,85 @@
from contextlib import nullcontext
import torch
from diffusers.models.autoencoders.autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from einops import rearrange
from PIL import Image
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
from invokeai.backend.util.devices import TorchDevice
@invocation(
"qwen_image_l2i",
title="Latents to Image - Qwen Image",
tags=["latents", "image", "vae", "l2i", "qwen_image"],
category="latents",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents using the Qwen Image VAE."""
latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, AutoencoderKLQwenImage)
with (
SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes),
vae_info.model_on_device() as (_, vae),
):
context.util.signal_progress("Running VAE")
assert isinstance(vae, AutoencoderKLQwenImage)
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
vae.disable_tiling()
tiling_context = nullcontext()
TorchDevice.empty_cache()
with torch.inference_mode(), tiling_context:
# The Qwen Image VAE uses per-channel latents_mean / latents_std
# instead of a single scaling_factor.
# Latents are 5D: (B, C, num_frames, H, W) — the unpack from the
# denoise step already produces this shape.
latents_mean = (
torch.tensor(vae.config.latents_mean)
.view(1, vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
img = vae.decode(latents, return_dict=False)[0]
# Drop the temporal frame dimension: (B, C, 1, H, W) -> (B, C, H, W)
img = img[:, :, 0]
img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())
TorchDevice.empty_cache()
image_dto = context.images.save(image=img_pil)
return ImageOutput.build(image_dto)

View File

@@ -0,0 +1,115 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType
@invocation_output("qwen_image_lora_loader_output")
class QwenImageLoRALoaderOutput(BaseInvocationOutput):
"""Qwen Image LoRA Loader Output"""
transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="Transformer"
)
@invocation(
"qwen_image_lora_loader",
title="Apply LoRA - Qwen Image",
tags=["lora", "model", "qwen_image"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a Qwen Image transformer."""
lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_base=BaseModelType.QwenImage,
ui_model_type=ModelType.LoRA,
)
weight: float = InputField(default=1.0, description=FieldDescriptions.lora_weight)
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
)
def invoke(self, context: InvocationContext) -> QwenImageLoRALoaderOutput:
lora_key = self.lora.key
if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")
if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
output = QwenImageLoRALoaderOutput()
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
return output
@invocation(
"qwen_image_lora_collection_loader",
title="Apply LoRA Collection - Qwen Image",
tags=["lora", "model", "qwen_image"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class QwenImageLoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to a Qwen Image transformer."""
loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)
transformer: Optional[TransformerField] = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
)
def invoke(self, context: InvocationContext) -> QwenImageLoRALoaderOutput:
output = QwenImageLoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []
if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
for lora in loras:
if lora is None:
continue
if lora.lora.key in added_loras:
continue
if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")
added_loras.append(lora.lora.key)
if self.transformer is not None and output.transformer is not None:
output.transformer.loras.append(lora)
return output

View File

@@ -0,0 +1,107 @@
from typing import Optional
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import (
ModelIdentifierField,
QwenVLEncoderField,
TransformerField,
VAEField,
)
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, SubModelType
@invocation_output("qwen_image_model_loader_output")
class QwenImageModelLoaderOutput(BaseInvocationOutput):
"""Qwen Image model loader output."""
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer")
qwen_vl_encoder: QwenVLEncoderField = OutputField(
description=FieldDescriptions.qwen_vl_encoder, title="Qwen VL Encoder"
)
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE")
@invocation(
"qwen_image_model_loader",
title="Main Model - Qwen Image",
tags=["model", "qwen_image"],
category="model",
version="1.1.0",
classification=Classification.Prototype,
)
class QwenImageModelLoaderInvocation(BaseInvocation):
"""Loads a Qwen Image model, outputting its submodels.
The transformer is always loaded from the main model (Diffusers or GGUF).
For GGUF quantized models, the VAE and Qwen VL encoder must come from a
separate Diffusers model specified in the "Component Source" field.
For Diffusers models, all components are extracted from the main model
automatically. The "Component Source" field is ignored.
"""
model: ModelIdentifierField = InputField(
description=FieldDescriptions.qwen_image_model,
input=Input.Direct,
ui_model_base=BaseModelType.QwenImage,
ui_model_type=ModelType.Main,
title="Transformer",
)
component_source: Optional[ModelIdentifierField] = InputField(
default=None,
description="Diffusers Qwen Image model to extract the VAE and Qwen VL encoder from. "
"Required when using a GGUF quantized transformer. "
"Ignored when the main model is already in Diffusers format.",
input=Input.Direct,
ui_model_base=BaseModelType.QwenImage,
ui_model_type=ModelType.Main,
ui_model_format=ModelFormat.Diffusers,
title="Component Source (Diffusers)",
)
def invoke(self, context: InvocationContext) -> QwenImageModelLoaderOutput:
main_config = context.models.get_config(self.model)
main_is_diffusers = main_config.format == ModelFormat.Diffusers
# Transformer always comes from the main model
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer})
if main_is_diffusers:
# Diffusers model: extract all components directly
vae = self.model.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.model.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.model.model_copy(update={"submodel_type": SubModelType.TextEncoder})
elif self.component_source is not None:
# GGUF/checkpoint transformer: get VAE + encoder from the component source
source_config = context.models.get_config(self.component_source)
if source_config.format != ModelFormat.Diffusers:
raise ValueError(
f"The Component Source model must be in Diffusers format. "
f"The selected model '{source_config.name}' is in {source_config.format.value} format."
)
vae = self.component_source.model_copy(update={"submodel_type": SubModelType.VAE})
tokenizer = self.component_source.model_copy(update={"submodel_type": SubModelType.Tokenizer})
text_encoder = self.component_source.model_copy(update={"submodel_type": SubModelType.TextEncoder})
else:
raise ValueError(
"No source for VAE and Qwen VL encoder. "
"GGUF quantized models only contain the transformer — "
"please set 'Component Source' to a Diffusers Qwen Image model "
"to provide the VAE and text encoder."
)
return QwenImageModelLoaderOutput(
transformer=TransformerField(transformer=transformer, loras=[]),
qwen_vl_encoder=QwenVLEncoderField(tokenizer=tokenizer, text_encoder=text_encoder),
vae=VAEField(vae=vae),
)

View File

@@ -0,0 +1,298 @@
from typing import Literal
import torch
from PIL import Image as PILImage
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
UIComponent,
)
from invokeai.app.invocations.model import QwenVLEncoderField
from invokeai.app.invocations.primitives import QwenImageConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningFieldData,
QwenImageConditioningInfo,
)
# Prompt templates and drop indices for the two Qwen Image model modes.
# These are taken directly from the diffusers pipelines.
# Image editing mode (QwenImagePipeline)
_EDIT_SYSTEM_PROMPT = (
"Describe the key features of the input image (color, shape, size, texture, objects, background), "
"then explain how the user's text instruction should alter or modify the image. "
"Generate a new image that meets the user's requirements while maintaining consistency "
"with the original input where appropriate."
)
_EDIT_DROP_IDX = 64
# Text-to-image mode (QwenImagePipeline)
_GENERATE_SYSTEM_PROMPT = (
"Describe the image by detailing the color, shape, size, texture, quantity, "
"text, spatial relationships of the objects and background:"
)
_GENERATE_DROP_IDX = 34
_IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
def _build_prompt(user_prompt: str, num_images: int) -> str:
"""Build the full prompt with the appropriate template based on whether reference images are provided."""
if num_images > 0:
# Edit mode: include vision placeholders for reference images
image_tokens = _IMAGE_PLACEHOLDER * num_images
return (
f"<|im_start|>system\n{_EDIT_SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{image_tokens}{user_prompt}<|im_end|>\n"
"<|im_start|>assistant\n"
)
else:
# Generate mode: text-only prompt
return (
f"<|im_start|>system\n{_GENERATE_SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{user_prompt}<|im_end|>\n"
"<|im_start|>assistant\n"
)
@invocation(
"qwen_image_text_encoder",
title="Prompt - Qwen Image",
tags=["prompt", "conditioning", "qwen_image"],
category="conditioning",
version="1.2.0",
classification=Classification.Prototype,
)
class QwenImageTextEncoderInvocation(BaseInvocation):
"""Encodes text and reference images for Qwen Image using Qwen2.5-VL."""
prompt: str = InputField(description="Text prompt describing the desired edit.", ui_component=UIComponent.Textarea)
reference_images: list[ImageField] = InputField(
default=[],
description="Reference images to guide the edit. The model can use multiple reference images.",
)
qwen_vl_encoder: QwenVLEncoderField = InputField(
title="Qwen VL Encoder",
description=FieldDescriptions.qwen_vl_encoder,
input=Input.Connection,
)
quantization: Literal["none", "int8", "nf4"] = InputField(
default="none",
description="Quantize the Qwen VL encoder to reduce VRAM usage. "
"'nf4' (4-bit) saves the most memory, 'int8' (8-bit) is a middle ground.",
)
@staticmethod
def _resize_for_vl_encoder(image: PILImage.Image, target_pixels: int = 512 * 512) -> PILImage.Image:
"""Resize image to fit within target_pixels while preserving aspect ratio.
Matches the diffusers pipeline's calculate_dimensions logic: the image is resized
so its total pixel count is approximately target_pixels, with dimensions rounded to
multiples of 32. This prevents large images from producing too many vision tokens
which can overwhelm the text prompt.
"""
w, h = image.size
aspect = w / h
# Compute dimensions that preserve aspect ratio at ~target_pixels total
new_w = int((target_pixels * aspect) ** 0.5)
new_h = int(target_pixels / new_w)
# Round to multiples of 32
new_w = max(32, (new_w // 32) * 32)
new_h = max(32, (new_h // 32) * 32)
if new_w != w or new_h != h:
image = image.resize((new_w, new_h), resample=PILImage.LANCZOS)
return image
@torch.no_grad()
def invoke(self, context: InvocationContext) -> QwenImageConditioningOutput:
# Load and resize reference images to ~1M pixels (matching diffusers pipeline)
pil_images: list[PILImage.Image] = []
for img_field in self.reference_images:
pil_img = context.images.get_pil(img_field.image_name)
pil_img = self._resize_for_vl_encoder(pil_img.convert("RGB"))
pil_images.append(pil_img)
prompt_embeds, prompt_mask = self._encode(context, pil_images)
prompt_embeds = prompt_embeds.detach().to("cpu")
prompt_mask = prompt_mask.detach().to("cpu") if prompt_mask is not None else None
conditioning_data = ConditioningFieldData(
conditionings=[QwenImageConditioningInfo(prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_mask)]
)
conditioning_name = context.conditioning.save(conditioning_data)
return QwenImageConditioningOutput.build(conditioning_name)
def _encode(
self, context: InvocationContext, images: list[PILImage.Image]
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Encode text prompt and reference images using Qwen2.5-VL.
Matches the diffusers QwenImagePipeline._get_qwen_prompt_embeds logic:
1. Format prompt with the edit-specific system template
2. Run through Qwen2.5-VL to get hidden states
3. Extract valid (non-padding) tokens and drop the system prefix
4. Return padded embeddings + attention mask
"""
from transformers import AutoTokenizer, Qwen2_5_VLProcessor
try:
from transformers import Qwen2_5_VLImageProcessor as _ImageProcessorCls
except ImportError:
from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( # type: ignore[no-redef]
Qwen2VLImageProcessor as _ImageProcessorCls,
)
try:
from transformers import Qwen2_5_VLVideoProcessor as _VideoProcessorCls
except ImportError:
from transformers.models.qwen2_vl.video_processing_qwen2_vl import ( # type: ignore[no-redef]
Qwen2VLVideoProcessor as _VideoProcessorCls,
)
# Format the prompt with one vision placeholder per reference image
text = _build_prompt(self.prompt, len(images))
# Build the processor
tokenizer_config = context.models.get_config(self.qwen_vl_encoder.tokenizer)
model_root = context.models.get_absolute_path(tokenizer_config)
tokenizer_dir = model_root / "tokenizer"
tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_dir), local_files_only=True)
image_processor = None
for search_dir in [model_root / "processor", tokenizer_dir, model_root, model_root / "image_processor"]:
if (search_dir / "preprocessor_config.json").exists():
image_processor = _ImageProcessorCls.from_pretrained(str(search_dir), local_files_only=True)
break
if image_processor is None:
image_processor = _ImageProcessorCls()
processor = Qwen2_5_VLProcessor(
tokenizer=tokenizer,
image_processor=image_processor,
video_processor=_VideoProcessorCls(),
)
context.util.signal_progress("Running Qwen2.5-VL text/vision encoder")
if self.quantization != "none":
text_encoder, device, cleanup = self._load_quantized_encoder(context)
else:
text_encoder, device, cleanup = self._load_cached_encoder(context)
try:
model_inputs = processor(
text=[text],
images=images if images else None,
padding=True,
return_tensors="pt",
).to(device=device)
outputs = text_encoder(
input_ids=model_inputs.input_ids,
attention_mask=model_inputs.attention_mask,
pixel_values=getattr(model_inputs, "pixel_values", None),
image_grid_thw=getattr(model_inputs, "image_grid_thw", None),
output_hidden_states=True,
)
# Use last hidden state (matching diffusers pipeline)
hidden_states = outputs.hidden_states[-1]
# Extract valid (non-padding) tokens using the attention mask,
# then drop the system prompt prefix tokens.
# The drop index differs between edit mode (64) and generate mode (34).
drop_idx = _EDIT_DROP_IDX if images else _GENERATE_DROP_IDX
attn_mask = model_inputs.attention_mask
bool_mask = attn_mask.bool()
valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask]
split_hidden = torch.split(selected, valid_lengths.tolist(), dim=0)
# Drop system prefix tokens and build padded output
trimmed = [h[drop_idx:] for h in split_hidden]
attn_mask_list = [torch.ones(h.size(0), dtype=torch.long, device=device) for h in trimmed]
max_seq_len = max(h.size(0) for h in trimmed)
prompt_embeds = torch.stack(
[torch.cat([h, h.new_zeros(max_seq_len - h.size(0), h.size(1))]) for h in trimmed]
)
encoder_attention_mask = torch.stack(
[torch.cat([m, m.new_zeros(max_seq_len - m.size(0))]) for m in attn_mask_list]
)
prompt_embeds = prompt_embeds.to(dtype=torch.bfloat16)
finally:
if cleanup is not None:
cleanup()
# If all tokens are valid (no padding), mask is not needed
if encoder_attention_mask.all():
encoder_attention_mask = None
return prompt_embeds, encoder_attention_mask
def _load_cached_encoder(self, context: InvocationContext):
"""Load the text encoder through the model cache (no quantization)."""
from transformers import Qwen2_5_VLForConditionalGeneration
text_encoder_info = context.models.load(self.qwen_vl_encoder.text_encoder)
ctx = text_encoder_info.model_on_device()
_, text_encoder = ctx.__enter__()
device = get_effective_device(text_encoder)
assert isinstance(text_encoder, Qwen2_5_VLForConditionalGeneration)
return text_encoder, device, lambda: ctx.__exit__(None, None, None)
def _load_quantized_encoder(self, context: InvocationContext):
"""Load the text encoder with BitsAndBytes quantization, bypassing the model cache.
BnB-quantized models are pinned to GPU and can't be moved between devices,
so they can't go through the standard model cache. The model is loaded fresh
each time and freed after use via the cleanup callback.
"""
import gc
import warnings
from transformers import BitsAndBytesConfig, Qwen2_5_VLForConditionalGeneration
encoder_config = context.models.get_config(self.qwen_vl_encoder.text_encoder)
model_root = context.models.get_absolute_path(encoder_config)
encoder_path = model_root / "text_encoder"
if self.quantization == "nf4":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
)
else: # int8
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
context.util.signal_progress("Loading Qwen2.5-VL encoder (quantized)")
with warnings.catch_warnings():
# BnB int8 internally casts bfloat16→float16; the warning is harmless
warnings.filterwarnings("ignore", message="MatMul8bitLt.*cast.*float16")
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
str(encoder_path),
quantization_config=bnb_config,
device_map="auto",
torch_dtype=torch.bfloat16,
local_files_only=True,
)
device = next(text_encoder.parameters()).device
def cleanup():
nonlocal text_encoder
del text_encoder
gc.collect()
torch.cuda.empty_cache()
return text_encoder, device, cleanup

View File

@@ -9,6 +9,17 @@ from invokeai.app.util.misc import get_iso_timestamp
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
class BoardVisibility(str, Enum, metaclass=MetaEnum):
"""The visibility options for a board."""
Private = "private"
"""Only the board owner (and admins) can see and modify this board."""
Shared = "shared"
"""All users can view this board, but only the owner (and admins) can modify it."""
Public = "public"
"""All users can view this board; only the owner (and admins) can modify its structure."""
class BoardRecord(BaseModelExcludeNull):
"""Deserialized board record."""
@@ -28,6 +39,10 @@ class BoardRecord(BaseModelExcludeNull):
"""The name of the cover image of the board."""
archived: bool = Field(description="Whether or not the board is archived.")
"""Whether or not the board is archived."""
board_visibility: BoardVisibility = Field(
default=BoardVisibility.Private, description="The visibility of the board."
)
"""The visibility of the board (private, shared, or public)."""
def deserialize_board_record(board_dict: dict) -> BoardRecord:
@@ -44,6 +59,11 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
updated_at = board_dict.get("updated_at", get_iso_timestamp())
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
archived = board_dict.get("archived", False)
board_visibility_raw = board_dict.get("board_visibility", BoardVisibility.Private.value)
try:
board_visibility = BoardVisibility(board_visibility_raw)
except ValueError:
board_visibility = BoardVisibility.Private
return BoardRecord(
board_id=board_id,
@@ -54,6 +74,7 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
updated_at=updated_at,
deleted_at=deleted_at,
archived=archived,
board_visibility=board_visibility,
)
@@ -61,6 +82,7 @@ class BoardChanges(BaseModel, extra="forbid"):
board_name: Optional[str] = Field(default=None, description="The board's new name.", max_length=300)
cover_image_name: Optional[str] = Field(default=None, description="The name of the board's new cover image.")
archived: Optional[bool] = Field(default=None, description="Whether or not the board is archived")
board_visibility: Optional[BoardVisibility] = Field(default=None, description="The visibility of the board.")
class BoardRecordOrderBy(str, Enum, metaclass=MetaEnum):

View File

@@ -116,6 +116,17 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
(changes.archived, board_id),
)
# Change the visibility of a board
if changes.board_visibility is not None:
cursor.execute(
"""--sql
UPDATE boards
SET board_visibility = ?
WHERE board_id = ?;
""",
(changes.board_visibility.value, board_id),
)
except sqlite3.Error as e:
raise BoardRecordSaveException from e
return self.get(board_id)
@@ -155,7 +166,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
{archived_filter}
ORDER BY {order_by} {direction}
LIMIT ? OFFSET ?;
@@ -194,14 +205,14 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
SELECT COUNT(DISTINCT boards.board_id)
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1);
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'));
"""
else:
count_query = """
SELECT COUNT(DISTINCT boards.board_id)
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
AND boards.archived = 0;
"""
@@ -251,7 +262,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
{archived_filter}
ORDER BY LOWER(boards.board_name) {direction}
"""
@@ -260,7 +271,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
SELECT DISTINCT boards.*
FROM boards
LEFT JOIN shared_boards ON boards.board_id = shared_boards.board_id
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.is_public = 1)
WHERE (boards.user_id = ? OR shared_boards.user_id = ? OR boards.board_visibility IN ('shared', 'public'))
{archived_filter}
ORDER BY {order_by} {direction}
"""

View File

@@ -7,7 +7,11 @@ class BulkDownloadBase(ABC):
@abstractmethod
def handler(
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
self,
image_names: Optional[list[str]],
board_id: Optional[str],
bulk_download_item_id: Optional[str],
user_id: str = "system",
) -> None:
"""
Create a zip file containing the images specified by the given image names or board id.
@@ -15,6 +19,7 @@ class BulkDownloadBase(ABC):
:param image_names: A list of image names to include in the zip file.
:param board_id: The ID of the board. If provided, all images associated with the board will be included in the zip file.
:param bulk_download_item_id: The bulk_download_item_id that will be used to retrieve the bulk download item when it is prepared, if none is provided a uuid will be generated.
:param user_id: The ID of the user who initiated the download.
"""
@abstractmethod
@@ -42,3 +47,12 @@ class BulkDownloadBase(ABC):
:param bulk_download_item_name: The name of the bulk download item.
"""
@abstractmethod
def get_owner(self, bulk_download_item_name: str) -> Optional[str]:
"""
Get the user_id of the user who initiated the download.
:param bulk_download_item_name: The name of the bulk download item.
:return: The user_id of the owner, or None if not tracked.
"""

View File

@@ -25,15 +25,24 @@ class BulkDownloadService(BulkDownloadBase):
self._temp_directory = TemporaryDirectory()
self._bulk_downloads_folder = Path(self._temp_directory.name) / "bulk_downloads"
self._bulk_downloads_folder.mkdir(parents=True, exist_ok=True)
# Track which user owns each download so the fetch endpoint can enforce ownership
self._download_owners: dict[str, str] = {}
def handler(
self, image_names: Optional[list[str]], board_id: Optional[str], bulk_download_item_id: Optional[str]
self,
image_names: Optional[list[str]],
board_id: Optional[str],
bulk_download_item_id: Optional[str],
user_id: str = "system",
) -> None:
bulk_download_id: str = DEFAULT_BULK_DOWNLOAD_ID
bulk_download_item_id = bulk_download_item_id or uuid_string()
bulk_download_item_name = bulk_download_item_id + ".zip"
self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
# Record ownership so the fetch endpoint can verify the caller
self._download_owners[bulk_download_item_name] = user_id
self._signal_job_started(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
try:
image_dtos: list[ImageDTO] = []
@@ -46,16 +55,16 @@ class BulkDownloadService(BulkDownloadBase):
raise BulkDownloadParametersException()
bulk_download_item_name: str = self._create_zip_file(image_dtos, bulk_download_item_id)
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name)
self._signal_job_completed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
except (
ImageRecordNotFoundException,
BoardRecordNotFoundException,
BulkDownloadException,
BulkDownloadParametersException,
) as e:
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id)
except Exception as e:
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e)
self._signal_job_failed(bulk_download_id, bulk_download_item_id, bulk_download_item_name, e, user_id)
self._invoker.services.logger.error("Problem bulk downloading images.")
raise e
@@ -103,43 +112,60 @@ class BulkDownloadService(BulkDownloadBase):
return "".join([c for c in s if c.isalpha() or c.isdigit() or c == " " or c == "_" or c == "-"]).rstrip()
def _signal_job_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
user_id: str = "system",
) -> None:
"""Signal that a bulk download job has started."""
if self._invoker:
assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started(
bulk_download_id, bulk_download_item_id, bulk_download_item_name
bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id
)
def _signal_job_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
user_id: str = "system",
) -> None:
"""Signal that a bulk download job has completed."""
if self._invoker:
assert bulk_download_id is not None
assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_complete(
bulk_download_id, bulk_download_item_id, bulk_download_item_name
bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id=user_id
)
def _signal_job_failed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, exception: Exception
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
exception: Exception,
user_id: str = "system",
) -> None:
"""Signal that a bulk download job has failed."""
if self._invoker:
assert bulk_download_id is not None
assert exception is not None
self._invoker.services.events.emit_bulk_download_error(
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception), user_id=user_id
)
def stop(self, *args, **kwargs):
self._temp_directory.cleanup()
def get_owner(self, bulk_download_item_name: str) -> Optional[str]:
return self._download_owners.get(bulk_download_item_name)
def delete(self, bulk_download_item_name: str) -> None:
path = self.get_path(bulk_download_item_name)
Path(path).unlink()
self._download_owners.pop(bulk_download_item_name, None)
def get_path(self, bulk_download_item_name: str) -> str:
path = str(self._bulk_downloads_folder / bulk_download_item_name)

View File

@@ -100,9 +100,9 @@ class EventServiceBase:
"""Emitted when a queue item's status changes"""
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", user_id: str = "system") -> None:
"""Emitted when a batch is enqueued"""
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
self.dispatch(BatchEnqueuedEvent.build(enqueue_result, user_id))
def emit_queue_items_retried(self, retry_result: "RetryItemsResult") -> None:
"""Emitted when a list of queue items are retried"""
@@ -112,9 +112,9 @@ class EventServiceBase:
"""Emitted when a queue is cleared"""
self.dispatch(QueueClearedEvent.build(queue_id))
def emit_recall_parameters_updated(self, queue_id: str, parameters: dict) -> None:
def emit_recall_parameters_updated(self, queue_id: str, user_id: str, parameters: dict) -> None:
"""Emitted when recall parameters are updated"""
self.dispatch(RecallParametersUpdatedEvent.build(queue_id, parameters))
self.dispatch(RecallParametersUpdatedEvent.build(queue_id, user_id, parameters))
# endregion
@@ -194,23 +194,42 @@ class EventServiceBase:
# region Bulk image download
def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
user_id: str = "system",
) -> None:
"""Emitted when a bulk image download is started"""
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
self.dispatch(
BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
)
def emit_bulk_download_complete(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
user_id: str = "system",
) -> None:
"""Emitted when a bulk image download is complete"""
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
self.dispatch(
BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, user_id)
)
def emit_bulk_download_error(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
error: str,
user_id: str = "system",
) -> None:
"""Emitted when a bulk image download has an error"""
self.dispatch(
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
BulkDownloadErrorEvent.build(
bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, user_id
)
)
# endregion

View File

@@ -281,9 +281,10 @@ class BatchEnqueuedEvent(QueueEventBase):
)
priority: int = Field(description="The priority of the batch")
origin: str | None = Field(default=None, description="The origin of the batch")
user_id: str = Field(default="system", description="The ID of the user who enqueued the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
def build(cls, enqueue_result: EnqueueBatchResult, user_id: str = "system") -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
@@ -291,6 +292,7 @@ class BatchEnqueuedEvent(QueueEventBase):
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,
user_id=user_id,
)
@@ -609,6 +611,7 @@ class BulkDownloadEventBase(EventBase):
bulk_download_id: str = Field(description="The ID of the bulk image download")
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
user_id: str = Field(default="system", description="The ID of the user who initiated the download")
@payload_schema.register
@@ -619,12 +622,17 @@ class BulkDownloadStartedEvent(BulkDownloadEventBase):
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
user_id: str = "system",
) -> "BulkDownloadStartedEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
user_id=user_id,
)
@@ -636,12 +644,17 @@ class BulkDownloadCompleteEvent(BulkDownloadEventBase):
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
user_id: str = "system",
) -> "BulkDownloadCompleteEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
user_id=user_id,
)
@@ -655,13 +668,19 @@ class BulkDownloadErrorEvent(BulkDownloadEventBase):
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
error: str,
user_id: str = "system",
) -> "BulkDownloadErrorEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=error,
user_id=user_id,
)
@@ -671,8 +690,9 @@ class RecallParametersUpdatedEvent(QueueEventBase):
__event_name__ = "recall_parameters_updated"
user_id: str = Field(description="The ID of the user whose recall parameters were updated")
parameters: dict[str, Any] = Field(description="The recall parameters that were updated")
@classmethod
def build(cls, queue_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent":
return cls(queue_id=queue_id, parameters=parameters)
def build(cls, queue_id: str, user_id: str, parameters: dict[str, Any]) -> "RecallParametersUpdatedEvent":
return cls(queue_id=queue_id, user_id=user_id, parameters=parameters)

View File

@@ -46,3 +46,9 @@ class FastAPIEventService(EventServiceBase):
except asyncio.CancelledError as e:
raise e # Raise a proper error
except Exception:
import logging
logging.getLogger("InvokeAI").error(
f"Error dispatching event {getattr(event, '__event_name__', event)}", exc_info=True
)

View File

@@ -74,8 +74,8 @@ class ImageRecordStorageBase(ABC):
pass
@abstractmethod
def get_intermediates_count(self) -> int:
"""Gets a count of all intermediate images."""
def get_intermediates_count(self, user_id: Optional[str] = None) -> int:
"""Gets a count of intermediate images. If user_id is provided, only counts that user's intermediates."""
pass
@abstractmethod
@@ -97,6 +97,11 @@ class ImageRecordStorageBase(ABC):
"""Saves an image record."""
pass
@abstractmethod
def get_user_id(self, image_name: str) -> Optional[str]:
"""Gets the user_id of the image owner. Returns None if image not found."""
pass
@abstractmethod
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
"""Gets the most recent image for a board."""

View File

@@ -46,6 +46,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
def get_user_id(self, image_name: str) -> Optional[str]:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT user_id FROM images
WHERE image_name = ?;
""",
(image_name,),
)
result = cast(Optional[sqlite3.Row], cursor.fetchone())
if not result:
return None
return cast(Optional[str], dict(result).get("user_id"))
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
with self._db.transaction() as cursor:
try:
@@ -269,14 +283,14 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
except sqlite3.Error as e:
raise ImageRecordDeleteException from e
def get_intermediates_count(self) -> int:
def get_intermediates_count(self, user_id: Optional[str] = None) -> int:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
WHERE is_intermediate = TRUE;
"""
)
query = "SELECT COUNT(*) FROM images WHERE is_intermediate = TRUE"
params: list[str] = []
if user_id is not None:
query += " AND user_id = ?"
params.append(user_id)
cursor.execute(query, params)
count = cast(int, cursor.fetchone()[0])
return count

View File

@@ -143,8 +143,8 @@ class ImageServiceABC(ABC):
pass
@abstractmethod
def get_intermediates_count(self) -> int:
"""Gets the number of intermediate images."""
def get_intermediates_count(self, user_id: Optional[str] = None) -> int:
"""Gets the number of intermediate images. If user_id is provided, only counts that user's intermediates."""
pass
@abstractmethod

View File

@@ -310,9 +310,9 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem deleting image records and files")
raise e
def get_intermediates_count(self) -> int:
def get_intermediates_count(self, user_id: Optional[str] = None) -> int:
try:
return self.__invoker.services.image_records.get_intermediates_count()
return self.__invoker.services.image_records.get_intermediates_count(user_id=user_id)
except Exception as e:
self.__invoker.services.logger.error("Problem getting intermediates count")
raise e

View File

@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import torch
import yaml
from huggingface_hub import HfFolder
from huggingface_hub import get_token as hf_get_token
from pydantic.networks import AnyHttpUrl
from pydantic_core import Url
from requests import Session
@@ -1115,7 +1115,7 @@ class ModelInstallService(ModelInstallServiceBase):
) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests
if source.access_token is None:
source.access_token = HfFolder.get_token()
source.access_token = hf_get_token()
remote_files, metadata = self._remote_files_from_source(source)
return self._import_remote_model(
source=source,

View File

@@ -30,6 +30,7 @@ from invokeai.backend.model_manager.taxonomy import (
ModelType,
ModelVariantType,
Qwen3VariantType,
QwenImageVariantType,
SchedulerPredictionType,
ZImageVariantType,
)
@@ -109,7 +110,13 @@ class ModelRecordChanges(BaseModelExcludeNull):
# Checkpoint-specific changes
# TODO(MM2): Should we expose these? Feels footgun-y...
variant: Optional[
ModelVariantType | ClipVariantType | FluxVariantType | Flux2VariantType | ZImageVariantType | Qwen3VariantType
ModelVariantType
| ClipVariantType
| FluxVariantType
| Flux2VariantType
| ZImageVariantType
| QwenImageVariantType
| Qwen3VariantType
] = Field(description="The variant of the model.", default=None)
prediction_type: Optional[SchedulerPredictionType] = Field(
description="The prediction type of the model.", default=None

View File

@@ -78,13 +78,15 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
"""Gets the counts of queue items by destination"""
def get_counts_by_destination(
self, queue_id: str, destination: str, user_id: Optional[str] = None
) -> SessionQueueCountsByDestination:
"""Gets the counts of queue items by destination. If user_id is provided, only counts that user's items."""
pass
@abstractmethod
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
"""Gets the status of a batch"""
def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus:
"""Gets the status of a batch. If user_id is provided, only counts that user's items."""
pass
@abstractmethod
@@ -172,8 +174,9 @@ class SessionQueueBase(ABC):
self,
queue_id: str,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
user_id: Optional[str] = None,
) -> ItemIdsResult:
"""Gets all queue item ids that match the given parameters"""
"""Gets all queue item ids that match the given parameters. If user_id is provided, only returns items for that user."""
pass
@abstractmethod

View File

@@ -304,12 +304,6 @@ class SessionQueueStatus(BaseModel):
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
user_pending: Optional[int] = Field(
default=None, description="Number of queue items with status 'pending' for the current user"
)
user_in_progress: Optional[int] = Field(
default=None, description="Number of queue items with status 'in_progress' for the current user"
)
class SessionQueueCountsByDestination(BaseModel):

View File

@@ -151,7 +151,7 @@ class SqliteSessionQueue(SessionQueueBase):
priority=priority,
item_ids=item_ids,
)
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
self.__invoker.services.events.emit_batch_enqueued(enqueue_result, user_id=user_id)
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
@@ -765,15 +765,21 @@ class SqliteSessionQueue(SessionQueueBase):
self,
queue_id: str,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
user_id: Optional[str] = None,
) -> ItemIdsResult:
with self._db.transaction() as cursor_:
query = f"""--sql
query = """--sql
SELECT item_id
FROM session_queue
WHERE queue_id = ?
ORDER BY created_at {order_dir.value}
"""
query_params = [queue_id]
query_params: list[str] = [queue_id]
if user_id is not None:
query += " AND user_id = ?"
query_params.append(user_id)
query += f" ORDER BY created_at {order_dir.value}"
cursor_.execute(query, query_params)
result = cast(list[sqlite3.Row], cursor_.fetchall())
@@ -783,20 +789,7 @@ class SqliteSessionQueue(SessionQueueBase):
def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus:
with self._db.transaction() as cursor:
# Get total counts
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
# Get user-specific counts if user_id is provided (using a single query with CASE)
user_counts_result = []
# When user_id is provided (non-admin), only count that user's items
if user_id is not None:
cursor.execute(
"""--sql
@@ -807,48 +800,51 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id, user_id),
)
user_counts_result = cast(list[sqlite3.Row], cursor.fetchall())
else:
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
GROUP BY status
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] or 0 for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
# Process user-specific counts if available
user_pending = None
user_in_progress = None
if user_id is not None:
user_counts: dict[str, int] = {row[0]: row[1] for row in user_counts_result}
user_pending = user_counts.get("pending", 0)
user_in_progress = user_counts.get("in_progress", 0)
# For non-admin users, hide current item details if they don't own it
show_current_item = current_item is not None and (user_id is None or current_item.user_id == user_id)
return SessionQueueStatus(
queue_id=queue_id,
item_id=current_item.item_id if current_item else None,
session_id=current_item.session_id if current_item else None,
batch_id=current_item.batch_id if current_item else None,
item_id=current_item.item_id if show_current_item else None,
session_id=current_item.session_id if show_current_item else None,
batch_id=current_item.batch_id if show_current_item else None,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),
completed=counts.get("completed", 0),
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
user_pending=user_pending,
user_in_progress=user_in_progress,
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
def get_batch_status(self, queue_id: str, batch_id: str, user_id: Optional[str] = None) -> BatchStatus:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
query = """--sql
SELECT status, count(*), origin, destination
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
WHERE queue_id = ? AND batch_id = ?
"""
params: list[str] = [queue_id, batch_id]
if user_id is not None:
query += " AND user_id = ?"
params.append(user_id)
query += " GROUP BY status"
cursor.execute(query, params)
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
@@ -868,18 +864,21 @@ class SqliteSessionQueue(SessionQueueBase):
total=total,
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
def get_counts_by_destination(
self, queue_id: str, destination: str, user_id: Optional[str] = None
) -> SessionQueueCountsByDestination:
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
query = """--sql
SELECT status, count(*)
FROM session_queue
WHERE queue_id = ?
AND destination = ?
GROUP BY status
""",
(queue_id, destination),
)
WHERE queue_id = ? AND destination = ?
"""
params: list[str] = [queue_id, destination]
if user_id is not None:
query += " AND user_id = ?"
params.append(user_id)
query += " GROUP BY status"
cursor.execute(query, params)
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in counts_result)

View File

@@ -30,6 +30,8 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_27 import build_migration_27
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_28 import build_migration_28
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_29 import build_migration_29
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
@@ -77,6 +79,8 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
migrator.register_migration(build_migration_25(app_config=config, logger=logger))
migrator.register_migration(build_migration_26(app_config=config, logger=logger))
migrator.register_migration(build_migration_27())
migrator.register_migration(build_migration_28())
migrator.register_migration(build_migration_29())
migrator.run_migrations()
return db

View File

@@ -0,0 +1,45 @@
"""Migration 28: Add per-user workflow isolation columns to workflow_library.
This migration adds the database columns required for multiuser workflow isolation
to the workflow_library table:
- user_id: the owner of the workflow (defaults to 'system' for existing workflows)
- is_public: whether the workflow is shared with all users
"""
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration28Callback:
"""Migration to add user_id and is_public to the workflow_library table."""
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._update_workflow_library_table(cursor)
def _update_workflow_library_table(self, cursor: sqlite3.Cursor) -> None:
"""Add user_id and is_public columns to workflow_library table."""
cursor.execute("PRAGMA table_info(workflow_library);")
columns = [row[1] for row in cursor.fetchall()]
if "user_id" not in columns:
cursor.execute("ALTER TABLE workflow_library ADD COLUMN user_id TEXT DEFAULT 'system';")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_user_id ON workflow_library(user_id);")
if "is_public" not in columns:
cursor.execute("ALTER TABLE workflow_library ADD COLUMN is_public BOOLEAN NOT NULL DEFAULT FALSE;")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_workflow_library_is_public ON workflow_library(is_public);")
def build_migration_28() -> Migration:
"""Builds the migration object for migrating from version 27 to version 28.
This migration adds per-user workflow isolation to the workflow_library table:
- user_id column: identifies the owner of each workflow
- is_public column: controls whether a workflow is shared with all users
"""
return Migration(
from_version=27,
to_version=28,
callback=Migration28Callback(),
)

View File

@@ -0,0 +1,53 @@
"""Migration 29: Add board_visibility column to boards table.
This migration adds a board_visibility column to the boards table to support
three visibility levels:
- 'private': only the board owner (and admins) can view/modify
- 'shared': all users can view, but only the owner (and admins) can modify
- 'public': all users can view; only the owner (and admins) can modify the
board structure (rename/archive/delete)
Existing boards with is_public = 1 are migrated to 'public'.
All other existing boards default to 'private'.
"""
import sqlite3
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
class Migration29Callback:
"""Migration to add board_visibility column to the boards table."""
def __call__(self, cursor: sqlite3.Cursor) -> None:
self._update_boards_table(cursor)
def _update_boards_table(self, cursor: sqlite3.Cursor) -> None:
"""Add board_visibility column to boards table."""
# Check if boards table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='boards';")
if cursor.fetchone() is None:
return
cursor.execute("PRAGMA table_info(boards);")
columns = [row[1] for row in cursor.fetchall()]
if "board_visibility" not in columns:
cursor.execute("ALTER TABLE boards ADD COLUMN board_visibility TEXT NOT NULL DEFAULT 'private';")
cursor.execute("CREATE INDEX IF NOT EXISTS idx_boards_board_visibility ON boards(board_visibility);")
# Migrate existing is_public = 1 boards to 'public'
if "is_public" in columns:
cursor.execute("UPDATE boards SET board_visibility = 'public' WHERE is_public = 1;")
def build_migration_29() -> Migration:
"""Builds the migration object for migrating from version 28 to version 29.
This migration adds the board_visibility column to the boards table,
supporting 'private', 'shared', and 'public' visibility levels.
"""
return Migration(
from_version=28,
to_version=29,
callback=Migration29Callback(),
)

View File

@@ -131,6 +131,15 @@ class UserServiceBase(ABC):
"""
pass
@abstractmethod
def get_admin_email(self) -> str | None:
"""Get the email address of the first active admin user.
Returns:
Email address of the first active admin, or None if no admin exists
"""
pass
@abstractmethod
def count_admins(self) -> int:
"""Count active admin users.

View File

@@ -256,6 +256,20 @@ class UserService(UserServiceBase):
for row in rows
]
def get_admin_email(self) -> str | None:
"""Get the email address of the first active admin user."""
with self._db.transaction() as cursor:
cursor.execute(
"""
SELECT email FROM users
WHERE is_admin = TRUE AND is_active = TRUE
ORDER BY created_at ASC
LIMIT 1
""",
)
row = cursor.fetchone()
return row[0] if row else None
def count_admins(self) -> int:
"""Count active admin users."""
with self._db.transaction() as cursor:

View File

@@ -4,6 +4,7 @@ from typing import Optional
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.workflow_records.workflow_records_common import (
WORKFLOW_LIBRARY_DEFAULT_USER_ID,
Workflow,
WorkflowCategory,
WorkflowRecordDTO,
@@ -22,18 +23,18 @@ class WorkflowRecordsStorageBase(ABC):
pass
@abstractmethod
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
def create(self, workflow: WorkflowWithoutID, user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID) -> WorkflowRecordDTO:
"""Creates a workflow."""
pass
@abstractmethod
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
"""Updates a workflow."""
def update(self, workflow: Workflow, user_id: Optional[str] = None) -> WorkflowRecordDTO:
"""Updates a workflow. When user_id is provided, the UPDATE is scoped to that user."""
pass
@abstractmethod
def delete(self, workflow_id: str) -> None:
"""Deletes a workflow."""
def delete(self, workflow_id: str, user_id: Optional[str] = None) -> None:
"""Deletes a workflow. When user_id is provided, the DELETE is scoped to that user."""
pass
@abstractmethod
@@ -47,6 +48,8 @@ class WorkflowRecordsStorageBase(ABC):
query: Optional[str],
tags: Optional[list[str]],
has_been_opened: Optional[bool],
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
"""Gets many workflows."""
pass
@@ -56,6 +59,8 @@ class WorkflowRecordsStorageBase(ABC):
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided categories."""
pass
@@ -66,19 +71,28 @@ class WorkflowRecordsStorageBase(ABC):
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> dict[str, int]:
"""Gets a dictionary of counts for each of the provided tags."""
pass
@abstractmethod
def update_opened_at(self, workflow_id: str) -> None:
"""Open a workflow."""
def update_opened_at(self, workflow_id: str, user_id: Optional[str] = None) -> None:
"""Open a workflow. When user_id is provided, the UPDATE is scoped to that user."""
pass
@abstractmethod
def get_all_tags(
self,
categories: Optional[list[WorkflowCategory]] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> list[str]:
"""Gets all unique tags from workflows."""
pass
@abstractmethod
def update_is_public(self, workflow_id: str, is_public: bool, user_id: Optional[str] = None) -> WorkflowRecordDTO:
"""Updates the is_public field of a workflow. When user_id is provided, the UPDATE is scoped to that user."""
pass

View File

@@ -9,6 +9,9 @@ from invokeai.app.util.metaenum import MetaEnum
__workflow_meta_version__ = semver.Version.parse("1.0.0")
WORKFLOW_LIBRARY_DEFAULT_USER_ID = "system"
"""Default user_id for workflows created in single-user mode or migrated from pre-multiuser databases."""
class ExposedField(BaseModel):
nodeId: str
@@ -26,6 +29,7 @@ class WorkflowRecordOrderBy(str, Enum, metaclass=MetaEnum):
UpdatedAt = "updated_at"
OpenedAt = "opened_at"
Name = "name"
IsPublic = "is_public"
class WorkflowCategory(str, Enum, metaclass=MetaEnum):
@@ -100,6 +104,8 @@ class WorkflowRecordDTOBase(BaseModel):
opened_at: Optional[Union[datetime.datetime, str]] = Field(
default=None, description="The opened timestamp of the workflow."
)
user_id: str = Field(description="The id of the user who owns this workflow.")
is_public: bool = Field(description="Whether this workflow is shared with all users.")
class WorkflowRecordDTO(WorkflowRecordDTOBase):

View File

@@ -7,6 +7,7 @@ from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from invokeai.app.services.workflow_records.workflow_records_common import (
WORKFLOW_LIBRARY_DEFAULT_USER_ID,
Workflow,
WorkflowCategory,
WorkflowNotFoundError,
@@ -36,7 +37,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at, user_id, is_public
FROM workflow_library
WHERE workflow_id = ?;
""",
@@ -47,7 +48,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
def create(self, workflow: WorkflowWithoutID) -> WorkflowRecordDTO:
def create(self, workflow: WorkflowWithoutID, user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be created via this method")
@@ -57,43 +58,98 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
"""--sql
INSERT OR IGNORE INTO workflow_library (
workflow_id,
workflow
workflow,
user_id
)
VALUES (?, ?);
VALUES (?, ?, ?);
""",
(workflow_with_id.id, workflow_with_id.model_dump_json()),
(workflow_with_id.id, workflow_with_id.model_dump_json(), user_id),
)
return self.get(workflow_with_id.id)
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
def update(self, workflow: Workflow, user_id: Optional[str] = None) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be updated")
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE workflow_library
SET workflow = ?
WHERE workflow_id = ? AND category = 'user';
""",
(workflow.model_dump_json(), workflow.id),
)
if user_id is not None:
cursor.execute(
"""--sql
UPDATE workflow_library
SET workflow = ?
WHERE workflow_id = ? AND category = 'user' AND user_id = ?;
""",
(workflow.model_dump_json(), workflow.id, user_id),
)
else:
cursor.execute(
"""--sql
UPDATE workflow_library
SET workflow = ?
WHERE workflow_id = ? AND category = 'user';
""",
(workflow.model_dump_json(), workflow.id),
)
return self.get(workflow.id)
def delete(self, workflow_id: str) -> None:
def delete(self, workflow_id: str, user_id: Optional[str] = None) -> None:
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be deleted")
with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE from workflow_library
WHERE workflow_id = ? AND category = 'user';
""",
(workflow_id,),
)
if user_id is not None:
cursor.execute(
"""--sql
DELETE from workflow_library
WHERE workflow_id = ? AND category = 'user' AND user_id = ?;
""",
(workflow_id, user_id),
)
else:
cursor.execute(
"""--sql
DELETE from workflow_library
WHERE workflow_id = ? AND category = 'user';
""",
(workflow_id,),
)
return None
def update_is_public(self, workflow_id: str, is_public: bool, user_id: Optional[str] = None) -> WorkflowRecordDTO:
"""Updates the is_public field of a workflow and manages the 'shared' tag automatically."""
record = self.get(workflow_id)
workflow = record.workflow
# Manage "shared" tag: add when public, remove when private
tags_list = [t.strip() for t in workflow.tags.split(",") if t.strip()] if workflow.tags else []
if is_public and "shared" not in tags_list:
tags_list.append("shared")
elif not is_public and "shared" in tags_list:
tags_list.remove("shared")
updated_tags = ", ".join(tags_list)
updated_workflow = workflow.model_copy(update={"tags": updated_tags})
with self._db.transaction() as cursor:
if user_id is not None:
cursor.execute(
"""--sql
UPDATE workflow_library
SET workflow = ?, is_public = ?
WHERE workflow_id = ? AND category = 'user' AND user_id = ?;
""",
(updated_workflow.model_dump_json(), is_public, workflow_id, user_id),
)
else:
cursor.execute(
"""--sql
UPDATE workflow_library
SET workflow = ?, is_public = ?
WHERE workflow_id = ? AND category = 'user';
""",
(updated_workflow.model_dump_json(), is_public, workflow_id),
)
return self.get(workflow_id)
def get_many(
self,
order_by: WorkflowRecordOrderBy,
@@ -104,6 +160,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
query: Optional[str] = None,
tags: Optional[list[str]] = None,
has_been_opened: Optional[bool] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
with self._db.transaction() as cursor:
# sanitize!
@@ -122,7 +180,9 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
created_at,
updated_at,
opened_at,
tags
tags,
user_id,
is_public
FROM workflow_library
"""
count_query = "SELECT COUNT(*) FROM workflow_library"
@@ -177,6 +237,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
conditions.append(query_condition)
params.extend([wildcard_query, wildcard_query, wildcard_query])
if user_id is not None:
# Scope to the given user but always include default workflows
conditions.append("(user_id = ? OR category = 'default')")
params.append(user_id)
if is_public is True:
conditions.append("is_public = TRUE")
elif is_public is False:
conditions.append("is_public = FALSE")
if conditions:
# If there are conditions, add a WHERE clause and then join the conditions
main_query += " WHERE "
@@ -226,6 +296,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> dict[str, int]:
if not tags:
return {}
@@ -248,6 +320,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if user_id is not None:
# Scope to the given user but always include default workflows
base_conditions.append("(user_id = ? OR category = 'default')")
base_params.append(user_id)
if is_public is True:
base_conditions.append("is_public = TRUE")
elif is_public is False:
base_conditions.append("is_public = FALSE")
# For each tag to count, run a separate query
for tag in tags:
# Start with the base conditions
@@ -277,6 +359,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> dict[str, int]:
with self._db.transaction() as cursor:
result: dict[str, int] = {}
@@ -296,6 +380,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
elif has_been_opened is False:
base_conditions.append("opened_at IS NULL")
if user_id is not None:
# Scope to the given user but always include default workflows
base_conditions.append("(user_id = ? OR category = 'default')")
base_params.append(user_id)
if is_public is True:
base_conditions.append("is_public = TRUE")
elif is_public is False:
base_conditions.append("is_public = FALSE")
# For each category to count, run a separate query
for category in categories:
# Start with the base conditions
@@ -321,20 +415,32 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
return result
def update_opened_at(self, workflow_id: str) -> None:
def update_opened_at(self, workflow_id: str, user_id: Optional[str] = None) -> None:
with self._db.transaction() as cursor:
cursor.execute(
f"""--sql
UPDATE workflow_library
SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW')
WHERE workflow_id = ?;
""",
(workflow_id,),
)
if user_id is not None:
cursor.execute(
f"""--sql
UPDATE workflow_library
SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW')
WHERE workflow_id = ? AND user_id = ?;
""",
(workflow_id, user_id),
)
else:
cursor.execute(
f"""--sql
UPDATE workflow_library
SET opened_at = STRFTIME('{SQL_TIME_FORMAT}', 'NOW')
WHERE workflow_id = ?;
""",
(workflow_id,),
)
def get_all_tags(
self,
categories: Optional[list[WorkflowCategory]] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> list[str]:
with self._db.transaction() as cursor:
conditions: list[str] = []
@@ -349,6 +455,16 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
conditions.append(f"category IN ({placeholders})")
params.extend([category.value for category in categories])
if user_id is not None:
# Scope to the given user but always include default workflows
conditions.append("(user_id = ? OR category = 'default')")
params.append(user_id)
if is_public is True:
conditions.append("is_public = TRUE")
elif is_public is False:
conditions.append("is_public = FALSE")
stmt = """--sql
SELECT DISTINCT tags
FROM workflow_library

View File

@@ -93,6 +93,29 @@ COGVIEW4_LATENT_RGB_FACTORS = [
[-0.00955853, -0.00980067, -0.00977842],
]
# Qwen Image uses the same VAE as Wan 2.1 (16-channel).
# Factors from ComfyUI: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py
QWEN_IMAGE_LATENT_RGB_FACTORS = [
[-0.1299, -0.1692, 0.2932],
[0.0671, 0.0406, 0.0442],
[0.3568, 0.2548, 0.1747],
[0.0372, 0.2344, 0.1420],
[0.0313, 0.0189, -0.0328],
[0.0296, -0.0956, -0.0665],
[-0.3477, -0.4059, -0.2925],
[0.0166, 0.1902, 0.1975],
[-0.0412, 0.0267, -0.1364],
[-0.1293, 0.0740, 0.1636],
[0.0680, 0.3019, 0.1128],
[0.0032, 0.0581, 0.0639],
[-0.1251, 0.0927, 0.1699],
[0.0060, -0.0633, 0.0005],
[0.3477, 0.2275, 0.2950],
[0.1984, 0.0913, 0.1861],
]
QWEN_IMAGE_LATENT_RGB_BIAS = [-0.1835, -0.0868, -0.3360]
# FLUX.2 uses 32 latent channels.
# Factors from ComfyUI: https://github.com/Comfy-Org/ComfyUI/blob/main/comfy/latent_formats.py
FLUX2_LATENT_RGB_FACTORS = [
@@ -232,6 +255,9 @@ def diffusion_step_callback(
latent_rgb_factors = SD3_5_LATENT_RGB_FACTORS
elif base_model == BaseModelType.CogView4:
latent_rgb_factors = COGVIEW4_LATENT_RGB_FACTORS
elif base_model == BaseModelType.QwenImage:
latent_rgb_factors = QWEN_IMAGE_LATENT_RGB_FACTORS
latent_rgb_bias = QWEN_IMAGE_LATENT_RGB_BIAS
elif base_model == BaseModelType.Flux:
latent_rgb_factors = FLUX_LATENT_RGB_FACTORS
elif base_model == BaseModelType.Flux2:

View File

@@ -50,6 +50,7 @@ from invokeai.backend.model_manager.configs.lora import (
LoRA_LyCORIS_Anima_Config,
LoRA_LyCORIS_Flux2_Config,
LoRA_LyCORIS_FLUX_Config,
LoRA_LyCORIS_QwenImage_Config,
LoRA_LyCORIS_SD1_Config,
LoRA_LyCORIS_SD2_Config,
LoRA_LyCORIS_SDXL_Config,
@@ -71,6 +72,7 @@ from invokeai.backend.model_manager.configs.main import (
Main_Diffusers_CogView4_Config,
Main_Diffusers_Flux2_Config,
Main_Diffusers_FLUX_Config,
Main_Diffusers_QwenImage_Config,
Main_Diffusers_SD1_Config,
Main_Diffusers_SD2_Config,
Main_Diffusers_SD3_Config,
@@ -79,6 +81,7 @@ from invokeai.backend.model_manager.configs.main import (
Main_Diffusers_ZImage_Config,
Main_GGUF_Flux2_Config,
Main_GGUF_FLUX_Config,
Main_GGUF_QwenImage_Config,
Main_GGUF_ZImage_Config,
MainModelDefaultSettings,
)
@@ -163,6 +166,7 @@ AnyModelConfig = Annotated[
Annotated[Main_Diffusers_FLUX_Config, Main_Diffusers_FLUX_Config.get_tag()],
Annotated[Main_Diffusers_Flux2_Config, Main_Diffusers_Flux2_Config.get_tag()],
Annotated[Main_Diffusers_CogView4_Config, Main_Diffusers_CogView4_Config.get_tag()],
Annotated[Main_Diffusers_QwenImage_Config, Main_Diffusers_QwenImage_Config.get_tag()],
Annotated[Main_Diffusers_ZImage_Config, Main_Diffusers_ZImage_Config.get_tag()],
# Main (Pipeline) - checkpoint format
# IMPORTANT: FLUX.2 must be checked BEFORE FLUX.1 because FLUX.2 has specific validation
@@ -181,6 +185,7 @@ AnyModelConfig = Annotated[
Annotated[Main_BnBNF4_FLUX_Config, Main_BnBNF4_FLUX_Config.get_tag()],
Annotated[Main_GGUF_Flux2_Config, Main_GGUF_Flux2_Config.get_tag()],
Annotated[Main_GGUF_FLUX_Config, Main_GGUF_FLUX_Config.get_tag()],
Annotated[Main_GGUF_QwenImage_Config, Main_GGUF_QwenImage_Config.get_tag()],
Annotated[Main_GGUF_ZImage_Config, Main_GGUF_ZImage_Config.get_tag()],
# VAE - checkpoint format
Annotated[VAE_Checkpoint_SD1_Config, VAE_Checkpoint_SD1_Config.get_tag()],
@@ -213,6 +218,7 @@ AnyModelConfig = Annotated[
Annotated[LoRA_LyCORIS_Flux2_Config, LoRA_LyCORIS_Flux2_Config.get_tag()],
Annotated[LoRA_LyCORIS_FLUX_Config, LoRA_LyCORIS_FLUX_Config.get_tag()],
Annotated[LoRA_LyCORIS_ZImage_Config, LoRA_LyCORIS_ZImage_Config.get_tag()],
Annotated[LoRA_LyCORIS_QwenImage_Config, LoRA_LyCORIS_QwenImage_Config.get_tag()],
Annotated[LoRA_LyCORIS_Anima_Config, LoRA_LyCORIS_Anima_Config.get_tag()],
# LoRA - OMI format
Annotated[LoRA_OMI_SDXL_Config, LoRA_OMI_SDXL_Config.get_tag()],

View File

@@ -772,6 +772,85 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
raise NotAMatchError("model does not look like a Z-Image LoRA")
class LoRA_LyCORIS_QwenImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
"""Model config for Qwen Image Edit LoRA models in LyCORIS format."""
base: Literal[BaseModelType.QwenImage] = Field(default=BaseModelType.QwenImage)
@classmethod
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
"""Qwen Image Edit LoRAs have keys like transformer_blocks.X.attn.to_k.lora_down.weight."""
state_dict = mod.load_state_dict()
has_qwen_ie_keys = state_dict_has_any_keys_starting_with(
state_dict,
{
"transformer_blocks.",
"transformer.transformer_blocks.",
"lora_unet_transformer_blocks_", # Kohya format
},
)
has_lora_suffix = state_dict_has_any_keys_ending_with(
state_dict,
{
"lora_A.weight",
"lora_B.weight",
"lora_down.weight",
"lora_up.weight",
"dora_scale",
"lokr_w1",
"lokr_w2", # LoKR format
},
)
# Must NOT have diffusion_model.layers (Z-Image) or Flux-style keys.
# Flux LoRAs can have transformer.single_transformer_blocks or transformer.transformer_blocks
# (with the "transformer." prefix and "single_" variant) which would falsely match our check.
# Flux Kohya LoRAs use lora_unet_double_blocks or lora_unet_single_blocks.
has_z_image_keys = state_dict_has_any_keys_starting_with(state_dict, {"diffusion_model.layers."})
has_flux_keys = state_dict_has_any_keys_starting_with(
state_dict,
{
"double_blocks.",
"single_blocks.",
"single_transformer_blocks.",
"transformer.single_transformer_blocks.",
"lora_unet_double_blocks_",
"lora_unet_single_blocks_",
"lora_unet_single_transformer_blocks_",
},
)
if has_qwen_ie_keys and has_lora_suffix and not has_z_image_keys and not has_flux_keys:
return
raise NotAMatchError("model does not match Qwen Image LoRA heuristics")
@classmethod
def _get_base_or_raise(cls, mod: ModelOnDisk) -> BaseModelType:
state_dict = mod.load_state_dict()
has_qwen_ie_keys = state_dict_has_any_keys_starting_with(
state_dict,
{"transformer_blocks.", "transformer.transformer_blocks.", "lora_unet_transformer_blocks_"},
)
has_z_image_keys = state_dict_has_any_keys_starting_with(state_dict, {"diffusion_model.layers."})
has_flux_keys = state_dict_has_any_keys_starting_with(
state_dict,
{
"double_blocks.",
"single_blocks.",
"single_transformer_blocks.",
"transformer.single_transformer_blocks.",
"lora_unet_double_blocks_",
"lora_unet_single_blocks_",
"lora_unet_single_transformer_blocks_",
},
)
if has_qwen_ie_keys and not has_z_image_keys and not has_flux_keys:
return BaseModelType.QwenImage
raise NotAMatchError("model does not look like a Qwen Image Edit LoRA")
class LoRA_LyCORIS_Anima_Config(LoRA_LyCORIS_Config_Base, Config_Base):
"""Model config for Anima LoRA models in LyCORIS format."""

View File

@@ -28,6 +28,7 @@ from invokeai.backend.model_manager.taxonomy import (
ModelFormat,
ModelType,
ModelVariantType,
QwenImageVariantType,
SchedulerPredictionType,
SubModelType,
ZImageVariantType,
@@ -86,6 +87,8 @@ class MainModelDefaultSettings(BaseModel):
else:
# Distilled models (Klein 4B, Klein 9B) use fewer steps
return cls(steps=4, cfg_scale=1.0, width=1024, height=1024)
case BaseModelType.QwenImage:
return cls(steps=40, cfg_scale=4.0, width=1024, height=1024)
case _:
# TODO(psyche): Do we want defaults for other base types?
return None
@@ -196,9 +199,11 @@ class Main_SD_Checkpoint_Config_Base(Checkpoint_Config_Base, Main_Config_Base):
cls._validate_base(mod)
prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod)
prediction_type = override_fields.pop("prediction_type", None) or cls._get_scheduler_prediction_type_or_raise(
mod
)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod)
return cls(**override_fields, prediction_type=prediction_type, variant=variant)
@@ -471,7 +476,7 @@ class Main_Checkpoint_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Conf
cls._validate_does_not_look_like_gguf_quantized(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@@ -546,7 +551,7 @@ class Main_Checkpoint_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Con
cls._validate_does_not_look_like_gguf_quantized(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@@ -609,7 +614,7 @@ class Main_BnBNF4_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_B
cls._validate_model_looks_like_bnb_quantized(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@@ -660,7 +665,7 @@ class Main_GGUF_FLUX_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Bas
cls._validate_is_not_flux2(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@@ -718,7 +723,7 @@ class Main_GGUF_Flux2_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Ba
cls._validate_is_flux2(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod)
return cls(**override_fields, variant=variant)
@@ -779,9 +784,9 @@ class Main_Diffusers_FLUX_Config(Diffusers_Config_Base, Main_Config_Base, Config
},
)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
@@ -833,9 +838,9 @@ class Main_Diffusers_Flux2_Config(Diffusers_Config_Base, Main_Config_Base, Confi
},
)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
@@ -904,11 +909,13 @@ class Main_SD_Diffusers_Config_Base(Diffusers_Config_Base, Main_Config_Base):
cls._validate_base(mod)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod)
prediction_type = override_fields.get("prediction_type") or cls._get_scheduler_prediction_type_or_raise(mod)
prediction_type = override_fields.pop("prediction_type", None) or cls._get_scheduler_prediction_type_or_raise(
mod
)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
@@ -1014,9 +1021,9 @@ class Main_Diffusers_SD3_Config(Diffusers_Config_Base, Main_Config_Base, Config_
},
)
submodels = override_fields.get("submodels") or cls._get_submodels_or_raise(mod)
submodels = override_fields.pop("submodels", None) or cls._get_submodels_or_raise(mod)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
@@ -1089,7 +1096,7 @@ class Main_Diffusers_CogView4_Config(Diffusers_Config_Base, Main_Config_Base, Co
},
)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
@@ -1155,9 +1162,9 @@ class Main_Diffusers_ZImage_Config(Diffusers_Config_Base, Main_Config_Base, Conf
},
)
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
variant = override_fields.pop("variant", None) or cls._get_variant_or_raise(mod)
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod)
return cls(
**override_fields,
@@ -1201,7 +1208,7 @@ class Main_Checkpoint_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Co
cls._validate_does_not_look_like_gguf_quantized(mod)
variant = override_fields.get("variant", ZImageVariantType.Turbo)
variant = override_fields.pop("variant", None) or ZImageVariantType.Turbo
return cls(**override_fields, variant=variant)
@@ -1235,7 +1242,7 @@ class Main_GGUF_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_B
cls._validate_looks_like_gguf_quantized(mod)
variant = override_fields.get("variant", ZImageVariantType.Turbo)
variant = override_fields.pop("variant", None) or ZImageVariantType.Turbo
return cls(**override_fields, variant=variant)
@@ -1252,6 +1259,106 @@ class Main_GGUF_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_B
raise NotAMatchError("state dict does not look like GGUF quantized")
class Main_Diffusers_QwenImage_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
"""Model config for Qwen Image diffusers models (both txt2img and edit)."""
base: Literal[BaseModelType.QwenImage] = Field(BaseModelType.QwenImage)
variant: QwenImageVariantType | None = Field(default=None)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_dir(mod)
raise_for_override_fields(cls, override_fields)
# This check implies the base type - no further validation needed.
raise_for_class_name(
common_config_paths(mod.path),
{
"QwenImagePlusPipeline",
"QwenImageEditPlusPipeline",
"QwenImagePipeline",
},
)
repo_variant = override_fields.pop("repo_variant", None) or cls._get_repo_variant_or_raise(mod)
variant = override_fields.pop("variant", None) or cls._get_qwen_image_variant(mod)
return cls(
**override_fields,
repo_variant=repo_variant,
variant=variant,
)
@classmethod
def _get_qwen_image_variant(cls, mod: ModelOnDisk) -> QwenImageVariantType:
"""Detect whether this is an edit or txt2img model from the pipeline class name."""
import json
model_index = mod.path / "model_index.json"
if model_index.exists():
with open(model_index) as f:
config = json.load(f)
class_name = config.get("_class_name", "")
if "Edit" in class_name:
return QwenImageVariantType.Edit
return QwenImageVariantType.Generate
def _has_qwen_image_keys(state_dict: dict[str | int, Any]) -> bool:
"""Check if state dict contains Qwen Image Edit transformer keys.
Qwen Image Edit uses 'txt_in' and 'txt_norm' instead of 'context_embedder' (FLUX).
This distinguishes it from FLUX and other architectures.
"""
has_txt_in = any(isinstance(k, str) and k.startswith("txt_in.") for k in state_dict.keys())
has_txt_norm = any(isinstance(k, str) and k.startswith("txt_norm.") for k in state_dict.keys())
has_img_in = any(isinstance(k, str) and k.startswith("img_in.") for k in state_dict.keys())
# Must NOT have context_embedder (which would indicate FLUX)
has_context_embedder = any(isinstance(k, str) and "context_embedder" in k for k in state_dict.keys())
return has_txt_in and has_txt_norm and has_img_in and not has_context_embedder
class Main_GGUF_QwenImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for GGUF-quantized Qwen Image transformer models."""
base: Literal[BaseModelType.QwenImage] = Field(default=BaseModelType.QwenImage)
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
variant: QwenImageVariantType | None = Field(default=None)
@classmethod
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
raise_if_not_file(mod)
raise_for_override_fields(cls, override_fields)
sd = mod.load_state_dict()
if not _has_qwen_image_keys(sd):
raise NotAMatchError("state dict does not look like a Qwen Image Edit model")
if not _has_ggml_tensors(sd):
raise NotAMatchError("state dict does not look like GGUF quantized")
# Infer variant from the state dict if not explicitly provided.
# The Edit variant includes an extra tensor `__index_timestep_zero__` (used by the
# `zero_cond_t` dual-modulation path in diffusers' QwenImageTransformer2DModel).
# If the marker tensor is missing, fall back to the filename heuristic since older
# or alternate GGUF converters may not emit it.
explicit_variant = override_fields.pop("variant", None)
if explicit_variant is None:
if "__index_timestep_zero__" in sd:
explicit_variant = QwenImageVariantType.Edit
else:
filename = mod.path.stem.lower()
if "edit" in filename:
explicit_variant = QwenImageVariantType.Edit
else:
explicit_variant = QwenImageVariantType.Generate
return cls(**override_fields, variant=explicit_variant)
class Main_Checkpoint_Anima_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
"""Model config for Anima single-file checkpoint models (safetensors).

View File

@@ -57,6 +57,9 @@ from invokeai.backend.patches.lora_conversions.flux_xlabs_lora_conversion_utils
is_state_dict_likely_in_flux_xlabs_format,
lora_model_from_flux_xlabs_state_dict,
)
from invokeai.backend.patches.lora_conversions.qwen_image_lora_conversion_utils import (
lora_model_from_qwen_image_state_dict,
)
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import lora_model_from_sd_state_dict
from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import convert_sdxl_keys_to_diffusers_format
from invokeai.backend.patches.lora_conversions.z_image_lora_conversion_utils import lora_model_from_z_image_state_dict
@@ -162,6 +165,8 @@ class LoRALoader(ModelLoader):
# Z-Image LoRAs use diffusers PEFT format with transformer and/or Qwen3 encoder layers.
# We set alpha=None to use rank as alpha (common default).
model = lora_model_from_z_image_state_dict(state_dict=state_dict, alpha=None)
elif self._model_base == BaseModelType.QwenImage:
model = lora_model_from_qwen_image_state_dict(state_dict=state_dict, alpha=None)
elif self._model_base == BaseModelType.Anima:
# Anima LoRAs use Kohya-style or diffusers PEFT format targeting Cosmos DiT blocks.
model = lora_model_from_anima_state_dict(state_dict=state_dict, alpha=None)

View File

@@ -0,0 +1,177 @@
from pathlib import Path
from typing import Optional
import accelerate
import torch
from invokeai.backend.model_manager.configs.base import Checkpoint_Config_Base, Diffusers_Config_Base
from invokeai.backend.model_manager.configs.factory import AnyModelConfig
from invokeai.backend.model_manager.configs.main import Main_GGUF_QwenImage_Config
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
ModelFormat,
ModelType,
QwenImageVariantType,
SubModelType,
)
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from invokeai.backend.quantization.gguf.loaders import gguf_sd_loader
from invokeai.backend.util.devices import TorchDevice
@ModelLoaderRegistry.register(base=BaseModelType.QwenImage, type=ModelType.Main, format=ModelFormat.Diffusers)
class QwenImageDiffusersModel(GenericDiffusersLoader):
"""Class to load Qwen Image Edit main models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, Checkpoint_Config_Base):
raise NotImplementedError("CheckpointConfigBase is not implemented for Qwen Image Edit models.")
if submodel_type is None:
raise Exception("A submodel type must be provided when loading main pipelines.")
model_path = Path(config.path)
load_class = self.get_hf_load_class(model_path, submodel_type)
repo_variant = config.repo_variant if isinstance(config, Diffusers_Config_Base) else None
variant = repo_variant.value if repo_variant else None
model_path = model_path / submodel_type.value
# We force bfloat16 for Qwen Image Edit models.
# Use `dtype` (newer) with fallback to `torch_dtype` (older diffusers).
dtype_kwarg = {"dtype": torch.bfloat16}
try:
result: AnyModel = load_class.from_pretrained(
model_path,
**dtype_kwarg,
variant=variant,
local_files_only=True,
)
except TypeError:
# Older diffusers uses torch_dtype instead of dtype
dtype_kwarg = {"torch_dtype": torch.bfloat16}
result = load_class.from_pretrained(
model_path,
**dtype_kwarg,
variant=variant,
local_files_only=True,
)
except OSError as e:
if variant and "no file named" in str(e):
result = load_class.from_pretrained(model_path, **dtype_kwarg, local_files_only=True)
else:
raise e
return result
@ModelLoaderRegistry.register(base=BaseModelType.QwenImage, type=ModelType.Main, format=ModelFormat.GGUFQuantized)
class QwenImageGGUFCheckpointModel(ModelLoader):
"""Class to load GGUF-quantized Qwen Image Edit transformer models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not isinstance(config, Checkpoint_Config_Base):
raise ValueError("Only CheckpointConfigBase models are currently supported here.")
match submodel_type:
case SubModelType.Transformer:
return self._load_from_singlefile(config)
raise ValueError(
f"Only Transformer submodels are currently supported. Received: {submodel_type.value if submodel_type else 'None'}"
)
def _load_from_singlefile(self, config: AnyModelConfig) -> AnyModel:
from diffusers import QwenImageTransformer2DModel
if not isinstance(config, Main_GGUF_QwenImage_Config):
raise TypeError(f"Expected Main_GGUF_QwenImage_Config, got {type(config).__name__}.")
model_path = Path(config.path)
target_device = TorchDevice.choose_torch_device()
compute_dtype = TorchDevice.choose_bfloat16_safe_dtype(target_device)
sd = gguf_sd_loader(model_path, compute_dtype=compute_dtype)
# Strip ComfyUI-style prefixes if present
prefix_to_strip = None
for prefix in ["model.diffusion_model.", "diffusion_model."]:
if any(k.startswith(prefix) for k in sd.keys() if isinstance(k, str)):
prefix_to_strip = prefix
break
if prefix_to_strip:
stripped_sd = {}
for key, value in sd.items():
if isinstance(key, str) and key.startswith(prefix_to_strip):
stripped_sd[key[len(prefix_to_strip) :]] = value
else:
stripped_sd[key] = value
sd = stripped_sd
# Auto-detect architecture from state dict
num_layers = 0
for key in sd.keys():
if isinstance(key, str) and key.startswith("transformer_blocks."):
parts = key.split(".")
if len(parts) >= 2:
try:
layer_idx = int(parts[1])
num_layers = max(num_layers, layer_idx + 1)
except ValueError:
pass
# Detect dimensions from weights
num_attention_heads = 24 # default
attention_head_dim = 128 # default
if "img_in.weight" in sd:
w = sd["img_in.weight"]
shape = w.tensor_shape if isinstance(w, GGMLTensor) else w.shape
hidden_dim = shape[0]
in_channels = shape[1]
num_attention_heads = hidden_dim // attention_head_dim
joint_attention_dim = 3584 # default
if "txt_in.weight" in sd:
w = sd["txt_in.weight"]
shape = w.tensor_shape if isinstance(w, GGMLTensor) else w.shape
joint_attention_dim = shape[1]
model_config: dict = {
"patch_size": 2,
"in_channels": in_channels if "img_in.weight" in sd else 64,
"out_channels": 16,
"num_layers": num_layers if num_layers > 0 else 60,
"attention_head_dim": attention_head_dim,
"num_attention_heads": num_attention_heads,
"joint_attention_dim": joint_attention_dim,
"guidance_embeds": False,
"axes_dims_rope": (16, 56, 56),
}
# zero_cond_t is only used by edit-variant models. It enables dual modulation
# for noisy vs reference patches. Setting it on txt2img models produces garbage.
# Also requires diffusers 0.37+ (the parameter doesn't exist in older versions).
import inspect
is_edit = getattr(config, "variant", None) == QwenImageVariantType.Edit
if is_edit and "zero_cond_t" in inspect.signature(QwenImageTransformer2DModel.__init__).parameters:
model_config["zero_cond_t"] = True
with accelerate.init_empty_weights():
model = QwenImageTransformer2DModel(**model_config)
model.load_state_dict(sd, strict=False, assign=True)
return model

View File

@@ -19,8 +19,7 @@ from pathlib import Path
from typing import Optional
import requests
from huggingface_hub import HfApi, configure_http_backend, hf_hub_url
from huggingface_hub.errors import RepositoryNotFoundError, RevisionNotFoundError
from huggingface_hub import hf_hub_url
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
@@ -47,7 +46,6 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
this module without an internet connection.
"""
self._requests = session or requests.Session()
configure_http_backend(backend_factory=lambda: self._requests)
@classmethod
def from_json(cls, json: str) -> HuggingFaceMetadata:
@@ -55,6 +53,22 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
metadata = HuggingFaceMetadata.model_validate_json(json)
return metadata
def _fetch_model_info(self, repo_id: str, variant: Optional[ModelRepoVariant] = None) -> dict:
"""Fetch model info from HuggingFace API using self._requests session.
This allows the session to be mocked in tests via requests_testadapter.
"""
url = f"https://huggingface.co/api/models/{repo_id}"
params: dict[str, str] = {"blobs": "True"}
if variant is not None:
params["revision"] = str(variant)
response = self._requests.get(url, params=params)
if response.status_code == 404:
raise UnknownMetadataException(f"'{repo_id}' not found.")
response.raise_for_status()
return response.json()
def from_id(self, id: str, variant: Optional[ModelRepoVariant] = None) -> AnyModelRepoMetadata:
"""Return a HuggingFaceMetadata object given the model's repo_id."""
# Little loop which tries fetching a revision corresponding to the selected variant.
@@ -67,10 +81,10 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
repo_id = id.split("::")[0] or id
while not model_info:
try:
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True, revision=variant)
except RepositoryNotFoundError as excp:
raise UnknownMetadataException(f"'{repo_id}' not found. See trace for details.") from excp
except RevisionNotFoundError:
model_info = self._fetch_model_info(repo_id, variant)
except UnknownMetadataException:
raise
except requests.HTTPError:
if variant is None:
raise
else:
@@ -80,15 +94,18 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
_, name = repo_id.split("/")
for s in model_info.siblings or []:
assert s.rfilename is not None
assert s.size is not None
for s in model_info.get("siblings") or []:
rfilename = s.get("rfilename")
size = s.get("size")
assert rfilename is not None
assert size is not None
lfs = s.get("lfs")
files.append(
RemoteModelFile(
url=hf_hub_url(repo_id, s.rfilename, revision=variant or "main"),
path=Path(name, s.rfilename),
size=s.size,
sha256=s.lfs.get("sha256") if s.lfs else None,
url=hf_hub_url(repo_id, rfilename, revision=variant or "main"),
path=Path(name, rfilename),
size=size,
sha256=lfs.get("sha256") if lfs else None,
)
)
@@ -115,10 +132,10 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
)
return HuggingFaceMetadata(
id=model_info.id,
id=model_info["id"],
name=name,
files=files,
api_response=json.dumps(model_info.__dict__, default=str),
api_response=json.dumps(model_info, default=str),
is_diffusers=is_diffusers,
ckpt_urls=ckpt_urls,
)

View File

@@ -17,7 +17,7 @@ remote repo.
from pathlib import Path
from typing import List, Literal, Optional, Union
from huggingface_hub import configure_http_backend, hf_hub_url
from huggingface_hub import hf_hub_url
from pydantic import BaseModel, Field, TypeAdapter
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
@@ -111,7 +111,6 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
full-precision model is returned.
"""
session = session or Session()
configure_http_backend(backend_factory=lambda: session) # used in testing
paths = filter_files([x.path for x in self.files], variant, subfolder, subfolders) # all files in the model

View File

@@ -9,7 +9,13 @@ from invokeai.backend.model_manager.configs.external_api import (
ExternalModelPanelSchema,
ExternalResolutionPreset,
)
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,
ModelFormat,
ModelType,
QwenImageVariantType,
)
class StarterModelWithoutDependencies(BaseModel):
@@ -19,6 +25,7 @@ class StarterModelWithoutDependencies(BaseModel):
base: BaseModelType
type: ModelType
format: Optional[ModelFormat] = None
variant: Optional[AnyVariant] = None
is_installed: bool = False
capabilities: ExternalModelCapabilities | None = None
default_settings: ExternalApiModelDefaultSettings | None = None
@@ -659,6 +666,138 @@ cogview4 = StarterModel(
)
# endregion
# region Qwen Image Edit
qwen_image_edit = StarterModel(
name="Qwen Image Edit 2511",
base=BaseModelType.QwenImage,
source="Qwen/Qwen-Image-Edit-2511",
description="Qwen Image Edit 2511 full diffusers model. Supports text-guided image editing with multiple reference images. (~40GB)",
type=ModelType.Main,
variant=QwenImageVariantType.Edit,
)
qwen_image_edit_gguf_q4_k_m = StarterModel(
name="Qwen Image Edit 2511 (Q4_K_M)",
base=BaseModelType.QwenImage,
source="https://huggingface.co/unsloth/Qwen-Image-Edit-2511-GGUF/resolve/main/qwen-image-edit-2511-Q4_K_M.gguf",
description="Qwen Image Edit 2511 - Q4_K_M quantized transformer. Good quality/size balance. (~13GB)",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
variant=QwenImageVariantType.Edit,
)
qwen_image_edit_gguf_q2_k = StarterModel(
name="Qwen Image Edit 2511 (Q2_K)",
base=BaseModelType.QwenImage,
source="https://huggingface.co/unsloth/Qwen-Image-Edit-2511-GGUF/resolve/main/qwen-image-edit-2511-Q2_K.gguf",
description="Qwen Image Edit 2511 - Q2_K heavily quantized transformer. Smallest size, lower quality. (~7.5GB)",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
variant=QwenImageVariantType.Edit,
)
qwen_image_edit_gguf_q6_k = StarterModel(
name="Qwen Image Edit 2511 (Q6_K)",
base=BaseModelType.QwenImage,
source="https://huggingface.co/unsloth/Qwen-Image-Edit-2511-GGUF/resolve/main/qwen-image-edit-2511-Q6_K.gguf",
description="Qwen Image Edit 2511 - Q6_K quantized transformer. Near-lossless quality. (~17GB)",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
variant=QwenImageVariantType.Edit,
)
qwen_image_edit_gguf_q8_0 = StarterModel(
name="Qwen Image Edit 2511 (Q8_0)",
base=BaseModelType.QwenImage,
source="https://huggingface.co/unsloth/Qwen-Image-Edit-2511-GGUF/resolve/main/qwen-image-edit-2511-Q8_0.gguf",
description="Qwen Image Edit 2511 - Q8_0 quantized transformer. Highest quality quantization. (~22GB)",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
variant=QwenImageVariantType.Edit,
)
qwen_image_edit_lightning_4step = StarterModel(
name="Qwen Image Edit Lightning (4-step, bf16)",
base=BaseModelType.QwenImage,
source="https://huggingface.co/lightx2v/Qwen-Image-Edit-2511-Lightning/resolve/main/Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors",
description="Lightning distillation LoRA for Qwen Image Edit — enables generation in just 4 steps. "
"Settings: Steps=4, CFG=1, Shift Override=3.",
type=ModelType.LoRA,
)
qwen_image_edit_lightning_8step = StarterModel(
name="Qwen Image Edit Lightning (8-step, bf16)",
base=BaseModelType.QwenImage,
source="https://huggingface.co/lightx2v/Qwen-Image-Edit-2511-Lightning/resolve/main/Qwen-Image-Edit-2511-Lightning-8steps-V1.0-bf16.safetensors",
description="Lightning distillation LoRA for Qwen Image Edit — enables generation in 8 steps with better quality. "
"Settings: Steps=8, CFG=1, Shift Override=3.",
type=ModelType.LoRA,
)
# Qwen Image (txt2img)
qwen_image = StarterModel(
name="Qwen Image 2512",
base=BaseModelType.QwenImage,
source="Qwen/Qwen-Image-2512",
description="Qwen Image 2512 full diffusers model. High-quality text-to-image generation. (~40GB)",
type=ModelType.Main,
)
qwen_image_gguf_q4_k_m = StarterModel(
name="Qwen Image 2512 (Q4_K_M)",
base=BaseModelType.QwenImage,
source="https://huggingface.co/unsloth/Qwen-Image-2512-GGUF/resolve/main/qwen-image-2512-Q4_K_M.gguf",
description="Qwen Image 2512 - Q4_K_M quantized transformer. Good quality/size balance. (~13GB)",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
)
qwen_image_gguf_q2_k = StarterModel(
name="Qwen Image 2512 (Q2_K)",
base=BaseModelType.QwenImage,
source="https://huggingface.co/unsloth/Qwen-Image-2512-GGUF/resolve/main/qwen-image-2512-Q2_K.gguf",
description="Qwen Image 2512 - Q2_K heavily quantized transformer. Smallest size, lower quality. (~7.5GB)",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
)
qwen_image_gguf_q6_k = StarterModel(
name="Qwen Image 2512 (Q6_K)",
base=BaseModelType.QwenImage,
source="https://huggingface.co/unsloth/Qwen-Image-2512-GGUF/resolve/main/qwen-image-2512-Q6_K.gguf",
description="Qwen Image 2512 - Q6_K quantized transformer. Near-lossless quality. (~17GB)",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
)
qwen_image_gguf_q8_0 = StarterModel(
name="Qwen Image 2512 (Q8_0)",
base=BaseModelType.QwenImage,
source="https://huggingface.co/unsloth/Qwen-Image-2512-GGUF/resolve/main/qwen-image-2512-Q8_0.gguf",
description="Qwen Image 2512 - Q8_0 quantized transformer. Highest quality quantization. (~22GB)",
type=ModelType.Main,
format=ModelFormat.GGUFQuantized,
)
qwen_image_lightning_4step = StarterModel(
name="Qwen Image Lightning (4-step, V2.0, bf16)",
base=BaseModelType.QwenImage,
source="https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V2.0-bf16.safetensors",
description="Lightning distillation LoRA for Qwen Image — enables generation in just 4 steps. "
"Settings: Steps=4, CFG=1, Shift Override=3.",
type=ModelType.LoRA,
)
qwen_image_lightning_8step = StarterModel(
name="Qwen Image Lightning (8-step, V2.0, bf16)",
base=BaseModelType.QwenImage,
source="https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V2.0-bf16.safetensors",
description="Lightning distillation LoRA for Qwen Image — enables generation in 8 steps with better quality. "
"Settings: Steps=8, CFG=1, Shift Override=3.",
type=ModelType.LoRA,
)
# endregion
# region SigLIP
siglip = StarterModel(
name="SigLIP - google/siglip-so400m-patch14-384",
@@ -1225,6 +1364,20 @@ STARTER_MODELS: list[StarterModel] = [
flux2_klein_qwen3_4b_encoder,
flux2_klein_qwen3_8b_encoder,
cogview4,
qwen_image_edit,
qwen_image_edit_gguf_q2_k,
qwen_image_edit_gguf_q4_k_m,
qwen_image_edit_gguf_q6_k,
qwen_image_edit_gguf_q8_0,
qwen_image_edit_lightning_4step,
qwen_image_edit_lightning_8step,
qwen_image,
qwen_image_gguf_q2_k,
qwen_image_gguf_q4_k_m,
qwen_image_gguf_q6_k,
qwen_image_gguf_q8_0,
qwen_image_lightning_4step,
qwen_image_lightning_8step,
flux_krea,
flux_krea_quantized,
z_image_turbo,
@@ -1313,6 +1466,19 @@ flux2_klein_bundle: list[StarterModel] = [
flux2_klein_qwen3_4b_encoder,
]
qwen_image_bundle: list[StarterModel] = [
qwen_image_edit,
qwen_image_edit_gguf_q4_k_m,
qwen_image_edit_gguf_q8_0,
qwen_image_edit_lightning_4step,
qwen_image_edit_lightning_8step,
qwen_image,
qwen_image_gguf_q4_k_m,
qwen_image_gguf_q8_0,
qwen_image_lightning_4step,
qwen_image_lightning_8step,
]
anima_bundle: list[StarterModel] = [
anima_preview3,
anima_qwen3_encoder,
@@ -1326,6 +1492,7 @@ STARTER_BUNDLES: dict[str, StarterModelBundle] = {
BaseModelType.Flux: StarterModelBundle(name="FLUX.1 dev", models=flux_bundle),
BaseModelType.Flux2: StarterModelBundle(name="FLUX.2 Klein", models=flux2_klein_bundle),
BaseModelType.ZImage: StarterModelBundle(name="Z-Image Turbo", models=zimage_bundle),
BaseModelType.QwenImage: StarterModelBundle(name="Qwen Image", models=qwen_image_bundle),
BaseModelType.Anima: StarterModelBundle(name="Anima", models=anima_bundle),
}

View File

@@ -54,6 +54,8 @@ class BaseModelType(str, Enum):
"""Indicates the model is associated with Z-Image model architecture, including Z-Image-Turbo."""
External = "external"
"""Indicates the model is hosted by an external provider."""
QwenImage = "qwen-image"
"""Indicates the model is associated with Qwen Image Edit 2511 model architecture."""
Anima = "anima"
"""Indicates the model is associated with Anima model architecture (Cosmos Predict2 DiT + LLM Adapter)."""
Unknown = "unknown"
@@ -148,6 +150,16 @@ class ZImageVariantType(str, Enum):
"""Z-Image Base - undistilled foundation model with full CFG and negative prompt support."""
class QwenImageVariantType(str, Enum):
"""Qwen Image model variants."""
Generate = "generate"
"""Qwen Image - text-to-image generation model."""
Edit = "edit"
"""Qwen Image Edit - image editing model with reference image support."""
class Qwen3VariantType(str, Enum):
"""Qwen3 text encoder variants based on model size."""
@@ -224,8 +236,28 @@ class FluxLoRAFormat(str, Enum):
AnyVariant: TypeAlias = Union[
ModelVariantType, ClipVariantType, FluxVariantType, Flux2VariantType, ZImageVariantType, Qwen3VariantType
ModelVariantType,
ClipVariantType,
FluxVariantType,
Flux2VariantType,
ZImageVariantType,
QwenImageVariantType,
Qwen3VariantType,
]
variant_type_adapter = TypeAdapter[
ModelVariantType | ClipVariantType | FluxVariantType | Flux2VariantType | ZImageVariantType | Qwen3VariantType
](ModelVariantType | ClipVariantType | FluxVariantType | Flux2VariantType | ZImageVariantType | Qwen3VariantType)
ModelVariantType
| ClipVariantType
| FluxVariantType
| Flux2VariantType
| ZImageVariantType
| QwenImageVariantType
| Qwen3VariantType
](
ModelVariantType
| ClipVariantType
| FluxVariantType
| Flux2VariantType
| ZImageVariantType
| QwenImageVariantType
| Qwen3VariantType
)

View File

@@ -0,0 +1,5 @@
# Qwen Image Edit LoRA prefix constants
# These prefixes are used for key mapping when applying LoRA patches to Qwen Image Edit models
# Prefix for Qwen Image Edit transformer LoRA layers
QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX = "lora_transformer-"

View File

@@ -0,0 +1,197 @@
"""Qwen Image LoRA conversion utilities.
Qwen Image uses QwenImageTransformer2DModel architecture.
Supports multiple LoRA formats:
- Diffusers/PEFT: transformer_blocks.0.attn.to_k.lora_down.weight
- With prefix: transformer.transformer_blocks.0.attn.to_k.lora_down.weight
- Kohya: lora_unet_transformer_blocks_0_attn_to_k.lora_down.weight (underscores instead of dots)
"""
import re
from typing import Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.qwen_image_lora_constants import (
QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX,
)
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
# Regex for Kohya-format Qwen Image LoRA keys.
# Example: lora_unet_transformer_blocks_0_attn_to_k
# Groups: (block_idx, sub_module_with_underscores)
_KOHYA_KEY_REGEX = re.compile(r"lora_unet_transformer_blocks_(\d+)_(.*)")
# Mapping from Kohya underscore-separated sub-module names to dot-separated model paths.
# The Kohya format uses underscores everywhere, but some underscores are part of the
# module name (e.g., add_k_proj, to_out). We match the longest prefix first.
_KOHYA_MODULE_MAP: list[tuple[str, str]] = [
# Attention projections
("attn_add_k_proj", "attn.add_k_proj"),
("attn_add_q_proj", "attn.add_q_proj"),
("attn_add_v_proj", "attn.add_v_proj"),
("attn_to_add_out", "attn.to_add_out"),
("attn_to_out_0", "attn.to_out.0"),
("attn_to_k", "attn.to_k"),
("attn_to_q", "attn.to_q"),
("attn_to_v", "attn.to_v"),
# Image stream MLP and modulation
("img_mlp_net_0_proj", "img_mlp.net.0.proj"),
("img_mlp_net_2", "img_mlp.net.2"),
("img_mod_1", "img_mod.1"),
# Text stream MLP and modulation
("txt_mlp_net_0_proj", "txt_mlp.net.0.proj"),
("txt_mlp_net_2", "txt_mlp.net.2"),
("txt_mod_1", "txt_mod.1"),
]
def is_state_dict_likely_kohya_qwen_image(state_dict: dict[str | int, torch.Tensor]) -> bool:
"""Check if the state dict uses Kohya-format Qwen Image LoRA keys."""
str_keys = [k for k in state_dict.keys() if isinstance(k, str)]
if not str_keys:
return False
# Check if any key matches the Kohya pattern
return any(k.startswith("lora_unet_transformer_blocks_") for k in str_keys)
def _convert_kohya_key(kohya_layer: str) -> str | None:
"""Convert a Kohya-format layer name to a dot-separated model module path.
Example: lora_unet_transformer_blocks_0_attn_to_k -> transformer_blocks.0.attn.to_k
"""
m = _KOHYA_KEY_REGEX.match(kohya_layer)
if not m:
return None
block_idx = m.group(1)
sub_module = m.group(2)
for kohya_name, model_path in _KOHYA_MODULE_MAP:
if sub_module == kohya_name:
return f"transformer_blocks.{block_idx}.{model_path}"
# Fallback: unknown sub-module, return None so caller can warn/skip
return None
def lora_model_from_qwen_image_state_dict(
state_dict: Dict[str, torch.Tensor], alpha: float | None = None
) -> ModelPatchRaw:
"""Convert a Qwen Image LoRA state dict to a ModelPatchRaw.
Handles three key formats:
- Diffusers/PEFT: transformer_blocks.0.attn.to_k.lora_down.weight
- With prefix: transformer.transformer_blocks.0.attn.to_k.lora_down.weight
- Kohya: lora_unet_transformer_blocks_0_attn_to_k.lora_down.weight
"""
is_kohya = is_state_dict_likely_kohya_qwen_image(state_dict)
if is_kohya:
return _convert_kohya_format(state_dict, alpha)
else:
return _convert_diffusers_format(state_dict, alpha)
def _convert_kohya_format(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> ModelPatchRaw:
"""Convert Kohya-format state dict. Keys are like lora_unet_transformer_blocks_0_attn_to_k.lokr_w1"""
layers: dict[str, BaseLayerPatch] = {}
# Group by layer (split at first dot: layer_name.param_name)
grouped: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
if not isinstance(key, str):
continue
layer_name, param_name = key.split(".", 1)
if layer_name not in grouped:
grouped[layer_name] = {}
grouped[layer_name][param_name] = value
for kohya_layer, layer_dict in grouped.items():
model_path = _convert_kohya_key(kohya_layer)
if model_path is None:
continue # Skip unrecognized layers
layer = any_lora_layer_from_state_dict(layer_dict)
final_key = f"{QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX}{model_path}"
layers[final_key] = layer
return ModelPatchRaw(layers=layers)
def _convert_diffusers_format(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> ModelPatchRaw:
"""Convert Diffusers/PEFT format state dict."""
layers: dict[str, BaseLayerPatch] = {}
# Some LoRAs use a "transformer." prefix on keys
strip_prefixes = ["transformer."]
grouped = _group_by_layer(state_dict)
for layer_key, layer_dict in grouped.items():
values = _normalize_lora_keys(layer_dict, alpha)
layer = any_lora_layer_from_state_dict(values)
clean_key = layer_key
for prefix in strip_prefixes:
if clean_key.startswith(prefix):
clean_key = clean_key[len(prefix) :]
break
final_key = f"{QWEN_IMAGE_EDIT_LORA_TRANSFORMER_PREFIX}{clean_key}"
layers[final_key] = layer
return ModelPatchRaw(layers=layers)
def _normalize_lora_keys(layer_dict: dict[str, torch.Tensor], alpha: float | None) -> dict[str, torch.Tensor]:
"""Normalize LoRA key names to internal format."""
if "lora_A.weight" in layer_dict:
values: dict[str, torch.Tensor] = {
"lora_down.weight": layer_dict["lora_A.weight"],
"lora_up.weight": layer_dict["lora_B.weight"],
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
return values
elif "lora_down.weight" in layer_dict:
return layer_dict
else:
return layer_dict
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
"""Group state dict keys by layer path."""
layer_dict: dict[str, dict[str, torch.Tensor]] = {}
known_suffixes = [
".lora_A.weight",
".lora_B.weight",
".lora_down.weight",
".lora_up.weight",
".dora_scale",
".alpha",
]
for key in state_dict:
if not isinstance(key, str):
continue
layer_name = None
key_name = None
for suffix in known_suffixes:
if key.endswith(suffix):
layer_name = key[: -len(suffix)]
key_name = suffix[1:]
break
if layer_name is None:
parts = key.rsplit(".", maxsplit=2)
layer_name = parts[0]
key_name = ".".join(parts[1:])
if layer_name not in layer_dict:
layer_dict[layer_name] = {}
layer_dict[layer_name][key_name] = state_dict[key]
return layer_dict

View File

@@ -17,7 +17,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from diffusers.utils.import_utils import is_xformers_available
from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
@@ -139,7 +139,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -151,7 +151,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: Optional[StableDiffusionSafetyChecker],
feature_extractor: Optional[CLIPFeatureExtractor],
feature_extractor: Optional[CLIPImageProcessor],
requires_safety_checker: bool = False,
):
super().__init__(

View File

@@ -88,6 +88,23 @@ class ZImageConditioningInfo:
return self
@dataclass
class QwenImageConditioningInfo:
"""Qwen Image Edit conditioning information from Qwen2.5-VL encoder."""
prompt_embeds: torch.Tensor
"""Text/image embeddings from Qwen2.5-VL encoder. Shape: (batch_size, seq_len, hidden_size)."""
prompt_embeds_mask: torch.Tensor | None = None
"""Attention mask for prompt_embeds. Shape: (batch_size, seq_len). 1 for valid, 0 for padding."""
def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None):
self.prompt_embeds = self.prompt_embeds.to(device=device, dtype=dtype)
if self.prompt_embeds_mask is not None:
self.prompt_embeds_mask = self.prompt_embeds_mask.to(device=device)
return self
@dataclass
class AnimaConditioningInfo:
"""Anima text conditioning information from Qwen3 0.6B encoder + T5-XXL tokenizer.
@@ -125,6 +142,7 @@ class ConditioningFieldData:
| List[SD3ConditioningInfo]
| List[CogView4ConditioningInfo]
| List[ZImageConditioningInfo]
| List[QwenImageConditioningInfo]
| List[AnimaConditioningInfo]
)

View File

@@ -6463,6 +6463,23 @@
"title": "Has Been Opened"
},
"description": "Whether to include/exclude recent workflows"
},
{
"name": "is_public",
"in": "query",
"required": false,
"schema": {
"anyOf": [
{
"type": "boolean"
},
{
"type": "null"
}
],
"title": "Is Public"
},
"description": "Filter by public/shared status"
}
],
"responses": {
@@ -6655,6 +6672,23 @@
"title": "Categories"
},
"description": "The categories to include"
},
{
"name": "is_public",
"in": "query",
"required": false,
"schema": {
"anyOf": [
{
"type": "boolean"
},
{
"type": "null"
}
],
"title": "Is Public"
},
"description": "Filter by public/shared status"
}
],
"responses": {
@@ -6744,6 +6778,23 @@
"title": "Has Been Opened"
},
"description": "Whether to include/exclude recent workflows"
},
{
"name": "is_public",
"in": "query",
"required": false,
"schema": {
"anyOf": [
{
"type": "boolean"
},
{
"type": "null"
}
],
"title": "Is Public"
},
"description": "Filter by public/shared status"
}
],
"responses": {
@@ -6812,6 +6863,23 @@
"title": "Has Been Opened"
},
"description": "Whether to include/exclude recent workflows"
},
{
"name": "is_public",
"in": "query",
"required": false,
"schema": {
"anyOf": [
{
"type": "boolean"
},
{
"type": "null"
}
],
"title": "Is Public"
},
"description": "Filter by public/shared status"
}
],
"responses": {
@@ -7352,6 +7420,67 @@
}
}
}
},
"/api/v1/workflows/i/{workflow_id}/is_public": {
"patch": {
"tags": ["workflows"],
"summary": "Update Workflow Is Public",
"description": "Updates whether a workflow is shared publicly",
"operationId": "update_workflow_is_public",
"parameters": [
{
"name": "workflow_id",
"in": "path",
"required": true,
"schema": {
"type": "string",
"title": "Workflow Id"
},
"description": "The workflow to update"
}
],
"requestBody": {
"content": {
"application/json": {
"schema": {
"properties": {
"is_public": {
"type": "boolean",
"title": "Is Public",
"description": "Whether the workflow should be shared publicly"
}
},
"type": "object",
"required": ["is_public"],
"title": "Body_update_workflow_is_public"
}
}
},
"required": true
},
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/WorkflowRecordDTO"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
}
},
"components": {
@@ -59137,10 +59266,20 @@
"workflow": {
"$ref": "#/components/schemas/Workflow",
"description": "The workflow."
},
"user_id": {
"type": "string",
"title": "User Id",
"description": "The id of the user who owns this workflow."
},
"is_public": {
"type": "boolean",
"title": "Is Public",
"description": "Whether this workflow is shared with all users."
}
},
"type": "object",
"required": ["workflow_id", "name", "created_at", "updated_at", "workflow"],
"required": ["workflow_id", "name", "created_at", "updated_at", "workflow", "user_id", "is_public"],
"title": "WorkflowRecordDTO"
},
"WorkflowRecordListItemWithThumbnailDTO": {
@@ -59222,15 +59361,35 @@
],
"title": "Thumbnail Url",
"description": "The URL of the workflow thumbnail."
},
"user_id": {
"type": "string",
"title": "User Id",
"description": "The id of the user who owns this workflow."
},
"is_public": {
"type": "boolean",
"title": "Is Public",
"description": "Whether this workflow is shared with all users."
}
},
"type": "object",
"required": ["workflow_id", "name", "created_at", "updated_at", "description", "category", "tags"],
"required": [
"workflow_id",
"name",
"created_at",
"updated_at",
"description",
"category",
"tags",
"user_id",
"is_public"
],
"title": "WorkflowRecordListItemWithThumbnailDTO"
},
"WorkflowRecordOrderBy": {
"type": "string",
"enum": ["created_at", "updated_at", "opened_at", "name"],
"enum": ["created_at", "updated_at", "opened_at", "name", "is_public"],
"title": "WorkflowRecordOrderBy",
"description": "The order by options for workflow records"
},
@@ -59303,10 +59462,20 @@
],
"title": "Thumbnail Url",
"description": "The URL of the workflow thumbnail."
},
"user_id": {
"type": "string",
"title": "User Id",
"description": "The id of the user who owns this workflow."
},
"is_public": {
"type": "boolean",
"title": "Is Public",
"description": "Whether this workflow is shared with all users."
}
},
"type": "object",
"required": ["workflow_id", "name", "created_at", "updated_at", "workflow"],
"required": ["workflow_id", "name", "created_at", "updated_at", "workflow", "user_id", "is_public"],
"title": "WorkflowRecordWithThumbnailDTO"
},
"WorkflowWithoutID": {

View File

@@ -161,7 +161,17 @@
"imagesWithCount_other": "{{count}} images",
"assetsWithCount_one": "{{count}} asset",
"assetsWithCount_other": "{{count}} assets",
"updateBoardError": "Error updating board"
"updateBoardError": "Error updating board",
"setBoardVisibility": "Set Board Visibility",
"setVisibilityPrivate": "Set Private",
"setVisibilityShared": "Set Shared",
"setVisibilityPublic": "Set Public",
"visibilityPrivate": "Private",
"visibilityShared": "Shared",
"visibilityPublic": "Public",
"visibilityBadgeShared": "Shared board",
"visibilityBadgePublic": "Public board",
"updateBoardVisibilityError": "Error updating board visibility"
},
"accordions": {
"generation": {
@@ -1199,7 +1209,9 @@
"numImages": "Num Images",
"modelPickerFallbackNoModelsInstalled": "No models installed.",
"modelPickerFallbackNoModelsInstalled2": "Visit the <LinkComponent>Model Manager</LinkComponent> to install models.",
"modelPickerFallbackNoModelsInstalledNonAdmin": "No models installed. Ask your InvokeAI administrator (<AdminEmailLink />) to install some models.",
"noModelsInstalledDesc1": "Install models with the",
"noModelsInstalledAskAdmin": "Ask your administrator to install some.",
"noModelSelected": "No Model Selected",
"noMatchingModels": "No matching models",
"noModelsInstalled": "No models installed",
@@ -1295,6 +1307,12 @@
"flux2KleinVaePlaceholder": "From main model",
"flux2KleinQwen3Encoder": "Qwen3 Encoder (optional)",
"flux2KleinQwen3EncoderPlaceholder": "From main model",
"qwenImageComponentSource": "VAE/Encoder Source (Diffusers)",
"qwenImageComponentSourcePlaceholder": "Required for GGUF models",
"qwenImageQuantization": "Encoder Quantization",
"qwenImageQuantizationNone": "None (bf16)",
"qwenImageQuantizationInt8": "8-bit (int8)",
"qwenImageQuantizationNf4": "4-bit (nf4)",
"upcastAttention": "Upcast Attention",
"uploadImage": "Upload Image",
"urlOrLocalPath": "URL or Local Path",
@@ -1562,6 +1580,7 @@
"info": "Info",
"invoke": {
"addingImagesTo": "Adding images to",
"boardNotWritable": "You do not have write access to board \"{{boardName}}\". Select a board you own or switch to Uncategorized.",
"modelDisabledForTrial": "Generating with {{modelName}} is not available on trial accounts. Visit your account settings to upgrade.",
"invoke": "Invoke",
"missingFieldTemplate": "Missing field template",
@@ -1588,6 +1607,7 @@
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
"noCLIPEmbedModelSelected": "No CLIP Embed model selected for FLUX generation",
"noQwen3EncoderModelSelected": "No Qwen3 Encoder model selected for FLUX2 Klein generation",
"noQwenImageComponentSourceSelected": "GGUF Qwen Image models require a Diffusers Component Source for VAE/encoder",
"noZImageVaeSourceSelected": "No VAE source: Select VAE (FLUX) or Qwen3 Source model",
"noZImageQwen3EncoderSourceSelected": "No Qwen3 Encoder source: Select Qwen3 Encoder or Qwen3 Source model",
"noAnimaVaeModelSelected": "No Anima VAE model selected",
@@ -1641,6 +1661,7 @@
"sendToCanvas": "Send To Canvas",
"sendToUpscale": "Send To Upscale",
"showOptionsPanel": "Show Side Panel (O or T)",
"shift": "Shift",
"shuffle": "Shuffle Seed",
"steps": "Steps",
"strength": "Strength",
@@ -2317,6 +2338,8 @@
"tags": "Tags",
"yourWorkflows": "Your Workflows",
"recentlyOpened": "Recently Opened",
"sharedWorkflows": "Shared Workflows",
"shareWorkflow": "Shared workflow",
"noRecentWorkflows": "No Recent Workflows",
"private": "Private",
"shared": "Shared",
@@ -3051,6 +3074,7 @@
"tileOverlap": "Tile Overlap",
"postProcessingMissingModelWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install a post-processing (image to image) model.",
"missingModelsWarning": "Visit the <LinkComponent>Model Manager</LinkComponent> to install the required models:",
"missingModelsWarningNonAdmin": "Ask your InvokeAI administrator (<AdminEmailLink />) to install the required models:",
"mainModelDesc": "Main model (SD1.5 or SDXL architecture)",
"tileControlNetModelDesc": "Tile ControlNet model for the chosen main model architecture",
"upscaleModelDesc": "Upscale (image to image) model",
@@ -3159,6 +3183,7 @@
},
"workflows": {
"description": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results.",
"descriptionMultiuser": "Workflows are reusable templates that automate image generation tasks, allowing you to quickly perform complex operations and get consistent results. You may share your workflows with other users of the system by selecting 'Shared workflow' when you create or edit it.",
"learnMoreLink": "Learn more about creating workflows",
"browseTemplates": {
"title": "Browse Workflow Templates",
@@ -3237,9 +3262,11 @@
"toGetStartedLocal": "To get started, make sure to download or import models needed to run Invoke. Then, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
"toGetStarted": "To get started, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
"toGetStartedWorkflow": "To get started, fill in the fields on the left and press <StrongComponent>Invoke</StrongComponent> to generate your image. Want to explore more workflows? Click the <StrongComponent>folder icon</StrongComponent> next to the workflow title to see a list of other templates you can try.",
"toGetStartedNonAdmin": "To get started, ask your InvokeAI administrator (<AdminEmailLink />) to install the AI models needed to run Invoke. Then, enter a prompt in the box and click <StrongComponent>Invoke</StrongComponent> to generate your first image. Select a prompt template to improve results. You can choose to save your images directly to the <StrongComponent>Gallery</StrongComponent> or edit them to the <StrongComponent>Canvas</StrongComponent>.",
"gettingStartedSeries": "Want more guidance? Check out our <LinkComponent>Getting Started Series</LinkComponent> for tips on unlocking the full potential of the Invoke Studio.",
"lowVRAMMode": "For best performance, follow our <LinkComponent>Low VRAM guide</LinkComponent>.",
"noModelsInstalled": "It looks like you don't have any models installed! You can <DownloadStarterModelsButton>download a starter model bundle</DownloadStarterModelsButton> or <ImportModelsButton>import models</ImportModelsButton>."
"noModelsInstalled": "It looks like you don't have any models installed! You can <DownloadStarterModelsButton>download a starter model bundle</DownloadStarterModelsButton> or <ImportModelsButton>import models</ImportModelsButton>.",
"noModelsInstalledAskAdmin": "Ask your administrator to install some."
},
"whatsNew": {
"whatsNewInInvoke": "What's New in Invoke",

View File

@@ -12,10 +12,14 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
effect: (action) => {
log.debug(action.payload, 'Bulk download requested');
// If we have an item name, we are processing the bulk download locally and should use it as the toast id to
// prevent multiple toasts for the same item.
// Use a "preparing:" prefix so this toast cannot collide with the
// "ready to download" toast that arrives via the bulk_download_complete
// socket event. The background task can complete in under 20ms, so the
// socket event may arrive *before* this Redux middleware runs — without
// distinct IDs the "preparing" toast would overwrite the "ready" toast.
const itemName = action.payload.bulk_download_item_name;
toast({
id: action.payload.bulk_download_item_name ?? undefined,
id: itemName ? `preparing:${itemName}` : undefined,
title: t('gallery.bulkDownloadRequested'),
status: 'success',
// Show the response message if it exists, otherwise show the default message

View File

@@ -11,6 +11,7 @@ import {
kleinQwen3EncoderModelSelected,
kleinVaeModelSelected,
modelChanged,
qwenImageComponentSourceSelected,
resolutionPresetSelected,
setZImageScheduler,
syncedToOptimalDimension,
@@ -29,12 +30,18 @@ import {
selectBboxModelBase,
selectCanvasSlice,
} from 'features/controlLayers/store/selectors';
import { getEntityIdentifier, isAspectRatioID, isFlux2ReferenceImageConfig } from 'features/controlLayers/store/types';
import {
getEntityIdentifier,
isAspectRatioID,
isFlux2ReferenceImageConfig,
isQwenImageReferenceImageConfig,
} from 'features/controlLayers/store/types';
import {
initialFlux2ReferenceImage,
initialFluxKontextReferenceImage,
initialFLUXRedux,
initialIPAdapter,
initialQwenImageReferenceImage,
} from 'features/controlLayers/store/util';
import { SUPPORTS_REF_IMAGES_BASE_MODELS } from 'features/modelManagerV2/models';
import { zModelIdentifierField } from 'features/nodes/types/common';
@@ -49,6 +56,7 @@ import {
selectFluxVAEModels,
selectGlobalRefImageModels,
selectQwen3EncoderModels,
selectQwenImageDiffusersModels,
selectRegionalRefImageModels,
selectT5EncoderModels,
selectZImageDiffusersModels,
@@ -238,6 +246,44 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
}
// handle incompatible Qwen Image Edit component source - clear if switching away
const { qwenImageComponentSource } = state.params;
if (newBase !== 'qwen-image') {
if (qwenImageComponentSource) {
dispatch(qwenImageComponentSourceSelected(null));
modelsUpdatedDisabledOrCleared += 1;
}
} else {
// Switching to Qwen Image - auto-default component source to a matching diffusers model
if (!qwenImageComponentSource) {
const availableQwenImageDiffusers = selectQwenImageDiffusersModels(state);
// Look up the new model's variant to match generate vs edit
const modelConfigsResult = selectModelConfigsQuery(state);
let selectedVariant: string | null = null;
if (modelConfigsResult.data) {
const newModelConfig = modelConfigsAdapterSelectors.selectById(modelConfigsResult.data, newModel.key);
if (newModelConfig && 'variant' in newModelConfig && typeof newModelConfig.variant === 'string') {
selectedVariant = newModelConfig.variant;
}
}
// Find a diffusers model matching the variant; if no variant on denoiser, prefer "generate" then "edit"
const variantToMatch = selectedVariant ?? 'generate';
const matchingModel = availableQwenImageDiffusers.find(
(m) => 'variant' in m && m.variant === variantToMatch
);
const fallbackModel = availableQwenImageDiffusers.find(
(m) => 'variant' in m && m.variant !== variantToMatch
);
const diffusersModel = matchingModel ?? fallbackModel ?? availableQwenImageDiffusers[0];
if (diffusersModel) {
dispatch(qwenImageComponentSourceSelected(zModelIdentifierField.parse(diffusersModel)));
}
}
}
if (newModel.base !== 'external' && SUPPORTS_REF_IMAGES_BASE_MODELS.includes(newModel.base)) {
// Handle incompatible reference image models - switch to first compatible model, with some smart logic
// to choose the best available model based on the new main model.
@@ -280,6 +326,20 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
continue;
}
if (newBase === 'qwen-image') {
// Switching TO Qwen Image Edit - convert any non-qwen configs to qwen_image_reference_image
if (!isQwenImageReferenceImageConfig(entity.config)) {
dispatch(
refImageConfigChanged({
id: entity.id,
config: { ...initialQwenImageReferenceImage },
})
);
modelsUpdatedDisabledOrCleared += 1;
}
continue;
}
if (isFlux2ReferenceImageConfig(entity.config)) {
// Switching AWAY from FLUX.2 - convert flux2_reference_image to the appropriate config type
let newConfig;
@@ -304,6 +364,30 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
continue;
}
if (isQwenImageReferenceImageConfig(entity.config)) {
// Switching AWAY from Qwen Image Edit - convert to the appropriate config type
let newConfig;
if (newGlobalRefImageModel) {
const parsedModel = zModelIdentifierField.parse(newGlobalRefImageModel);
if (newModel.base === 'flux' && newModel.name.toLowerCase().includes('kontext')) {
newConfig = { ...initialFluxKontextReferenceImage, model: parsedModel };
} else if (newGlobalRefImageModel.type === 'flux_redux') {
newConfig = { ...initialFLUXRedux, model: parsedModel };
} else {
newConfig = { ...initialIPAdapter, model: parsedModel };
if (parsedModel.base === 'flux') {
newConfig.clipVisionModel = 'ViT-L';
}
}
} else {
// No compatible model found - fall back to an empty IP adapter config
newConfig = { ...initialIPAdapter };
}
dispatch(refImageConfigChanged({ id: entity.id, config: newConfig }));
modelsUpdatedDisabledOrCleared += 1;
continue;
}
// Standard handling for non-flux2 configs
const shouldUpdateModel =
(entity.config.model && entity.config.model.base !== newBase) ||
@@ -391,6 +475,32 @@ export const addModelSelectedListener = (startAppListening: AppStartListening) =
}
}
// Handle Qwen Image model changes within the same base (variant may change between generate/edit)
// Auto-update the component source diffusers model to match the new variant
if (
newBase === 'qwen-image' &&
state.params.model?.base === 'qwen-image' &&
newModel.key !== state.params.model?.key
) {
const modelConfigsResult = selectModelConfigsQuery(state);
if (modelConfigsResult.data) {
const newModelConfig = modelConfigsAdapterSelectors.selectById(modelConfigsResult.data, newModel.key);
const newVariant =
newModelConfig && 'variant' in newModelConfig && typeof newModelConfig.variant === 'string'
? newModelConfig.variant
: 'generate';
const availableQwenImageDiffusers = selectQwenImageDiffusersModels(state);
const matchingModel = availableQwenImageDiffusers.find((m) => 'variant' in m && m.variant === newVariant);
const fallbackModel = availableQwenImageDiffusers.find((m) => 'variant' in m && m.variant !== newVariant);
const diffusersModel = matchingModel ?? fallbackModel ?? availableQwenImageDiffusers[0];
if (diffusersModel) {
dispatch(qwenImageComponentSourceSelected(zModelIdentifierField.parse(diffusersModel)));
}
}
}
// Handle Z-Image scheduler when switching to Z-Image Base (zbase) model
// LCM is not supported for undistilled models, so reset to euler
if (newBase === 'z-image' && state.params.zImageScheduler === 'lcm') {

View File

@@ -867,7 +867,7 @@ const GroupToggleButtons = typedMemo(<T extends object>() => {
}
return (
<Flex gap={2} alignItems="center">
<Flex gap={2} alignItems="center" flexWrap="wrap">
{groups.map((group) => (
<GroupToggleButton key={group.id} group={group} />
))}
@@ -927,6 +927,7 @@ const GroupToggleButton = typedMemo(<T extends object>({ group }: { group: Group
size="xs"
variant="solid"
userSelect="none"
flexShrink={0}
bg={bg}
color={color}
borderColor={groupColor}

View File

@@ -3,6 +3,7 @@ import { Combobox, ConfirmationAlertDialog, Flex, FormControl, Text } from '@inv
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { selectCurrentUser } from 'features/auth/store/authSlice';
import {
changeBoardReset,
isModalOpenChanged,
@@ -13,6 +14,7 @@ import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useListAllBoardsQuery } from 'services/api/endpoints/boards';
import { useAddImagesToBoardMutation, useRemoveImagesFromBoardMutation } from 'services/api/endpoints/images';
import type { BoardDTO } from 'services/api/types';
const selectImagesToChange = createSelector(
selectChangeBoardModalSlice,
@@ -28,6 +30,7 @@ const ChangeBoardModal = () => {
useAssertSingleton('ChangeBoardModal');
const dispatch = useAppDispatch();
const currentBoardId = useAppSelector(selectSelectedBoardId);
const currentUser = useAppSelector(selectCurrentUser);
const [selectedBoardId, setSelectedBoardId] = useState<string | null>();
const { data: boards, isFetching } = useListAllBoardsQuery({ include_archived: true });
const isModalOpen = useAppSelector(selectIsModalOpen);
@@ -36,10 +39,20 @@ const ChangeBoardModal = () => {
const [removeImagesFromBoard] = useRemoveImagesFromBoardMutation();
const { t } = useTranslation();
// Returns true if the current user can write images to the given board.
const canWriteToBoard = useCallback(
(board: BoardDTO): boolean => {
const isOwnerOrAdmin = !currentUser || currentUser.is_admin || board.user_id === currentUser.user_id;
return isOwnerOrAdmin || board.board_visibility === 'public';
},
[currentUser]
);
const options = useMemo<ComboboxOption[]>(() => {
return [{ label: t('boards.uncategorized'), value: 'none' }]
.concat(
(boards ?? [])
.filter(canWriteToBoard)
.map((board) => ({
label: board.board_name,
value: board.board_id,
@@ -47,7 +60,7 @@ const ChangeBoardModal = () => {
.sort((a, b) => a.label.localeCompare(b.label))
)
.filter((board) => board.value !== currentBoardId);
}, [boards, currentBoardId, t]);
}, [boards, canWriteToBoard, currentBoardId, t]);
const value = useMemo(() => options.find((o) => o.value === selectedBoardId), [options, selectedBoardId]);

View File

@@ -34,7 +34,12 @@ import type {
FLUXReduxImageInfluence as FLUXReduxImageInfluenceType,
IPMethodV2,
} from 'features/controlLayers/store/types';
import { isFlux2ReferenceImageConfig, isFLUXReduxConfig, isIPAdapterConfig } from 'features/controlLayers/store/types';
import {
isFlux2ReferenceImageConfig,
isFLUXReduxConfig,
isIPAdapterConfig,
isQwenImageReferenceImageConfig,
} from 'features/controlLayers/store/types';
import type { SetGlobalReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setGlobalReferenceImageDndTarget } from 'features/dnd/dnd';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
@@ -124,8 +129,9 @@ const RefImageSettingsContent = memo(() => {
const isFLUX = useAppSelector(selectIsFLUX);
const isExternalModel = !!mainModelConfig && isExternalApiModelConfig(mainModelConfig);
// FLUX.2 Klein and external API models do not require a ref image model selection.
const showModelSelector = !isFlux2ReferenceImageConfig(config) && !isExternalModel;
// FLUX.2 Klein, Qwen Image Edit and external API models do not require a ref image model selection.
const showModelSelector =
!isFlux2ReferenceImageConfig(config) && !isQwenImageReferenceImageConfig(config) && !isExternalModel;
return (
<Flex flexDir="column" gap={2} position="relative" w="full">

View File

@@ -29,6 +29,7 @@ import type {
Flux2ReferenceImageConfig,
FluxKontextReferenceImageConfig,
IPAdapterConfig,
QwenImageReferenceImageConfig,
RegionalGuidanceIPAdapterConfig,
T2IAdapterConfig,
} from 'features/controlLayers/store/types';
@@ -37,6 +38,7 @@ import {
initialFlux2ReferenceImage,
initialFluxKontextReferenceImage,
initialIPAdapter,
initialQwenImageReferenceImage,
initialRegionalGuidanceIPAdapter,
initialT2IAdapter,
} from 'features/controlLayers/store/util';
@@ -78,7 +80,7 @@ export const selectDefaultControlAdapter = createSelector(
export const getDefaultRefImageConfig = (
getState: AppGetState
): IPAdapterConfig | FluxKontextReferenceImageConfig | Flux2ReferenceImageConfig => {
): IPAdapterConfig | FluxKontextReferenceImageConfig | Flux2ReferenceImageConfig | QwenImageReferenceImageConfig => {
const state = getState();
const mainModelConfig = selectMainModelConfig(state);
@@ -91,6 +93,11 @@ export const getDefaultRefImageConfig = (
return deepClone(initialFlux2ReferenceImage);
}
// Qwen Image Edit has built-in reference image support - no model needed
if (base === 'qwen-image') {
return deepClone(initialQwenImageReferenceImage);
}
if (base === 'flux' && mainModelConfig?.name?.toLowerCase().includes('kontext')) {
const config = deepClone(initialFluxKontextReferenceImage);
config.model = zModelIdentifierField.parse(mainModelConfig);

View File

@@ -261,6 +261,19 @@ const slice = createSlice({
}
state.kleinQwen3EncoderModel = result.data;
},
qwenImageComponentSourceSelected: (state, action: PayloadAction<ParameterModel | null>) => {
const result = zParamsState.shape.qwenImageComponentSource.safeParse(action.payload);
if (!result.success) {
return;
}
state.qwenImageComponentSource = result.data;
},
qwenImageQuantizationChanged: (state, action: PayloadAction<'none' | 'int8' | 'nf4'>) => {
state.qwenImageQuantization = action.payload;
},
qwenImageShiftChanged: (state, action: PayloadAction<number | null>) => {
state.qwenImageShift = action.payload;
},
vaePrecisionChanged: (state, action: PayloadAction<ParameterPrecision>) => {
state.vaePrecision = action.payload;
},
@@ -566,6 +579,9 @@ const resetState = (state: ParamsState): ParamsState => {
newState.animaT5EncoderModel = oldState.animaT5EncoderModel;
newState.kleinVaeModel = oldState.kleinVaeModel;
newState.kleinQwen3EncoderModel = oldState.kleinQwen3EncoderModel;
newState.qwenImageComponentSource = oldState.qwenImageComponentSource;
newState.qwenImageQuantization = oldState.qwenImageQuantization;
newState.qwenImageShift = oldState.qwenImageShift;
return newState;
};
@@ -613,6 +629,9 @@ export const {
zImageQwen3SourceModelSelected,
kleinVaeModelSelected,
kleinQwen3EncoderModelSelected,
qwenImageComponentSourceSelected,
qwenImageQuantizationChanged,
qwenImageShiftChanged,
setClipSkip,
shouldUseCpuNoiseChanged,
setColorCompensation,
@@ -691,6 +710,7 @@ export const selectIsZImage = createParamsSelector((params) => params.model?.bas
export const selectIsAnima = createParamsSelector((params) => params.model?.base === 'anima');
export const selectIsFlux2 = createParamsSelector((params) => params.model?.base === 'flux2');
export const selectIsExternal = createParamsSelector((params) => params.model?.base === 'external');
export const selectIsQwenImage = createParamsSelector((params) => params.model?.base === 'qwen-image');
export const selectIsFluxKontext = createParamsSelector((params) => {
if (params.model?.base === 'flux' && params.model?.name.toLowerCase().includes('kontext')) {
return true;
@@ -717,6 +737,9 @@ export const selectAnimaT5EncoderModel = createParamsSelector((params) => params
export const selectAnimaScheduler = createParamsSelector((params) => params.animaScheduler);
export const selectKleinVaeModel = createParamsSelector((params) => params.kleinVaeModel);
export const selectKleinQwen3EncoderModel = createParamsSelector((params) => params.kleinQwen3EncoderModel);
export const selectQwenImageComponentSource = createParamsSelector((params) => params.qwenImageComponentSource);
export const selectQwenImageQuantization = createParamsSelector((params) => params.qwenImageQuantization);
export const selectQwenImageShift = createParamsSelector((params) => params.qwenImageShift);
export const selectCFGScale = createParamsSelector((params) => params.cfgScale);
export const selectGuidance = createParamsSelector((params) => params.guidance);

View File

@@ -22,6 +22,7 @@ import {
isFlux2ReferenceImageConfig,
isFLUXReduxConfig,
isIPAdapterConfig,
isQwenImageReferenceImageConfig,
zRefImagesState,
} from './types';
import { getReferenceImageState, initialFluxKontextReferenceImage, initialFLUXRedux, initialIPAdapter } from './util';
@@ -106,8 +107,8 @@ const slice = createSlice({
return;
}
// FLUX.2 reference images don't have a model field - they use built-in support
if (isFlux2ReferenceImageConfig(entity.config)) {
// FLUX.2 and Qwen Image Edit reference images don't have a model field - they use built-in support
if (isFlux2ReferenceImageConfig(entity.config) || isQwenImageReferenceImageConfig(entity.config)) {
return;
}

View File

@@ -370,6 +370,13 @@ const zFlux2ReferenceImageConfig = z.object({
});
export type Flux2ReferenceImageConfig = z.infer<typeof zFlux2ReferenceImageConfig>;
// Qwen Image Edit has built-in reference image support - no separate model needed
const zQwenImageReferenceImageConfig = z.object({
type: z.literal('qwen_image_reference_image'),
image: zCroppableImageWithDims.nullable(),
});
export type QwenImageReferenceImageConfig = z.infer<typeof zQwenImageReferenceImageConfig>;
const zCanvasEntityBase = z.object({
id: zId,
name: zName,
@@ -385,6 +392,7 @@ export const zRefImageState = z.object({
zFLUXReduxConfig,
zFluxKontextReferenceImageConfig,
zFlux2ReferenceImageConfig,
zQwenImageReferenceImageConfig,
]),
});
export type RefImageState = z.infer<typeof zRefImageState>;
@@ -402,6 +410,10 @@ export const isFluxKontextReferenceImageConfig = (
export const isFlux2ReferenceImageConfig = (config: RefImageState['config']): config is Flux2ReferenceImageConfig =>
config.type === 'flux2_reference_image';
export const isQwenImageReferenceImageConfig = (
config: RefImageState['config']
): config is QwenImageReferenceImageConfig => config.type === 'qwen_image_reference_image';
const zFillStyle = z.enum(['solid', 'grid', 'crosshatch', 'diagonal', 'horizontal', 'vertical']);
export type FillStyle = z.infer<typeof zFillStyle>;
export const isFillStyle = (v: unknown): v is FillStyle => zFillStyle.safeParse(v).success;
@@ -782,6 +794,10 @@ export const zParamsState = z.object({
// Flux2 Klein model components - uses Qwen3 instead of CLIP+T5
kleinVaeModel: zParameterVAEModel.nullable(), // Optional: Separate FLUX.2 VAE for Klein
kleinQwen3EncoderModel: zModelIdentifierField.nullable(), // Optional: Separate Qwen3 Encoder for Klein
// Qwen Image Edit model components - GGUF transformer needs a Diffusers source for VAE/encoder
qwenImageComponentSource: zParameterModel.nullable(), // Diffusers model providing VAE + text encoder
qwenImageQuantization: z.enum(['none', 'int8', 'nf4']), // BitsAndBytes quantization for Qwen VL encoder
qwenImageShift: z.number().nullable(), // Sigma schedule shift override (e.g. 3.0 for Lightning LoRAs)
// Z-Image Seed Variance Enhancer settings
zImageSeedVarianceEnabled: z.boolean(),
zImageSeedVarianceStrength: z.number().min(0).max(2),
@@ -859,6 +875,9 @@ export const getInitialParamsState = (): ParamsState => ({
animaScheduler: 'euler',
kleinVaeModel: null,
kleinQwen3EncoderModel: null,
qwenImageComponentSource: null,
qwenImageQuantization: 'none' as const,
qwenImageShift: null,
zImageSeedVarianceEnabled: false,
zImageSeedVarianceStrength: 0.1,
zImageSeedVarianceRandomizePercent: 50,

View File

@@ -15,6 +15,7 @@ import type {
FLUXReduxConfig,
ImageWithDims,
IPAdapterConfig,
QwenImageReferenceImageConfig,
RasterLayerAdjustments,
RefImageState,
RegionalGuidanceIPAdapterConfig,
@@ -117,6 +118,10 @@ export const initialFlux2ReferenceImage: Flux2ReferenceImageConfig = {
type: 'flux2_reference_image',
image: null,
};
export const initialQwenImageReferenceImage: QwenImageReferenceImageConfig = {
type: 'qwen_image_reference_image',
image: null,
};
export const initialT2IAdapter: T2IAdapterConfig = {
type: 't2i_adapter',
model: null,

View File

@@ -147,8 +147,8 @@ export const getGlobalReferenceImageWarnings = (
const { config } = entity;
// FLUX.2 reference images don't require a model - it's built-in
if (config.type !== 'flux2_reference_image') {
// FLUX.2 and Qwen Image Edit reference images don't require a model - it's built-in
if (config.type !== 'flux2_reference_image' && config.type !== 'qwen_image_reference_image') {
if (!('model' in config) || !config.model) {
// No model selected
warnings.push(WARNINGS.IP_ADAPTER_NO_MODEL_SELECTED);
@@ -159,8 +159,10 @@ export const getGlobalReferenceImageWarnings = (
}
if (!entity.config.image) {
// No image selected
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
// No image selected - for Qwen Image Edit, an image is optional (txt2img works without one)
if (config.type !== 'qwen_image_reference_image') {
warnings.push(WARNINGS.IP_ADAPTER_NO_IMAGE_SELECTED);
}
}
}

View File

@@ -434,6 +434,49 @@ export const replaceCanvasEntityObjectsWithImageDndTarget: DndTarget<
//#endregion
//#region Add To Board
/**
* Check whether the current user can move images out of their source board.
* Returns false if the source board is a shared board not owned by the current user
* (and the user is not an admin). In that case, images can be viewed/used but not moved.
*/
const canMoveFromSourceBoard = (sourceBoardId: BoardId, getState: AppGetState): boolean => {
const state = getState();
// In single-user mode (no auth), always allow
const currentUser = state.auth?.user;
if (!currentUser) {
return true;
}
// Admins can always move
if (currentUser.is_admin) {
return true;
}
// "Uncategorized" (none) — user's own uncategorized images, allow
if (sourceBoardId === 'none') {
return true;
}
// Look up the board from the RTK Query cache
const boardsQueryState = state.api?.queries;
if (boardsQueryState) {
for (const query of Object.values(boardsQueryState)) {
if (query?.data && Array.isArray(query.data)) {
const board = (query.data as Array<{ board_id: string; user_id?: string; board_visibility?: string }>).find(
(b) => b.board_id === sourceBoardId
);
if (board) {
// Owner can always move
if (board.user_id === currentUser.user_id) {
return true;
}
// Non-owner can only move from public boards
return board.board_visibility === 'public';
}
}
}
}
// Board not found in cache — allow by default to avoid blocking legitimate operations
return true;
};
const _addToBoard = buildTypeAndKey('add-to-board');
export type AddImageToBoardDndTargetData = DndData<
typeof _addToBoard.type,
@@ -447,16 +490,23 @@ export const addImageToBoardDndTarget: DndTarget<
..._addToBoard,
typeGuard: buildTypeGuard(_addToBoard.key),
getData: buildGetData(_addToBoard.key, _addToBoard.type),
isValid: ({ sourceData, targetData }) => {
isValid: ({ sourceData, targetData, getState }) => {
if (singleImageDndSource.typeGuard(sourceData)) {
const currentBoard = sourceData.payload.imageDTO.board_id ?? 'none';
const destinationBoard = targetData.payload.boardId;
return currentBoard !== destinationBoard;
if (currentBoard === destinationBoard) {
return false;
}
// Don't allow moving images from shared boards the user doesn't own
return canMoveFromSourceBoard(currentBoard, getState);
}
if (multipleImageDndSource.typeGuard(sourceData)) {
const currentBoard = sourceData.payload.board_id;
const destinationBoard = targetData.payload.boardId;
return currentBoard !== destinationBoard;
if (currentBoard === destinationBoard) {
return false;
}
return canMoveFromSourceBoard(currentBoard, getState);
}
return false;
},
@@ -491,15 +541,22 @@ export const removeImageFromBoardDndTarget: DndTarget<
..._removeFromBoard,
typeGuard: buildTypeGuard(_removeFromBoard.key),
getData: buildGetData(_removeFromBoard.key, _removeFromBoard.type),
isValid: ({ sourceData }) => {
isValid: ({ sourceData, getState }) => {
if (singleImageDndSource.typeGuard(sourceData)) {
const currentBoard = sourceData.payload.imageDTO.board_id ?? 'none';
return currentBoard !== 'none';
if (currentBoard === 'none') {
return false;
}
// Don't allow removing images from shared boards the user doesn't own
return canMoveFromSourceBoard(currentBoard, getState);
}
if (multipleImageDndSource.typeGuard(sourceData)) {
const currentBoard = sourceData.payload.board_id;
return currentBoard !== 'none';
if (currentBoard === 'none') {
return false;
}
return canMoveFromSourceBoard(currentBoard, getState);
}
return false;

View File

@@ -2,15 +2,26 @@ import type { ContextMenuProps } from '@invoke-ai/ui-library';
import { ContextMenu, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectCurrentUser } from 'features/auth/store/authSlice';
import { $boardToDelete } from 'features/gallery/components/Boards/DeleteBoardModal';
import { selectAutoAddBoardId, selectAutoAssignBoardOnClick } from 'features/gallery/store/gallerySelectors';
import { autoAddBoardIdChanged } from 'features/gallery/store/gallerySlice';
import { toast } from 'features/toast/toast';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArchiveBold, PiArchiveFill, PiDownloadBold, PiPlusBold, PiTrashSimpleBold } from 'react-icons/pi';
import {
PiArchiveBold,
PiArchiveFill,
PiDownloadBold,
PiGlobeBold,
PiLockBold,
PiPlusBold,
PiShareNetworkBold,
PiTrashSimpleBold,
} from 'react-icons/pi';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import { useBulkDownloadImagesMutation } from 'services/api/endpoints/images';
import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
import { useBoardName } from 'services/api/hooks/useBoardName';
import type { BoardDTO } from 'services/api/types';
@@ -23,6 +34,7 @@ const BoardContextMenu = ({ board, children }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const autoAssignBoardOnClick = useAppSelector(selectAutoAssignBoardOnClick);
const currentUser = useAppSelector(selectCurrentUser);
const selectIsSelectedForAutoAdd = useMemo(
() => createSelector(selectAutoAddBoardId, (autoAddBoardId) => board.board_id === autoAddBoardId),
[board.board_id]
@@ -35,6 +47,11 @@ const BoardContextMenu = ({ board, children }: Props) => {
const [bulkDownload] = useBulkDownloadImagesMutation();
// Only the board owner or admin can modify visibility
const canChangeVisibility = currentUser !== null && (currentUser.is_admin || board.user_id === currentUser.user_id);
const { canDeleteBoard } = useBoardAccess(board);
const handleSetAutoAdd = useCallback(() => {
dispatch(autoAddBoardIdChanged(board.board_id));
}, [board.board_id, dispatch]);
@@ -64,6 +81,26 @@ const BoardContextMenu = ({ board, children }: Props) => {
});
}, [board.board_id, updateBoard]);
const handleSetVisibility = useCallback(
async (visibility: 'private' | 'shared' | 'public') => {
try {
await updateBoard({
board_id: board.board_id,
changes: { board_visibility: visibility },
}).unwrap();
} catch {
toast({ status: 'error', title: t('boards.updateBoardVisibilityError') });
}
},
[board.board_id, t, updateBoard]
);
const handleSetVisibilityPrivate = useCallback(() => handleSetVisibility('private'), [handleSetVisibility]);
const handleSetVisibilityShared = useCallback(() => handleSetVisibility('shared'), [handleSetVisibility]);
const handleSetVisibilityPublic = useCallback(() => handleSetVisibility('public'), [handleSetVisibility]);
const setAsBoardToDelete = useCallback(() => {
$boardToDelete.set(board);
}, [board]);
@@ -83,18 +120,50 @@ const BoardContextMenu = ({ board, children }: Props) => {
</MenuItem>
{board.archived && (
<MenuItem icon={<PiArchiveBold />} onClick={handleUnarchive}>
<MenuItem icon={<PiArchiveBold />} onClick={handleUnarchive} isDisabled={!canDeleteBoard}>
{t('boards.unarchiveBoard')}
</MenuItem>
)}
{!board.archived && (
<MenuItem icon={<PiArchiveFill />} onClick={handleArchive}>
<MenuItem icon={<PiArchiveFill />} onClick={handleArchive} isDisabled={!canDeleteBoard}>
{t('boards.archiveBoard')}
</MenuItem>
)}
<MenuItem color="error.300" icon={<PiTrashSimpleBold />} onClick={setAsBoardToDelete} isDestructive>
{canChangeVisibility && (
<>
<MenuItem
icon={<PiLockBold />}
onClick={handleSetVisibilityPrivate}
isDisabled={board.board_visibility === 'private'}
>
{t('boards.setVisibilityPrivate')}
</MenuItem>
<MenuItem
icon={<PiShareNetworkBold />}
onClick={handleSetVisibilityShared}
isDisabled={board.board_visibility === 'shared'}
>
{t('boards.setVisibilityShared')}
</MenuItem>
<MenuItem
icon={<PiGlobeBold />}
onClick={handleSetVisibilityPublic}
isDisabled={board.board_visibility === 'public'}
>
{t('boards.setVisibilityPublic')}
</MenuItem>
</>
)}
<MenuItem
color="error.300"
icon={<PiTrashSimpleBold />}
onClick={setAsBoardToDelete}
isDestructive
isDisabled={!canDeleteBoard}
>
{t('boards.deleteBoard')}
</MenuItem>
</MenuGroup>
@@ -108,8 +177,14 @@ const BoardContextMenu = ({ board, children }: Props) => {
t,
handleBulkDownload,
board.archived,
board.board_visibility,
handleUnarchive,
handleArchive,
canChangeVisibility,
handleSetVisibilityPrivate,
handleSetVisibilityShared,
handleSetVisibilityPublic,
canDeleteBoard,
setAsBoardToDelete,
]
);

View File

@@ -7,6 +7,7 @@ import { memo, useCallback, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPencilBold } from 'react-icons/pi';
import { useUpdateBoardMutation } from 'services/api/endpoints/boards';
import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
import type { BoardDTO } from 'services/api/types';
type Props = {
@@ -19,6 +20,7 @@ export const BoardEditableTitle = memo(({ board, isSelected }: Props) => {
const isHovering = useBoolean(false);
const inputRef = useRef<HTMLInputElement>(null);
const [updateBoard, updateBoardResult] = useUpdateBoardMutation();
const { canRenameBoard } = useBoardAccess(board);
const onChange = useCallback(
async (board_name: string) => {
@@ -51,13 +53,13 @@ export const BoardEditableTitle = memo(({ board, isSelected }: Props) => {
fontWeight="semibold"
userSelect="none"
color={isSelected ? 'base.100' : 'base.300'}
onDoubleClick={editable.startEditing}
cursor="text"
onDoubleClick={canRenameBoard ? editable.startEditing : undefined}
cursor={canRenameBoard ? 'text' : 'default'}
noOfLines={1}
>
{editable.value}
</Text>
{isHovering.isTrue && (
{canRenameBoard && isHovering.isTrue && (
<IconButton
aria-label={t('common.editName')}
icon={<PiPencilBold />}

View File

@@ -18,8 +18,9 @@ import {
import { autoAddBoardIdChanged, boardIdSelected } from 'features/gallery/store/gallerySlice';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArchiveBold, PiImageSquare } from 'react-icons/pi';
import { PiArchiveBold, PiGlobeBold, PiImageSquare, PiShareNetworkBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
import type { BoardDTO } from 'services/api/types';
const _hover: SystemStyleObject = {
@@ -62,6 +63,8 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => {
const showOwner = currentUser?.is_admin && board.owner_username;
const { canWriteImages } = useBoardAccess(board);
return (
<Box position="relative" w="full" h={12}>
<BoardContextMenu board={board}>
@@ -99,6 +102,20 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => {
</Flex>
{autoAddBoardId === board.board_id && <AutoAddBadge />}
{board.archived && <Icon as={PiArchiveBold} fill="base.300" />}
{board.board_visibility === 'shared' && (
<Tooltip label={t('boards.visibilityBadgeShared')}>
<span>
<Icon as={PiShareNetworkBold} fill="blue.300" />
</span>
</Tooltip>
)}
{board.board_visibility === 'public' && (
<Tooltip label={t('boards.visibilityBadgePublic')}>
<span>
<Icon as={PiGlobeBold} fill="green.300" />
</span>
</Tooltip>
)}
<Flex justifyContent="flex-end">
<Text variant="subtext">
{board.image_count} | {board.asset_count}
@@ -108,7 +125,12 @@ const GalleryBoard = ({ board, isSelected }: GalleryBoardProps) => {
</Tooltip>
)}
</BoardContextMenu>
<DndDropTarget dndTarget={addImageToBoardDndTarget} dndTargetData={dndTargetData} label={t('gallery.move')} />
<DndDropTarget
dndTarget={addImageToBoardDndTarget}
dndTargetData={dndTargetData}
label={t('gallery.move')}
isDisabled={!canWriteImages}
/>
</Box>
);
};

View File

@@ -5,11 +5,15 @@ import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiFoldersBold } from 'react-icons/pi';
import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard';
export const ContextMenuItemChangeBoard = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const imageDTO = useImageDTOContext();
const selectedBoard = useSelectedBoard();
const { canWriteImages } = useBoardAccess(selectedBoard);
const onClick = useCallback(() => {
dispatch(imagesToChangeSelected([imageDTO.image_name]));
@@ -17,7 +21,7 @@ export const ContextMenuItemChangeBoard = memo(() => {
}, [dispatch, imageDTO]);
return (
<MenuItem icon={<PiFoldersBold />} onClickCapture={onClick}>
<MenuItem icon={<PiFoldersBold />} onClickCapture={onClick} isDisabled={!canWriteImages}>
{t('boards.changeBoard')}
</MenuItem>
);

View File

@@ -4,11 +4,15 @@ import { useImageDTOContext } from 'features/gallery/contexts/ImageDTOContext';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleBold } from 'react-icons/pi';
import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard';
export const ContextMenuItemDeleteImage = memo(() => {
const { t } = useTranslation();
const deleteImageModal = useDeleteImageModalApi();
const imageDTO = useImageDTOContext();
const selectedBoard = useSelectedBoard();
const { canWriteImages } = useBoardAccess(selectedBoard);
const onClick = useCallback(async () => {
try {
@@ -18,6 +22,10 @@ export const ContextMenuItemDeleteImage = memo(() => {
}
}, [deleteImageModal, imageDTO]);
if (!canWriteImages) {
return null;
}
return (
<IconMenuItem
icon={<PiTrashSimpleBold />}

View File

@@ -10,12 +10,16 @@ import {
useStarImagesMutation,
useUnstarImagesMutation,
} from 'services/api/endpoints/images';
import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard';
const MultipleSelectionMenuItems = () => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const selection = useAppSelector((s) => s.gallery.selection);
const deleteImageModal = useDeleteImageModalApi();
const selectedBoard = useSelectedBoard();
const { canWriteImages } = useBoardAccess(selectedBoard);
const [starImages] = useStarImagesMutation();
const [unstarImages] = useUnstarImagesMutation();
@@ -53,11 +57,16 @@ const MultipleSelectionMenuItems = () => {
<MenuItem icon={<PiDownloadSimpleBold />} onClickCapture={handleBulkDownload}>
{t('gallery.downloadSelection')}
</MenuItem>
<MenuItem icon={<PiFoldersBold />} onClickCapture={handleChangeBoard}>
<MenuItem icon={<PiFoldersBold />} onClickCapture={handleChangeBoard} isDisabled={!canWriteImages}>
{t('boards.changeBoard')}
</MenuItem>
<MenuDivider />
<MenuItem color="error.300" icon={<PiTrashSimpleBold />} onClickCapture={handleDeleteSelection}>
<MenuItem
color="error.300"
icon={<PiTrashSimpleBold />}
onClickCapture={handleDeleteSelection}
isDisabled={!canWriteImages}
>
{t('gallery.deleteSelection')}
</MenuItem>
</>

View File

@@ -108,6 +108,25 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
if (!element) {
return;
}
const monitorBinding = monitorForElements({
// This is a "global" drag start event, meaning that it is called for all drag events.
onDragStart: ({ source }) => {
// When we start dragging multiple images, set the dragging state to true if the dragged image is part of the
// selection. This is called for all drag events.
if (
multipleImageDndSource.typeGuard(source.data) &&
source.data.payload.image_names.includes(imageDTO.image_name)
) {
setIsDragging(true);
}
},
onDrop: () => {
// Always set the dragging state to false when a drop event occurs.
setIsDragging(false);
},
});
return combine(
firefoxDndFix(element),
draggable({
@@ -153,23 +172,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
}
},
}),
monitorForElements({
// This is a "global" drag start event, meaning that it is called for all drag events.
onDragStart: ({ source }) => {
// When we start dragging multiple images, set the dragging state to true if the dragged image is part of the
// selection. This is called for all drag events.
if (
multipleImageDndSource.typeGuard(source.data) &&
source.data.payload.image_names.includes(imageDTO.image_name)
) {
setIsDragging(true);
}
},
onDrop: () => {
// Always set the dragging state to false when a drop event occurs.
setIsDragging(false);
},
})
monitorBinding
);
}, [imageDTO, store]);

View File

@@ -5,6 +5,8 @@ import type { MouseEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiTrashSimpleFill } from 'react-icons/pi';
import { useBoardAccess } from 'services/api/hooks/useBoardAccess';
import { useSelectedBoard } from 'services/api/hooks/useSelectedBoard';
import type { ImageDTO } from 'services/api/types';
type Props = {
@@ -15,6 +17,8 @@ export const GalleryItemDeleteIconButton = memo(({ imageDTO }: Props) => {
const shift = useShiftModifier();
const { t } = useTranslation();
const deleteImageModal = useDeleteImageModalApi();
const selectedBoard = useSelectedBoard();
const { canWriteImages } = useBoardAccess(selectedBoard);
const onClick = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
@@ -24,7 +28,7 @@ export const GalleryItemDeleteIconButton = memo(({ imageDTO }: Props) => {
[deleteImageModal, imageDTO]
);
if (!shift) {
if (!shift || !canWriteImages) {
return null;
}

View File

@@ -0,0 +1,24 @@
import { ImageMetadataHandlers } from 'features/metadata/parsing';
import { describe, expect, it } from 'vitest';
import { ImageMetadataActions } from './ImageMetadataActions';
describe('ImageMetadataActions', () => {
it('includes Qwen metadata handlers in the recall parameters UI', () => {
const element = (ImageMetadataActions as unknown as { type: (props: { metadata: unknown }) => unknown }).type({
metadata: { model: { key: 'test' } },
}) as {
props: {
children: Array<{ props?: { handler?: unknown } }>;
};
};
const handlers = element.props.children
.map((child) => child.props?.handler)
.filter((handler): handler is unknown => handler !== undefined);
expect(handlers).toContain(ImageMetadataHandlers.QwenImageComponentSource);
expect(handlers).toContain(ImageMetadataHandlers.QwenImageQuantization);
expect(handlers).toContain(ImageMetadataHandlers.QwenImageShift);
});
});

View File

@@ -58,6 +58,9 @@ export const ImageMetadataActions = memo((props: Props) => {
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.RefinerScheduler} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.RefinerDenoisingStart} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.RefinerSteps} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.QwenImageComponentSource} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.QwenImageQuantization} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.QwenImageShift} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.CanvasLayers} />
<CollectionMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.RefImages} />
<CollectionMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.LoRAs} />

View File

@@ -1,7 +1,9 @@
import type { ButtonProps } from '@invoke-ai/ui-library';
import { Alert, AlertDescription, AlertIcon, Button, Divider, Flex, Link, Spinner, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import { InvokeLogoIcon } from 'common/components/InvokeLogoIcon';
import { selectCurrentUser } from 'features/auth/store/authSlice';
import { LOADING_SYMBOL, useHasImages } from 'features/gallery/hooks/useHasImages';
import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore';
import { navigationApi } from 'features/ui/layouts/navigation-api';
@@ -9,16 +11,26 @@ import type { PropsWithChildren } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { Trans, useTranslation } from 'react-i18next';
import { PiArrowSquareOutBold, PiImageBold } from 'react-icons/pi';
import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
import { useMainModels } from 'services/api/hooks/modelsByType';
export const NoContentForViewer = memo(() => {
const hasImages = useHasImages();
const [mainModels, { data }] = useMainModels();
const { data: setupStatus } = useGetSetupStatusQuery();
const user = useAppSelector(selectCurrentUser);
const { t } = useTranslation();
const isMultiuser = setupStatus?.multiuser_enabled ?? false;
const isAdmin = !isMultiuser || (user?.is_admin ?? false);
const adminEmail = setupStatus?.admin_email ?? null;
const modelsLoaded = data !== undefined;
const hasModels = mainModels.length > 0;
const showStarterBundles = useMemo(() => {
return data && mainModels.length === 0;
}, [mainModels.length, data]);
return modelsLoaded && !hasModels && isAdmin;
}, [modelsLoaded, hasModels, isAdmin]);
if (hasImages === LOADING_SYMBOL) {
// Blank bg w/ a spinner. The new user experience components below have an invoke logo, but it's not centered.
@@ -36,10 +48,18 @@ export const NoContentForViewer = memo(() => {
<Flex flexDir="column" gap={8} alignItems="center" textAlign="center" maxW="400px">
<InvokeLogoIcon w={32} h={32} />
<Flex flexDir="column" gap={4} alignItems="center" textAlign="center">
<GetStartedLocal />
{showStarterBundles && <StarterBundlesCallout />}
<Divider />
<LowVRAMAlert />
{isAdmin ? (
// Admin / single-user mode
<>
{modelsLoaded && hasModels ? <GetStartedWithModels /> : <GetStartedLocal />}
{showStarterBundles && <StarterBundlesCallout />}
<Divider />
<LowVRAMAlert />
</>
) : (
// Non-admin user in multiuser mode
<>{modelsLoaded && hasModels ? <GetStartedWithModels /> : <GetStartedNonAdmin adminEmail={adminEmail} />}</>
)}
</Flex>
</Flex>
);
@@ -99,6 +119,32 @@ const GetStartedLocal = () => {
);
};
const GetStartedWithModels = () => {
return (
<Text fontSize="md" color="base.200">
<Trans i18nKey="newUserExperience.toGetStarted" components={{ StrongComponent }} />
</Text>
);
};
const GetStartedNonAdmin = ({ adminEmail }: { adminEmail: string | null }) => {
const AdminEmailLink = adminEmail ? (
<Link href={`mailto:${adminEmail}`} color="base.50">
{adminEmail}
</Link>
) : (
<Text as="span" color="base.50">
your administrator
</Text>
);
return (
<Text fontSize="md" color="base.200">
<Trans i18nKey="newUserExperience.toGetStartedNonAdmin" components={{ StrongComponent, AdminEmailLink }} />
</Text>
);
};
const StarterBundlesCallout = () => {
const handleClickDownloadStarterModels = useCallback(() => {
navigationApi.switchToTab('models');

View File

@@ -0,0 +1,94 @@
import { describe, expect, it, vi } from 'vitest';
import { ImageMetadataHandlers, MetadataUtils } from './parsing';
const createMockStore = () => ({
dispatch: vi.fn(),
getState: vi.fn(() => ({
params: { model: null },
})),
});
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const createStore = () => createMockStore() as any;
describe('Qwen metadata parsing', () => {
it('does not report missing Qwen metadata keys as available', async () => {
const store = createStore();
const hasMetadata = await MetadataUtils.hasMetadataByHandlers({
metadata: {},
handlers: [
ImageMetadataHandlers.QwenImageComponentSource,
ImageMetadataHandlers.QwenImageQuantization,
ImageMetadataHandlers.QwenImageShift,
],
store,
require: 'all',
});
// Handlers reject when keys are absent, so hasMetadata should be false
expect(hasMetadata).toBe(false);
});
it('does not recall Qwen values when metadata keys are absent', async () => {
const store = createStore();
const recalled = await MetadataUtils.recallByHandlers({
metadata: {},
handlers: [
ImageMetadataHandlers.QwenImageComponentSource,
ImageMetadataHandlers.QwenImageQuantization,
ImageMetadataHandlers.QwenImageShift,
],
store,
silent: true,
});
// No keys present → handlers reject → 0 recalls, no dispatches
expect(recalled.size).toBe(0);
const mockStore = store as ReturnType<typeof createMockStore>;
expect(mockStore.dispatch).not.toHaveBeenCalled();
});
it('recalls Qwen handlers with actual values when metadata keys are present', async () => {
const store = createStore();
const recalled = await MetadataUtils.recallByHandlers({
metadata: {
qwen_image_component_source: { key: 'test-key', hash: 'test', name: 'Test', base: 'qwen-image', type: 'main' },
qwen_image_quantization: 'nf4',
qwen_image_shift: 3.0,
},
handlers: [
ImageMetadataHandlers.QwenImageComponentSource,
ImageMetadataHandlers.QwenImageQuantization,
ImageMetadataHandlers.QwenImageShift,
],
store,
silent: true,
});
expect(recalled.size).toBe(3);
const mockStore = store as ReturnType<typeof createMockStore>;
expect(mockStore.dispatch).toHaveBeenCalledTimes(3);
});
it('recalls Qwen component source as null when key is present but value is null', async () => {
const store = createStore();
const recalled = await MetadataUtils.recallByHandlers({
metadata: {
qwen_image_component_source: null,
},
handlers: [ImageMetadataHandlers.QwenImageComponentSource],
store,
silent: true,
});
// Key is present with null value → handler resolves with null → 1 recall
expect(recalled.size).toBe(1);
const mockStore = store as ReturnType<typeof createMockStore>;
expect(mockStore.dispatch).toHaveBeenCalledTimes(1);
});
});

View File

@@ -16,6 +16,9 @@ import {
kleinVaeModelSelected,
negativePromptChanged,
positivePromptChanged,
qwenImageComponentSourceSelected,
qwenImageQuantizationChanged,
qwenImageShiftChanged,
refinerModelChanged,
selectBase,
setAnimaScheduler,
@@ -687,6 +690,83 @@ const ZImageSeedVarianceRandomizePercent: SingleMetadataHandler<number> = {
};
//#endregion ZImageSeedVarianceRandomizePercent
//#region QwenImageComponentSource
const QwenImageComponentSource: SingleMetadataHandler<ModelIdentifierField | null> = {
[SingleMetadataKey]: true,
type: 'QwenImageComponentSource',
parse: (metadata, _store) => {
const raw = getProperty(metadata, 'qwen_image_component_source');
// Reject when the key is absent so the handler is not rendered for non-Qwen images
if (raw === undefined) {
return Promise.reject();
}
if (raw === null) {
return Promise.resolve(null);
}
return Promise.resolve(zModelIdentifierField.parse(raw));
},
recall: (value, store) => {
store.dispatch(qwenImageComponentSourceSelected(value));
},
i18nKey: 'modelManager.qwenImageComponentSource',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<ModelIdentifierField | null>) => (
<MetadataPrimitiveValue value={value ? value.name : 'None'} />
),
};
//#endregion QwenImageComponentSource
//#region QwenImageQuantization
const QwenImageQuantization: SingleMetadataHandler<'none' | 'int8' | 'nf4'> = {
[SingleMetadataKey]: true,
type: 'QwenImageQuantization',
parse: (metadata, _store) => {
const raw = getProperty(metadata, 'qwen_image_quantization');
// Reject when the key is absent so the handler is not rendered for non-Qwen images
if (raw === undefined) {
return Promise.reject();
}
const parsed = z.enum(['none', 'int8', 'nf4']).parse(raw);
return Promise.resolve(parsed);
},
recall: (value, store) => {
store.dispatch(qwenImageQuantizationChanged(value));
},
i18nKey: 'modelManager.qwenImageQuantization',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<'none' | 'int8' | 'nf4'>) => (
<MetadataPrimitiveValue value={value} />
),
};
//#endregion QwenImageQuantization
//#region QwenImageShift
const QwenImageShift: SingleMetadataHandler<number | null> = {
[SingleMetadataKey]: true,
type: 'QwenImageShift',
parse: (metadata, _store) => {
const raw = getProperty(metadata, 'qwen_image_shift');
// Reject when the key is absent so the handler is not rendered for non-Qwen images
if (raw === undefined) {
return Promise.reject();
}
if (raw === null) {
return Promise.resolve(null);
}
const parsed = z.number().parse(raw);
return Promise.resolve(parsed);
},
recall: (value, store) => {
store.dispatch(qwenImageShiftChanged(value));
},
i18nKey: 'modelManager.qwenImageShift',
LabelComponent: MetadataLabel,
ValueComponent: ({ value }: SingleMetadataValueProps<number | null>) => (
<MetadataPrimitiveValue value={value ?? 'Default'} />
),
};
//#endregion QwenImageShift
//#region ZImageShift
const ZImageShift: SingleMetadataHandler<number> = {
[SingleMetadataKey]: true,
@@ -1334,6 +1414,9 @@ export const ImageMetadataHandlers = {
ZImageSeedVarianceEnabled,
ZImageSeedVarianceStrength,
ZImageSeedVarianceRandomizePercent,
QwenImageComponentSource,
QwenImageQuantization,
QwenImageShift,
ZImageShift,
LoRAs,
CanvasLayers,

View File

@@ -4,7 +4,7 @@ import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/
import type { StarterModel } from 'services/api/types';
type ModelInstallArg = {
config: Pick<StarterModel, 'name' | 'base' | 'type' | 'description' | 'format'>;
config: Pick<StarterModel, 'name' | 'base' | 'type' | 'description' | 'format' | 'variant'>;
source: string;
};
@@ -32,7 +32,7 @@ export const useBuildModelInstallArg = () => {
);
const buildModelInstallArg = useCallback((starterModel: StarterModel): ModelInstallArg => {
const { name, base, type, source, description, format } = starterModel;
const { name, base, type, source, description, format, variant } = starterModel;
return {
config: {
@@ -41,6 +41,7 @@ export const useBuildModelInstallArg = () => {
type,
description,
format,
variant,
},
source,
};

View File

@@ -1,10 +1,11 @@
import { Button, Text, useToast } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsAuthenticated } from 'features/auth/store/authSlice';
import { selectCurrentUser, selectIsAuthenticated } from 'features/auth/store/authSlice';
import { setInstallModelsTabByName } from 'features/modelManagerV2/store/installModelsStore';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
import { useMainModels } from 'services/api/hooks/modelsByType';
const TOAST_ID = 'starterModels';
@@ -15,6 +16,11 @@ export const useStarterModelsToast = () => {
const [mainModels, { data }] = useMainModels();
const toast = useToast();
const isAuthenticated = useAppSelector(selectIsAuthenticated);
const { data: setupStatus } = useGetSetupStatusQuery();
const user = useAppSelector(selectCurrentUser);
const isMultiuser = setupStatus?.multiuser_enabled ?? false;
const isAdmin = !isMultiuser || (user?.is_admin ?? false);
useEffect(() => {
// Only show the toast if the user is authenticated
@@ -33,17 +39,17 @@ export const useStarterModelsToast = () => {
toast({
id: TOAST_ID,
title: t('modelManager.noModelsInstalled'),
description: <ToastDescription />,
description: isAdmin ? <AdminToastDescription /> : <NonAdminToastDescription />,
status: 'info',
isClosable: true,
duration: null,
onCloseComplete: () => setDidToast(true),
});
}
}, [data, didToast, isAuthenticated, mainModels.length, t, toast]);
}, [data, didToast, isAuthenticated, isAdmin, mainModels.length, t, toast]);
};
const ToastDescription = () => {
const AdminToastDescription = () => {
const { t } = useTranslation();
const toast = useToast();
@@ -62,3 +68,9 @@ const ToastDescription = () => {
</Text>
);
};
const NonAdminToastDescription = () => {
const { t } = useTranslation();
return <Text fontSize="md">{t('modelManager.noModelsInstalledAskAdmin')}</Text>;
};

View File

@@ -148,6 +148,7 @@ export const MODEL_BASE_TO_COLOR: Record<BaseModelType, string> = {
flux: 'gold',
flux2: 'gold',
cogview4: 'red',
'qwen-image': 'orange',
'z-image': 'cyan',
external: 'orange',
anima: 'invokePurple',
@@ -192,6 +193,7 @@ export const MODEL_BASE_TO_LONG_NAME: Record<BaseModelType, string> = {
flux: 'FLUX',
flux2: 'FLUX.2',
cogview4: 'CogView4',
'qwen-image': 'Qwen Image',
'z-image': 'Z-Image',
external: 'External',
anima: 'Anima',
@@ -211,6 +213,7 @@ export const MODEL_BASE_TO_SHORT_NAME: Record<BaseModelType, string> = {
flux: 'FLUX',
flux2: 'FLUX.2',
cogview4: 'CogView4',
'qwen-image': 'QwenImg',
'z-image': 'Z-Image',
external: 'External',
anima: 'Anima',
@@ -231,6 +234,8 @@ export const MODEL_VARIANT_TO_LONG_NAME: Record<AnyModelVariant, string> = {
zbase: 'Z-Image Base',
large: 'CLIP L',
gigantic: 'CLIP G',
generate: 'Qwen Image',
edit: 'Qwen Image Edit',
qwen3_4b: 'Qwen3 4B',
qwen3_8b: 'Qwen3 8B',
qwen3_06b: 'Qwen3 0.6B',
@@ -257,13 +262,14 @@ export const MODEL_FORMAT_TO_LONG_NAME: Record<ModelFormat, string> = {
export const SUPPORTS_OPTIMIZED_DENOISING_BASE_MODELS: BaseModelType[] = ['flux', 'sd-3', 'z-image'];
export const SUPPORTS_REF_IMAGES_BASE_MODELS: BaseModelType[] = ['sd-1', 'sdxl', 'flux', 'flux2'];
export const SUPPORTS_REF_IMAGES_BASE_MODELS: BaseModelType[] = ['sd-1', 'sdxl', 'flux', 'flux2', 'qwen-image'];
export const SUPPORTS_NEGATIVE_PROMPT_BASE_MODELS: BaseModelType[] = [
'sd-1',
'sd-2',
'sdxl',
'cogview4',
'qwen-image',
'sd-3',
'z-image',
'anima',

View File

@@ -37,7 +37,7 @@ export const ModelManager = memo(() => {
{t('common.modelManager')}
</Heading>
<Flex gap={2}>
<SyncModelsButton />
{canManageModels && <SyncModelsButton />}
{!!selectedModelKey && canManageModels && (
<Button size="sm" colorScheme="invokeYellow" leftIcon={<PiPlusBold />} onClick={handleClickAddModel}>
{t('modelManager.addModels')}

View File

@@ -1,5 +1,6 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useDoesWorkflowHaveUnsavedChanges } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
import { useIsCurrentWorkflowOwner } from 'features/workflowLibrary/hooks/useIsCurrentWorkflowOwner';
import { useSaveOrSaveAsWorkflow } from 'features/workflowLibrary/hooks/useSaveOrSaveAsWorkflow';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -8,6 +9,7 @@ import { PiFloppyDiskBold } from 'react-icons/pi';
const SaveWorkflowButton = () => {
const { t } = useTranslation();
const doesWorkflowHaveUnsavedChanges = useDoesWorkflowHaveUnsavedChanges();
const isCurrentWorkflowOwner = useIsCurrentWorkflowOwner();
const saveOrSaveAsWorkflow = useSaveOrSaveAsWorkflow();
return (
@@ -15,7 +17,7 @@ const SaveWorkflowButton = () => {
tooltip={t('workflows.saveWorkflow')}
aria-label={t('workflows.saveWorkflow')}
icon={<PiFloppyDiskBold />}
isDisabled={!doesWorkflowHaveUnsavedChanges}
isDisabled={!doesWorkflowHaveUnsavedChanges || !isCurrentWorkflowOwner}
onClick={saveOrSaveAsWorkflow}
pointerEvents="auto"
/>

View File

@@ -1,4 +1,6 @@
import { IconButton } from '@invoke-ai/ui-library';
import { useDoesWorkflowHaveUnsavedChanges } from 'features/nodes/components/sidePanel/workflow/IsolatedWorkflowBuilderWatcher';
import { useIsCurrentWorkflowOwner } from 'features/workflowLibrary/hooks/useIsCurrentWorkflowOwner';
import { useSaveOrSaveAsWorkflow } from 'features/workflowLibrary/hooks/useSaveOrSaveAsWorkflow';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -7,12 +9,15 @@ import { PiFloppyDiskBold } from 'react-icons/pi';
const SaveWorkflowButton = () => {
const { t } = useTranslation();
const saveOrSaveAsWorkflow = useSaveOrSaveAsWorkflow();
const doesWorkflowHaveUnsavedChanges = useDoesWorkflowHaveUnsavedChanges();
const isCurrentWorkflowOwner = useIsCurrentWorkflowOwner();
return (
<IconButton
tooltip={t('workflows.saveWorkflow')}
aria-label={t('workflows.saveWorkflow')}
icon={<PiFloppyDiskBold />}
isDisabled={!doesWorkflowHaveUnsavedChanges || !isCurrentWorkflowOwner}
onClick={saveOrSaveAsWorkflow}
pointerEvents="auto"
variant="ghost"

View File

@@ -1,8 +1,19 @@
import type { FormControlProps } from '@invoke-ai/ui-library';
import { Box, Flex, FormControl, FormControlGroup, FormLabel, Image, Input, Textarea } from '@invoke-ai/ui-library';
import {
Box,
Checkbox,
Flex,
FormControl,
FormControlGroup,
FormLabel,
Image,
Input,
Textarea,
} from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { selectCurrentUser } from 'features/auth/store/authSlice';
import {
workflowAuthorChanged,
workflowContactChanged,
@@ -25,7 +36,8 @@ import {
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetWorkflowQuery } from 'services/api/endpoints/workflows';
import { useGetSetupStatusQuery } from 'services/api/endpoints/auth';
import { useGetWorkflowQuery, useUpdateWorkflowIsPublicMutation } from 'services/api/endpoints/workflows';
import { WorkflowThumbnailEditor } from './WorkflowThumbnail/WorkflowThumbnailEditor';
@@ -95,6 +107,7 @@ const WorkflowGeneralTab = () => {
<FormLabel>{t('nodes.workflowName')}</FormLabel>
<Input variant="darkFilled" value={name} onChange={handleChangeName} />
</FormControl>
<ShareWorkflowCheckbox id={id} />
<Thumbnail id={id} />
<FormControl>
<FormLabel>{t('nodes.workflowVersion')}</FormLabel>
@@ -187,3 +200,40 @@ const Thumbnail = ({ id }: { id?: string | null }) => {
// This is a default workflow and it does not have a thumbnail set. Users may not edit the thumbnail.
return null;
};
const ShareWorkflowCheckbox = ({ id }: { id?: string | null }) => {
const { t } = useTranslation();
const currentUser = useAppSelector(selectCurrentUser);
const { data: setupStatus } = useGetSetupStatusQuery();
const { data } = useGetWorkflowQuery(id ?? skipToken);
const [updateIsPublic, { isLoading }] = useUpdateWorkflowIsPublicMutation();
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
if (!id) {
return;
}
updateIsPublic({ workflow_id: id, is_public: e.target.checked });
},
[id, updateIsPublic]
);
// Only show for saved user workflows in multiuser mode when the current user is the owner or admin
if (!data || !id || data.workflow.meta.category !== 'user') {
return null;
}
if (setupStatus?.multiuser_enabled) {
const isOwner = currentUser !== null && data.user_id === currentUser.user_id;
const isAdmin = currentUser?.is_admin ?? false;
if (!isOwner && !isAdmin) {
return null;
}
}
return (
<Flex alignItems="center" gap={2}>
<Checkbox isChecked={data.is_public} onChange={handleChange} isDisabled={isLoading} />
<FormLabel mb={0}>{t('workflows.shareWorkflow')}</FormLabel>
</Flex>
);
};

View File

@@ -41,6 +41,7 @@ export const WorkflowLibrarySideNav = () => {
<Flex flexDir="column" w="full" pb={2} gap={2}>
<WorkflowLibraryViewButton view="recent">{t('workflows.recentlyOpened')}</WorkflowLibraryViewButton>
<YourWorkflowsButton />
<WorkflowLibraryViewButton view="shared">{t('workflows.sharedWorkflows')}</WorkflowLibraryViewButton>
</Flex>
<Flex h="full" minH={0} overflow="hidden" flexDir="column">
<BrowseWorkflowsButton />

View File

@@ -32,6 +32,8 @@ const getCategories = (view: WorkflowLibraryView): WorkflowCategory[] => {
return ['user', 'default'];
case 'yours':
return ['user'];
case 'shared':
return ['user'];
default:
assert<Equals<typeof view, never>>(false);
}
@@ -44,6 +46,13 @@ const getHasBeenOpened = (view: WorkflowLibraryView): boolean | undefined => {
return undefined;
};
const getIsPublic = (view: WorkflowLibraryView): boolean | undefined => {
if (view === 'shared') {
return true;
}
return undefined;
};
const useInfiniteQueryAry = () => {
const orderBy = useAppSelector(selectWorkflowLibraryOrderBy);
const direction = useAppSelector(selectWorkflowLibraryDirection);
@@ -62,6 +71,7 @@ const useInfiniteQueryAry = () => {
query: debouncedSearchTerm,
tags: view === 'defaults' || view === 'yours' ? selectedTags : [],
has_been_opened: getHasBeenOpened(view),
is_public: getIsPublic(view),
} satisfies Parameters<typeof useListWorkflowsInfiniteInfiniteQuery>[0];
}, [orderBy, direction, view, debouncedSearchTerm, selectedTags]);

View File

@@ -1,13 +1,15 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Badge, Flex, Icon, Image, Spacer, Text } from '@invoke-ai/ui-library';
import { Badge, Flex, Icon, Image, Spacer, Switch, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectCurrentUser } from 'features/auth/store/authSlice';
import { selectWorkflowId } from 'features/nodes/store/selectors';
import { workflowModeChanged } from 'features/nodes/store/workflowLibrarySlice';
import { useLoadWorkflowWithDialog } from 'features/workflowLibrary/components/LoadWorkflowConfirmationAlertDialog';
import InvokeLogo from 'public/assets/images/invoke-symbol-wht-lrg.svg';
import { memo, useCallback, useMemo } from 'react';
import { type ChangeEvent, memo, type MouseEvent, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImage } from 'react-icons/pi';
import { useUpdateWorkflowIsPublicMutation } from 'services/api/endpoints/workflows';
import type { WorkflowRecordListItemWithThumbnailDTO } from 'services/api/types';
import { DeleteWorkflow } from './WorkflowLibraryListItemActions/DeleteWorkflow';
@@ -33,12 +35,21 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
const { t } = useTranslation();
const dispatch = useAppDispatch();
const workflowId = useAppSelector(selectWorkflowId);
const currentUser = useAppSelector(selectCurrentUser);
const loadWorkflowWithDialog = useLoadWorkflowWithDialog();
const isActive = useMemo(() => {
return workflowId === workflow.workflow_id;
}, [workflowId, workflow.workflow_id]);
const isOwner = useMemo(() => {
return currentUser !== null && workflow.user_id === currentUser.user_id;
}, [currentUser, workflow.user_id]);
const canEditOrDelete = useMemo(() => {
return isOwner || (currentUser?.is_admin ?? false);
}, [isOwner, currentUser]);
const tags = useMemo(() => {
if (!workflow.tags) {
return [];
@@ -102,6 +113,18 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
{t('workflows.opened')}
</Badge>
)}
{workflow.is_public && workflow.category !== 'default' && (
<Badge
color="invokeGreen.400"
borderColor="invokeGreen.700"
borderWidth={1}
bg="transparent"
flexShrink={0}
variant="subtle"
>
{t('workflows.shared')}
</Badge>
)}
{workflow.category === 'default' && (
<Image
src={InvokeLogo}
@@ -137,12 +160,13 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
</Text>
)}
<Spacer />
{isOwner && <ShareWorkflowToggle workflow={workflow} />}
{workflow.category === 'default' && <ViewWorkflow workflowId={workflow.workflow_id} />}
{workflow.category !== 'default' && (
<>
<EditWorkflow workflowId={workflow.workflow_id} />
{canEditOrDelete && <EditWorkflow workflowId={workflow.workflow_id} />}
<DownloadWorkflow workflowId={workflow.workflow_id} />
<DeleteWorkflow workflowId={workflow.workflow_id} />
{canEditOrDelete && <DeleteWorkflow workflowId={workflow.workflow_id} />}
</>
)}
</Flex>
@@ -152,6 +176,35 @@ export const WorkflowListItem = memo(({ workflow }: { workflow: WorkflowRecordLi
});
WorkflowListItem.displayName = 'WorkflowListItem';
const ShareWorkflowToggle = memo(({ workflow }: { workflow: WorkflowRecordListItemWithThumbnailDTO }) => {
const { t } = useTranslation();
const [updateIsPublic, { isLoading }] = useUpdateWorkflowIsPublicMutation();
const handleChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
e.stopPropagation();
updateIsPublic({ workflow_id: workflow.workflow_id, is_public: e.target.checked });
},
[updateIsPublic, workflow.workflow_id]
);
const handleClick = useCallback((e: MouseEvent) => {
e.stopPropagation();
}, []);
return (
<Tooltip label={t('workflows.shareWorkflow')}>
<Flex alignItems="center" gap={1} onClick={handleClick}>
<Text variant="subtext" fontSize="xs">
{t('workflows.shared')}
</Text>
<Switch size="sm" isChecked={workflow.is_public} onChange={handleChange} isDisabled={isLoading} />
</Flex>
</Tooltip>
);
});
ShareWorkflowToggle.displayName = 'ShareWorkflowToggle';
const UserThumbnailFallback = memo(() => {
return (
<Flex

View File

@@ -12,7 +12,7 @@ import { useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { z } from 'zod';
const zOrderBy = z.enum(['opened_at', 'created_at', 'updated_at', 'name']);
const zOrderBy = z.enum(['opened_at', 'created_at', 'updated_at', 'name', 'is_public']);
type OrderBy = z.infer<typeof zOrderBy>;
const isOrderBy = (v: unknown): v is OrderBy => zOrderBy.safeParse(v).success;
@@ -32,6 +32,7 @@ export const WorkflowSortControl = () => {
created_at: t('workflows.created'),
updated_at: t('workflows.updated'),
name: t('workflows.name'),
is_public: t('workflows.shared'),
}),
[t]
);

View File

@@ -11,7 +11,7 @@ import {
} from 'services/api/types';
import z from 'zod';
const zWorkflowLibraryView = z.enum(['recent', 'yours', 'defaults']);
const zWorkflowLibraryView = z.enum(['recent', 'yours', 'shared', 'defaults']);
export type WorkflowLibraryView = z.infer<typeof zWorkflowLibraryView>;
const zWorkflowLibraryState = z.object({
@@ -55,6 +55,9 @@ const slice = createSlice({
if (action.payload === 'recent') {
state.orderBy = 'opened_at';
state.direction = 'DESC';
} else if (action.payload === 'shared') {
state.orderBy = 'name';
state.direction = 'ASC';
}
},
workflowLibraryTagToggled: (state, action: PayloadAction<string>) => {
@@ -121,5 +124,11 @@ export const WORKFLOW_LIBRARY_TAG_CATEGORIES: WorkflowTagCategory[] = [
];
export const WORKFLOW_LIBRARY_TAGS = WORKFLOW_LIBRARY_TAG_CATEGORIES.flatMap(({ tags }) => tags);
type WorkflowSortOption = 'opened_at' | 'created_at' | 'updated_at' | 'name';
export const WORKFLOW_LIBRARY_SORT_OPTIONS: WorkflowSortOption[] = ['opened_at', 'created_at', 'updated_at', 'name'];
type WorkflowSortOption = 'opened_at' | 'created_at' | 'updated_at' | 'name' | 'is_public';
export const WORKFLOW_LIBRARY_SORT_OPTIONS: WorkflowSortOption[] = [
'opened_at',
'created_at',
'updated_at',
'name',
'is_public',
];

View File

@@ -95,13 +95,25 @@ export const zBaseModelType = z.enum([
'flux',
'flux2',
'cogview4',
'qwen-image',
'z-image',
'external',
'anima',
'unknown',
]);
export type BaseModelType = z.infer<typeof zBaseModelType>;
export const zMainModelBase = z.enum(['sd-1', 'sd-2', 'sd-3', 'sdxl', 'flux', 'flux2', 'cogview4', 'z-image', 'anima']);
export const zMainModelBase = z.enum([
'sd-1',
'sd-2',
'sd-3',
'sdxl',
'flux',
'flux2',
'cogview4',
'qwen-image',
'z-image',
'anima',
]);
type MainModelBase = z.infer<typeof zMainModelBase>;
export const isMainModelBase = (base: unknown): base is MainModelBase => zMainModelBase.safeParse(base).success;
export const zModelType = z.enum([
@@ -147,6 +159,7 @@ export const zModelVariantType = z.enum(['normal', 'inpaint', 'depth']);
export const zFluxVariantType = z.enum(['dev', 'dev_fill', 'schnell']);
export const zFlux2VariantType = z.enum(['klein_4b', 'klein_9b', 'klein_9b_base']);
export const zZImageVariantType = z.enum(['turbo', 'zbase']);
const zQwenImageVariantType = z.enum(['generate', 'edit']);
export const zQwen3VariantType = z.enum(['qwen3_4b', 'qwen3_8b', 'qwen3_06b']);
export const zAnyModelVariant = z.union([
zModelVariantType,
@@ -154,6 +167,7 @@ export const zAnyModelVariant = z.union([
zFluxVariantType,
zFlux2VariantType,
zZImageVariantType,
zQwenImageVariantType,
zQwen3VariantType,
]);
export type AnyModelVariant = z.infer<typeof zAnyModelVariant>;

View File

@@ -22,7 +22,14 @@ type AddImageToImageArg = {
manager: CanvasManager;
l2i: Invocation<LatentToImageNodes>;
i2l: Invocation<
'i2l' | 'flux_vae_encode' | 'flux2_vae_encode' | 'sd3_i2l' | 'cogview4_i2l' | 'z_image_i2l' | 'anima_i2l'
| 'i2l'
| 'flux_vae_encode'
| 'flux2_vae_encode'
| 'sd3_i2l'
| 'cogview4_i2l'
| 'qwen_image_i2l'
| 'z_image_i2l'
| 'anima_i2l'
>;
noise?: Invocation<'noise'>;
denoise: Invocation<DenoiseLatentsNodes>;
@@ -46,6 +53,7 @@ export const addImageToImage = async ({
| 'flux2_vae_decode'
| 'sd3_l2i'
| 'cogview4_l2i'
| 'qwen_image_l2i'
| 'z_image_l2i'
| 'anima_l2i'
>
@@ -58,6 +66,7 @@ export const addImageToImage = async ({
if (
denoise.type === 'cogview4_denoise' ||
denoise.type === 'qwen_image_denoise' ||
denoise.type === 'flux_denoise' ||
denoise.type === 'flux2_denoise' ||
denoise.type === 'sd3_denoise' ||

View File

@@ -25,7 +25,14 @@ type AddInpaintArg = {
manager: CanvasManager;
l2i: Invocation<LatentToImageNodes>;
i2l: Invocation<
'i2l' | 'flux_vae_encode' | 'flux2_vae_encode' | 'sd3_i2l' | 'cogview4_i2l' | 'z_image_i2l' | 'anima_i2l'
| 'i2l'
| 'flux_vae_encode'
| 'flux2_vae_encode'
| 'sd3_i2l'
| 'cogview4_i2l'
| 'qwen_image_i2l'
| 'z_image_i2l'
| 'anima_i2l'
>;
noise?: Invocation<'noise'>;
denoise: Invocation<DenoiseLatentsNodes>;
@@ -57,6 +64,7 @@ export const addInpaint = async ({
if (
denoise.type === 'cogview4_denoise' ||
denoise.type === 'qwen_image_denoise' ||
denoise.type === 'flux_denoise' ||
denoise.type === 'flux2_denoise' ||
denoise.type === 'sd3_denoise' ||

Some files were not shown because too many files have changed in this diff Show More