mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-17 05:47:57 -05:00
Compare commits
31 Commits
improve-co
...
psyche/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc78a0e699 | ||
|
|
08a42c3c03 | ||
|
|
0758e9cb9b | ||
|
|
fb93e686b2 | ||
|
|
350feeed56 | ||
|
|
169b75b2b7 | ||
|
|
c88de180e7 | ||
|
|
7d1844eaf2 | ||
|
|
a98ddedb95 | ||
|
|
6063487b20 | ||
|
|
9a4c167342 | ||
|
|
19227fe4e6 | ||
|
|
db0ef8d316 | ||
|
|
6a34176376 | ||
|
|
d6696a7b97 | ||
|
|
0e81e7b460 | ||
|
|
7652fbc2e9 | ||
|
|
a55b2f09e2 | ||
|
|
23b05344a3 | ||
|
|
80905ff3ea | ||
|
|
df5457231f | ||
|
|
d30c1ad6dc | ||
|
|
b1f819ae8d | ||
|
|
eff359625a | ||
|
|
cef1585dfb | ||
|
|
cb8e9e1c7b | ||
|
|
f7c356d142 | ||
|
|
efb069dd71 | ||
|
|
8edc25d35a | ||
|
|
82957bb826 | ||
|
|
e51a3025ea |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -188,4 +188,3 @@ installer/install.sh
|
||||
installer/update.bat
|
||||
installer/update.sh
|
||||
installer/InvokeAI-Installer/
|
||||
.aider*
|
||||
|
||||
@@ -64,7 +64,7 @@ GPU_DRIVER=nvidia
|
||||
|
||||
Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
|
||||
|
||||
## Even More Customizing!
|
||||
## Even Moar Customizing!
|
||||
|
||||
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ When you generate an image using text-to-image, multiple steps occur in latent s
|
||||
4. The VAE decodes the final latent image from latent space into image space.
|
||||
|
||||
Image-to-image is a similar process, with only step 1 being different:
|
||||
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how many noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
|
||||
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how may noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
|
||||
|
||||
Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor).
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from ..services.boards.boards_default import BoardService
|
||||
from ..services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from ..services.config import InvokeAIAppConfig
|
||||
from ..services.download import DownloadQueueService
|
||||
from ..services.events.events_fastapievents import FastAPIEventService
|
||||
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||
from ..services.images.images_default import ImageService
|
||||
@@ -34,6 +33,7 @@ from ..services.session_processor.session_processor_default import DefaultSessio
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
# TODO: is there a better way to achieve this?
|
||||
@@ -103,6 +103,7 @@ class ApiDependencies:
|
||||
)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
|
||||
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
|
||||
52
invokeai/app/api/events.py
Normal file
52
invokeai/app/api/events.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from ..services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
event_handler_id: int
|
||||
__queue: Queue
|
||||
__stop_event: threading.Event
|
||||
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self.__queue = Queue()
|
||||
self.__stop_event = threading.Event()
|
||||
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self.__stop_event.set()
|
||||
self.__queue.put(None)
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.__queue.put({"event_name": event_name, "payload": payload})
|
||||
|
||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self.__queue.get(block=False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
|
||||
dispatch(
|
||||
event.get("event_name"),
|
||||
payload=event.get("payload"),
|
||||
middleware_id=self.event_handler_id,
|
||||
)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
||||
@@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
|
||||
@@ -1,125 +1,66 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from socketio import ASGIApp, AsyncServer
|
||||
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
BulkDownloadCompleteEvent,
|
||||
BulkDownloadErrorEvent,
|
||||
BulkDownloadEventBase,
|
||||
BulkDownloadStartedEvent,
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadEventBase,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
FastAPIEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
ModelEventBase,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallErrorEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelLoadStartedEvent,
|
||||
QueueClearedEvent,
|
||||
QueueEventBase,
|
||||
QueueItemStatusChangedEvent,
|
||||
register_events,
|
||||
)
|
||||
|
||||
|
||||
class QueueSubscriptionEvent(BaseModel):
|
||||
"""Event data for subscribing to the socket.io queue room.
|
||||
This is a pydantic model to ensure the data is in the correct format."""
|
||||
|
||||
queue_id: str
|
||||
|
||||
|
||||
class BulkDownloadSubscriptionEvent(BaseModel):
|
||||
"""Event data for subscribing to the socket.io bulk downloads room.
|
||||
This is a pydantic model to ensure the data is in the correct format."""
|
||||
|
||||
bulk_download_id: str
|
||||
|
||||
|
||||
QUEUE_EVENTS = {
|
||||
InvocationStartedEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationErrorEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
BatchEnqueuedEvent,
|
||||
QueueClearedEvent,
|
||||
}
|
||||
|
||||
MODEL_EVENTS = {
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
ModelLoadStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallErrorEvent,
|
||||
}
|
||||
|
||||
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
|
||||
from ..services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class SocketIO:
|
||||
_sub_queue = "subscribe_queue"
|
||||
_unsub_queue = "unsubscribe_queue"
|
||||
__sio: AsyncServer
|
||||
__app: ASGIApp
|
||||
|
||||
_sub_bulk_download = "subscribe_bulk_download"
|
||||
_unsub_bulk_download = "unsubscribe_bulk_download"
|
||||
__sub_queue: str = "subscribe_queue"
|
||||
__unsub_queue: str = "unsubscribe_queue"
|
||||
|
||||
__sub_bulk_download: str = "subscribe_bulk_download"
|
||||
__unsub_bulk_download: str = "unsubscribe_bulk_download"
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
|
||||
app.mount("/ws", self._app)
|
||||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
|
||||
app.mount("/ws", self.__app)
|
||||
|
||||
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
|
||||
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
|
||||
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
|
||||
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
|
||||
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
|
||||
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
|
||||
|
||||
register_events(QUEUE_EVENTS, self._handle_queue_event)
|
||||
register_events(MODEL_EVENTS, self._handle_model_event)
|
||||
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
|
||||
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
|
||||
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
|
||||
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
|
||||
|
||||
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
|
||||
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
||||
async def _handle_queue_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["queue_id"],
|
||||
)
|
||||
|
||||
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
||||
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||
if "queue_id" in data:
|
||||
await self.__sio.enter_room(sid, data["queue_id"])
|
||||
|
||||
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||
if "queue_id" in data:
|
||||
await self.__sio.leave_room(sid, data["queue_id"])
|
||||
|
||||
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
async def _handle_model_event(self, event: Event) -> None:
|
||||
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
|
||||
|
||||
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
|
||||
async def _handle_bulk_download_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["bulk_download_id"],
|
||||
)
|
||||
|
||||
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
|
||||
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
|
||||
if "bulk_download_id" in data:
|
||||
await self.__sio.enter_room(sid, data["bulk_download_id"])
|
||||
|
||||
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
|
||||
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
|
||||
if "bulk_download_id" in data:
|
||||
await self.__sio.leave_room(sid, data["bulk_download_id"])
|
||||
|
||||
@@ -27,7 +27,6 @@ import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
@@ -183,14 +182,23 @@ def custom_openapi() -> dict[str, Any]:
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
# Add all event schemas
|
||||
for event in sorted(EventBase.get_events(), key=lambda e: e.__name__):
|
||||
json_schema = event.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||
if "$defs" in json_schema:
|
||||
for schema_key, schema in json_schema["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema
|
||||
del json_schema["$defs"]
|
||||
openapi_schema["components"]["schemas"][event.__name__] = json_schema
|
||||
# This code no longer seems to be necessary?
|
||||
# Leave it here just in case
|
||||
#
|
||||
# from invokeai.backend.model_manager import get_model_config_formats
|
||||
# formats = get_model_config_formats()
|
||||
# for model_config_name, enum_set in formats.items():
|
||||
|
||||
# if model_config_name in openapi_schema["components"]["schemas"]:
|
||||
# # print(f"Config with name {name} already defined")
|
||||
# continue
|
||||
|
||||
# openapi_schema["components"]["schemas"][model_config_name] = {
|
||||
# "title": model_config_name,
|
||||
# "description": "An enumeration.",
|
||||
# "type": "string",
|
||||
# "enum": [v.value for v in enum_set],
|
||||
# }
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
@@ -65,7 +65,11 @@ class CompelInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
@@ -80,21 +84,19 @@ class CompelInvocation(BaseInvocation):
|
||||
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
||||
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info as text_encoder,
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
||||
patched_tokenizer,
|
||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
||||
):
|
||||
assert isinstance(text_encoder, CLIPTextModel)
|
||||
assert isinstance(tokenizer, CLIPTokenizer)
|
||||
compel = Compel(
|
||||
tokenizer=patched_tokenizer,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
@@ -104,7 +106,7 @@ class CompelInvocation(BaseInvocation):
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
|
||||
@@ -134,7 +136,11 @@ class SDXLPromptInvocationBase:
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
@@ -171,23 +177,20 @@ class SDXLPromptInvocationBase:
|
||||
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
|
||||
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info as text_encoder,
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
||||
patched_tokenizer,
|
||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||
):
|
||||
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
assert isinstance(tokenizer, CLIPTokenizer)
|
||||
|
||||
text_encoder = cast(CLIPTextModel, text_encoder)
|
||||
compel = Compel(
|
||||
tokenizer=patched_tokenizer,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
@@ -200,7 +203,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
# TODO: better logging for and syntax
|
||||
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
# TODO: ask for optimizations? to not run text_encoder twice
|
||||
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
|
||||
@@ -930,9 +930,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config),
|
||||
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
|
||||
unet_info as unet,
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||
):
|
||||
|
||||
@@ -106,7 +106,9 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
self._invoker.services.events.emit_bulk_download_started(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
)
|
||||
|
||||
def _signal_job_completed(
|
||||
@@ -116,8 +118,10 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert bulk_download_item_name is not None
|
||||
self._invoker.services.events.emit_bulk_download_complete(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name
|
||||
self._invoker.services.events.emit_bulk_download_completed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
)
|
||||
|
||||
def _signal_job_failed(
|
||||
@@ -127,8 +131,11 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert exception is not None
|
||||
self._invoker.services.events.emit_bulk_download_error(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
|
||||
self._invoker.services.events.emit_bulk_download_failed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
error=str(exception),
|
||||
)
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
|
||||
@@ -8,13 +8,14 @@ import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@@ -29,9 +30,6 @@ from .download_base import (
|
||||
UnknownJobIDException,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
# Maximum number of bytes to download during each call to requests.iter_content()
|
||||
DOWNLOAD_CHUNK_SIZE = 100000
|
||||
|
||||
@@ -42,7 +40,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
):
|
||||
"""
|
||||
@@ -345,7 +343,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_started(job)
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
|
||||
|
||||
def _signal_job_progress(self, job: DownloadJob) -> None:
|
||||
if job.on_progress:
|
||||
@@ -356,7 +355,13 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_progress(job)
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_progress(
|
||||
str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
current_bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
|
||||
def _signal_job_complete(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.COMPLETED
|
||||
@@ -368,7 +373,10 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_complete(job)
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_complete(
|
||||
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
|
||||
)
|
||||
|
||||
def _signal_job_cancelled(self, job: DownloadJob) -> None:
|
||||
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
||||
@@ -382,7 +390,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_cancelled(job)
|
||||
self._event_bus.emit_download_cancelled(str(job.source))
|
||||
|
||||
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
@@ -395,7 +403,9 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_error(job)
|
||||
assert job.error_type
|
||||
assert job.error
|
||||
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
|
||||
|
||||
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
|
||||
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")
|
||||
|
||||
@@ -1,195 +1,494 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
BulkDownloadCompleteEvent,
|
||||
BulkDownloadErrorEvent,
|
||||
BulkDownloadStartedEvent,
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
EventBase,
|
||||
InvocationCompleteEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallErrorEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelLoadStartedEvent,
|
||||
QueueClearedEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
queue_event: str = "queue_event"
|
||||
bulk_download_event: str = "bulk_download_event"
|
||||
download_event: str = "download_event"
|
||||
model_event: str = "model_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
def dispatch(self, event: "EventBase") -> None:
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
# region: Invocation
|
||||
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Bulk download events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.bulk_download_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
|
||||
"""Emitted when an invocation is started"""
|
||||
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
|
||||
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Queue events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.queue_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def emit_invocation_denoise_progress(
|
||||
def __emit_download_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.download_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def __emit_model_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.model_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
queue_item: "SessionQueueItem",
|
||||
invocation: "BaseInvocation",
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
progress_image: "ProgressImage",
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node_id: str,
|
||||
source_node_id: str,
|
||||
progress_image: Optional[ProgressImage],
|
||||
step: int,
|
||||
order: int,
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Emitted at each step during denoising of an invocation."""
|
||||
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_queue_event(
|
||||
event_name="generator_progress",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node_id": node_id,
|
||||
"source_node_id": source_node_id,
|
||||
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
|
||||
"step": step,
|
||||
"order": order,
|
||||
"total_steps": total_steps,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_invocation_complete(
|
||||
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
result: dict,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation is complete"""
|
||||
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_complete",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
"result": result,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_invocation_error(
|
||||
self,
|
||||
queue_item: "SessionQueueItem",
|
||||
invocation: "BaseInvocation",
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
user_id: str | None,
|
||||
project_id: str | None,
|
||||
) -> None:
|
||||
"""Emitted when an invocation encounters an error"""
|
||||
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback))
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_error",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
"error_type": error_type,
|
||||
"error_message": error_message,
|
||||
"error_traceback": error_traceback,
|
||||
"user_id": user_id,
|
||||
"project_id": project_id,
|
||||
},
|
||||
)
|
||||
|
||||
# endregion
|
||||
def emit_invocation_started(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_started",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
},
|
||||
)
|
||||
|
||||
# region Queue
|
||||
def emit_graph_execution_complete(
|
||||
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
|
||||
) -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_queue_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_load_started(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_started",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_config": model_config.model_dump(mode="json"),
|
||||
"submodel_type": submodel_type,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_load_completed(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_completed",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_config": model_config.model_dump(mode="json"),
|
||||
"submodel_type": submodel_type,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_session_canceled(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
) -> None:
|
||||
"""Emitted when a session is canceled"""
|
||||
self.__emit_queue_event(
|
||||
event_name="session_canceled",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_queue_item_status_changed(
|
||||
self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
|
||||
self,
|
||||
session_queue_item: SessionQueueItem,
|
||||
batch_status: BatchStatus,
|
||||
queue_status: SessionQueueStatus,
|
||||
) -> None:
|
||||
"""Emitted when a queue item's status changes"""
|
||||
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_item_status_changed",
|
||||
payload={
|
||||
"queue_id": queue_status.queue_id,
|
||||
"queue_item": {
|
||||
"queue_id": session_queue_item.queue_id,
|
||||
"item_id": session_queue_item.item_id,
|
||||
"status": session_queue_item.status,
|
||||
"batch_id": session_queue_item.batch_id,
|
||||
"session_id": session_queue_item.session_id,
|
||||
"error_type": session_queue_item.error_type,
|
||||
"error_message": session_queue_item.error_message,
|
||||
"error_traceback": session_queue_item.error_traceback,
|
||||
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||
},
|
||||
"batch_status": batch_status.model_dump(mode="json"),
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
|
||||
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
|
||||
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
||||
"""Emitted when a batch is enqueued"""
|
||||
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
|
||||
self.__emit_queue_event(
|
||||
event_name="batch_enqueued",
|
||||
payload={
|
||||
"queue_id": enqueue_result.queue_id,
|
||||
"batch_id": enqueue_result.batch.batch_id,
|
||||
"enqueued": enqueue_result.enqueued,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||
"""Emitted when a queue is cleared"""
|
||||
self.dispatch(QueueClearedEvent.build(queue_id))
|
||||
"""Emitted when the queue is cleared"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_cleared",
|
||||
payload={"queue_id": queue_id},
|
||||
)
|
||||
|
||||
# endregion
|
||||
def emit_download_started(self, source: str, download_path: str) -> None:
|
||||
"""
|
||||
Emit when a download job is started.
|
||||
|
||||
# region Download
|
||||
:param url: The downloaded url
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_started",
|
||||
payload={"source": source, "download_path": download_path},
|
||||
)
|
||||
|
||||
def emit_download_started(self, job: "DownloadJob") -> None:
|
||||
"""Emitted when a download is started"""
|
||||
self.dispatch(DownloadStartedEvent.build(job))
|
||||
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
|
||||
"""
|
||||
Emit "download_progress" events at regular intervals during a download job.
|
||||
|
||||
def emit_download_progress(self, job: "DownloadJob") -> None:
|
||||
"""Emitted at intervals during a download"""
|
||||
self.dispatch(DownloadProgressEvent.build(job))
|
||||
:param source: The downloaded source
|
||||
:param download_path: The local downloaded file
|
||||
:param current_bytes: Number of bytes downloaded so far
|
||||
:param total_bytes: The size of the file being downloaded (if known)
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_progress",
|
||||
payload={
|
||||
"source": source,
|
||||
"download_path": download_path,
|
||||
"current_bytes": current_bytes,
|
||||
"total_bytes": total_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_download_complete(self, job: "DownloadJob") -> None:
|
||||
"""Emitted when a download is completed"""
|
||||
self.dispatch(DownloadCompleteEvent.build(job))
|
||||
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
|
||||
"""
|
||||
Emit a "download_complete" event at the end of a successful download.
|
||||
|
||||
def emit_download_cancelled(self, job: "DownloadJob") -> None:
|
||||
"""Emitted when a download is cancelled"""
|
||||
self.dispatch(DownloadCancelledEvent.build(job))
|
||||
:param source: Source URL
|
||||
:param download_path: Path to the locally downloaded file
|
||||
:param total_bytes: The size of the downloaded file
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_complete",
|
||||
payload={
|
||||
"source": source,
|
||||
"download_path": download_path,
|
||||
"total_bytes": total_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_download_error(self, job: "DownloadJob") -> None:
|
||||
"""Emitted when a download encounters an error"""
|
||||
self.dispatch(DownloadErrorEvent.build(job))
|
||||
def emit_download_cancelled(self, source: str) -> None:
|
||||
"""Emit a "download_cancelled" event in the event that the download was cancelled by user."""
|
||||
self.__emit_download_event(
|
||||
event_name="download_cancelled",
|
||||
payload={
|
||||
"source": source,
|
||||
},
|
||||
)
|
||||
|
||||
# endregion
|
||||
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
|
||||
"""
|
||||
Emit a "download_error" event when an download job encounters an exception.
|
||||
|
||||
# region Model loading
|
||||
:param source: Source URL
|
||||
:param error_type: The name of the exception that raised the error
|
||||
:param error: The traceback from this error
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_error",
|
||||
payload={
|
||||
"source": source,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
|
||||
"""Emitted when a model load is started."""
|
||||
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
|
||||
|
||||
def emit_model_load_complete(
|
||||
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
|
||||
def emit_model_install_downloading(
|
||||
self,
|
||||
source: str,
|
||||
local_path: str,
|
||||
bytes: int,
|
||||
total_bytes: int,
|
||||
parts: List[Dict[str, Union[str, int]]],
|
||||
id: int,
|
||||
) -> None:
|
||||
"""Emitted when a model load is complete."""
|
||||
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
|
||||
"""
|
||||
Emit at intervals while the install job is in progress (remote models only).
|
||||
|
||||
# endregion
|
||||
:param source: Source of the model
|
||||
:param local_path: Where model is downloading to
|
||||
:param parts: Progress of downloading URLs that comprise the model, if any.
|
||||
:param bytes: Number of bytes downloaded so far.
|
||||
:param total_bytes: Total size of download, including all files.
|
||||
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_downloading",
|
||||
payload={
|
||||
"source": source,
|
||||
"local_path": local_path,
|
||||
"bytes": bytes,
|
||||
"total_bytes": total_bytes,
|
||||
"parts": parts,
|
||||
"id": id,
|
||||
},
|
||||
)
|
||||
|
||||
# region Model install
|
||||
def emit_model_install_downloads_done(self, source: str) -> None:
|
||||
"""
|
||||
Emit once when all parts are downloaded, but before the probing and registration start.
|
||||
|
||||
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
|
||||
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
||||
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_downloads_done",
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
|
||||
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
|
||||
def emit_model_install_running(self, source: str) -> None:
|
||||
"""
|
||||
Emit once when an install job becomes active.
|
||||
|
||||
def emit_model_install_started(self, job: "ModelInstallJob") -> None:
|
||||
"""Emitted once when an install job is started (after any download)."""
|
||||
self.dispatch(ModelInstallStartedEvent.build(job))
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_running",
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
|
||||
"""Emitted when an install job is completed successfully."""
|
||||
self.dispatch(ModelInstallCompleteEvent.build(job))
|
||||
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
|
||||
"""
|
||||
Emit when an install job is completed successfully.
|
||||
|
||||
def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
|
||||
"""Emitted when an install job is cancelled."""
|
||||
self.dispatch(ModelInstallCancelledEvent.build(job))
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
:param key: Model config record key
|
||||
:param total_bytes: Size of the model (may be None for installation of a local path)
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_completed",
|
||||
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
|
||||
)
|
||||
|
||||
def emit_model_install_error(self, job: "ModelInstallJob") -> None:
|
||||
"""Emitted when an install job encounters an exception."""
|
||||
self.dispatch(ModelInstallErrorEvent.build(job))
|
||||
def emit_model_install_cancelled(self, source: str, id: int) -> None:
|
||||
"""
|
||||
Emit when an install job is cancelled.
|
||||
|
||||
# endregion
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_cancelled",
|
||||
payload={"source": source, "id": id},
|
||||
)
|
||||
|
||||
# region Bulk image download
|
||||
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
|
||||
"""
|
||||
Emit when an install job encounters an exception.
|
||||
|
||||
:param source: Source of the model
|
||||
:param error_type: The name of the exception
|
||||
:param error: A text description of the exception
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_error",
|
||||
payload={"source": source, "error_type": error_type, "error": error, "id": id},
|
||||
)
|
||||
|
||||
def emit_bulk_download_started(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk image download is started"""
|
||||
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
||||
|
||||
def emit_bulk_download_complete(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk image download is complete"""
|
||||
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
||||
|
||||
def emit_bulk_download_error(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk image download has an error"""
|
||||
self.dispatch(
|
||||
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
|
||||
"""Emitted when a bulk download starts"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_started",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
},
|
||||
)
|
||||
|
||||
# endregion
|
||||
def emit_bulk_download_completed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk download completes"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_completed",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_bulk_download_failed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk download fails"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_failed",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,608 +0,0 @@
|
||||
from math import floor
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.registry.payload_schema import registry as payload_schema
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, field_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
|
||||
|
||||
class EventBase(BaseModel):
|
||||
"""Base class for all events. All events must inherit from this class.
|
||||
|
||||
Events must define a class attribute `__event_name__` to identify the event.
|
||||
|
||||
All other attributes should be defined as normal for a pydantic model.
|
||||
|
||||
A timestamp is automatically added to the event when it is created.
|
||||
"""
|
||||
|
||||
__event_name__: ClassVar[str]
|
||||
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
||||
|
||||
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
||||
|
||||
@classmethod
|
||||
def get_events(cls) -> set[type["EventBase"]]:
|
||||
"""Get a set of all event models."""
|
||||
|
||||
event_subclasses: set[type["EventBase"]] = set()
|
||||
for subclass in cls.__subclasses__():
|
||||
# We only want to include subclasses that are event models, not intermediary classes
|
||||
if hasattr(subclass, "__event_name__"):
|
||||
event_subclasses.add(subclass)
|
||||
event_subclasses.update(subclass.get_events())
|
||||
|
||||
return event_subclasses
|
||||
|
||||
|
||||
TEvent = TypeVar("TEvent", bound=EventBase, contravariant=True)
|
||||
|
||||
FastAPIEvent: TypeAlias = tuple[str, TEvent]
|
||||
"""
|
||||
A tuple representing a `fastapi-events` event, with the event name and payload.
|
||||
Provide a generic type to `TEvent` to specify the payload type.
|
||||
"""
|
||||
|
||||
|
||||
class FastAPIEventFunc(Protocol, Generic[TEvent]):
|
||||
def __call__(self, event: FastAPIEvent[TEvent]) -> Optional[Coroutine[Any, Any, None]]: ...
|
||||
|
||||
|
||||
def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc[TEvent]) -> None:
|
||||
"""Register a function to handle specific events.
|
||||
|
||||
:param events: An event or set of events to handle
|
||||
:param func: The function to handle the events
|
||||
"""
|
||||
events = events if isinstance(events, set) else {events}
|
||||
for event in events:
|
||||
assert hasattr(event, "__event_name__")
|
||||
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||
|
||||
|
||||
class QueueEventBase(EventBase):
|
||||
"""Base class for queue events"""
|
||||
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
|
||||
|
||||
class QueueItemEventBase(QueueEventBase):
|
||||
"""Base class for queue item events"""
|
||||
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
|
||||
|
||||
class InvocationEventBase(QueueItemEventBase):
|
||||
"""Base class for invocation events"""
|
||||
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
invocation: SerializeAsAny[BaseInvocation] = Field(description="The ID of the invocation")
|
||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||
|
||||
@field_validator("invocation", mode="plain")
|
||||
@classmethod
|
||||
def validate_invocation(cls, v: Any):
|
||||
"""Validates the invocation using the dynamic type adapter."""
|
||||
|
||||
invocation = BaseInvocation.get_typeadapter().validate_python(v)
|
||||
return invocation
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class InvocationStartedEvent(InvocationEventBase):
|
||||
"""Event model for invocation_started"""
|
||||
|
||||
__event_name__ = "invocation_started"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem, invocation: BaseInvocation) -> "InvocationStartedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
"""Event model for invocation_denoise_progress"""
|
||||
|
||||
__event_name__ = "invocation_denoise_progress"
|
||||
|
||||
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
|
||||
step: int = Field(description="The current step of the invocation")
|
||||
total_steps: int = Field(description="The total number of steps in the invocation")
|
||||
order: int = Field(description="The order of the invocation in the session")
|
||||
percentage: float = Field(description="The percentage of completion of the invocation")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
progress_image: ProgressImage,
|
||||
) -> "InvocationDenoiseProgressEvent":
|
||||
step = intermediate_state.step
|
||||
total_steps = intermediate_state.total_steps
|
||||
order = intermediate_state.order
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
progress_image=progress_image,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
order=order,
|
||||
percentage=cls.calc_percentage(step, total_steps, order),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
|
||||
"""Calculate the percentage of completion of denoising."""
|
||||
if total_steps == 0:
|
||||
return 0.0
|
||||
if scheduler_order == 2:
|
||||
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
|
||||
# order == 1
|
||||
return (step + 1 + 1) / (total_steps + 1)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class InvocationCompleteEvent(InvocationEventBase):
|
||||
"""Event model for invocation_complete"""
|
||||
|
||||
__event_name__ = "invocation_complete"
|
||||
|
||||
result: SerializeAsAny[BaseInvocationOutput] = Field(description="The result of the invocation")
|
||||
|
||||
@field_validator("result", mode="plain")
|
||||
@classmethod
|
||||
def validate_results(cls, v: Any):
|
||||
"""Validates the invocation result using the dynamic type adapter."""
|
||||
|
||||
result = BaseInvocationOutput.get_typeadapter().validate_python(v)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, queue_item: SessionQueueItem, invocation: BaseInvocation, result: BaseInvocationOutput
|
||||
) -> "InvocationCompleteEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
result=result,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class InvocationErrorEvent(InvocationEventBase):
|
||||
"""Event model for invocation_error"""
|
||||
|
||||
__event_name__ = "invocation_error"
|
||||
|
||||
error_type: str = Field(description="The error type")
|
||||
error_message: str = Field(description="The error message")
|
||||
error_traceback: str = Field(description="The error traceback")
|
||||
user_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
|
||||
project_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: BaseInvocation,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
) -> "InvocationErrorEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
user_id=getattr(queue_item, "user_id", None),
|
||||
project_id=getattr(queue_item, "project_id", None),
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
"""Event model for queue_item_status_changed"""
|
||||
|
||||
__event_name__ = "queue_item_status_changed"
|
||||
|
||||
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
|
||||
error_type: Optional[str] = Field(default=None, description="The error type, if any")
|
||||
error_message: Optional[str] = Field(default=None, description="The error message, if any")
|
||||
error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any")
|
||||
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
|
||||
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
|
||||
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
|
||||
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
|
||||
batch_status: BatchStatus = Field(description="The status of the batch")
|
||||
queue_status: SessionQueueStatus = Field(description="The status of the queue")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus
|
||||
) -> "QueueItemStatusChangedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
status=queue_item.status,
|
||||
error_type=queue_item.error_type,
|
||||
error_message=queue_item.error_message,
|
||||
error_traceback=queue_item.error_traceback,
|
||||
created_at=str(queue_item.created_at) if queue_item.created_at else None,
|
||||
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
|
||||
started_at=str(queue_item.started_at) if queue_item.started_at else None,
|
||||
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class BatchEnqueuedEvent(QueueEventBase):
|
||||
"""Event model for batch_enqueued"""
|
||||
|
||||
__event_name__ = "batch_enqueued"
|
||||
|
||||
batch_id: str = Field(description="The ID of the batch")
|
||||
enqueued: int = Field(description="The number of invocations enqueued")
|
||||
requested: int = Field(
|
||||
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
||||
)
|
||||
priority: int = Field(description="The priority of the batch")
|
||||
|
||||
@classmethod
|
||||
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
||||
return cls(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
enqueued=enqueue_result.enqueued,
|
||||
requested=enqueue_result.requested,
|
||||
priority=enqueue_result.priority,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class QueueClearedEvent(QueueEventBase):
|
||||
"""Event model for queue_cleared"""
|
||||
|
||||
__event_name__ = "queue_cleared"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_id: str) -> "QueueClearedEvent":
|
||||
return cls(queue_id=queue_id)
|
||||
|
||||
|
||||
class DownloadEventBase(EventBase):
|
||||
"""Base class for events associated with a download"""
|
||||
|
||||
source: str = Field(description="The source of the download")
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class DownloadStartedEvent(DownloadEventBase):
|
||||
"""Event model for download_started"""
|
||||
|
||||
__event_name__ = "download_started"
|
||||
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadStartedEvent":
|
||||
assert job.download_path
|
||||
return cls(source=str(job.source), download_path=job.download_path.as_posix())
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class DownloadProgressEvent(DownloadEventBase):
|
||||
"""Event model for download_progress"""
|
||||
|
||||
__event_name__ = "download_progress"
|
||||
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
current_bytes: int = Field(description="The number of bytes downloaded so far")
|
||||
total_bytes: int = Field(description="The total number of bytes to be downloaded")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadProgressEvent":
|
||||
assert job.download_path
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
current_bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class DownloadCompleteEvent(DownloadEventBase):
|
||||
"""Event model for download_complete"""
|
||||
|
||||
__event_name__ = "download_complete"
|
||||
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
total_bytes: int = Field(description="The total number of bytes downloaded")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent":
|
||||
assert job.download_path
|
||||
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class DownloadCancelledEvent(DownloadEventBase):
|
||||
"""Event model for download_cancelled"""
|
||||
|
||||
__event_name__ = "download_cancelled"
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent":
|
||||
return cls(source=str(job.source))
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class DownloadErrorEvent(DownloadEventBase):
|
||||
"""Event model for download_error"""
|
||||
|
||||
__event_name__ = "download_error"
|
||||
|
||||
error_type: str = Field(description="The type of error")
|
||||
error: str = Field(description="The error message")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadErrorEvent":
|
||||
assert job.error_type
|
||||
assert job.error
|
||||
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
|
||||
|
||||
|
||||
class ModelEventBase(EventBase):
|
||||
"""Base class for events associated with a model"""
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelLoadStartedEvent(ModelEventBase):
|
||||
"""Event model for model_load_started"""
|
||||
|
||||
__event_name__ = "model_load_started"
|
||||
|
||||
config: AnyModelConfig = Field(description="The model's config")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
|
||||
return cls(config=config, submodel_type=submodel_type)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelLoadCompleteEvent(ModelEventBase):
|
||||
"""Event model for model_load_complete"""
|
||||
|
||||
__event_name__ = "model_load_complete"
|
||||
|
||||
config: AnyModelConfig = Field(description="The model's config")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
|
||||
return cls(config=config, submodel_type=submodel_type)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||
"""Event model for model_install_download_progress"""
|
||||
|
||||
__event_name__ = "model_install_download_progress"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
local_path: str = Field(description="Where model is downloading to")
|
||||
bytes: int = Field(description="Number of bytes downloaded so far")
|
||||
total_bytes: int = Field(description="Total size of download, including all files")
|
||||
parts: list[dict[str, int | str]] = Field(
|
||||
description="Progress of downloading URLs that comprise the model, if any"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
|
||||
parts: list[dict[str, str | int]] = [
|
||||
{
|
||||
"url": str(x.source),
|
||||
"local_path": str(x.download_path),
|
||||
"bytes": x.bytes,
|
||||
"total_bytes": x.total_bytes,
|
||||
}
|
||||
for x in job.download_parts
|
||||
]
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
local_path=job.local_path.as_posix(),
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
|
||||
"""Emitted once when an install job becomes active."""
|
||||
|
||||
__event_name__ = "model_install_downloads_complete"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
|
||||
return cls(id=job.id, source=str(job.source))
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallStartedEvent(ModelEventBase):
|
||||
"""Event model for model_install_started"""
|
||||
|
||||
__event_name__ = "model_install_started"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
|
||||
return cls(id=job.id, source=str(job.source))
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallCompleteEvent(ModelEventBase):
|
||||
"""Event model for model_install_complete"""
|
||||
|
||||
__event_name__ = "model_install_complete"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
key: str = Field(description="Model config record key")
|
||||
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
|
||||
assert job.config_out is not None
|
||||
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallCancelledEvent(ModelEventBase):
|
||||
"""Event model for model_install_cancelled"""
|
||||
|
||||
__event_name__ = "model_install_cancelled"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
|
||||
return cls(id=job.id, source=str(job.source))
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallErrorEvent(ModelEventBase):
|
||||
"""Event model for model_install_error"""
|
||||
|
||||
__event_name__ = "model_install_error"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
error_type: str = Field(description="The name of the exception")
|
||||
error: str = Field(description="A text description of the exception")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
|
||||
assert job.error_type is not None
|
||||
assert job.error is not None
|
||||
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
|
||||
|
||||
|
||||
class BulkDownloadEventBase(EventBase):
|
||||
"""Base class for events associated with a bulk image download"""
|
||||
|
||||
bulk_download_id: str = Field(description="The ID of the bulk image download")
|
||||
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
|
||||
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class BulkDownloadStartedEvent(BulkDownloadEventBase):
|
||||
"""Event model for bulk_download_started"""
|
||||
|
||||
__event_name__ = "bulk_download_started"
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> "BulkDownloadStartedEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
|
||||
"""Event model for bulk_download_complete"""
|
||||
|
||||
__event_name__ = "bulk_download_complete"
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> "BulkDownloadCompleteEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class BulkDownloadErrorEvent(BulkDownloadEventBase):
|
||||
"""Event model for bulk_download_error"""
|
||||
|
||||
__event_name__ = "bulk_download_error"
|
||||
|
||||
error: str = Field(description="The error message")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
) -> "BulkDownloadErrorEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
error=error,
|
||||
)
|
||||
@@ -1,47 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from invokeai.app.services.events.events_common import (
|
||||
EventBase,
|
||||
)
|
||||
|
||||
from .events_base import EventServiceBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self._queue = Queue[EventBase | None]()
|
||||
self._stop_event = threading.Event()
|
||||
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self._stop_event.set()
|
||||
self._queue.put(None)
|
||||
|
||||
def dispatch(self, event: EventBase) -> None:
|
||||
self._queue.put(event)
|
||||
|
||||
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self._queue.get(block=False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
# Leave the payloads as live pydantic models
|
||||
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
||||
@@ -1,13 +1,11 @@
|
||||
"""Initialization file for model install service package."""
|
||||
|
||||
from .model_install_base import (
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from .model_install_common import (
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
UnknownInstallJobException,
|
||||
URLModelSource,
|
||||
|
||||
@@ -1,19 +1,244 @@
|
||||
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team
|
||||
"""Baseclass definitions for the model installer."""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
DOWNLOADING = "downloading" # downloading of model files in process
|
||||
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
CANCELLED = "cancelled" # terminated with an error message
|
||||
|
||||
|
||||
class ModelInstallPart(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
path: Path
|
||||
bytes: int = 0
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
class UnknownInstallJobException(Exception):
|
||||
"""Raised when the status of an unknown job is requested."""
|
||||
|
||||
|
||||
class StringLikeSource(BaseModel):
|
||||
"""
|
||||
Base class for model sources, implements functions that lets the source be sorted and indexed.
|
||||
|
||||
These shenanigans let this stuff work:
|
||||
|
||||
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
|
||||
mydict = {source1: 'model 1'}
|
||||
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
|
||||
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
|
||||
|
||||
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
|
||||
assert source1 == source2
|
||||
assert source1 == 'C:/users/mort/foo.safetensors'
|
||||
"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the path field, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __lt__(self, other: object) -> int:
|
||||
"""Return comparison of the stringified version, for sorting."""
|
||||
return str(self) < str(other)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Return equality on the stringified version."""
|
||||
if isinstance(other, Path):
|
||||
return str(self) == other.as_posix()
|
||||
else:
|
||||
return str(self) == str(other)
|
||||
|
||||
|
||||
class LocalModelSource(StringLikeSource):
|
||||
"""A local file or directory path."""
|
||||
|
||||
path: str | Path
|
||||
inplace: Optional[bool] = False
|
||||
type: Literal["local"] = "local"
|
||||
|
||||
# these methods allow the source to be used in a string-like way,
|
||||
# for example as an index into a dict
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of path when string rep needed."""
|
||||
return Path(self.path).as_posix()
|
||||
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""
|
||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||
Note that the variant option, if not provided to the constructor, will default to fp16, which is
|
||||
what people (almost) always want.
|
||||
"""
|
||||
|
||||
repo_id: str
|
||||
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
|
||||
subfolder: Optional[Path] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["hf"] = "hf"
|
||||
|
||||
@field_validator("repo_id")
|
||||
@classmethod
|
||||
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
|
||||
raise ValueError(f"{v}: invalid repo_id format")
|
||||
return v
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f":{self.subfolder}"
|
||||
return base
|
||||
|
||||
|
||||
class URLModelSource(StringLikeSource):
|
||||
"""A generic URL point to a checkpoint file."""
|
||||
|
||||
url: AnyHttpUrl
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["url"] = "url"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of the url when string rep needed."""
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
)
|
||||
config_out: Optional[AnyModelConfig] = Field(
|
||||
default=None, description="After successful installation, this will hold the configuration object."
|
||||
)
|
||||
inplace: bool = Field(
|
||||
default=False, description="Leave model in its current location; otherwise install under models directory"
|
||||
)
|
||||
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||
bytes: int = Field(
|
||||
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
||||
)
|
||||
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
|
||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||
default=None, description="Metadata provided by the model source"
|
||||
)
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||
)
|
||||
error_traceback: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the exception traceback"
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self._exception = e
|
||||
self.error = str(e)
|
||||
self.error_traceback = self._format_error(e)
|
||||
self.status = InstallStatus.ERROR
|
||||
self.error_reason = self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
self.status = InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def error_type(self) -> Optional[str]:
|
||||
"""Class name of the exception that led to status==ERROR."""
|
||||
return self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def _format_error(self, exception: Exception) -> str:
|
||||
"""Error traceback."""
|
||||
return "".join(traceback.format_exception(exception))
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Set status to CANCELLED."""
|
||||
return self.status == InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
"""Return true if job has errored."""
|
||||
return self.status == InstallStatus.ERROR
|
||||
|
||||
@property
|
||||
def waiting(self) -> bool:
|
||||
"""Return true if job is waiting to run."""
|
||||
return self.status == InstallStatus.WAITING
|
||||
|
||||
@property
|
||||
def downloading(self) -> bool:
|
||||
"""Return true if job is downloading."""
|
||||
return self.status == InstallStatus.DOWNLOADING
|
||||
|
||||
@property
|
||||
def downloads_done(self) -> bool:
|
||||
"""Return true if job's downloads ae done."""
|
||||
return self.status == InstallStatus.DOWNLOADS_DONE
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if job is running."""
|
||||
return self.status == InstallStatus.RUNNING
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
"""Return true if job completed without errors."""
|
||||
return self.status == InstallStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def in_terminal_state(self) -> bool:
|
||||
"""Return true if job is in a terminal state."""
|
||||
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
@@ -57,7 +282,7 @@ class ModelInstallServiceBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def event_bus(self) -> Optional["EventServiceBase"]:
|
||||
def event_bus(self) -> Optional[EventServiceBase]:
|
||||
"""Return the event service base object associated with the installer."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,233 +0,0 @@
|
||||
import re
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.download import DownloadJob
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
DOWNLOADING = "downloading" # downloading of model files in process
|
||||
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
CANCELLED = "cancelled" # terminated with an error message
|
||||
|
||||
|
||||
class ModelInstallPart(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
path: Path
|
||||
bytes: int = 0
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
class UnknownInstallJobException(Exception):
|
||||
"""Raised when the status of an unknown job is requested."""
|
||||
|
||||
|
||||
class StringLikeSource(BaseModel):
|
||||
"""
|
||||
Base class for model sources, implements functions that lets the source be sorted and indexed.
|
||||
|
||||
These shenanigans let this stuff work:
|
||||
|
||||
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
|
||||
mydict = {source1: 'model 1'}
|
||||
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
|
||||
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
|
||||
|
||||
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
|
||||
assert source1 == source2
|
||||
assert source1 == 'C:/users/mort/foo.safetensors'
|
||||
"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the path field, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __lt__(self, other: object) -> int:
|
||||
"""Return comparison of the stringified version, for sorting."""
|
||||
return str(self) < str(other)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Return equality on the stringified version."""
|
||||
if isinstance(other, Path):
|
||||
return str(self) == other.as_posix()
|
||||
else:
|
||||
return str(self) == str(other)
|
||||
|
||||
|
||||
class LocalModelSource(StringLikeSource):
|
||||
"""A local file or directory path."""
|
||||
|
||||
path: str | Path
|
||||
inplace: Optional[bool] = False
|
||||
type: Literal["local"] = "local"
|
||||
|
||||
# these methods allow the source to be used in a string-like way,
|
||||
# for example as an index into a dict
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of path when string rep needed."""
|
||||
return Path(self.path).as_posix()
|
||||
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""
|
||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||
Note that the variant option, if not provided to the constructor, will default to fp16, which is
|
||||
what people (almost) always want.
|
||||
"""
|
||||
|
||||
repo_id: str
|
||||
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
|
||||
subfolder: Optional[Path] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["hf"] = "hf"
|
||||
|
||||
@field_validator("repo_id")
|
||||
@classmethod
|
||||
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
|
||||
raise ValueError(f"{v}: invalid repo_id format")
|
||||
return v
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f":{self.subfolder}"
|
||||
return base
|
||||
|
||||
|
||||
class URLModelSource(StringLikeSource):
|
||||
"""A generic URL point to a checkpoint file."""
|
||||
|
||||
url: AnyHttpUrl
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["url"] = "url"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of the url when string rep needed."""
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
)
|
||||
config_out: Optional[AnyModelConfig] = Field(
|
||||
default=None, description="After successful installation, this will hold the configuration object."
|
||||
)
|
||||
inplace: bool = Field(
|
||||
default=False, description="Leave model in its current location; otherwise install under models directory"
|
||||
)
|
||||
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||
bytes: int = Field(
|
||||
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
||||
)
|
||||
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
|
||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||
default=None, description="Metadata provided by the model source"
|
||||
)
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||
)
|
||||
error_traceback: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the exception traceback"
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self._exception = e
|
||||
self.error = str(e)
|
||||
self.error_traceback = self._format_error(e)
|
||||
self.status = InstallStatus.ERROR
|
||||
self.error_reason = self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
self.status = InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def error_type(self) -> Optional[str]:
|
||||
"""Class name of the exception that led to status==ERROR."""
|
||||
return self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def _format_error(self, exception: Exception) -> str:
|
||||
"""Error traceback."""
|
||||
return "".join(traceback.format_exception(exception))
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Set status to CANCELLED."""
|
||||
return self.status == InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
"""Return true if job has errored."""
|
||||
return self.status == InstallStatus.ERROR
|
||||
|
||||
@property
|
||||
def waiting(self) -> bool:
|
||||
"""Return true if job is waiting to run."""
|
||||
return self.status == InstallStatus.WAITING
|
||||
|
||||
@property
|
||||
def downloading(self) -> bool:
|
||||
"""Return true if job is downloading."""
|
||||
return self.status == InstallStatus.DOWNLOADING
|
||||
|
||||
@property
|
||||
def downloads_done(self) -> bool:
|
||||
"""Return true if job's downloads ae done."""
|
||||
return self.status == InstallStatus.DOWNLOADS_DONE
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if job is running."""
|
||||
return self.status == InstallStatus.RUNNING
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
"""Return true if job completed without errors."""
|
||||
return self.status == InstallStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def in_terminal_state(self) -> bool:
|
||||
"""Return true if job is in a terminal state."""
|
||||
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
|
||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@@ -20,8 +20,8 @@ from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@@ -45,12 +45,13 @@ from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .model_install_common import (
|
||||
from .model_install_base import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
StringLikeSource,
|
||||
URLModelSource,
|
||||
@@ -58,9 +59,6 @@ from .model_install_common import (
|
||||
|
||||
TMPDIR_PREFIX = "tmpinstall_"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""class for InvokeAI model installation."""
|
||||
@@ -70,7 +68,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
session: Optional[Session] = None,
|
||||
):
|
||||
"""
|
||||
@@ -106,7 +104,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return self._record_store
|
||||
|
||||
@property
|
||||
def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102
|
||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||
return self._event_bus
|
||||
|
||||
# make the invoker optional here because we don't need it and it
|
||||
@@ -857,17 +855,35 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.status = InstallStatus.RUNNING
|
||||
self._logger.info(f"Model install started: {job.source}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_started(job)
|
||||
self._event_bus.emit_model_install_running(str(job.source))
|
||||
|
||||
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_download_progress(job)
|
||||
parts: List[Dict[str, str | int]] = [
|
||||
{
|
||||
"url": str(x.source),
|
||||
"local_path": str(x.download_path),
|
||||
"bytes": x.bytes,
|
||||
"total_bytes": x.total_bytes,
|
||||
}
|
||||
for x in job.download_parts
|
||||
]
|
||||
assert job.bytes is not None
|
||||
assert job.total_bytes is not None
|
||||
self._event_bus.emit_model_install_downloading(
|
||||
str(job.source),
|
||||
local_path=job.local_path.as_posix(),
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
id=job.id,
|
||||
)
|
||||
|
||||
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.DOWNLOADS_DONE
|
||||
self._logger.info(f"Model download complete: {job.source}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_downloads_complete(job)
|
||||
self._event_bus.emit_model_install_downloads_done(str(job.source))
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.COMPLETED
|
||||
@@ -875,19 +891,24 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"Model install complete: {job.source}")
|
||||
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_complete(job)
|
||||
assert job.local_path is not None
|
||||
assert job.config_out is not None
|
||||
key = job.config_out.key
|
||||
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
|
||||
if self._event_bus:
|
||||
assert job.error_type is not None
|
||||
assert job.error is not None
|
||||
self._event_bus.emit_model_install_error(job)
|
||||
error_type = job.error_type
|
||||
error = job.error
|
||||
assert error_type is not None
|
||||
assert error is not None
|
||||
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
|
||||
|
||||
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
|
||||
self._logger.info(f"Model install canceled: {job.source}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_cancelled(job)
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
@@ -14,12 +15,18 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Wrapper around AnyModelLoader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context_data: Invocation context data used for event reporting
|
||||
"""
|
||||
|
||||
@property
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional, Type
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
@@ -50,18 +51,25 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
|
||||
# We don't have an invoker during testing
|
||||
# TODO(psyche): Mock this method on the invoker in the tests
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
|
||||
if context_data:
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
||||
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
loaded_model: LoadedModel = implementation(
|
||||
@@ -71,7 +79,40 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
convert_cache=self._convert_cache,
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
|
||||
|
||||
if context_data:
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
loaded=True,
|
||||
)
|
||||
return loaded_model
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context_data: InvocationContextData,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
if not self._invoker:
|
||||
return
|
||||
|
||||
if not loaded:
|
||||
self._invoker.services.events.emit_model_load_started(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
else:
|
||||
self._invoker.services.events.emit_model_load_completed(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
||||
@@ -4,14 +4,11 @@ from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
FastAPIEvent,
|
||||
QueueClearedEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
register_events,
|
||||
)
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_base import (
|
||||
OnAfterRunNode,
|
||||
@@ -63,11 +60,6 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
self._cancel_event = cancel_event
|
||||
self._profiler = profiler
|
||||
|
||||
def _is_canceled(self) -> bool:
|
||||
"""Check if the cancel event is set. This is also passed to the invocation context builder and called during
|
||||
denoising to check if the session has been canceled."""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def run(self, queue_item: SessionQueueItem):
|
||||
# Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here.
|
||||
|
||||
@@ -91,19 +83,13 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
)
|
||||
break
|
||||
|
||||
if invocation is None or self._is_canceled():
|
||||
if invocation is None or self._cancel_event.is_set():
|
||||
break
|
||||
|
||||
self.run_node(invocation, queue_item)
|
||||
|
||||
# The session is complete if all invocations have been run or there is an error on the session.
|
||||
# At this time, the queue item may be canceled, but the object itself here won't be updated yet. We must
|
||||
# use the cancel event to check if the session is canceled.
|
||||
if (
|
||||
queue_item.session.is_complete()
|
||||
or self._is_canceled()
|
||||
or queue_item.status in ["failed", "canceled", "completed"]
|
||||
):
|
||||
if queue_item.session.is_complete() or self._cancel_event.is_set():
|
||||
break
|
||||
|
||||
self._on_after_run_session(queue_item=queue_item)
|
||||
@@ -122,7 +108,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._services,
|
||||
is_canceled=self._is_canceled,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
@@ -136,12 +122,16 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here?
|
||||
pass
|
||||
except CanceledException:
|
||||
# A CanceledException is raised during the denoising step callback if the cancel event is set. We don't need
|
||||
# to do any handling here, and no error should be set - just pass and the cancellation will be handled
|
||||
# correctly in the next iteration of the session runner loop.
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# See the comment in the processor's `_on_queue_item_status_changed()` method for more details on how we
|
||||
# handle cancellation.
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
error_type = e.__class__.__name__
|
||||
@@ -156,11 +146,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
)
|
||||
|
||||
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Called before a session is run.
|
||||
|
||||
- Start the profiler if profiling is enabled.
|
||||
- Run any callbacks registered for this event.
|
||||
"""
|
||||
"""Run before a session is executed"""
|
||||
|
||||
self._services.logger.debug(
|
||||
f"On before run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
|
||||
@@ -174,14 +160,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
callback(queue_item=queue_item)
|
||||
|
||||
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Called after a session is run.
|
||||
|
||||
- Stop the profiler if profiling is enabled.
|
||||
- Update the queue item's session object in the database.
|
||||
- If not already canceled or failed, complete the queue item.
|
||||
- Log and reset performance statistics.
|
||||
- Run any callbacks registered for this event.
|
||||
"""
|
||||
"""Run after a session is executed"""
|
||||
|
||||
self._services.logger.debug(
|
||||
f"On after run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
|
||||
@@ -201,10 +180,14 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# while the session is running.
|
||||
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
|
||||
# The queue item may have been canceled or failed while the session was running. We should only complete it
|
||||
# if it is not already canceled or failed.
|
||||
if queue_item.status not in ["canceled", "failed"]:
|
||||
queue_item = self._services.session_queue.complete_queue_item(queue_item.item_id)
|
||||
# TODO(psyche): This feels jumbled - we should review separation of concerns here.
|
||||
# Send complete event. The events service will receive this and update the queue item's status.
|
||||
self._services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
)
|
||||
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
@@ -218,18 +201,21 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
pass
|
||||
|
||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Called before a node is run.
|
||||
|
||||
- Emits an invocation started event.
|
||||
- Run any callbacks registered for this event.
|
||||
"""
|
||||
"""Run before a node is executed"""
|
||||
|
||||
self._services.logger.debug(
|
||||
f"On before run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
|
||||
)
|
||||
|
||||
# Send starting event
|
||||
self._services.events.emit_invocation_started(queue_item=queue_item, invocation=invocation)
|
||||
self._services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
)
|
||||
|
||||
for callback in self._on_before_run_node_callbacks:
|
||||
callback(invocation=invocation, queue_item=queue_item)
|
||||
@@ -237,18 +223,22 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
def _on_after_run_node(
|
||||
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
|
||||
):
|
||||
"""Called after a node is run.
|
||||
|
||||
- Emits an invocation complete event.
|
||||
- Run any callbacks registered for this event.
|
||||
"""
|
||||
"""Run after a node is executed"""
|
||||
|
||||
self._services.logger.debug(
|
||||
f"On after run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
|
||||
)
|
||||
|
||||
# Send complete event on successful runs
|
||||
self._services.events.emit_invocation_complete(invocation=invocation, queue_item=queue_item, output=output)
|
||||
self._services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
result=output.model_dump(),
|
||||
)
|
||||
|
||||
for callback in self._on_after_run_node_callbacks:
|
||||
callback(invocation=invocation, queue_item=queue_item, output=output)
|
||||
@@ -261,14 +251,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
):
|
||||
"""Called when a node errors. Node errors may occur when running or preparing the node..
|
||||
|
||||
- Set the node error on the session object.
|
||||
- Log the error.
|
||||
- Fail the queue item.
|
||||
- Emits an invocation error event.
|
||||
- Run any callbacks registered for this event.
|
||||
"""
|
||||
"""Run when a node errors"""
|
||||
|
||||
self._services.logger.debug(
|
||||
f"On node error: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
|
||||
@@ -282,19 +265,19 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
)
|
||||
self._services.logger.error(error_traceback)
|
||||
|
||||
# Fail the queue item
|
||||
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
queue_item = self._services.session_queue.fail_queue_item(
|
||||
queue_item.item_id, error_type, error_message, error_traceback
|
||||
)
|
||||
|
||||
# Send error event
|
||||
self._services.events.emit_invocation_error(
|
||||
queue_item=queue_item,
|
||||
invocation=invocation,
|
||||
queue_batch_id=queue_item.session_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
user_id=getattr(queue_item, "user_id", None),
|
||||
project_id=getattr(queue_item, "project_id", None),
|
||||
)
|
||||
|
||||
for callback in self._on_node_error_callbacks:
|
||||
@@ -332,9 +315,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now_event = ThreadEvent()
|
||||
self._cancel_event = ThreadEvent()
|
||||
|
||||
register_events(QueueClearedEvent, self._on_queue_cleared)
|
||||
register_events(BatchEnqueuedEvent, self._on_batch_enqueued)
|
||||
register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||
|
||||
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
|
||||
|
||||
@@ -369,25 +350,31 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def _poll_now(self) -> None:
|
||||
self._poll_now_event.set()
|
||||
|
||||
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
|
||||
if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
if (
|
||||
event_name == "session_canceled"
|
||||
and self._queue_item
|
||||
and self._queue_item.item_id == event[1]["data"]["queue_item_id"]
|
||||
):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
|
||||
async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None:
|
||||
self._poll_now()
|
||||
|
||||
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
|
||||
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
|
||||
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
|
||||
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
|
||||
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
|
||||
#
|
||||
# Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such
|
||||
# node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item
|
||||
# is canceled, and if it is, raises a `CanceledException` to stop execution immediately.
|
||||
if event[1].status == "canceled":
|
||||
self._cancel_event.set()
|
||||
elif (
|
||||
event_name == "queue_cleared"
|
||||
and self._queue_item
|
||||
and self._queue_item.queue_id == event[1]["data"]["queue_id"]
|
||||
):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
elif event_name == "batch_enqueued":
|
||||
self._poll_now()
|
||||
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [
|
||||
"completed",
|
||||
"failed",
|
||||
"canceled",
|
||||
]:
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
@@ -476,22 +463,15 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
) -> None:
|
||||
"""Called when a non-fatal error occurs in the processor.
|
||||
|
||||
- Log the error.
|
||||
- If a queue item is provided, update the queue item with the completed session & fail it.
|
||||
- Run any callbacks registered for this event.
|
||||
"""
|
||||
|
||||
# Non-fatal error in processor
|
||||
self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}")
|
||||
self._invoker.services.logger.error(error_traceback)
|
||||
|
||||
if queue_item is not None:
|
||||
# Update the queue item with the completed session & fail it
|
||||
queue_item = self._invoker.services.session_queue.set_queue_item_session(
|
||||
queue_item.item_id, queue_item.session
|
||||
)
|
||||
queue_item = self._invoker.services.session_queue.fail_queue_item(
|
||||
# Update the queue item with the completed session
|
||||
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
# Fail the queue item
|
||||
self._invoker.services.session_queue.fail_queue_item(
|
||||
item_id=queue_item.item_id,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
|
||||
@@ -73,11 +73,6 @@ class SessionQueueBase(ABC):
|
||||
"""Gets the status of a batch"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
"""Completes a session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
"""Cancels a session queue item"""
|
||||
|
||||
@@ -2,6 +2,10 @@ import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
@@ -38,7 +42,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
@@ -48,6 +52,60 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__conn = db.conn
|
||||
self.__cursor = self.__conn.cursor()
|
||||
|
||||
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||
return event[1]["event"] in match_in
|
||||
|
||||
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
# This was a match statement, but match is not supported on python 3.9
|
||||
if event_name == "graph_execution_state_complete":
|
||||
await self._handle_complete_event(event)
|
||||
elif event_name == "invocation_error":
|
||||
await self._handle_error_event(event)
|
||||
elif event_name == "session_canceled":
|
||||
await self._handle_cancel_event(event)
|
||||
return event
|
||||
|
||||
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
# When a queue item has an error, we get an error event, then a completed event.
|
||||
# Mark the queue item completed only if it isn't already marked completed, e.g.
|
||||
# by a previously-handled error event.
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
async def _handle_error_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
error_type = event[1]["data"]["error_type"]
|
||||
error_message = event[1]["data"]["error_message"]
|
||||
error_traceback = event[1]["data"]["error_traceback"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
# always set to failed if have an error, even if previously the item was marked completed or canceled
|
||||
queue_item = self._set_queue_item_status(
|
||||
item_id=queue_item.item_id,
|
||||
status="failed",
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||
@@ -248,7 +306,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
session_queue_item=queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
)
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
@@ -357,11 +419,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return PruneResult(deleted=count)
|
||||
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||
return queue_item
|
||||
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
)
|
||||
return queue_item
|
||||
|
||||
def fail_queue_item(
|
||||
@@ -371,13 +437,21 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
) -> SessionQueueItem:
|
||||
queue_item = self._set_queue_item_status(
|
||||
item_id=item_id,
|
||||
status="failed",
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
queue_item = self._set_queue_item_status(
|
||||
item_id=item_id,
|
||||
status="failed",
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
)
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
@@ -413,10 +487,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
queue_batch_id=current_queue_item.batch_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
current_queue_item, batch_status, queue_status
|
||||
session_queue_item=current_queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
@@ -456,10 +538,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
queue_batch_id=current_queue_item.batch_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
current_queue_item, batch_status, queue_status
|
||||
session_queue_item=current_queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
@@ -352,11 +353,11 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
@@ -381,7 +382,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
"""Gets a model's config.
|
||||
@@ -448,10 +449,10 @@ class ConfigInterface(InvocationContextInterface):
|
||||
|
||||
class UtilInterface(InvocationContextInterface):
|
||||
def __init__(
|
||||
self, services: InvocationServices, data: InvocationContextData, is_canceled: Callable[[], bool]
|
||||
self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event
|
||||
) -> None:
|
||||
super().__init__(services, data)
|
||||
self._is_canceled = is_canceled
|
||||
self._cancel_event = cancel_event
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
"""Checks if the current session has been canceled.
|
||||
@@ -459,7 +460,7 @@ class UtilInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
True if the current session has been canceled, False if not.
|
||||
"""
|
||||
return self._is_canceled()
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
||||
"""
|
||||
@@ -534,7 +535,7 @@ class InvocationContext:
|
||||
def build_invocation_context(
|
||||
services: InvocationServices,
|
||||
data: InvocationContextData,
|
||||
is_canceled: Callable[[], bool],
|
||||
cancel_event: threading.Event,
|
||||
) -> InvocationContext:
|
||||
"""Builds the invocation context for a specific invocation execution.
|
||||
|
||||
@@ -551,7 +552,7 @@ def build_invocation_context(
|
||||
tensors = TensorsInterface(services=services, data=data)
|
||||
models = ModelsInterface(services=services, data=data)
|
||||
config = ConfigInterface(services=services, data=data)
|
||||
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
|
||||
util = UtilInterface(services=services, data=data, cancel_event=cancel_event)
|
||||
conditioning = ConditioningInterface(services=services, data=data)
|
||||
boards = BoardsInterface(services=services, data=data)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -13,36 +13,8 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
|
||||
# fast latents preview matrix for sdxl
|
||||
# generated by @StAlKeR7779
|
||||
SDXL_LATENT_RGB_FACTORS = [
|
||||
# R G B
|
||||
[0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
]
|
||||
SDXL_SMOOTH_MATRIX = [
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
]
|
||||
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
SD1_5_LATENT_RGB_FACTORS = [
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
]
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(
|
||||
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
|
||||
):
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
|
||||
|
||||
if smooth_matrix is not None:
|
||||
@@ -75,12 +47,64 @@ def stable_diffusion_step_callback(
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
# TODO: This does not seem to be needed any more?
|
||||
# # txt2img provides a Tensor in the step_callback
|
||||
# # img2img provides a PipelineIntermediateState
|
||||
# if isinstance(sample, PipelineIntermediateState):
|
||||
# # this was an img2img
|
||||
# print('img2img')
|
||||
# latents = sample.latents
|
||||
# step = sample.step
|
||||
# else:
|
||||
# print('txt2img')
|
||||
# latents = sample
|
||||
# step = intermediate_state.step
|
||||
|
||||
# TODO: only output a preview image when requested
|
||||
|
||||
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
||||
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
|
||||
# fast latents preview matrix for sdxl
|
||||
# generated by @StAlKeR7779
|
||||
sdxl_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
sdxl_smooth_matrix = torch.tensor(
|
||||
[
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
else:
|
||||
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
|
||||
(width, height) = image.size
|
||||
@@ -89,9 +113,15 @@ def stable_diffusion_step_callback(
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
events.emit_invocation_denoise_progress(
|
||||
context_data.queue_item,
|
||||
context_data.invocation,
|
||||
intermediate_state,
|
||||
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
events.emit_generator_progress(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
node_id=context_data.invocation.id,
|
||||
source_node_id=context_data.source_invocation_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
order=intermediate_state.order,
|
||||
total_steps=intermediate_state.total_steps,
|
||||
)
|
||||
|
||||
@@ -30,12 +30,8 @@ def convert_ldm_vae_to_diffusers(
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
with torch.no_grad():
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
del converted_vae_checkpoint # Free memory
|
||||
import gc
|
||||
gc.collect()
|
||||
vae.to(precision)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
vae.to(precision)
|
||||
|
||||
if dump_path:
|
||||
vae.save_pretrained(dump_path, safe_serialization=True)
|
||||
@@ -56,11 +52,7 @@ def convert_ckpt_to_diffusers(
|
||||
model to be written.
|
||||
"""
|
||||
pipe = download_from_original_stable_diffusion_ckpt(Path(checkpoint_path).as_posix(), **kwargs)
|
||||
with torch.no_grad():
|
||||
del kwargs # Free memory
|
||||
import gc
|
||||
gc.collect()
|
||||
pipe = pipe.to(precision)
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
if dump_path:
|
||||
@@ -83,11 +75,7 @@ def convert_controlnet_to_diffusers(
|
||||
model to be written.
|
||||
"""
|
||||
pipe = download_controlnet_from_original_ckpt(checkpoint_path.as_posix(), **kwargs)
|
||||
with torch.no_grad():
|
||||
del kwargs # Free memory
|
||||
import gc
|
||||
gc.collect()
|
||||
pipe = pipe.to(precision)
|
||||
pipe = pipe.to(precision)
|
||||
|
||||
# TO DO: save correct repo variant
|
||||
if dump_path:
|
||||
|
||||
@@ -42,26 +42,10 @@ T = TypeVar("T")
|
||||
|
||||
@dataclass
|
||||
class CacheRecord(Generic[T]):
|
||||
"""
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
"""
|
||||
"""Elements of the cache."""
|
||||
|
||||
key: str
|
||||
model: T
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
@@ -20,6 +20,7 @@ context. Use like this:
|
||||
|
||||
import gc
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from logging import Logger
|
||||
@@ -161,9 +162,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if key in self._cached_models:
|
||||
return
|
||||
self.make_room(size)
|
||||
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
cache_record = CacheRecord(key, model, size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
@@ -258,37 +257,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
||||
return
|
||||
|
||||
source_device = cache_entry.device
|
||||
source_device = cache_entry.model.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
# This roundabout method for moving the model around is done to avoid
|
||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# This operation is slightly faster than running `to()` on the whole model.
|
||||
#
|
||||
# When the model needs to be removed from VRAM we simply delete the copy
|
||||
# of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# in RAM into the model. So this operation is very fast.
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
try:
|
||||
if cache_entry.state_dict is not None:
|
||||
assert hasattr(cache_entry.model, "load_state_dict")
|
||||
if target_device == self.storage_device:
|
||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(torch.device(target_device), copy=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device)
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
@@ -368,12 +347,43 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
|
||||
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
|
||||
# https://docs.python.org/3/library/gc.html#gc.get_referrers
|
||||
|
||||
# manualy clear local variable references of just finished function calls
|
||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
||||
if refs > 2:
|
||||
while True:
|
||||
cleared = False
|
||||
for referrer in gc.get_referrers(cache_entry.model):
|
||||
if type(referrer).__name__ == "frame":
|
||||
# RuntimeError: cannot clear an executing frame
|
||||
with suppress(RuntimeError):
|
||||
referrer.clear()
|
||||
cleared = True
|
||||
# break
|
||||
|
||||
# repeat if referrers changes(due to frame clear), else exit loop
|
||||
if cleared:
|
||||
gc.collect()
|
||||
else:
|
||||
break
|
||||
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
|
||||
f" refs: {refs}"
|
||||
)
|
||||
|
||||
if not cache_entry.locked:
|
||||
# Expected refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
# 1 from onnx runtime object
|
||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Textual Inversion wrapper class."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
@@ -66,52 +66,35 @@ class TextualInversionModelRaw(RawModel):
|
||||
return result
|
||||
|
||||
|
||||
class TextualInversionManager(BaseTextualInversionManager):
|
||||
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
|
||||
# no type hints for BaseTextualInversionManager?
|
||||
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
|
||||
pad_tokens: Dict[int, List[int]]
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.pad_tokens: dict[int, list[int]] = {}
|
||||
self.pad_tokens = {}
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||
"""Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens.
|
||||
|
||||
For example, suppose we have a `<ti_dog>` TI with 4 vectors that was added to the tokenizer with the following
|
||||
mapping of tokens to token_ids:
|
||||
```
|
||||
<ti_dog>: 49408
|
||||
<ti_dog-!pad-1>: 49409
|
||||
<ti_dog-!pad-2>: 49410
|
||||
<ti_dog-!pad-3>: 49411
|
||||
```
|
||||
`self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`.
|
||||
This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`.
|
||||
"""
|
||||
# Short circuit if there are no pad tokens to save a little time.
|
||||
if len(self.pad_tokens) == 0:
|
||||
return token_ids
|
||||
|
||||
# This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify
|
||||
# this assumption here.
|
||||
if token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("token_ids must not start with bos_token_id")
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("token_ids must not end with eos_token_id")
|
||||
|
||||
# Expand any TI tokens to their corresponding pad tokens.
|
||||
new_token_ids: list[int] = []
|
||||
new_token_ids = []
|
||||
for token_id in token_ids:
|
||||
new_token_ids.append(token_id)
|
||||
if token_id in self.pad_tokens:
|
||||
new_token_ids.extend(self.pad_tokens[token_id])
|
||||
|
||||
# Do not exceed the max model input size. The -2 here is compensating for
|
||||
# compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens.
|
||||
max_length = self.tokenizer.model_max_length - 2
|
||||
# Do not exceed the max model input size
|
||||
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
|
||||
# which first removes and then adds back the start and end tokens.
|
||||
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
|
||||
if len(new_token_ids) > max_length:
|
||||
# HACK: If TI token expansion causes us to exceed the max text encoder input length, we silently discard
|
||||
# tokens. Token expansion should happen in a way that is compatible with compel's default handling of long
|
||||
# prompts.
|
||||
new_token_ids = new_token_ids[0:max_length]
|
||||
|
||||
return new_token_ids
|
||||
|
||||
@@ -1104,7 +1104,7 @@
|
||||
"parameters": "Parameters",
|
||||
"parameterSet": "Parameter Recalled",
|
||||
"parameterSetDesc": "Recalled {{parameter}}",
|
||||
"parameterNotSet": "Parameter Not Recalled",
|
||||
"parameterNotSet": "Parameter Recalled",
|
||||
"parameterNotSetDesc": "Unable to recall {{parameter}}",
|
||||
"parameterNotSetDescWithMessage": "Unable to recall {{parameter}}: {{message}}",
|
||||
"parametersSet": "Parameters Recalled",
|
||||
|
||||
@@ -6,8 +6,8 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import type { MapStore } from 'nanostores';
|
||||
import { atom, map } from 'nanostores';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { setEventListeners } from 'services/events/setEventListeners';
|
||||
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
|
||||
import { setEventListeners } from 'services/events/util/setEventListeners';
|
||||
import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client';
|
||||
import { io } from 'socket.io-client';
|
||||
|
||||
|
||||
@@ -35,22 +35,26 @@ import { addImageUploadedFulfilledListener } from 'app/store/middleware/listener
|
||||
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
|
||||
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
|
||||
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
|
||||
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
||||
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
|
||||
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
|
||||
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
|
||||
import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete';
|
||||
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
|
||||
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
|
||||
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
|
||||
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
|
||||
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
|
||||
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
|
||||
import { addSocketSubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed';
|
||||
import { addSocketUnsubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed';
|
||||
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
|
||||
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
|
||||
import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
@@ -98,11 +102,14 @@ addCommitStagingAreaImageListener(startAppListening);
|
||||
|
||||
// Socket.IO
|
||||
addGeneratorProgressEventListener(startAppListening);
|
||||
addGraphExecutionStateCompleteEventListener(startAppListening);
|
||||
addInvocationCompleteEventListener(startAppListening);
|
||||
addInvocationErrorEventListener(startAppListening);
|
||||
addInvocationStartedEventListener(startAppListening);
|
||||
addSocketConnectedEventListener(startAppListening);
|
||||
addSocketDisconnectedEventListener(startAppListening);
|
||||
addSocketSubscribedEventListener(startAppListening);
|
||||
addSocketUnsubscribedEventListener(startAppListening);
|
||||
addModelLoadEventListener(startAppListening);
|
||||
addModelInstallEventListener(startAppListening);
|
||||
addSocketQueueItemStatusChangedEventListener(startAppListening);
|
||||
|
||||
@@ -5,8 +5,8 @@ import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import {
|
||||
socketBulkDownloadComplete,
|
||||
socketBulkDownloadError,
|
||||
socketBulkDownloadCompleted,
|
||||
socketBulkDownloadFailed,
|
||||
socketBulkDownloadStarted,
|
||||
} from 'services/events/actions';
|
||||
|
||||
@@ -54,7 +54,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketBulkDownloadComplete,
|
||||
actionCreator: socketBulkDownloadCompleted,
|
||||
effect: async (action) => {
|
||||
log.debug(action.payload.data, 'Bulk download preparation completed');
|
||||
|
||||
@@ -80,7 +80,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketBulkDownloadError,
|
||||
actionCreator: socketBulkDownloadFailed,
|
||||
effect: async (action) => {
|
||||
log.debug(action.payload.data, 'Bulk download preparation failed');
|
||||
|
||||
|
||||
@@ -133,8 +133,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
||||
const [invocationCompleteAction] = await take(
|
||||
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
||||
socketInvocationComplete.match(action) &&
|
||||
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.invocation_source_id === processorNode.id
|
||||
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.source_node_id === processorNode.id
|
||||
);
|
||||
|
||||
// We still have to check the output type
|
||||
|
||||
@@ -69,8 +69,8 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
|
||||
const [invocationCompleteAction] = await take(
|
||||
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
||||
socketInvocationComplete.match(action) &&
|
||||
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.invocation_source_id === nodeId
|
||||
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.source_node_id === nodeId
|
||||
);
|
||||
|
||||
// We still have to check the output type
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketGeneratorProgress } from 'services/events/actions';
|
||||
@@ -12,9 +11,9 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
|
||||
startAppListening({
|
||||
actionCreator: socketGeneratorProgress,
|
||||
effect: (action) => {
|
||||
log.trace(parseify(action.payload), `Generator progress`);
|
||||
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
log.trace(action.payload, `Generator progress`);
|
||||
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
nes.progress = (step + 1) / total_steps;
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketGraphExecutionStateComplete } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addGraphExecutionStateCompleteEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketGraphExecutionStateComplete,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, 'Session complete');
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -29,12 +29,12 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
actionCreator: socketInvocationComplete,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { data } = action.payload;
|
||||
log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation.type})`);
|
||||
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
|
||||
|
||||
const { result, invocation_source_id } = data;
|
||||
const { result, node, queue_batch_id, source_node_id } = data;
|
||||
// This complete event has an associated image output
|
||||
if (isImageOutput(data.result) && !nodeTypeDenylist.includes(data.invocation.type)) {
|
||||
const { image_name } = data.result.image;
|
||||
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
|
||||
const { image_name } = result.image;
|
||||
const { canvas, gallery } = getState();
|
||||
|
||||
// This populates the `getImageDTO` cache
|
||||
@@ -48,7 +48,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
imageDTORequest.unsubscribe();
|
||||
|
||||
// Add canvas images to the staging area
|
||||
if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) {
|
||||
if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) {
|
||||
dispatch(addImageToStagingArea(imageDTO));
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
}
|
||||
}
|
||||
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.COMPLETED;
|
||||
if (nes.progress !== null) {
|
||||
|
||||
@@ -1,24 +1,52 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import ToastWithSessionRefDescription from 'features/toast/ToastWithSessionRefDescription';
|
||||
import { t } from 'i18next';
|
||||
import { startCase } from 'lodash-es';
|
||||
import { socketInvocationError } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
const getTitle = (errorType: string) => {
|
||||
if (errorType === 'OutOfMemoryError') {
|
||||
return t('toast.outOfMemoryError');
|
||||
}
|
||||
return t('toast.serverError');
|
||||
};
|
||||
|
||||
const getDescription = (errorType: string, sessionId: string, isLocal?: boolean) => {
|
||||
if (!isLocal) {
|
||||
if (errorType === 'OutOfMemoryError') {
|
||||
return ToastWithSessionRefDescription({
|
||||
message: t('toast.outOfMemoryDescription'),
|
||||
sessionId,
|
||||
});
|
||||
}
|
||||
return ToastWithSessionRefDescription({
|
||||
message: errorType,
|
||||
sessionId,
|
||||
});
|
||||
}
|
||||
return errorType;
|
||||
};
|
||||
|
||||
export const addInvocationErrorEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketInvocationError,
|
||||
effect: (action) => {
|
||||
const { invocation_source_id, invocation, error_type, error_message, error_traceback } = action.payload.data;
|
||||
log.error(parseify(action.payload), `Invocation error (${invocation.type})`);
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
effect: (action, { getState }) => {
|
||||
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
|
||||
const { source_node_id, error_type, error_message, error_traceback, graph_execution_state_id } =
|
||||
action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.FAILED;
|
||||
nes.progress = null;
|
||||
nes.progressImage = null;
|
||||
|
||||
nes.error = {
|
||||
error_type,
|
||||
error_message,
|
||||
@@ -26,6 +54,19 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
|
||||
};
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
|
||||
const errorType = startCase(error_type);
|
||||
const sessionId = graph_execution_state_id;
|
||||
const { isLocal } = getState().config;
|
||||
|
||||
toast({
|
||||
id: `INVOCATION_ERROR_${errorType}`,
|
||||
title: getTitle(errorType),
|
||||
status: 'error',
|
||||
duration: null,
|
||||
description: getDescription(errorType, sessionId, isLocal),
|
||||
updateDescription: isLocal ? true : false,
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketInvocationStarted } from 'services/events/actions';
|
||||
@@ -12,9 +11,9 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
|
||||
startAppListening({
|
||||
actionCreator: socketInvocationStarted,
|
||||
effect: (action) => {
|
||||
log.debug(parseify(action.payload), `Invocation started (${action.payload.data.invocation.type})`);
|
||||
const { invocation_source_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
|
||||
const { source_node_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
|
||||
@@ -3,14 +3,14 @@ import { api, LIST_TAG } from 'services/api';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import {
|
||||
socketModelInstallCancelled,
|
||||
socketModelInstallComplete,
|
||||
socketModelInstallDownloadProgress,
|
||||
socketModelInstallCompleted,
|
||||
socketModelInstallDownloading,
|
||||
socketModelInstallError,
|
||||
} from 'services/events/actions';
|
||||
|
||||
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallDownloadProgress,
|
||||
actionCreator: socketModelInstallDownloading,
|
||||
effect: async (action, { dispatch }) => {
|
||||
const { bytes, total_bytes, id } = action.payload.data;
|
||||
|
||||
@@ -29,7 +29,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallComplete,
|
||||
actionCreator: socketModelInstallCompleted,
|
||||
effect: (action, { dispatch }) => {
|
||||
const { id } = action.payload.data;
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketModelLoadComplete, socketModelLoadStarted } from 'services/events/actions';
|
||||
import { socketModelLoadCompleted, socketModelLoadStarted } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
@@ -8,11 +8,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
|
||||
startAppListening({
|
||||
actionCreator: socketModelLoadStarted,
|
||||
effect: (action) => {
|
||||
const { config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = config;
|
||||
const { model_config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = model_config;
|
||||
|
||||
const extras: string[] = [base, type];
|
||||
|
||||
if (submodel_type) {
|
||||
extras.push(submodel_type);
|
||||
}
|
||||
@@ -24,10 +23,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelLoadComplete,
|
||||
actionCreator: socketModelLoadCompleted,
|
||||
effect: (action) => {
|
||||
const { config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = config;
|
||||
const { model_config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = model_config;
|
||||
|
||||
const extras: string[] = [base, type];
|
||||
if (submodel_type) {
|
||||
|
||||
@@ -3,8 +3,6 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
|
||||
import { socketQueueItemStatusChanged } from 'services/events/actions';
|
||||
@@ -14,38 +12,18 @@ const log = logger('socketio');
|
||||
export const addSocketQueueItemStatusChangedEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketQueueItemStatusChanged,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch }) => {
|
||||
// we've got new status for the queue item, batch and queue
|
||||
const {
|
||||
item_id,
|
||||
session_id,
|
||||
status,
|
||||
started_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
batch_status,
|
||||
queue_status,
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
} = action.payload.data;
|
||||
const { queue_item, batch_status, queue_status } = action.payload.data;
|
||||
|
||||
log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
|
||||
log.debug(action.payload, `Queue item ${queue_item.item_id} status updated: ${queue_item.status}`);
|
||||
|
||||
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
|
||||
dispatch(
|
||||
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
|
||||
queueItemsAdapter.updateOne(draft, {
|
||||
id: String(item_id),
|
||||
changes: {
|
||||
status,
|
||||
started_at,
|
||||
updated_at: updated_at ?? undefined,
|
||||
completed_at: completed_at ?? undefined,
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
},
|
||||
id: String(queue_item.item_id),
|
||||
changes: queue_item,
|
||||
});
|
||||
})
|
||||
);
|
||||
@@ -72,11 +50,11 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
|
||||
'CurrentSessionQueueItem',
|
||||
'NextSessionQueueItem',
|
||||
'InvocationCacheStatus',
|
||||
{ type: 'SessionQueueItem', id: item_id },
|
||||
{ type: 'SessionQueueItem', id: queue_item.item_id },
|
||||
])
|
||||
);
|
||||
|
||||
if (status === 'in_progress') {
|
||||
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
|
||||
forEach($nodeExecutionStates.get(), (nes) => {
|
||||
if (!nes) {
|
||||
return;
|
||||
@@ -89,25 +67,6 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
|
||||
clone.outputs = [];
|
||||
$nodeExecutionStates.setKey(clone.nodeId, clone);
|
||||
});
|
||||
} else if (status === 'failed' && error_type) {
|
||||
const isLocal = getState().config.isLocal ?? true;
|
||||
const sessionId = session_id;
|
||||
|
||||
toast({
|
||||
id: `INVOCATION_ERROR_${error_type}`,
|
||||
title: getTitleFromErrorType(error_type),
|
||||
status: 'error',
|
||||
duration: null,
|
||||
updateDescription: isLocal,
|
||||
description: (
|
||||
<ErrorToastDescription
|
||||
errorType={error_type}
|
||||
errorMessage={error_message}
|
||||
sessionId={sessionId}
|
||||
isLocal={isLocal}
|
||||
/>
|
||||
),
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
@@ -0,0 +1,14 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketSubscribedSession } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addSocketSubscribedEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketSubscribedSession,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, 'Subscribed');
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,13 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketUnsubscribedSession } from 'services/events/actions';
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addSocketUnsubscribedEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketUnsubscribedSession,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, 'Unsubscribed');
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -613,7 +613,7 @@ export const canvasSlice = createSlice({
|
||||
state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id);
|
||||
}
|
||||
|
||||
const queueItemStatus = action.payload.data.status;
|
||||
const queueItemStatus = action.payload.data.queue_item.status;
|
||||
if (queueItemStatus === 'canceled' || queueItemStatus === 'failed') {
|
||||
resetStagingAreaIfEmpty(state);
|
||||
}
|
||||
|
||||
@@ -72,12 +72,10 @@ export const ModelEdit = ({ form }: Props) => {
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={form.control} />
|
||||
</FormControl>
|
||||
{data.type === 'main' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={form.control} />
|
||||
</FormControl>
|
||||
)}
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={form.control} />
|
||||
</FormControl>
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import type { PayloadAction } from '@reduxjs/toolkit';
|
||||
import { createSlice } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { calculateStepPercentage } from 'features/system/util/calculateStepPercentage';
|
||||
import type { LogLevelName } from 'roarr';
|
||||
import {
|
||||
socketConnected,
|
||||
socketDisconnected,
|
||||
socketGeneratorProgress,
|
||||
socketGraphExecutionStateComplete,
|
||||
socketInvocationComplete,
|
||||
socketInvocationStarted,
|
||||
socketModelLoadComplete,
|
||||
socketModelLoadCompleted,
|
||||
socketModelLoadStarted,
|
||||
socketQueueItemStatusChanged,
|
||||
} from 'services/events/actions';
|
||||
@@ -96,7 +98,14 @@ export const systemSlice = createSlice({
|
||||
* Generator Progress
|
||||
*/
|
||||
builder.addCase(socketGeneratorProgress, (state, action) => {
|
||||
const { step, total_steps, progress_image, session_id, batch_id, percentage } = action.payload.data;
|
||||
const {
|
||||
step,
|
||||
total_steps,
|
||||
order,
|
||||
progress_image,
|
||||
graph_execution_state_id: session_id,
|
||||
queue_batch_id: batch_id,
|
||||
} = action.payload.data;
|
||||
|
||||
if (state.cancellations.includes(session_id)) {
|
||||
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
|
||||
@@ -107,7 +116,8 @@ export const systemSlice = createSlice({
|
||||
state.denoiseProgress = {
|
||||
step,
|
||||
total_steps,
|
||||
percentage,
|
||||
order,
|
||||
percentage: calculateStepPercentage(step, total_steps, order),
|
||||
progress_image,
|
||||
session_id,
|
||||
batch_id,
|
||||
@@ -124,19 +134,27 @@ export const systemSlice = createSlice({
|
||||
state.status = 'CONNECTED';
|
||||
});
|
||||
|
||||
/**
|
||||
* Graph Execution State Complete
|
||||
*/
|
||||
builder.addCase(socketGraphExecutionStateComplete, (state) => {
|
||||
state.denoiseProgress = null;
|
||||
state.status = 'CONNECTED';
|
||||
});
|
||||
|
||||
builder.addCase(socketModelLoadStarted, (state) => {
|
||||
state.status = 'LOADING_MODEL';
|
||||
});
|
||||
|
||||
builder.addCase(socketModelLoadComplete, (state) => {
|
||||
builder.addCase(socketModelLoadCompleted, (state) => {
|
||||
state.status = 'CONNECTED';
|
||||
});
|
||||
|
||||
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
|
||||
if (['completed', 'canceled', 'failed'].includes(action.payload.data.status)) {
|
||||
if (['completed', 'canceled', 'failed'].includes(action.payload.data.queue_item.status)) {
|
||||
state.status = 'CONNECTED';
|
||||
state.denoiseProgress = null;
|
||||
state.cancellations.push(action.payload.data.session_id);
|
||||
state.cancellations.push(action.payload.data.queue_item.session_id);
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
@@ -10,6 +10,7 @@ type DenoiseProgress = {
|
||||
progress_image: ProgressImage | null | undefined;
|
||||
step: number;
|
||||
total_steps: number;
|
||||
order: number;
|
||||
percentage: number;
|
||||
};
|
||||
|
||||
@@ -54,5 +55,5 @@ export interface SystemState {
|
||||
shouldUseWatermarker: boolean;
|
||||
status: SystemStatus;
|
||||
shouldEnableInformationalPopovers: boolean;
|
||||
cancellations: string[];
|
||||
cancellations: string[]
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
export const calculateStepPercentage = (step: number, total_steps: number, order: number) => {
|
||||
if (total_steps === 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// we add one extra to step so that the progress bar will be full when denoise completes
|
||||
|
||||
if (order === 2) {
|
||||
return Math.floor((step + 1 + 1) / 2) / Math.floor((total_steps + 1) / 2);
|
||||
}
|
||||
|
||||
return (step + 1 + 1) / (total_steps + 1);
|
||||
};
|
||||
@@ -1,59 +0,0 @@
|
||||
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { t } from 'i18next';
|
||||
import { useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCopyBold } from 'react-icons/pi';
|
||||
|
||||
function onCopy(sessionId: string) {
|
||||
navigator.clipboard.writeText(sessionId);
|
||||
}
|
||||
|
||||
const ERROR_TYPE_TO_TITLE: Record<string, string> = {
|
||||
OutOfMemoryError: 'toast.outOfMemoryError',
|
||||
};
|
||||
|
||||
const COMMERCIAL_ERROR_TYPE_TO_DESC: Record<string, string> = {
|
||||
OutOfMemoryError: 'toast.outOfMemoryErrorDesc',
|
||||
};
|
||||
|
||||
export const getTitleFromErrorType = (errorType: string) => {
|
||||
return t(ERROR_TYPE_TO_TITLE[errorType] ?? 'toast.serverError');
|
||||
};
|
||||
|
||||
type Props = { errorType: string; errorMessage?: string | null; sessionId: string; isLocal: boolean };
|
||||
|
||||
export default function ErrorToastDescription({ errorType, errorMessage, sessionId, isLocal }: Props) {
|
||||
const { t } = useTranslation();
|
||||
const description = useMemo(() => {
|
||||
// Special handling for commercial error types
|
||||
const descriptionTKey = isLocal ? null : COMMERCIAL_ERROR_TYPE_TO_DESC[errorType];
|
||||
if (descriptionTKey) {
|
||||
return t(descriptionTKey);
|
||||
}
|
||||
if (errorMessage) {
|
||||
return `${errorType}: ${errorMessage}`;
|
||||
}
|
||||
}, [errorMessage, errorType, isLocal, t]);
|
||||
return (
|
||||
<Flex flexDir="column">
|
||||
{description && <Text fontSize="md">{description}</Text>}
|
||||
{!isLocal && (
|
||||
<Flex gap="2" alignItems="center">
|
||||
<Text fontSize="sm" fontStyle="italic">
|
||||
{t('toast.sessionRef', { sessionId })}
|
||||
</Text>
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="Copy"
|
||||
icon={<PiCopyBold />}
|
||||
onClick={onCopy.bind(null, sessionId)}
|
||||
variant="ghost"
|
||||
sx={sx}
|
||||
/>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
const sx = { svg: { fill: 'base.50' } };
|
||||
@@ -0,0 +1,30 @@
|
||||
import { Flex, IconButton, Text } from '@invoke-ai/ui-library';
|
||||
import { t } from 'i18next';
|
||||
import { PiCopyBold } from 'react-icons/pi';
|
||||
|
||||
function onCopy(sessionId: string) {
|
||||
navigator.clipboard.writeText(sessionId);
|
||||
}
|
||||
|
||||
type Props = { message: string; sessionId: string };
|
||||
|
||||
export default function ToastWithSessionRefDescription({ message, sessionId }: Props) {
|
||||
return (
|
||||
<Flex flexDir="column">
|
||||
<Text fontSize="md">{message}</Text>
|
||||
<Flex gap="2" alignItems="center">
|
||||
<Text fontSize="sm">{t('toast.sessionRef', { sessionId })}</Text>
|
||||
<IconButton
|
||||
size="sm"
|
||||
aria-label="Copy"
|
||||
icon={<PiCopyBold />}
|
||||
onClick={onCopy.bind(null, sessionId)}
|
||||
variant="ghost"
|
||||
sx={sx}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
const sx = { svg: { fill: 'base.50' } };
|
||||
File diff suppressed because one or more lines are too long
@@ -38,6 +38,7 @@ export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDT
|
||||
|
||||
// Models
|
||||
export type ModelType = S['ModelType'];
|
||||
export type SubModelType = S['SubModelType'];
|
||||
export type BaseModelType = S['BaseModelType'];
|
||||
|
||||
// Model Configs
|
||||
|
||||
@@ -1,58 +1,101 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type {
|
||||
BulkDownloadCompleteEvent,
|
||||
BulkDownloadCompletedEvent,
|
||||
BulkDownloadFailedEvent,
|
||||
BulkDownloadStartedEvent,
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
GeneratorProgressEvent,
|
||||
GraphExecutionStateCompleteEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationRetrievalErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallCompletedEvent,
|
||||
ModelInstallDownloadingEvent,
|
||||
ModelInstallErrorEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelLoadCompletedEvent,
|
||||
ModelLoadStartedEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
SessionRetrievalErrorEvent,
|
||||
} from 'services/events/types';
|
||||
|
||||
const createSocketAction = <T = undefined>(name: string) =>
|
||||
createAction<T extends undefined ? void : { data: T }>(`socket/${name}`);
|
||||
// Create actions for each socket
|
||||
// Middleware and redux can then respond to them as needed
|
||||
|
||||
export const socketConnected = createSocketAction('Connected');
|
||||
export const socketDisconnected = createSocketAction('Disconnected');
|
||||
export const socketInvocationStarted = createSocketAction<InvocationStartedEvent>('InvocationStartedEvent');
|
||||
export const socketInvocationComplete = createSocketAction<InvocationCompleteEvent>('InvocationCompleteEvent');
|
||||
export const socketInvocationError = createSocketAction<InvocationErrorEvent>('InvocationErrorEvent');
|
||||
export const socketGeneratorProgress = createSocketAction<InvocationDenoiseProgressEvent>(
|
||||
'InvocationDenoiseProgressEvent'
|
||||
);
|
||||
export const socketModelLoadStarted = createSocketAction<ModelLoadStartedEvent>('ModelLoadStartedEvent');
|
||||
export const socketModelLoadComplete = createSocketAction<ModelLoadCompleteEvent>('ModelLoadCompleteEvent');
|
||||
export const socketDownloadStarted = createSocketAction<DownloadStartedEvent>('DownloadStartedEvent');
|
||||
export const socketDownloadProgress = createSocketAction<DownloadProgressEvent>('DownloadProgressEvent');
|
||||
export const socketDownloadComplete = createSocketAction<DownloadCompleteEvent>('DownloadCompleteEvent');
|
||||
export const socketDownloadCancelled = createSocketAction<DownloadCancelledEvent>('DownloadCancelledEvent');
|
||||
export const socketDownloadError = createSocketAction<DownloadErrorEvent>('DownloadErrorEvent');
|
||||
export const socketModelInstallStarted = createSocketAction<ModelInstallStartedEvent>('ModelInstallStartedEvent');
|
||||
export const socketModelInstallDownloadProgress = createSocketAction<ModelInstallDownloadProgressEvent>(
|
||||
'ModelInstallDownloadProgressEvent'
|
||||
);
|
||||
export const socketModelInstallDownloadsComplete = createSocketAction<ModelInstallDownloadsCompleteEvent>(
|
||||
'ModelInstallDownloadsCompleteEvent'
|
||||
);
|
||||
export const socketModelInstallComplete = createSocketAction<ModelInstallCompleteEvent>('ModelInstallCompleteEvent');
|
||||
export const socketModelInstallError = createSocketAction<ModelInstallErrorEvent>('ModelInstallErrorEvent');
|
||||
export const socketModelInstallCancelled = createSocketAction<ModelInstallCancelledEvent>('ModelInstallCancelledEvent');
|
||||
export const socketQueueItemStatusChanged =
|
||||
createSocketAction<QueueItemStatusChangedEvent>('QueueItemStatusChangedEvent');
|
||||
export const socketBulkDownloadStarted = createSocketAction<BulkDownloadStartedEvent>('BulkDownloadStartedEvent');
|
||||
export const socketBulkDownloadComplete = createSocketAction<BulkDownloadCompleteEvent>('BulkDownloadCompleteEvent');
|
||||
export const socketBulkDownloadError = createSocketAction<BulkDownloadFailedEvent>('BulkDownloadFailedEvent');
|
||||
export const socketConnected = createAction('socket/socketConnected');
|
||||
|
||||
export const socketDisconnected = createAction('socket/socketDisconnected');
|
||||
|
||||
export const socketSubscribedSession = createAction<{
|
||||
sessionId: string;
|
||||
}>('socket/socketSubscribedSession');
|
||||
|
||||
export const socketUnsubscribedSession = createAction<{ sessionId: string }>('socket/socketUnsubscribedSession');
|
||||
|
||||
export const socketInvocationStarted = createAction<{
|
||||
data: InvocationStartedEvent;
|
||||
}>('socket/socketInvocationStarted');
|
||||
|
||||
export const socketInvocationComplete = createAction<{
|
||||
data: InvocationCompleteEvent;
|
||||
}>('socket/socketInvocationComplete');
|
||||
|
||||
export const socketInvocationError = createAction<{
|
||||
data: InvocationErrorEvent;
|
||||
}>('socket/socketInvocationError');
|
||||
|
||||
export const socketGraphExecutionStateComplete = createAction<{
|
||||
data: GraphExecutionStateCompleteEvent;
|
||||
}>('socket/socketGraphExecutionStateComplete');
|
||||
|
||||
export const socketGeneratorProgress = createAction<{
|
||||
data: GeneratorProgressEvent;
|
||||
}>('socket/socketGeneratorProgress');
|
||||
|
||||
export const socketModelLoadStarted = createAction<{
|
||||
data: ModelLoadStartedEvent;
|
||||
}>('socket/socketModelLoadStarted');
|
||||
|
||||
export const socketModelLoadCompleted = createAction<{
|
||||
data: ModelLoadCompletedEvent;
|
||||
}>('socket/socketModelLoadCompleted');
|
||||
|
||||
export const socketModelInstallDownloading = createAction<{
|
||||
data: ModelInstallDownloadingEvent;
|
||||
}>('socket/socketModelInstallDownloading');
|
||||
|
||||
export const socketModelInstallCompleted = createAction<{
|
||||
data: ModelInstallCompletedEvent;
|
||||
}>('socket/socketModelInstallCompleted');
|
||||
|
||||
export const socketModelInstallError = createAction<{
|
||||
data: ModelInstallErrorEvent;
|
||||
}>('socket/socketModelInstallError');
|
||||
|
||||
export const socketModelInstallCancelled = createAction<{
|
||||
data: ModelInstallCancelledEvent;
|
||||
}>('socket/socketModelInstallCancelled');
|
||||
|
||||
export const socketSessionRetrievalError = createAction<{
|
||||
data: SessionRetrievalErrorEvent;
|
||||
}>('socket/socketSessionRetrievalError');
|
||||
|
||||
export const socketInvocationRetrievalError = createAction<{
|
||||
data: InvocationRetrievalErrorEvent;
|
||||
}>('socket/socketInvocationRetrievalError');
|
||||
|
||||
export const socketQueueItemStatusChanged = createAction<{
|
||||
data: QueueItemStatusChangedEvent;
|
||||
}>('socket/socketQueueItemStatusChanged');
|
||||
|
||||
export const socketBulkDownloadStarted = createAction<{
|
||||
data: BulkDownloadStartedEvent;
|
||||
}>('socket/socketBulkDownloadStarted');
|
||||
|
||||
export const socketBulkDownloadCompleted = createAction<{
|
||||
data: BulkDownloadCompletedEvent;
|
||||
}>('socket/socketBulkDownloadCompleted');
|
||||
|
||||
export const socketBulkDownloadFailed = createAction<{
|
||||
data: BulkDownloadFailedEvent;
|
||||
}>('socket/socketBulkDownloadFailed');
|
||||
|
||||
@@ -1,73 +1,278 @@
|
||||
import type { Graph, GraphExecutionState, S } from 'services/api/types';
|
||||
import type { components } from 'services/api/schema';
|
||||
import type { AnyModelConfig, Graph, GraphExecutionState, SubModelType } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* A progress image, we get one for each step in the generation
|
||||
*/
|
||||
export type ProgressImage = {
|
||||
dataURL: string;
|
||||
width: number;
|
||||
height: number;
|
||||
};
|
||||
|
||||
export type AnyInvocation = NonNullable<NonNullable<Graph['nodes']>[string]>;
|
||||
|
||||
export type AnyResult = NonNullable<GraphExecutionState['results'][string]>;
|
||||
|
||||
export type ModelLoadStartedEvent = S['ModelLoadStartedEvent'];
|
||||
export type ModelLoadCompleteEvent = S['ModelLoadCompleteEvent'];
|
||||
|
||||
export type InvocationStartedEvent = Omit<S['InvocationStartedEvent'], 'invocation'> & { invocation: AnyInvocation };
|
||||
export type InvocationDenoiseProgressEvent = Omit<S['InvocationDenoiseProgressEvent'], 'invocation'> & {
|
||||
invocation: AnyInvocation;
|
||||
type BaseNode = {
|
||||
id: string;
|
||||
type: string;
|
||||
[key: string]: AnyInvocation[keyof AnyInvocation];
|
||||
};
|
||||
export type InvocationCompleteEvent = Omit<S['InvocationCompleteEvent'], 'result' | 'invocation'> & {
|
||||
|
||||
export type ModelLoadStartedEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
model_config: AnyModelConfig;
|
||||
submodel_type?: SubModelType | null;
|
||||
};
|
||||
|
||||
export type ModelLoadCompletedEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
model_config: AnyModelConfig;
|
||||
submodel_type?: SubModelType | null;
|
||||
};
|
||||
|
||||
export type ModelInstallDownloadingEvent = {
|
||||
bytes: number;
|
||||
local_path: string;
|
||||
source: string;
|
||||
timestamp: number;
|
||||
total_bytes: number;
|
||||
id: number;
|
||||
};
|
||||
|
||||
export type ModelInstallCompletedEvent = {
|
||||
key: number;
|
||||
source: string;
|
||||
timestamp: number;
|
||||
id: number;
|
||||
};
|
||||
|
||||
export type ModelInstallErrorEvent = {
|
||||
error: string;
|
||||
error_type: string;
|
||||
source: string;
|
||||
timestamp: number;
|
||||
id: number;
|
||||
};
|
||||
|
||||
export type ModelInstallCancelledEvent = {
|
||||
source: string;
|
||||
timestamp: number;
|
||||
id: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `generator_progress` socket.io event.
|
||||
*
|
||||
* @example socket.on('generator_progress', (data: GeneratorProgressEvent) => { ... }
|
||||
*/
|
||||
export type GeneratorProgressEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node_id: string;
|
||||
source_node_id: string;
|
||||
progress_image?: ProgressImage;
|
||||
step: number;
|
||||
order: number;
|
||||
total_steps: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `invocation_complete` socket.io event.
|
||||
*
|
||||
* `result` is a discriminated union with a `type` property as the discriminant.
|
||||
*
|
||||
* @example socket.on('invocation_complete', (data: InvocationCompleteEvent) => { ... }
|
||||
*/
|
||||
export type InvocationCompleteEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node: BaseNode;
|
||||
source_node_id: string;
|
||||
result: AnyResult;
|
||||
invocation: AnyInvocation;
|
||||
};
|
||||
export type InvocationErrorEvent = Omit<S['InvocationErrorEvent'], 'invocation'> & { invocation: AnyInvocation };
|
||||
export type ProgressImage = InvocationDenoiseProgressEvent['progress_image'];
|
||||
|
||||
export type ModelInstallDownloadProgressEvent = S['ModelInstallDownloadProgressEvent'];
|
||||
export type ModelInstallDownloadsCompleteEvent = S['ModelInstallDownloadsCompleteEvent'];
|
||||
export type ModelInstallCompleteEvent = S['ModelInstallCompleteEvent'];
|
||||
export type ModelInstallErrorEvent = S['ModelInstallErrorEvent'];
|
||||
export type ModelInstallStartedEvent = S['ModelInstallStartedEvent'];
|
||||
export type ModelInstallCancelledEvent = S['ModelInstallCancelledEvent'];
|
||||
/**
|
||||
* A `invocation_error` socket.io event.
|
||||
*
|
||||
* @example socket.on('invocation_error', (data: InvocationErrorEvent) => { ... }
|
||||
*/
|
||||
export type InvocationErrorEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node: BaseNode;
|
||||
source_node_id: string;
|
||||
error_type: string;
|
||||
error_message: string;
|
||||
error_traceback: string;
|
||||
};
|
||||
|
||||
export type DownloadStartedEvent = S['DownloadStartedEvent'];
|
||||
export type DownloadProgressEvent = S['DownloadProgressEvent'];
|
||||
export type DownloadCompleteEvent = S['DownloadCompleteEvent'];
|
||||
export type DownloadCancelledEvent = S['DownloadCancelledEvent'];
|
||||
export type DownloadErrorEvent = S['DownloadErrorEvent'];
|
||||
/**
|
||||
* A `invocation_started` socket.io event.
|
||||
*
|
||||
* @example socket.on('invocation_started', (data: InvocationStartedEvent) => { ... }
|
||||
*/
|
||||
export type InvocationStartedEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node: BaseNode;
|
||||
source_node_id: string;
|
||||
};
|
||||
|
||||
export type QueueItemStatusChangedEvent = S['QueueItemStatusChangedEvent'];
|
||||
/**
|
||||
* A `graph_execution_state_complete` socket.io event.
|
||||
*
|
||||
* @example socket.on('graph_execution_state_complete', (data: GraphExecutionStateCompleteEvent) => { ... }
|
||||
*/
|
||||
export type GraphExecutionStateCompleteEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
};
|
||||
|
||||
export type BulkDownloadStartedEvent = S['BulkDownloadStartedEvent'];
|
||||
export type BulkDownloadCompleteEvent = S['BulkDownloadCompleteEvent'];
|
||||
export type BulkDownloadFailedEvent = S['BulkDownloadErrorEvent'];
|
||||
/**
|
||||
* A `session_retrieval_error` socket.io event.
|
||||
*
|
||||
* @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... }
|
||||
*/
|
||||
export type SessionRetrievalErrorEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
error_type: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `invocation_retrieval_error` socket.io event.
|
||||
*
|
||||
* @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... }
|
||||
*/
|
||||
export type InvocationRetrievalErrorEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: number;
|
||||
queue_batch_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node_id: string;
|
||||
error_type: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* A `queue_item_status_changed` socket.io event.
|
||||
*
|
||||
* @example socket.on('queue_item_status_changed', (data: QueueItemStatusChangedEvent) => { ... }
|
||||
*/
|
||||
export type QueueItemStatusChangedEvent = {
|
||||
queue_id: string;
|
||||
queue_item: {
|
||||
queue_id: string;
|
||||
item_id: number;
|
||||
batch_id: string;
|
||||
session_id: string;
|
||||
status: components['schemas']['SessionQueueItemDTO']['status'];
|
||||
error_type?: string | null;
|
||||
error_message?: string | null;
|
||||
error_traceback?: string | null;
|
||||
created_at: string;
|
||||
updated_at: string;
|
||||
started_at: string | undefined;
|
||||
completed_at: string | undefined;
|
||||
};
|
||||
batch_status: {
|
||||
queue_id: string;
|
||||
batch_id: string;
|
||||
pending: number;
|
||||
in_progress: number;
|
||||
completed: number;
|
||||
failed: number;
|
||||
canceled: number;
|
||||
total: number;
|
||||
};
|
||||
queue_status: {
|
||||
queue_id: string;
|
||||
item_id?: number;
|
||||
batch_id?: string;
|
||||
session_id?: string;
|
||||
pending: number;
|
||||
in_progress: number;
|
||||
completed: number;
|
||||
failed: number;
|
||||
canceled: number;
|
||||
total: number;
|
||||
};
|
||||
};
|
||||
|
||||
type ClientEmitSubscribeQueue = {
|
||||
queue_id: string;
|
||||
};
|
||||
type ClientEmitUnsubscribeQueue = ClientEmitSubscribeQueue;
|
||||
|
||||
type ClientEmitUnsubscribeQueue = {
|
||||
queue_id: string;
|
||||
};
|
||||
|
||||
export type BulkDownloadStartedEvent = {
|
||||
bulk_download_id: string;
|
||||
bulk_download_item_id: string;
|
||||
bulk_download_item_name: string;
|
||||
};
|
||||
|
||||
export type BulkDownloadCompletedEvent = {
|
||||
bulk_download_id: string;
|
||||
bulk_download_item_id: string;
|
||||
bulk_download_item_name: string;
|
||||
};
|
||||
|
||||
export type BulkDownloadFailedEvent = {
|
||||
bulk_download_id: string;
|
||||
bulk_download_item_id: string;
|
||||
bulk_download_item_name: string;
|
||||
error: string;
|
||||
};
|
||||
|
||||
type ClientEmitSubscribeBulkDownload = {
|
||||
bulk_download_id: string;
|
||||
};
|
||||
type ClientEmitUnsubscribeBulkDownload = ClientEmitSubscribeBulkDownload;
|
||||
|
||||
type ClientEmitUnsubscribeBulkDownload = {
|
||||
bulk_download_id: string;
|
||||
};
|
||||
|
||||
export type ServerToClientEvents = {
|
||||
invocation_denoise_progress: (payload: InvocationDenoiseProgressEvent) => void;
|
||||
generator_progress: (payload: GeneratorProgressEvent) => void;
|
||||
invocation_complete: (payload: InvocationCompleteEvent) => void;
|
||||
invocation_error: (payload: InvocationErrorEvent) => void;
|
||||
invocation_started: (payload: InvocationStartedEvent) => void;
|
||||
download_started: (payload: DownloadStartedEvent) => void;
|
||||
download_progress: (payload: DownloadProgressEvent) => void;
|
||||
download_complete: (payload: DownloadCompleteEvent) => void;
|
||||
download_cancelled: (payload: DownloadCancelledEvent) => void;
|
||||
download_error: (payload: DownloadErrorEvent) => void;
|
||||
graph_execution_state_complete: (payload: GraphExecutionStateCompleteEvent) => void;
|
||||
model_load_started: (payload: ModelLoadStartedEvent) => void;
|
||||
model_install_started: (payload: ModelInstallStartedEvent) => void;
|
||||
model_install_download_progress: (payload: ModelInstallDownloadProgressEvent) => void;
|
||||
model_install_downloads_complete: (payload: ModelInstallDownloadsCompleteEvent) => void;
|
||||
model_install_complete: (payload: ModelInstallCompleteEvent) => void;
|
||||
model_load_completed: (payload: ModelLoadCompletedEvent) => void;
|
||||
model_install_downloading: (payload: ModelInstallDownloadingEvent) => void;
|
||||
model_install_completed: (payload: ModelInstallCompletedEvent) => void;
|
||||
model_install_error: (payload: ModelInstallErrorEvent) => void;
|
||||
model_install_cancelled: (payload: ModelInstallCancelledEvent) => void;
|
||||
model_load_complete: (payload: ModelLoadCompleteEvent) => void;
|
||||
model_install_canceled: (payload: ModelInstallCancelledEvent) => void;
|
||||
session_retrieval_error: (payload: SessionRetrievalErrorEvent) => void;
|
||||
invocation_retrieval_error: (payload: InvocationRetrievalErrorEvent) => void;
|
||||
queue_item_status_changed: (payload: QueueItemStatusChangedEvent) => void;
|
||||
bulk_download_started: (payload: BulkDownloadStartedEvent) => void;
|
||||
bulk_download_complete: (payload: BulkDownloadCompleteEvent) => void;
|
||||
bulk_download_error: (payload: BulkDownloadFailedEvent) => void;
|
||||
bulk_download_completed: (payload: BulkDownloadCompletedEvent) => void;
|
||||
bulk_download_failed: (payload: BulkDownloadFailedEvent) => void;
|
||||
};
|
||||
|
||||
export type ClientToServerEvents = {
|
||||
|
||||
@@ -4,29 +4,24 @@ import { $queueId } from 'app/store/nanostores/queueId';
|
||||
import type { AppDispatch } from 'app/store/store';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import {
|
||||
socketBulkDownloadComplete,
|
||||
socketBulkDownloadError,
|
||||
socketBulkDownloadCompleted,
|
||||
socketBulkDownloadFailed,
|
||||
socketBulkDownloadStarted,
|
||||
socketConnected,
|
||||
socketDisconnected,
|
||||
socketDownloadCancelled,
|
||||
socketDownloadComplete,
|
||||
socketDownloadError,
|
||||
socketDownloadProgress,
|
||||
socketDownloadStarted,
|
||||
socketGeneratorProgress,
|
||||
socketGraphExecutionStateComplete,
|
||||
socketInvocationComplete,
|
||||
socketInvocationError,
|
||||
socketInvocationRetrievalError,
|
||||
socketInvocationStarted,
|
||||
socketModelInstallCancelled,
|
||||
socketModelInstallComplete,
|
||||
socketModelInstallDownloadProgress,
|
||||
socketModelInstallDownloadsComplete,
|
||||
socketModelInstallCompleted,
|
||||
socketModelInstallDownloading,
|
||||
socketModelInstallError,
|
||||
socketModelInstallStarted,
|
||||
socketModelLoadComplete,
|
||||
socketModelLoadCompleted,
|
||||
socketModelLoadStarted,
|
||||
socketQueueItemStatusChanged,
|
||||
socketSessionRetrievalError,
|
||||
} from 'services/events/actions';
|
||||
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
|
||||
import type { Socket } from 'socket.io-client';
|
||||
@@ -36,7 +31,12 @@ type SetEventListenersArg = {
|
||||
dispatch: AppDispatch;
|
||||
};
|
||||
|
||||
export const setEventListeners = ({ socket, dispatch }: SetEventListenersArg) => {
|
||||
export const setEventListeners = (arg: SetEventListenersArg) => {
|
||||
const { socket, dispatch } = arg;
|
||||
|
||||
/**
|
||||
* Connect
|
||||
*/
|
||||
socket.on('connect', () => {
|
||||
dispatch(socketConnected());
|
||||
const queue_id = $queueId.get();
|
||||
@@ -46,6 +46,7 @@ export const setEventListeners = ({ socket, dispatch }: SetEventListenersArg) =>
|
||||
socket.emit('subscribe_bulk_download', { bulk_download_id });
|
||||
}
|
||||
});
|
||||
|
||||
socket.on('connect_error', (error) => {
|
||||
if (error && error.message) {
|
||||
const data: string | undefined = (error as unknown as { data: string | undefined }).data;
|
||||
@@ -59,70 +60,147 @@ export const setEventListeners = ({ socket, dispatch }: SetEventListenersArg) =>
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Disconnect
|
||||
*/
|
||||
socket.on('disconnect', () => {
|
||||
dispatch(socketDisconnected());
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation started
|
||||
*/
|
||||
socket.on('invocation_started', (data) => {
|
||||
dispatch(socketInvocationStarted({ data }));
|
||||
});
|
||||
socket.on('invocation_denoise_progress', (data) => {
|
||||
|
||||
/**
|
||||
* Generator progress
|
||||
*/
|
||||
socket.on('generator_progress', (data) => {
|
||||
dispatch(socketGeneratorProgress({ data }));
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation error
|
||||
*/
|
||||
socket.on('invocation_error', (data) => {
|
||||
dispatch(socketInvocationError({ data }));
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation complete
|
||||
*/
|
||||
socket.on('invocation_complete', (data) => {
|
||||
dispatch(socketInvocationComplete({ data }));
|
||||
dispatch(
|
||||
socketInvocationComplete({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Graph complete
|
||||
*/
|
||||
socket.on('graph_execution_state_complete', (data) => {
|
||||
dispatch(
|
||||
socketGraphExecutionStateComplete({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Model load started
|
||||
*/
|
||||
socket.on('model_load_started', (data) => {
|
||||
dispatch(socketModelLoadStarted({ data }));
|
||||
dispatch(
|
||||
socketModelLoadStarted({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
socket.on('model_load_complete', (data) => {
|
||||
dispatch(socketModelLoadComplete({ data }));
|
||||
|
||||
/**
|
||||
* Model load completed
|
||||
*/
|
||||
socket.on('model_load_completed', (data) => {
|
||||
dispatch(
|
||||
socketModelLoadCompleted({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
socket.on('download_started', (data) => {
|
||||
dispatch(socketDownloadStarted({ data }));
|
||||
|
||||
/**
|
||||
* Model Install Downloading
|
||||
*/
|
||||
socket.on('model_install_downloading', (data) => {
|
||||
dispatch(
|
||||
socketModelInstallDownloading({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
socket.on('download_progress', (data) => {
|
||||
dispatch(socketDownloadProgress({ data }));
|
||||
});
|
||||
socket.on('download_complete', (data) => {
|
||||
dispatch(socketDownloadComplete({ data }));
|
||||
});
|
||||
socket.on('download_cancelled', (data) => {
|
||||
dispatch(socketDownloadCancelled({ data }));
|
||||
});
|
||||
socket.on('download_error', (data) => {
|
||||
dispatch(socketDownloadError({ data }));
|
||||
});
|
||||
socket.on('model_install_started', (data) => {
|
||||
dispatch(socketModelInstallStarted({ data }));
|
||||
});
|
||||
socket.on('model_install_download_progress', (data) => {
|
||||
dispatch(socketModelInstallDownloadProgress({ data }));
|
||||
});
|
||||
socket.on('model_install_downloads_complete', (data) => {
|
||||
dispatch(socketModelInstallDownloadsComplete({ data }));
|
||||
});
|
||||
socket.on('model_install_complete', (data) => {
|
||||
dispatch(socketModelInstallComplete({ data }));
|
||||
|
||||
/**
|
||||
* Model Install Completed
|
||||
*/
|
||||
socket.on('model_install_completed', (data) => {
|
||||
dispatch(
|
||||
socketModelInstallCompleted({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Model Install Error
|
||||
*/
|
||||
socket.on('model_install_error', (data) => {
|
||||
dispatch(socketModelInstallError({ data }));
|
||||
dispatch(
|
||||
socketModelInstallError({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
socket.on('model_install_cancelled', (data) => {
|
||||
dispatch(socketModelInstallCancelled({ data }));
|
||||
|
||||
/**
|
||||
* Session retrieval error
|
||||
*/
|
||||
socket.on('session_retrieval_error', (data) => {
|
||||
dispatch(
|
||||
socketSessionRetrievalError({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
/**
|
||||
* Invocation retrieval error
|
||||
*/
|
||||
socket.on('invocation_retrieval_error', (data) => {
|
||||
dispatch(
|
||||
socketInvocationRetrievalError({
|
||||
data,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
socket.on('queue_item_status_changed', (data) => {
|
||||
dispatch(socketQueueItemStatusChanged({ data }));
|
||||
});
|
||||
|
||||
socket.on('bulk_download_started', (data) => {
|
||||
dispatch(socketBulkDownloadStarted({ data }));
|
||||
});
|
||||
socket.on('bulk_download_complete', (data) => {
|
||||
dispatch(socketBulkDownloadComplete({ data }));
|
||||
|
||||
socket.on('bulk_download_completed', (data) => {
|
||||
dispatch(socketBulkDownloadCompleted({ data }));
|
||||
});
|
||||
socket.on('bulk_download_error', (data) => {
|
||||
dispatch(socketBulkDownloadError({ data }));
|
||||
|
||||
socket.on('bulk_download_failed', (data) => {
|
||||
dispatch(socketBulkDownloadFailed({ data }));
|
||||
});
|
||||
};
|
||||
@@ -33,8 +33,8 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
# Core generation dependencies, pinned for reproducible builds.
|
||||
"accelerate==0.30.1",
|
||||
"clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"accelerate==0.29.2",
|
||||
"clip_anytorch==2.5.2", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||
"compel==2.0.2",
|
||||
"controlnet-aux==0.0.7",
|
||||
"diffusers[torch]==0.27.2",
|
||||
@@ -45,18 +45,18 @@ dependencies = [
|
||||
"onnxruntime==1.16.3",
|
||||
"opencv-python==4.9.0.80",
|
||||
"pytorch-lightning==2.1.3",
|
||||
"safetensors==0.4.3",
|
||||
"safetensors==0.4.2",
|
||||
"timm==0.6.13", # needed to override timm latest in controlnet_aux, see https://github.com/isl-org/ZoeDepth/issues/26
|
||||
"torch==2.2.2",
|
||||
"torchmetrics==0.11.4",
|
||||
"torchsde==0.2.6",
|
||||
"torchvision==0.17.2",
|
||||
"transformers==4.41.1",
|
||||
"transformers==4.39.3",
|
||||
|
||||
# Core application dependencies, pinned for reproducible builds.
|
||||
"fastapi-events==0.11.0",
|
||||
"fastapi==0.110.0",
|
||||
"huggingface-hub==0.23.1",
|
||||
"huggingface-hub==0.22.2",
|
||||
"pydantic-settings==2.2.1",
|
||||
"pydantic==2.6.3",
|
||||
"python-socketio==5.11.1",
|
||||
|
||||
@@ -9,11 +9,6 @@ import pytest
|
||||
from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException
|
||||
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException
|
||||
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BulkDownloadCompleteEvent,
|
||||
BulkDownloadErrorEvent,
|
||||
BulkDownloadStartedEvent,
|
||||
)
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageRecordNotFoundException,
|
||||
@@ -286,9 +281,9 @@ def assert_handler_success(
|
||||
|
||||
# Check that the correct events were emitted
|
||||
assert len(event_bus.events) == 2
|
||||
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
|
||||
assert isinstance(event_bus.events[1], BulkDownloadCompleteEvent)
|
||||
assert event_bus.events[1].bulk_download_item_name == os.path.basename(expected_zip_path)
|
||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||
assert event_bus.events[1].event_name == "bulk_download_completed"
|
||||
assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path)
|
||||
|
||||
|
||||
def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
|
||||
@@ -334,9 +329,9 @@ def test_handler_on_generic_exception(
|
||||
event_bus: TestEventService = mock_invoker.services.events
|
||||
|
||||
assert len(event_bus.events) == 2
|
||||
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
|
||||
assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
|
||||
assert event_bus.events[1].error == exception.__str__()
|
||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||
assert event_bus.events[1].event_name == "bulk_download_failed"
|
||||
assert event_bus.events[1].payload["error"] == exception.__str__()
|
||||
|
||||
|
||||
def execute_handler_test_on_error(
|
||||
@@ -349,9 +344,9 @@ def execute_handler_test_on_error(
|
||||
event_bus: TestEventService = mock_invoker.services.events
|
||||
|
||||
assert len(event_bus.events) == 2
|
||||
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
|
||||
assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
|
||||
assert event_bus.events[1].error == error.__str__()
|
||||
assert event_bus.events[0].event_name == "bulk_download_started"
|
||||
assert event_bus.events[1].event_name == "bulk_download_failed"
|
||||
assert event_bus.events[1].payload["error"] == error.__str__()
|
||||
|
||||
|
||||
def test_delete(tmp_path: Path):
|
||||
|
||||
@@ -10,13 +10,6 @@ from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||||
from invokeai.app.services.events.events_common import (
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
)
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
# Prevent pytest deprecation warnings
|
||||
@@ -123,14 +116,14 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
queue.join()
|
||||
events = event_bus.events
|
||||
assert len(events) == 3
|
||||
assert isinstance(events[0], DownloadStartedEvent)
|
||||
assert isinstance(events[1], DownloadProgressEvent)
|
||||
assert isinstance(events[2], DownloadCompleteEvent)
|
||||
assert events[0].timestamp <= events[1].timestamp
|
||||
assert events[1].timestamp <= events[2].timestamp
|
||||
assert events[1].total_bytes > 0
|
||||
assert events[1].current_bytes <= events[1].total_bytes
|
||||
assert events[2].total_bytes == 32029
|
||||
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
|
||||
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
|
||||
assert events[0].event_name == "download_started"
|
||||
assert events[1].event_name == "download_progress"
|
||||
assert events[1].payload["total_bytes"] > 0
|
||||
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
|
||||
assert events[2].event_name == "download_complete"
|
||||
assert events[2].payload["total_bytes"] == 32029
|
||||
|
||||
# test a failure
|
||||
event_bus.events = [] # reset our accumulator
|
||||
@@ -139,10 +132,10 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
events = event_bus.events
|
||||
print("\n".join([x.model_dump_json() for x in events]))
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], DownloadErrorEvent)
|
||||
assert events[0].error_type == "HTTPError(NOT FOUND)"
|
||||
assert events[0].error is not None
|
||||
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].error)
|
||||
assert events[0].event_name == "download_error"
|
||||
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
|
||||
assert events[0].payload["error"] is not None
|
||||
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
|
||||
queue.stop()
|
||||
|
||||
|
||||
@@ -209,6 +202,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
assert job.status == DownloadJobStatus.CANCELLED
|
||||
assert cancelled
|
||||
events = event_bus.events
|
||||
assert isinstance(events[-1], DownloadCancelledEvent)
|
||||
assert events[-1].source == "http://www.civitai.com/models/12345"
|
||||
assert events[-1].event_name == "download_cancelled"
|
||||
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
||||
queue.stop()
|
||||
|
||||
@@ -13,24 +13,16 @@ from pydantic_core import Url
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallStartedEvent,
|
||||
)
|
||||
from invokeai.app.services.model_install import (
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from invokeai.app.services.model_install.model_install_common import (
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
|
||||
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
|
||||
from tests.test_nodes import TestEventService
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
OS = platform.uname().system
|
||||
|
||||
@@ -138,16 +130,17 @@ def test_background_install(
|
||||
assert job.total_bytes == size
|
||||
|
||||
# test that the expected events were issued
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
bus = mm2_installer.event_bus
|
||||
assert bus
|
||||
assert hasattr(bus, "events")
|
||||
|
||||
assert len(bus.events) == 2
|
||||
assert isinstance(bus.events[0], ModelInstallStartedEvent)
|
||||
assert isinstance(bus.events[1], ModelInstallCompleteEvent)
|
||||
assert Path(bus.events[0].source) == source
|
||||
assert Path(bus.events[1].source) == source
|
||||
key = bus.events[1].key
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert "model_install_running" in event_names
|
||||
assert "model_install_completed" in event_names
|
||||
assert Path(bus.events[0].payload["source"]) == source
|
||||
assert Path(bus.events[1].payload["source"]) == source
|
||||
key = bus.events[1].payload["key"]
|
||||
assert key is not None
|
||||
|
||||
# see if the thing actually got installed at the expected location
|
||||
@@ -226,7 +219,7 @@ def test_delete_register(
|
||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
bus = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert store is not None
|
||||
assert bus is not None
|
||||
@@ -244,17 +237,20 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
|
||||
assert len(bus.events) == 4
|
||||
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
|
||||
assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
|
||||
assert isinstance(bus.events[2], ModelInstallStartedEvent)
|
||||
assert isinstance(bus.events[3], ModelInstallCompleteEvent)
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert event_names == [
|
||||
"model_install_downloading",
|
||||
"model_install_downloads_done",
|
||||
"model_install_running",
|
||||
"model_install_completed",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
bus = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert isinstance(bus, EventServiceBase)
|
||||
assert store is not None
|
||||
@@ -271,10 +267,15 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
|
||||
assert model_record.type == ModelType.Main
|
||||
assert model_record.format == ModelFormat.Diffusers
|
||||
|
||||
assert any(isinstance(x, ModelInstallStartedEvent) for x in bus.events)
|
||||
assert any(isinstance(x, ModelInstallDownloadProgressEvent) for x in bus.events)
|
||||
assert any(isinstance(x, ModelInstallCompleteEvent) for x in bus.events)
|
||||
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||
assert len(bus.events) >= 3
|
||||
event_names = {x.event_name for x in bus.events}
|
||||
assert event_names == {
|
||||
"model_install_downloading",
|
||||
"model_install_downloads_done",
|
||||
"model_install_running",
|
||||
"model_install_completed",
|
||||
}
|
||||
|
||||
|
||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
|
||||
@@ -3,13 +3,16 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
|
||||
@@ -36,7 +39,27 @@ from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||
RepoHFModelJson1,
|
||||
)
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
|
||||
class DummyEvent(BaseModel):
|
||||
"""Dummy Event to use with Dummy Event service."""
|
||||
|
||||
event_name: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
class DummyEventService(EventServiceBase):
|
||||
"""Dummy event service for testing."""
|
||||
|
||||
events: List[DummyEvent]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.events = []
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
|
||||
|
||||
# Create a temporary directory using the contents of `./data/invokeai_root` as the template
|
||||
@@ -104,7 +127,7 @@ def mm2_installer(
|
||||
) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = TestEventService()
|
||||
events = DummyEventService()
|
||||
store = ModelRecordServiceSQL(db)
|
||||
|
||||
installer = ModelInstallService(
|
||||
|
||||
@@ -20,7 +20,6 @@ from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
@@ -8,7 +10,6 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
)
|
||||
from invokeai.app.invocations.fields import InputField, OutputField
|
||||
from invokeai.app.invocations.image import ImageField
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
|
||||
@@ -116,10 +117,11 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
|
||||
)
|
||||
|
||||
|
||||
class TestEvent(EventBase):
|
||||
class TestEvent(BaseModel):
|
||||
__test__ = False # not a pytest test case
|
||||
|
||||
__event_name__ = "test_event"
|
||||
event_name: str
|
||||
payload: Any
|
||||
|
||||
|
||||
class TestEventService(EventServiceBase):
|
||||
@@ -127,10 +129,10 @@ class TestEventService(EventServiceBase):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.events: list[EventBase] = []
|
||||
self.events: list[TestEvent] = []
|
||||
|
||||
def dispatch(self, event: EventBase) -> None:
|
||||
self.events.append(event)
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
pass
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user