Compare commits

..

2 Commits

Author SHA1 Message Date
psychedelicious
66c9f4708d Update invokeai_version.py 2024-05-21 07:11:09 +10:00
steffylo
32277193b6 fix(ui): retain denoise strength and opacity when changing image 2024-05-20 18:27:51 +10:00
54 changed files with 1876 additions and 3018 deletions

View File

@@ -18,7 +18,6 @@ from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_default import BulkDownloadService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.events.events_fastapievents import FastAPIEventService
from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService
@@ -34,6 +33,7 @@ from ..services.session_processor.session_processor_default import DefaultSessio
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
# TODO: is there a better way to achieve this?

View File

@@ -0,0 +1,52 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from typing import Any
from fastapi_events.dispatcher import dispatch
from ..services.events.events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
event_handler_id: int
__queue: Queue
__stop_event: threading.Event
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self.__queue = Queue()
self.__stop_event = threading.Event()
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self.__stop_event.set()
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put({"event_name": event_name, "payload": payload})
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self.__queue.get(block=False)
if not event: # Probably stopping
continue
dispatch(
event.get("event_name"),
payload=event.get("payload"),
middleware_id=self.event_handler_id,
)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,

View File

@@ -203,7 +203,6 @@ async def get_batch_status(
responses={
200: {"model": SessionQueueItem},
},
response_model_exclude_none=True,
)
async def get_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),

View File

@@ -1,131 +1,66 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from socketio import ASGIApp, AsyncServer
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadEventBase,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelEventBase,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEventBase,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
SessionCompleteEvent,
SessionStartedEvent,
register_events,
)
class QueueSubscriptionEvent(BaseModel):
queue_id: str
class BulkDownloadSubscriptionEvent(BaseModel):
bulk_download_id: str
from ..services.events.events_base import EventServiceBase
class SocketIO:
_sub_queue = "subscribe_queue"
_unsub_queue = "unsubscribe_queue"
__sio: AsyncServer
__app: ASGIApp
_sub_bulk_download = "subscribe_bulk_download"
_unsub_bulk_download = "unsubscribe_bulk_download"
__sub_queue: str = "subscribe_queue"
__unsub_queue: str = "unsubscribe_queue"
__sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI):
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
app.mount("/ws", self._app)
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
app.mount("/ws", self.__app)
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
register_events(
{
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
SessionStartedEvent,
SessionCompleteEvent,
SessionCanceledEvent,
QueueItemStatusChangedEvent,
BatchEnqueuedEvent,
QueueClearedEvent,
},
self._handle_queue_event,
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["queue_id"],
)
register_events(
{
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
ModelInstallCancelledEvent,
ModelInstallErrorEvent,
},
self._handle_model_event,
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.enter_room(sid, data["queue_id"])
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
async def _handle_bulk_download_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["bulk_download_id"],
)
register_events(
{BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent},
self._handle_bulk_image_download_event,
)
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.enter_room(sid, data["bulk_download_id"])
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"), room=payload.queue_id)
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase]) -> None:
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"))
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"), room=payload.bulk_download_id)
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.leave_room(sid, data["bulk_download_id"])

View File

