Compare commits

..

31 Commits

Author SHA1 Message Date
psychedelicious
dc78a0e699 fix(ui): correctly fallback to error message when traceback is empty string 2024-05-24 12:15:51 +10:00
psychedelicious
08a42c3c03 tidy(ui): remove extraneous condition in socketInvocationError 2024-05-24 12:14:48 +10:00
psychedelicious
0758e9cb9b fix(ui): race condition with progress
There's a race condition where a canceled session may emit a progress event or two after it's been canceled, and the progress image isn't cleared out.

To resolve this, the system slice tracks canceled session ids. When a progress event comes in, we check the cancellations and skip setting the progress if canceled.
2024-05-24 12:01:02 +10:00
psychedelicious
fb93e686b2 feat(processor): add debug log stmts to session running callbacks 2024-05-24 11:28:55 +10:00
psychedelicious
350feeed56 fix(processor): fix race condition related to clearing the queue 2024-05-24 11:26:57 +10:00
psychedelicious
169b75b2b7 tidy(processor): remove test callbacks 2024-05-24 11:23:26 +10:00
psychedelicious
c88de180e7 tidy(queue): delete unused delete_queue_item method 2024-05-24 10:48:33 +10:00
psychedelicious
7d1844eaf2 chore: ruff 2024-05-24 10:21:01 +10:00
psychedelicious
a98ddedb95 docs(processor): update docstrings, comments 2024-05-24 10:20:20 +10:00
psychedelicious
6063487b20 feat(ui): handle enriched events 2024-05-24 09:30:07 +10:00
psychedelicious
9a4c167342 chore(ui): typegen 2024-05-24 09:30:07 +10:00
psychedelicious
19227fe4e6 feat(app): update test event callbacks 2024-05-24 09:30:07 +10:00
psychedelicious
db0ef8d316 feat(processor): update enriched errors & fail_queue_item() 2024-05-24 09:30:07 +10:00
psychedelicious
6a34176376 feat(events): add enriched errors to events 2024-05-24 09:30:07 +10:00
psychedelicious
d6696a7b97 feat(queue): session queue error handling
- Add handling for new error columns `error_type`, `error_message`, `error_traceback`.
- Update queue item model to include the new data. The `error_traceback` field has an alias of `error` for backwards compatibility.
- Add `fail_queue_item` method. This was previously handled by `cancel_queue_item`. Splitting this functionality makes failing a queue item a bit more explicit. We also don't need to handle multiple optional error args.
-
2024-05-24 09:30:01 +10:00
psychedelicious
0e81e7b460 feat(db): add error_type, error_message, rename error -> error_traceback to session_queue table 2024-05-24 09:28:48 +10:00
psychedelicious
7652fbc2e9 fix(processor): restore missing update of session 2024-05-24 09:26:33 +10:00
psychedelicious
a55b2f09e2 chore: ruff 2024-05-24 09:20:15 +10:00
psychedelicious
23b05344a3 feat(processor): get user/project from queue item w/ fallback 2024-05-24 09:20:15 +10:00
psychedelicious
80905ff3ea fix(app): fix logging of error classes instead of class names 2024-05-24 09:20:15 +10:00
psychedelicious
df5457231f feat(app): handle preparation errors as node errors
We were not handling node preparation errors as node errors before. Here's the explanation, copied from a comment that is no longer required:

---

TODO(psyche): Sessions only support errors on nodes, not on the session itself. When an error occurs outside
node execution, it bubbles up to the processor where it is treated as a queue item error.

Nodes are pydantic models. When we prepare a node in `session.next()`, we set its inputs. This can cause a
pydantic validation error. For example, consider a resize image node which has a constraint on its `width`
input field - it must be greater than zero. During preparation, if the width is set to zero, pydantic will
raise a validation error.

When this happens, it breaks the flow before `invocation` is set. We can't set an error on the invocation
because we didn't get far enough to get it - we don't know its id. Hence, we just set it as a queue item error.

---

This change wraps the node preparation step with exception handling. A new `NodeInputError` exception is raised when there is a validation error. This error has the node (in the state it was in just prior to the error) and an identifier of the input that failed.

This allows us to mark the node that failed preparation as errored, correctly making such errors _node_ errors and not _processor_ errors. It's much easier to diagnose these situations. The error messages look like this:

> Node b5ac87c6-0678-4b8c-96b9-d215aee12175 has invalid incoming input for height

Some of the exception handling logic is cleaned up.
2024-05-24 09:20:15 +10:00
psychedelicious
d30c1ad6dc docs(app): explain why errors are handled poorly 2024-05-24 09:20:15 +10:00
psychedelicious
b1f819ae8d tidy(app): "outputs" -> "output" 2024-05-24 09:20:15 +10:00
psychedelicious
eff359625a tidy(app): rearrange proccessor 2024-05-24 09:20:15 +10:00
psychedelicious
cef1585dfb feat(app): support multiple processor lifecycle callbacks 2024-05-24 09:19:55 +10:00
psychedelicious
cb8e9e1c7b feat(app): make things in session runner private 2024-05-24 09:19:55 +10:00
psychedelicious
f7c356d142 feat(app): iterate on processor split 2
- Use protocol to define callbacks, this allows them to have kwargs
- Shuffle the profiler around a bit
- Move `thread_limit` and `polling_interval` to `__init__`; `start` is called programmatically and will never get these args in practice
2024-05-24 09:19:55 +10:00
psychedelicious
efb069dd71 feat(app): iterate on processor split
- Add `OnNodeError` and `OnNonFatalProcessorError` callbacks
- Move all session/node callbacks to `SessionRunner` - this ensures we dump perf stats before resetting them and generally makes sense to me
- Remove `complete` event from `SessionRunner`, it's essentially the same as `OnAfterRunSession`
- Remove extraneous `next_invocation` block, which would treat a processor error as a node error
- Simplify loops
- Add some callbacks for testing, to be removed before merge
2024-05-24 09:19:55 +10:00
brandonrising
8edc25d35a Fix next node calling logic 2024-05-24 09:17:43 +10:00
brandonrising
82957bb826 Run ruff 2024-05-24 09:17:43 +10:00
brandonrising
e51a3025ea Break apart session processor and the running of each session into separate classes 2024-05-24 09:17:43 +10:00
128 changed files with 3231 additions and 5768 deletions

View File

@@ -18,7 +18,6 @@ help:
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
@echo "installer-zip Build the installer .zip file for the current version"
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
# Runs ruff, fixing any safely-fixable errors and formatting
ruff:
@@ -71,6 +70,3 @@ installer-zip:
tag-release:
cd installer && ./tag_release.sh
# Generate the OpenAPI Schema for the app
openapi:
python scripts/generate_openapi_schema.py

View File

@@ -64,7 +64,7 @@ GPU_DRIVER=nvidia
Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
## Even More Customizing!
## Even Moar Customizing!
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.

View File

