Compare commits

...

6 Commits

Author SHA1 Message Date
Mary Hipp
c2931f0bac add user_id and project_id as nullable args for image saves 2024-10-22 10:14:10 -04:00
Mary Hipp
e6f80ca9b2 send one imageDTO back with complete event so the UI can refresh the correct data 2024-10-21 20:22:51 -04:00
Mary Hipp
90ad720bb2 (ui) accept upload socket events and show toast 2024-10-21 20:07:33 -04:00
Mary Hipp
0b139ec7df update backend to use bulk_upload_id passed in from client and emit events correctly 2024-10-21 20:05:57 -04:00
Brandon Rising
2c77d62865 Some small typing/syntax/linting updates for bulk upload flow 2024-10-21 14:17:24 -04:00
Mary Hipp
233afe7b3a WIP 2024-10-18 15:26:28 -04:00
20 changed files with 722 additions and 40 deletions

View File

@@ -1,6 +1,6 @@
import io
import traceback
from typing import Optional
from typing import List, Optional
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
@@ -15,7 +15,7 @@ from invokeai.app.services.image_records.image_records_common import (
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.images.images_common import ImageBulkUploadData, ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
@@ -26,6 +26,90 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
IMAGE_MAX_AGE = 31536000
class BulkUploadImageResponse(BaseModel):
sent: int
uploading: int
@images_router.post(
"/bulk-upload",
operation_id="bulk_upload",
responses={
201: {"description": "The images are being prepared for upload"},
415: {"description": "Images upload failed"},
},
status_code=201,
response_model=BulkUploadImageResponse,
)
async def bulk_upload(
bulk_upload_id: str,
files: list[UploadFile],
background_tasks: BackgroundTasks,
request: Request,
response: Response,
board_id: Optional[str] = Query(default=None, description="The board to add this images to, if any"),
) -> BulkUploadImageResponse:
"""Uploads multiple images"""
upload_data_list: List[ImageBulkUploadData] = []
# loop to handle multiple files
for file in files:
if not file.content_type or not file.content_type.startswith("image"):
ApiDependencies.invoker.services.logger.error("Not an image")
continue
_metadata = None
_workflow = None
_graph = None
contents = await file.read()
try:
pil_image = Image.open(io.BytesIO(contents))
except Exception:
ApiDependencies.invoker.services.logger.error(traceback.format_exc())
continue
# TODO: retain non-invokeai metadata on upload?
# attempt to parse metadata from image
metadata_raw = pil_image.info.get("invokeai_metadata", None)
if isinstance(metadata_raw, str):
_metadata = metadata_raw
else:
ApiDependencies.invoker.services.logger.debug("Failed to parse metadata for uploaded image")
pass
# attempt to parse workflow from image
workflow_raw = pil_image.info.get("invokeai_workflow", None)
if isinstance(workflow_raw, str):
_workflow = workflow_raw
else:
ApiDependencies.invoker.services.logger.debug("Failed to parse workflow for uploaded image")
pass
# attempt to extract graph from image
graph_raw = pil_image.info.get("invokeai_graph", None)
if isinstance(graph_raw, str):
_graph = graph_raw
else:
ApiDependencies.invoker.services.logger.debug("Failed to parse graph for uploaded image")
pass
# construct an ImageUploadData object for each file
upload_data = ImageBulkUploadData(
image=pil_image,
board_id=board_id,
metadata=_metadata,
workflow=_workflow,
graph=_graph,
)
upload_data_list.append(upload_data)
# Schedule image processing as a background task
background_tasks.add_task(ApiDependencies.invoker.services.images.create_many, bulk_upload_id, upload_data_list)
return BulkUploadImageResponse(sent=len(files), uploading=len(upload_data_list))
@images_router.post(
"/upload",
operation_id="upload_image",

View File

@@ -12,6 +12,11 @@ from invokeai.app.services.events.events_common import (
BulkDownloadErrorEvent,
BulkDownloadEventBase,
BulkDownloadStartedEvent,
BulkUploadCompletedEvent,
BulkUploadErrorEvent,
BulkUploadEventBase,
BulkUploadProgressEvent,
BulkUploadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
@@ -53,6 +58,13 @@ class BulkDownloadSubscriptionEvent(BaseModel):
bulk_download_id: str
class BulkUploadSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io bulk uploads room.
This is a pydantic model to ensure the data is in the correct format."""
bulk_upload_id: str
QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationProgressEvent,
@@ -80,6 +92,7 @@ MODEL_EVENTS = {
}
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
BULK_UPLOAD_EVENTS = {BulkUploadStartedEvent, BulkUploadCompletedEvent, BulkUploadProgressEvent, BulkUploadErrorEvent}
class SocketIO:
@@ -89,6 +102,9 @@ class SocketIO:
_sub_bulk_download = "subscribe_bulk_download"
_unsub_bulk_download = "unsubscribe_bulk_download"
_sub_bulk_upload = "subscribe_bulk_upload"
_unsub_bulk_upload = "unsubscribe_bulk_upload"
def __init__(self, app: FastAPI):
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
@@ -98,10 +114,13 @@ class SocketIO:
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
self._sio.on(self._sub_bulk_upload, handler=self._handle_sub_bulk_upload)
self._sio.on(self._unsub_bulk_upload, handler=self._handle_unsub_bulk_upload)
register_events(QUEUE_EVENTS, self._handle_queue_event)
register_events(MODEL_EVENTS, self._handle_model_event)
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
register_events(BULK_UPLOAD_EVENTS, self._handle_bulk_image_upload_event)
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
@@ -115,6 +134,12 @@ class SocketIO:
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_sub_bulk_upload(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, BulkUploadSubscriptionEvent(**data).bulk_upload_id)
async def _handle_unsub_bulk_upload(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkUploadSubscriptionEvent(**data).bulk_upload_id)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
@@ -123,3 +148,6 @@ class SocketIO:
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
async def _handle_bulk_image_upload_event(self, event: FastAPIEvent[BulkUploadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_upload_id)

View File

@@ -1,13 +1,17 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Literal, Optional
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadStartedEvent,
BulkUploadCompletedEvent,
BulkUploadErrorEvent,
BulkUploadProgressEvent,
BulkUploadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
@@ -30,6 +34,7 @@ from invokeai.app.services.events.events_common import (
QueueClearedEvent,
QueueItemStatusChangedEvent,
)
from invokeai.app.services.images.images_common import ImageDTO
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
@@ -44,6 +49,8 @@ if TYPE_CHECKING:
)
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
UploadStatusType = Literal["started", "processing", "done", "error"]
class EventServiceBase:
"""Basic event bus, to have an empty stand-in when not needed"""
@@ -197,3 +204,23 @@ class EventServiceBase:
)
# endregion
# region Bulk image upload
def emit_bulk_upload_started(self, bulk_upload_id: str, total: int) -> None:
"""Emitted when a bulk image upload is started"""
self.dispatch(BulkUploadStartedEvent.build(bulk_upload_id, total))
def emit_bulk_upload_progress(self, bulk_upload_id: str, completed: int, total: int) -> None:
"""Emitted when a bulk image upload is started"""
self.dispatch(BulkUploadProgressEvent.build(bulk_upload_id, completed, total))
def emit_bulk_upload_complete(self, bulk_upload_id: str, total: int, image_DTO: ImageDTO) -> None:
"""Emitted when a bulk image upload is complete"""
self.dispatch(BulkUploadCompletedEvent.build(bulk_upload_id, total=total, image_DTO=image_DTO))
def emit_bulk_upload_error(self, bulk_upload_id: str, error: str) -> None:
"""Emitted when a bulk image upload has an error"""
self.dispatch(BulkUploadErrorEvent.build(bulk_upload_id, error))
# endregion

View File

@@ -4,6 +4,7 @@ from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
@@ -624,3 +625,80 @@ class BulkDownloadErrorEvent(BulkDownloadEventBase):
bulk_download_item_name=bulk_download_item_name,
error=error,
)
class BulkUploadEventBase(EventBase):
"""Base class for events associated with a bulk image upload"""
bulk_upload_id: str = Field(description="The ID of the bulk image download")
@payload_schema.register
class BulkUploadStartedEvent(BulkUploadEventBase):
"""Event model for bulk_upload_started"""
__event_name__ = "bulk_upload_started"
total: int = Field(description="The total numberof images")
@classmethod
def build(
cls,
bulk_upload_id: str,
total: int,
) -> "BulkUploadStartedEvent":
return cls(bulk_upload_id=bulk_upload_id, total=total)
@payload_schema.register
class BulkUploadCompletedEvent(BulkUploadEventBase):
"""Event model for bulk_upload_completed"""
__event_name__ = "bulk_upload_completed"
total: int = Field(description="The total numberof images")
image_DTO: ImageDTO = Field(description="An image from the upload so client can refetch correctly")
@classmethod
def build(cls, bulk_upload_id: str, total: int, image_DTO: ImageDTO) -> "BulkUploadCompletedEvent":
return cls(bulk_upload_id=bulk_upload_id, total=total, image_DTO=image_DTO)
@payload_schema.register
class BulkUploadProgressEvent(BulkUploadEventBase):
"""Event model for bulk_upload_progress"""
__event_name__ = "bulk_upload_progress"
completed: int = Field(description="The completed number of images")
total: int = Field(description="The total number of images")
@classmethod
def build(
cls,
bulk_upload_id: str,
completed: int,
total: int,
) -> "BulkUploadProgressEvent":
return cls(
bulk_upload_id=bulk_upload_id,
completed=completed,
total=total,
)
@payload_schema.register
class BulkUploadErrorEvent(BulkUploadEventBase):
"""Event model for bulk_upload_error"""
__event_name__ = "bulk_upload_error"
error: str = Field(description="The error message")
@classmethod
def build(
cls,
bulk_upload_id: str,
error: str,
) -> "BulkUploadErrorEvent":
return cls(bulk_upload_id=bulk_upload_id, error=error)

View File

@@ -34,6 +34,7 @@ class ImageFileStorageBase(ABC):
workflow: Optional[str] = None,
graph: Optional[str] = None,
thumbnail_size: int = 256,
project_id: Optional[str] = None,
) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""
pass

View File

@@ -54,6 +54,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
workflow: Optional[str] = None,
graph: Optional[str] = None,
thumbnail_size: int = 256,
project_id: Optional[str] = None,
) -> None:
try:
self.__validate_storage_folders()

View File

@@ -89,6 +89,8 @@ class ImageRecordStorageBase(ABC):
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[str] = None,
user_id: Optional[str] = None,
project_id: Optional[str] = None,
) -> datetime:
"""Saves an image record."""
pass

View File

@@ -344,6 +344,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[str] = None,
user_id: Optional[str] = None,
project_id: Optional[str] = None,
) -> datetime:
try:
self._lock.acquire()

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Optional
from typing import Callable, List, Optional
from PIL.Image import Image as PILImageType
@@ -10,7 +10,7 @@ from invokeai.app.services.image_records.image_records_common import (
ImageRecordChanges,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.images.images_common import ImageBulkUploadData, ImageDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
@@ -58,6 +58,11 @@ class ImageServiceABC(ABC):
"""Creates an image, storing the file and its metadata."""
pass
@abstractmethod
def create_many(self, bulk_upload_id: str, upload_data_list: List[ImageBulkUploadData]):
"""Creates an images array DTO out of an array of images, storing the images and their metadata"""
pass
@abstractmethod
def update(
self,

View File

@@ -1,6 +1,7 @@
from typing import Optional
from pydantic import Field
from PIL.Image import Image as PILImageType
from pydantic import BaseModel, Field
from invokeai.app.services.image_records.image_records_common import ImageRecord
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
@@ -39,3 +40,20 @@ def image_record_to_dto(
thumbnail_url=thumbnail_url,
board_id=board_id,
)
class ImageBulkUploadData(BaseModel):
image: PILImageType
image_name: Optional[str] = None
image_url: Optional[str] = None
board_id: Optional[str] = None
metadata: Optional[str] = None
workflow: Optional[str] = None
graph: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
user_id: Optional[str] = None
project_id: Optional[str] = None
class Config:
arbitrary_types_allowed = True

View File

@@ -1,6 +1,10 @@
from typing import Optional
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
from typing import List, Optional
from PIL.Image import Image as PILImageType
from tqdm import tqdm
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_files.image_files_common import (
@@ -20,7 +24,7 @@ from invokeai.app.services.image_records.image_records_common import (
ResourceOrigin,
)
from invokeai.app.services.images.images_base import ImageServiceABC
from invokeai.app.services.images.images_common import ImageDTO, image_record_to_dto
from invokeai.app.services.images.images_common import ImageBulkUploadData, ImageDTO, image_record_to_dto
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
@@ -96,6 +100,99 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error(f"Problem saving image record and file: {str(e)}")
raise e
def create_many(self, bulk_upload_id: str, upload_data_list: List[ImageBulkUploadData]) -> None:
total_images = len(upload_data_list)
processed_counter = 0 # Local counter
images_DTOs: list[ImageDTO] = [] # Collect ImageDTOs for successful uploads
progress_lock = Lock()
self.__invoker.services.events.emit_bulk_upload_started(
bulk_upload_id=bulk_upload_id,
total=total_images,
)
def process_and_save_image(image_data: ImageBulkUploadData):
nonlocal processed_counter # refer to the counter in the enclosing scope
try:
# processing and saving each image
width, height = image_data.image.size
image_data.width = width
image_data.height = height
image_name = self.__invoker.services.names.create_image_name()
image_data.image_name = image_name
self.__invoker.services.image_records.save(
image_name=image_data.image_name,
image_origin=ResourceOrigin.EXTERNAL,
image_category=ImageCategory.USER,
width=image_data.width,
height=image_data.height,
has_workflow=image_data.workflow is not None or image_data.graph is not None,
is_intermediate=False,
metadata=image_data.metadata,
user_id=image_data.user_id,
project_id=image_data.project_id,
)
if image_data.board_id is not None:
self.__invoker.services.board_image_records.add_image_to_board(
board_id=image_data.board_id, image_name=image_data.image_name
)
self.__invoker.services.image_files.save(
image_name=image_data.image_name,
image=image_data.image,
metadata=image_data.metadata,
workflow=image_data.workflow,
graph=image_data.graph,
project_id=image_data.project_id,
)
image_dto = self.get_dto(image_data.image_name)
self._on_changed(image_dto)
with progress_lock:
processed_counter += 1
return image_dto
except ImageRecordSaveException:
self.__invoker.services.logger.error("Failed to save image record")
raise
except ImageFileSaveException:
self.__invoker.services.logger.error("Failed to save image file")
raise
except Exception as e:
self.__invoker.services.logger.error(f"Problem processing and saving image: {str(e)}")
raise e
# Determine the number of available CPU cores
num_cores = os.cpu_count() or 1
num_workers = max(1, num_cores - 1)
# Initialize tqdm progress bar
pbar = tqdm(total=total_images, desc="Processing Images", unit="images", colour="green")
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_and_save_image, image) for image in upload_data_list]
for future in as_completed(futures):
try:
image_DTO = future.result()
images_DTOs.append(image_DTO)
pbar.update(1) # Update progress bar
self.__invoker.services.events.emit_bulk_upload_progress(
bulk_upload_id=bulk_upload_id,
completed=processed_counter,
total=total_images,
)
except Exception as e:
self.__invoker.services.logger.error(f"Error in processing image: {str(e)}")
self.__invoker.services.events.emit_bulk_upload_error(bulk_upload_id=bulk_upload_id, error=str(e))
pbar.close()
self.__invoker.services.events.emit_bulk_upload_complete(
bulk_upload_id=bulk_upload_id, total=len(images_DTOs), image_DTO=images_DTOs[0]
)
def update(
self,
image_name: str,

View File

@@ -324,6 +324,15 @@
"bulkDownloadRequestedDesc": "Your download request is being prepared. This may take a few moments.",
"bulkDownloadRequestFailed": "Problem Preparing Download",
"bulkDownloadFailed": "Download Failed",
"bulkUploadRequested": "Preparing Upload",
"bulkUploadStarted": "Uploading Images",
"bulkUploadStartedDesc": "Starting upload of {{x}} images",
"bulkUploadProgressDesc": "Uploading {{y}} of {{x}} images",
"bulkUploadComplete": "Upload Complete",
"bulkUploadCompleteDesc": "Successfully uploaded {{x}} images.",
"bulkUploadRequestFailed": "Problem Preparing Download",
"bulkUploadFailed": "Upload Failed or Partially Failed",
"viewerImage": "Viewer Image",
"compareImage": "Compare Image",
"openInViewer": "Open in Viewer",

View File

@@ -78,22 +78,9 @@ export const addImageUploadedFulfilledListener = (startAppListening: AppStartLis
lastUploadedToastTimeout = window.setTimeout(() => {
toastApi.close();
}, 3000);
/**
* We only want to change the board and view if this is the first upload of a batch, else we end up hijacking
* the user's gallery board and view selection:
* - User uploads multiple images
* - A couple uploads finish, but others are pending still
* - User changes the board selection
* - Pending uploads finish and change the board back to the original board
* - User is confused as to why the board changed
*
* Default to true to not require _all_ image upload handlers to set this value
*/
const isFirstUploadOfBatch = action.meta.arg.originalArgs.isFirstUploadOfBatch ?? true;
if (isFirstUploadOfBatch) {
dispatch(boardIdSelected({ boardId }));
dispatch(galleryViewChanged('assets'));
}
dispatch(boardIdSelected({ boardId }));
dispatch(galleryViewChanged('assets'));
return;
}

View File

@@ -1,4 +1,5 @@
import { logger } from 'app/logging/logger';
import { $queueId } from 'app/store/nanostores/queueId';
import { useAppSelector } from 'app/store/storeHooks';
import { useAssertSingleton } from 'common/hooks/useAssertSingleton';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
@@ -9,7 +10,7 @@ import { useCallback, useEffect, useState } from 'react';
import type { Accept, FileRejection } from 'react-dropzone';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import { useBulkUploadImagesMutation, useUploadImageMutation } from 'services/api/endpoints/images';
import type { PostUploadAction } from 'services/api/types';
const log = logger('gallery');
@@ -25,6 +26,7 @@ export const useFullscreenDropzone = () => {
const autoAddBoardId = useAppSelector(selectAutoAddBoardId);
const [isHandlingUpload, setIsHandlingUpload] = useState<boolean>(false);
const [uploadImage] = useUploadImageMutation();
const [bulkUploadImages] = useBulkUploadImagesMutation();
const activeTabName = useAppSelector(selectActiveTab);
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
@@ -37,7 +39,7 @@ export const useFullscreenDropzone = () => {
}, [activeTabName]);
const onDrop = useCallback(
(acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
async (acceptedFiles: Array<File>, fileRejections: Array<FileRejection>) => {
if (fileRejections.length > 0) {
const errors = fileRejections.map((rejection) => ({
errors: rejection.errors.map(({ message }) => message),
@@ -60,22 +62,39 @@ export const useFullscreenDropzone = () => {
return;
}
for (const [i, file] of acceptedFiles.entries()) {
if (acceptedFiles.length > 1) {
try {
toast({
id: 'BULK_UPLOAD',
title: t('gallery.bulkUploadRequested'),
status: 'info',
duration: null,
});
await bulkUploadImages({
bulk_upload_id: $queueId.get(),
files: acceptedFiles,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
}).unwrap();
} catch (error) {
toast({
status: 'error',
title: t('gallery.bulkUploadRequestFailed'),
});
throw error;
}
} else if (acceptedFiles[0]) {
uploadImage({
file,
file: acceptedFiles[0],
image_category: 'user',
is_intermediate: false,
postUploadAction: getPostUploadAction(),
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
// The `imageUploaded` listener does some extra logic, like switching to the asset view on upload on the
// first upload of a "batch".
isFirstUploadOfBatch: i === 0,
});
}
setIsHandlingUpload(false);
},
[t, maxImageUploadCount, uploadImage, getPostUploadAction, autoAddBoardId]
[t, maxImageUploadCount, uploadImage, getPostUploadAction, autoAddBoardId, bulkUploadImages]
);
const onDragOver = useCallback(() => {

View File

@@ -1,4 +1,5 @@
import { logger } from 'app/logging/logger';
import { $queueId } from 'app/store/nanostores/queueId';
import { useAppSelector } from 'app/store/storeHooks';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectMaxImageUploadCount } from 'features/system/store/configSlice';
@@ -7,7 +8,7 @@ import { useCallback } from 'react';
import type { FileRejection } from 'react-dropzone';
import { useDropzone } from 'react-dropzone';
import { useTranslation } from 'react-i18next';
import { useUploadImageMutation } from 'services/api/endpoints/images';
import { useBulkUploadImagesMutation, useUploadImageMutation } from 'services/api/endpoints/images';
import type { PostUploadAction } from 'services/api/types';
type UseImageUploadButtonArgs = {
@@ -46,21 +47,41 @@ export const useImageUploadButton = ({
const [uploadImage] = useUploadImageMutation();
const maxImageUploadCount = useAppSelector(selectMaxImageUploadCount);
const { t } = useTranslation();
const [bulkUploadImages] = useBulkUploadImagesMutation();
const onDropAccepted = useCallback(
(files: File[]) => {
for (const [i, file] of files.entries()) {
async (files: File[]) => {
if (files.length > 1) {
try {
toast({
id: 'BULK_UPLOAD',
title: t('gallery.bulkUploadRequested'),
status: 'info',
duration: null,
});
await bulkUploadImages({
bulk_upload_id: $queueId.get(),
files,
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
}).unwrap();
} catch (error) {
toast({
status: 'error',
title: t('gallery.bulkUploadRequestFailed'),
});
throw error;
}
} else if (files[0]) {
uploadImage({
file,
file: files[0],
image_category: 'user',
is_intermediate: false,
postUploadAction: postUploadAction ?? { type: 'TOAST' },
board_id: autoAddBoardId === 'none' ? undefined : autoAddBoardId,
isFirstUploadOfBatch: i === 0,
});
}
},
[autoAddBoardId, postUploadAction, uploadImage]
[autoAddBoardId, postUploadAction, uploadImage, bulkUploadImages, t]
);
const onDropRejected = useCallback(

View File

@@ -5,6 +5,7 @@ import type { BoardId } from 'features/gallery/store/types';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
import type { components, paths } from 'services/api/schema';
import type {
BulkUploadImageResponse,
DeleteBoardResult,
GraphAndWorkflowResponse,
ImageCategory,
@@ -272,7 +273,6 @@ export const imagesApi = api.injectEndpoints({
board_id?: string;
crop_visible?: boolean;
metadata?: SerializableObject;
isFirstUploadOfBatch?: boolean;
}
>({
query: ({ file, image_category, is_intermediate, session_id, board_id, crop_visible, metadata }) => {
@@ -321,6 +321,31 @@ export const imagesApi = api.injectEndpoints({
];
},
}),
bulkUploadImages: build.mutation<
BulkUploadImageResponse,
{
bulk_upload_id: string;
files: File[];
board_id?: string;
}
>({
query: ({ bulk_upload_id, files, board_id }) => {
const formData = new FormData();
for (const file of files) {
formData.append('files', file);
}
return {
url: buildImagesUrl('bulk-upload'),
method: 'POST',
body: formData,
params: {
bulk_upload_id,
board_id: board_id === 'none' ? undefined : board_id,
},
};
},
}),
deleteBoard: build.mutation<DeleteBoardResult, string>({
query: (board_id) => ({ url: buildBoardsUrl(board_id), method: 'DELETE' }),
@@ -556,6 +581,7 @@ export const {
useGetImageWorkflowQuery,
useLazyGetImageWorkflowQuery,
useUploadImageMutation,
useBulkUploadImagesMutation,
useClearIntermediatesMutation,
useAddImagesToBoardMutation,
useRemoveImagesFromBoardMutation,

View File

@@ -432,6 +432,26 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v1/images/bulk-upload": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
/**
* Bulk Upload
* @description Uploads multiple images
*/
post: operations["bulk_upload"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/images/upload": {
parameters: {
query?: never;
@@ -2065,6 +2085,11 @@ export type components = {
*/
image_names: string[];
};
/** Body_bulk_upload */
Body_bulk_upload: {
/** Files */
files: Blob[];
};
/** Body_cancel_by_batch_ids */
Body_cancel_by_batch_ids: {
/**
@@ -2555,6 +2580,104 @@ export type components = {
*/
bulk_download_item_name: string;
};
/**
* BulkUploadCompletedEvent
* @description Event model for bulk_upload_completed
*/
BulkUploadCompletedEvent: {
/**
* Timestamp
* @description The timestamp of the event
*/
timestamp: number;
/**
* Bulk Upload Id
* @description The ID of the bulk image download
*/
bulk_upload_id: string;
/**
* Total
* @description The total numberof images
*/
total: number;
/** @description An image from the upload so client can refetch correctly */
image_DTO: components["schemas"]["ImageDTO"];
};
/**
* BulkUploadErrorEvent
* @description Event model for bulk_upload_error
*/
BulkUploadErrorEvent: {
/**
* Timestamp
* @description The timestamp of the event
*/
timestamp: number;
/**
* Bulk Upload Id
* @description The ID of the bulk image download
*/
bulk_upload_id: string;
/**
* Error
* @description The error message
*/
error: string;
};
/** BulkUploadImageResponse */
BulkUploadImageResponse: {
/** Sent */
sent: number;
/** Uploading */
uploading: number;
};
/**
* BulkUploadProgressEvent
* @description Event model for bulk_upload_progress
*/
BulkUploadProgressEvent: {
/**
* Timestamp
* @description The timestamp of the event
*/
timestamp: number;
/**
* Bulk Upload Id
* @description The ID of the bulk image download
*/
bulk_upload_id: string;
/**
* Completed
* @description The completed number of images
*/
completed: number;
/**
* Total
* @description The total number of images
*/
total: number;
};
/**
* BulkUploadStartedEvent
* @description Event model for bulk_upload_started
*/
BulkUploadStartedEvent: {
/**
* Timestamp
* @description The timestamp of the event
*/
timestamp: number;
/**
* Bulk Upload Id
* @description The ID of the bulk image download
*/
bulk_upload_id: string;
/**
* Total
* @description The total numberof images
*/
total: number;
};
/**
* CLIPEmbedDiffusersConfig
* @description Model config for Clip Embeddings.
@@ -18273,6 +18396,50 @@ export interface operations {
};
};
};
bulk_upload: {
parameters: {
query: {
bulk_upload_id: string;
/** @description The board to add this images to, if any */
board_id?: string | null;
};
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"multipart/form-data": components["schemas"]["Body_bulk_upload"];
};
};
responses: {
/** @description The images are being prepared for upload */
201: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["BulkUploadImageResponse"];
};
};
/** @description Images upload failed */
415: {
headers: {
[name: string]: unknown;
};
content?: never;
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
upload_image: {
parameters: {
query: {

View File

@@ -244,3 +244,5 @@ export type PostUploadAction =
export type BoardRecordOrderBy = S['BoardRecordOrderBy'];
export type StarterModel = S['StarterModel'];
export type BulkUploadImageResponse = S['BulkUploadImageResponse'];

View File

@@ -1,4 +1,4 @@
import { ExternalLink } from '@invoke-ai/ui-library';
import { ExternalLink, Flex, Progress, Text } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { socketConnected } from 'app/store/middleware/listenerMiddleware/listeners/socketConnected';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
@@ -15,8 +15,11 @@ import { t } from 'i18next';
import { forEach, isNil, round } from 'lodash-es';
import type { ApiTagDescription } from 'services/api';
import { api, LIST_TAG } from 'services/api';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import { modelsApi } from 'services/api/endpoints/models';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { buildOnInvocationComplete } from 'services/events/onInvocationComplete';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import type { Socket } from 'socket.io-client';
@@ -42,6 +45,7 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
dispatch(socketConnected());
const queue_id = $queueId.get();
socket.emit('subscribe_queue', { queue_id });
socket.emit('subscribe_bulk_upload', { bulk_upload_id: $queueId.get() });
if (!$baseUrl.get()) {
const bulk_download_id = $bulkDownloadId.get();
socket.emit('subscribe_bulk_download', { bulk_download_id });
@@ -485,4 +489,100 @@ export const setEventListeners = ({ socket, store, setIsConnected }: SetEventLis
duration: null,
});
});
socket.on('bulk_upload_started', (data) => {
log.debug({ data }, 'Bulk gallery upload started');
const { total } = data;
toast({
id: 'BULK_UPLOAD',
title: t('gallery.bulkUploadStarted'),
status: 'info',
updateDescription: true,
withCount: false,
description: (
<Flex flexDir="column" gap={2}>
<Text>{t('gallery.bulkUploadStartedDesc', { x: total })}</Text>
<Progress value={0} />
</Flex>
),
duration: null,
});
});
socket.on('bulk_upload_progress', (data) => {
log.debug({ data }, 'Bulk gallery upload ready');
const { completed, total } = data;
toast({
id: 'BULK_UPLOAD',
title: t('gallery.bulkUploadStarted'),
status: 'info',
updateDescription: true,
withCount: false,
description: (
<Flex flexDir="column" gap={2}>
<Text>{t('gallery.bulkUploadProgressDesc', { x: total, y: completed })}</Text>
<Progress value={(completed / total) * 100} />
</Flex>
),
duration: null,
});
});
socket.on('bulk_upload_completed', (data) => {
log.debug({ data }, 'Bulk gallery upload ready');
const { total, image_DTO: imageDTO } = data;
toast({
id: 'BULK_UPLOAD',
title: t('gallery.bulkUploadComplete'),
status: 'success',
updateDescription: true,
withCount: false,
description: (
<Flex flexDir="column" gap={2}>
<Text>{t('gallery.bulkUploadCompleteDesc', { x: total })}</Text>
<Progress value={100} />
</Flex>
),
duration: null,
});
// update the total images for the board
dispatch(
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
draft.total += 1;
})
);
dispatch(
imagesApi.util.invalidateTags([
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
{
type: 'ImageList',
id: getListImagesUrl({
board_id: imageDTO.board_id ?? 'none',
categories: getCategories(imageDTO),
}),
},
])
);
});
socket.on('bulk_upload_error', (data) => {
log.error({ data }, 'Bulk gallery download error');
const { error } = data;
toast({
id: 'BULK_UPLOAD',
title: t('gallery.bulkUploadFailed'),
status: 'error',
updateDescription: true,
withCount: false,
description: error,
duration: null,
});
});
};

View File

@@ -5,6 +5,8 @@ type ClientEmitSubscribeQueue = { queue_id: string };
type ClientEmitUnsubscribeQueue = ClientEmitSubscribeQueue;
type ClientEmitSubscribeBulkDownload = { bulk_download_id: string };
type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload;
type ClientEmitSubscribeBulkUpload = { bulk_upload_id: string };
type ClientEmitUnsubscribeBulkUpload = ClientEmitSubscribeBulkUpload;
export type ServerToClientEvents = {
invocation_progress: (payload: S['InvocationProgressEvent']) => void;
@@ -31,6 +33,10 @@ export type ServerToClientEvents = {
bulk_download_started: (payload: S['BulkDownloadStartedEvent']) => void;
bulk_download_complete: (payload: S['BulkDownloadCompleteEvent']) => void;
bulk_download_error: (payload: S['BulkDownloadErrorEvent']) => void;
bulk_upload_started: (payload: S['BulkUploadStartedEvent']) => void;
bulk_upload_completed: (payload: S['BulkUploadCompletedEvent']) => void;
bulk_upload_progress: (payload: S['BulkUploadProgressEvent']) => void;
bulk_upload_error: (payload: S['BulkUploadErrorEvent']) => void;
};
export type ClientToServerEvents = {
@@ -40,6 +46,8 @@ export type ClientToServerEvents = {
unsubscribe_queue: (payload: ClientEmitUnsubscribeQueue) => void;
subscribe_bulk_download: (payload: ClientEmitSubscribeBulkDownload) => void;
unsubscribe_bulk_download: (payload: ClientEmitUnsubscribeBulkDownload) => void;
subscribe_bulk_upload: (payload: ClientEmitSubscribeBulkUpload) => void;
unsubscribe_bulk_upload: (payload: ClientEmitUnsubscribeBulkUpload) => void;
};
export type AppSocket = Socket<ServerToClientEvents, ClientToServerEvents>;