diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 6aedf47a1b..6f064df6cd 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -31,6 +31,7 @@ from invokeai.app.services.events.events_common import ( ModelInstallErrorEvent, ModelInstallStartedEvent, ModelLoadCompleteEvent, + ModelLoadEventBase, ModelLoadStartedEvent, QueueClearedEvent, QueueEventBase, @@ -134,7 +135,7 @@ class SocketIO: async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]): await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id) - async def _handle_model_load_event(self, event: FastAPIEvent[ModelEventBase]) -> None: + async def _handle_model_load_event(self, event: FastAPIEvent[ModelLoadEventBase]) -> None: await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id) async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None: diff --git a/invokeai/app/services/events/events_common.py b/invokeai/app/services/events/events_common.py index 7e4c17b800..4d70d14b92 100644 --- a/invokeai/app/services/events/events_common.py +++ b/invokeai/app/services/events/events_common.py @@ -383,20 +383,19 @@ class DownloadErrorEvent(DownloadEventBase): return cls(source=str(job.source), error_type=job.error_type, error=job.error) -class ModelEventBase(EventBase): +class ModelLoadEventBase(EventBase): """Base class for queue events""" queue_id: str = Field(description="The ID of the queue") @payload_schema.register -class ModelLoadStartedEvent(ModelEventBase): +class ModelLoadStartedEvent(ModelLoadEventBase): """Event model for model_load_started""" __event_name__ = "model_load_started" config: AnyModelConfig = Field(description="The model's config") - queue_id: str = Field(description="Queue ID to emit to") submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") @classmethod @@ -405,13 +404,12 @@ class ModelLoadStartedEvent(ModelEventBase): @payload_schema.register -class ModelLoadCompleteEvent(ModelEventBase): +class ModelLoadCompleteEvent(ModelLoadEventBase): """Event model for model_load_complete""" __event_name__ = "model_load_complete" config: AnyModelConfig = Field(description="The model's config") - queue_id: str = Field(description="Queue ID to emit to") submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any") @classmethod @@ -419,6 +417,9 @@ class ModelLoadCompleteEvent(ModelEventBase): return cls(config=config, queue_id=queue_id, submodel_type=submodel_type) +class ModelEventBase(EventBase): + """Base class for model events""" + @payload_schema.register class ModelInstallDownloadStartedEvent(ModelEventBase): """Event model for model_install_download_started"""