@@ -154,18 +154,6 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file.
Check the [configuration docs] for more detail about the settings and how to specify them.
## `ModuleNotFoundError: No module named 'controlnet_aux'`
`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control.
If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed:
- Run the Invoke launcher
- Choose the developer console option
- Run this command: `pip cache remove controlnet_aux`
- Close the terminal window
- Download and run the [installer](https://github.com/invoke-ai/InvokeAI/releases/latest), selecting your current install location
## Out of Memory Issues
The models are large, VRAM is expensive, and you may find yourself

View File

@@ -20,7 +20,7 @@ When you generate an image using text-to-image, multiple steps occur in latent s
4. The VAE decodes the final latent image from latent space into image space.
Image-to-image is a similar process, with only step 1 being different:
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how many noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how may noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor).

View File

@@ -18,7 +18,6 @@ from ..services.boards.boards_default import BoardService
from ..services.bulk_download.bulk_download_default import BulkDownloadService
from ..services.config import InvokeAIAppConfig
from ..services.download import DownloadQueueService
from ..services.events.events_fastapievents import FastAPIEventService
from ..services.image_files.image_files_disk import DiskImageFileStorage
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
from ..services.images.images_default import ImageService
@@ -34,6 +33,7 @@ from ..services.session_processor.session_processor_default import DefaultSessio
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
from ..services.urls.urls_default import LocalUrlService
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from .events import FastAPIEventService
# TODO: is there a better way to achieve this?
@@ -103,6 +103,7 @@ class ApiDependencies:
)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
session_queue = SqliteSessionQueue(db=db)
urls = LocalUrlService()

View File

@@ -0,0 +1,52 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from typing import Any
from fastapi_events.dispatcher import dispatch
from ..services.events.events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
event_handler_id: int
__queue: Queue
__stop_event: threading.Event
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self.__queue = Queue()
self.__stop_event = threading.Event()
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self.__stop_event.set()
self.__queue.put(None)
def dispatch(self, event_name: str, payload: Any) -> None:
self.__queue.put({"event_name": event_name, "payload": payload})
async def __dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self.__queue.get(block=False)
if not event: # Probably stopping
continue
dispatch(
event.get("event_name"),
payload=event.get("payload"),
middleware_id=self.event_handler_id,
)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException
from typing_extensions import Annotated
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.model_install import ModelInstallJob
from invokeai.app.services.model_records import (
DuplicateModelException,
InvalidModelException,

View File

@@ -1,125 +1,66 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from socketio import ASGIApp, AsyncServer
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadEventBase,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadEventBase,
DownloadProgressEvent,
DownloadStartedEvent,
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelEventBase,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEventBase,
QueueItemStatusChangedEvent,
register_events,
)
class QueueSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io queue room.
This is a pydantic model to ensure the data is in the correct format."""
queue_id: str
class BulkDownloadSubscriptionEvent(BaseModel):
"""Event data for subscribing to the socket.io bulk downloads room.
This is a pydantic model to ensure the data is in the correct format."""
bulk_download_id: str
QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
QueueItemStatusChangedEvent,
BatchEnqueuedEvent,
QueueClearedEvent,
}
MODEL_EVENTS = {
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
ModelInstallCancelledEvent,
ModelInstallErrorEvent,
}
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
from ..services.events.events_base import EventServiceBase
class SocketIO:
_sub_queue = "subscribe_queue"
_unsub_queue = "unsubscribe_queue"
__sio: AsyncServer
__app: ASGIApp
_sub_bulk_download = "subscribe_bulk_download"
_unsub_bulk_download = "unsubscribe_bulk_download"
__sub_queue: str = "subscribe_queue"
__unsub_queue: str = "unsubscribe_queue"
__sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI):
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
app.mount("/ws", self._app)
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
app.mount("/ws", self.__app)
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
register_events(QUEUE_EVENTS, self._handle_queue_event)
register_events(MODEL_EVENTS, self._handle_model_event)
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["queue_id"],
)
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.enter_room(sid, data["queue_id"])
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
if "queue_id" in data:
await self.__sio.leave_room(sid, data["queue_id"])
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
async def _handle_bulk_download_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["bulk_download_id"],
)
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.enter_room(sid, data["bulk_download_id"])
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.leave_room(sid, data["bulk_download_id"])

View File

@@ -3,7 +3,9 @@ import logging
import mimetypes
import socket
from contextlib import asynccontextmanager
from inspect import signature
from pathlib import Path
from typing import Any
import torch
import uvicorn
@@ -11,9 +13,11 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
from fastapi.openapi.utils import get_openapi
from fastapi.responses import HTMLResponse
from fastapi_events.handlers.local import local_handler
from fastapi_events.middleware import EventHandlerASGIMiddleware
from pydantic.json_schema import models_json_schema
from torch.backends.mps import is_available as is_mps_available
# for PyCharm:
@@ -21,8 +25,9 @@ from torch.backends.mps import is_available as is_mps_available
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config
from invokeai.app.util.custom_openapi import get_openapi_func
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice
from ..backend.util.logging import InvokeAILogger
@@ -39,6 +44,11 @@ from .api.routers import (
workflows,
)
from .api.sockets import SocketIO
from .invocations.baseinvocation import (
BaseInvocation,
UIConfigBase,
)
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
app_config = get_config()
@@ -108,7 +118,93 @@ app.include_router(app_info.app_router, prefix="/api")
app.include_router(session_queue.session_queue_router, prefix="/api")
app.include_router(workflows.workflows_router, prefix="/api")
app.openapi = get_openapi_func(app)
# Build a custom OpenAPI to include all outputs
# TODO: can outputs be included on metadata of invocation schemas somehow?
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# Add all outputs
all_invocations = BaseInvocation.get_invocations()
output_types = set()
output_type_titles = {}
for invoker in all_invocations:
output_type = signature(invoker.invoke).return_annotation
output_types.add(output_type)
output_schemas = models_json_schema(
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
)
for schema_key, output_schema in output_schemas[1]["$defs"].items():
# TODO: note that we assume the schema_key here is the TYPE.__name__
# This could break in some cases, figure out a better way to do it
output_type_titles[schema_key] = output_schema["title"]
openapi_schema["components"]["schemas"][schema_key] = output_schema
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
# Some models don't end up in the schemas as standalone definitions
additional_schemas = models_json_schema(
[
(UIConfigBase, "serialization"),
(InputFieldJSONSchemaExtra, "serialization"),
(OutputFieldJSONSchemaExtra, "serialization"),
(ModelIdentifierField, "serialization"),
(ProgressImage, "serialization"),
],
ref_template="#/components/schemas/{model}",
)
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
openapi_schema["components"]["schemas"][schema_key] = schema_json
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": {},
"required": [],
}
# Add a reference to the output type to additionalProperties of the invoker schema
for invoker in all_invocations:
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
output_type = signature(obj=invoker.invoke).return_annotation
output_type_title = output_type_titles[output_type.__name__]
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
invoker_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
invoker_schema["class"] = "invocation"
# This code no longer seems to be necessary?
# Leave it here just in case
#
# from invokeai.backend.model_manager import get_model_config_formats
# formats = get_model_config_formats()
# for model_config_name, enum_set in formats.items():
# if model_config_name in openapi_schema["components"]["schemas"]:
# # print(f"Config with name {name} already defined")
# continue
# openapi_schema["components"]["schemas"][model_config_name] = {
# "title": model_config_name,
# "description": "An enumeration.",
# "type": "string",
# "enum": [v.value for v in enum_set],
# }
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
@app.get("/docs", include_in_schema=False)

View File

@@ -98,13 +98,11 @@ class BaseInvocationOutput(BaseModel):
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def register_output(cls, output: BaseInvocationOutput) -> None:
"""Registers an invocation output."""
cls._output_classes.add(output)
cls._typeadapter_needs_update = True
@classmethod
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
@@ -114,12 +112,11 @@ class BaseInvocationOutput(BaseModel):
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocationOutput = TypeAliasType(
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
if not cls._typeadapter:
InvocationOutputsUnion = TypeAliasType(
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
cls._typeadapter_needs_update = False
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
return cls._typeadapter
@classmethod
@@ -128,13 +125,12 @@ class BaseInvocationOutput(BaseModel):
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
# Because we use a pydantic Literal field with default value for the invocation type,
# it will be typed as optional in the OpenAPI schema. Make it required manually.
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "output"
schema["required"].extend(["type"])
@classmethod
@@ -171,7 +167,6 @@ class BaseInvocation(ABC, BaseModel):
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
_typeadapter_needs_update: ClassVar[bool] = False
@classmethod
def get_type(cls) -> str:
@@ -182,17 +177,15 @@ class BaseInvocation(ABC, BaseModel):
def register_invocation(cls, invocation: BaseInvocation) -> None:
"""Registers an invocation."""
cls._invocation_classes.add(invocation)
cls._typeadapter_needs_update = True
@classmethod
def get_typeadapter(cls) -> TypeAdapter[Any]:
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
if not cls._typeadapter or cls._typeadapter_needs_update:
AnyInvocation = TypeAliasType(
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
if not cls._typeadapter:
InvocationsUnion = TypeAliasType(
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
)
cls._typeadapter = TypeAdapter(AnyInvocation)
cls._typeadapter_needs_update = False
cls._typeadapter = TypeAdapter(InvocationsUnion)
return cls._typeadapter
@classmethod
@@ -228,7 +221,7 @@ class BaseInvocation(ABC, BaseModel):
return signature(cls.invoke).return_annotation
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
if uiconfig is not None:
@@ -244,7 +237,6 @@ class BaseInvocation(ABC, BaseModel):
schema["version"] = uiconfig.version
if "required" not in schema or not isinstance(schema["required"], list):
schema["required"] = []
schema["class"] = "invocation"
schema["required"].extend(["type", "id"])
@abstractmethod
@@ -318,7 +310,7 @@ class BaseInvocation(ABC, BaseModel):
protected_namespaces=(),
validate_assignment=True,
json_schema_extra=json_schema_extra,
json_schema_serialization_defaults_required=False,
json_schema_serialization_defaults_required=True,
coerce_numbers_to_str=True,
)

View File

@@ -65,7 +65,11 @@ class CompelInvocation(BaseInvocation):
@torch.no_grad()
def invoke(self, context: InvocationContext) -> ConditioningOutput:
tokenizer_info = context.models.load(self.clip.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(self.clip.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, CLIPTextModel)
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.clip.loras:
@@ -80,21 +84,19 @@ class CompelInvocation(BaseInvocation):
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
with (
# apply all patches while the model is on the target device
text_encoder_info as text_encoder,
tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
tokenizer,
ti_manager,
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
):
assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(tokenizer, CLIPTokenizer)
compel = Compel(
tokenizer=patched_tokenizer,
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@@ -104,7 +106,7 @@ class CompelInvocation(BaseInvocation):
conjunction = Compel.parse_prompt_string(self.prompt)
if context.config.get().log_tokenization:
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
log_tokenization_for_conjunction(conjunction, tokenizer)
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
@@ -134,7 +136,11 @@ class SDXLPromptInvocationBase:
zero_on_empty: bool,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
tokenizer_info = context.models.load(clip_field.tokenizer)
tokenizer_model = tokenizer_info.model
assert isinstance(tokenizer_model, CLIPTokenizer)
text_encoder_info = context.models.load(clip_field.text_encoder)
text_encoder_model = text_encoder_info.model
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
# return zero on empty
if prompt == "" and zero_on_empty:
@@ -171,23 +177,20 @@ class SDXLPromptInvocationBase:
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
with (
# apply all patches while the model is on the target device
text_encoder_info as text_encoder,
tokenizer_info as tokenizer,
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
patched_tokenizer,
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
tokenizer,
ti_manager,
),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
):
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(tokenizer, CLIPTokenizer)
text_encoder = cast(CLIPTextModel, text_encoder)
compel = Compel(
tokenizer=patched_tokenizer,
tokenizer=tokenizer,
text_encoder=text_encoder,
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
@@ -200,7 +203,7 @@ class SDXLPromptInvocationBase:
if context.config.get().log_tokenization:
# TODO: better logging for and syntax
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
log_tokenization_for_conjunction(conjunction, tokenizer)
# TODO: ask for optimizations? to not run text_encoder twice
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)

View File

@@ -50,7 +50,7 @@ from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput,
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType, LoadedModel
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
@@ -672,52 +672,54 @@ class DenoiseLatentsInvocation(BaseInvocation):
return controlnet_data
def prep_ip_adapter_image_prompts(
self,
context: InvocationContext,
ip_adapters: List[IPAdapterField],
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
image_prompts = []
for single_ip_adapter in ip_adapters:
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
assert isinstance(ip_adapter_model, IPAdapter)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
single_ipa_image_fields = [single_ipa_image_fields]
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model
)
image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds))
return image_prompts
def prep_ip_adapter_data(
self,
context: InvocationContext,
ip_adapters: List[IPAdapterField],
image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
exit_stack: ExitStack,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> Optional[List[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data."""
ip_adapter_data_list = []
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
ip_adapters, image_prompts, strict=True
):
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
) -> Optional[list[IPAdapterData]]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
to the `conditioning_data` (in-place).
"""
if ip_adapter is None:
return None
mask_field = single_ip_adapter.mask
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
if not isinstance(ip_adapter, list):
ip_adapter = [ip_adapter]
if len(ip_adapter) == 0:
return None
ip_adapter_data_list = []
for single_ip_adapter in ip_adapter:
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.models.load(single_ip_adapter.ip_adapter_model)
)
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
single_ipa_image_fields = single_ip_adapter.image
if not isinstance(single_ipa_image_fields, list):
single_ipa_image_fields = [single_ipa_image_fields]
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
single_ipa_images, image_encoder_model
)
mask = single_ip_adapter.mask
if mask is not None:
mask = context.tensors.load(mask.tensor_name)
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
ip_adapter_data_list.append(
@@ -732,7 +734,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
)
return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None
return ip_adapter_data_list
def run_t2i_adapters(
self,
@@ -853,16 +855,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
# At some point, someone decided that schedulers that accept a generator should use the original seed with
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
# reproducibility.
#
# These Invoke-supported schedulers accept a generator as of 2024-06-04:
# - DDIMScheduler
# - DDPMScheduler
# - DPMSolverMultistepScheduler
# - EulerAncestralDiscreteScheduler
# - EulerDiscreteScheduler
# - KDPM2AncestralDiscreteScheduler
# - LCMScheduler
# - TCDScheduler
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
if isinstance(scheduler, TCDScheduler):
scheduler_step_kwargs.update({"eta": 1.0})
@@ -920,20 +912,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True,
)
ip_adapters: List[IPAdapterField] = []
if self.ip_adapter is not None:
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
if isinstance(self.ip_adapter, list):
ip_adapters = self.ip_adapter
else:
ip_adapters = [self.ip_adapter]
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
# a series of image conditioning embeddings. This is being done here rather than in the
# big model context below in order to use less VRAM on low-VRAM systems.
# The image prompts are then passed to prep_ip_adapter_data().
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
# get the unet's config so that we can pass the base to dispatch_progress()
unet_config = context.models.get_config(self.unet.unet.key)
@@ -952,9 +930,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
ExitStack() as exit_stack,
ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config),
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
unet_info as unet,
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
):
@@ -992,8 +970,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
ip_adapter_data = self.prep_ip_adapter_data(
context=context,
ip_adapters=ip_adapters,
image_prompts=image_prompts,
ip_adapter=self.ip_adapter,
exit_stack=exit_stack,
latent_height=latent_height,
latent_width=latent_width,
@@ -1308,7 +1285,7 @@ class ImageToLatentsInvocation(BaseInvocation):
title="Blend Latents",
tags=["latents", "blend"],
category="latents",
version="1.0.3",
version="1.0.2",
)
class BlendLatentsInvocation(BaseInvocation):
"""Blend two latents using a given alpha. Latents must have same size."""
@@ -1387,7 +1364,7 @@ class BlendLatentsInvocation(BaseInvocation):
TorchDevice.empty_cache()
name = context.tensors.save(tensor=blended_latents)
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)
return LatentsOutput.build(latents_name=name, latents=blended_latents)
# The Crop Latents node was copied from @skunkworxdark's implementation here:

View File

@@ -106,7 +106,9 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
self._invoker.services.events.emit_bulk_download_started(
bulk_download_id, bulk_download_item_id, bulk_download_item_name
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
def _signal_job_completed(
@@ -116,8 +118,10 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
assert bulk_download_item_name is not None
self._invoker.services.events.emit_bulk_download_complete(
bulk_download_id, bulk_download_item_id, bulk_download_item_name
self._invoker.services.events.emit_bulk_download_completed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
def _signal_job_failed(
@@ -127,8 +131,11 @@ class BulkDownloadService(BulkDownloadBase):
if self._invoker:
assert bulk_download_id is not None
assert exception is not None
self._invoker.services.events.emit_bulk_download_error(
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
self._invoker.services.events.emit_bulk_download_failed(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=str(exception),
)
def stop(self, *args, **kwargs):

View File

@@ -8,13 +8,14 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from typing import Any, Dict, List, Optional, Set
import requests
from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.util.logging import InvokeAILogger
@@ -29,9 +30,6 @@ from .download_base import (
UnknownJobIDException,
)
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
# Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000
@@ -42,7 +40,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__(
self,
max_parallel_dl: int = 5,
event_bus: Optional["EventServiceBase"] = None,
event_bus: Optional[EventServiceBase] = None,
requests_session: Optional[requests.sessions.Session] = None,
):
"""
@@ -345,7 +343,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_started(job)
assert job.download_path
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
def _signal_job_progress(self, job: DownloadJob) -> None:
if job.on_progress:
@@ -356,7 +355,13 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_progress(job)
assert job.download_path
self._event_bus.emit_download_progress(
str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
def _signal_job_complete(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.COMPLETED
@@ -368,7 +373,10 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_complete(job)
assert job.download_path
self._event_bus.emit_download_complete(
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
)
def _signal_job_cancelled(self, job: DownloadJob) -> None:
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
@@ -382,7 +390,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_cancelled(job)
self._event_bus.emit_download_cancelled(str(job.source))
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
job.status = DownloadJobStatus.ERROR
@@ -395,7 +403,9 @@ class DownloadQueueService(DownloadQueueServiceBase):
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
)
if self._event_bus:
self._event_bus.emit_download_error(job)
assert job.error_type
assert job.error
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")

View File

@@ -1,195 +1,494 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import TYPE_CHECKING, Optional
from typing import Any, Dict, List, Optional, Union
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
EventBase,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueItemStatusChangedEvent,
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager import AnyModelConfig
from invokeai.backend.model_manager.config import SubModelType
class EventServiceBase:
queue_event: str = "queue_event"
bulk_download_event: str = "bulk_download_event"
download_event: str = "download_event"
model_event: str = "model_event"
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event: "EventBase") -> None:
def dispatch(self, event_name: str, payload: Any) -> None:
pass
# region: Invocation
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
"""Bulk download events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.bulk_download_event,
payload={"event": event_name, "data": payload},
)
def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
"""Emitted when an invocation is started"""
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
"""Queue events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.queue_event,
payload={"event": event_name, "data": payload},
)
def emit_invocation_denoise_progress(
def __emit_download_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.download_event,
payload={"event": event_name, "data": payload},
)
def __emit_model_event(self, event_name: str, payload: dict) -> None:
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.model_event,
payload={"event": event_name, "data": payload},
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
intermediate_state: PipelineIntermediateState,
progress_image: "ProgressImage",
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node_id: str,
source_node_id: str,
progress_image: Optional[ProgressImage],
step: int,
order: int,
total_steps: int,
) -> None:
"""Emitted at each step during denoising of an invocation."""
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
"""Emitted when there is generation progress"""
self.__emit_queue_event(
event_name="generator_progress",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node_id": node_id,
"source_node_id": source_node_id,
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
"step": step,
"order": order,
"total_steps": total_steps,
},
)
def emit_invocation_complete(
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation is complete"""
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"result": result,
},
)
def emit_invocation_error(
self,
queue_item: "SessionQueueItem",
invocation: "BaseInvocation",
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
error_type: str,
error_message: str,
error_traceback: str,
user_id: str | None,
project_id: str | None,
) -> None:
"""Emitted when an invocation encounters an error"""
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback))
"""Emitted when an invocation has completed"""
self.__emit_queue_event(
event_name="invocation_error",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
"error_type": error_type,
"error_message": error_message,
"error_traceback": error_traceback,
"user_id": user_id,
"project_id": project_id,
},
)
# endregion
def emit_invocation_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has started"""
self.__emit_queue_event(
event_name="invocation_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"node": node,
"source_node_id": source_node_id,
},
)
# region Queue
def emit_graph_execution_complete(
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_queue_event(
event_name="graph_execution_state_complete",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_model_load_started(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is requested"""
self.__emit_queue_event(
event_name="model_load_started",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_model_load_completed(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_queue_event(
event_name="model_load_completed",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
"model_config": model_config.model_dump(mode="json"),
"submodel_type": submodel_type,
},
)
def emit_session_canceled(
self,
queue_id: str,
queue_item_id: int,
queue_batch_id: str,
graph_execution_state_id: str,
) -> None:
"""Emitted when a session is canceled"""
self.__emit_queue_event(
event_name="session_canceled",
payload={
"queue_id": queue_id,
"queue_item_id": queue_item_id,
"queue_batch_id": queue_batch_id,
"graph_execution_state_id": graph_execution_state_id,
},
)
def emit_queue_item_status_changed(
self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
self,
session_queue_item: SessionQueueItem,
batch_status: BatchStatus,
queue_status: SessionQueueStatus,
) -> None:
"""Emitted when a queue item's status changes"""
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload={
"queue_id": queue_status.queue_id,
"queue_item": {
"queue_id": session_queue_item.queue_id,
"item_id": session_queue_item.item_id,
"status": session_queue_item.status,
"batch_id": session_queue_item.batch_id,
"session_id": session_queue_item.session_id,
"error_type": session_queue_item.error_type,
"error_message": session_queue_item.error_message,
"error_traceback": session_queue_item.error_traceback,
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
},
"batch_status": batch_status.model_dump(mode="json"),
"queue_status": queue_status.model_dump(mode="json"),
},
)
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
"""Emitted when a batch is enqueued"""
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
self.__emit_queue_event(
event_name="batch_enqueued",
payload={
"queue_id": enqueue_result.queue_id,
"batch_id": enqueue_result.batch.batch_id,
"enqueued": enqueue_result.enqueued,
},
)
def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when a queue is cleared"""
self.dispatch(QueueClearedEvent.build(queue_id))
"""Emitted when the queue is cleared"""
self.__emit_queue_event(
event_name="queue_cleared",
payload={"queue_id": queue_id},
)
# endregion
def emit_download_started(self, source: str, download_path: str) -> None:
"""
Emit when a download job is started.
# region Download
:param url: The downloaded url
"""
self.__emit_download_event(
event_name="download_started",
payload={"source": source, "download_path": download_path},
)
def emit_download_started(self, job: "DownloadJob") -> None:
"""Emitted when a download is started"""
self.dispatch(DownloadStartedEvent.build(job))
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
"""
Emit "download_progress" events at regular intervals during a download job.
def emit_download_progress(self, job: "DownloadJob") -> None:
"""Emitted at intervals during a download"""
self.dispatch(DownloadProgressEvent.build(job))
:param source: The downloaded source
:param download_path: The local downloaded file
:param current_bytes: Number of bytes downloaded so far
:param total_bytes: The size of the file being downloaded (if known)
"""
self.__emit_download_event(
event_name="download_progress",
payload={
"source": source,
"download_path": download_path,
"current_bytes": current_bytes,
"total_bytes": total_bytes,
},
)
def emit_download_complete(self, job: "DownloadJob") -> None:
"""Emitted when a download is completed"""
self.dispatch(DownloadCompleteEvent.build(job))
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
"""
Emit a "download_complete" event at the end of a successful download.
def emit_download_cancelled(self, job: "DownloadJob") -> None:
"""Emitted when a download is cancelled"""
self.dispatch(DownloadCancelledEvent.build(job))
:param source: Source URL
:param download_path: Path to the locally downloaded file
:param total_bytes: The size of the downloaded file
"""
self.__emit_download_event(
event_name="download_complete",
payload={
"source": source,
"download_path": download_path,
"total_bytes": total_bytes,
},
)
def emit_download_error(self, job: "DownloadJob") -> None:
"""Emitted when a download encounters an error"""
self.dispatch(DownloadErrorEvent.build(job))
def emit_download_cancelled(self, source: str) -> None:
"""Emit a "download_cancelled" event in the event that the download was cancelled by user."""
self.__emit_download_event(
event_name="download_cancelled",
payload={
"source": source,
},
)
# endregion
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
"""
Emit a "download_error" event when an download job encounters an exception.
# region Model loading
:param source: Source URL
:param error_type: The name of the exception that raised the error
:param error: The traceback from this error
"""
self.__emit_download_event(
event_name="download_error",
payload={
"source": source,
"error_type": error_type,
"error": error,
},
)
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
"""Emitted when a model load is started."""
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
def emit_model_load_complete(
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
def emit_model_install_downloading(
self,
source: str,
local_path: str,
bytes: int,
total_bytes: int,
parts: List[Dict[str, Union[str, int]]],
id: int,
) -> None:
"""Emitted when a model load is complete."""
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
"""
Emit at intervals while the install job is in progress (remote models only).
# endregion
:param source: Source of the model
:param local_path: Where model is downloading to
:param parts: Progress of downloading URLs that comprise the model, if any.
:param bytes: Number of bytes downloaded so far.
:param total_bytes: Total size of download, including all files.
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
"""
self.__emit_model_event(
event_name="model_install_downloading",
payload={
"source": source,
"local_path": local_path,
"bytes": bytes,
"total_bytes": total_bytes,
"parts": parts,
"id": id,
},
)
# region Model install
def emit_model_install_downloads_done(self, source: str) -> None:
"""
Emit once when all parts are downloaded, but before the probing and registration start.
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
"""Emitted at intervals while the install job is in progress (remote models only)."""
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_downloads_done",
payload={"source": source},
)
def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
def emit_model_install_running(self, source: str) -> None:
"""
Emit once when an install job becomes active.
def emit_model_install_started(self, job: "ModelInstallJob") -> None:
"""Emitted once when an install job is started (after any download)."""
self.dispatch(ModelInstallStartedEvent.build(job))
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_running",
payload={"source": source},
)
def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
"""Emitted when an install job is completed successfully."""
self.dispatch(ModelInstallCompleteEvent.build(job))
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
"""
Emit when an install job is completed successfully.
def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
"""Emitted when an install job is cancelled."""
self.dispatch(ModelInstallCancelledEvent.build(job))
:param source: Source of the model; local path, repo_id or url
:param key: Model config record key
:param total_bytes: Size of the model (may be None for installation of a local path)
"""
self.__emit_model_event(
event_name="model_install_completed",
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
)
def emit_model_install_error(self, job: "ModelInstallJob") -> None:
"""Emitted when an install job encounters an exception."""
self.dispatch(ModelInstallErrorEvent.build(job))
def emit_model_install_cancelled(self, source: str, id: int) -> None:
"""
Emit when an install job is cancelled.
# endregion
:param source: Source of the model; local path, repo_id or url
"""
self.__emit_model_event(
event_name="model_install_cancelled",
payload={"source": source, "id": id},
)
# region Bulk image download
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
"""
Emit when an install job encounters an exception.
:param source: Source of the model
:param error_type: The name of the exception
:param error: A text description of the exception
"""
self.__emit_model_event(
event_name="model_install_error",
payload={"source": source, "error_type": error_type, "error": error, "id": id},
)
def emit_bulk_download_started(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk image download is started"""
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
def emit_bulk_download_complete(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk image download is complete"""
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
def emit_bulk_download_error(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None:
"""Emitted when a bulk image download has an error"""
self.dispatch(
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
"""Emitted when a bulk download starts"""
self._emit_bulk_download_event(
event_name="bulk_download_started",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
# endregion
def emit_bulk_download_completed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> None:
"""Emitted when a bulk download completes"""
self._emit_bulk_download_event(
event_name="bulk_download_completed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
},
)
def emit_bulk_download_failed(
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> None:
"""Emitted when a bulk download fails"""
self._emit_bulk_download_event(
event_name="bulk_download_failed",
payload={
"bulk_download_id": bulk_download_id,
"bulk_download_item_id": bulk_download_item_id,
"bulk_download_item_name": bulk_download_item_name,
"error": error,
},
)

View File

@@ -1,592 +0,0 @@
from math import floor
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
from fastapi_events.handlers.local import local_handler
from fastapi_events.registry.payload_schema import registry as payload_schema
from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
BatchStatus,
EnqueueBatchResult,
SessionQueueItem,
SessionQueueStatus,
)
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
if TYPE_CHECKING:
from invokeai.app.services.download.download_base import DownloadJob
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
class EventBase(BaseModel):
"""Base class for all events. All events must inherit from this class.
Events must define a class attribute `__event_name__` to identify the event.
All other attributes should be defined as normal for a pydantic model.
A timestamp is automatically added to the event when it is created.
"""
__event_name__: ClassVar[str]
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
@classmethod
def get_events(cls) -> set[type["EventBase"]]:
"""Get a set of all event models."""
event_subclasses: set[type["EventBase"]] = set()
for subclass in cls.__subclasses__():
# We only want to include subclasses that are event models, not intermediary classes
if hasattr(subclass, "__event_name__"):
event_subclasses.add(subclass)
event_subclasses.update(subclass.get_events())
return event_subclasses
TEvent = TypeVar("TEvent", bound=EventBase, contravariant=True)
FastAPIEvent: TypeAlias = tuple[str, TEvent]
"""
A tuple representing a `fastapi-events` event, with the event name and payload.
Provide a generic type to `TEvent` to specify the payload type.
"""
class FastAPIEventFunc(Protocol, Generic[TEvent]):
def __call__(self, event: FastAPIEvent[TEvent]) -> Optional[Coroutine[Any, Any, None]]: ...
def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc[TEvent]) -> None:
"""Register a function to handle specific events.
:param events: An event or set of events to handle
:param func: The function to handle the events
"""
events = events if isinstance(events, set) else {events}
for event in events:
assert hasattr(event, "__event_name__")
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
class QueueEventBase(EventBase):
"""Base class for queue events"""
queue_id: str = Field(description="The ID of the queue")
class QueueItemEventBase(QueueEventBase):
"""Base class for queue item events"""
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
class InvocationEventBase(QueueItemEventBase):
"""Base class for invocation events"""
session_id: str = Field(description="The ID of the session (aka graph execution state)")
queue_id: str = Field(description="The ID of the queue")
item_id: int = Field(description="The ID of the queue item")
batch_id: str = Field(description="The ID of the queue batch")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
invocation: AnyInvocation = Field(description="The ID of the invocation")
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
@payload_schema.register
class InvocationStartedEvent(InvocationEventBase):
"""Event model for invocation_started"""
__event_name__ = "invocation_started"
@classmethod
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
)
@payload_schema.register
class InvocationDenoiseProgressEvent(InvocationEventBase):
"""Event model for invocation_denoise_progress"""
__event_name__ = "invocation_denoise_progress"
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
step: int = Field(description="The current step of the invocation")
total_steps: int = Field(description="The total number of steps in the invocation")
order: int = Field(description="The order of the invocation in the session")
percentage: float = Field(description="The percentage of completion of the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
intermediate_state: PipelineIntermediateState,
progress_image: ProgressImage,
) -> "InvocationDenoiseProgressEvent":
step = intermediate_state.step
total_steps = intermediate_state.total_steps
order = intermediate_state.order
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
progress_image=progress_image,
step=step,
total_steps=total_steps,
order=order,
percentage=cls.calc_percentage(step, total_steps, order),
)
@staticmethod
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
"""Calculate the percentage of completion of denoising."""
if total_steps == 0:
return 0.0
if scheduler_order == 2:
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
# order == 1
return (step + 1 + 1) / (total_steps + 1)
@payload_schema.register
class InvocationCompleteEvent(InvocationEventBase):
"""Event model for invocation_complete"""
__event_name__ = "invocation_complete"
result: AnyInvocationOutput = Field(description="The result of the invocation")
@classmethod
def build(
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
) -> "InvocationCompleteEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
result=result,
)
@payload_schema.register
class InvocationErrorEvent(InvocationEventBase):
"""Event model for invocation_error"""
__event_name__ = "invocation_error"
error_type: str = Field(description="The error type")
error_message: str = Field(description="The error message")
error_traceback: str = Field(description="The error traceback")
user_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
project_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
@classmethod
def build(
cls,
queue_item: SessionQueueItem,
invocation: AnyInvocation,
error_type: str,
error_message: str,
error_traceback: str,
) -> "InvocationErrorEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
invocation=invocation,
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
user_id=getattr(queue_item, "user_id", None),
project_id=getattr(queue_item, "project_id", None),
)
@payload_schema.register
class QueueItemStatusChangedEvent(QueueItemEventBase):
"""Event model for queue_item_status_changed"""
__event_name__ = "queue_item_status_changed"
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
error_type: Optional[str] = Field(default=None, description="The error type, if any")
error_message: Optional[str] = Field(default=None, description="The error message, if any")
error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any")
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
batch_status: BatchStatus = Field(description="The status of the batch")
queue_status: SessionQueueStatus = Field(description="The status of the queue")
session_id: str = Field(description="The ID of the session (aka graph execution state)")
@classmethod
def build(
cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus
) -> "QueueItemStatusChangedEvent":
return cls(
queue_id=queue_item.queue_id,
item_id=queue_item.item_id,
batch_id=queue_item.batch_id,
session_id=queue_item.session_id,
status=queue_item.status,
error_type=queue_item.error_type,
error_message=queue_item.error_message,
error_traceback=queue_item.error_traceback,
created_at=str(queue_item.created_at) if queue_item.created_at else None,
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
started_at=str(queue_item.started_at) if queue_item.started_at else None,
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
batch_status=batch_status,
queue_status=queue_status,
)
@payload_schema.register
class BatchEnqueuedEvent(QueueEventBase):
"""Event model for batch_enqueued"""
__event_name__ = "batch_enqueued"
batch_id: str = Field(description="The ID of the batch")
enqueued: int = Field(description="The number of invocations enqueued")
requested: int = Field(
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
)
priority: int = Field(description="The priority of the batch")
@classmethod
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
return cls(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
requested=enqueue_result.requested,
priority=enqueue_result.priority,
)
@payload_schema.register
class QueueClearedEvent(QueueEventBase):
"""Event model for queue_cleared"""
__event_name__ = "queue_cleared"
@classmethod
def build(cls, queue_id: str) -> "QueueClearedEvent":
return cls(queue_id=queue_id)
class DownloadEventBase(EventBase):
"""Base class for events associated with a download"""
source: str = Field(description="The source of the download")
@payload_schema.register
class DownloadStartedEvent(DownloadEventBase):
"""Event model for download_started"""
__event_name__ = "download_started"
download_path: str = Field(description="The local path where the download is saved")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadStartedEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix())
@payload_schema.register
class DownloadProgressEvent(DownloadEventBase):
"""Event model for download_progress"""
__event_name__ = "download_progress"
download_path: str = Field(description="The local path where the download is saved")
current_bytes: int = Field(description="The number of bytes downloaded so far")
total_bytes: int = Field(description="The total number of bytes to be downloaded")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadProgressEvent":
assert job.download_path
return cls(
source=str(job.source),
download_path=job.download_path.as_posix(),
current_bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register
class DownloadCompleteEvent(DownloadEventBase):
"""Event model for download_complete"""
__event_name__ = "download_complete"
download_path: str = Field(description="The local path where the download is saved")
total_bytes: int = Field(description="The total number of bytes downloaded")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent":
assert job.download_path
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
@payload_schema.register
class DownloadCancelledEvent(DownloadEventBase):
"""Event model for download_cancelled"""
__event_name__ = "download_cancelled"
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent":
return cls(source=str(job.source))
@payload_schema.register
class DownloadErrorEvent(DownloadEventBase):
"""Event model for download_error"""
__event_name__ = "download_error"
error_type: str = Field(description="The type of error")
error: str = Field(description="The error message")
@classmethod
def build(cls, job: "DownloadJob") -> "DownloadErrorEvent":
assert job.error_type
assert job.error
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
class ModelEventBase(EventBase):
"""Base class for events associated with a model"""
@payload_schema.register
class ModelLoadStartedEvent(ModelEventBase):
"""Event model for model_load_started"""
__event_name__ = "model_load_started"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelLoadCompleteEvent(ModelEventBase):
"""Event model for model_load_complete"""
__event_name__ = "model_load_complete"
config: AnyModelConfig = Field(description="The model's config")
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
@classmethod
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
return cls(config=config, submodel_type=submodel_type)
@payload_schema.register
class ModelInstallDownloadProgressEvent(ModelEventBase):
"""Event model for model_install_download_progress"""
__event_name__ = "model_install_download_progress"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
local_path: str = Field(description="Where model is downloading to")
bytes: int = Field(description="Number of bytes downloaded so far")
total_bytes: int = Field(description="Total size of download, including all files")
parts: list[dict[str, int | str]] = Field(
description="Progress of downloading URLs that comprise the model, if any"
)
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
parts: list[dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
return cls(
id=job.id,
source=str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
)
@payload_schema.register
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
"""Emitted once when an install job becomes active."""
__event_name__ = "model_install_downloads_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallStartedEvent(ModelEventBase):
"""Event model for model_install_started"""
__event_name__ = "model_install_started"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallCompleteEvent(ModelEventBase):
"""Event model for model_install_complete"""
__event_name__ = "model_install_complete"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
key: str = Field(description="Model config record key")
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
assert job.config_out is not None
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
@payload_schema.register
class ModelInstallCancelledEvent(ModelEventBase):
"""Event model for model_install_cancelled"""
__event_name__ = "model_install_cancelled"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
return cls(id=job.id, source=str(job.source))
@payload_schema.register
class ModelInstallErrorEvent(ModelEventBase):
"""Event model for model_install_error"""
__event_name__ = "model_install_error"
id: int = Field(description="The ID of the install job")
source: str = Field(description="Source of the model; local path, repo_id or url")
error_type: str = Field(description="The name of the exception")
error: str = Field(description="A text description of the exception")
@classmethod
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
assert job.error_type is not None
assert job.error is not None
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
class BulkDownloadEventBase(EventBase):
"""Base class for events associated with a bulk image download"""
bulk_download_id: str = Field(description="The ID of the bulk image download")
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
@payload_schema.register
class BulkDownloadStartedEvent(BulkDownloadEventBase):
"""Event model for bulk_download_started"""
__event_name__ = "bulk_download_started"
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadStartedEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
@payload_schema.register
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
"""Event model for bulk_download_complete"""
__event_name__ = "bulk_download_complete"
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
) -> "BulkDownloadCompleteEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
)
@payload_schema.register
class BulkDownloadErrorEvent(BulkDownloadEventBase):
"""Event model for bulk_download_error"""
__event_name__ = "bulk_download_error"
error: str = Field(description="The error message")
@classmethod
def build(
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
) -> "BulkDownloadErrorEvent":
return cls(
bulk_download_id=bulk_download_id,
bulk_download_item_id=bulk_download_item_id,
bulk_download_item_name=bulk_download_item_name,
error=error,
)

View File

@@ -1,47 +0,0 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
import threading
from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_common import (
EventBase,
)
from .events_base import EventServiceBase
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None:
self.event_handler_id = event_handler_id
self._queue = Queue[EventBase | None]()
self._stop_event = threading.Event()
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
super().__init__()
def stop(self, *args, **kwargs):
self._stop_event.set()
self._queue.put(None)
def dispatch(self, event: EventBase) -> None:
self._queue.put(event)
async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = self._queue.get(block=False)
if not event: # Probably stopping
continue
# Leave the payloads as live pydantic models
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e:
raise e # Raise a proper error

View File

@@ -1,13 +1,11 @@
"""Initialization file for model install service package."""
from .model_install_base import (
ModelInstallServiceBase,
)
from .model_install_common import (
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
UnknownInstallJobException,
URLModelSource,

View File

@@ -1,19 +1,244 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team
"""Baseclass definitions for the model installer."""
import re
import traceback
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
class ModelInstallServiceBase(ABC):
@@ -57,7 +282,7 @@ class ModelInstallServiceBase(ABC):
@property
@abstractmethod
def event_bus(self) -> Optional["EventServiceBase"]:
def event_bus(self) -> Optional[EventServiceBase]:
"""Return the event service base object associated with the installer."""
@abstractmethod

View File

@@ -1,233 +0,0 @@
import re
import traceback
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.download import DownloadJob
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
class InstallStatus(str, Enum):
"""State of an install job running in the background."""
WAITING = "waiting" # waiting to be dequeued
DOWNLOADING = "downloading" # downloading of model files in process
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
RUNNING = "running" # being processed
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated with an error message
class ModelInstallPart(BaseModel):
url: AnyHttpUrl
path: Path
bytes: int = 0
total_bytes: int = 0
class UnknownInstallJobException(Exception):
"""Raised when the status of an unknown job is requested."""
class StringLikeSource(BaseModel):
"""
Base class for model sources, implements functions that lets the source be sorted and indexed.
These shenanigans let this stuff work:
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
mydict = {source1: 'model 1'}
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
assert source1 == source2
assert source1 == 'C:/users/mort/foo.safetensors'
"""
def __hash__(self) -> int:
"""Return hash of the path field, for indexing."""
return hash(str(self))
def __lt__(self, other: object) -> int:
"""Return comparison of the stringified version, for sorting."""
return str(self) < str(other)
def __eq__(self, other: object) -> bool:
"""Return equality on the stringified version."""
if isinstance(other, Path):
return str(self) == other.as_posix()
else:
return str(self) == str(other)
class LocalModelSource(StringLikeSource):
"""A local file or directory path."""
path: str | Path
inplace: Optional[bool] = False
type: Literal["local"] = "local"
# these methods allow the source to be used in a string-like way,
# for example as an index into a dict
def __str__(self) -> str:
"""Return string version of path when string rep needed."""
return Path(self.path).as_posix()
class HFModelSource(StringLikeSource):
"""
A HuggingFace repo_id with optional variant, sub-folder and access token.
Note that the variant option, if not provided to the constructor, will default to fp16, which is
what people (almost) always want.
"""
repo_id: str
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
subfolder: Optional[Path] = None
access_token: Optional[str] = None
type: Literal["hf"] = "hf"
@field_validator("repo_id")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def __str__(self) -> str:
"""Return string version of repoid when string rep needed."""
base: str = self.repo_id
if self.variant:
base += f":{self.variant or ''}"
if self.subfolder:
base += f":{self.subfolder}"
return base
class URLModelSource(StringLikeSource):
"""A generic URL point to a checkpoint file."""
url: AnyHttpUrl
access_token: Optional[str] = None
type: Literal["url"] = "url"
def __str__(self) -> str:
"""Return string version of the url when string rep needed."""
return str(self.url)
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
MODEL_SOURCE_TO_TYPE_MAP = {
URLModelSource: ModelSourceType.Url,
HFModelSource: ModelSourceType.HFRepoID,
LocalModelSource: ModelSourceType.Path,
}
class ModelInstallJob(BaseModel):
"""Object that tracks the current status of an install request."""
id: int = Field(description="Unique ID for this job")
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
config_in: Dict[str, Any] = Field(
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
)
config_out: Optional[AnyModelConfig] = Field(
default=None, description="After successful installation, this will hold the configuration object."
)
inplace: bool = Field(
default=False, description="Leave model in its current location; otherwise install under models directory"
)
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
bytes: int = Field(
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
)
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
error_traceback: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the exception traceback"
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
"""Record the error and traceback from an exception."""
self._exception = e
self.error = str(e)
self.error_traceback = self._format_error(e)
self.status = InstallStatus.ERROR
self.error_reason = self._exception.__class__.__name__ if self._exception else None
def cancel(self) -> None:
"""Call to cancel the job."""
self.status = InstallStatus.CANCELLED
@property
def error_type(self) -> Optional[str]:
"""Class name of the exception that led to status==ERROR."""
return self._exception.__class__.__name__ if self._exception else None
def _format_error(self, exception: Exception) -> str:
"""Error traceback."""
return "".join(traceback.format_exception(exception))
@property
def cancelled(self) -> bool:
"""Set status to CANCELLED."""
return self.status == InstallStatus.CANCELLED
@property
def errored(self) -> bool:
"""Return true if job has errored."""
return self.status == InstallStatus.ERROR
@property
def waiting(self) -> bool:
"""Return true if job is waiting to run."""
return self.status == InstallStatus.WAITING
@property
def downloading(self) -> bool:
"""Return true if job is downloading."""
return self.status == InstallStatus.DOWNLOADING
@property
def downloads_done(self) -> bool:
"""Return true if job's downloads ae done."""
return self.status == InstallStatus.DOWNLOADS_DONE
@property
def running(self) -> bool:
"""Return true if job is running."""
return self.status == InstallStatus.RUNNING
@property
def complete(self) -> bool:
"""Return true if job completed without errors."""
return self.status == InstallStatus.COMPLETED
@property
def in_terminal_state(self) -> bool:
"""Return true if job is in a terminal state."""
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]

View File

@@ -10,7 +10,7 @@ from pathlib import Path
from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
import torch
import yaml
@@ -20,8 +20,8 @@ from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.backend.model_manager.config import (
@@ -45,12 +45,13 @@ from invokeai.backend.util import InvokeAILogger
from invokeai.backend.util.catch_sigint import catch_sigint
from invokeai.backend.util.devices import TorchDevice
from .model_install_common import (
from .model_install_base import (
MODEL_SOURCE_TO_TYPE_MAP,
HFModelSource,
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
ModelSource,
StringLikeSource,
URLModelSource,
@@ -58,9 +59,6 @@ from .model_install_common import (
TMPDIR_PREFIX = "tmpinstall_"
if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
class ModelInstallService(ModelInstallServiceBase):
"""class for InvokeAI model installation."""
@@ -70,7 +68,7 @@ class ModelInstallService(ModelInstallServiceBase):
app_config: InvokeAIAppConfig,
record_store: ModelRecordServiceBase,
download_queue: DownloadQueueServiceBase,
event_bus: Optional["EventServiceBase"] = None,
event_bus: Optional[EventServiceBase] = None,
session: Optional[Session] = None,
):
"""
@@ -106,7 +104,7 @@ class ModelInstallService(ModelInstallServiceBase):
return self._record_store
@property
def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
return self._event_bus
# make the invoker optional here because we don't need it and it
@@ -857,17 +855,35 @@ class ModelInstallService(ModelInstallServiceBase):
job.status = InstallStatus.RUNNING
self._logger.info(f"Model install started: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_started(job)
self._event_bus.emit_model_install_running(str(job.source))
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus:
self._event_bus.emit_model_install_download_progress(job)
parts: List[Dict[str, str | int]] = [
{
"url": str(x.source),
"local_path": str(x.download_path),
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
]
assert job.bytes is not None
assert job.total_bytes is not None
self._event_bus.emit_model_install_downloading(
str(job.source),
local_path=job.local_path.as_posix(),
parts=parts,
bytes=job.bytes,
total_bytes=job.total_bytes,
id=job.id,
)
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.DOWNLOADS_DONE
self._logger.info(f"Model download complete: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_downloads_complete(job)
self._event_bus.emit_model_install_downloads_done(str(job.source))
def _signal_job_completed(self, job: ModelInstallJob) -> None:
job.status = InstallStatus.COMPLETED
@@ -875,19 +891,24 @@ class ModelInstallService(ModelInstallServiceBase):
self._logger.info(f"Model install complete: {job.source}")
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
if self._event_bus:
self._event_bus.emit_model_install_complete(job)
assert job.local_path is not None
assert job.config_out is not None
key = job.config_out.key
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
def _signal_job_errored(self, job: ModelInstallJob) -> None:
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
if self._event_bus:
assert job.error_type is not None
assert job.error is not None
self._event_bus.emit_model_install_error(job)
error_type = job.error_type
error = job.error
assert error_type is not None
assert error is not None
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
self._logger.info(f"Model install canceled: {job.source}")
if self._event_bus:
self._event_bus.emit_model_install_cancelled(job)
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
@staticmethod
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:

View File

@@ -4,6 +4,7 @@
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
@@ -14,12 +15,18 @@ class ModelLoadServiceBase(ABC):
"""Wrapper around AnyModelLoader."""
@abstractmethod
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context_data: Invocation context data used for event reporting
"""
@property

View File

@@ -5,6 +5,7 @@ from typing import Optional, Type
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.invocation_context import InvocationContextData
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import (
LoadedModel,
@@ -50,18 +51,25 @@ class ModelLoadService(ModelLoadServiceBase):
"""Return the checkpoint convert cache used by this loader."""
return self._convert_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(
self,
model_config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
context_data: Optional[InvocationContextData] = None,
) -> LoadedModel:
"""
Given a model's configuration, load it and return the LoadedModel object.
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
:param submodel: For main (pipeline models), the submodel to fetch.
:param context: Invocation context used for event reporting
"""
# We don't have an invoker during testing
# TODO(psyche): Mock this method on the invoker in the tests
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
if context_data:
self._emit_load_event(
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
)
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model: LoadedModel = implementation(
@@ -71,7 +79,40 @@ class ModelLoadService(ModelLoadServiceBase):
convert_cache=self._convert_cache,
).load_model(model_config, submodel_type)
if hasattr(self, "_invoker"):
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
if context_data:
self._emit_load_event(
context_data=context_data,
model_config=model_config,
submodel_type=submodel_type,
loaded=True,
)
return loaded_model
def _emit_load_event(
self,
context_data: InvocationContextData,
model_config: AnyModelConfig,
loaded: Optional[bool] = False,
submodel_type: Optional[SubModelType] = None,
) -> None:
if not self._invoker:
return
if not loaded:
self._invoker.services.events.emit_model_load_started(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)
else:
self._invoker.services.events.emit_model_load_completed(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
model_config=model_config,
submodel_type=submodel_type,
)

View File

@@ -4,14 +4,11 @@ from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
FastAPIEvent,
QueueClearedEvent,
QueueItemStatusChangedEvent,
register_events,
)
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
from invokeai.app.services.session_processor.session_processor_base import (
OnAfterRunNode,
@@ -63,11 +60,6 @@ class DefaultSessionRunner(SessionRunnerBase):
self._cancel_event = cancel_event
self._profiler = profiler
def _is_canceled(self) -> bool:
"""Check if the cancel event is set. This is also passed to the invocation context builder and called during
denoising to check if the session has been canceled."""
return self._cancel_event.is_set()
def run(self, queue_item: SessionQueueItem):
# Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here.
@@ -91,19 +83,13 @@ class DefaultSessionRunner(SessionRunnerBase):
)
break
if invocation is None or self._is_canceled():
if invocation is None or self._cancel_event.is_set():
break
self.run_node(invocation, queue_item)
# The session is complete if all invocations have been run or there is an error on the session.
# At this time, the queue item may be canceled, but the object itself here won't be updated yet. We must
# use the cancel event to check if the session is canceled.
if (
queue_item.session.is_complete()
or self._is_canceled()
or queue_item.status in ["failed", "canceled", "completed"]
):
if queue_item.session.is_complete() or self._cancel_event.is_set():
break
self._on_after_run_session(queue_item=queue_item)
@@ -122,7 +108,7 @@ class DefaultSessionRunner(SessionRunnerBase):
context = build_invocation_context(
data=data,
services=self._services,
is_canceled=self._is_canceled,
cancel_event=self._cancel_event,
)
# Invoke the node
@@ -136,12 +122,16 @@ class DefaultSessionRunner(SessionRunnerBase):
# TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here?
pass
except CanceledException:
# A CanceledException is raised during the denoising step callback if the cancel event is set. We don't need
# to do any handling here, and no error should be set - just pass and the cancellation will be handled
# correctly in the next iteration of the session runner loop.
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# See the comment in the processor's `_on_queue_item_status_changed()` method for more details on how we
# handle cancellation.
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
error_type = e.__class__.__name__
@@ -156,11 +146,7 @@ class DefaultSessionRunner(SessionRunnerBase):
)
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
"""Called before a session is run.
- Start the profiler if profiling is enabled.
- Run any callbacks registered for this event.
"""
"""Run before a session is executed"""
self._services.logger.debug(
f"On before run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
@@ -174,14 +160,7 @@ class DefaultSessionRunner(SessionRunnerBase):
callback(queue_item=queue_item)
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
"""Called after a session is run.
- Stop the profiler if profiling is enabled.
- Update the queue item's session object in the database.
- If not already canceled or failed, complete the queue item.
- Log and reset performance statistics.
- Run any callbacks registered for this event.
"""
"""Run after a session is executed"""
self._services.logger.debug(
f"On after run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
@@ -201,10 +180,14 @@ class DefaultSessionRunner(SessionRunnerBase):
# while the session is running.
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
# The queue item may have been canceled or failed while the session was running. We should only complete it
# if it is not already canceled or failed.
if queue_item.status not in ["canceled", "failed"]:
queue_item = self._services.session_queue.complete_queue_item(queue_item.item_id)
# TODO(psyche): This feels jumbled - we should review separation of concerns here.
# Send complete event. The events service will receive this and update the queue item's status.
self._services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
@@ -218,18 +201,21 @@ class DefaultSessionRunner(SessionRunnerBase):
pass
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Called before a node is run.
- Emits an invocation started event.
- Run any callbacks registered for this event.
"""
"""Run before a node is executed"""
self._services.logger.debug(
f"On before run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Send starting event
self._services.events.emit_invocation_started(queue_item=queue_item, invocation=invocation)
self._services.events.emit_invocation_started(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session_id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
)
for callback in self._on_before_run_node_callbacks:
callback(invocation=invocation, queue_item=queue_item)
@@ -237,18 +223,22 @@ class DefaultSessionRunner(SessionRunnerBase):
def _on_after_run_node(
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
):
"""Called after a node is run.
- Emits an invocation complete event.
- Run any callbacks registered for this event.
"""
"""Run after a node is executed"""
self._services.logger.debug(
f"On after run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
)
# Send complete event on successful runs
self._services.events.emit_invocation_complete(invocation=invocation, queue_item=queue_item, output=output)
self._services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
result=output.model_dump(),
)
for callback in self._on_after_run_node_callbacks:
callback(invocation=invocation, queue_item=queue_item, output=output)
@@ -261,14 +251,7 @@ class DefaultSessionRunner(SessionRunnerBase):
error_message: str,
error_traceback: str,
):
"""Called when a node errors. Node errors may occur when running or preparing the node..
- Set the node error on the session object.
- Log the error.
- Fail the queue item.
- Emits an invocation error event.
- Run any callbacks registered for this event.
"""
"""Run when a node errors"""
self._services.logger.debug(
f"On node error: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
@@ -282,19 +265,19 @@ class DefaultSessionRunner(SessionRunnerBase):
)
self._services.logger.error(error_traceback)
# Fail the queue item
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
queue_item = self._services.session_queue.fail_queue_item(
queue_item.item_id, error_type, error_message, error_traceback
)
# Send error event
self._services.events.emit_invocation_error(
queue_item=queue_item,
invocation=invocation,
queue_batch_id=queue_item.session_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session.id,
node=invocation.model_dump(),
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
user_id=getattr(queue_item, "user_id", None),
project_id=getattr(queue_item, "project_id", None),
)
for callback in self._on_node_error_callbacks:
@@ -332,9 +315,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now_event = ThreadEvent()
self._cancel_event = ThreadEvent()
register_events(QueueClearedEvent, self._on_queue_cleared)
register_events(BatchEnqueuedEvent, self._on_batch_enqueued)
register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
@@ -369,25 +350,31 @@ class DefaultSessionProcessor(SessionProcessorBase):
def _poll_now(self) -> None:
self._poll_now_event.set()
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
async def _on_queue_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
if (
event_name == "session_canceled"
and self._queue_item
and self._queue_item.item_id == event[1]["data"]["queue_item_id"]
):
self._cancel_event.set()
self._poll_now()
async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None:
self._poll_now()
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
#
# Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such
# node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item
# is canceled, and if it is, raises a `CanceledException` to stop execution immediately.
if event[1].status == "canceled":
self._cancel_event.set()
elif (
event_name == "queue_cleared"
and self._queue_item
and self._queue_item.queue_id == event[1]["data"]["queue_id"]
):
self._cancel_event.set()
self._poll_now()
elif event_name == "batch_enqueued":
self._poll_now()
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [
"completed",
"failed",
"canceled",
]:
self._cancel_event.set()
self._poll_now()
def resume(self) -> SessionProcessorStatus:
@@ -476,22 +463,15 @@ class DefaultSessionProcessor(SessionProcessorBase):
error_message: str,
error_traceback: str,
) -> None:
"""Called when a non-fatal error occurs in the processor.
- Log the error.
- If a queue item is provided, update the queue item with the completed session & fail it.
- Run any callbacks registered for this event.
"""
# Non-fatal error in processor
self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}")
self._invoker.services.logger.error(error_traceback)
if queue_item is not None:
# Update the queue item with the completed session & fail it
queue_item = self._invoker.services.session_queue.set_queue_item_session(
queue_item.item_id, queue_item.session
)
queue_item = self._invoker.services.session_queue.fail_queue_item(
# Update the queue item with the completed session
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
# Fail the queue item
self._invoker.services.session_queue.fail_queue_item(
item_id=queue_item.item_id,
error_type=error_type,
error_message=error_message,

View File

@@ -73,11 +73,6 @@ class SessionQueueBase(ABC):
"""Gets the status of a batch"""
pass
@abstractmethod
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
"""Completes a session queue item"""
pass
@abstractmethod
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
"""Cancels a session queue item"""

View File

@@ -2,6 +2,10 @@ import sqlite3
import threading
from typing import Optional, Union, cast
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import (
@@ -38,7 +42,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__invoker = invoker
self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
if prune_result.deleted > 0:
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
@@ -48,6 +52,60 @@ class SqliteSessionQueue(SessionQueueBase):
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
return event[1]["event"] in match_in
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
event_name = event[1]["event"]
# This was a match statement, but match is not supported on python 3.9
if event_name == "graph_execution_state_complete":
await self._handle_complete_event(event)
elif event_name == "invocation_error":
await self._handle_error_event(event)
elif event_name == "session_canceled":
await self._handle_cancel_event(event)
return event
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
# When a queue item has an error, we get an error event, then a completed event.
# Mark the queue item completed only if it isn't already marked completed, e.g.
# by a previously-handled error event.
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
except SessionQueueItemNotFoundError:
return
async def _handle_error_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
error_type = event[1]["data"]["error_type"]
error_message = event[1]["data"]["error_message"]
error_traceback = event[1]["data"]["error_traceback"]
queue_item = self.get_queue_item(item_id)
# always set to failed if have an error, even if previously the item was marked completed or canceled
queue_item = self._set_queue_item_status(
item_id=queue_item.item_id,
status="failed",
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
except SessionQueueItemNotFoundError:
return
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
try:
item_id = event[1]["data"]["queue_item_id"]
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
except SessionQueueItemNotFoundError:
return
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
@@ -248,7 +306,11 @@ class SqliteSessionQueue(SessionQueueBase):
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
self.__invoker.services.events.emit_queue_item_status_changed(
session_queue_item=queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
@@ -357,11 +419,15 @@ class SqliteSessionQueue(SessionQueueBase):
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
return queue_item
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["canceled", "failed", "completed"]:
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
queue_batch_id=queue_item.batch_id,
graph_execution_state_id=queue_item.session_id,
)
return queue_item
def fail_queue_item(
@@ -371,13 +437,21 @@ class SqliteSessionQueue(SessionQueueBase):
error_message: str,
error_traceback: str,
) -> SessionQueueItem:
queue_item = self._set_queue_item_status(
item_id=item_id,
status="failed",
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["canceled", "failed", "completed"]:
queue_item = self._set_queue_item_status(
item_id=item_id,
status="failed",
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
queue_batch_id=queue_item.batch_id,
graph_execution_state_id=queue_item.session_id,
)
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
@@ -413,10 +487,18 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
session_queue_item=current_queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
except Exception:
self.__conn.rollback()
@@ -456,10 +538,18 @@ class SqliteSessionQueue(SessionQueueBase):
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
queue_batch_id=current_queue_item.batch_id,
graph_execution_state_id=current_queue_item.session_id,
)
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(
current_queue_item, batch_status, queue_status
session_queue_item=current_queue_item,
batch_status=batch_status,
queue_status=queue_status,
)
except Exception:
self.__conn.rollback()

View File

@@ -2,19 +2,18 @@
import copy
import itertools
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import (
BaseModel,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
ValidationError,
field_validator,
)
from pydantic.fields import Field
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema
from pydantic_core import CoreSchema
# Importing * is bad karma but needed here for node detection
from invokeai.app.invocations import * # noqa: F401 F403
@@ -278,58 +277,73 @@ class CollectInvocation(BaseInvocation):
return CollectInvocationOutput(collection=copy.copy(self.collection))
class AnyInvocation(BaseInvocation):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
def validate_invocation(v: Any) -> "AnyInvocation":
return BaseInvocation.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocation.get_invocations()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class AnyInvocationOutput(BaseInvocationOutput):
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
return BaseInvocationOutput.get_typeadapter().validate_python(v)
return core_schema.no_info_plain_validator_function(validate_invocation_output)
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# Nodes are too powerful, we have to make our own OpenAPI schema manually
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
oneOf: list[dict[str, str]] = []
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
for name in sorted(names):
oneOf.append({"$ref": f"#/components/schemas/{name}"})
return {"oneOf": oneOf}
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=uuid_string)
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
edges: list[Edge] = Field(
description="The connections between nodes and their fields in this graph",
default_factory=list,
)
@field_validator("nodes", mode="plain")
@classmethod
def validate_nodes(cls, v: dict[str, Any]):
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
# Invocations register themselves as their python modules are executed. The union of all invocations is
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
#
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
# invocations will cause a graph to fail if they are used.
#
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
#
# This same pattern is used in `GraphExecutionState`.
nodes: dict[str, BaseInvocation] = {}
typeadapter = BaseInvocation.get_typeadapter()
for node_id, node in v.items():
nodes[node_id] = typeadapter.validate_python(node)
return nodes
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
# the generated schema as options for the `nodes` field.
#
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
# expected.
#
# You might be tempted to do something like this:
#
# ```py
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
# delattr(cloned_model, "validate_nodes")
# cloned_model.model_rebuild(force=True)
# json_schema = handler(cloned_model.__pydantic_core_schema__)
# ```
#
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
#
# This same pattern is used in `GraphExecutionState`.
class Graph(BaseModel):
id: Optional[str] = Field(default=None, description="The id of this graph")
nodes: dict[
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
] = Field(description="The nodes in this graph")
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
json_schema = handler(Graph.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
@@ -760,7 +774,7 @@ class GraphExecutionState(BaseModel):
)
# The results of executed nodes
results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
# Errors raised when executing nodes
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
@@ -777,12 +791,52 @@ class GraphExecutionState(BaseModel):
default_factory=dict,
)
@field_validator("results", mode="plain")
@classmethod
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
results: dict[str, BaseInvocationOutput] = {}
typeadapter = BaseInvocationOutput.get_typeadapter()
for result_id, result in v.items():
results[result_id] = typeadapter.validate_python(result)
return results
@field_validator("graph")
def graph_is_valid(cls, v: Graph):
"""Validates that the graph is valid"""
v.validate_self()
return v
@classmethod
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state")
graph: Graph = Field(description="The graph being executed")
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
executed: set[str] = Field(description="The set of node ids that have been executed")
executed_history: list[str] = Field(
description="The list of node ids that have been executed, in order of execution"
)
results: dict[
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
] = Field(description="The results of node executions")
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
prepared_source_mapping: dict[str, str] = Field(
description="The map of prepared nodes to original graph nodes"
)
source_prepared_mapping: dict[str, set[str]] = Field(
description="The map of original graph nodes to prepared nodes"
)
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
json_schema = handler.resolve_ref_schema(json_schema)
return json_schema
def next(self) -> Optional[BaseInvocation]:
"""Gets the next node ready to execute."""