@@ -27,7 +27,6 @@ import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice
@@ -183,14 +182,23 @@ def custom_openapi() -> dict[str, Any]:
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation"
# Add all event schemas
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
if "$defs" in json_schema:
for schema_key, schema in json_schema["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema
del json_schema["$defs"]
openapi_schema["components"]["schemas"][event.__name__] = json_schema
# This code no longer seems to be necessary?
# Leave it here just in case
#
# from invokeai.backend.model_manager import get_model_config_formats
# formats = get_model_config_formats()
# for model_config_name, enum_set in formats.items():
# if model_config_name in openapi_schema["components"]["schemas"]:
# # print(f"Config with name {name} already defined")
# continue
# openapi_schema["components"]["schemas"][model_config_name] = {
# "title": model_config_name,
# "description": "An enumeration.",
# "type": "string",
# "enum": [v.value for v in enum_set],
# }
app.openapi_schema = openapi_schema
return app.openapi_schema

View File

@@ -106,7 +106,9 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started(
bulk_download_id, bulk_download_item_id, bulk_download_item_name
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
def _signal_job_completed(
@@ -116,8 +118,10 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_complete(
bulk_download_id, bulk_download_item_id, bulk_download_item_name
self._invoker.services.events.emit_bulk_download_completed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
def _signal_job_failed(
@@ -127,8 +131,11 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
assert exception is not None
self._invoker.services.events.emit_bulk_download_error(
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
self._invoker.services.events.emit_bulk_download_failed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=str(exception),
)
def stop(self, *args, **kwargs):

View File

@@ -8,13 +8,14 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.util.logging import InvokeAILogger
@@ -29,9 +30,6 @@ from .download_base import (
UnknownJobIDException,
)
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
# Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000
@@ -42,7 +40,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__(
self,
max_parallel_dl: int = 5,
event_bus: Optional["EventServiceBase"] = None,
event_bus: Optional[EventServiceBase] = None,
requests_session: Optional[requests.sessions.Session] = None,
):
"""
@@ -345,7 +343,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_started(job)
assert job.download_path
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
def _signal_job_progress(self, job: DownloadJob) -> None:
if job.on_progress:
@@ -356,7 +355,13 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_progress(job)
assert job.download_path
self._event_bus.emit_download_progress(
str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED
@@ -368,7 +373,10 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_complete(job)
assert job.download_path
self._event_bus.emit_download_complete(
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
)
def _signal_job_cancelled(self, job: DownloadJob) -> None:
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
@@ -382,7 +390,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_cancelled(job)
self._event_bus.emit_download_cancelled(str(job.source))
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
job.status = DownloadJobStatus.ERROR
@@ -395,7 +403,9 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_error(job)
assert job.error_type
assert job.error
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")

View File

@@ -1,253 +1,490 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import TYPE_CHECKING, Optional
from typing import Any, Dict, List, Optional, Union
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
EventBase,
ExtraData,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
SessionCompleteEvent,
SessionStartedEvent,
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.config import SubModelType
class EventServiceBase:
queue_event: str = "queue_event"
bulk_download_event: str = "bulk_download_event"
download_event: str = "download_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event: "EventBase") -> None:
def dispatch(self, event_name: str, payload: Any) -> None:
pass
# region: Invocation
def emit_invocation_started(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", extra: Optional[ExtraData] = None
) -> None:
"""Emitted when an invocation is started"""
self.dispatch(InvocationStartedEvent.build(queue_item, invocation, extra))
def emit_invocation_denoise_progress(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
intermediate_state: PipelineIntermediateState,
progress_image: "ProgressImage",
extra: Optional[ExtraData] = None,
) -> None:
"""Emitted at each step during denoising of an invocation."""
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
"""Bulk download events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image, extra)
event_name=EventServiceBase.bulk_download_event,
payload={"event": event_name, "data": payload},
)
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
"""Queue events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.queue_event,
payload={"event": event_name, "data": payload},
)
def __emit_download_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.download_event,
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node_id: str,
source_node_id: str,
progress_image: Optional[ProgressImage],
step: int,
order: int,
total_steps: int,
) -> None:
"""Emitted when there is generation progress"""
self.__emit_queue_event(
event_name="generator_progress",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"source_node_id": source_node_id,
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
"step": step,
"order": order,
"total_steps": total_steps,
},
)
def emit_invocation_complete(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
output: "BaseInvocationOutput",
extra: Optional[ExtraData] = None,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation is complete"""
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output, extra))
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"result": result,
},
)
def emit_invocation_error(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
error_type: str,
error: str,
extra: Optional[ExtraData] = None,
user_id: str | None,
project_id: str | None,
) -> None:
"""Emitted when an invocation encounters an error"""
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error, extra))
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"error_type": error_type,
"error": error,
"user_id": user_id,
"project_id": project_id,
},
)
# endregion
# region Session
def emit_session_started(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a session has started"""
self.dispatch(SessionStartedEvent.build(queue_item, extra))
def emit_session_complete(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a session has completed all invocations"""
self.dispatch(SessionCompleteEvent.build(queue_item, extra))
def emit_session_canceled(self, queue_item: "SessionQueueItem", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a session is canceled"""
self.dispatch(SessionCanceledEvent.build(queue_item, extra))
# endregion
# region Queue
def emit_queue_item_status_changed(
def emit_invocation_started(
self,
queue_item: "SessionQueueItem",
batch_status: "BatchStatus",
queue_status: "SessionQueueStatus",
extra: Optional[ExtraData] = None,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when a queue item's status changes"""
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status, extra))
"""Emitted when an invocation has started"""
self.__emit_queue_event(
event_name="invocation_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
},
)
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a batch is enqueued"""
self.dispatch(BatchEnqueuedEvent.build(enqueue_result, extra))
def emit_queue_cleared(self, queue_id: str, extra: Optional[ExtraData] = None) -> None:
"""Emitted when a queue is cleared"""
self.dispatch(QueueClearedEvent.build(queue_id, extra))
# endregion
# region Download
def emit_download_started(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a download is started"""
self.dispatch(DownloadStartedEvent.build(job, extra))
def emit_download_progress(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted at intervals during a download"""
self.dispatch(DownloadProgressEvent.build(job, extra))
def emit_download_complete(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a download is completed"""
self.dispatch(DownloadCompleteEvent.build(job, extra))
def emit_download_cancelled(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a download is cancelled"""
self.dispatch(DownloadCancelledEvent.build(job, extra))
def emit_download_error(self, job: "DownloadJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when a download encounters an error"""
self.dispatch(DownloadErrorEvent.build(job, extra))
# endregion
# region Model loading
def emit_graph_execution_complete(
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_queue_event(
event_name="graph_execution_state_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_model_load_started(
self,
config: "AnyModelConfig",
submodel_type: Optional["SubModelType"] = None,
extra: Optional[ExtraData] = None,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model load is started."""
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type, extra))
"""Emitted when a model is requested"""
self.__emit_queue_event(
event_name="model_load_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_model_load_complete(
def emit_model_load_completed(
self,
config: "AnyModelConfig",
submodel_type: Optional["SubModelType"] = None,
extra: Optional[ExtraData] = None,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model load is complete."""
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type, extra))
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
event_name="model_load_completed",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
# endregion
def emit_session_canceled(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
) -> None:
"""Emitted when a session is canceled"""
self.__emit_queue_event(
event_name="session_canceled",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
# region Model install
def emit_queue_item_status_changed(
self,
session_queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
) -> None:
"""Emitted when a queue item's status changes"""
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload={
"queue_id": queue_status.queue_id,
"queue_item": {
"queue_id": session_queue_item.queue_id,
"item_id": session_queue_item.item_id,
"status": session_queue_item.status,
"batch_id": session_queue_item.batch_id,
"session_id": session_queue_item.session_id,
"error": session_queue_item.error,
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
},
"batch_status": batch_status.model_dump(mode="json"),
"queue_status": queue_status.model_dump(mode="json"),
},
)
def emit_model_install_download_progress(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted at intervals while the install job is in progress (remote models only)."""
self.dispatch(ModelInstallDownloadProgressEvent.build(job, extra))
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
"""Emitted when a batch is enqueued"""
self.__emit_queue_event(
event_name="batch_enqueued",
payload={
"queue_id": enqueue_result.queue_id,
"batch_id": enqueue_result.batch.batch_id,
"enqueued": enqueue_result.enqueued,
},
)
def emit_model_install_downloads_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job, extra))
def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when the queue is cleared"""
self.__emit_queue_event(
event_name="queue_cleared",
payload={"queue_id": queue_id},
)
def emit_model_install_started(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted once when an install job is started (after any download)."""
self.dispatch(ModelInstallStartedEvent.build(job, extra))
def emit_download_started(self, source: str, download_path: str) -> None:
"""
Emit when a download job is started.
def emit_model_install_complete(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when an install job is completed successfully."""
self.dispatch(ModelInstallCompleteEvent.build(job, extra))
:param url: The downloaded url
"""
self.__emit_download_event(
event_name="download_started",
payload={"source": source, "download_path": download_path},
)
def emit_model_install_cancelled(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when an install job is cancelled."""
self.dispatch(ModelInstallCancelledEvent.build(job, extra))
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
"""
Emit "download_progress" events at regular intervals during a download job.
def emit_model_install_error(self, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> None:
"""Emitted when an install job encounters an exception."""
self.dispatch(ModelInstallErrorEvent.build(job, extra))
:param source: The downloaded source
:param download_path: The local downloaded file
:param current_bytes: Number of bytes downloaded so far
:param total_bytes: The size of the file being downloaded (if known)
"""
self.__emit_download_event(
event_name="download_progress",
payload={
"source": source,
"download_path": download_path,
"current_bytes": current_bytes,
"total_bytes": total_bytes,
},
)
# endregion
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
"""
Emit a "download_complete" event at the end of a successful download.
# region Bulk image download
:param source: Source URL
:param download_path: Path to the locally downloaded file
:param total_bytes: The size of the downloaded file
"""
self.__emit_download_event(
event_name="download_complete",
payload={
"source": source,
"download_path": download_path,
"total_bytes": total_bytes,
},
)
def emit_download_cancelled(self, source: str) -> None:
"""Emit a "download_cancelled" event in the event that the download was cancelled by user."""
self.__emit_download_event(
event_name="download_cancelled",
payload={
"source": source,
},
)
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
"""
Emit a "download_error" event when an download job encounters an exception.
:param source: Source URL
:param error_type: The name of the exception that raised the error
:param error: The traceback from this error
"""
self.__emit_download_event(
event_name="download_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)
def emit_model_install_downloading(
self,
source: str,
local_path: str,
bytes: int,
total_bytes: int,
parts: List[Dict[str, Union[str, int]]],
id: int,
) -> None:
"""
Emit at intervals while the install job is in progress (remote models only).
:param source: Source of the model
:param local_path: Where model is downloading to
:param parts: Progress of downloading URLs that comprise the model, if any.
:param bytes: Number of bytes downloaded so far.
:param total_bytes: Total size of download, including all files.
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
"""
self.__emit_model_event(
event_name="model_install_downloading",
payload={
"source": source,
"local_path": local_path,
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
"id": id,
},
)
def emit_model_install_downloads_done(self, source: str) -> None:
"""
Emit once when all parts are downloaded, but before the probing and registration start.
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_downloads_done",
payload={"source": source},
)
def emit_model_install_running(self, source: str) -> None:
"""
Emit once when an install job becomes active.
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_running",
payload={"source": source},
)
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
"""
Emit when an install job is completed successfully.
:param source: Source of the model; local path, repo_id or url
:param key: Model config record key
:param total_bytes: Size of the model (may be None for installation of a local path)
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
)
def emit_model_install_cancelled(self, source: str, id: int) -> None:
"""
Emit when an install job is cancelled.
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_cancelled",
payload={"source": source, "id": id},
)
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
"""
Emit when an install job encounters an exception.
:param source: Source of the model
:param error_type: The name of the exception
:param error: A text description of the exception
"""
self.__emit_model_event(
event_name="model_install_error",
payload={"source": source, "error_type": error_type, "error": error, "id": id},
)
def emit_bulk_download_started(
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk image download is started"""
self.dispatch(
BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
"""Emitted when a bulk download starts"""
self._emit_bulk_download_event(
event_name="bulk_download_started",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_complete(
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
def emit_bulk_download_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk image download is complete"""
self.dispatch(
BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, extra)
"""Emitted when a bulk download completes"""
self._emit_bulk_download_event(
event_name="bulk_download_completed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_error(
self,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
error: str,
extra: Optional[ExtraData] = None,
def emit_bulk_download_failed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None:
"""Emitted when a bulk image download has an error"""
self.dispatch(
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error, extra)
"""Emitted when a bulk download fails"""
self._emit_bulk_download_event(
event_name="bulk_download_failed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
"error": error,
},
)
# endregion

View File

@@ -1,707 +0,0 @@
from math import floor
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
ExtraData: TypeAlias = dict[str, Any]
class EventBase(BaseModel):
"""Base class for all events. All events must inherit from this class.
Events must define a class attribute `__event_name__` to identify the event.
All other attributes should be defined as normal for a pydantic model.
A timestamp is automatically added to the event when it is created.
"""
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
extra: Optional[ExtraData] = Field(default=None, description="Extra data to include with the event")
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
@classmethod
def get_events(cls) -> set[type["EventBase"]]:
"""Get a set of all event models."""
event_subclasses: set[type["EventBase"]] = set()
for subclass in cls.__subclasses__():
# We only want to include subclasses that are event models, not intermediary classes
if hasattr(subclass, "__event_name__"):
event_subclasses.add(subclass)
event_subclasses.update(subclass.get_events())
return event_subclasses
TEvent = TypeVar("TEvent", bound=EventBase)
FastAPIEvent: TypeAlias = tuple[str, TEvent]
"""
A tuple representing a `fastapi-events` event, with the event name and payload.
Provide a generic type to `TEvent` to specify the payload type.
"""
class FastAPIEventFunc(Protocol):
def __call__(self, event: FastAPIEvent[Any]) -> Optional[Coroutine[Any, Any, None]]: ...
def register_events(events: set[type[TEvent]], func: FastAPIEventFunc) -> None:
"""Register a function to handle a list of events.
:param events: A list of event classes to handle
:param func: The function to handle the events
"""
for event in events:
assert hasattr(event, "__event_name__")
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
class QueueEventBase(EventBase):
"""Base class for queue events"""
queue_id: str = Field(description="The ID of the queue")
class QueueItemEventBase(QueueEventBase):
"""Base class for queue item events"""
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
class SessionEventBase(QueueItemEventBase):
"""Base class for session (aka graph execution state) events"""
session_id: str = Field(description="The ID of the session (aka graph execution state)")
class InvocationEventBase(SessionEventBase):
"""Base class for invocation events"""
queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation_id: str = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
invocation_type: str = Field(description="The type of invocation")
class InvocationStartedEvent(InvocationEventBase):
"""Event model for invocation_started"""
__event_name__ = "invocation_started"
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, extra: Optional[ExtraData] = None
) -> "InvocationStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
extra=extra,
)
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress"""
__event_name__ = "invocation_denoise_progress"
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
step: int = Field(description="The current step of the invocation")
total_steps: int = Field(description="The total number of steps in the invocation")
order: int = Field(description="The order of the invocation in the session")
percentage: float = Field(description="The percentage of completion of the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
extra: Optional[ExtraData] = None,
) -> "InvocationDenoiseProgressEvent":
step = intermediate_state.step
total_steps = intermediate_state.total_steps
order = intermediate_state.order
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
progress_image=progress_image,
step=step,
total_steps=total_steps,
order=order,
percentage=cls.calc_percentage(step, total_steps, order),
extra=extra,
)
@staticmethod
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
"""Calculate the percentage of completion of denoising."""
if total_steps == 0:
return 0.0
if scheduler_order == 2:
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
# order == 1
return (step + 1 + 1) / (total_steps + 1)
class InvocationCompleteEvent(InvocationEventBase):
"""Event model for invocation_complete"""
__event_name__ = "invocation_complete"
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
result: BaseInvocationOutput,
extra: Optional[ExtraData] = None,
) -> "InvocationCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
result=result,
extra=extra,
)
class InvocationErrorEvent(InvocationEventBase):
"""Event model for invocation_error"""
__event_name__ = "invocation_error"
error_type: str = Field(description="The type of error")
error: str = Field(description="The error message")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: BaseInvocation,
error_type: str,
error: str,
extra: Optional[ExtraData] = None,
) -> "InvocationErrorEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation_id=invocation.id,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
invocation_type=invocation.get_type(),
error_type=error_type,
error=error,
extra=extra,
)
class SessionStartedEvent(SessionEventBase):
"""Event model for session_started"""
__event_name__ = "session_started"
@classmethod
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
extra=extra,
)
class SessionCompleteEvent(SessionEventBase):
"""Event model for session_complete"""
__event_name__ = "session_complete"
@classmethod
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
extra=extra,
)
class SessionCanceledEvent(SessionEventBase):
"""Event model for session_canceled"""
__event_name__ = "session_canceled"
@classmethod
def build(cls, queue_item: SessionQueueItem, extra: Optional[ExtraData] = None) -> "SessionCanceledEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
extra=extra,
)
class QueueItemStatusChangedEvent(QueueItemEventBase):
"""Event model for queue_item_status_changed"""
__event_name__ = "queue_item_status_changed"
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
error: Optional[str] = Field(default=None, description="The error message, if any")
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
batch_status: BatchStatus = Field(description="The status of the batch")
queue_status: SessionQueueStatus = Field(description="The status of the queue")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
extra: Optional[ExtraData] = None,
) -> "QueueItemStatusChangedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
status=queue_item.status,
error=queue_item.error,
created_at=str(queue_item.created_at) if queue_item.created_at else None,
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
started_at=str(queue_item.started_at) if queue_item.started_at else None,
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
extra=extra,
)
class BatchEnqueuedEvent(QueueEventBase):
"""Event model for batch_enqueued"""
__event_name__ = "batch_enqueued"
batch_id: str = Field(description="The ID of the batch")
enqueued: int = Field(description="The number of invocations enqueued")
requested: int = Field(
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
)
priority: int = Field(description="The priority of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult, extra: Optional[ExtraData] = None) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,
extra=extra,
)
class QueueClearedEvent(QueueEventBase):
"""Event model for queue_cleared"""
__event_name__ = "queue_cleared"
@classmethod
def build(cls, queue_id: str, extra: Optional[ExtraData] = None) -> "QueueClearedEvent":
return cls(
queue_id=queue_id,
extra=extra,
)
class DownloadEventBase(EventBase):
"""Base class for events associated with a download"""
source: str = Field(description="The source of the download")
class DownloadStartedEvent(DownloadEventBase):
"""Event model for download_started"""
__event_name__ = "download_started"
download_path: str = Field(description="The local path where the download is saved")
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadStartedEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
extra=extra,
)
class DownloadProgressEvent(DownloadEventBase):
"""Event model for download_progress"""
__event_name__ = "download_progress"
download_path: str = Field(description="The local path where the download is saved")
current_bytes: int = Field(description="The number of bytes downloaded so far")
total_bytes: int = Field(description="The total number of bytes to be downloaded")
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadProgressEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
extra=extra,
)
class DownloadCompleteEvent(DownloadEventBase):
"""Event model for download_complete"""
__event_name__ = "download_complete"
download_path: str = Field(description="The local path where the download is saved")
total_bytes: int = Field(description="The total number of bytes downloaded")
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCompleteEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
total_bytes=job.total_bytes,
extra=extra,
)
class DownloadCancelledEvent(DownloadEventBase):
"""Event model for download_cancelled"""
__event_name__ = "download_cancelled"
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadCancelledEvent":
return cls(
source=str(job.source),
extra=extra,
)
class DownloadErrorEvent(DownloadEventBase):
"""Event model for download_error"""
__event_name__ = "download_error"
error_type: str = Field(description="The type of error")
error: str = Field(description="The error message")
@classmethod
def build(cls, job: "DownloadJob", extra: Optional[ExtraData] = None) -> "DownloadErrorEvent":
assert job.error_type
assert job.error
return cls(
source=str(job.source),
error_type=job.error_type,
error=job.error,
extra=extra,
)
class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
class ModelLoadStartedEvent(ModelEventBase):
"""Event model for model_load_started"""
__event_name__ = "model_load_started"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
) -> "ModelLoadStartedEvent":
return cls(
config=config,
submodel_type=submodel_type,
extra=extra,
)
class ModelLoadCompleteEvent(ModelEventBase):
"""Event model for model_load_complete"""
__event_name__ = "model_load_complete"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None, extra: Optional[ExtraData] = None
) -> "ModelLoadCompleteEvent":
return cls(
config=config,
submodel_type=submodel_type,
extra=extra,
)
class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress"""
__event_name__ = "model_install_download_progress"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
parts: list[dict[str, int | str]] = Field(
description="Progress of downloading URLs that comprise the model, if any"
)
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadProgressEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
return cls(
id=job.id,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
extra=extra,
)
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
"""Emitted once when an install job becomes active."""
__event_name__ = "model_install_downloads_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallDownloadsCompleteEvent":
return cls(
id=job.id,
source=str(job.source),
extra=extra,
)
class ModelInstallStartedEvent(ModelEventBase):
"""Event model for model_install_started"""
__event_name__ = "model_install_started"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallStartedEvent":
return cls(
id=job.id,
source=str(job.source),
extra=extra,
)
class ModelInstallCompleteEvent(ModelEventBase):
"""Event model for model_install_complete"""
__event_name__ = "model_install_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
key: str = Field(description="Model config record key")
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCompleteEvent":
assert job.config_out is not None
return cls(
id=job.id,
source=str(job.source),
key=(job.config_out.key),
total_bytes=job.total_bytes,
extra=extra,
)
class ModelInstallCancelledEvent(ModelEventBase):
"""Event model for model_install_cancelled"""
__event_name__ = "model_install_cancelled"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallCancelledEvent":
return cls(
id=job.id,
source=str(job.source),
extra=extra,
)
class ModelInstallErrorEvent(ModelEventBase):
"""Event model for model_install_error"""
__event_name__ = "model_install_error"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
error_type: str = Field(description="The name of the exception")
error: str = Field(description="A text description of the exception")
@classmethod
def build(cls, job: "ModelInstallJob", extra: Optional[ExtraData] = None) -> "ModelInstallErrorEvent":
assert job.error_type is not None
assert job.error is not None
return cls(
id=job.id,
source=str(job.source),
error_type=job.error_type,
error=job.error,
extra=extra,
)
class BulkDownloadEventBase(EventBase):
"""Base class for events associated with a bulk image download"""
bulk_download_id: str = Field(description="The ID of the bulk image download")
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
class BulkDownloadStartedEvent(BulkDownloadEventBase):
"""Event model for bulk_download_started"""
__event_name__ = "bulk_download_started"
@classmethod
def build(
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
) -> "BulkDownloadStartedEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
extra=extra,
)
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
"""Event model for bulk_download_complete"""
__event_name__ = "bulk_download_complete"
@classmethod
def build(
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
extra: Optional[ExtraData] = None,
) -> "BulkDownloadCompleteEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
extra=extra,
)
class BulkDownloadErrorEvent(BulkDownloadEventBase):
"""Event model for bulk_download_error"""
__event_name__ = "bulk_download_error"
error: str = Field(description="The error message")
@classmethod
def build(
cls,
bulk_download_id: str,
bulk_download_item_id: str,
bulk_download_item_name: str,
error: str,
extra: Optional[ExtraData] = None,
) -> "BulkDownloadErrorEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=error,
extra=extra,
)

View File

@@ -1,46 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_common import (
EventBase,
)
from .events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self._queue = Queue[EventBase | None]()
self._stop_event = threading.Event()
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self._stop_event.set()
self._queue.put(None)
def dispatch(self, event: EventBase) -> None:
self._queue.put(event)
async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self._queue.get(block=False)
if not event: # Probably stopping
continue
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@@ -1,13 +1,11 @@
"""Initialization file for model install service package."""
from .model_install_base import (
ModelInstallServiceBase,
)
from .model_install_common import (
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
UnknownInstallJobException,
URLModelSource,

View File

@@ -1,19 +1,244 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team
"""Baseclass definitions for the model installer."""
import re
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
class ModelInstallServiceBase(ABC):
@@ -57,7 +282,7 @@ class ModelInstallServiceBase(ABC):
@property
@abstractmethod
def event_bus(self) -> Optional["EventServiceBase"]:
def event_bus(self) -> Optional[EventServiceBase]:
"""Return the event service base object associated with the installer."""
@abstractmethod

View File

@@ -1,233 +0,0 @@
import re
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]

View File

@@ -10,7 +10,7 @@ from pathlib import Path
from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
import torch
import yaml
@@ -20,8 +20,8 @@ from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import (
@@ -45,12 +45,13 @@ from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice
from .model_install_common import (
from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
StringLikeSource,
URLModelSource,
@@ -58,9 +59,6 @@ from .model_install_common import (
TMPDIR_PREFIX = "tmpinstall_"
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation."""
@@ -70,7 +68,7 @@ class ModelInstallService(ModelInstallServiceBase):
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
event_bus: Optional["EventServiceBase"] = None,
event_bus: Optional[EventServiceBase] = None,
session: Optional[Session] = None,
):
"""
@@ -106,7 +104,7 @@ class ModelInstallService(ModelInstallServiceBase):
return self._record_store
@property
def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus
# make the invoker optional here because we don't need it and it
@@ -857,17 +855,35 @@ class ModelInstallService(ModelInstallServiceBase):
job.status = InstallStatus.RUNNING
self._logger.info(f"Model install started: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_started(job)
self._event_bus.emit_model_install_running(str(job.source))
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus:
self._event_bus.emit_model_install_download_progress(job)
parts: List[Dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_downloading(
str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
id=job.id,
)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.DOWNLOADS_DONE
self._logger.info(f"Model download complete: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_downloads_complete(job)
self._event_bus.emit_model_install_downloads_done(str(job.source))
def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED
@@ -875,19 +891,24 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Model install complete: {job.source}")
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
if self._event_bus:
self._event_bus.emit_model_install_complete(job)
assert job.local_path is not None
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
if self._event_bus:
assert job.error_type is not None
assert job.error is not None
self._event_bus.emit_model_install_error(job)
error_type = job.error_type
error = job.error
assert error_type is not None
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"Model install canceled: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_cancelled(job)
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
@staticmethod
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:

View File

@@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
@@ -14,12 +15,18 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader."""
@abstractmethod
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context_data: Invocation context data used for event reporting
"""
@property

View File

@@ -5,6 +5,7 @@ from typing import Optional, Type
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import (
LoadedModel,
@@ -50,18 +51,25 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the checkpoint convert cache used by this loader."""
return self._convert_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
# We don't have an invoker during testing
# TODO(psyche): Mock this method on the invoker in the tests
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
if context_data:
self._emit_load_event(
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation(
@@ -71,7 +79,40 @@ class ModelLoadService(ModelLoadServiceBase):
convert_cache=self._convert_cache,
).load_model(model_config, submodel_type)
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
if context_data:
self._emit_load_event(
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
loaded=True,
)
return loaded_model
def _emit_load_event(
self,
context_data: InvocationContextData,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
submodel_type: Optional[SubModelType] = None,
) -> None:
if not self._invoker:
return
if not loaded:
self._invoker.services.events.emit_model_load_started(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)
else:
self._invoker.services.events.emit_model_load_completed(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)

View File

@@ -4,16 +4,11 @@ from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
FastAPIEvent,
QueueClearedEvent,
QueueEventBase,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
register_events,
)
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
@@ -36,6 +31,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now_event = ThreadEvent()
self._cancel_event = ThreadEvent()
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
self._thread_limit = thread_limit
self._thread_semaphore = BoundedSemaphore(thread_limit)
self._polling_interval = polling_interval
@@ -52,8 +49,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None
)
register_events({SessionCanceledEvent, QueueClearedEvent, BatchEnqueuedEvent}, self._on_queue_event)
self._thread = Thread(
name="session_processor",
target=self._process,
@@ -72,25 +67,30 @@ class DefaultSessionProcessor(SessionProcessorBase):
def _poll_now(self) -> None:
self._poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent[QueueEventBase]) -> None:
_event_name, payload = event
async def _on_queue_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
if (
isinstance(payload, SessionCanceledEvent)
event_name == "session_canceled"
and self._queue_item
and self._queue_item.item_id == payload.item_id
and self._queue_item.item_id == event[1]["data"]["queue_item_id"]
):
self._cancel_event.set()
self._poll_now()
elif (
isinstance(payload, QueueClearedEvent)
event_name == "queue_cleared"
and self._queue_item
and self._queue_item.queue_id == payload.queue_id
and self._queue_item.queue_id == event[1]["data"]["queue_id"]
):
self._cancel_event.set()
self._poll_now()
elif isinstance(payload, BatchEnqueuedEvent):
elif event_name == "batch_enqueued":
self._poll_now()
elif isinstance(payload, QueueItemStatusChangedEvent) and payload.status in ("completed", "failed", "canceled"):
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [
"completed",
"failed",
"canceled",
]:
self._poll_now()
def resume(self) -> SessionProcessorStatus:
@@ -139,7 +139,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.wait(self._polling_interval)
continue
self._invoker.services.events.emit_session_started(self._queue_item)
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear()
@@ -154,7 +153,16 @@ class DefaultSessionProcessor(SessionProcessorBase):
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
self._invoker.services.events.emit_invocation_started(self._queue_item, self._invocation)
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
@@ -181,12 +189,19 @@ class DefaultSessionProcessor(SessionProcessorBase):
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
self._queue_item, self._invocation, outputs
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt:
# TODO(MM2): I don't think this is ever raised...
# TODO(MM2): Create an event for this
pass
except CanceledException:
@@ -212,20 +227,30 @@ class DefaultSessionProcessor(SessionProcessorBase):
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_item=self._queue_item,
invocation=self._invocation,
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
user_id=None,
project_id=None,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
self._invoker.services.session_queue.set_queue_item_session(
self._queue_item.item_id, self._queue_item.session
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
self._invoker.services.events.emit_session_complete(self._queue_item)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
@@ -256,9 +281,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
)
# Cancel the queue item
if self._queue_item is not None:
self._invoker.services.session_queue.set_queue_item_session(
self._queue_item.item_id, self._queue_item.session
)
self._invoker.services.session_queue.cancel_queue_item(
self._queue_item.item_id, error=traceback.format_exc()
)

