diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 08ed2edbf8..bdc6f287f5 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -49,6 +49,38 @@ class BulkDownloadSubscriptionEvent(BaseModel): bulk_download_id: str +QUEUE_EVENTS = { + InvocationStartedEvent, + InvocationDenoiseProgressEvent, + InvocationCompleteEvent, + InvocationErrorEvent, + SessionStartedEvent, + SessionCompleteEvent, + SessionCanceledEvent, + QueueItemStatusChangedEvent, + BatchEnqueuedEvent, + QueueClearedEvent, +} + +MODEL_EVENTS = { + DownloadCancelledEvent, + DownloadCompleteEvent, + DownloadErrorEvent, + DownloadProgressEvent, + DownloadStartedEvent, + ModelLoadStartedEvent, + ModelLoadCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallStartedEvent, + ModelInstallCompleteEvent, + ModelInstallCancelledEvent, + ModelInstallErrorEvent, +} + +BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent} + + class SocketIO: _sub_queue = "subscribe_queue" _unsub_queue = "unsubscribe_queue" @@ -66,45 +98,9 @@ class SocketIO: self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download) self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download) - register_events( - { - InvocationStartedEvent, - InvocationDenoiseProgressEvent, - InvocationCompleteEvent, - InvocationErrorEvent, - SessionStartedEvent, - SessionCompleteEvent, - SessionCanceledEvent, - QueueItemStatusChangedEvent, - BatchEnqueuedEvent, - QueueClearedEvent, - }, - self._handle_queue_event, - ) - - register_events( - { - DownloadCancelledEvent, - DownloadCompleteEvent, - DownloadErrorEvent, - DownloadProgressEvent, - DownloadStartedEvent, - ModelLoadStartedEvent, - ModelLoadCompleteEvent, - ModelInstallDownloadProgressEvent, - ModelInstallDownloadsCompleteEvent, - ModelInstallStartedEvent, - ModelInstallCompleteEvent, - ModelInstallCancelledEvent, - ModelInstallErrorEvent, - }, - self._handle_model_event, - ) - - register_events( - {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}, - self._handle_bulk_image_download_event, - ) + register_events(QUEUE_EVENTS, self._handle_queue_event) + register_events(MODEL_EVENTS, self._handle_model_event) + register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event) async def _handle_sub_queue(self, sid: str, data: Any) -> None: await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)