View File

@@ -1,6 +1,7 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Optional, Union
from PIL.Image import Image
from torch import Tensor
@@ -352,11 +353,11 @@ class ModelsInterface(InvocationContextInterface):
if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type)
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
else:
_submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type)
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
@@ -381,7 +382,7 @@ class ModelsInterface(InvocationContextInterface):
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
return self._services.model_manager.load.load_model(configs[0], submodel_type)
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
"""Gets a model's config.
@@ -448,10 +449,10 @@ class ConfigInterface(InvocationContextInterface):
class UtilInterface(InvocationContextInterface):
def __init__(
self, services: InvocationServices, data: InvocationContextData, is_canceled: Callable[[], bool]
self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event
) -> None:
super().__init__(services, data)
self._is_canceled = is_canceled
self._cancel_event = cancel_event
def is_canceled(self) -> bool:
"""Checks if the current session has been canceled.
@@ -459,7 +460,7 @@ class UtilInterface(InvocationContextInterface):
Returns:
True if the current session has been canceled, False if not.
"""
return self._is_canceled()
return self._cancel_event.is_set()
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
"""
@@ -534,7 +535,7 @@ class InvocationContext:
def build_invocation_context(
services: InvocationServices,
data: InvocationContextData,
is_canceled: Callable[[], bool],
cancel_event: threading.Event,
) -> InvocationContext:
"""Builds the invocation context for a specific invocation execution.
@@ -551,7 +552,7 @@ def build_invocation_context(
tensors = TensorsInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data)
config = ConfigInterface(services=services, data=data)
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
util = UtilInterface(services=services, data=data, cancel_event=cancel_event)
conditioning = ConditioningInterface(services=services, data=data)
boards = BoardsInterface(services=services, data=data)

View File

@@ -1,116 +0,0 @@
from typing import Any, Callable, Optional
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from pydantic.json_schema import models_json_schema
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
"""Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema
for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and
component_schema."""
defs = component_schema.pop("$defs", {})
for schema_key, json_schema in defs.items():
if schema_key in openapi_schema["components"]["schemas"]:
continue
openapi_schema["components"]["schemas"][schema_key] = json_schema
def get_openapi_func(
app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
) -> Callable[[], dict[str, Any]]:
"""Gets the OpenAPI schema generator function.
Args:
app (FastAPI): The FastAPI app to generate the schema for.
post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the
generated schema before returning it. Defaults to None.
Returns:
Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is
cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour
matches FastAPI's default schema generation caching.
"""
def openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title=app.title,
description="An API for invoking AI image operations",
version="1.0.0",
routes=app.routes,
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
)
# We'll create a map of invocation type to output schema to make some types simpler on the client.
invocation_output_map_properties: dict[str, Any] = {}
invocation_output_map_required: list[str] = []
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
for output in BaseInvocationOutput.get_outputs():
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
move_defs_to_top_level(openapi_schema, json_schema)
openapi_schema["components"]["schemas"][output.__name__] = json_schema
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
# property, so we'll just do it all manually.
for invocation in BaseInvocation.get_invocations():
json_schema = invocation.model_json_schema(
mode="serialization", ref_template="#/components/schemas/{model}"
)
move_defs_to_top_level(openapi_schema, json_schema)
output_title = invocation.get_output_annotation().__name__
outputs_ref = {"$ref": f"#/components/schemas/{output_title}"}
json_schema["output"] = outputs_ref
openapi_schema["components"]["schemas"][invocation.__name__] = json_schema
# Add this invocation and its output to the output map
invocation_type = invocation.get_type()
invocation_output_map_properties[invocation_type] = json_schema["output"]
invocation_output_map_required.append(invocation_type)
# Add the output map to the schema
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
"type": "object",
"properties": invocation_output_map_properties,
"required": invocation_output_map_required,
}
# Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API.
# We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get
# a schema. This has something to do with schema refs - not totally clear. For whatever reason, using
# `models_json_schema` seems to work fine.
additional_models = [
*EventBase.get_events(),
UIConfigBase,
InputFieldJSONSchemaExtra,
OutputFieldJSONSchemaExtra,
ModelIdentifierField,
ProgressImage,
]
additional_schemas = models_json_schema(
[(m, "serialization") for m in additional_models],
ref_template="#/components/schemas/{model}",
)
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
move_defs_to_top_level(openapi_schema, additional_schemas[1])
if post_transform is not None:
openapi_schema = post_transform(openapi_schema)
openapi_schema["components"]["schemas"] = dict(sorted(openapi_schema["components"]["schemas"].items()))
app.openapi_schema = openapi_schema
return app.openapi_schema
return openapi

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable
import torch
from PIL import Image
@@ -13,36 +13,8 @@ if TYPE_CHECKING:
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContextData
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
SDXL_LATENT_RGB_FACTORS = [
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
]
SDXL_SMOOTH_MATRIX = [
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
]
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
SD1_5_LATENT_RGB_FACTORS = [
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
]
def sample_to_lowres_estimated_image(
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
):
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
if smooth_matrix is not None:
@@ -75,12 +47,64 @@ def stable_diffusion_step_callback(
else:
sample = intermediate_state.latents
# TODO: This does not seem to be needed any more?
# # txt2img provides a Tensor in the step_callback
# # img2img provides a PipelineIntermediateState
# if isinstance(sample, PipelineIntermediateState):
# # this was an img2img
# print('img2img')
# latents = sample.latents
# step = sample.step
# else:
# print('txt2img')
# latents = sample
# step = intermediate_state.step
# TODO: only output a preview image when requested
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
# fast latents preview matrix for sdxl
# generated by @StAlKeR7779
sdxl_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3816, 0.4930, 0.5320],
[-0.3753, 0.1631, 0.1739],
[0.1770, 0.3588, -0.2048],
[-0.4350, -0.2644, -0.4289],
],
dtype=sample.dtype,
device=sample.device,
)
sdxl_smooth_matrix = torch.tensor(
[
[0.0358, 0.0964, 0.0358],
[0.0964, 0.4711, 0.0964],
[0.0358, 0.0964, 0.0358],
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
else:
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
# origingally adapted from code by @erucipe and @keturn here:
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
# these updated numbers for v1.5 are from @torridgristle
v1_5_latent_rgb_factors = torch.tensor(
[
# R G B
[0.3444, 0.1385, 0.0670], # L1
[0.1247, 0.4027, 0.1494], # L2
[-0.3192, 0.2513, 0.2103], # L3
[-0.1307, -0.1874, -0.7445], # L4
],
dtype=sample.dtype,
device=sample.device,
)
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
(width, height) = image.size
@@ -89,9 +113,15 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
events.emit_invocation_denoise_progress(
context_data.queue_item,
context_data.invocation,
intermediate_state,
ProgressImage(dataURL=dataURL, width=width, height=height),
events.emit_generator_progress(
queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.queue_item.session_id,
node_id=context_data.invocation.id,
source_node_id=context_data.source_invocation_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step,
order=intermediate_state.order,
total_steps=intermediate_state.total_steps,
)

View File

@@ -42,26 +42,10 @@ T = TypeVar("T")
@dataclass
class CacheRecord(Generic[T]):
"""
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
"""
"""Elements of the cache."""
key: str
model: T
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
_locks: int = 0

View File

@@ -20,6 +20,7 @@ context. Use like this:
import gc
import math
import sys
import time
from contextlib import suppress
from logging import Logger
@@ -161,9 +162,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
if key in self._cached_models:
return
self.make_room(size)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
cache_record = CacheRecord(key, model, size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
@@ -258,37 +257,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
return
source_device = cache_entry.device
source_device = cache_entry.model.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
# This roundabout method for moving the model around is done to avoid
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
# When moving to VRAM, we copy (not move) each element of the state dict from
# RAM to a new state dict in VRAM, and then inject it into the model.
# This operation is slightly faster than running `to()` on the whole model.
#
# When the model needs to be removed from VRAM we simply delete the copy
# of the state dict in VRAM, and reinject the state dict that is cached
# in RAM into the model. So this operation is very fast.
start_model_to_time = time.time()
snapshot_before = self._capture_memory_snapshot()
try:
if cache_entry.state_dict is not None:
assert hasattr(cache_entry.model, "load_state_dict")
if target_device == self.storage_device:
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
else:
new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True)
cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device)
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)
raise e
@@ -368,12 +347,43 @@ class ModelCache(ModelCacheBase[AnyModel]):
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
while True:
cleared = False
for referrer in gc.get_referrers(cache_entry.model):
if type(referrer).__name__ == "frame":
# RuntimeError: cannot clear an executing frame
with suppress(RuntimeError):
referrer.clear()
cleared = True
# break
# repeat if referrers changes(due to frame clear), else exit loop
if cleared:
gc.collect()
else:
break
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
self.logger.debug(
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
f" refs: {refs}"
)
if not cache_entry.locked:
# Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug(
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)

View File

@@ -60,5 +60,5 @@ class ModelLocker(ModelLockerBase):
self._cache_entry.unlock()
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(0)
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.print_cuda_stats()

View File

@@ -1,7 +1,7 @@
"""Textual Inversion wrapper class."""
from pathlib import Path
from typing import Optional, Union
from typing import Dict, List, Optional, Union
import torch
from compel.embeddings_provider import BaseTextualInversionManager
@@ -66,52 +66,35 @@ class TextualInversionModelRaw(RawModel):
return result
class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
# no type hints for BaseTextualInversionManager?
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens: dict[int, list[int]] = {}
self.pad_tokens = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
"""Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens.
For example, suppose we have a `<ti_dog>` TI with 4 vectors that was added to the tokenizer with the following
mapping of tokens to token_ids:
```
<ti_dog>: 49408
<ti_dog-!pad-1>: 49409
<ti_dog-!pad-2>: 49410
<ti_dog-!pad-3>: 49411
```
`self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`.
This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`.
"""
# Short circuit if there are no pad tokens to save a little time.
if len(self.pad_tokens) == 0:
return token_ids
# This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify
# this assumption here.
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
# Expand any TI tokens to their corresponding pad tokens.
new_token_ids: list[int] = []
new_token_ids = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size. The -2 here is compensating for
# compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens.
max_length = self.tokenizer.model_max_length - 2
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length:
# HACK: If TI token expansion causes us to exceed the max text encoder input length, we silently discard
# tokens. Token expansion should happen in a way that is compatible with compel's default handling of long
# prompts.
new_token_ids = new_token_ids[0:max_length]
return new_token_ids

View File

@@ -1021,8 +1021,7 @@
"float": "Kommazahlen",
"enum": "Aufzählung",
"fullyContainNodes": "Vollständig ausgewählte Nodes auswählen",
"editMode": "Im Workflow-Editor bearbeiten",
"resetToDefaultValue": "Auf Standardwert zurücksetzen"
"editMode": "Im Workflow-Editor bearbeiten"
},
"hrf": {
"enableHrf": "Korrektur für hohe Auflösungen",

View File

@@ -148,8 +148,6 @@
"viewingDesc": "Review images in a large gallery view",
"editing": "Editing",
"editingDesc": "Edit on the Control Layers canvas",
"comparing": "Comparing",
"comparingDesc": "Comparing two images",
"enabled": "Enabled",
"disabled": "Disabled"
},
@@ -377,23 +375,7 @@
"bulkDownloadRequestFailed": "Problem Preparing Download",
"bulkDownloadFailed": "Download Failed",
"problemDeletingImages": "Problem Deleting Images",
"problemDeletingImagesDesc": "One or more images could not be deleted",
"viewerImage": "Viewer Image",
"compareImage": "Compare Image",
"openInViewer": "Open in Viewer",
"selectForCompare": "Select for Compare",
"selectAnImageToCompare": "Select an Image to Compare",
"slider": "Slider",
"sideBySide": "Side-by-Side",
"hover": "Hover",
"swapImages": "Swap Images",
"compareOptions": "Comparison Options",
"stretchToFit": "Stretch to Fit",
"exitCompare": "Exit Compare",
"compareHelp1": "Hold <Kbd>Alt</Kbd> while clicking a gallery image or using the arrow keys to change the compare image.",
"compareHelp2": "Press <Kbd>M</Kbd> to cycle through comparison modes.",
"compareHelp3": "Press <Kbd>C</Kbd> to swap the compared images.",
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit."
"problemDeletingImagesDesc": "One or more images could not be deleted"
},
"hotkeys": {
"searchHotkeys": "Search Hotkeys",
@@ -1122,7 +1104,7 @@
"parameters": "Parameters",
"parameterSet": "Parameter Recalled",
"parameterSetDesc": "Recalled {{parameter}}",
"parameterNotSet": "Parameter Not Recalled",
"parameterNotSet": "Parameter Recalled",
"parameterNotSetDesc": "Unable to recall {{parameter}}",
"parameterNotSetDescWithMessage": "Unable to recall {{parameter}}: {{message}}",
"parametersSet": "Parameters Recalled",

View File

@@ -6,7 +6,7 @@
"settingsLabel": "Ajustes",
"img2img": "Imagen a Imagen",
"unifiedCanvas": "Lienzo Unificado",
"nodes": "Flujos de trabajo",
"nodes": "Editor del flujo de trabajo",
"upload": "Subir imagen",
"load": "Cargar",
"statusDisconnected": "Desconectado",
@@ -14,7 +14,7 @@
"discordLabel": "Discord",
"back": "Atrás",
"loading": "Cargando",
"postprocessing": "Postprocesado",
"postprocessing": "Tratamiento posterior",
"txt2img": "De texto a imagen",
"accept": "Aceptar",
"cancel": "Cancelar",
@@ -42,42 +42,7 @@
"copy": "Copiar",
"beta": "Beta",
"on": "En",
"aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:",
"installed": "Instalado",
"green": "Verde",
"editor": "Editor",
"orderBy": "Ordenar por",
"file": "Archivo",
"goTo": "Ir a",
"imageFailedToLoad": "No se puede cargar la imagen",
"saveAs": "Guardar Como",
"somethingWentWrong": "Algo salió mal",
"nextPage": "Página Siguiente",
"selected": "Seleccionado",
"tab": "Tabulador",
"positivePrompt": "Prompt Positivo",
"negativePrompt": "Prompt Negativo",
"error": "Error",
"format": "formato",
"unknown": "Desconocido",
"input": "Entrada",
"nodeEditor": "Editor de nodos",
"template": "Plantilla",
"prevPage": "Página Anterior",
"red": "Rojo",
"alpha": "Transparencia",
"outputs": "Salidas",
"editing": "Editando",
"learnMore": "Aprende más",
"enabled": "Activado",
"disabled": "Desactivado",
"folder": "Carpeta",
"updated": "Actualizado",
"created": "Creado",
"save": "Guardar",
"unknownError": "Error Desconocido",
"blue": "Azul",
"viewingDesc": "Revisar imágenes en una vista de galería grande"
"aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:"
},
"gallery": {
"galleryImageSize": "Tamaño de la imagen",
@@ -502,8 +467,7 @@
"about": "Acerca de",
"createIssue": "Crear un problema",
"resetUI": "Interfaz de usuario $t(accessibility.reset)",
"mode": "Modo",
"submitSupportTicket": "Enviar Ticket de Soporte"
"mode": "Modo"
},
"nodes": {
"zoomInNodes": "Acercar",
@@ -579,17 +543,5 @@
"layers_one": "Capa",
"layers_many": "Capas",
"layers_other": "Capas"
},
"controlnet": {
"crop": "Cortar",
"delete": "Eliminar",
"depthAnythingDescription": "Generación de mapa de profundidad usando la técnica de Depth Anything",
"duplicate": "Duplicar",
"colorMapDescription": "Genera un mapa de color desde la imagen",
"depthMidasDescription": "Crea un mapa de profundidad con Midas",
"balanced": "Equilibrado",
"beginEndStepPercent": "Inicio / Final Porcentaje de pasos",
"detectResolution": "Detectar resolución",
"beginEndStepPercentShort": "Inicio / Final %"
}
}

View File

@@ -45,7 +45,7 @@
"outputs": "Risultati",
"data": "Dati",
"somethingWentWrong": "Qualcosa è andato storto",
"copyError": "Errore $t(gallery.copy)",
"copyError": "$t(gallery.copy) Errore",
"input": "Ingresso",
"notInstalled": "Non $t(common.installed)",
"unknownError": "Errore sconosciuto",
@@ -85,11 +85,7 @@
"viewing": "Visualizza",
"viewingDesc": "Rivedi le immagini in un'ampia vista della galleria",
"editing": "Modifica",
"editingDesc": "Modifica nell'area Livelli di controllo",
"enabled": "Abilitato",
"disabled": "Disabilitato",
"comparingDesc": "Confronta due immagini",
"comparing": "Confronta"
"editingDesc": "Modifica nell'area Livelli di controllo"
},
"gallery": {
"galleryImageSize": "Dimensione dell'immagine",
@@ -126,30 +122,14 @@
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
"bulkDownloadRequestFailed": "Problema durante la preparazione del download",
"bulkDownloadFailed": "Scaricamento fallito",
"alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine",
"openInViewer": "Apri nel visualizzatore",
"selectForCompare": "Seleziona per il confronto",
"selectAnImageToCompare": "Seleziona un'immagine da confrontare",
"slider": "Cursore",
"sideBySide": "Fianco a Fianco",
"compareImage": "Immagine di confronto",
"viewerImage": "Immagine visualizzata",
"hover": "Al passaggio del mouse",
"swapImages": "Scambia le immagini",
"compareOptions": "Opzioni di confronto",
"stretchToFit": "Scala per adattare",
"exitCompare": "Esci dal confronto",
"compareHelp1": "Tieni premuto <Kbd>Alt</Kbd> mentre fai clic su un'immagine della galleria o usi i tasti freccia per cambiare l'immagine di confronto.",
"compareHelp2": "Premi <Kbd>M</Kbd> per scorrere le modalità di confronto.",
"compareHelp3": "Premi <Kbd>C</Kbd> per scambiare le immagini confrontate.",
"compareHelp4": "Premi <Kbd>Z</Kbd> o <Kbd>Esc</Kbd> per uscire."
"alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine"
},
"hotkeys": {
"keyboardShortcuts": "Tasti di scelta rapida",
"appHotkeys": "Applicazione",
"generalHotkeys": "Generale",
"galleryHotkeys": "Galleria",
"unifiedCanvasHotkeys": "Tela",
"unifiedCanvasHotkeys": "Tela Unificata",
"invoke": {
"title": "Invoke",
"desc": "Genera un'immagine"
@@ -167,8 +147,8 @@
"desc": "Apre e chiude il pannello delle opzioni"
},
"pinOptions": {
"title": "Fissa le opzioni",
"desc": "Fissa il pannello delle opzioni"
"title": "Appunta le opzioni",
"desc": "Blocca il pannello delle opzioni"
},
"toggleGallery": {
"title": "Attiva/disattiva galleria",
@@ -352,14 +332,14 @@
"title": "Annulla e cancella"
},
"resetOptionsAndGallery": {
"title": "Ripristina le opzioni e la galleria",
"desc": "Reimposta i pannelli delle opzioni e della galleria"
"title": "Ripristina Opzioni e Galleria",
"desc": "Reimposta le opzioni e i pannelli della galleria"
},
"searchHotkeys": "Cerca tasti di scelta rapida",
"noHotkeysFound": "Nessun tasto di scelta rapida trovato",
"toggleOptionsAndGallery": {
"desc": "Apre e chiude le opzioni e i pannelli della galleria",
"title": "Attiva/disattiva le opzioni e la galleria"
"title": "Attiva/disattiva le Opzioni e la Galleria"
},
"clearSearch": "Cancella ricerca",
"remixImage": {
@@ -368,7 +348,7 @@
},
"toggleViewer": {
"title": "Attiva/disattiva il visualizzatore di immagini",
"desc": "Passa dal visualizzatore immagini all'area di lavoro per la scheda corrente."
"desc": "Passa dal Visualizzatore immagini all'area di lavoro per la scheda corrente."
}
},
"modelManager": {
@@ -398,7 +378,7 @@
"convertToDiffusers": "Converti in Diffusori",
"convertToDiffusersHelpText2": "Questo processo sostituirà la voce in Gestione Modelli con la versione Diffusori dello stesso modello.",
"convertToDiffusersHelpText4": "Questo è un processo una tantum. Potrebbero essere necessari circa 30-60 secondi a seconda delle specifiche del tuo computer.",
"convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB in dimensione.",
"convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB di dimensioni.",
"convertToDiffusersHelpText6": "Vuoi convertire questo modello?",
"modelConverted": "Modello convertito",
"alpha": "Alpha",
@@ -548,7 +528,7 @@
"layer": {
"initialImageNoImageSelected": "Nessuna immagine iniziale selezionata",
"t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}",
"controlAdapterNoModelSelected": "Nessun modello di adattatore di controllo selezionato",
"controlAdapterNoModelSelected": "Nessun modello di Adattatore di Controllo selezionato",
"controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile",
"controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata",
"controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata",
@@ -626,25 +606,25 @@
"canvasMerged": "Tela unita",
"sentToImageToImage": "Inviato a Generazione da immagine",
"sentToUnifiedCanvas": "Inviato alla Tela",
"parametersNotSet": "Parametri non richiamati",
"parametersNotSet": "Parametri non impostati",
"metadataLoadFailed": "Impossibile caricare i metadati",
"serverError": "Errore del Server",
"connected": "Connesso al server",
"connected": "Connesso al Server",
"canceled": "Elaborazione annullata",
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
"parameterSet": "Parametro richiamato",
"parameterNotSet": "Parametro non richiamato",
"parameterSet": "{{parameter}} impostato",
"parameterNotSet": "{{parameter}} non impostato",
"problemCopyingImage": "Impossibile copiare l'immagine",
"baseModelChangedCleared_one": "Cancellato o disabilitato {{count}} sottomodello incompatibile",
"baseModelChangedCleared_many": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
"baseModelChangedCleared_other": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
"baseModelChangedCleared_one": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modello incompatibile",
"baseModelChangedCleared_many": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili",
"baseModelChangedCleared_other": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili",
"imageSavingFailed": "Salvataggio dell'immagine non riuscito",
"canvasSentControlnetAssets": "Tela inviata a ControlNet & Risorse",
"problemCopyingCanvasDesc": "Impossibile copiare la tela",
"loadedWithWarnings": "Flusso di lavoro caricato con avvisi",
"canvasCopiedClipboard": "Tela copiata negli appunti",
"maskSavedAssets": "Maschera salvata nelle risorse",
"problemDownloadingCanvas": "Problema durante lo scarico della tela",
"problemDownloadingCanvas": "Problema durante il download della tela",
"problemMergingCanvas": "Problema nell'unione delle tele",
"imageUploaded": "Immagine caricata",
"addedToBoard": "Aggiunto alla bacheca",
@@ -678,17 +658,7 @@
"problemDownloadingImage": "Impossibile scaricare l'immagine",
"prunedQueue": "Coda ripulita",
"modelImportCanceled": "Importazione del modello annullata",
"parameters": "Parametri",
"parameterSetDesc": "{{parameter}} richiamato",
"parameterNotSetDesc": "Impossibile richiamare {{parameter}}",
"parameterNotSetDescWithMessage": "Impossibile richiamare {{parameter}}: {{message}}",
"parametersSet": "Parametri richiamati",
"errorCopied": "Errore copiato",
"outOfMemoryError": "Errore di memoria esaurita",
"baseModelChanged": "Modello base modificato",
"sessionRef": "Sessione: {{sessionId}}",
"somethingWentWrong": "Qualcosa è andato storto",
"outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova."
"parameters": "Parametri"
},
"tooltip": {
"feature": {
@@ -704,7 +674,7 @@
"layer": "Livello",
"base": "Base",
"mask": "Maschera",
"maskingOptions": "Opzioni maschera",
"maskingOptions": "Opzioni di mascheramento",
"enableMask": "Abilita maschera",
"preserveMaskedArea": "Mantieni area mascherata",
"clearMask": "Cancella maschera (Shift+C)",
@@ -775,8 +745,7 @@
"mode": "Modalità",
"resetUI": "$t(accessibility.reset) l'Interfaccia Utente",
"createIssue": "Segnala un problema",
"about": "Informazioni",
"submitSupportTicket": "Invia ticket di supporto"
"about": "Informazioni"
},
"nodes": {
"zoomOutNodes": "Rimpicciolire",
@@ -821,7 +790,7 @@
"workflowNotes": "Note",
"versionUnknown": " Versione sconosciuta",
"unableToValidateWorkflow": "Impossibile convalidare il flusso di lavoro",
"updateApp": "Aggiorna Applicazione",
"updateApp": "Aggiorna App",
"unableToLoadWorkflow": "Impossibile caricare il flusso di lavoro",
"updateNode": "Aggiorna nodo",
"version": "Versione",
@@ -913,14 +882,11 @@
"missingNode": "Nodo di invocazione mancante",
"missingInvocationTemplate": "Modello di invocazione mancante",
"missingFieldTemplate": "Modello di campo mancante",
"singleFieldType": "{{name}} (Singola)",
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite",
"boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
"modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
"singleFieldType": "{{name}} (Singola)"
},
"boards": {
"autoAddBoard": "Aggiungi automaticamente bacheca",
"menuItemAutoAdd": "Aggiungi automaticamente a questa bacheca",
"menuItemAutoAdd": "Aggiungi automaticamente a questa Bacheca",
"cancel": "Annulla",
"addBoard": "Aggiungi Bacheca",
"bottomMessage": "L'eliminazione di questa bacheca e delle sue immagini ripristinerà tutte le funzionalità che le stanno attualmente utilizzando.",
@@ -932,7 +898,7 @@
"myBoard": "Bacheca",
"searchBoard": "Cerca bacheche ...",
"noMatching": "Nessuna bacheca corrispondente",
"selectBoard": "Seleziona una bacheca",
"selectBoard": "Seleziona una Bacheca",
"uncategorized": "Non categorizzato",
"downloadBoard": "Scarica la bacheca",
"deleteBoardOnly": "solo la Bacheca",
@@ -953,7 +919,7 @@
"control": "Controllo",
"crop": "Ritaglia",
"depthMidas": "Profondità (Midas)",
"detectResolution": "Rileva la risoluzione",
"detectResolution": "Rileva risoluzione",
"controlMode": "Modalità di controllo",
"cannyDescription": "Canny rilevamento bordi",
"depthZoe": "Profondità (Zoe)",
@@ -964,7 +930,7 @@
"showAdvanced": "Mostra opzioni Avanzate",
"bgth": "Soglia rimozione sfondo",
"importImageFromCanvas": "Importa immagine dalla Tela",
"lineartDescription": "Converte l'immagine in linea",
"lineartDescription": "Converte l'immagine in lineart",
"importMaskFromCanvas": "Importa maschera dalla Tela",
"hideAdvanced": "Nascondi opzioni avanzate",
"resetControlImage": "Reimposta immagine di controllo",
@@ -980,7 +946,7 @@
"pidiDescription": "Elaborazione immagini PIDI",
"fill": "Riempie",
"colorMapDescription": "Genera una mappa dei colori dall'immagine",
"lineartAnimeDescription": "Elaborazione linea in stile anime",
"lineartAnimeDescription": "Elaborazione lineart in stile anime",
"imageResolution": "Risoluzione dell'immagine",
"colorMap": "Colore",
"lowThreshold": "Soglia inferiore",

View File

@@ -87,11 +87,7 @@
"viewing": "Просмотр",
"editing": "Редактирование",
"viewingDesc": "Просмотр изображений в режиме большой галереи",
"editingDesc": "Редактировать на холсте слоёв управления",
"enabled": "Включено",
"disabled": "Отключено",
"comparingDesc": "Сравнение двух изображений",
"comparing": "Сравнение"
"editingDesc": "Редактировать на холсте слоёв управления"
},
"gallery": {
"galleryImageSize": "Размер изображений",
@@ -128,23 +124,7 @@
"bulkDownloadRequested": "Подготовка к скачиванию",
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания",
"alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения",
"openInViewer": "Открыть в просмотрщике",
"selectForCompare": "Выбрать для сравнения",
"hover": "Наведение",
"swapImages": "Поменять местами",
"stretchToFit": "Растягивание до нужного размера",
"exitCompare": "Выйти из сравнения",
"compareHelp4": "Нажмите <Kbd>Z</Kbd> или <Kbd>Esc</Kbd> для выхода.",
"compareImage": "Сравнить изображение",
"viewerImage": "Изображение просмотрщика",
"selectAnImageToCompare": "Выберите изображение для сравнения",
"slider": "Слайдер",
"sideBySide": "Бок о бок",
"compareOptions": "Варианты сравнения",
"compareHelp1": "Удерживайте <Kbd>Alt</Kbd> при нажатии на изображение в галерее или при помощи клавиш со стрелками, чтобы изменить сравниваемое изображение.",
"compareHelp2": "Нажмите <Kbd>M</Kbd>, чтобы переключиться между режимами сравнения.",
"compareHelp3": "Нажмите <Kbd>C</Kbd>, чтобы поменять местами сравниваемые изображения."
"alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения"
},
"hotkeys": {
"keyboardShortcuts": "Горячие клавиши",
@@ -548,20 +528,7 @@
"missingFieldTemplate": "Отсутствует шаблон поля",
"addingImagesTo": "Добавление изображений в",
"invoke": "Создать",
"imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается",
"layer": {
"controlAdapterImageNotProcessed": "Изображение адаптера контроля не обработано",
"ipAdapterNoModelSelected": "IP адаптер не выбран",
"controlAdapterNoModelSelected": "не выбрана модель адаптера контроля",
"controlAdapterIncompatibleBaseModel": "несовместимая базовая модель адаптера контроля",
"controlAdapterNoImageSelected": "не выбрано изображение контрольного адаптера",
"initialImageNoImageSelected": "начальное изображение не выбрано",
"rgNoRegion": "регион не выбран",
"rgNoPromptsOrIPAdapters": "нет текстовых запросов или IP-адаптеров",
"ipAdapterIncompatibleBaseModel": "несовместимая базовая модель IP-адаптера",
"t2iAdapterIncompatibleDimensions": "Адаптер T2I требует, чтобы размеры изображения были кратны {{multiple}}",
"ipAdapterNoImageSelected": "изображение IP-адаптера не выбрано"
}
"imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается"
},
"isAllowedToUpscale": {
"useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2",
@@ -639,12 +606,12 @@
"connected": "Подключено к серверу",
"canceled": "Обработка отменена",
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
"parameterNotSet": "Параметр не задан",
"parameterSet": "Параметр задан",
"parameterNotSet": "Параметр {{parameter}} не задан",
"parameterSet": "Параметр {{parameter}} задан",
"problemCopyingImage": "Не удается скопировать изображение",
"baseModelChangedCleared_one": "Очищена или отключена {{count}} несовместимая подмодель",
"baseModelChangedCleared_few": "Очищены или отключены {{count}} несовместимые подмодели",
"baseModelChangedCleared_many": "Очищены или отключены {{count}} несовместимых подмоделей",
"baseModelChangedCleared_one": "Базовая модель изменила, очистила или отключила {{count}} несовместимую подмодель",
"baseModelChangedCleared_few": "Базовая модель изменила, очистила или отключила {{count}} несовместимые подмодели",
"baseModelChangedCleared_many": "Базовая модель изменила, очистила или отключила {{count}} несовместимых подмоделей",
"imageSavingFailed": "Не удалось сохранить изображение",
"canvasSentControlnetAssets": "Холст отправлен в ControlNet и ресурсы",
"problemCopyingCanvasDesc": "Невозможно экспортировать базовый слой",
@@ -685,17 +652,7 @@
"resetInitialImage": "Сбросить начальное изображение",
"prunedQueue": "Урезанная очередь",
"modelImportCanceled": "Импорт модели отменен",
"parameters": "Параметры",
"parameterSetDesc": "Задан {{parameter}}",
"parameterNotSetDesc": "Невозможно задать {{parameter}}",
"baseModelChanged": "Базовая модель сменена",
"parameterNotSetDescWithMessage": "Не удалось задать {{parameter}}: {{message}}",
"parametersSet": "Параметры заданы",
"errorCopied": "Ошибка скопирована",
"sessionRef": "Сессия: {{sessionId}}",
"outOfMemoryError": "Ошибка нехватки памяти",
"outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
"somethingWentWrong": "Что-то пошло не так"
"parameters": "Параметры"
},
"tooltip": {
"feature": {
@@ -782,8 +739,7 @@
"loadMore": "Загрузить больше",
"resetUI": "$t(accessibility.reset) интерфейс",
"createIssue": "Сообщить о проблеме",
"about": "Об этом",
"submitSupportTicket": "Отправить тикет в службу поддержки"
"about": "Об этом"
},
"nodes": {
"zoomInNodes": "Увеличьте масштаб",
@@ -876,7 +832,7 @@
"workflowName": "Название",
"collection": "Коллекция",
"unknownErrorValidatingWorkflow": "Неизвестная ошибка при проверке рабочего процесса",
"collectionFieldType": "{{name}} (Коллекция)",
"collectionFieldType": "Коллекция {{name}}",
"workflowNotes": "Примечания",
"string": "Строка",
"unknownNodeType": "Неизвестный тип узла",
@@ -892,7 +848,7 @@
"targetNodeDoesNotExist": "Недопустимое ребро: целевой/входной узел {{node}} не существует",
"mismatchedVersion": "Недопустимый узел: узел {{node}} типа {{type}} имеет несоответствующую версию (попробовать обновить?)",
"unknownFieldType": "$t(nodes.unknownField) тип: {{type}}",
"collectionOrScalarFieldType": "{{name}} (Один или коллекция)",
"collectionOrScalarFieldType": "Коллекция | Скаляр {{name}}",
"betaDesc": "Этот вызов находится в бета-версии. Пока он не станет стабильным, в нем могут происходить изменения при обновлении приложений. Мы планируем поддерживать этот вызов в течение длительного времени.",
"nodeVersion": "Версия узла",
"loadingNodes": "Загрузка узлов...",
@@ -914,16 +870,7 @@
"noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений.",
"graph": "График",
"showEdgeLabels": "Показать метки на ребрах",
"showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы",
"cannotMixAndMatchCollectionItemTypes": "Невозможно смешивать и сопоставлять типы элементов коллекции",
"missingNode": "Отсутствует узел вызова",
"missingInvocationTemplate": "Отсутствует шаблон вызова",
"missingFieldTemplate": "Отсутствующий шаблон поля",
"singleFieldType": "{{name}} (Один)",
"noGraph": "Нет графика",
"imageAccessError": "Невозможно найти изображение {{image_name}}, сбрасываем на значение по умолчанию",
"boardAccessError": "Невозможно найти доску {{board_id}}, сбрасываем на значение по умолчанию",
"modelAccessError": "Невозможно найти модель {{key}}, сброс на модель по умолчанию"
"showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы"
},
"controlnet": {
"amult": "a_mult",
@@ -1494,16 +1441,7 @@
"clearQueueAlertDialog2": "Вы уверены, что хотите очистить очередь?",
"item": "Элемент",
"graphFailedToQueue": "Не удалось поставить график в очередь",
"openQueue": "Открыть очередь",
"prompts_one": "Запрос",
"prompts_few": "Запроса",
"prompts_many": "Запросов",
"iterations_one": "Итерация",
"iterations_few": "Итерации",
"iterations_many": "Итераций",
"generations_one": "Генерация",
"generations_few": "Генерации",
"generations_many": "Генераций"
"openQueue": "Открыть очередь"
},
"sdxl": {
"refinerStart": "Запуск доработчика",

View File

@@ -1,6 +1,6 @@
{
"common": {
"nodes": "工作流程",
"nodes": "節點",
"img2img": "圖片轉圖片",
"statusDisconnected": "已中斷連線",
"back": "返回",
@@ -11,239 +11,17 @@
"reportBugLabel": "回報錯誤",
"githubLabel": "GitHub",
"hotkeysLabel": "快捷鍵",
"languagePickerLabel": "語言",
"languagePickerLabel": "切換語言",
"unifiedCanvas": "統一畫布",
"cancel": "取消",
"txt2img": "文字轉圖片",
"controlNet": "ControlNet",
"advanced": "進階",
"folder": "資料夾",
"installed": "已安裝",
"accept": "接受",
"goTo": "前往",
"input": "輸入",
"random": "隨機",
"selected": "已選擇",
"communityLabel": "社群",
"loading": "載入中",
"delete": "刪除",
"copy": "複製",
"error": "錯誤",
"file": "檔案",
"format": "格式",
"imageFailedToLoad": "無法載入圖片"
"txt2img": "文字轉圖片"
},
"accessibility": {
"invokeProgressBar": "Invoke 進度條",
"uploadImage": "上傳圖片",
"reset": "重",
"reset": "重",
"nextImage": "下一張圖片",
"previousImage": "上一張圖片",
"menu": "選單",
"loadMore": "載入更多",
"about": "關於",
"createIssue": "建立問題",
"resetUI": "$t(accessibility.reset) 介面",
"submitSupportTicket": "提交支援工單",
"mode": "模式"
},
"boards": {
"loading": "載入中…",
"movingImagesToBoard_other": "正在移動 {{count}} 張圖片至板上:",
"move": "移動",
"uncategorized": "未分類",
"cancel": "取消"
},
"metadata": {
"workflow": "工作流程",
"steps": "步數",
"model": "模型",
"seed": "種子",
"vae": "VAE",
"seamless": "無縫",
"metadata": "元數據",
"width": "寬度",
"height": "高度"
},
"accordions": {
"control": {
"title": "控制"
},
"compositing": {
"title": "合成"
},
"advanced": {
"title": "進階",
"options": "$t(accordions.advanced.title) 選項"
}
},
"hotkeys": {
"nodesHotkeys": "節點",
"cancel": {
"title": "取消"
},
"generalHotkeys": "一般",
"keyboardShortcuts": "快捷鍵",
"appHotkeys": "應用程式"
},
"modelManager": {
"advanced": "進階",
"allModels": "全部模型",
"variant": "變體",
"config": "配置",
"model": "模型",
"selected": "已選擇",
"huggingFace": "HuggingFace",
"install": "安裝",
"metadata": "元數據",
"delete": "刪除",
"description": "描述",
"cancel": "取消",
"convert": "轉換",
"manual": "手動",
"none": "無",
"name": "名稱",
"load": "載入",
"height": "高度",
"width": "寬度",
"search": "搜尋",
"vae": "VAE",
"settings": "設定"
},
"controlnet": {
"mlsd": "M-LSD",
"canny": "Canny",
"duplicate": "重複",
"none": "無",
"pidi": "PIDI",
"h": "H",
"balanced": "平衡",
"crop": "裁切",
"processor": "處理器",
"control": "控制",
"f": "F",
"lineart": "線條藝術",
"w": "W",
"hed": "HED",
"delete": "刪除"
},
"queue": {
"queue": "佇列",
"canceled": "已取消",
"failed": "已失敗",
"completed": "已完成",
"cancel": "取消",
"session": "工作階段",
"batch": "批量",
"item": "項目",
"completedIn": "完成於",
"notReady": "無法排隊"
},
"parameters": {
"cancel": {
"cancel": "取消"
},
"height": "高度",
"type": "類型",
"symmetry": "對稱性",
"images": "圖片",
"width": "寬度",
"coherenceMode": "模式",
"seed": "種子",
"general": "一般",
"strength": "強度",
"steps": "步數",
"info": "資訊"
},
"settings": {
"beta": "Beta",
"developer": "開發者",
"general": "一般",
"models": "模型"
},
"popovers": {
"paramModel": {
"heading": "模型"
},
"compositingCoherenceMode": {
"heading": "模式"
},
"paramSteps": {
"heading": "步數"
},
"controlNetProcessor": {
"heading": "處理器"
},
"paramVAE": {
"heading": "VAE"
},
"paramHeight": {
"heading": "高度"
},
"paramSeed": {
"heading": "種子"
},
"paramWidth": {
"heading": "寬度"
},
"refinerSteps": {
"heading": "步數"
}
},
"unifiedCanvas": {
"undo": "復原",
"mask": "遮罩",
"eraser": "橡皮擦",
"antialiasing": "抗鋸齒",
"redo": "重做",
"layer": "圖層",
"accept": "接受",
"brush": "刷子",
"move": "移動",
"brushSize": "大小"
},
"nodes": {
"workflowName": "名稱",
"notes": "註釋",
"workflowVersion": "版本",
"workflowNotes": "註釋",
"executionStateError": "錯誤",
"unableToUpdateNodes_other": "無法更新 {{count}} 個節點",
"integer": "整數",
"workflow": "工作流程",
"enum": "枚舉",
"edit": "編輯",
"string": "字串",
"workflowTags": "標籤",
"node": "節點",
"boolean": "布林值",
"workflowAuthor": "作者",
"version": "版本",
"executionStateCompleted": "已完成",
"edge": "邊緣",
"versionUnknown": " 版本未知"
},
"sdxl": {
"steps": "步數",
"loading": "載入中…",
"refiner": "精煉器"
},
"gallery": {
"copy": "複製",
"download": "下載",
"loading": "載入中"
},
"ui": {
"tabs": {
"models": "模型",
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
"queue": "佇列"
}
},
"models": {
"loading": "載入中"
},
"workflows": {
"name": "名稱"
"menu": "選單"
}
}

View File

@@ -19,13 +19,6 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
return extendTheme({
..._theme,
direction,
shadows: {
..._theme.shadows,
selectedForCompare:
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-400)',
hoverSelectedForCompare:
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-300)',
},
});
}, [direction]);

View File

@@ -6,8 +6,8 @@ import { useAppDispatch } from 'app/store/storeHooks';
import type { MapStore } from 'nanostores';
import { atom, map } from 'nanostores';
import { useEffect, useMemo } from 'react';
import { setEventListeners } from 'services/events/setEventListeners';
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
import { setEventListeners } from 'services/events/util/setEventListeners';
import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client';
import { io } from 'socket.io-client';

View File

@@ -35,22 +35,26 @@ import { addImageUploadedFulfilledListener } from 'app/store/middleware/listener
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete';
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
import { addSocketSubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed';
import { addSocketUnsubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed';
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
import type { AppDispatch, RootState } from 'app/store/store';
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
export const listenerMiddleware = createListenerMiddleware();
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
@@ -98,11 +102,14 @@ addCommitStagingAreaImageListener(startAppListening);
// Socket.IO
addGeneratorProgressEventListener(startAppListening);
addGraphExecutionStateCompleteEventListener(startAppListening);
addInvocationCompleteEventListener(startAppListening);
addInvocationErrorEventListener(startAppListening);
addInvocationStartedEventListener(startAppListening);
addSocketConnectedEventListener(startAppListening);
addSocketDisconnectedEventListener(startAppListening);
addSocketSubscribedEventListener(startAppListening);
addSocketUnsubscribedEventListener(startAppListening);
addModelLoadEventListener(startAppListening);
addModelInstallEventListener(startAppListening);
addSocketQueueItemStatusChangedEventListener(startAppListening);

View File

@@ -5,8 +5,8 @@ import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
import {
socketBulkDownloadComplete,
socketBulkDownloadError,
socketBulkDownloadCompleted,
socketBulkDownloadFailed,
socketBulkDownloadStarted,
} from 'services/events/actions';
@@ -54,7 +54,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
});
startAppListening({
actionCreator: socketBulkDownloadComplete,
actionCreator: socketBulkDownloadCompleted,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation completed');
@@ -80,7 +80,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
});
startAppListening({
actionCreator: socketBulkDownloadError,
actionCreator: socketBulkDownloadFailed,
effect: async (action) => {
log.debug(action.payload.data, 'Bulk download preparation failed');

View File

@@ -13,6 +13,7 @@ import {
isControlAdapterLayer,
} from 'features/controlLayers/store/controlLayersSlice';
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
import { isImageOutput } from 'features/nodes/types/common';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { isEqual } from 'lodash-es';
@@ -132,13 +133,13 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.invocation_source_id === processorNode.id
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
action.payload.data.source_node_id === processorNode.id
);
// We still have to check the output type
assert(
invocationCompleteAction.payload.data.result.type === 'image_output',
isImageOutput(invocationCompleteAction.payload.data.result),
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
);
const { image_name } = invocationCompleteAction.payload.data.result.image;

View File

@@ -9,6 +9,7 @@ import {
selectControlAdapterById,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
import { isImageOutput } from 'features/nodes/types/common';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { imagesApi } from 'services/api/endpoints/images';
@@ -68,12 +69,12 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
const [invocationCompleteAction] = await take(
(action): action is ReturnType<typeof socketInvocationComplete> =>
socketInvocationComplete.match(action) &&
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
action.payload.data.invocation_source_id === nodeId
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
action.payload.data.source_node_id === nodeId
);
// We still have to check the output type
if (invocationCompleteAction.payload.data.result.type === 'image_output') {
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
const { image_name } = invocationCompleteAction.payload.data.result.image;
// Wait for the ImageDTO to be received

View File

@@ -1,7 +1,7 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { imagesSelectors } from 'services/api/util';
@@ -11,7 +11,6 @@ export const galleryImageClicked = createAction<{
shiftKey: boolean;
ctrlKey: boolean;
metaKey: boolean;
altKey: boolean;
}>('gallery/imageClicked');
/**
@@ -29,7 +28,7 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
startAppListening({
actionCreator: galleryImageClicked,
effect: async (action, { dispatch, getState }) => {
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const { imageDTO, shiftKey, ctrlKey, metaKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
@@ -42,13 +41,7 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
const imageDTOs = imagesSelectors.selectAll(listImagesData);
const selection = state.gallery.selection;
if (altKey) {
if (state.gallery.imageToCompare?.image_name === imageDTO.image_name) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(imageToCompareChanged(imageDTO));
}
} else if (shiftKey) {
if (shiftKey) {
const rangeEndImageName = imageDTO.image_name;
const lastSelectedImage = selection[selection.length - 1]?.image_name;
const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage);

View File

@@ -14,8 +14,7 @@ import {
rgLayerIPAdapterImageChanged,
} from 'features/controlLayers/store/controlLayersSlice';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import { isValidDrop } from 'features/dnd/util/isValidDrop';
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { imagesApi } from 'services/api/endpoints/images';
@@ -31,9 +30,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
effect: async (action, { dispatch, getState }) => {
const log = logger('dnd');
const { activeData, overData } = action.payload;
if (!isValidDrop(overData, activeData)) {
return;
}
if (activeData.payloadType === 'IMAGE_DTO') {
log.debug({ activeData, overData }, 'Image dropped');
@@ -54,7 +50,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
activeData.payload.imageDTO
) {
dispatch(imageSelected(activeData.payload.imageDTO));
dispatch(isImageViewerOpenChanged(true));
return;
}
@@ -187,18 +182,24 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
}
/**
* Image selected for compare
* TODO
* Image selection dropped on node image collection field
*/
if (
overData.actionType === 'SELECT_FOR_COMPARE' &&
activeData.payloadType === 'IMAGE_DTO' &&
activeData.payload.imageDTO
) {
const { imageDTO } = activeData.payload;
dispatch(imageToCompareChanged(imageDTO));
dispatch(isImageViewerOpenChanged(true));
return;
}
// if (
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
// activeData.payloadType === 'IMAGE_DTO' &&
// activeData.payload.imageDTO
// ) {
// const { fieldName, nodeId } = overData.context;
// dispatch(
// fieldValueChanged({
// nodeId,
// fieldName,
// value: [activeData.payload.imageDTO],
// })
// );
// return;
// }
/**
* Image dropped on user board

View File

@@ -1,7 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions';
@@ -12,9 +11,9 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
startAppListening({
actionCreator: socketGeneratorProgress,
effect: (action) => {
log.trace(parseify(action.payload), `Generator progress`);
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
log.trace(action.payload, `Generator progress`);
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;

View File

@@ -0,0 +1,14 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketGraphExecutionStateComplete } from 'services/events/actions';
const log = logger('socketio');
export const addGraphExecutionStateCompleteEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketGraphExecutionStateComplete,
effect: (action) => {
log.debug(action.payload, 'Session complete');
},
});
};

View File

@@ -11,6 +11,7 @@ import {
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { isImageOutput } from 'features/nodes/types/common';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards';
@@ -28,12 +29,12 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
actionCreator: socketInvocationComplete,
effect: async (action, { dispatch, getState }) => {
const { data } = action.payload;
log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation.type})`);
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
const { result, invocation_source_id } = data;
const { result, node, queue_batch_id, source_node_id } = data;
// This complete event has an associated image output
if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) {
const { image_name } = data.result.image;
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
const { image_name } = result.image;
const { canvas, gallery } = getState();
// This populates the `getImageDTO` cache
@@ -47,7 +48,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
imageDTORequest.unsubscribe();
// Add canvas images to the staging area
if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) {
if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) {
dispatch(addImageToStagingArea(imageDTO));
}
@@ -113,7 +114,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
}
}
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {

View File

@@ -1,24 +1,52 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { toast } from 'features/toast/toast';
import ToastWithSessionRefDescription from 'features/toast/ToastWithSessionRefDescription';
import { t } from 'i18next';
import { startCase } from 'lodash-es';
import { socketInvocationError } from 'services/events/actions';
const log = logger('socketio');
const getTitle = (errorType: string) => {
if (errorType === 'OutOfMemoryError') {
return t('toast.outOfMemoryError');
}
return t('toast.serverError');
};
const getDescription = (errorType: string, sessionId: string, isLocal?: boolean) => {
if (!isLocal) {
if (errorType === 'OutOfMemoryError') {
return ToastWithSessionRefDescription({
message: t('toast.outOfMemoryDescription'),
sessionId,
});
}
return ToastWithSessionRefDescription({
message: errorType,
sessionId,
});
}
return errorType;
};
export const addInvocationErrorEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketInvocationError,
effect: (action) => {
const { invocation_source_id, invocation, error_type, error_message, error_traceback } = action.payload.data;
log.error(parseify(action.payload), `Invocation error (${invocation.type})`);
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
effect: (action, { getState }) => {
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
const { source_node_id, error_type, error_message, error_traceback, graph_execution_state_id } =
action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.FAILED;
nes.progress = null;
nes.progressImage = null;
nes.error = {
error_type,
error_message,
@@ -26,6 +54,19 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
};
upsertExecutionState(nes.nodeId, nes);
}
const errorType = startCase(error_type);
const sessionId = graph_execution_state_id;
const { isLocal } = getState().config;
toast({
id: `INVOCATION_ERROR_${errorType}`,
title: getTitle(errorType),
status: 'error',
duration: null,
description: getDescription(errorType, sessionId, isLocal),
updateDescription: isLocal ? true : false,
});
},
});
};