View File

@@ -16,7 +16,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemDTO,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults
@@ -104,8 +103,3 @@ class SessionQueueBase(ABC):
def get_queue_item(self, item_id: int) -> SessionQueueItem:
"""Gets a session queue item by ID"""
pass
@abstractmethod
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
"""Sets the session for a session queue item. Use this to update the session state."""
pass

View File

@@ -2,13 +2,10 @@ import sqlite3
import threading
from typing import Optional, Union, cast
from invokeai.app.services.events.events_common import (
FastAPIEvent,
InvocationErrorEvent,
SessionCanceledEvent,
SessionCompleteEvent,
register_events,
)
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import (
@@ -30,7 +27,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
calc_session_count,
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@@ -45,11 +41,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__invoker = invoker
self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID)
register_events(events={InvocationErrorEvent}, func=self._handle_error_event)
register_events(events={SessionCompleteEvent}, func=self._handle_complete_event)
register_events(events={SessionCanceledEvent}, func=self._handle_cancel_event)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
@@ -59,35 +51,51 @@ class SqliteSessionQueue(SessionQueueBase):
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
async def _handle_complete_event(self, event: FastAPIEvent[SessionCompleteEvent]) -> None:
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
return event[1]["event"] in match_in
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
event_name = event[1]["event"]
# This was a match statement, but match is not supported on python 3.9
if event_name == "graph_execution_state_complete":
await self._handle_complete_event(event)
elif event_name == "invocation_error":
await self._handle_error_event(event)
elif event_name == "session_canceled":
await self._handle_cancel_event(event)
return event
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
# When a queue item has an error, we get an error event, then a completed event.
# Mark the queue item completed only if it isn't already marked completed, e.g.
# by a previously-handled error event.
_event_name, payload = event
queue_item = self.get_queue_item(payload.item_id)
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
self._set_queue_item_status(item_id=payload.item_id, status="completed")
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
except SessionQueueItemNotFoundError:
pass
return
async def _handle_error_event(self, event: FastAPIEvent[InvocationErrorEvent]) -> None:
async def _handle_error_event(self, event: FastAPIEvent) -> None:
try:
_event_name, payload = event
item_id = event[1]["data"]["queue_item_id"]
error = event[1]["data"]["error"]
queue_item = self.get_queue_item(item_id)
# always set to failed if have an error, even if previously the item was marked completed or canceled
self._set_queue_item_status(item_id=payload.item_id, status="failed", error=payload.error)
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
except SessionQueueItemNotFoundError:
pass
return
async def _handle_cancel_event(self, event: FastAPIEvent[SessionCanceledEvent]) -> None:
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
try:
_event_name, payload = event
queue_item = self.get_queue_item(payload.item_id)
item_id = event[1]["data"]["queue_item_id"]
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
self._set_queue_item_status(item_id=payload.item_id, status="canceled")
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
except SessionQueueItemNotFoundError:
pass
return
def _set_in_progress_to_canceled(self) -> None:
"""
@@ -284,7 +292,11 @@ class SqliteSessionQueue(SessionQueueBase):
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
@@ -417,7 +429,12 @@ class SqliteSessionQueue(SessionQueueBase):
if queue_item.status not in ["canceled", "failed", "completed"]:
status = "failed" if error is not None else "canceled"
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
self.__invoker.services.events.emit_session_canceled(queue_item)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
queue_batch_id=queue_item.batch_id,
graph_execution_state_id=queue_item.session_id,
)
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
@@ -453,11 +470,18 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.events.emit_session_canceled(current_queue_item)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
session_queue_item=current_queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
except Exception:
self.__conn.rollback()
@@ -497,11 +521,18 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.events.emit_session_canceled(current_queue_item)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
session_queue_item=current_queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
except Exception:
self.__conn.rollback()
@@ -531,29 +562,6 @@ class SqliteSessionQueue(SessionQueueBase):
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
try:
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
session_json = session.model_dump_json(warnings=False, exclude_none=True)
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET session = ?
WHERE item_id = ?
""",
(session_json, item_id),
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return self.get_queue_item(item_id)
def list_queue_items(
self,
queue_id: str,

View File

@@ -353,11 +353,11 @@ class ModelsInterface(InvocationContextInterface):
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type)
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
else:
_submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
@@ -382,7 +382,7 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type)
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Gets a model's config.

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable
import torch
from PIL import Image
@@ -13,36 +13,8 @@ if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContextData
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
SDXL_LATENT_RGB_FACTORS = [
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
]
SDXL_SMOOTH_MATRIX = [
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
]
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
SD1_5_LATENT_RGB_FACTORS = [
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
]
def sample_to_lowres_estimated_image(
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
):
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
if smooth_matrix is not None:
@@ -75,12 +47,64 @@ def stable_diffusion_step_callback(
else:
sample = intermediate_state.latents
# TODO: This does not seem to be needed any more?
# # txt2img provides a Tensor in the step_callback
# # img2img provides a PipelineIntermediateState
# if isinstance(sample, PipelineIntermediateState):
# # this was an img2img
# print('img2img')
# latents = sample.latents
# step = sample.step
# else:
# print('txt2img')
# latents = sample
# step = intermediate_state.step
# TODO: only output a preview image when requested
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
],
dtype=sample.dtype,
device=sample.device,
)
sdxl_smooth_matrix = torch.tensor(
[
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
else:
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
(width, height) = image.size
@@ -89,9 +113,15 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
events.emit_invocation_denoise_progress(
context_data.queue_item,
context_data.invocation,
intermediate_state,
ProgressImage(dataURL=dataURL, width=width, height=height),
events.emit_generator_progress(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
node_id=context_data.invocation.id,
source_node_id=context_data.source_invocation_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
order=intermediate_state.order,
total_steps=intermediate_state.total_steps,
)

View File

@@ -35,23 +35,28 @@ import { addImageUploadedFulfilledListener } from 'app/store/middleware/listener
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationRetrievalError';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
import { addSessionRetrievalErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSessionRetrievalError';
import { addSocketSubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed';
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
@@ -105,8 +110,12 @@ addInvocationErrorEventListener(startAppListening);
addInvocationStartedEventListener(startAppListening);
addSocketConnectedEventListener(startAppListening);
addSocketDisconnectedEventListener(startAppListening);
addSocketSubscribedEventListener(startAppListening);
addSocketUnsubscribedEventListener(startAppListening);
addModelLoadEventListener(startAppListening);
addModelInstallEventListener(startAppListening);
addSessionRetrievalErrorEventListener(startAppListening);
addInvocationRetrievalErrorEventListener(startAppListening);
addSocketQueueItemStatusChangedEventListener(startAppListening);
addBulkDownloadListeners(startAppListening);

View File

@@ -6,8 +6,8 @@ import { toast } from 'common/util/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import {
socketBulkDownloadComplete,
socketBulkDownloadError,
socketBulkDownloadCompleted,
socketBulkDownloadFailed,
socketBulkDownloadStarted,
} from 'services/events/actions';
@@ -56,7 +56,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
});
startAppListening({
actionCreator: socketBulkDownloadComplete,
actionCreator: socketBulkDownloadCompleted,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation completed');
@@ -89,7 +89,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
});
startAppListening({
actionCreator: socketBulkDownloadError,
actionCreator: socketBulkDownloadFailed,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation failed');

View File

@@ -133,8 +133,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.invocation_source_id === processorNode.id
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
action.payload.data.source_node_id === processorNode.id
);
// We still have to check the output type

View File

@@ -69,8 +69,8 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.invocation_source_id === nodeId
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
action.payload.data.source_node_id === nodeId
);
// We still have to check the output type

View File

@@ -12,8 +12,8 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
actionCreator: socketGeneratorProgress,
effect: (action) => {
log.trace(action.payload, `Generator progress`);
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;

View File

@@ -29,12 +29,12 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
actionCreator: socketInvocationComplete,
effect: async (action, { dispatch, getState }) => {
const { data } = action.payload;
log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation_type})`);
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
const { result, invocation_source_id } = data;
const { result, node, queue_batch_id, source_node_id } = data;
// This complete event has an associated image output
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation_type)) {
const { image_name } = data.result.image;
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
const { image_name } = result.image;
const { canvas, gallery } = getState();
// This populates the `getImageDTO` cache
@@ -48,7 +48,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
imageDTORequest.unsubscribe();
// Add canvas images to the staging area
if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) {
if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) {
dispatch(addImageToStagingArea(imageDTO));
}
@@ -114,7 +114,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
}
}
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {

View File

@@ -11,9 +11,9 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
startAppListening({
actionCreator: socketInvocationError,
effect: (action) => {
log.error(action.payload, `Invocation error (${action.payload.data.invocation_type})`);
const { invocation_source_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
const { source_node_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.FAILED;
nes.error = action.payload.data.error;

View File

@@ -0,0 +1,14 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketInvocationRetrievalError } from 'services/events/actions';
const log = logger('socketio');
export const addInvocationRetrievalErrorEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketInvocationRetrievalError,
effect: (action) => {
log.error(action.payload, `Invocation retrieval error (${action.payload.data.graph_execution_state_id})`);
},
});
};

View File

@@ -11,9 +11,9 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
startAppListening({
actionCreator: socketInvocationStarted,
effect: (action) => {
log.debug(action.payload, `Invocation started (${action.payload.data.invocation_type})`);
const { invocation_source_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
const { source_node_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
upsertExecutionState(nes.nodeId, nes);

View File

@@ -3,14 +3,14 @@ import { api, LIST_TAG } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import {
socketModelInstallCancelled,
socketModelInstallComplete,
socketModelInstallDownloadProgress,
socketModelInstallCompleted,
socketModelInstallDownloading,
socketModelInstallError,
} from 'services/events/actions';
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketModelInstallDownloadProgress,
actionCreator: socketModelInstallDownloading,
effect: async (action, { dispatch }) => {
const { bytes, total_bytes, id } = action.payload.data;
@@ -29,7 +29,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
});
startAppListening({
actionCreator: socketModelInstallComplete,
actionCreator: socketModelInstallCompleted,
effect: (action, { dispatch }) => {
const { id } = action.payload.data;

View File

@@ -1,6 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketModelLoadComplete, socketModelLoadStarted } from 'services/events/actions';
import { socketModelLoadCompleted, socketModelLoadStarted } from 'services/events/actions';
const log = logger('socketio');
@@ -8,11 +8,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action) => {
const { config, submodel_type } = action.payload.data;
const { name, base, type } = config;
const { model_config, submodel_type } = action.payload.data;
const { name, base, type } = model_config;
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
}
@@ -24,10 +23,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
});
startAppListening({
actionCreator: socketModelLoadComplete,
actionCreator: socketModelLoadCompleted,
effect: (action) => {
const { config, submodel_type } = action.payload.data;
const { name, base, type } = config;
const { model_config, submodel_type } = action.payload.data;
const { name, base, type } = model_config;
const extras: string[] = [base, type];
if (submodel_type) {

View File

@@ -14,23 +14,16 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
actionCreator: socketQueueItemStatusChanged,
effect: async (action, { dispatch }) => {
// we've got new status for the queue item, batch and queue
const { item_id, status, started_at, updated_at, error, completed_at, batch_status, queue_status } =
action.payload.data;
const { queue_item, batch_status, queue_status } = action.payload.data;
log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
log.debug(action.payload, `Queue item ${queue_item.item_id} status updated: ${queue_item.status}`);
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
queueItemsAdapter.updateOne(draft, {
id: String(item_id),
changes: {
status,
started_at,
updated_at: updated_at ?? undefined,
error,
completed_at: completed_at ?? undefined,
},
id: String(queue_item.item_id),
changes: queue_item,
});
})
);
@@ -50,18 +43,23 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
queueApi.util.updateQueryData('getBatchStatus', { batch_id: batch_status.batch_id }, () => batch_status)
);
// Update the queue item status (this is the full queue item, including the session)
dispatch(
queueApi.util.updateQueryData('getQueueItem', queue_item.item_id, (draft) => {
if (!draft) {
return;
}
Object.assign(draft, queue_item);
})
);
// Invalidate caches for things we cannot update
// TODO: technically, we could possibly update the current session queue item, but feels safer to just request it again
dispatch(
queueApi.util.invalidateTags([
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'InvocationCacheStatus',
{ type: 'SessionQueueItem', id: item_id },
])
queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus'])
);
if (['in_progress'].includes(action.payload.data.status)) {
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
forEach($nodeExecutionStates.get(), (nes) => {
if (!nes) {
return;

View File

@@ -0,0 +1,14 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketSessionRetrievalError } from 'services/events/actions';
const log = logger('socketio');
export const addSessionRetrievalErrorEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketSessionRetrievalError,
effect: (action) => {
log.error(action.payload, `Session retrieval error (${action.payload.data.graph_execution_state_id})`);
},
});
};

View File

@@ -0,0 +1,14 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketSubscribedSession } from 'services/events/actions';
const log = logger('socketio');
export const addSocketSubscribedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketSubscribedSession,
effect: (action) => {
log.debug(action.payload, 'Subscribed');
},
});
};

View File

@@ -0,0 +1,13 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketUnsubscribedSession } from 'services/events/actions';
const log = logger('socketio');
export const addSocketUnsubscribedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketUnsubscribedSession,
effect: (action) => {
log.debug(action.payload, 'Unsubscribed');
},
});
};

View File

@@ -613,7 +613,7 @@ export const canvasSlice = createSlice({
state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id);
}
const queueItemStatus = action.payload.data.status;
const queueItemStatus = action.payload.data.queue_item.status;
if (queueItemStatus === 'canceled' || queueItemStatus === 'failed') {
resetStagingAreaIfEmpty(state);
}

View File

@@ -616,12 +616,24 @@ export const controlLayersSlice = createSlice({
iiLayerAdded: {
reducer: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => {
const { layerId, imageDTO } = action.payload;
// Retain opacity and denoising strength of existing initial image layer if exists
let opacity = 1;
let denoisingStrength = 0.75;
const iiLayer = state.layers.find((l) => l.id === layerId);
if (iiLayer) {
assert(isInitialImageLayer(iiLayer));
opacity = iiLayer.opacity;
denoisingStrength = iiLayer.denoisingStrength;
}
// Highlander! There can be only one!
state.layers = state.layers.filter((l) => (isInitialImageLayer(l) ? false : true));
const layer: InitialImageLayer = {
id: layerId,
type: 'initial_image_layer',
opacity: 1,
opacity,
x: 0,
y: 0,
bbox: null,
@@ -629,7 +641,7 @@ export const controlLayersSlice = createSlice({
isEnabled: true,
image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null,
isSelected: true,
denoisingStrength: 0.75,
denoisingStrength,
};
state.layers.push(layer);
exclusivelySelectLayer(state, layer.id);

View File

@@ -1,7 +1,8 @@
import type { UseToastOptions } from '@invoke-ai/ui-library';
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { calculateStepPercentage } from 'features/system/util/calculateStepPercentage';
import { makeToast } from 'features/system/util/makeToast';
import { t } from 'i18next';
import { startCase } from 'lodash-es';
@@ -13,10 +14,12 @@ import {
socketGraphExecutionStateComplete,
socketInvocationComplete,
socketInvocationError,
socketInvocationRetrievalError,
socketInvocationStarted,
socketModelLoadComplete,
socketModelLoadCompleted,
socketModelLoadStarted,
socketQueueItemStatusChanged,
socketSessionRetrievalError,
} from 'services/events/actions';
import type { Language, SystemState } from './types';
@@ -107,12 +110,20 @@ export const systemSlice = createSlice({
* Generator Progress
*/
builder.addCase(socketGeneratorProgress, (state, action) => {
const { step, total_steps, progress_image, session_id, batch_id, percentage } = action.payload.data;
const {
step,
total_steps,
order,
progress_image,
graph_execution_state_id: session_id,
queue_batch_id: batch_id,
} = action.payload.data;
state.denoiseProgress = {
step,
total_steps,
percentage,
order,
percentage: calculateStepPercentage(step, total_steps, order),
progress_image,
session_id,
batch_id,
@@ -141,12 +152,12 @@ export const systemSlice = createSlice({
state.status = 'LOADING_MODEL';
});
builder.addCase(socketModelLoadComplete, (state) => {
builder.addCase(socketModelLoadCompleted, (state) => {
state.status = 'CONNECTED';
});
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
if (['completed', 'canceled', 'failed'].includes(action.payload.data.status)) {
if (['completed', 'canceled', 'failed'].includes(action.payload.data.queue_item.status)) {
state.status = 'CONNECTED';
state.denoiseProgress = null;
}
@@ -157,7 +168,7 @@ export const systemSlice = createSlice({
/**
* Any server error
*/
builder.addCase(socketInvocationError, (state, action) => {
builder.addMatcher(isAnyServerError, (state, action) => {
state.toastQueue.push(
makeToast({
title: t('toast.serverError'),
@@ -183,6 +194,8 @@ export const {
setShouldEnableInformationalPopovers,
} = systemSlice.actions;
const isAnyServerError = isAnyOf(socketInvocationError, socketSessionRetrievalError, socketInvocationRetrievalError);
export const selectSystemSlice = (state: RootState) => state.system;
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */

View File

@@ -11,6 +11,7 @@ type DenoiseProgress = {
progress_image: ProgressImage | null | undefined;
step: number;
total_steps: number;
order: number;
percentage: number;
};

View File

@@ -0,0 +1,13 @@
export const calculateStepPercentage = (step: number, total_steps: number, order: number) => {
if (total_steps === 0) {
return 0;
}
// we add one extra to step so that the progress bar will be full when denoise completes
if (order === 2) {
return Math.floor((step + 1 + 1) / 2) / Math.floor((total_steps + 1) / 2);
}
return (step + 1 + 1) / (total_steps + 1);
};

File diff suppressed because one or more lines are too long

View File

@@ -39,6 +39,7 @@ export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDT
// Models
export type ModelType = S['ModelType'];
export type SubModelType = S['SubModelType'];
export type BaseModelType = S['BaseModelType'];
// Model Configs

View File

@@ -1,29 +1,22 @@
import { createAction } from '@reduxjs/toolkit';
import type {
BulkDownloadCompleteEvent,
BulkDownloadCompletedEvent,
BulkDownloadFailedEvent,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
GeneratorProgressEvent,
GraphExecutionStateCompleteEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationRetrievalErrorEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallCompletedEvent,
ModelInstallDownloadingEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadCompletedEvent,
ModelLoadStartedEvent,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
SessionCompleteEvent,
SessionStartedEvent,
SessionRetrievalErrorEvent,
} from 'services/events/types';
// Create actions for each socket
@@ -33,6 +26,12 @@ export const socketConnected = createAction('socket/socketConnected');
export const socketDisconnected = createAction('socket/socketDisconnected');
export const socketSubscribedSession = createAction<{
sessionId: string;
}>('socket/socketSubscribedSession');
export const socketUnsubscribedSession = createAction<{ sessionId: string }>('socket/socketUnsubscribedSession');
export const socketInvocationStarted = createAction<{
data: InvocationStartedEvent;
}>('socket/socketInvocationStarted');
@@ -45,65 +44,29 @@ export const socketInvocationError = createAction<{
data: InvocationErrorEvent;
}>('socket/socketInvocationError');
export const socketSessionStarted = createAction<{
data: SessionStartedEvent;
}>('socket/socketSessionStarted');
export const socketGraphExecutionStateComplete = createAction<{
data: SessionCompleteEvent;
data: GraphExecutionStateCompleteEvent;
}>('socket/socketGraphExecutionStateComplete');
export const socketSessionCanceled = createAction<{
data: SessionCanceledEvent;
}>('socket/socketSessionCanceled');
export const socketGeneratorProgress = createAction<{
data: InvocationDenoiseProgressEvent;
data: GeneratorProgressEvent;
}>('socket/socketGeneratorProgress');
export const socketModelLoadStarted = createAction<{
data: ModelLoadStartedEvent;
}>('socket/socketModelLoadStarted');
export const socketModelLoadComplete = createAction<{
data: ModelLoadCompleteEvent;
}>('socket/socketModelLoadComplete');
export const socketModelLoadCompleted = createAction<{
data: ModelLoadCompletedEvent;
}>('socket/socketModelLoadCompleted');
export const socketDownloadStarted = createAction<{
data: DownloadStartedEvent;
}>('socket/socketDownloadStarted');
export const socketModelInstallDownloading = createAction<{
data: ModelInstallDownloadingEvent;
}>('socket/socketModelInstallDownloading');
export const socketDownloadProgress = createAction<{
data: DownloadProgressEvent;
}>('socket/socketDownloadProgress');
export const socketDownloadComplete = createAction<{
data: DownloadCompleteEvent;
}>('socket/socketDownloadComplete');
export const socketDownloadCancelled = createAction<{
data: DownloadCancelledEvent;
}>('socket/socketDownloadCancelled');
export const socketDownloadError = createAction<{
data: DownloadErrorEvent;
}>('socket/socketDownloadError');
export const socketModelInstallStarted = createAction<{
data: ModelInstallStartedEvent;
}>('socket/socketModelInstallStarted');
export const socketModelInstallDownloadProgress = createAction<{
data: ModelInstallDownloadProgressEvent;
}>('socket/socketModelInstallDownloadProgress');
export const socketModelInstallDownloadsComplete = createAction<{
data: ModelInstallDownloadsCompleteEvent;
}>('socket/socketModelInstallDownloadsComplete');
export const socketModelInstallComplete = createAction<{
data: ModelInstallCompleteEvent;
}>('socket/socketModelInstallComplete');
export const socketModelInstallCompleted = createAction<{
data: ModelInstallCompletedEvent;
}>('socket/socketModelInstallCompleted');
export const socketModelInstallError = createAction<{
data: ModelInstallErrorEvent;
@@ -113,6 +76,14 @@ export const socketModelInstallCancelled = createAction<{
data: ModelInstallCancelledEvent;
}>('socket/socketModelInstallCancelled');
export const socketSessionRetrievalError = createAction<{
data: SessionRetrievalErrorEvent;
}>('socket/socketSessionRetrievalError');
export const socketInvocationRetrievalError = createAction<{
data: InvocationRetrievalErrorEvent;
}>('socket/socketInvocationRetrievalError');
export const socketQueueItemStatusChanged = createAction<{
data: QueueItemStatusChangedEvent;
}>('socket/socketQueueItemStatusChanged');
@@ -121,10 +92,10 @@ export const socketBulkDownloadStarted = createAction<{
data: BulkDownloadStartedEvent;
}>('socket/socketBulkDownloadStarted');
export const socketBulkDownloadComplete = createAction<{
data: BulkDownloadCompleteEvent;
}>('socket/socketBulkDownloadComplete');
export const socketBulkDownloadCompleted = createAction<{
data: BulkDownloadCompletedEvent;
}>('socket/socketBulkDownloadCompleted');
export const socketBulkDownloadError = createAction<{
export const socketBulkDownloadFailed = createAction<{
data: BulkDownloadFailedEvent;
}>('socket/socketBulkDownloadError');
}>('socket/socketBulkDownloadFailed');

View File

@@ -1,75 +1,275 @@
import type { Graph, GraphExecutionState, S } from 'services/api/types';
import type { components } from 'services/api/schema';
import type { AnyModelConfig, Graph, GraphExecutionState, SubModelType } from 'services/api/types';
/**
* A progress image, we get one for each step in the generation
*/
export type ProgressImage = {
dataURL: string;
width: number;
height: number;
};
export type AnyInvocation = NonNullable<NonNullable<Graph['nodes']>[string]>;
export type AnyResult = NonNullable<GraphExecutionState['results'][string]>;
export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
type BaseNode = {
id: string;
type: string;
[key: string]: AnyInvocation[keyof AnyInvocation];
};
export type InvocationStartedEvent = S['InvocationStartedEvent'];
export type InvocationDenoiseProgressEvent = S['InvocationDenoiseProgressEvent'];
export type InvocationCompleteEvent = Omit<S['InvocationCompleteEvent'], 'result'> & { result: AnyResult };
export type InvocationErrorEvent = S['InvocationErrorEvent'];
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
export type ModelLoadStartedEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
model_config: AnyModelConfig;
submodel_type?: SubModelType | null;
};
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
export type ModelInstallDownloadsCompleteEvent = S['ModelInstallDownloadsCompleteEvent'];
export type ModelInstallCompleteEvent = S['ModelInstallCompleteEvent'];
export type ModelInstallErrorEvent = S['ModelInstallErrorEvent'];
export type ModelInstallStartedEvent = S['ModelInstallStartedEvent'];
export type ModelInstallCancelledEvent = S['ModelInstallCancelledEvent'];
export type ModelLoadCompletedEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
model_config: AnyModelConfig;
submodel_type?: SubModelType | null;
};
export type DownloadStartedEvent = S['DownloadStartedEvent'];
export type DownloadProgressEvent = S['DownloadProgressEvent'];
export type DownloadCompleteEvent = S['DownloadCompleteEvent'];
export type DownloadCancelledEvent = S['DownloadCancelledEvent'];
export type DownloadErrorEvent = S['DownloadErrorEvent'];
export type ModelInstallDownloadingEvent = {
bytes: number;
local_path: string;
source: string;
timestamp: number;
total_bytes: number;
id: number;
};
export type SessionStartedEvent = S['SessionStartedEvent'];
export type SessionCompleteEvent = S['SessionCompleteEvent'];
export type SessionCanceledEvent = S['SessionCanceledEvent'];
export type ModelInstallCompletedEvent = {
key: number;
source: string;
timestamp: number;
id: number;
};
export type QueueItemStatusChangedEvent = S['QueueItemStatusChangedEvent'];
export type ModelInstallErrorEvent = {
error: string;
error_type: string;
source: string;
timestamp: number;
id: number;
};
export type BulkDownloadStartedEvent = S['BulkDownloadStartedEvent'];
export type BulkDownloadCompleteEvent = S['BulkDownloadCompleteEvent'];
export type BulkDownloadFailedEvent = S['BulkDownloadErrorEvent'];
export type ModelInstallCancelledEvent = {
source: string;
timestamp: number;
id: number;
};
/**
* A `generator_progress` socket.io event.
*
* @example socket.on('generator_progress', (data: GeneratorProgressEvent) => { ... }
*/
export type GeneratorProgressEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
node_id: string;
source_node_id: string;
progress_image?: ProgressImage;
step: number;
order: number;
total_steps: number;
};
/**
* A `invocation_complete` socket.io event.
*
* `result` is a discriminated union with a `type` property as the discriminant.
*
* @example socket.on('invocation_complete', (data: InvocationCompleteEvent) => { ... }
*/
export type InvocationCompleteEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
node: BaseNode;
source_node_id: string;
result: AnyResult;
};
/**
* A `invocation_error` socket.io event.
*
* @example socket.on('invocation_error', (data: InvocationErrorEvent) => { ... }
*/
export type InvocationErrorEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
node: BaseNode;
source_node_id: string;
error_type: string;
error: string;
};
/**
* A `invocation_started` socket.io event.
*
* @example socket.on('invocation_started', (data: InvocationStartedEvent) => { ... }
*/
export type InvocationStartedEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
node: BaseNode;
source_node_id: string;
};
/**
* A `graph_execution_state_complete` socket.io event.
*
* @example socket.on('graph_execution_state_complete', (data: GraphExecutionStateCompleteEvent) => { ... }
*/
export type GraphExecutionStateCompleteEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
};
/**
* A `session_retrieval_error` socket.io event.
*
* @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... }
*/
export type SessionRetrievalErrorEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
error_type: string;
error: string;
};
/**
* A `invocation_retrieval_error` socket.io event.
*
* @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... }
*/
export type InvocationRetrievalErrorEvent = {
queue_id: string;
queue_item_id: number;
queue_batch_id: string;
graph_execution_state_id: string;
node_id: string;
error_type: string;
error: string;
};
/**
* A `queue_item_status_changed` socket.io event.
*
* @example socket.on('queue_item_status_changed', (data: QueueItemStatusChangedEvent) => { ... }
*/
export type QueueItemStatusChangedEvent = {
queue_id: string;
queue_item: {
queue_id: string;
item_id: number;
batch_id: string;
session_id: string;
status: components['schemas']['SessionQueueItemDTO']['status'];
error: string | undefined;
created_at: string;
updated_at: string;
started_at: string | undefined;
completed_at: string | undefined;
};
batch_status: {
queue_id: string;
batch_id: string;
pending: number;
in_progress: number;
completed: number;
failed: number;
canceled: number;
total: number;
};
queue_status: {
queue_id: string;
item_id?: number;
batch_id?: string;
session_id?: string;
pending: number;
in_progress: number;
completed: number;
failed: number;
canceled: number;
total: number;
};
};
type ClientEmitSubscribeQueue = {
queue_id: string;
};
type ClientEmitUnsubscribeQueue = ClientEmitSubscribeQueue;
type ClientEmitUnsubscribeQueue = {
queue_id: string;
};
export type BulkDownloadStartedEvent = {
bulk_download_id: string;
bulk_download_item_id: string;
bulk_download_item_name: string;
};
export type BulkDownloadCompletedEvent = {
bulk_download_id: string;
bulk_download_item_id: string;
bulk_download_item_name: string;
};
export type BulkDownloadFailedEvent = {
bulk_download_id: string;
bulk_download_item_id: string;
bulk_download_item_name: string;
error: string;
};
type ClientEmitSubscribeBulkDownload = {
bulk_download_id: string;
};
type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload;
type ClientEmitUnsubscribeBulkDownload = {
bulk_download_id: string;
};
export type ServerToClientEvents = {
invocation_denoise_progress: (payload: InvocationDenoiseProgressEvent) => void;
generator_progress: (payload: GeneratorProgressEvent) => void;
invocation_complete: (payload: InvocationCompleteEvent) => void;
invocation_error: (payload: InvocationErrorEvent) => void;
invocation_started: (payload: InvocationStartedEvent) => void;
session_started: (payload: SessionStartedEvent) => void;
session_complete: (payload: SessionCompleteEvent) => void;
session_canceled: (payload: SessionCanceledEvent) => void;
download_started: (payload: DownloadStartedEvent) => void;
download_progress: (payload: DownloadProgressEvent) => void;
download_complete: (payload: DownloadCompleteEvent) => void;
download_cancelled: (payload: DownloadCancelledEvent) => void;
download_error: (payload: DownloadErrorEvent) => void;
graph_execution_state_complete: (payload: GraphExecutionStateCompleteEvent) => void;
model_load_started: (payload: ModelLoadStartedEvent) => void;
model_install_started: (payload: ModelInstallStartedEvent) => void;
model_install_download_progress: (payload: ModelInstallDownloadProgressEvent) => void;
model_install_downloads_complete: (payload: ModelInstallDownloadsCompleteEvent) => void;
model_install_complete: (payload: ModelInstallCompleteEvent) => void;
model_load_completed: (payload: ModelLoadCompletedEvent) => void;
model_install_downloading: (payload: ModelInstallDownloadingEvent) => void;
model_install_completed: (payload: ModelInstallCompletedEvent) => void;
model_install_error: (payload: ModelInstallErrorEvent) => void;
model_install_cancelled: (payload: ModelInstallCancelledEvent) => void;
model_load_complete: (payload: ModelLoadCompleteEvent) => void;
model_install_canceled: (payload: ModelInstallCancelledEvent) => void;
session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void;
invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void;
queue_item_status_changed: (payload: QueueItemStatusChangedEvent) => void;
bulk_download_started: (payload: BulkDownloadStartedEvent) => void;
bulk_download_complete: (payload: BulkDownloadCompleteEvent) => void;
bulk_download_error: (payload: BulkDownloadFailedEvent) => void;
bulk_download_completed: (payload: BulkDownloadCompletedEvent) => void;
bulk_download_failed: (payload: BulkDownloadFailedEvent) => void;
};
export type ClientToServerEvents = {

View File

@@ -5,32 +5,24 @@ import type { AppDispatch } from 'app/store/store';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import {
socketBulkDownloadComplete,
socketBulkDownloadError,
socketBulkDownloadCompleted,
socketBulkDownloadFailed,
socketBulkDownloadStarted,
socketConnected,
socketDisconnected,
socketDownloadCancelled,
socketDownloadComplete,
socketDownloadError,
socketDownloadProgress,
socketDownloadStarted,
socketGeneratorProgress,
socketGraphExecutionStateComplete,
socketInvocationComplete,
socketInvocationError,
socketInvocationRetrievalError,
socketInvocationStarted,
socketModelInstallCancelled,
socketModelInstallComplete,
socketModelInstallDownloadProgress,
socketModelInstallDownloadsComplete,
socketModelInstallCompleted,
socketModelInstallDownloading,
socketModelInstallError,
socketModelInstallStarted,
socketModelLoadComplete,
socketModelLoadCompleted,
socketModelLoadStarted,
socketQueueItemStatusChanged,
socketSessionCanceled,
socketSessionStarted,
socketSessionRetrievalError,
} from 'services/events/actions';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import type { Socket } from 'socket.io-client';
@@ -73,87 +65,131 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
}
});
/**
* Disconnect
*/
socket.on('disconnect', () => {
dispatch(socketDisconnected());
});
/**
* Invocation started
*/
socket.on('invocation_started', (data) => {
dispatch(socketInvocationStarted({ data }));
});
socket.on('invocation_denoise_progress', (data) => {
/**
* Generator progress
*/
socket.on('generator_progress', (data) => {
dispatch(socketGeneratorProgress({ data }));
});
/**
* Invocation error
*/
socket.on('invocation_error', (data) => {
dispatch(socketInvocationError({ data }));
});
/**
* Invocation complete
*/
socket.on('invocation_complete', (data) => {
dispatch(socketInvocationComplete({ data }));
dispatch(
socketInvocationComplete({
data,
})
);
});
socket.on('session_started', (data) => {
dispatch(socketSessionStarted({ data }));
});
socket.on('session_complete', (data) => {
dispatch(socketGraphExecutionStateComplete({ data }));
});
socket.on('session_canceled', (data) => {
dispatch(socketSessionCanceled({ data }));
/**
* Graph complete
*/
socket.on('graph_execution_state_complete', (data) => {
dispatch(
socketGraphExecutionStateComplete({
data,
})
);
});
/**
* Model load started
*/
socket.on('model_load_started', (data) => {
dispatch(socketModelLoadStarted({ data }));
dispatch(
socketModelLoadStarted({
data,
})
);
});
socket.on('model_load_complete', (data) => {
dispatch(socketModelLoadComplete({ data }));
/**
* Model load completed
*/
socket.on('model_load_completed', (data) => {
dispatch(
socketModelLoadCompleted({
data,
})
);
});
socket.on('download_started', (data) => {
dispatch(socketDownloadStarted({ data }));
/**
* Model Install Downloading
*/
socket.on('model_install_downloading', (data) => {
dispatch(
socketModelInstallDownloading({
data,
})
);
});
socket.on('download_progress', (data) => {
dispatch(socketDownloadProgress({ data }));
});
socket.on('download_complete', (data) => {
dispatch(socketDownloadComplete({ data }));
});
socket.on('download_cancelled', (data) => {
dispatch(socketDownloadCancelled({ data }));
});
socket.on('download_error', (data) => {
dispatch(socketDownloadError({ data }));
});
socket.on('model_install_started', (data) => {
dispatch(socketModelInstallStarted({ data }));
});
socket.on('model_install_download_progress', (data) => {
dispatch(socketModelInstallDownloadProgress({ data }));
});
socket.on('model_install_downloads_complete', (data) => {
dispatch(socketModelInstallDownloadsComplete({ data }));
});
socket.on('model_install_complete', (data) => {
dispatch(socketModelInstallComplete({ data }));
/**
* Model Install Completed
*/
socket.on('model_install_completed', (data) => {
dispatch(
socketModelInstallCompleted({
data,
})
);
});
/**
* Model Install Error
*/
socket.on('model_install_error', (data) => {
dispatch(socketModelInstallError({ data }));
dispatch(
socketModelInstallError({
data,
})
);
});
socket.on('model_install_cancelled', (data) => {
dispatch(socketModelInstallCancelled({ data }));
/**
* Session retrieval error
*/
socket.on('session_retrieval_error', (data) => {
dispatch(
socketSessionRetrievalError({
data,
})
);
});
/**
* Invocation retrieval error
*/
socket.on('invocation_retrieval_error', (data) => {
dispatch(
socketInvocationRetrievalError({
data,
})
);
});
socket.on('queue_item_status_changed', (data) => {
@@ -164,11 +200,11 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
dispatch(socketBulkDownloadStarted({ data }));
});
socket.on('bulk_download_complete', (data) => {
dispatch(socketBulkDownloadComplete({ data }));
socket.on('bulk_download_completed', (data) => {
dispatch(socketBulkDownloadCompleted({ data }));
});
socket.on('bulk_download_error', (data) => {
dispatch(socketBulkDownloadError({ data }));
socket.on('bulk_download_failed', (data) => {
dispatch(socketBulkDownloadFailed({ data }));
});
};

View File

@@ -1 +1 @@
__version__ = "4.2.1"
__version__ = "4.2.2"

View File

@@ -9,11 +9,6 @@ import pytest
from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.events.events_common import (
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadStartedEvent,
)
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecordNotFoundException,
@@ -286,9 +281,9 @@ def assert_handler_success(
# Check that the correct events were emitted
assert len(event_bus.events) == 2
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
assert isinstance(event_bus.events[1], BulkDownloadCompleteEvent)
assert event_bus.events[1].bulk_download_item_name == os.path.basename(expected_zip_path)
assert event_bus.events[0].event_name == "bulk_download_started"
assert event_bus.events[1].event_name == "bulk_download_completed"
assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path)
def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
@@ -334,9 +329,9 @@ def test_handler_on_generic_exception(
event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
assert event_bus.events[1].error == exception.__str__()
assert event_bus.events[0].event_name == "bulk_download_started"
assert event_bus.events[1].event_name == "bulk_download_failed"
assert event_bus.events[1].payload["error"] == exception.__str__()
def execute_handler_test_on_error(
@@ -349,9 +344,9 @@ def execute_handler_test_on_error(
event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
assert event_bus.events[1].error == error.__str__()
assert event_bus.events[0].event_name == "bulk_download_started"
assert event_bus.events[1].event_name == "bulk_download_failed"
assert event_bus.events[1].payload["error"] == error.__str__()
def test_delete(tmp_path: Path):

View File

@@ -10,13 +10,6 @@ from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from invokeai.app.services.events.events_common import (
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
)
from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings
@@ -123,14 +116,14 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
queue.join()
events = event_bus.events
assert len(events) == 3
assert isinstance(events[0], DownloadStartedEvent)
assert isinstance(events[1], DownloadProgressEvent)
assert isinstance(events[2], DownloadCompleteEvent)
assert events[0].timestamp <= events[1].timestamp
assert events[1].timestamp <= events[2].timestamp
assert events[1].total_bytes > 0
assert events[1].current_bytes <= events[1].total_bytes
assert events[2].total_bytes == 32029
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
assert events[0].event_name == "download_started"
assert events[1].event_name == "download_progress"
assert events[1].payload["total_bytes"] > 0
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
assert events[2].event_name == "download_complete"
assert events[2].payload["total_bytes"] == 32029
# test a failure
event_bus.events = [] # reset our accumulator
@@ -139,10 +132,10 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
events = event_bus.events
print("\n".join([x.model_dump_json() for x in events]))
assert len(events) == 1
assert isinstance(events[0], DownloadErrorEvent)
assert events[0].error_type == "HTTPError(NOT FOUND)"
assert events[0].error is not None
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].error)
assert events[0].event_name == "download_error"
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
assert events[0].payload["error"] is not None
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
queue.stop()
@@ -209,6 +202,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
assert job.status == DownloadJobStatus.CANCELLED
assert cancelled
events = event_bus.events
assert isinstance(events[-1], DownloadCancelledEvent)
assert events[-1].source == "http://www.civitai.com/models/12345"
assert events[-1].event_name == "download_cancelled"
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
queue.stop()

View File

@@ -13,24 +13,16 @@ from pydantic_core import Url
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import (
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
)
from invokeai.app.services.model_install import (
ModelInstallServiceBase,
)
from invokeai.app.services.model_install.model_install_common import (
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
URLModelSource,
)
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
from tests.test_nodes import TestEventService
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
OS = platform.uname().system
@@ -138,16 +130,17 @@ def test_background_install(
assert job.total_bytes == size
# test that the expected events were issued
bus: TestEventService = mm2_installer.event_bus
bus = mm2_installer.event_bus
assert bus
assert hasattr(bus, "events")
assert len(bus.events) == 2
assert isinstance(bus.events[0], ModelInstallStartedEvent)
assert isinstance(bus.events[1], ModelInstallCompleteEvent)
assert Path(bus.events[0].source) == source
assert Path(bus.events[1].source) == source
key = bus.events[1].key
event_names = [x.event_name for x in bus.events]
assert "model_install_running" in event_names
assert "model_install_completed" in event_names
assert Path(bus.events[0].payload["source"]) == source
assert Path(bus.events[1].payload["source"]) == source
key = bus.events[1].payload["key"]
assert key is not None
# see if the thing actually got installed at the expected location
@@ -226,7 +219,7 @@ def test_delete_register(
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
bus: TestEventService = mm2_installer.event_bus
bus = mm2_installer.event_bus
store = mm2_installer.record_store
assert store is not None
assert bus is not None
@@ -244,17 +237,20 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
assert (mm2_app_config.models_path / model_record.path).exists()
assert len(bus.events) == 4
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
assert isinstance(bus.events[2], ModelInstallStartedEvent)
assert isinstance(bus.events[3], ModelInstallCompleteEvent)
event_names = [x.event_name for x in bus.events]
assert event_names == [
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
]
@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
bus: TestEventService = mm2_installer.event_bus
bus = mm2_installer.event_bus
store = mm2_installer.record_store
assert isinstance(bus, EventServiceBase)
assert store is not None
@@ -271,10 +267,15 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
assert model_record.type == ModelType.Main
assert model_record.format == ModelFormat.Diffusers
assert any(isinstance(x, ModelInstallStartedEvent) for x in bus.events)
assert any(isinstance(x, ModelInstallDownloadProgressEvent) for x in bus.events)
assert any(isinstance(x, ModelInstallCompleteEvent) for x in bus.events)
assert hasattr(bus, "events") # the dummyeventservice has this
assert len(bus.events) >= 3
event_names = {x.event_name for x in bus.events}
assert event_names == {
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
}
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:

View File

@@ -3,13 +3,16 @@
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
import pytest
from pydantic import BaseModel
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
@@ -36,7 +39,27 @@ from tests.backend.model_manager.model_metadata.metadata_examples import (
RepoHFModelJson1,
)
from tests.fixtures.sqlite_database import create_mock_sqlite_database
from tests.test_nodes import TestEventService
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
# Create a temporary directory using the contents of `./data/invokeai_root` as the template
@@ -104,7 +127,7 @@ def mm2_installer(
) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = TestEventService()
events = DummyEventService()
store = ModelRecordServiceSQL(db)
installer = ModelInstallService(

View File

@@ -20,7 +20,6 @@ from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403
from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401
from tests.test_nodes import TestEventService

View File

@@ -1,5 +1,7 @@
from typing import Any, Callable, Union
from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@@ -8,7 +10,6 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.fields import InputField, OutputField
from invokeai.app.invocations.image import ImageField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -116,10 +117,11 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
)
class TestEvent(EventBase):
class TestEvent(BaseModel):
__test__ = False # not a pytest test case
__event_name__ = "test_event"
event_name: str
payload: Any
class TestEventService(EventServiceBase):
@@ -127,10 +129,10 @@ class TestEventService(EventServiceBase):
def __init__(self):
super().__init__()
self.events: list[EventBase] = []
self.events: list[TestEvent] = []
def dispatch(self, event: EventBase) -> None:
self.events.append(event)
def dispatch(self, event_name: str, payload: Any) -> None:
self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"]))
pass