View File

@@ -1,7 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationStarted } from 'services/events/actions';
@@ -12,9 +11,9 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
startAppListening({
actionCreator: socketInvocationStarted,
effect: (action) => {
log.debug(parseify(action.payload), `Invocation started (${action.payload.data.invocation.type})`);
const { invocation_source_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
const { source_node_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
upsertExecutionState(nes.nodeId, nes);

View File

@@ -3,14 +3,14 @@ import { api, LIST_TAG } from 'services/api';
import { modelsApi } from 'services/api/endpoints/models';
import {
socketModelInstallCancelled,
socketModelInstallComplete,
socketModelInstallDownloadProgress,
socketModelInstallCompleted,
socketModelInstallDownloading,
socketModelInstallError,
} from 'services/events/actions';
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketModelInstallDownloadProgress,
actionCreator: socketModelInstallDownloading,
effect: async (action, { dispatch }) => {
const { bytes, total_bytes, id } = action.payload.data;
@@ -29,7 +29,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
});
startAppListening({
actionCreator: socketModelInstallComplete,
actionCreator: socketModelInstallCompleted,
effect: (action, { dispatch }) => {
const { id } = action.payload.data;

View File

@@ -1,6 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketModelLoadComplete, socketModelLoadStarted } from 'services/events/actions';
import { socketModelLoadCompleted, socketModelLoadStarted } from 'services/events/actions';
const log = logger('socketio');
@@ -8,11 +8,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
startAppListening({
actionCreator: socketModelLoadStarted,
effect: (action) => {
const { config, submodel_type } = action.payload.data;
const { name, base, type } = config;
const { model_config, submodel_type } = action.payload.data;
const { name, base, type } = model_config;
const extras: string[] = [base, type];
if (submodel_type) {
extras.push(submodel_type);
}
@@ -24,10 +23,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
});
startAppListening({
actionCreator: socketModelLoadComplete,
actionCreator: socketModelLoadCompleted,
effect: (action) => {
const { config, submodel_type } = action.payload.data;
const { name, base, type } = config;
const { model_config, submodel_type } = action.payload.data;
const { name, base, type } = model_config;
const extras: string[] = [base, type];
if (submodel_type) {

View File

@@ -3,8 +3,6 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
import { toast } from 'features/toast/toast';
import { forEach } from 'lodash-es';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { socketQueueItemStatusChanged } from 'services/events/actions';
@@ -14,38 +12,18 @@ const log = logger('socketio');
export const addSocketQueueItemStatusChangedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketQueueItemStatusChanged,
effect: async (action, { dispatch, getState }) => {
effect: async (action, { dispatch }) => {
// we've got new status for the queue item, batch and queue
const {
item_id,
session_id,
status,
started_at,
updated_at,
completed_at,
batch_status,
queue_status,
error_type,
error_message,
error_traceback,
} = action.payload.data;
const { queue_item, batch_status, queue_status } = action.payload.data;
log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
log.debug(action.payload, `Queue item ${queue_item.item_id} status updated: ${queue_item.status}`);
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
dispatch(
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
queueItemsAdapter.updateOne(draft, {
id: String(item_id),
changes: {
status,
started_at,
updated_at: updated_at ?? undefined,
completed_at: completed_at ?? undefined,
error_type,
error_message,
error_traceback,
},
id: String(queue_item.item_id),
changes: queue_item,
});
})
);
@@ -72,11 +50,11 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'InvocationCacheStatus',
{ type: 'SessionQueueItem', id: item_id },
{ type: 'SessionQueueItem', id: queue_item.item_id },
])
);
if (status === 'in_progress') {
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
forEach($nodeExecutionStates.get(), (nes) => {
if (!nes) {
return;
@@ -89,25 +67,6 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
clone.outputs = [];
$nodeExecutionStates.setKey(clone.nodeId, clone);
});
} else if (status === 'failed' && error_type) {
const isLocal = getState().config.isLocal ?? true;
const sessionId = session_id;
toast({
id: `INVOCATION_ERROR_${error_type}`,
title: getTitleFromErrorType(error_type),
status: 'error',
duration: null,
updateDescription: isLocal,
description: (
<ErrorToastDescription
errorType={error_type}
errorMessage={error_message}
sessionId={sessionId}
isLocal={isLocal}
/>
),
});
}
},
});

View File

@@ -0,0 +1,14 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketSubscribedSession } from 'services/events/actions';
const log = logger('socketio');
export const addSocketSubscribedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketSubscribedSession,
effect: (action) => {
log.debug(action.payload, 'Subscribed');
},
});
};

View File

@@ -0,0 +1,13 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { socketUnsubscribedSession } from 'services/events/actions';
const log = logger('socketio');
export const addSocketUnsubscribedEventListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: socketUnsubscribedSession,
effect: (action) => {
log.debug(action.payload, 'Unsubscribed');
},
});
};

View File

@@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
import { parseify } from 'common/util/serialize';
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { $needsFit } from 'features/nodes/store/reactFlowInstance';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import type { Templates } from 'features/nodes/store/types';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
@@ -65,7 +65,9 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
});
}
$needsFit.set(true);
requestAnimationFrame(() => {
$flow.get()?.fitView();
});
} catch (e) {
if (e instanceof WorkflowVersionError) {
// The workflow version was not recognized in the valid list of versions

View File

@@ -35,7 +35,6 @@ type IAIDndImageProps = FlexProps & {
draggableData?: TypesafeDraggableData;
dropLabel?: ReactNode;
isSelected?: boolean;
isSelectedForCompare?: boolean;
thumbnail?: boolean;
noContentFallback?: ReactElement;
useThumbailFallback?: boolean;
@@ -62,7 +61,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
draggableData,
dropLabel,
isSelected = false,
isSelectedForCompare = false,
thumbnail = false,
noContentFallback = defaultNoContentFallback,
uploadElement = defaultUploadElement,
@@ -167,11 +165,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
data-testid={dataTestId}
/>
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
<SelectionOverlay
isSelected={isSelected}
isSelectedForCompare={isSelectedForCompare}
isHovered={withHoverOverlay ? isHovered : false}
/>
<SelectionOverlay isSelected={isSelected} isHovered={withHoverOverlay ? isHovered : false} />
</Flex>
)}
{!imageDTO && !isUploadDisabled && (

View File

@@ -36,7 +36,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
pointerEvents={active ? 'auto' : 'none'}
>
<AnimatePresence>
{isValidDrop(data, active?.data.current) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
{isValidDrop(data, active) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
</AnimatePresence>
</Box>
);

View File

@@ -3,17 +3,10 @@ import { memo, useMemo } from 'react';
type Props = {
isSelected: boolean;
isSelectedForCompare: boolean;
isHovered: boolean;
};
const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props) => {
const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
const shadow = useMemo(() => {
if (isSelectedForCompare && isHovered) {
return 'hoverSelectedForCompare';
}
if (isSelectedForCompare && !isHovered) {
return 'selectedForCompare';
}
if (isSelected && isHovered) {
return 'hoverSelected';
}
@@ -24,7 +17,7 @@ const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props
return 'hoverUnselected';
}
return undefined;
}, [isHovered, isSelected, isSelectedForCompare]);
}, [isHovered, isSelected]);
return (
<Box
className="selection-box"
@@ -34,7 +27,7 @@ const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props
bottom={0}
insetInlineStart={0}
borderRadius="base"
opacity={isSelected || isSelectedForCompare ? 1 : 0.7}
opacity={isSelected ? 1 : 0.7}
transitionProperty="common"
transitionDuration="0.1s"
pointerEvents="none"

View File

@@ -1,21 +0,0 @@
import { useCallback, useMemo, useState } from 'react';
export const useBoolean = (initialValue: boolean) => {
const [isTrue, set] = useState(initialValue);
const setTrue = useCallback(() => set(true), []);
const setFalse = useCallback(() => set(false), []);
const toggle = useCallback(() => set((v) => !v), []);
const api = useMemo(
() => ({
isTrue,
set,
setTrue,
setFalse,
toggle,
}),
[isTrue, set, setTrue, setFalse, toggle]
);
return api;
};

View File

@@ -1,7 +1,3 @@
export const stopPropagation = (e: React.MouseEvent) => {
e.stopPropagation();
};
export const preventDefault = (e: React.MouseEvent) => {
e.preventDefault();
};

View File

@@ -613,7 +613,7 @@ export const canvasSlice = createSlice({
state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id);
}
const queueItemStatus = action.payload.data.status;
const queueItemStatus = action.payload.data.queue_item.status;
if (queueItemStatus === 'canceled' || queueItemStatus === 'failed') {
resetStagingAreaIfEmpty(state);
}

View File

@@ -1,13 +1,7 @@
import { deepClone } from 'common/util/deepClone';
import { zModelIdentifierField } from 'features/nodes/types/common';
import { merge, omit } from 'lodash-es';
import type {
AnyInvocation,
BaseModelType,
ControlNetModelConfig,
ImageDTO,
T2IAdapterModelConfig,
} from 'services/api/types';
import type { BaseModelType, ControlNetModelConfig, Graph, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
import { z } from 'zod';
const zId = z.string().min(1);
@@ -153,7 +147,7 @@ const zBeginEndStepPct = z
const zControlAdapterBase = z.object({
id: zId,
weight: z.number().gte(-1).lte(2),
weight: z.number().gte(0).lte(1),
image: zImageWithDims.nullable(),
processedImage: zImageWithDims.nullable(),
processorConfig: zProcessorConfig.nullable(),
@@ -189,7 +183,7 @@ export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safePar
export const zIPAdapterConfigV2 = z.object({
id: zId,
type: z.literal('ip_adapter'),
weight: z.number().gte(-1).lte(2),
weight: z.number().gte(0).lte(1),
method: zIPMethodV2,
image: zImageWithDims.nullable(),
model: zModelIdentifierField.nullable(),
@@ -222,7 +216,10 @@ type ProcessorData<T extends ProcessorTypeV2> = {
labelTKey: string;
descriptionTKey: string;
buildDefaults(baseModel?: BaseModelType): Extract<ProcessorConfig, { type: T }>;
buildNode(image: ImageWithDims, config: Extract<ProcessorConfig, { type: T }>): Extract<AnyInvocation, { type: T }>;
buildNode(
image: ImageWithDims,
config: Extract<ProcessorConfig, { type: T }>
): Extract<Graph['nodes'][string], { type: T }>;
};
const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height);

View File

@@ -54,7 +54,7 @@ const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)';
const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
export const STAGE_BG_DATAURL =
const STAGE_BG_DATAURL =
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAAEsmlUWHRYTUw6Y29tLmFkb2JlLnhtcAAAAAAAPD94cGFja2V0IGJlZ2luPSLvu78iIGlkPSJXNU0wTXBDZWhpSHpyZVN6TlRjemtjOWQiPz4KPHg6eG1wbWV0YSB4bWxuczp4PSJhZG9iZTpuczptZXRhLyIgeDp4bXB0az0iWE1QIENvcmUgNS41LjAiPgogPHJkZjpSREYgeG1sbnM6cmRmPSJodHRwOi8vd3d3LnczLm9yZy8xOTk5LzAyLzIyLXJkZi1zeW50YXgtbnMjIj4KICA8cmRmOkRlc2NyaXB0aW9uIHJkZjphYm91dD0iIgogICAgeG1sbnM6ZXhpZj0iaHR0cDovL25zLmFkb2JlLmNvbS9leGlmLzEuMC8iCiAgICB4bWxuczp0aWZmPSJodHRwOi8vbnMuYWRvYmUuY29tL3RpZmYvMS4wLyIKICAgIHhtbG5zOnBob3Rvc2hvcD0iaHR0cDovL25zLmFkb2JlLmNvbS9waG90b3Nob3AvMS4wLyIKICAgIHhtbG5zOnhtcD0iaHR0cDovL25zLmFkb2JlLmNvbS94YXAvMS4wLyIKICAgIHhtbG5zOnhtcE1NPSJodHRwOi8vbnMuYWRvYmUuY29tL3hhcC8xLjAvbW0vIgogICAgeG1sbnM6c3RFdnQ9Imh0dHA6Ly9ucy5hZG9iZS5jb20veGFwLzEuMC9zVHlwZS9SZXNvdXJjZUV2ZW50IyIKICAgZXhpZjpQaXhlbFhEaW1lbnNpb249IjIwIgogICBleGlmOlBpeGVsWURpbWVuc2lvbj0iMjAiCiAgIGV4aWY6Q29sb3JTcGFjZT0iMSIKICAgdGlmZjpJbWFnZVdpZHRoPSIyMCIKICAgdGlmZjpJbWFnZUxlbmd0aD0iMjAiCiAgIHRpZmY6UmVzb2x1dGlvblVuaXQ9IjIiCiAgIHRpZmY6WFJlc29sdXRpb249IjMwMC8xIgogICB0aWZmOllSZXNvbHV0aW9uPSIzMDAvMSIKICAgcGhvdG9zaG9wOkNvbG9yTW9kZT0iMyIKICAgcGhvdG9zaG9wOklDQ1Byb2ZpbGU9InNSR0IgSUVDNjE5NjYtMi4xIgogICB4bXA6TW9kaWZ5RGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCIKICAgeG1wOk1ldGFkYXRhRGF0ZT0iMjAyNC0wNC0yM1QwODoyMDo0NysxMDowMCI+CiAgIDx4bXBNTTpIaXN0b3J5PgogICAgPHJkZjpTZXE+CiAgICAgPHJkZjpsaQogICAgICBzdEV2dDphY3Rpb249InByb2R1Y2VkIgogICAgICBzdEV2dDpzb2Z0d2FyZUFnZW50PSJBZmZpbml0eSBQaG90byAxLjEwLjgiCiAgICAgIHN0RXZ0OndoZW49IjIwMjQtMDQtMjNUMDg6MjA6NDcrMTA6MDAiLz4KICAgIDwvcmRmOlNlcT4KICAgPC94bXBNTTpIaXN0b3J5PgogIDwvcmRmOkRlc2NyaXB0aW9uPgogPC9yZGY6UkRGPgo8L3g6eG1wbWV0YT4KPD94cGFja2V0IGVuZD0iciI/Pn9pdVgAAAGBaUNDUHNSR0IgSUVDNjE5NjYtMi4xAAAokXWR3yuDURjHP5uJmKghFy6WxpVpqMWNMgm1tGbKr5vt3S+1d3t73y3JrXKrKHHj1wV/AbfKtVJESq53TdywXs9rakv2nJ7zfM73nOfpnOeAPZJRVMPhAzWb18NTAffC4pK7oYiDTjpw4YgqhjYeCgWpaR8P2Kx457Vq1T73rzXHE4YCtkbhMUXT88LTwsG1vGbxrnC7ko7Ghc+F+3W5oPC9pcfKXLQ4VeYvi/VIeALsbcLuVBXHqlhJ66qwvByPmikov/exXuJMZOfnJPaId2MQZooAbmaYZAI/g4zK7MfLEAOyoka+7yd/lpzkKjJrrKOzSoo0efpFLUj1hMSk6AkZGdat/v/tq5EcHipXdwag/sU033qhYQdK26b5eWyapROoe4arbCU/dwQj76JvVzTPIbRuwsV1RYvtweUWdD1pUT36I9WJ25NJeD2DlkVw3ULTcrlnv/ucPkJkQ77qBvYPoE/Ot658AxagZ8FoS/a7AAAACXBIWXMAAC4jAAAuIwF4pT92AAAAL0lEQVQ4jWM8ffo0A25gYmKCR5YJjxxBMKp5ZGhm/P//Px7pM2fO0MrmUc0jQzMAB2EIhZC3pUYAAAAASUVORK5CYII=';
const mapId = (object: { id: string }) => object.id;

View File

@@ -18,7 +18,7 @@ type BaseDropData = {
id: string;
};
export type CurrentImageDropData = BaseDropData & {
type CurrentImageDropData = BaseDropData & {
actionType: 'SET_CURRENT_IMAGE';
};
@@ -79,14 +79,6 @@ export type RemoveFromBoardDropData = BaseDropData & {
actionType: 'REMOVE_FROM_BOARD';
};
export type SelectForCompareDropData = BaseDropData & {
actionType: 'SELECT_FOR_COMPARE';
context: {
firstImageName?: string | null;
secondImageName?: string | null;
};
};
export type TypesafeDroppableData =
| CurrentImageDropData
| ControlAdapterDropData
@@ -97,8 +89,7 @@ export type TypesafeDroppableData =
| CALayerImageDropData
| IPALayerImageDropData
| RGLayerIPAdapterImageDropData
| IILayerImageDropData
| SelectForCompareDropData;
| IILayerImageDropData;
type BaseDragData = {
id: string;
@@ -143,7 +134,7 @@ export type UseDraggableTypesafeReturnValue = Omit<ReturnType<typeof useOriginal
over: TypesafeOver | null;
};
interface TypesafeActive extends Omit<Active, 'data'> {
export interface TypesafeActive extends Omit<Active, 'data'> {
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
}

View File

@@ -1,14 +1,14 @@
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import type { TypesafeActive, TypesafeDroppableData } from 'features/dnd/types';
export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?: TypesafeDraggableData | null) => {
if (!overData || !activeData) {
export const isValidDrop = (overData: TypesafeDroppableData | undefined, active: TypesafeActive | null) => {
if (!overData || !active?.data.current) {
return false;
}
const { actionType } = overData;
const { payloadType } = activeData;
const { payloadType } = active.data.current;
if (overData.id === activeData.id) {
if (overData.id === active.data.current.id) {
return false;
}
@@ -29,8 +29,6 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
return payloadType === 'IMAGE_DTO';
case 'SET_NODES_IMAGE':
return payloadType === 'IMAGE_DTO';
case 'SELECT_FOR_COMPARE':
return payloadType === 'IMAGE_DTO';
case 'ADD_TO_BOARD': {
// If the board is the same, don't allow the drop
@@ -42,7 +40,7 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = activeData.payload;
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id ?? 'none';
const destinationBoard = overData.context.boardId;
@@ -51,7 +49,7 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
if (payloadType === 'GALLERY_SELECTION') {
// Assume all images are on the same board - this is true for the moment
const currentBoard = activeData.payload.boardId;
const currentBoard = active.data.current.payload.boardId;
const destinationBoard = overData.context.boardId;
return currentBoard !== destinationBoard;
}
@@ -69,14 +67,14 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
// Check if the image's board is the board we are dragging onto
if (payloadType === 'IMAGE_DTO') {
const { imageDTO } = activeData.payload;
const { imageDTO } = active.data.current.payload;
const currentBoard = imageDTO.board_id ?? 'none';
return currentBoard !== 'none';
}
if (payloadType === 'GALLERY_SELECTION') {
const currentBoard = activeData.payload.boardId;
const currentBoard = active.data.current.payload.boardId;
return currentBoard !== 'none';
}

View File

@@ -162,7 +162,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
</Flex>
)}
{isSelectedForAutoAdd && <AutoAddIcon />}
<SelectionOverlay isSelected={isSelected} isSelectedForCompare={false} isHovered={isHovered} />
<SelectionOverlay isSelected={isSelected} isHovered={isHovered} />
<Flex
position="absolute"
bottom={0}

View File

@@ -117,7 +117,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
>
{boardName}
</Flex>
<SelectionOverlay isSelected={isSelected} isSelectedForCompare={false} isHovered={isHovered} />
<SelectionOverlay isSelected={isSelected} isHovered={isHovered} />
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="md">{t('unifiedCanvas.move')}</Text>} />
</Flex>
</Tooltip>

View File

@@ -10,7 +10,6 @@ import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice';
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
import { useImageActions } from 'features/gallery/hooks/useImageActions';
import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions';
import { imageToCompareChanged } from 'features/gallery/store/gallerySlice';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
@@ -28,7 +27,6 @@ import {
PiDownloadSimpleBold,
PiFlowArrowBold,
PiFoldersBold,
PiImagesBold,
PiPlantBold,
PiQuotesBold,
PiShareFatBold,
@@ -46,7 +44,6 @@ type SingleSelectionMenuItemsProps = {
const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
const { imageDTO } = props;
const optimalDimension = useAppSelector(selectOptimalDimension);
const maySelectForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name !== imageDTO.image_name);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const isCanvasEnabled = useFeatureStatus('canvas');
@@ -120,10 +117,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
downloadImage(imageDTO.image_url, imageDTO.image_name);
}, [downloadImage, imageDTO.image_name, imageDTO.image_url]);
const handleSelectImageForCompare = useCallback(() => {
dispatch(imageToCompareChanged(imageDTO));
}, [dispatch, imageDTO]);
return (
<>
<MenuItem as="a" href={imageDTO.image_url} target="_blank" icon={<PiShareFatBold />}>
@@ -137,9 +130,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
<MenuItem icon={<PiDownloadSimpleBold />} onClickCapture={handleDownloadImage}>
{t('parameters.downloadImage')}
</MenuItem>
<MenuItem icon={<PiImagesBold />} isDisabled={!maySelectForCompare} onClick={handleSelectImageForCompare}>
{t('gallery.selectForCompare')}
</MenuItem>
<MenuDivider />
<MenuItem
icon={getAndLoadEmbeddedWorkflowResult.isLoading ? <SpinnerIcon /> : <PiFlowArrowBold />}

View File

@@ -11,7 +11,7 @@ import type { GallerySelectionDraggableData, ImageDraggableData, TypesafeDraggab
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
import { useScrollIntoView } from 'features/gallery/hooks/useScrollIntoView';
import { imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import type { MouseEvent } from 'react';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
@@ -46,7 +46,6 @@ const GalleryImage = (props: HoverableImageProps) => {
const { t } = useTranslation();
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageName);
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
const customStarUi = useStore($customStarUI);
@@ -106,7 +105,6 @@ const GalleryImage = (props: HoverableImageProps) => {
const onDoubleClick = useCallback(() => {
dispatch(isImageViewerOpenChanged(true));
dispatch(imageToCompareChanged(null));
}, [dispatch]);
const handleMouseOut = useCallback(() => {
@@ -154,7 +152,6 @@ const GalleryImage = (props: HoverableImageProps) => {
imageDTO={imageDTO}
draggableData={draggableData}
isSelected={isSelected}
isSelectedForCompare={isSelectedForCompare}
minSize={0}
imageSx={imageSx}
isDropDisabled={true}

View File

@@ -1,140 +0,0 @@
import {
Button,
ButtonGroup,
Flex,
Icon,
IconButton,
Kbd,
ListItem,
Tooltip,
UnorderedList,
} from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
comparedImagesSwapped,
comparisonFitChanged,
comparisonModeChanged,
comparisonModeCycled,
imageToCompareChanged,
} from 'features/gallery/store/gallerySlice';
import { memo, useCallback } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { Trans, useTranslation } from 'react-i18next';
import { PiArrowsOutBold, PiQuestion, PiSwapBold, PiXBold } from 'react-icons/pi';
export const CompareToolbar = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode);
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
const setComparisonModeSlider = useCallback(() => {
dispatch(comparisonModeChanged('slider'));
}, [dispatch]);
const setComparisonModeSideBySide = useCallback(() => {
dispatch(comparisonModeChanged('side-by-side'));
}, [dispatch]);
const setComparisonModeHover = useCallback(() => {
dispatch(comparisonModeChanged('hover'));
}, [dispatch]);
const swapImages = useCallback(() => {
dispatch(comparedImagesSwapped());
}, [dispatch]);
useHotkeys('c', swapImages, [swapImages]);
const toggleComparisonFit = useCallback(() => {
dispatch(comparisonFitChanged(comparisonFit === 'contain' ? 'fill' : 'contain'));
}, [dispatch, comparisonFit]);
const exitCompare = useCallback(() => {
dispatch(imageToCompareChanged(null));
}, [dispatch]);
useHotkeys('esc', exitCompare, [exitCompare]);
const nextMode = useCallback(() => {
dispatch(comparisonModeCycled());
}, [dispatch]);
useHotkeys('m', nextMode, [nextMode]);
return (
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<IconButton
icon={<PiSwapBold />}
aria-label={`${t('gallery.swapImages')} (C)`}
tooltip={`${t('gallery.swapImages')} (C)`}
onClick={swapImages}
/>
{comparisonMode !== 'side-by-side' && (
<IconButton
aria-label={t('gallery.stretchToFit')}
tooltip={t('gallery.stretchToFit')}
onClick={toggleComparisonFit}
colorScheme={comparisonFit === 'fill' ? 'invokeBlue' : 'base'}
variant="outline"
icon={<PiArrowsOutBold />}
/>
)}
</Flex>
</Flex>
<Flex flex={1} gap={4} justifyContent="center">
<ButtonGroup variant="outline">
<Button
flexShrink={0}
onClick={setComparisonModeSlider}
colorScheme={comparisonMode === 'slider' ? 'invokeBlue' : 'base'}
>
{t('gallery.slider')}
</Button>
<Button
flexShrink={0}
onClick={setComparisonModeSideBySide}
colorScheme={comparisonMode === 'side-by-side' ? 'invokeBlue' : 'base'}
>
{t('gallery.sideBySide')}
</Button>
<Button
flexShrink={0}
onClick={setComparisonModeHover}
colorScheme={comparisonMode === 'hover' ? 'invokeBlue' : 'base'}
>
{t('gallery.hover')}
</Button>
</ButtonGroup>
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto" alignItems="center">
<Tooltip label={<CompareHelp />}>
<Flex alignItems="center">
<Icon boxSize={8} color="base.500" as={PiQuestion} lineHeight={0} />
</Flex>
</Tooltip>
<IconButton
icon={<PiXBold />}
aria-label={`${t('gallery.exitCompare')} (Esc)`}
tooltip={`${t('gallery.exitCompare')} (Esc)`}
onClick={exitCompare}
/>
</Flex>
</Flex>
</Flex>
);
});
CompareToolbar.displayName = 'CompareToolbar';
const CompareHelp = () => {
return (
<UnorderedList>
<ListItem>
<Trans i18nKey="gallery.compareHelp1" components={{ Kbd: <Kbd /> }}></Trans>
</ListItem>
<ListItem>
<Trans i18nKey="gallery.compareHelp2" components={{ Kbd: <Kbd /> }}></Trans>
</ListItem>
<ListItem>
<Trans i18nKey="gallery.compareHelp3" components={{ Kbd: <Kbd /> }}></Trans>
</ListItem>
<ListItem>
<Trans i18nKey="gallery.compareHelp4" components={{ Kbd: <Kbd /> }}></Trans>
</ListItem>
</UnorderedList>
);
};

View File

@@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDndImage from 'common/components/IAIDndImage';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import type { TypesafeDraggableData } from 'features/dnd/types';
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
@@ -22,7 +22,21 @@ const selectLastSelectedImageName = createSelector(
(lastSelectedImage) => lastSelectedImage?.image_name
);
const CurrentImagePreview = () => {
type Props = {
isDragDisabled?: boolean;
isDropDisabled?: boolean;
withNextPrevButtons?: boolean;
withMetadata?: boolean;
alwaysShowProgress?: boolean;
};
const CurrentImagePreview = ({
isDragDisabled = false,
isDropDisabled = false,
withNextPrevButtons = true,
withMetadata = true,
alwaysShowProgress = false,
}: Props) => {
const { t } = useTranslation();
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
const imageName = useAppSelector(selectLastSelectedImageName);
@@ -41,6 +55,14 @@ const CurrentImagePreview = () => {
}
}, [imageDTO]);
const droppableData = useMemo<TypesafeDroppableData | undefined>(
() => ({
id: 'current-image',
actionType: 'SET_CURRENT_IMAGE',
}),
[]
);
// Show and hide the next/prev buttons on mouse move
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = useState<boolean>(false);
const timeoutId = useRef(0);
@@ -64,27 +86,30 @@ const CurrentImagePreview = () => {
justifyContent="center"
position="relative"
>
{hasDenoiseProgress && shouldShowProgressInViewer ? (
{hasDenoiseProgress && (shouldShowProgressInViewer || alwaysShowProgress) ? (
<ProgressImage />
) : (
<IAIDndImage
imageDTO={imageDTO}
droppableData={droppableData}
draggableData={draggableData}
isDropDisabled={true}
isDragDisabled={isDragDisabled}
isDropDisabled={isDropDisabled}
isUploadDisabled={true}
fitContainer
useThumbailFallback
dropLabel={t('gallery.setCurrentImage')}
noContentFallback={<IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />}
dataTestId="image-preview"
/>
)}
{shouldShowImageDetails && imageDTO && (
{shouldShowImageDetails && imageDTO && withMetadata && (
<Box position="absolute" opacity={0.8} top={0} width="full" height="full" borderRadius="base">
<ImageMetadataViewer image={imageDTO} />
</Box>
)}
<AnimatePresence>
{shouldShowNextPrevButtons && imageDTO && (
{withNextPrevButtons && shouldShowNextPrevButtons && imageDTO && (
<Box
as={motion.div}
key="nextPrevButtons"

View File

@@ -1,41 +0,0 @@
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { selectComparisonImages } from 'features/gallery/components/ImageViewer/common';
import { ImageComparisonHover } from 'features/gallery/components/ImageViewer/ImageComparisonHover';
import { ImageComparisonSideBySide } from 'features/gallery/components/ImageViewer/ImageComparisonSideBySide';
import { ImageComparisonSlider } from 'features/gallery/components/ImageViewer/ImageComparisonSlider';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiImagesBold } from 'react-icons/pi';
type Props = {
containerDims: Dimensions;
};
export const ImageComparison = memo(({ containerDims }: Props) => {
const { t } = useTranslation();
const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode);
const { firstImage, secondImage } = useAppSelector(selectComparisonImages);
if (!firstImage || !secondImage) {
// Should rarely/never happen - we don't render this component unless we have images to compare
return <IAINoContentFallback label={t('gallery.selectAnImageToCompare')} icon={PiImagesBold} />;
}
if (comparisonMode === 'slider') {
return <ImageComparisonSlider containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />;
}
if (comparisonMode === 'side-by-side') {
return (
<ImageComparisonSideBySide containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />
);
}
if (comparisonMode === 'hover') {
return <ImageComparisonHover containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />;
}
});
ImageComparison.displayName = 'ImageComparison';

View File

@@ -1,47 +0,0 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import IAIDroppable from 'common/components/IAIDroppable';
import type { CurrentImageDropData, SelectForCompareDropData } from 'features/dnd/types';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { selectComparisonImages } from './common';
const setCurrentImageDropData: CurrentImageDropData = {
id: 'current-image',
actionType: 'SET_CURRENT_IMAGE',
};
export const ImageComparisonDroppable = memo(() => {
const { t } = useTranslation();
const imageViewer = useImageViewer();
const { firstImage, secondImage } = useAppSelector(selectComparisonImages);
const selectForCompareDropData = useMemo<SelectForCompareDropData>(
() => ({
id: 'image-comparison',
actionType: 'SELECT_FOR_COMPARE',
context: {
firstImageName: firstImage?.image_name,
secondImageName: secondImage?.image_name,
},
}),
[firstImage?.image_name, secondImage?.image_name]
);
if (!imageViewer.isOpen) {
return (
<Flex position="absolute" top={0} right={0} bottom={0} left={0} gap={2} pointerEvents="none">
<IAIDroppable data={setCurrentImageDropData} dropLabel={t('gallery.openInViewer')} />
</Flex>
);
}
return (
<Flex position="absolute" top={0} right={0} bottom={0} left={0} gap={2} pointerEvents="none">
<IAIDroppable data={selectForCompareDropData} dropLabel={t('gallery.selectForCompare')} />
</Flex>
);
});
ImageComparisonDroppable.displayName = 'ImageComparisonDroppable';

View File

@@ -1,117 +0,0 @@
import { Box, Flex, Image } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { useBoolean } from 'common/hooks/useBoolean';
import { preventDefault } from 'common/util/stopPropagation';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
import { memo, useMemo, useRef } from 'react';
import type { ComparisonProps } from './common';
import { fitDimsToContainer, getSecondImageDims } from './common';
export const ImageComparisonHover = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
const imageContainerRef = useRef<HTMLDivElement>(null);
const mouseOver = useBoolean(false);
const fittedDims = useMemo<Dimensions>(
() => fitDimsToContainer(containerDims, firstImage),
[containerDims, firstImage]
);
const compareImageDims = useMemo<Dimensions>(
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),
[comparisonFit, fittedDims, firstImage, secondImage]
);
return (
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
<Flex
id="image-comparison-wrapper"
w="full"
h="full"
maxW="full"
maxH="full"
position="absolute"
alignItems="center"
justifyContent="center"
>
<Box
ref={imageContainerRef}
position="relative"
id="image-comparison-hover-image-container"
w={fittedDims.width}
h={fittedDims.height}
maxW="full"
maxH="full"
userSelect="none"
overflow="hidden"
borderRadius="base"
>
<Image
id="image-comparison-hover-first-image"
src={firstImage.image_url}
fallbackSrc={firstImage.thumbnail_url}
w={fittedDims.width}
h={fittedDims.height}
maxW="full"
maxH="full"
objectFit="cover"
objectPosition="top left"
/>
<ImageComparisonLabel type="first" opacity={mouseOver.isTrue ? 0 : 1} />
<Box
id="image-comparison-hover-second-image-container"
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
overflow="hidden"
opacity={mouseOver.isTrue ? 1 : 0}
transitionDuration="0.2s"
transitionProperty="common"
>
<Box
id="image-comparison-hover-bg"
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
backgroundImage={STAGE_BG_DATAURL}
backgroundRepeat="repeat"
opacity={0.2}
/>
<Image
position="relative"
id="image-comparison-hover-second-image"
src={secondImage.image_url}
fallbackSrc={secondImage.thumbnail_url}
w={compareImageDims.width}
h={compareImageDims.height}
maxW={fittedDims.width}
maxH={fittedDims.height}
objectFit={comparisonFit}
objectPosition="top left"
/>
<ImageComparisonLabel type="second" opacity={mouseOver.isTrue ? 1 : 0} />
</Box>
<Box
id="image-comparison-hover-interaction-overlay"
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
onMouseOver={mouseOver.setTrue}
onMouseOut={mouseOver.setFalse}
onContextMenu={preventDefault}
userSelect="none"
/>
</Box>
</Flex>
</Flex>
);
});
ImageComparisonHover.displayName = 'ImageComparisonHover';

View File

@@ -1,33 +0,0 @@
import type { TextProps } from '@invoke-ai/ui-library';
import { Text } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { DROP_SHADOW } from './common';
type Props = TextProps & {
type: 'first' | 'second';
};
export const ImageComparisonLabel = memo(({ type, ...rest }: Props) => {
const { t } = useTranslation();
return (
<Text
position="absolute"
bottom={4}
insetInlineEnd={type === 'first' ? undefined : 4}
insetInlineStart={type === 'first' ? 4 : undefined}
textOverflow="clip"
whiteSpace="nowrap"
filter={DROP_SHADOW}
color="base.50"
transitionDuration="0.2s"
transitionProperty="common"
{...rest}
>
{type === 'first' ? t('gallery.viewerImage') : t('gallery.compareImage')}
</Text>
);
});
ImageComparisonLabel.displayName = 'ImageComparisonLabel';

View File

@@ -1,70 +0,0 @@
import { Flex, Image } from '@invoke-ai/ui-library';
import type { ComparisonProps } from 'features/gallery/components/ImageViewer/common';
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
import { memo, useCallback, useRef } from 'react';
import type { ImperativePanelGroupHandle } from 'react-resizable-panels';
import { Panel, PanelGroup } from 'react-resizable-panels';
export const ImageComparisonSideBySide = memo(({ firstImage, secondImage }: ComparisonProps) => {
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
const onDoubleClickHandle = useCallback(() => {
if (!panelGroupRef.current) {
return;
}
panelGroupRef.current.setLayout([50, 50]);
}, []);
return (
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
<Flex w="full" h="full" maxW="full" maxH="full" position="absolute" alignItems="center" justifyContent="center">
<PanelGroup ref={panelGroupRef} direction="horizontal" id="image-comparison-side-by-side">
<Panel minSize={20}>
<Flex position="relative" w="full" h="full" alignItems="center" justifyContent="center">
<Flex position="absolute" maxW="full" maxH="full" aspectRatio={firstImage.width / firstImage.height}>
<Image
id="image-comparison-side-by-side-first-image"
w={firstImage.width}
h={firstImage.height}
maxW="full"
maxH="full"
src={firstImage.image_url}
fallbackSrc={firstImage.thumbnail_url}
objectFit="contain"
borderRadius="base"
/>
<ImageComparisonLabel type="first" />
</Flex>
</Flex>
</Panel>
<ResizeHandle
id="image-comparison-side-by-side-handle"
onDoubleClick={onDoubleClickHandle}
orientation="vertical"
/>
<Panel minSize={20}>
<Flex position="relative" w="full" h="full" alignItems="center" justifyContent="center">
<Flex position="absolute" maxW="full" maxH="full" aspectRatio={secondImage.width / secondImage.height}>
<Image
id="image-comparison-side-by-side-first-image"
w={secondImage.width}
h={secondImage.height}
maxW="full"
maxH="full"
src={secondImage.image_url}
fallbackSrc={secondImage.thumbnail_url}
objectFit="contain"
borderRadius="base"
/>
<ImageComparisonLabel type="second" />
</Flex>
</Flex>
</Panel>
</PanelGroup>
</Flex>
</Flex>
);
});
ImageComparisonSideBySide.displayName = 'ImageComparisonSideBySide';

View File

@@ -1,215 +0,0 @@
import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { preventDefault } from 'common/util/stopPropagation';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
import type { ComparisonProps } from './common';
import { DROP_SHADOW, fitDimsToContainer, getSecondImageDims } from './common';
const INITIAL_POS = '50%';
const HANDLE_WIDTH = 2;
const HANDLE_WIDTH_PX = `${HANDLE_WIDTH}px`;
const HANDLE_HITBOX = 20;
const HANDLE_HITBOX_PX = `${HANDLE_HITBOX}px`;
const HANDLE_INNER_LEFT_PX = `${HANDLE_HITBOX / 2 - HANDLE_WIDTH / 2}px`;
const HANDLE_LEFT_INITIAL_PX = `calc(${INITIAL_POS} - ${HANDLE_HITBOX / 2}px)`;
export const ImageComparisonSlider = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
// How far the handle is from the left - this will be a CSS calculation that takes into account the handle width
const [left, setLeft] = useState(HANDLE_LEFT_INITIAL_PX);
// How wide the first image is
const [width, setWidth] = useState(INITIAL_POS);
const handleRef = useRef<HTMLDivElement>(null);
// To manage aspect ratios, we need to know the size of the container
const imageContainerRef = useRef<HTMLDivElement>(null);
// To keep things smooth, we use RAF to update the handle position & gate it to 60fps
const rafRef = useRef<number | null>(null);
const lastMoveTimeRef = useRef<number>(0);
const fittedDims = useMemo<Dimensions>(
() => fitDimsToContainer(containerDims, firstImage),
[containerDims, firstImage]
);
const compareImageDims = useMemo<Dimensions>(
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),
[comparisonFit, fittedDims, firstImage, secondImage]
);
const updateHandlePos = useCallback((clientX: number) => {
if (!handleRef.current || !imageContainerRef.current) {
return;
}
lastMoveTimeRef.current = performance.now();
const { x, width } = imageContainerRef.current.getBoundingClientRect();
const rawHandlePos = ((clientX - x) * 100) / width;
const handleWidthPct = (HANDLE_WIDTH * 100) / width;
const newHandlePos = Math.min(100 - handleWidthPct, Math.max(0, rawHandlePos));
setWidth(`${newHandlePos}%`);
setLeft(`calc(${newHandlePos}% - ${HANDLE_HITBOX / 2}px)`);
}, []);
const onMouseMove = useCallback(
(e: MouseEvent) => {
if (rafRef.current === null && performance.now() > lastMoveTimeRef.current + 1000 / 60) {
rafRef.current = window.requestAnimationFrame(() => {
updateHandlePos(e.clientX);
rafRef.current = null;
});
}
},
[updateHandlePos]
);
const onMouseUp = useCallback(() => {
window.removeEventListener('mousemove', onMouseMove);
}, [onMouseMove]);
const onMouseDown = useCallback(
(e: React.MouseEvent<HTMLDivElement>) => {
// Update the handle position immediately on click
updateHandlePos(e.clientX);
window.addEventListener('mouseup', onMouseUp, { once: true });
window.addEventListener('mousemove', onMouseMove);
},
[onMouseMove, onMouseUp, updateHandlePos]
);
useEffect(
() => () => {
if (rafRef.current !== null) {
cancelAnimationFrame(rafRef.current);
}
},
[]
);
return (
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
<Flex
id="image-comparison-wrapper"
w="full"
h="full"
maxW="full"
maxH="full"
position="absolute"
alignItems="center"
justifyContent="center"
>
<Box
ref={imageContainerRef}
position="relative"
id="image-comparison-image-container"
w={fittedDims.width}
h={fittedDims.height}
maxW="full"
maxH="full"
userSelect="none"
overflow="hidden"
borderRadius="base"
>
<Box
id="image-comparison-bg"
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
backgroundImage={STAGE_BG_DATAURL}
backgroundRepeat="repeat"
opacity={0.2}
/>
<Image
position="relative"
id="image-comparison-second-image"
src={secondImage.image_url}
fallbackSrc={secondImage.thumbnail_url}
w={compareImageDims.width}
h={compareImageDims.height}
maxW={fittedDims.width}
maxH={fittedDims.height}
objectFit={comparisonFit}
objectPosition="top left"
/>
<ImageComparisonLabel type="second" />
<Box
id="image-comparison-first-image-container"
position="absolute"
top={0}
left={0}
right={0}
bottom={0}
w={width}
overflow="hidden"
>
<Image
id="image-comparison-first-image"
src={firstImage.image_url}
fallbackSrc={firstImage.thumbnail_url}
w={fittedDims.width}
h={fittedDims.height}
objectFit="cover"
objectPosition="top left"
/>
<ImageComparisonLabel type="first" />
</Box>
<Flex
id="image-comparison-handle"
ref={handleRef}
position="absolute"
top={0}
bottom={0}
left={left}
w={HANDLE_HITBOX_PX}
cursor="ew-resize"
filter={DROP_SHADOW}
opacity={0.8}
color="base.50"
>
<Box
id="image-comparison-handle-divider"
w={HANDLE_WIDTH_PX}
h="full"
bg="currentColor"
shadow="dark-lg"
position="absolute"
top={0}
left={HANDLE_INNER_LEFT_PX}
/>
<Flex
id="image-comparison-handle-icons"
gap={4}
position="absolute"
left="50%"
top="50%"
transform="translate(-50%, 0)"
filter={DROP_SHADOW}
>
<Icon as={PiCaretLeftBold} />
<Icon as={PiCaretRightBold} />
</Flex>
</Flex>
<Box
id="image-comparison-interaction-overlay"
position="absolute"
top={0}
right={0}
bottom={0}
left={0}
onMouseDown={onMouseDown}
onContextMenu={preventDefault}
userSelect="none"
cursor="ew-resize"
/>
</Box>
</Flex>
</Flex>
);
});
ImageComparisonSlider.displayName = 'ImageComparisonSlider';

View File

@@ -1,16 +1,36 @@
import { Box, Flex } from '@invoke-ai/ui-library';
import { CompareToolbar } from 'features/gallery/components/ImageViewer/CompareToolbar';
import CurrentImagePreview from 'features/gallery/components/ImageViewer/CurrentImagePreview';
import { ImageComparison } from 'features/gallery/components/ImageViewer/ImageComparison';
import { ViewerToolbar } from 'features/gallery/components/ImageViewer/ViewerToolbar';
import { memo } from 'react';
import { useMeasure } from 'react-use';
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import type { InvokeTabName } from 'features/ui/store/tabMap';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo, useMemo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useImageViewer } from './useImageViewer';
import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview';
import { ViewerToggleMenu } from './ViewerToggleMenu';
const VIEWER_ENABLED_TABS: InvokeTabName[] = ['canvas', 'generation', 'workflows'];
export const ImageViewer = memo(() => {
const imageViewer = useImageViewer();
const [containerRef, containerDims] = useMeasure<HTMLDivElement>();
const { isOpen, onToggle, onClose } = useImageViewer();
const activeTabName = useAppSelector(activeTabNameSelector);
const isViewerEnabled = useMemo(() => VIEWER_ENABLED_TABS.includes(activeTabName), [activeTabName]);
const shouldShowViewer = useMemo(() => {
if (!isViewerEnabled) {
return false;
}
return isOpen;
}, [isOpen, isViewerEnabled]);
useHotkeys('z', onToggle, { enabled: isViewerEnabled }, [isViewerEnabled, onToggle]);
useHotkeys('esc', onClose, { enabled: isViewerEnabled }, [isViewerEnabled, onClose]);
if (!shouldShowViewer) {
return null;
}
return (
<Flex
@@ -26,13 +46,25 @@ export const ImageViewer = memo(() => {
rowGap={4}
alignItems="center"
justifyContent="center"
zIndex={10} // reactflow puts its minimap at 5, so we need to be above that
>
{imageViewer.isComparing && <CompareToolbar />}
{!imageViewer.isComparing && <ViewerToolbar />}
<Box ref={containerRef} w="full" h="full">
{!imageViewer.isComparing && <CurrentImagePreview />}
{imageViewer.isComparing && <ImageComparison containerDims={containerDims} />}
</Box>
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<ToggleProgressButton />
<ToggleMetadataViewerButton />
</Flex>
</Flex>
<Flex flex={1} gap={2} justifyContent="center">
<CurrentImageButtons />
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto">
<ViewerToggleMenu />
</Flex>
</Flex>
</Flex>
<CurrentImagePreview />
</Flex>
);
});

View File

@@ -0,0 +1,45 @@
import { Flex } from '@invoke-ai/ui-library';
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { memo } from 'react';
import CurrentImageButtons from './CurrentImageButtons';
import CurrentImagePreview from './CurrentImagePreview';
export const ImageViewerWorkflows = memo(() => {
return (
<Flex
layerStyle="first"
borderRadius="base"
position="absolute"
flexDirection="column"
top={0}
right={0}
bottom={0}
left={0}
p={2}
rowGap={4}
alignItems="center"
justifyContent="center"
zIndex={10} // reactflow puts its minimap at 5, so we need to be above that
>
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<ToggleProgressButton />
<ToggleMetadataViewerButton />
</Flex>
</Flex>
<Flex flex={1} gap={2} justifyContent="center">
<CurrentImageButtons />
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto" />
</Flex>
</Flex>
<CurrentImagePreview />
</Flex>
);
});
ImageViewerWorkflows.displayName = 'ImageViewerWorkflows';

View File

@@ -9,35 +9,33 @@ import {
PopoverTrigger,
Text,
} from '@invoke-ai/ui-library';
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { PiCaretDownBold, PiCheckBold, PiEyeBold, PiPencilBold } from 'react-icons/pi';
import { useImageViewer } from './useImageViewer';
export const ViewerToggleMenu = () => {
const { t } = useTranslation();
const imageViewer = useImageViewer();
useHotkeys('z', imageViewer.onToggle, [imageViewer]);
useHotkeys('esc', imageViewer.onClose, [imageViewer]);
const { isOpen, onClose, onOpen } = useImageViewer();
return (
<Popover isLazy>
<PopoverTrigger>
<Button variant="outline" data-testid="toggle-viewer-menu-button" pointerEvents="auto">
<Button variant="outline" data-testid="toggle-viewer-menu-button">
<Flex gap={3} w="full" alignItems="center">
{imageViewer.isOpen ? <Icon as={PiEyeBold} /> : <Icon as={PiPencilBold} />}
<Text fontSize="md">{imageViewer.isOpen ? t('common.viewing') : t('common.editing')}</Text>
{isOpen ? <Icon as={PiEyeBold} /> : <Icon as={PiPencilBold} />}
<Text fontSize="md">{isOpen ? t('common.viewing') : t('common.editing')}</Text>
<Icon as={PiCaretDownBold} />
</Flex>
</Button>
</PopoverTrigger>
<PopoverContent p={2} pointerEvents="auto">
<PopoverContent p={2}>
<PopoverArrow />
<PopoverBody>
<Flex flexDir="column">
<Button onClick={imageViewer.onOpen} variant="ghost" h="auto" w="auto" p={2}>
<Button onClick={onOpen} variant="ghost" h="auto" w="auto" p={2}>
<Flex gap={2} w="full">
<Icon as={PiCheckBold} visibility={imageViewer.isOpen ? 'visible' : 'hidden'} />
<Icon as={PiCheckBold} visibility={isOpen ? 'visible' : 'hidden'} />
<Flex flexDir="column" gap={2} alignItems="flex-start">
<Text fontWeight="semibold" color="base.100">
{t('common.viewing')}
@@ -48,9 +46,9 @@ export const ViewerToggleMenu = () => {
</Flex>
</Flex>
</Button>
<Button onClick={imageViewer.onClose} variant="ghost" h="auto" w="auto" p={2}>
<Button onClick={onClose} variant="ghost" h="auto" w="auto" p={2}>
<Flex gap={2} w="full">
<Icon as={PiCheckBold} visibility={imageViewer.isOpen ? 'hidden' : 'visible'} />
<Icon as={PiCheckBold} visibility={isOpen ? 'hidden' : 'visible'} />
<Flex flexDir="column" gap={2} alignItems="flex-start">
<Text fontWeight="semibold" color="base.100">
{t('common.editing')}

View File

@@ -1,33 +0,0 @@
import { Flex } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { memo } from 'react';
import CurrentImageButtons from './CurrentImageButtons';
import { ViewerToggleMenu } from './ViewerToggleMenu';
export const ViewerToolbar = memo(() => {
const tab = useAppSelector(activeTabNameSelector);
return (
<Flex w="full" gap={2}>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineEnd="auto">
<ToggleProgressButton />
<ToggleMetadataViewerButton />
</Flex>
</Flex>
<Flex flex={1} gap={2} justifyContent="center">
<CurrentImageButtons />
</Flex>
<Flex flex={1} justifyContent="center">
<Flex gap={2} marginInlineStart="auto">
{tab !== 'workflows' && <ViewerToggleMenu />}
</Flex>
</Flex>
</Flex>
);
});
ViewerToolbar.displayName = 'ViewerToolbar';

View File

@@ -1,64 +0,0 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { Dimensions } from 'features/canvas/store/canvasTypes';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import type { ComparisonFit } from 'features/gallery/store/types';
import type { ImageDTO } from 'services/api/types';
export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
export type ComparisonProps = {
firstImage: ImageDTO;
secondImage: ImageDTO;
containerDims: Dimensions;
};
export const fitDimsToContainer = (containerDims: Dimensions, imageDims: Dimensions): Dimensions => {
// Fall back to the image's dimensions if the container has no dimensions
if (containerDims.width === 0 || containerDims.height === 0) {
return { width: imageDims.width, height: imageDims.height };
}
// Fall back to the image's dimensions if the image fits within the container
if (imageDims.width <= containerDims.width && imageDims.height <= containerDims.height) {
return { width: imageDims.width, height: imageDims.height };
}
const targetAspectRatio = containerDims.width / containerDims.height;
const imageAspectRatio = imageDims.width / imageDims.height;
let width: number;
let height: number;
if (imageAspectRatio > targetAspectRatio) {
// Image is wider than container's aspect ratio
width = containerDims.width;
height = width / imageAspectRatio;
} else {
// Image is taller than container's aspect ratio
height = containerDims.height;
width = height * imageAspectRatio;
}
return { width, height };
};
/**
* Gets the dimensions of the second image in a comparison based on the comparison fit mode.
*/
export const getSecondImageDims = (
comparisonFit: ComparisonFit,
fittedDims: Dimensions,
firstImageDims: Dimensions,
secondImageDims: Dimensions
): Dimensions => {
const width =
comparisonFit === 'fill' ? fittedDims.width : (fittedDims.width * secondImageDims.width) / firstImageDims.width;
const height =
comparisonFit === 'fill' ? fittedDims.height : (fittedDims.height * secondImageDims.height) / firstImageDims.height;
return { width, height };
};
export const selectComparisonImages = createMemoizedSelector(selectGallerySlice, (gallerySlice) => {
const firstImage = gallerySlice.selection.slice(-1)[0] ?? null;
const secondImage = gallerySlice.imageToCompare;
return { firstImage, secondImage };
});

View File

@@ -1,31 +0,0 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { useCallback } from 'react';
export const useImageViewer = () => {
const dispatch = useAppDispatch();
const isComparing = useAppSelector((s) => s.gallery.imageToCompare !== null);
const isOpen = useAppSelector((s) => s.gallery.isImageViewerOpen);
const onClose = useCallback(() => {
if (isComparing && isOpen) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(isImageViewerOpenChanged(false));
}
}, [dispatch, isComparing, isOpen]);
const onOpen = useCallback(() => {
dispatch(isImageViewerOpenChanged(true));
}, [dispatch]);
const onToggle = useCallback(() => {
if (isComparing && isOpen) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(isImageViewerOpenChanged(!isOpen));
}
}, [dispatch, isComparing, isOpen]);
return { isOpen, onOpen, onClose, onToggle, isComparing };
};

View File

@@ -0,0 +1,22 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
import { useCallback } from 'react';
export const useImageViewer = () => {
const dispatch = useAppDispatch();
const isOpen = useAppSelector((s) => s.gallery.isImageViewerOpen);
const onClose = useCallback(() => {
dispatch(isImageViewerOpenChanged(false));
}, [dispatch]);
const onOpen = useCallback(() => {
dispatch(isImageViewerOpenChanged(true));
}, [dispatch]);
const onToggle = useCallback(() => {
dispatch(isImageViewerOpenChanged(!isOpen));
}, [dispatch, isOpen]);
return { isOpen, onOpen, onClose, onToggle };
};

View File

@@ -14,7 +14,7 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
const NextPrevImageButtons = () => {
const { t } = useTranslation();
const { prevImage, nextImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
const { handleLeftImage, handleRightImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
const {
areMoreImagesAvailable,
@@ -30,7 +30,7 @@ const NextPrevImageButtons = () => {
aria-label={t('accessibility.previousImage')}
icon={<PiCaretLeftBold size={64} />}
variant="unstyled"
onClick={prevImage}
onClick={handleLeftImage}
boxSize={16}
sx={nextPrevButtonStyles}
/>
@@ -42,7 +42,7 @@ const NextPrevImageButtons = () => {
aria-label={t('accessibility.nextImage')}
icon={<PiCaretRightBold size={64} />}
variant="unstyled"
onClick={nextImage}
onClick={handleRightImage}
boxSize={16}
sx={nextPrevButtonStyles}
/>

View File

@@ -27,16 +27,16 @@ export const useGalleryHotkeys = () => {
useGalleryNavigation();
useHotkeys(
['left', 'alt+left'],
(e) => {
canNavigateGallery && handleLeftImage(e.altKey);
'left',
() => {
canNavigateGallery && handleLeftImage();
},
[handleLeftImage, canNavigateGallery]
);
useHotkeys(
['right', 'alt+right'],
(e) => {
'right',
() => {
if (!canNavigateGallery) {
return;
}
@@ -45,29 +45,29 @@ export const useGalleryHotkeys = () => {
return;
}
if (!isOnLastImage) {
handleRightImage(e.altKey);
handleRightImage();
}
},
[isOnLastImage, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleRightImage, canNavigateGallery]
);
useHotkeys(
['up', 'alt+up'],
(e) => {
handleUpImage(e.altKey);
'up',
() => {
handleUpImage();
},
{ preventDefault: true },
[handleUpImage]
);
useHotkeys(
['down', 'alt+down'],
(e) => {
'down',
() => {
if (!areImagesBelowCurrent && areMoreImagesAvailable && !isFetching) {
handleLoadMoreImages();
return;
}
handleDownImage(e.altKey);
handleDownImage();
},
{ preventDefault: true },
[areImagesBelowCurrent, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleDownImage]

View File

@@ -1,11 +1,11 @@
import { useAltModifier } from '@invoke-ai/ui-library';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
import { imageItemContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridItemContainer';
import { imageListContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridListContainer';
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
import { imageSelected, imageToCompareChanged } from 'features/gallery/store/gallerySlice';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { getIsVisible } from 'features/gallery/util/getIsVisible';
import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign';
import { clamp } from 'lodash-es';
@@ -106,12 +106,10 @@ const getImageFuncs = {
};
type UseGalleryNavigationReturn = {
handleLeftImage: (alt?: boolean) => void;
handleRightImage: (alt?: boolean) => void;
handleUpImage: (alt?: boolean) => void;
handleDownImage: (alt?: boolean) => void;
prevImage: () => void;
nextImage: () => void;
handleLeftImage: () => void;
handleRightImage: () => void;
handleUpImage: () => void;
handleDownImage: () => void;
isOnFirstImage: boolean;
isOnLastImage: boolean;
areImagesBelowCurrent: boolean;
@@ -125,15 +123,7 @@ type UseGalleryNavigationReturn = {
*/
export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
const dispatch = useAppDispatch();
const alt = useAltModifier();
const lastSelectedImage = useAppSelector((s) => {
const lastSelected = s.gallery.selection.slice(-1)[0] ?? null;
if (alt) {
return s.gallery.imageToCompare ?? lastSelected;
} else {
return lastSelected;
}
});
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
const {
queryResult: { data },
} = useGalleryImages();
@@ -146,7 +136,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
}, [lastSelectedImage, data]);
const handleNavigation = useCallback(
(direction: 'left' | 'right' | 'up' | 'down', alt?: boolean) => {
(direction: 'left' | 'right' | 'up' | 'down') => {
if (!data) {
return;
}
@@ -154,14 +144,10 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
if (!image || index === lastSelectedImageIndex) {
return;
}
if (alt) {
dispatch(imageToCompareChanged(image));
} else {
dispatch(imageSelected(image));
}
dispatch(imageSelected(image));
scrollToImage(image.image_name, index);
},
[data, lastSelectedImageIndex, dispatch]
[dispatch, lastSelectedImageIndex, data]
);
const isOnFirstImage = useMemo(() => lastSelectedImageIndex === 0, [lastSelectedImageIndex]);
@@ -176,41 +162,21 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
return lastSelectedImageIndex + imagesPerRow < loadedImagesCount;
}, [lastSelectedImageIndex, loadedImagesCount]);
const handleLeftImage = useCallback(
(alt?: boolean) => {
handleNavigation('left', alt);
},
[handleNavigation]
);
const handleLeftImage = useCallback(() => {
handleNavigation('left');
}, [handleNavigation]);
const handleRightImage = useCallback(
(alt?: boolean) => {
handleNavigation('right', alt);
},
[handleNavigation]
);
const handleRightImage = useCallback(() => {
handleNavigation('right');
}, [handleNavigation]);
const handleUpImage = useCallback(
(alt?: boolean) => {
handleNavigation('up', alt);
},
[handleNavigation]
);
const handleUpImage = useCallback(() => {
handleNavigation('up');
}, [handleNavigation]);
const handleDownImage = useCallback(
(alt?: boolean) => {
handleNavigation('down', alt);
},
[handleNavigation]
);
const nextImage = useCallback(() => {
handleRightImage();
}, [handleRightImage]);
const prevImage = useCallback(() => {
handleLeftImage();
}, [handleLeftImage]);
const handleDownImage = useCallback(() => {
handleNavigation('down');
}, [handleNavigation]);
return {
handleLeftImage,
@@ -220,7 +186,5 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
isOnFirstImage,
isOnLastImage,
areImagesBelowCurrent,
nextImage,
prevImage,
};
};

View File

@@ -36,7 +36,6 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
shiftKey: e.shiftKey,
ctrlKey: e.ctrlKey,
metaKey: e.metaKey,
altKey: e.altKey,
})
);
},

View File

@@ -6,7 +6,7 @@ import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import type { BoardId, ComparisonMode, GalleryState, GalleryView } from './types';
import type { BoardId, GalleryState, GalleryView } from './types';
import { IMAGE_LIMIT, INITIAL_IMAGE_LIMIT } from './types';
const initialGalleryState: GalleryState = {
@@ -22,9 +22,6 @@ const initialGalleryState: GalleryState = {
limit: INITIAL_IMAGE_LIMIT,
offset: 0,
isImageViewerOpen: true,
imageToCompare: null,
comparisonMode: 'slider',
comparisonFit: 'fill',
};
export const gallerySlice = createSlice({
@@ -37,28 +34,6 @@ export const gallerySlice = createSlice({
selectionChanged: (state, action: PayloadAction<ImageDTO[]>) => {
state.selection = uniqBy(action.payload, (i) => i.image_name);
},
imageToCompareChanged: (state, action: PayloadAction<ImageDTO | null>) => {
state.imageToCompare = action.payload;
if (action.payload) {
state.isImageViewerOpen = true;
}
},
comparisonModeChanged: (state, action: PayloadAction<ComparisonMode>) => {
state.comparisonMode = action.payload;
},
comparisonModeCycled: (state) => {
switch (state.comparisonMode) {
case 'slider':
state.comparisonMode = 'side-by-side';
break;
case 'side-by-side':
state.comparisonMode = 'hover';
break;
case 'hover':
state.comparisonMode = 'slider';
break;
}
},
shouldAutoSwitchChanged: (state, action: PayloadAction<boolean>) => {
state.shouldAutoSwitch = action.payload;
},
@@ -104,16 +79,6 @@ export const gallerySlice = createSlice({
isImageViewerOpenChanged: (state, action: PayloadAction<boolean>) => {
state.isImageViewerOpen = action.payload;
},
comparedImagesSwapped: (state) => {
if (state.imageToCompare) {
const oldSelection = state.selection;
state.selection = [state.imageToCompare];
state.imageToCompare = oldSelection[0] ?? null;
}
},
comparisonFitChanged: (state, action: PayloadAction<'contain' | 'fill'>) => {
state.comparisonFit = action.payload;
},
},
extraReducers: (builder) => {
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
@@ -152,11 +117,6 @@ export const {
moreImagesLoaded,
alwaysShowImageSizeBadgeChanged,
isImageViewerOpenChanged,
imageToCompareChanged,
comparisonModeChanged,
comparedImagesSwapped,
comparisonFitChanged,
comparisonModeCycled,
} = gallerySlice.actions;
const isAnyBoardDeleted = isAnyOf(
@@ -178,13 +138,5 @@ export const galleryPersistConfig: PersistConfig<GalleryState> = {
name: gallerySlice.name,
initialState: initialGalleryState,
migrate: migrateGalleryState,
persistDenylist: [
'selection',
'selectedBoardId',
'galleryView',
'offset',
'limit',
'isImageViewerOpen',
'imageToCompare',
],
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'offset', 'limit', 'isImageViewerOpen'],
};

View File

@@ -7,8 +7,6 @@ export const IMAGE_LIMIT = 20;
export type GalleryView = 'images' | 'assets';
export type BoardId = 'none' | (string & Record<never, never>);
export type ComparisonMode = 'slider' | 'side-by-side' | 'hover';
export type ComparisonFit = 'contain' | 'fill';
export type GalleryState = {
selection: ImageDTO[];
@@ -22,8 +20,5 @@ export type GalleryState = {
offset: number;
limit: number;
alwaysShowImageSizeBadge: boolean;
imageToCompare: ImageDTO | null;
comparisonMode: ComparisonMode;
comparisonFit: ComparisonFit;
isImageViewerOpen: boolean;
};

View File

@@ -1,7 +1,4 @@
import { getStore } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone';
import { objectKeys } from 'common/util/objectKeys';
import { shouldConcatPromptsChanged } from 'features/controlLayers/store/controlLayersSlice';
import type { Layer } from 'features/controlLayers/store/types';
import type { LoRA } from 'features/lora/store/loraSlice';
import type {
@@ -19,7 +16,6 @@ import { validators } from 'features/metadata/util/validators';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { toast } from 'features/toast/toast';
import { t } from 'i18next';
import { size } from 'lodash-es';
import { assert } from 'tsafe';
import { parsers } from './parsers';
@@ -380,25 +376,54 @@ export const handlers = {
}),
} as const;
type ParsedValue = Awaited<ReturnType<(typeof handlers)[keyof typeof handlers]['parse']>>;
type RecallResults = Partial<Record<keyof typeof handlers, ParsedValue>>;
export const parseAndRecallPrompts = async (metadata: unknown) => {
const keysToRecall: (keyof typeof handlers)[] = [
'positivePrompt',
'negativePrompt',
'sdxlPositiveStylePrompt',
'sdxlNegativeStylePrompt',
];
const recalled = await recallKeys(keysToRecall, metadata);
if (size(recalled) > 0) {
const results = await Promise.allSettled([
handlers.positivePrompt.parse(metadata).then((positivePrompt) => {
if (!handlers.positivePrompt.recall) {
return;
}
handlers.positivePrompt?.recall(positivePrompt);
}),
handlers.negativePrompt.parse(metadata).then((negativePrompt) => {
if (!handlers.negativePrompt.recall) {
return;
}
handlers.negativePrompt?.recall(negativePrompt);
}),
handlers.sdxlPositiveStylePrompt.parse(metadata).then((sdxlPositiveStylePrompt) => {
if (!handlers.sdxlPositiveStylePrompt.recall) {
return;
}
handlers.sdxlPositiveStylePrompt?.recall(sdxlPositiveStylePrompt);
}),
handlers.sdxlNegativeStylePrompt.parse(metadata).then((sdxlNegativeStylePrompt) => {
if (!handlers.sdxlNegativeStylePrompt.recall) {
return;
}
handlers.sdxlNegativeStylePrompt?.recall(sdxlNegativeStylePrompt);
}),
]);
if (results.some((result) => result.status === 'fulfilled')) {
parameterSetToast(t('metadata.allPrompts'));
}
};
export const parseAndRecallImageDimensions = async (metadata: unknown) => {
const recalled = recallKeys(['width', 'height'], metadata);
if (size(recalled) > 0) {
const results = await Promise.allSettled([
handlers.width.parse(metadata).then((width) => {
if (!handlers.width.recall) {
return;
}
handlers.width?.recall(width);
}),
handlers.height.parse(metadata).then((height) => {
if (!handlers.height.recall) {
return;
}
handlers.height?.recall(height);
}),
]);
if (results.some((result) => result.status === 'fulfilled')) {
parameterSetToast(t('metadata.imageDimensions'));
}
};
@@ -413,20 +438,28 @@ export const parseAndRecallAllMetadata = async (
toControlLayers: boolean,
skip: (keyof typeof handlers)[] = []
) => {
const skipKeys = deepClone(skip);
const skipKeys = skip ?? [];
if (toControlLayers) {
skipKeys.push(...TO_CONTROL_LAYERS_SKIP_KEYS);
} else {
skipKeys.push(...NOT_TO_CONTROL_LAYERS_SKIP_KEYS);
}
const results = await Promise.allSettled(
objectKeys(handlers)
.filter((key) => !skipKeys.includes(key))
.map((key) => {
const { parse, recall } = handlers[key];
return parse(metadata).then((value) => {
if (!recall) {
return;
}
/* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */
recall(value);
});
})
);
// We may need to take some further action depending on what was recalled. For example, we need to disable SDXL prompt
// concat if the negative or positive style prompt was set. Because the recalling is all async, we need to collect all
// results
const keysToRecall = objectKeys(handlers).filter((key) => !skipKeys.includes(key));
const recalled = await recallKeys(keysToRecall, metadata);
if (size(recalled) > 0) {
if (results.some((result) => result.status === 'fulfilled')) {
toast({
id: 'PARAMETER_SET',
title: t('toast.parametersSet'),
@@ -440,43 +473,3 @@ export const parseAndRecallAllMetadata = async (
});
}
};
/**
* Recalls a set of keys from metadata.
* Includes special handling for some metadata where recalling may have side effects. For example, recalling a "style"
* prompt that is different from the "positive" or "negative" prompt should disable prompt concatenation.
* @param keysToRecall An array of keys to recall.
* @param metadata The metadata to recall from
* @returns A promise that resolves to an object containing the recalled values.
*/
const recallKeys = async (keysToRecall: (keyof typeof handlers)[], metadata: unknown): Promise<RecallResults> => {
const { dispatch } = getStore();
const recalled: RecallResults = {};
for (const key of keysToRecall) {
const { parse, recall } = handlers[key];
if (!recall) {
continue;
}
try {
const value = await parse(metadata);
/* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */
await recall(value);
recalled[key] = value;
} catch {
// no-op
}
}
if (
(recalled['sdxlPositiveStylePrompt'] && recalled['sdxlPositiveStylePrompt'] !== recalled['positivePrompt']) ||
(recalled['sdxlNegativeStylePrompt'] && recalled['sdxlNegativeStylePrompt'] !== recalled['negativePrompt'])
) {
// If we set the negative style prompt or positive style prompt, we should disable prompt concat
dispatch(shouldConcatPromptsChanged(false));
} else {
// Otherwise, we should enable prompt concat
dispatch(shouldConcatPromptsChanged(true));
}
return recalled;
};

View File

@@ -1,7 +1,6 @@
import { getStore } from 'app/store/nanostores/store';
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
import type { ModelIdentifier } from 'features/nodes/types/v2/common';
import { modelsApi } from 'services/api/endpoints/models';
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
@@ -108,30 +107,19 @@ export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
/**
* Fetches the model key from a model identifier. This includes fetching the key for MM1 format model identifiers.
* @param modelIdentifier The model identifier. This can be a MM1 or MM2 identifier. In every case, we attempt to fetch
* the model config from the server to ensure that the model identifier is valid and represents an installed model.
* @param modelIdentifier The model identifier. The MM2 format `{key: string}` simply extracts the key. The MM1 format
* `{model_name: string, base_model: BaseModelType}` must do a network request to fetch the key.
* @param type The type of model to fetch. This is used to fetch the key for MM1 format model identifiers.
* @param message An optional custom message to include in the error if the model identifier is invalid.
* @returns A promise that resolves to the model key.
* @throws {InvalidModelConfigError} If the model identifier is invalid.
*/
export const getModelKey = async (
modelIdentifier: unknown | ModelIdentifierField | ModelIdentifier,
type: ModelType,
message?: string
): Promise<string> => {
export const getModelKey = async (modelIdentifier: unknown, type: ModelType, message?: string): Promise<string> => {
if (isModelIdentifier(modelIdentifier)) {
try {
// Check if the model exists by key
return (await fetchModelConfig(modelIdentifier.key)).key;
} catch {
// If not, fetch the model key by name and base model
return (await fetchModelConfigByAttrs(modelIdentifier.name, modelIdentifier.base, type)).key;
}
} else if (isModelIdentifierV2(modelIdentifier)) {
// Try by old-format model identifier
return modelIdentifier.key;
}
if (isModelIdentifierV2(modelIdentifier)) {
return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key;
}
// Nope, couldn't find it
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
};

View File

@@ -72,12 +72,10 @@ export const ModelEdit = ({ form }: Props) => {
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
<BaseModelSelect control={form.control} />
</FormControl>
{data.type === 'main' && (
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={form.control} />
</FormControl>
)}
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
<FormLabel>{t('modelManager.variant')}</FormLabel>
<ModelVariantSelect control={form.control} />
</FormControl>
{data.type === 'main' && data.format === 'checkpoint' && (
<>
<FormControl flexDir="column" alignItems="flex-start" gap={1}>

View File

@@ -19,7 +19,7 @@ import {
redo,
undo,
} from 'features/nodes/store/nodesSlice';
import { $flow, $needsFit } from 'features/nodes/store/reactFlowInstance';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import type { CSSProperties, MouseEvent } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
@@ -68,7 +68,6 @@ export const Flow = memo(() => {
const nodes = useAppSelector((s) => s.nodes.present.nodes);
const edges = useAppSelector((s) => s.nodes.present.edges);
const viewport = useStore($viewport);
const needsFit = useStore($needsFit);
const mayUndo = useAppSelector((s) => s.nodes.past.length > 0);
const mayRedo = useAppSelector((s) => s.nodes.future.length > 0);
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
@@ -93,16 +92,8 @@ export const Flow = memo(() => {
const onNodesChange: OnNodesChange = useCallback(
(nodeChanges) => {
dispatch(nodesChanged(nodeChanges));
const flow = $flow.get();
if (!flow) {
return;
}
if (needsFit) {
$needsFit.set(false);
flow.fitView();
}
},
[dispatch, needsFit]
[dispatch]
);
const onEdgesChange: OnEdgesChange = useCallback(

View File

@@ -15,20 +15,27 @@ const ViewportControls = () => {
const { t } = useTranslation();
const { zoomIn, zoomOut, fitView } = useReactFlow();
const dispatch = useAppDispatch();
// const shouldShowFieldTypeLegend = useAppSelector(
// (s) => s.nodes.present.shouldShowFieldTypeLegend
// );
const shouldShowMinimapPanel = useAppSelector((s) => s.workflowSettings.shouldShowMinimapPanel);
const handleClickedZoomIn = useCallback(() => {
zoomIn({ duration: 300 });
zoomIn();
}, [zoomIn]);
const handleClickedZoomOut = useCallback(() => {
zoomOut({ duration: 300 });
zoomOut();
}, [zoomOut]);
const handleClickedFitView = useCallback(() => {
fitView({ duration: 300 });
fitView();
}, [fitView]);
// const handleClickedToggleFieldTypeLegend = useCallback(() => {
// dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
// }, [shouldShowFieldTypeLegend, dispatch]);
const handleClickedToggleMiniMapPanel = useCallback(() => {
dispatch(shouldShowMinimapPanelChanged(!shouldShowMinimapPanel));
}, [shouldShowMinimapPanel, dispatch]);

View File

@@ -1,7 +1,9 @@
import 'reactflow/dist/style.css';
import { Flex } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
import QueueControls from 'features/queue/components/QueueControls';
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
import { usePanelStorage } from 'features/ui/hooks/usePanelStorage';
@@ -19,8 +21,14 @@ import { WorkflowName } from './WorkflowName';
const panelGroupStyles: CSSProperties = { height: '100%', width: '100%' };
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
return {
mode: workflow.mode,
};
});
const NodeEditorPanelGroup = () => {
const mode = useAppSelector((s) => s.workflow.mode);
const { mode } = useAppSelector(selector);
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
const panelStorage = usePanelStorage();

Some files were not shown because too many files have changed in this diff Show More