mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-16 14:28:03 -05:00
Compare commits
31 Commits
v4.2.4
...
psyche/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc78a0e699 | ||
|
|
08a42c3c03 | ||
|
|
0758e9cb9b | ||
|
|
fb93e686b2 | ||
|
|
350feeed56 | ||
|
|
169b75b2b7 | ||
|
|
c88de180e7 | ||
|
|
7d1844eaf2 | ||
|
|
a98ddedb95 | ||
|
|
6063487b20 | ||
|
|
9a4c167342 | ||
|
|
19227fe4e6 | ||
|
|
db0ef8d316 | ||
|
|
6a34176376 | ||
|
|
d6696a7b97 | ||
|
|
0e81e7b460 | ||
|
|
7652fbc2e9 | ||
|
|
a55b2f09e2 | ||
|
|
23b05344a3 | ||
|
|
80905ff3ea | ||
|
|
df5457231f | ||
|
|
d30c1ad6dc | ||
|
|
b1f819ae8d | ||
|
|
eff359625a | ||
|
|
cef1585dfb | ||
|
|
cb8e9e1c7b | ||
|
|
f7c356d142 | ||
|
|
efb069dd71 | ||
|
|
8edc25d35a | ||
|
|
82957bb826 | ||
|
|
e51a3025ea |
4
Makefile
4
Makefile
@@ -18,7 +18,6 @@ help:
|
||||
@echo "frontend-typegen Generate types for the frontend from the OpenAPI schema"
|
||||
@echo "installer-zip Build the installer .zip file for the current version"
|
||||
@echo "tag-release Tag the GitHub repository with the current version (use at release time only!)"
|
||||
@echo "openapi Generate the OpenAPI schema for the app, outputting to stdout"
|
||||
|
||||
# Runs ruff, fixing any safely-fixable errors and formatting
|
||||
ruff:
|
||||
@@ -71,6 +70,3 @@ installer-zip:
|
||||
tag-release:
|
||||
cd installer && ./tag_release.sh
|
||||
|
||||
# Generate the OpenAPI Schema for the app
|
||||
openapi:
|
||||
python scripts/generate_openapi_schema.py
|
||||
|
||||
@@ -64,7 +64,7 @@ GPU_DRIVER=nvidia
|
||||
|
||||
Any environment variables supported by InvokeAI can be set here - please see the [Configuration docs](https://invoke-ai.github.io/InvokeAI/features/CONFIGURATION/) for further detail.
|
||||
|
||||
## Even More Customizing!
|
||||
## Even Moar Customizing!
|
||||
|
||||
See the `docker-compose.yml` file. The `command` instruction can be uncommented and used to run arbitrary startup commands. Some examples below.
|
||||
|
||||
|
||||
@@ -154,18 +154,6 @@ This is caused by an invalid setting in the `invokeai.yaml` configuration file.
|
||||
|
||||
Check the [configuration docs] for more detail about the settings and how to specify them.
|
||||
|
||||
## `ModuleNotFoundError: No module named 'controlnet_aux'`
|
||||
|
||||
`controlnet_aux` is a dependency of Invoke and appears to have been packaged or distributed strangely. Sometimes, it doesn't install correctly. This is outside our control.
|
||||
|
||||
If you encounter this error, the solution is to remove the package from the `pip` cache and re-run the Invoke installer so a fresh, working version of `controlnet_aux` can be downloaded and installed:
|
||||
|
||||
- Run the Invoke launcher
|
||||
- Choose the developer console option
|
||||
- Run this command: `pip cache remove controlnet_aux`
|
||||
- Close the terminal window
|
||||
- Download and run the [installer](https://github.com/invoke-ai/InvokeAI/releases/latest), selecting your current install location
|
||||
|
||||
## Out of Memory Issues
|
||||
|
||||
The models are large, VRAM is expensive, and you may find yourself
|
||||
|
||||
@@ -20,7 +20,7 @@ When you generate an image using text-to-image, multiple steps occur in latent s
|
||||
4. The VAE decodes the final latent image from latent space into image space.
|
||||
|
||||
Image-to-image is a similar process, with only step 1 being different:
|
||||
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how many noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
|
||||
1. The input image is encoded from image space into latent space by the VAE. Noise is then added to the input latent image. Denoising Strength dictates how may noise steps are added, and the amount of noise added at each step. A Denoising Strength of 0 means there are 0 steps and no noise added, resulting in an unchanged image, while a Denoising Strength of 1 results in the image being completely replaced with noise and a full set of denoising steps are performance. The process is then the same as steps 2-4 in the text-to-image process.
|
||||
|
||||
Furthermore, a model provides the CLIP prompt tokenizer, the VAE, and a U-Net (where noise prediction occurs given a prompt and initial noise tensor).
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from ..services.boards.boards_default import BoardService
|
||||
from ..services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from ..services.config import InvokeAIAppConfig
|
||||
from ..services.download import DownloadQueueService
|
||||
from ..services.events.events_fastapievents import FastAPIEventService
|
||||
from ..services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from ..services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||
from ..services.images.images_default import ImageService
|
||||
@@ -34,6 +33,7 @@ from ..services.session_processor.session_processor_default import DefaultSessio
|
||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from ..services.urls.urls_default import LocalUrlService
|
||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from .events import FastAPIEventService
|
||||
|
||||
|
||||
# TODO: is there a better way to achieve this?
|
||||
@@ -103,6 +103,7 @@ class ApiDependencies:
|
||||
)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
|
||||
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
|
||||
52
invokeai/app/api/events.py
Normal file
52
invokeai/app/api/events.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
from typing import Any
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from ..services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
event_handler_id: int
|
||||
__queue: Queue
|
||||
__stop_event: threading.Event
|
||||
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self.__queue = Queue()
|
||||
self.__stop_event = threading.Event()
|
||||
asyncio.create_task(self.__dispatch_from_queue(stop_event=self.__stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self.__stop_event.set()
|
||||
self.__queue.put(None)
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
self.__queue.put({"event_name": event_name, "payload": payload})
|
||||
|
||||
async def __dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self.__queue.get(block=False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
|
||||
dispatch(
|
||||
event.get("event_name"),
|
||||
payload=event.get("payload"),
|
||||
middleware_id=self.event_handler_id,
|
||||
)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
||||
@@ -17,7 +17,7 @@ from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.model_images.model_images_common import ModelImageFileNotFoundException
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.model_install import ModelInstallJob
|
||||
from invokeai.app.services.model_records import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
|
||||
@@ -1,125 +1,66 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from socketio import ASGIApp, AsyncServer
|
||||
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
BulkDownloadCompleteEvent,
|
||||
BulkDownloadErrorEvent,
|
||||
BulkDownloadEventBase,
|
||||
BulkDownloadStartedEvent,
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadEventBase,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
FastAPIEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
ModelEventBase,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallErrorEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelLoadStartedEvent,
|
||||
QueueClearedEvent,
|
||||
QueueEventBase,
|
||||
QueueItemStatusChangedEvent,
|
||||
register_events,
|
||||
)
|
||||
|
||||
|
||||
class QueueSubscriptionEvent(BaseModel):
|
||||
"""Event data for subscribing to the socket.io queue room.
|
||||
This is a pydantic model to ensure the data is in the correct format."""
|
||||
|
||||
queue_id: str
|
||||
|
||||
|
||||
class BulkDownloadSubscriptionEvent(BaseModel):
|
||||
"""Event data for subscribing to the socket.io bulk downloads room.
|
||||
This is a pydantic model to ensure the data is in the correct format."""
|
||||
|
||||
bulk_download_id: str
|
||||
|
||||
|
||||
QUEUE_EVENTS = {
|
||||
InvocationStartedEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationCompleteEvent,
|
||||
InvocationErrorEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
BatchEnqueuedEvent,
|
||||
QueueClearedEvent,
|
||||
}
|
||||
|
||||
MODEL_EVENTS = {
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
ModelLoadStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallErrorEvent,
|
||||
}
|
||||
|
||||
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
|
||||
from ..services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class SocketIO:
|
||||
_sub_queue = "subscribe_queue"
|
||||
_unsub_queue = "unsubscribe_queue"
|
||||
__sio: AsyncServer
|
||||
__app: ASGIApp
|
||||
|
||||
_sub_bulk_download = "subscribe_bulk_download"
|
||||
_unsub_bulk_download = "unsubscribe_bulk_download"
|
||||
__sub_queue: str = "subscribe_queue"
|
||||
__unsub_queue: str = "unsubscribe_queue"
|
||||
|
||||
__sub_bulk_download: str = "subscribe_bulk_download"
|
||||
__unsub_bulk_download: str = "unsubscribe_bulk_download"
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
|
||||
app.mount("/ws", self._app)
|
||||
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
||||
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
|
||||
app.mount("/ws", self.__app)
|
||||
|
||||
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
|
||||
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
|
||||
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
|
||||
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
|
||||
self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
|
||||
self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
|
||||
|
||||
register_events(QUEUE_EVENTS, self._handle_queue_event)
|
||||
register_events(MODEL_EVENTS, self._handle_model_event)
|
||||
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
|
||||
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
|
||||
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
|
||||
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
|
||||
|
||||
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
|
||||
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
||||
async def _handle_queue_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["queue_id"],
|
||||
)
|
||||
|
||||
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
||||
async def _handle_sub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||
if "queue_id" in data:
|
||||
await self.__sio.enter_room(sid, data["queue_id"])
|
||||
|
||||
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
async def _handle_unsub_queue(self, sid, data, *args, **kwargs) -> None:
|
||||
if "queue_id" in data:
|
||||
await self.__sio.leave_room(sid, data["queue_id"])
|
||||
|
||||
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
|
||||
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
||||
async def _handle_model_event(self, event: Event) -> None:
|
||||
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
|
||||
|
||||
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
|
||||
async def _handle_bulk_download_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["bulk_download_id"],
|
||||
)
|
||||
|
||||
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase | DownloadEventBase]) -> None:
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
|
||||
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
|
||||
if "bulk_download_id" in data:
|
||||
await self.__sio.enter_room(sid, data["bulk_download_id"])
|
||||
|
||||
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
|
||||
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
|
||||
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
|
||||
if "bulk_download_id" in data:
|
||||
await self.__sio.leave_room(sid, data["bulk_download_id"])
|
||||
|
||||
@@ -3,7 +3,9 @@ import logging
|
||||
import mimetypes
|
||||
import socket
|
||||
from contextlib import asynccontextmanager
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
@@ -11,9 +13,11 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.middleware import EventHandlerASGIMiddleware
|
||||
from pydantic.json_schema import models_json_schema
|
||||
from torch.backends.mps import is_available as is_mps_available
|
||||
|
||||
# for PyCharm:
|
||||
@@ -21,8 +25,9 @@ from torch.backends.mps import is_available as is_mps_available
|
||||
import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.util.custom_openapi import get_openapi_func
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from ..backend.util.logging import InvokeAILogger
|
||||
@@ -39,6 +44,11 @@ from .api.routers import (
|
||||
workflows,
|
||||
)
|
||||
from .api.sockets import SocketIO
|
||||
from .invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
UIConfigBase,
|
||||
)
|
||||
from .invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
|
||||
app_config = get_config()
|
||||
|
||||
@@ -108,7 +118,93 @@ app.include_router(app_info.app_router, prefix="/api")
|
||||
app.include_router(session_queue.session_queue_router, prefix="/api")
|
||||
app.include_router(workflows.workflows_router, prefix="/api")
|
||||
|
||||
app.openapi = get_openapi_func(app)
|
||||
|
||||
# Build a custom OpenAPI to include all outputs
|
||||
# TODO: can outputs be included on metadata of invocation schemas somehow?
|
||||
def custom_openapi() -> dict[str, Any]:
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||
)
|
||||
|
||||
# Add all outputs
|
||||
all_invocations = BaseInvocation.get_invocations()
|
||||
output_types = set()
|
||||
output_type_titles = {}
|
||||
for invoker in all_invocations:
|
||||
output_type = signature(invoker.invoke).return_annotation
|
||||
output_types.add(output_type)
|
||||
|
||||
output_schemas = models_json_schema(
|
||||
models=[(o, "serialization") for o in output_types], ref_template="#/components/schemas/{model}"
|
||||
)
|
||||
for schema_key, output_schema in output_schemas[1]["$defs"].items():
|
||||
# TODO: note that we assume the schema_key here is the TYPE.__name__
|
||||
# This could break in some cases, figure out a better way to do it
|
||||
output_type_titles[schema_key] = output_schema["title"]
|
||||
openapi_schema["components"]["schemas"][schema_key] = output_schema
|
||||
openapi_schema["components"]["schemas"][schema_key]["class"] = "output"
|
||||
|
||||
# Some models don't end up in the schemas as standalone definitions
|
||||
additional_schemas = models_json_schema(
|
||||
[
|
||||
(UIConfigBase, "serialization"),
|
||||
(InputFieldJSONSchemaExtra, "serialization"),
|
||||
(OutputFieldJSONSchemaExtra, "serialization"),
|
||||
(ModelIdentifierField, "serialization"),
|
||||
(ProgressImage, "serialization"),
|
||||
],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
for schema_key, schema_json in additional_schemas[1]["$defs"].items():
|
||||
openapi_schema["components"]["schemas"][schema_key] = schema_json
|
||||
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
# Add a reference to the output type to additionalProperties of the invoker schema
|
||||
for invoker in all_invocations:
|
||||
invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute
|
||||
output_type = signature(obj=invoker.invoke).return_annotation
|
||||
output_type_title = output_type_titles[output_type.__name__]
|
||||
invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"]
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_type_title}"}
|
||||
invoker_schema["output"] = outputs_ref
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["properties"][invoker.get_type()] = outputs_ref
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"]["required"].append(invoker.get_type())
|
||||
invoker_schema["class"] = "invocation"
|
||||
|
||||
# This code no longer seems to be necessary?
|
||||
# Leave it here just in case
|
||||
#
|
||||
# from invokeai.backend.model_manager import get_model_config_formats
|
||||
# formats = get_model_config_formats()
|
||||
# for model_config_name, enum_set in formats.items():
|
||||
|
||||
# if model_config_name in openapi_schema["components"]["schemas"]:
|
||||
# # print(f"Config with name {name} already defined")
|
||||
# continue
|
||||
|
||||
# openapi_schema["components"]["schemas"][model_config_name] = {
|
||||
# "title": model_config_name,
|
||||
# "description": "An enumeration.",
|
||||
# "type": "string",
|
||||
# "enum": [v.value for v in enum_set],
|
||||
# }
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi # type: ignore [method-assign] # this is a valid assignment
|
||||
|
||||
|
||||
@app.get("/docs", include_in_schema=False)
|
||||
|
||||
@@ -98,13 +98,11 @@ class BaseInvocationOutput(BaseModel):
|
||||
|
||||
_output_classes: ClassVar[set[BaseInvocationOutput]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def register_output(cls, output: BaseInvocationOutput) -> None:
|
||||
"""Registers an invocation output."""
|
||||
cls._output_classes.add(output)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_outputs(cls) -> Iterable[BaseInvocationOutput]:
|
||||
@@ -114,12 +112,11 @@ class BaseInvocationOutput(BaseModel):
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation output types."""
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocationOutput = TypeAliasType(
|
||||
"AnyInvocationOutput", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
if not cls._typeadapter:
|
||||
InvocationOutputsUnion = TypeAliasType(
|
||||
"InvocationOutputsUnion", Annotated[Union[tuple(cls._output_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocationOutput)
|
||||
cls._typeadapter_needs_update = False
|
||||
cls._typeadapter = TypeAdapter(InvocationOutputsUnion)
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
@@ -128,13 +125,12 @@ class BaseInvocationOutput(BaseModel):
|
||||
return (i.get_type() for i in BaseInvocationOutput.get_outputs())
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocationOutput]) -> None:
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation output's OpenAPI schema."""
|
||||
# Because we use a pydantic Literal field with default value for the invocation type,
|
||||
# it will be typed as optional in the OpenAPI schema. Make it required manually.
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
schema["class"] = "output"
|
||||
schema["required"].extend(["type"])
|
||||
|
||||
@classmethod
|
||||
@@ -171,7 +167,6 @@ class BaseInvocation(ABC, BaseModel):
|
||||
|
||||
_invocation_classes: ClassVar[set[BaseInvocation]] = set()
|
||||
_typeadapter: ClassVar[Optional[TypeAdapter[Any]]] = None
|
||||
_typeadapter_needs_update: ClassVar[bool] = False
|
||||
|
||||
@classmethod
|
||||
def get_type(cls) -> str:
|
||||
@@ -182,17 +177,15 @@ class BaseInvocation(ABC, BaseModel):
|
||||
def register_invocation(cls, invocation: BaseInvocation) -> None:
|
||||
"""Registers an invocation."""
|
||||
cls._invocation_classes.add(invocation)
|
||||
cls._typeadapter_needs_update = True
|
||||
|
||||
@classmethod
|
||||
def get_typeadapter(cls) -> TypeAdapter[Any]:
|
||||
"""Gets a pydantc TypeAdapter for the union of all invocation types."""
|
||||
if not cls._typeadapter or cls._typeadapter_needs_update:
|
||||
AnyInvocation = TypeAliasType(
|
||||
"AnyInvocation", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
if not cls._typeadapter:
|
||||
InvocationsUnion = TypeAliasType(
|
||||
"InvocationsUnion", Annotated[Union[tuple(cls._invocation_classes)], Field(discriminator="type")]
|
||||
)
|
||||
cls._typeadapter = TypeAdapter(AnyInvocation)
|
||||
cls._typeadapter_needs_update = False
|
||||
cls._typeadapter = TypeAdapter(InvocationsUnion)
|
||||
return cls._typeadapter
|
||||
|
||||
@classmethod
|
||||
@@ -228,7 +221,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
return signature(cls.invoke).return_annotation
|
||||
|
||||
@staticmethod
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseInvocation]) -> None:
|
||||
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None:
|
||||
"""Adds various UI-facing attributes to the invocation's OpenAPI schema."""
|
||||
uiconfig = cast(UIConfigBase | None, getattr(model_class, "UIConfig", None))
|
||||
if uiconfig is not None:
|
||||
@@ -244,7 +237,6 @@ class BaseInvocation(ABC, BaseModel):
|
||||
schema["version"] = uiconfig.version
|
||||
if "required" not in schema or not isinstance(schema["required"], list):
|
||||
schema["required"] = []
|
||||
schema["class"] = "invocation"
|
||||
schema["required"].extend(["type", "id"])
|
||||
|
||||
@abstractmethod
|
||||
@@ -318,7 +310,7 @@ class BaseInvocation(ABC, BaseModel):
|
||||
protected_namespaces=(),
|
||||
validate_assignment=True,
|
||||
json_schema_extra=json_schema_extra,
|
||||
json_schema_serialization_defaults_required=False,
|
||||
json_schema_serialization_defaults_required=True,
|
||||
coerce_numbers_to_str=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -65,7 +65,11 @@ class CompelInvocation(BaseInvocation):
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
||||
tokenizer_info = context.models.load(self.clip.tokenizer)
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(self.clip.text_encoder)
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, CLIPTextModel)
|
||||
|
||||
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.clip.loras:
|
||||
@@ -80,21 +84,19 @@ class CompelInvocation(BaseInvocation):
|
||||
ti_list = generate_ti_list(self.prompt, text_encoder_info.config.base, context)
|
||||
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info as text_encoder,
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
||||
patched_tokenizer,
|
||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, self.clip.skipped_layers),
|
||||
):
|
||||
assert isinstance(text_encoder, CLIPTextModel)
|
||||
assert isinstance(tokenizer, CLIPTokenizer)
|
||||
compel = Compel(
|
||||
tokenizer=patched_tokenizer,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
@@ -104,7 +106,7 @@ class CompelInvocation(BaseInvocation):
|
||||
conjunction = Compel.parse_prompt_string(self.prompt)
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
|
||||
@@ -134,7 +136,11 @@ class SDXLPromptInvocationBase:
|
||||
zero_on_empty: bool,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
tokenizer_info = context.models.load(clip_field.tokenizer)
|
||||
tokenizer_model = tokenizer_info.model
|
||||
assert isinstance(tokenizer_model, CLIPTokenizer)
|
||||
text_encoder_info = context.models.load(clip_field.text_encoder)
|
||||
text_encoder_model = text_encoder_info.model
|
||||
assert isinstance(text_encoder_model, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
|
||||
# return zero on empty
|
||||
if prompt == "" and zero_on_empty:
|
||||
@@ -171,23 +177,20 @@ class SDXLPromptInvocationBase:
|
||||
ti_list = generate_ti_list(prompt, text_encoder_info.config.base, context)
|
||||
|
||||
with (
|
||||
# apply all patches while the model is on the target device
|
||||
text_encoder_info as text_encoder,
|
||||
tokenizer_info as tokenizer,
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
||||
patched_tokenizer,
|
||||
ModelPatcher.apply_ti(tokenizer_model, text_encoder_model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
text_encoder_info as text_encoder,
|
||||
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||
ModelPatcher.apply_clip_skip(text_encoder_model, clip_field.skipped_layers),
|
||||
):
|
||||
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
|
||||
assert isinstance(tokenizer, CLIPTokenizer)
|
||||
|
||||
text_encoder = cast(CLIPTextModel, text_encoder)
|
||||
compel = Compel(
|
||||
tokenizer=patched_tokenizer,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
textual_inversion_manager=ti_manager,
|
||||
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
|
||||
@@ -200,7 +203,7 @@ class SDXLPromptInvocationBase:
|
||||
|
||||
if context.config.get().log_tokenization:
|
||||
# TODO: better logging for and syntax
|
||||
log_tokenization_for_conjunction(conjunction, patched_tokenizer)
|
||||
log_tokenization_for_conjunction(conjunction, tokenizer)
|
||||
|
||||
# TODO: ask for optimizations? to not run text_encoder twice
|
||||
c, _options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
|
||||
@@ -50,7 +50,7 @@ from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput,
|
||||
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
||||
from invokeai.backend.model_manager.config import MainConfigBase, ModelVariantType
|
||||
@@ -672,52 +672,54 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
return controlnet_data
|
||||
|
||||
def prep_ip_adapter_image_prompts(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
ip_adapters: List[IPAdapterField],
|
||||
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
|
||||
image_prompts = []
|
||||
for single_ip_adapter in ip_adapters:
|
||||
with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model:
|
||||
assert isinstance(ip_adapter_model, IPAdapter)
|
||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
single_ipa_image_fields = [single_ipa_image_fields]
|
||||
|
||||
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
single_ipa_images, image_encoder_model
|
||||
)
|
||||
image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds))
|
||||
|
||||
return image_prompts
|
||||
|
||||
def prep_ip_adapter_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
ip_adapters: List[IPAdapterField],
|
||||
image_prompts: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]],
|
||||
exit_stack: ExitStack,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
) -> Optional[List[IPAdapterData]]:
|
||||
"""If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data."""
|
||||
ip_adapter_data_list = []
|
||||
for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip(
|
||||
ip_adapters, image_prompts, strict=True
|
||||
):
|
||||
ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model))
|
||||
) -> Optional[list[IPAdapterData]]:
|
||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
||||
to the `conditioning_data` (in-place).
|
||||
"""
|
||||
if ip_adapter is None:
|
||||
return None
|
||||
|
||||
mask_field = single_ip_adapter.mask
|
||||
mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None
|
||||
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
||||
if not isinstance(ip_adapter, list):
|
||||
ip_adapter = [ip_adapter]
|
||||
|
||||
if len(ip_adapter) == 0:
|
||||
return None
|
||||
|
||||
ip_adapter_data_list = []
|
||||
for single_ip_adapter in ip_adapter:
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.models.load(single_ip_adapter.ip_adapter_model)
|
||||
)
|
||||
|
||||
image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model)
|
||||
# `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here.
|
||||
single_ipa_image_fields = single_ip_adapter.image
|
||||
if not isinstance(single_ipa_image_fields, list):
|
||||
single_ipa_image_fields = [single_ipa_image_fields]
|
||||
|
||||
single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields]
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
single_ipa_images, image_encoder_model
|
||||
)
|
||||
|
||||
mask = single_ip_adapter.mask
|
||||
if mask is not None:
|
||||
mask = context.tensors.load(mask.tensor_name)
|
||||
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||
|
||||
ip_adapter_data_list.append(
|
||||
@@ -732,7 +734,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
)
|
||||
)
|
||||
|
||||
return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None
|
||||
return ip_adapter_data_list
|
||||
|
||||
def run_t2i_adapters(
|
||||
self,
|
||||
@@ -853,16 +855,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# At some point, someone decided that schedulers that accept a generator should use the original seed with
|
||||
# all bits flipped. I don't know the original rationale for this, but now we must keep it like this for
|
||||
# reproducibility.
|
||||
#
|
||||
# These Invoke-supported schedulers accept a generator as of 2024-06-04:
|
||||
# - DDIMScheduler
|
||||
# - DDPMScheduler
|
||||
# - DPMSolverMultistepScheduler
|
||||
# - EulerAncestralDiscreteScheduler
|
||||
# - EulerDiscreteScheduler
|
||||
# - KDPM2AncestralDiscreteScheduler
|
||||
# - LCMScheduler
|
||||
# - TCDScheduler
|
||||
scheduler_step_kwargs.update({"generator": torch.Generator(device=device).manual_seed(seed ^ 0xFFFFFFFF)})
|
||||
if isinstance(scheduler, TCDScheduler):
|
||||
scheduler_step_kwargs.update({"eta": 1.0})
|
||||
@@ -920,20 +912,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
|
||||
ip_adapters: List[IPAdapterField] = []
|
||||
if self.ip_adapter is not None:
|
||||
# ip_adapter could be a list or a single IPAdapterField. Normalize to a list here.
|
||||
if isinstance(self.ip_adapter, list):
|
||||
ip_adapters = self.ip_adapter
|
||||
else:
|
||||
ip_adapters = [self.ip_adapter]
|
||||
|
||||
# If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return
|
||||
# a series of image conditioning embeddings. This is being done here rather than in the
|
||||
# big model context below in order to use less VRAM on low-VRAM systems.
|
||||
# The image prompts are then passed to prep_ip_adapter_data().
|
||||
image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters)
|
||||
|
||||
# get the unet's config so that we can pass the base to dispatch_progress()
|
||||
unet_config = context.models.get_config(self.unet.unet.key)
|
||||
|
||||
@@ -952,9 +930,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
ModelPatcher.apply_freeu(unet_info.model, self.unet.freeu_config),
|
||||
set_seamless(unet_info.model, self.unet.seamless_axes), # FIXME
|
||||
unet_info as unet,
|
||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||
):
|
||||
@@ -992,8 +970,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
|
||||
ip_adapter_data = self.prep_ip_adapter_data(
|
||||
context=context,
|
||||
ip_adapters=ip_adapters,
|
||||
image_prompts=image_prompts,
|
||||
ip_adapter=self.ip_adapter,
|
||||
exit_stack=exit_stack,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
@@ -1308,7 +1285,7 @@ class ImageToLatentsInvocation(BaseInvocation):
|
||||
title="Blend Latents",
|
||||
tags=["latents", "blend"],
|
||||
category="latents",
|
||||
version="1.0.3",
|
||||
version="1.0.2",
|
||||
)
|
||||
class BlendLatentsInvocation(BaseInvocation):
|
||||
"""Blend two latents using a given alpha. Latents must have same size."""
|
||||
@@ -1387,7 +1364,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
TorchDevice.empty_cache()
|
||||
|
||||
name = context.tensors.save(tensor=blended_latents)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents, seed=self.latents_a.seed)
|
||||
return LatentsOutput.build(latents_name=name, latents=blended_latents)
|
||||
|
||||
|
||||
# The Crop Latents node was copied from @skunkworxdark's implementation here:
|
||||
|
||||
@@ -106,7 +106,9 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
self._invoker.services.events.emit_bulk_download_started(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
)
|
||||
|
||||
def _signal_job_completed(
|
||||
@@ -116,8 +118,10 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert bulk_download_item_name is not None
|
||||
self._invoker.services.events.emit_bulk_download_complete(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name
|
||||
self._invoker.services.events.emit_bulk_download_completed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
)
|
||||
|
||||
def _signal_job_failed(
|
||||
@@ -127,8 +131,11 @@ class BulkDownloadService(BulkDownloadBase):
|
||||
if self._invoker:
|
||||
assert bulk_download_id is not None
|
||||
assert exception is not None
|
||||
self._invoker.services.events.emit_bulk_download_error(
|
||||
bulk_download_id, bulk_download_item_id, bulk_download_item_name, str(exception)
|
||||
self._invoker.services.events.emit_bulk_download_failed(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
error=str(exception),
|
||||
)
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
|
||||
@@ -8,13 +8,14 @@ import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import requests
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
@@ -29,9 +30,6 @@ from .download_base import (
|
||||
UnknownJobIDException,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
# Maximum number of bytes to download during each call to requests.iter_content()
|
||||
DOWNLOAD_CHUNK_SIZE = 100000
|
||||
|
||||
@@ -42,7 +40,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
):
|
||||
"""
|
||||
@@ -345,7 +343,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_started(job)
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
|
||||
|
||||
def _signal_job_progress(self, job: DownloadJob) -> None:
|
||||
if job.on_progress:
|
||||
@@ -356,7 +355,13 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_progress(job)
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_progress(
|
||||
str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
current_bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
|
||||
def _signal_job_complete(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.COMPLETED
|
||||
@@ -368,7 +373,10 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_complete(job)
|
||||
assert job.download_path
|
||||
self._event_bus.emit_download_complete(
|
||||
str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes
|
||||
)
|
||||
|
||||
def _signal_job_cancelled(self, job: DownloadJob) -> None:
|
||||
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
||||
@@ -382,7 +390,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_cancelled(job)
|
||||
self._event_bus.emit_download_cancelled(str(job.source))
|
||||
|
||||
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
@@ -395,7 +403,9 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
|
||||
)
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_download_error(job)
|
||||
assert job.error_type
|
||||
assert job.error
|
||||
self._event_bus.emit_download_error(str(job.source), error_type=job.error_type, error=job.error)
|
||||
|
||||
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
|
||||
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")
|
||||
|
||||
@@ -1,195 +1,494 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
BulkDownloadCompleteEvent,
|
||||
BulkDownloadErrorEvent,
|
||||
BulkDownloadStartedEvent,
|
||||
DownloadCancelledEvent,
|
||||
DownloadCompleteEvent,
|
||||
DownloadErrorEvent,
|
||||
DownloadProgressEvent,
|
||||
DownloadStartedEvent,
|
||||
EventBase,
|
||||
InvocationCompleteEvent,
|
||||
InvocationDenoiseProgressEvent,
|
||||
InvocationErrorEvent,
|
||||
InvocationStartedEvent,
|
||||
ModelInstallCancelledEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallErrorEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelLoadCompleteEvent,
|
||||
ModelLoadStartedEvent,
|
||||
QueueClearedEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager import AnyModelConfig
|
||||
from invokeai.backend.model_manager.config import SubModelType
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
queue_event: str = "queue_event"
|
||||
bulk_download_event: str = "bulk_download_event"
|
||||
download_event: str = "download_event"
|
||||
model_event: str = "model_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
def dispatch(self, event: "EventBase") -> None:
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
# region: Invocation
|
||||
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Bulk download events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.bulk_download_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def emit_invocation_started(self, queue_item: "SessionQueueItem", invocation: "BaseInvocation") -> None:
|
||||
"""Emitted when an invocation is started"""
|
||||
self.dispatch(InvocationStartedEvent.build(queue_item, invocation))
|
||||
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Queue events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.queue_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def emit_invocation_denoise_progress(
|
||||
def __emit_download_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.download_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
def __emit_model_event(self, event_name: str, payload: dict) -> None:
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.model_event,
|
||||
payload={"event": event_name, "data": payload},
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
queue_item: "SessionQueueItem",
|
||||
invocation: "BaseInvocation",
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
progress_image: "ProgressImage",
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node_id: str,
|
||||
source_node_id: str,
|
||||
progress_image: Optional[ProgressImage],
|
||||
step: int,
|
||||
order: int,
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Emitted at each step during denoising of an invocation."""
|
||||
self.dispatch(InvocationDenoiseProgressEvent.build(queue_item, invocation, intermediate_state, progress_image))
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_queue_event(
|
||||
event_name="generator_progress",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node_id": node_id,
|
||||
"source_node_id": source_node_id,
|
||||
"progress_image": progress_image.model_dump(mode="json") if progress_image is not None else None,
|
||||
"step": step,
|
||||
"order": order,
|
||||
"total_steps": total_steps,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_invocation_complete(
|
||||
self, queue_item: "SessionQueueItem", invocation: "BaseInvocation", output: "BaseInvocationOutput"
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
result: dict,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation is complete"""
|
||||
self.dispatch(InvocationCompleteEvent.build(queue_item, invocation, output))
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_complete",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
"result": result,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_invocation_error(
|
||||
self,
|
||||
queue_item: "SessionQueueItem",
|
||||
invocation: "BaseInvocation",
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
user_id: str | None,
|
||||
project_id: str | None,
|
||||
) -> None:
|
||||
"""Emitted when an invocation encounters an error"""
|
||||
self.dispatch(InvocationErrorEvent.build(queue_item, invocation, error_type, error_message, error_traceback))
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_error",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
"error_type": error_type,
|
||||
"error_message": error_message,
|
||||
"error_traceback": error_traceback,
|
||||
"user_id": user_id,
|
||||
"project_id": project_id,
|
||||
},
|
||||
)
|
||||
|
||||
# endregion
|
||||
def emit_invocation_started(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_started",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"node": node,
|
||||
"source_node_id": source_node_id,
|
||||
},
|
||||
)
|
||||
|
||||
# region Queue
|
||||
def emit_graph_execution_complete(
|
||||
self, queue_id: str, queue_item_id: int, queue_batch_id: str, graph_execution_state_id: str
|
||||
) -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_queue_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_load_started(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_started",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_config": model_config.model_dump(mode="json"),
|
||||
"submodel_type": submodel_type,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_load_completed(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_completed",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
"model_config": model_config.model_dump(mode="json"),
|
||||
"submodel_type": submodel_type,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_session_canceled(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: int,
|
||||
queue_batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
) -> None:
|
||||
"""Emitted when a session is canceled"""
|
||||
self.__emit_queue_event(
|
||||
event_name="session_canceled",
|
||||
payload={
|
||||
"queue_id": queue_id,
|
||||
"queue_item_id": queue_item_id,
|
||||
"queue_batch_id": queue_batch_id,
|
||||
"graph_execution_state_id": graph_execution_state_id,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_queue_item_status_changed(
|
||||
self, queue_item: "SessionQueueItem", batch_status: "BatchStatus", queue_status: "SessionQueueStatus"
|
||||
self,
|
||||
session_queue_item: SessionQueueItem,
|
||||
batch_status: BatchStatus,
|
||||
queue_status: SessionQueueStatus,
|
||||
) -> None:
|
||||
"""Emitted when a queue item's status changes"""
|
||||
self.dispatch(QueueItemStatusChangedEvent.build(queue_item, batch_status, queue_status))
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_item_status_changed",
|
||||
payload={
|
||||
"queue_id": queue_status.queue_id,
|
||||
"queue_item": {
|
||||
"queue_id": session_queue_item.queue_id,
|
||||
"item_id": session_queue_item.item_id,
|
||||
"status": session_queue_item.status,
|
||||
"batch_id": session_queue_item.batch_id,
|
||||
"session_id": session_queue_item.session_id,
|
||||
"error_type": session_queue_item.error_type,
|
||||
"error_message": session_queue_item.error_message,
|
||||
"error_traceback": session_queue_item.error_traceback,
|
||||
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||
"completed_at": str(session_queue_item.completed_at) if session_queue_item.completed_at else None,
|
||||
},
|
||||
"batch_status": batch_status.model_dump(mode="json"),
|
||||
"queue_status": queue_status.model_dump(mode="json"),
|
||||
},
|
||||
)
|
||||
|
||||
def emit_batch_enqueued(self, enqueue_result: "EnqueueBatchResult") -> None:
|
||||
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
||||
"""Emitted when a batch is enqueued"""
|
||||
self.dispatch(BatchEnqueuedEvent.build(enqueue_result))
|
||||
self.__emit_queue_event(
|
||||
event_name="batch_enqueued",
|
||||
payload={
|
||||
"queue_id": enqueue_result.queue_id,
|
||||
"batch_id": enqueue_result.batch.batch_id,
|
||||
"enqueued": enqueue_result.enqueued,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||
"""Emitted when a queue is cleared"""
|
||||
self.dispatch(QueueClearedEvent.build(queue_id))
|
||||
"""Emitted when the queue is cleared"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_cleared",
|
||||
payload={"queue_id": queue_id},
|
||||
)
|
||||
|
||||
# endregion
|
||||
def emit_download_started(self, source: str, download_path: str) -> None:
|
||||
"""
|
||||
Emit when a download job is started.
|
||||
|
||||
# region Download
|
||||
:param url: The downloaded url
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_started",
|
||||
payload={"source": source, "download_path": download_path},
|
||||
)
|
||||
|
||||
def emit_download_started(self, job: "DownloadJob") -> None:
|
||||
"""Emitted when a download is started"""
|
||||
self.dispatch(DownloadStartedEvent.build(job))
|
||||
def emit_download_progress(self, source: str, download_path: str, current_bytes: int, total_bytes: int) -> None:
|
||||
"""
|
||||
Emit "download_progress" events at regular intervals during a download job.
|
||||
|
||||
def emit_download_progress(self, job: "DownloadJob") -> None:
|
||||
"""Emitted at intervals during a download"""
|
||||
self.dispatch(DownloadProgressEvent.build(job))
|
||||
:param source: The downloaded source
|
||||
:param download_path: The local downloaded file
|
||||
:param current_bytes: Number of bytes downloaded so far
|
||||
:param total_bytes: The size of the file being downloaded (if known)
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_progress",
|
||||
payload={
|
||||
"source": source,
|
||||
"download_path": download_path,
|
||||
"current_bytes": current_bytes,
|
||||
"total_bytes": total_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_download_complete(self, job: "DownloadJob") -> None:
|
||||
"""Emitted when a download is completed"""
|
||||
self.dispatch(DownloadCompleteEvent.build(job))
|
||||
def emit_download_complete(self, source: str, download_path: str, total_bytes: int) -> None:
|
||||
"""
|
||||
Emit a "download_complete" event at the end of a successful download.
|
||||
|
||||
def emit_download_cancelled(self, job: "DownloadJob") -> None:
|
||||
"""Emitted when a download is cancelled"""
|
||||
self.dispatch(DownloadCancelledEvent.build(job))
|
||||
:param source: Source URL
|
||||
:param download_path: Path to the locally downloaded file
|
||||
:param total_bytes: The size of the downloaded file
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_complete",
|
||||
payload={
|
||||
"source": source,
|
||||
"download_path": download_path,
|
||||
"total_bytes": total_bytes,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_download_error(self, job: "DownloadJob") -> None:
|
||||
"""Emitted when a download encounters an error"""
|
||||
self.dispatch(DownloadErrorEvent.build(job))
|
||||
def emit_download_cancelled(self, source: str) -> None:
|
||||
"""Emit a "download_cancelled" event in the event that the download was cancelled by user."""
|
||||
self.__emit_download_event(
|
||||
event_name="download_cancelled",
|
||||
payload={
|
||||
"source": source,
|
||||
},
|
||||
)
|
||||
|
||||
# endregion
|
||||
def emit_download_error(self, source: str, error_type: str, error: str) -> None:
|
||||
"""
|
||||
Emit a "download_error" event when an download job encounters an exception.
|
||||
|
||||
# region Model loading
|
||||
:param source: Source URL
|
||||
:param error_type: The name of the exception that raised the error
|
||||
:param error: The traceback from this error
|
||||
"""
|
||||
self.__emit_download_event(
|
||||
event_name="download_error",
|
||||
payload={
|
||||
"source": source,
|
||||
"error_type": error_type,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_model_load_started(self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None) -> None:
|
||||
"""Emitted when a model load is started."""
|
||||
self.dispatch(ModelLoadStartedEvent.build(config, submodel_type))
|
||||
|
||||
def emit_model_load_complete(
|
||||
self, config: "AnyModelConfig", submodel_type: Optional["SubModelType"] = None
|
||||
def emit_model_install_downloading(
|
||||
self,
|
||||
source: str,
|
||||
local_path: str,
|
||||
bytes: int,
|
||||
total_bytes: int,
|
||||
parts: List[Dict[str, Union[str, int]]],
|
||||
id: int,
|
||||
) -> None:
|
||||
"""Emitted when a model load is complete."""
|
||||
self.dispatch(ModelLoadCompleteEvent.build(config, submodel_type))
|
||||
"""
|
||||
Emit at intervals while the install job is in progress (remote models only).
|
||||
|
||||
# endregion
|
||||
:param source: Source of the model
|
||||
:param local_path: Where model is downloading to
|
||||
:param parts: Progress of downloading URLs that comprise the model, if any.
|
||||
:param bytes: Number of bytes downloaded so far.
|
||||
:param total_bytes: Total size of download, including all files.
|
||||
This emits a Dict with keys "source", "local_path", "bytes" and "total_bytes".
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_downloading",
|
||||
payload={
|
||||
"source": source,
|
||||
"local_path": local_path,
|
||||
"bytes": bytes,
|
||||
"total_bytes": total_bytes,
|
||||
"parts": parts,
|
||||
"id": id,
|
||||
},
|
||||
)
|
||||
|
||||
# region Model install
|
||||
def emit_model_install_downloads_done(self, source: str) -> None:
|
||||
"""
|
||||
Emit once when all parts are downloaded, but before the probing and registration start.
|
||||
|
||||
def emit_model_install_download_progress(self, job: "ModelInstallJob") -> None:
|
||||
"""Emitted at intervals while the install job is in progress (remote models only)."""
|
||||
self.dispatch(ModelInstallDownloadProgressEvent.build(job))
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_downloads_done",
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_downloads_complete(self, job: "ModelInstallJob") -> None:
|
||||
self.dispatch(ModelInstallDownloadsCompleteEvent.build(job))
|
||||
def emit_model_install_running(self, source: str) -> None:
|
||||
"""
|
||||
Emit once when an install job becomes active.
|
||||
|
||||
def emit_model_install_started(self, job: "ModelInstallJob") -> None:
|
||||
"""Emitted once when an install job is started (after any download)."""
|
||||
self.dispatch(ModelInstallStartedEvent.build(job))
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_running",
|
||||
payload={"source": source},
|
||||
)
|
||||
|
||||
def emit_model_install_complete(self, job: "ModelInstallJob") -> None:
|
||||
"""Emitted when an install job is completed successfully."""
|
||||
self.dispatch(ModelInstallCompleteEvent.build(job))
|
||||
def emit_model_install_completed(self, source: str, key: str, id: int, total_bytes: Optional[int] = None) -> None:
|
||||
"""
|
||||
Emit when an install job is completed successfully.
|
||||
|
||||
def emit_model_install_cancelled(self, job: "ModelInstallJob") -> None:
|
||||
"""Emitted when an install job is cancelled."""
|
||||
self.dispatch(ModelInstallCancelledEvent.build(job))
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
:param key: Model config record key
|
||||
:param total_bytes: Size of the model (may be None for installation of a local path)
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_completed",
|
||||
payload={"source": source, "total_bytes": total_bytes, "key": key, "id": id},
|
||||
)
|
||||
|
||||
def emit_model_install_error(self, job: "ModelInstallJob") -> None:
|
||||
"""Emitted when an install job encounters an exception."""
|
||||
self.dispatch(ModelInstallErrorEvent.build(job))
|
||||
def emit_model_install_cancelled(self, source: str, id: int) -> None:
|
||||
"""
|
||||
Emit when an install job is cancelled.
|
||||
|
||||
# endregion
|
||||
:param source: Source of the model; local path, repo_id or url
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_cancelled",
|
||||
payload={"source": source, "id": id},
|
||||
)
|
||||
|
||||
# region Bulk image download
|
||||
def emit_model_install_error(self, source: str, error_type: str, error: str, id: int) -> None:
|
||||
"""
|
||||
Emit when an install job encounters an exception.
|
||||
|
||||
:param source: Source of the model
|
||||
:param error_type: The name of the exception
|
||||
:param error: A text description of the exception
|
||||
"""
|
||||
self.__emit_model_event(
|
||||
event_name="model_install_error",
|
||||
payload={"source": source, "error_type": error_type, "error": error, "id": id},
|
||||
)
|
||||
|
||||
def emit_bulk_download_started(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk image download is started"""
|
||||
self.dispatch(BulkDownloadStartedEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
||||
|
||||
def emit_bulk_download_complete(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk image download is complete"""
|
||||
self.dispatch(BulkDownloadCompleteEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name))
|
||||
|
||||
def emit_bulk_download_error(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk image download has an error"""
|
||||
self.dispatch(
|
||||
BulkDownloadErrorEvent.build(bulk_download_id, bulk_download_item_id, bulk_download_item_name, error)
|
||||
"""Emitted when a bulk download starts"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_started",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
},
|
||||
)
|
||||
|
||||
# endregion
|
||||
def emit_bulk_download_completed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk download completes"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_completed",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
},
|
||||
)
|
||||
|
||||
def emit_bulk_download_failed(
|
||||
self, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
) -> None:
|
||||
"""Emitted when a bulk download fails"""
|
||||
self._emit_bulk_download_event(
|
||||
event_name="bulk_download_failed",
|
||||
payload={
|
||||
"bulk_download_id": bulk_download_id,
|
||||
"bulk_download_item_id": bulk_download_item_id,
|
||||
"bulk_download_item_name": bulk_download_item_name,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,592 +0,0 @@
|
||||
from math import floor
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Coroutine, Generic, Optional, Protocol, TypeAlias, TypeVar
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.registry.payload_schema import registry as payload_schema
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
BatchStatus,
|
||||
EnqueueBatchResult,
|
||||
SessionQueueItem,
|
||||
SessionQueueStatus,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import AnyInvocation, AnyInvocationOutput
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.download.download_base import DownloadJob
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob
|
||||
|
||||
|
||||
class EventBase(BaseModel):
|
||||
"""Base class for all events. All events must inherit from this class.
|
||||
|
||||
Events must define a class attribute `__event_name__` to identify the event.
|
||||
|
||||
All other attributes should be defined as normal for a pydantic model.
|
||||
|
||||
A timestamp is automatically added to the event when it is created.
|
||||
"""
|
||||
|
||||
__event_name__: ClassVar[str]
|
||||
timestamp: int = Field(description="The timestamp of the event", default_factory=get_timestamp)
|
||||
|
||||
model_config = ConfigDict(json_schema_serialization_defaults_required=True)
|
||||
|
||||
@classmethod
|
||||
def get_events(cls) -> set[type["EventBase"]]:
|
||||
"""Get a set of all event models."""
|
||||
|
||||
event_subclasses: set[type["EventBase"]] = set()
|
||||
for subclass in cls.__subclasses__():
|
||||
# We only want to include subclasses that are event models, not intermediary classes
|
||||
if hasattr(subclass, "__event_name__"):
|
||||
event_subclasses.add(subclass)
|
||||
event_subclasses.update(subclass.get_events())
|
||||
|
||||
return event_subclasses
|
||||
|
||||
|
||||
TEvent = TypeVar("TEvent", bound=EventBase, contravariant=True)
|
||||
|
||||
FastAPIEvent: TypeAlias = tuple[str, TEvent]
|
||||
"""
|
||||
A tuple representing a `fastapi-events` event, with the event name and payload.
|
||||
Provide a generic type to `TEvent` to specify the payload type.
|
||||
"""
|
||||
|
||||
|
||||
class FastAPIEventFunc(Protocol, Generic[TEvent]):
|
||||
def __call__(self, event: FastAPIEvent[TEvent]) -> Optional[Coroutine[Any, Any, None]]: ...
|
||||
|
||||
|
||||
def register_events(events: set[type[TEvent]] | type[TEvent], func: FastAPIEventFunc[TEvent]) -> None:
|
||||
"""Register a function to handle specific events.
|
||||
|
||||
:param events: An event or set of events to handle
|
||||
:param func: The function to handle the events
|
||||
"""
|
||||
events = events if isinstance(events, set) else {events}
|
||||
for event in events:
|
||||
assert hasattr(event, "__event_name__")
|
||||
local_handler.register(event_name=event.__event_name__, _func=func) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||
|
||||
|
||||
class QueueEventBase(EventBase):
|
||||
"""Base class for queue events"""
|
||||
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
|
||||
|
||||
class QueueItemEventBase(QueueEventBase):
|
||||
"""Base class for queue item events"""
|
||||
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
|
||||
|
||||
class InvocationEventBase(QueueItemEventBase):
|
||||
"""Base class for invocation events"""
|
||||
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
item_id: int = Field(description="The ID of the queue item")
|
||||
batch_id: str = Field(description="The ID of the queue batch")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
invocation: AnyInvocation = Field(description="The ID of the invocation")
|
||||
invocation_source_id: str = Field(description="The ID of the prepared invocation's source node")
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class InvocationStartedEvent(InvocationEventBase):
|
||||
"""Event model for invocation_started"""
|
||||
|
||||
__event_name__ = "invocation_started"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_item: SessionQueueItem, invocation: AnyInvocation) -> "InvocationStartedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class InvocationDenoiseProgressEvent(InvocationEventBase):
|
||||
"""Event model for invocation_denoise_progress"""
|
||||
|
||||
__event_name__ = "invocation_denoise_progress"
|
||||
|
||||
progress_image: ProgressImage = Field(description="The progress image sent at each step during processing")
|
||||
step: int = Field(description="The current step of the invocation")
|
||||
total_steps: int = Field(description="The total number of steps in the invocation")
|
||||
order: int = Field(description="The order of the invocation in the session")
|
||||
percentage: float = Field(description="The percentage of completion of the invocation")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: AnyInvocation,
|
||||
intermediate_state: PipelineIntermediateState,
|
||||
progress_image: ProgressImage,
|
||||
) -> "InvocationDenoiseProgressEvent":
|
||||
step = intermediate_state.step
|
||||
total_steps = intermediate_state.total_steps
|
||||
order = intermediate_state.order
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
progress_image=progress_image,
|
||||
step=step,
|
||||
total_steps=total_steps,
|
||||
order=order,
|
||||
percentage=cls.calc_percentage(step, total_steps, order),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def calc_percentage(step: int, total_steps: int, scheduler_order: float) -> float:
|
||||
"""Calculate the percentage of completion of denoising."""
|
||||
if total_steps == 0:
|
||||
return 0.0
|
||||
if scheduler_order == 2:
|
||||
return floor((step + 1 + 1) / 2) / floor((total_steps + 1) / 2)
|
||||
# order == 1
|
||||
return (step + 1 + 1) / (total_steps + 1)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class InvocationCompleteEvent(InvocationEventBase):
|
||||
"""Event model for invocation_complete"""
|
||||
|
||||
__event_name__ = "invocation_complete"
|
||||
|
||||
result: AnyInvocationOutput = Field(description="The result of the invocation")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, queue_item: SessionQueueItem, invocation: AnyInvocation, result: AnyInvocationOutput
|
||||
) -> "InvocationCompleteEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
result=result,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class InvocationErrorEvent(InvocationEventBase):
|
||||
"""Event model for invocation_error"""
|
||||
|
||||
__event_name__ = "invocation_error"
|
||||
|
||||
error_type: str = Field(description="The error type")
|
||||
error_message: str = Field(description="The error message")
|
||||
error_traceback: str = Field(description="The error traceback")
|
||||
user_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
|
||||
project_id: Optional[str] = Field(default=None, description="The ID of the user who created the invocation")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
queue_item: SessionQueueItem,
|
||||
invocation: AnyInvocation,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
) -> "InvocationErrorEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
invocation=invocation,
|
||||
invocation_source_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
user_id=getattr(queue_item, "user_id", None),
|
||||
project_id=getattr(queue_item, "project_id", None),
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class QueueItemStatusChangedEvent(QueueItemEventBase):
|
||||
"""Event model for queue_item_status_changed"""
|
||||
|
||||
__event_name__ = "queue_item_status_changed"
|
||||
|
||||
status: QUEUE_ITEM_STATUS = Field(description="The new status of the queue item")
|
||||
error_type: Optional[str] = Field(default=None, description="The error type, if any")
|
||||
error_message: Optional[str] = Field(default=None, description="The error message, if any")
|
||||
error_traceback: Optional[str] = Field(default=None, description="The error traceback, if any")
|
||||
created_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was created")
|
||||
updated_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was last updated")
|
||||
started_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was started")
|
||||
completed_at: Optional[str] = Field(default=None, description="The timestamp when the queue item was completed")
|
||||
batch_status: BatchStatus = Field(description="The status of the batch")
|
||||
queue_status: SessionQueueStatus = Field(description="The status of the queue")
|
||||
session_id: str = Field(description="The ID of the session (aka graph execution state)")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, queue_item: SessionQueueItem, batch_status: BatchStatus, queue_status: SessionQueueStatus
|
||||
) -> "QueueItemStatusChangedEvent":
|
||||
return cls(
|
||||
queue_id=queue_item.queue_id,
|
||||
item_id=queue_item.item_id,
|
||||
batch_id=queue_item.batch_id,
|
||||
session_id=queue_item.session_id,
|
||||
status=queue_item.status,
|
||||
error_type=queue_item.error_type,
|
||||
error_message=queue_item.error_message,
|
||||
error_traceback=queue_item.error_traceback,
|
||||
created_at=str(queue_item.created_at) if queue_item.created_at else None,
|
||||
updated_at=str(queue_item.updated_at) if queue_item.updated_at else None,
|
||||
started_at=str(queue_item.started_at) if queue_item.started_at else None,
|
||||
completed_at=str(queue_item.completed_at) if queue_item.completed_at else None,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class BatchEnqueuedEvent(QueueEventBase):
|
||||
"""Event model for batch_enqueued"""
|
||||
|
||||
__event_name__ = "batch_enqueued"
|
||||
|
||||
batch_id: str = Field(description="The ID of the batch")
|
||||
enqueued: int = Field(description="The number of invocations enqueued")
|
||||
requested: int = Field(
|
||||
description="The number of invocations initially requested to be enqueued (may be less than enqueued if queue was full)"
|
||||
)
|
||||
priority: int = Field(description="The priority of the batch")
|
||||
|
||||
@classmethod
|
||||
def build(cls, enqueue_result: EnqueueBatchResult) -> "BatchEnqueuedEvent":
|
||||
return cls(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
enqueued=enqueue_result.enqueued,
|
||||
requested=enqueue_result.requested,
|
||||
priority=enqueue_result.priority,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class QueueClearedEvent(QueueEventBase):
|
||||
"""Event model for queue_cleared"""
|
||||
|
||||
__event_name__ = "queue_cleared"
|
||||
|
||||
@classmethod
|
||||
def build(cls, queue_id: str) -> "QueueClearedEvent":
|
||||
return cls(queue_id=queue_id)
|
||||
|
||||
|
||||
class DownloadEventBase(EventBase):
|
||||
"""Base class for events associated with a download"""
|
||||
|
||||
source: str = Field(description="The source of the download")
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class DownloadStartedEvent(DownloadEventBase):
|
||||
"""Event model for download_started"""
|
||||
|
||||
__event_name__ = "download_started"
|
||||
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadStartedEvent":
|
||||
assert job.download_path
|
||||
return cls(source=str(job.source), download_path=job.download_path.as_posix())
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class DownloadProgressEvent(DownloadEventBase):
|
||||
"""Event model for download_progress"""
|
||||
|
||||
__event_name__ = "download_progress"
|
||||
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
current_bytes: int = Field(description="The number of bytes downloaded so far")
|
||||
total_bytes: int = Field(description="The total number of bytes to be downloaded")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadProgressEvent":
|
||||
assert job.download_path
|
||||
return cls(
|
||||
source=str(job.source),
|
||||
download_path=job.download_path.as_posix(),
|
||||
current_bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class DownloadCompleteEvent(DownloadEventBase):
|
||||
"""Event model for download_complete"""
|
||||
|
||||
__event_name__ = "download_complete"
|
||||
|
||||
download_path: str = Field(description="The local path where the download is saved")
|
||||
total_bytes: int = Field(description="The total number of bytes downloaded")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadCompleteEvent":
|
||||
assert job.download_path
|
||||
return cls(source=str(job.source), download_path=job.download_path.as_posix(), total_bytes=job.total_bytes)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class DownloadCancelledEvent(DownloadEventBase):
|
||||
"""Event model for download_cancelled"""
|
||||
|
||||
__event_name__ = "download_cancelled"
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadCancelledEvent":
|
||||
return cls(source=str(job.source))
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class DownloadErrorEvent(DownloadEventBase):
|
||||
"""Event model for download_error"""
|
||||
|
||||
__event_name__ = "download_error"
|
||||
|
||||
error_type: str = Field(description="The type of error")
|
||||
error: str = Field(description="The error message")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "DownloadJob") -> "DownloadErrorEvent":
|
||||
assert job.error_type
|
||||
assert job.error
|
||||
return cls(source=str(job.source), error_type=job.error_type, error=job.error)
|
||||
|
||||
|
||||
class ModelEventBase(EventBase):
|
||||
"""Base class for events associated with a model"""
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelLoadStartedEvent(ModelEventBase):
|
||||
"""Event model for model_load_started"""
|
||||
|
||||
__event_name__ = "model_load_started"
|
||||
|
||||
config: AnyModelConfig = Field(description="The model's config")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadStartedEvent":
|
||||
return cls(config=config, submodel_type=submodel_type)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelLoadCompleteEvent(ModelEventBase):
|
||||
"""Event model for model_load_complete"""
|
||||
|
||||
__event_name__ = "model_load_complete"
|
||||
|
||||
config: AnyModelConfig = Field(description="The model's config")
|
||||
submodel_type: Optional[SubModelType] = Field(default=None, description="The submodel type, if any")
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> "ModelLoadCompleteEvent":
|
||||
return cls(config=config, submodel_type=submodel_type)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallDownloadProgressEvent(ModelEventBase):
|
||||
"""Event model for model_install_download_progress"""
|
||||
|
||||
__event_name__ = "model_install_download_progress"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
local_path: str = Field(description="Where model is downloading to")
|
||||
bytes: int = Field(description="Number of bytes downloaded so far")
|
||||
total_bytes: int = Field(description="Total size of download, including all files")
|
||||
parts: list[dict[str, int | str]] = Field(
|
||||
description="Progress of downloading URLs that comprise the model, if any"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadProgressEvent":
|
||||
parts: list[dict[str, str | int]] = [
|
||||
{
|
||||
"url": str(x.source),
|
||||
"local_path": str(x.download_path),
|
||||
"bytes": x.bytes,
|
||||
"total_bytes": x.total_bytes,
|
||||
}
|
||||
for x in job.download_parts
|
||||
]
|
||||
return cls(
|
||||
id=job.id,
|
||||
source=str(job.source),
|
||||
local_path=job.local_path.as_posix(),
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallDownloadsCompleteEvent(ModelEventBase):
|
||||
"""Emitted once when an install job becomes active."""
|
||||
|
||||
__event_name__ = "model_install_downloads_complete"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallDownloadsCompleteEvent":
|
||||
return cls(id=job.id, source=str(job.source))
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallStartedEvent(ModelEventBase):
|
||||
"""Event model for model_install_started"""
|
||||
|
||||
__event_name__ = "model_install_started"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallStartedEvent":
|
||||
return cls(id=job.id, source=str(job.source))
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallCompleteEvent(ModelEventBase):
|
||||
"""Event model for model_install_complete"""
|
||||
|
||||
__event_name__ = "model_install_complete"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
key: str = Field(description="Model config record key")
|
||||
total_bytes: Optional[int] = Field(description="Size of the model (may be None for installation of a local path)")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallCompleteEvent":
|
||||
assert job.config_out is not None
|
||||
return cls(id=job.id, source=str(job.source), key=(job.config_out.key), total_bytes=job.total_bytes)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallCancelledEvent(ModelEventBase):
|
||||
"""Event model for model_install_cancelled"""
|
||||
|
||||
__event_name__ = "model_install_cancelled"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallCancelledEvent":
|
||||
return cls(id=job.id, source=str(job.source))
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class ModelInstallErrorEvent(ModelEventBase):
|
||||
"""Event model for model_install_error"""
|
||||
|
||||
__event_name__ = "model_install_error"
|
||||
|
||||
id: int = Field(description="The ID of the install job")
|
||||
source: str = Field(description="Source of the model; local path, repo_id or url")
|
||||
error_type: str = Field(description="The name of the exception")
|
||||
error: str = Field(description="A text description of the exception")
|
||||
|
||||
@classmethod
|
||||
def build(cls, job: "ModelInstallJob") -> "ModelInstallErrorEvent":
|
||||
assert job.error_type is not None
|
||||
assert job.error is not None
|
||||
return cls(id=job.id, source=str(job.source), error_type=job.error_type, error=job.error)
|
||||
|
||||
|
||||
class BulkDownloadEventBase(EventBase):
|
||||
"""Base class for events associated with a bulk image download"""
|
||||
|
||||
bulk_download_id: str = Field(description="The ID of the bulk image download")
|
||||
bulk_download_item_id: str = Field(description="The ID of the bulk image download item")
|
||||
bulk_download_item_name: str = Field(description="The name of the bulk image download item")
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class BulkDownloadStartedEvent(BulkDownloadEventBase):
|
||||
"""Event model for bulk_download_started"""
|
||||
|
||||
__event_name__ = "bulk_download_started"
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> "BulkDownloadStartedEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class BulkDownloadCompleteEvent(BulkDownloadEventBase):
|
||||
"""Event model for bulk_download_complete"""
|
||||
|
||||
__event_name__ = "bulk_download_complete"
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str
|
||||
) -> "BulkDownloadCompleteEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
)
|
||||
|
||||
|
||||
@payload_schema.register
|
||||
class BulkDownloadErrorEvent(BulkDownloadEventBase):
|
||||
"""Event model for bulk_download_error"""
|
||||
|
||||
__event_name__ = "bulk_download_error"
|
||||
|
||||
error: str = Field(description="The error message")
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, bulk_download_id: str, bulk_download_item_id: str, bulk_download_item_name: str, error: str
|
||||
) -> "BulkDownloadErrorEvent":
|
||||
return cls(
|
||||
bulk_download_id=bulk_download_id,
|
||||
bulk_download_item_id=bulk_download_item_id,
|
||||
bulk_download_item_name=bulk_download_item_name,
|
||||
error=error,
|
||||
)
|
||||
@@ -1,47 +0,0 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from queue import Empty, Queue
|
||||
|
||||
from fastapi_events.dispatcher import dispatch
|
||||
|
||||
from invokeai.app.services.events.events_common import (
|
||||
EventBase,
|
||||
)
|
||||
|
||||
from .events_base import EventServiceBase
|
||||
|
||||
|
||||
class FastAPIEventService(EventServiceBase):
|
||||
def __init__(self, event_handler_id: int) -> None:
|
||||
self.event_handler_id = event_handler_id
|
||||
self._queue = Queue[EventBase | None]()
|
||||
self._stop_event = threading.Event()
|
||||
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self._stop_event.set()
|
||||
self._queue.put(None)
|
||||
|
||||
def dispatch(self, event: EventBase) -> None:
|
||||
self._queue.put(event)
|
||||
|
||||
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
event = self._queue.get(block=False)
|
||||
if not event: # Probably stopping
|
||||
continue
|
||||
# Leave the payloads as live pydantic models
|
||||
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
|
||||
|
||||
except Empty:
|
||||
await asyncio.sleep(0.1)
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError as e:
|
||||
raise e # Raise a proper error
|
||||
@@ -1,13 +1,11 @@
|
||||
"""Initialization file for model install service package."""
|
||||
|
||||
from .model_install_base import (
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from .model_install_common import (
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
UnknownInstallJobException,
|
||||
URLModelSource,
|
||||
|
||||
@@ -1,19 +1,244 @@
|
||||
# Copyright 2023 Lincoln D. Stein and the InvokeAI development team
|
||||
"""Baseclass definitions for the model installer."""
|
||||
|
||||
import re
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_common import ModelInstallJob, ModelSource
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
DOWNLOADING = "downloading" # downloading of model files in process
|
||||
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
CANCELLED = "cancelled" # terminated with an error message
|
||||
|
||||
|
||||
class ModelInstallPart(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
path: Path
|
||||
bytes: int = 0
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
class UnknownInstallJobException(Exception):
|
||||
"""Raised when the status of an unknown job is requested."""
|
||||
|
||||
|
||||
class StringLikeSource(BaseModel):
|
||||
"""
|
||||
Base class for model sources, implements functions that lets the source be sorted and indexed.
|
||||
|
||||
These shenanigans let this stuff work:
|
||||
|
||||
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
|
||||
mydict = {source1: 'model 1'}
|
||||
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
|
||||
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
|
||||
|
||||
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
|
||||
assert source1 == source2
|
||||
assert source1 == 'C:/users/mort/foo.safetensors'
|
||||
"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the path field, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __lt__(self, other: object) -> int:
|
||||
"""Return comparison of the stringified version, for sorting."""
|
||||
return str(self) < str(other)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Return equality on the stringified version."""
|
||||
if isinstance(other, Path):
|
||||
return str(self) == other.as_posix()
|
||||
else:
|
||||
return str(self) == str(other)
|
||||
|
||||
|
||||
class LocalModelSource(StringLikeSource):
|
||||
"""A local file or directory path."""
|
||||
|
||||
path: str | Path
|
||||
inplace: Optional[bool] = False
|
||||
type: Literal["local"] = "local"
|
||||
|
||||
# these methods allow the source to be used in a string-like way,
|
||||
# for example as an index into a dict
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of path when string rep needed."""
|
||||
return Path(self.path).as_posix()
|
||||
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""
|
||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||
Note that the variant option, if not provided to the constructor, will default to fp16, which is
|
||||
what people (almost) always want.
|
||||
"""
|
||||
|
||||
repo_id: str
|
||||
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
|
||||
subfolder: Optional[Path] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["hf"] = "hf"
|
||||
|
||||
@field_validator("repo_id")
|
||||
@classmethod
|
||||
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
|
||||
raise ValueError(f"{v}: invalid repo_id format")
|
||||
return v
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f":{self.subfolder}"
|
||||
return base
|
||||
|
||||
|
||||
class URLModelSource(StringLikeSource):
|
||||
"""A generic URL point to a checkpoint file."""
|
||||
|
||||
url: AnyHttpUrl
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["url"] = "url"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of the url when string rep needed."""
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
)
|
||||
config_out: Optional[AnyModelConfig] = Field(
|
||||
default=None, description="After successful installation, this will hold the configuration object."
|
||||
)
|
||||
inplace: bool = Field(
|
||||
default=False, description="Leave model in its current location; otherwise install under models directory"
|
||||
)
|
||||
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||
bytes: int = Field(
|
||||
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
||||
)
|
||||
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
|
||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||
default=None, description="Metadata provided by the model source"
|
||||
)
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||
)
|
||||
error_traceback: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the exception traceback"
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self._exception = e
|
||||
self.error = str(e)
|
||||
self.error_traceback = self._format_error(e)
|
||||
self.status = InstallStatus.ERROR
|
||||
self.error_reason = self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
self.status = InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def error_type(self) -> Optional[str]:
|
||||
"""Class name of the exception that led to status==ERROR."""
|
||||
return self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def _format_error(self, exception: Exception) -> str:
|
||||
"""Error traceback."""
|
||||
return "".join(traceback.format_exception(exception))
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Set status to CANCELLED."""
|
||||
return self.status == InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
"""Return true if job has errored."""
|
||||
return self.status == InstallStatus.ERROR
|
||||
|
||||
@property
|
||||
def waiting(self) -> bool:
|
||||
"""Return true if job is waiting to run."""
|
||||
return self.status == InstallStatus.WAITING
|
||||
|
||||
@property
|
||||
def downloading(self) -> bool:
|
||||
"""Return true if job is downloading."""
|
||||
return self.status == InstallStatus.DOWNLOADING
|
||||
|
||||
@property
|
||||
def downloads_done(self) -> bool:
|
||||
"""Return true if job's downloads ae done."""
|
||||
return self.status == InstallStatus.DOWNLOADS_DONE
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if job is running."""
|
||||
return self.status == InstallStatus.RUNNING
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
"""Return true if job completed without errors."""
|
||||
return self.status == InstallStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def in_terminal_state(self) -> bool:
|
||||
"""Return true if job is in a terminal state."""
|
||||
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
|
||||
|
||||
|
||||
class ModelInstallServiceBase(ABC):
|
||||
@@ -57,7 +282,7 @@ class ModelInstallServiceBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def event_bus(self) -> Optional["EventServiceBase"]:
|
||||
def event_bus(self) -> Optional[EventServiceBase]:
|
||||
"""Return the event service base object associated with the installer."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,233 +0,0 @@
|
||||
import re
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.download import DownloadJob
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
"""State of an install job running in the background."""
|
||||
|
||||
WAITING = "waiting" # waiting to be dequeued
|
||||
DOWNLOADING = "downloading" # downloading of model files in process
|
||||
DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run
|
||||
RUNNING = "running" # being processed
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
CANCELLED = "cancelled" # terminated with an error message
|
||||
|
||||
|
||||
class ModelInstallPart(BaseModel):
|
||||
url: AnyHttpUrl
|
||||
path: Path
|
||||
bytes: int = 0
|
||||
total_bytes: int = 0
|
||||
|
||||
|
||||
class UnknownInstallJobException(Exception):
|
||||
"""Raised when the status of an unknown job is requested."""
|
||||
|
||||
|
||||
class StringLikeSource(BaseModel):
|
||||
"""
|
||||
Base class for model sources, implements functions that lets the source be sorted and indexed.
|
||||
|
||||
These shenanigans let this stuff work:
|
||||
|
||||
source1 = LocalModelSource(path='C:/users/mort/foo.safetensors')
|
||||
mydict = {source1: 'model 1'}
|
||||
assert mydict['C:/users/mort/foo.safetensors'] == 'model 1'
|
||||
assert mydict[LocalModelSource(path='C:/users/mort/foo.safetensors')] == 'model 1'
|
||||
|
||||
source2 = LocalModelSource(path=Path('C:/users/mort/foo.safetensors'))
|
||||
assert source1 == source2
|
||||
assert source1 == 'C:/users/mort/foo.safetensors'
|
||||
"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return hash of the path field, for indexing."""
|
||||
return hash(str(self))
|
||||
|
||||
def __lt__(self, other: object) -> int:
|
||||
"""Return comparison of the stringified version, for sorting."""
|
||||
return str(self) < str(other)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""Return equality on the stringified version."""
|
||||
if isinstance(other, Path):
|
||||
return str(self) == other.as_posix()
|
||||
else:
|
||||
return str(self) == str(other)
|
||||
|
||||
|
||||
class LocalModelSource(StringLikeSource):
|
||||
"""A local file or directory path."""
|
||||
|
||||
path: str | Path
|
||||
inplace: Optional[bool] = False
|
||||
type: Literal["local"] = "local"
|
||||
|
||||
# these methods allow the source to be used in a string-like way,
|
||||
# for example as an index into a dict
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of path when string rep needed."""
|
||||
return Path(self.path).as_posix()
|
||||
|
||||
|
||||
class HFModelSource(StringLikeSource):
|
||||
"""
|
||||
A HuggingFace repo_id with optional variant, sub-folder and access token.
|
||||
Note that the variant option, if not provided to the constructor, will default to fp16, which is
|
||||
what people (almost) always want.
|
||||
"""
|
||||
|
||||
repo_id: str
|
||||
variant: Optional[ModelRepoVariant] = ModelRepoVariant.FP16
|
||||
subfolder: Optional[Path] = None
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["hf"] = "hf"
|
||||
|
||||
@field_validator("repo_id")
|
||||
@classmethod
|
||||
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||
if not re.match(r"^([.\w-]+/[.\w-]+)$", v):
|
||||
raise ValueError(f"{v}: invalid repo_id format")
|
||||
return v
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of repoid when string rep needed."""
|
||||
base: str = self.repo_id
|
||||
if self.variant:
|
||||
base += f":{self.variant or ''}"
|
||||
if self.subfolder:
|
||||
base += f":{self.subfolder}"
|
||||
return base
|
||||
|
||||
|
||||
class URLModelSource(StringLikeSource):
|
||||
"""A generic URL point to a checkpoint file."""
|
||||
|
||||
url: AnyHttpUrl
|
||||
access_token: Optional[str] = None
|
||||
type: Literal["url"] = "url"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string version of the url when string rep needed."""
|
||||
return str(self.url)
|
||||
|
||||
|
||||
ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")]
|
||||
|
||||
MODEL_SOURCE_TO_TYPE_MAP = {
|
||||
URLModelSource: ModelSourceType.Url,
|
||||
HFModelSource: ModelSourceType.HFRepoID,
|
||||
LocalModelSource: ModelSourceType.Path,
|
||||
}
|
||||
|
||||
|
||||
class ModelInstallJob(BaseModel):
|
||||
"""Object that tracks the current status of an install request."""
|
||||
|
||||
id: int = Field(description="Unique ID for this job")
|
||||
status: InstallStatus = Field(default=InstallStatus.WAITING, description="Current status of install process")
|
||||
error_reason: Optional[str] = Field(default=None, description="Information about why the job failed")
|
||||
config_in: Dict[str, Any] = Field(
|
||||
default_factory=dict, description="Configuration information (e.g. 'description') to apply to model."
|
||||
)
|
||||
config_out: Optional[AnyModelConfig] = Field(
|
||||
default=None, description="After successful installation, this will hold the configuration object."
|
||||
)
|
||||
inplace: bool = Field(
|
||||
default=False, description="Leave model in its current location; otherwise install under models directory"
|
||||
)
|
||||
source: ModelSource = Field(description="Source (URL, repo_id, or local path) of model")
|
||||
local_path: Path = Field(description="Path to locally-downloaded model; may be the same as the source")
|
||||
bytes: int = Field(
|
||||
default=0, description="For a remote model, the number of bytes downloaded so far (may not be available)"
|
||||
)
|
||||
total_bytes: int = Field(default=0, description="Total size of the model to be installed")
|
||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||
default=None, description="Metadata provided by the model source"
|
||||
)
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||
)
|
||||
error_traceback: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the exception traceback"
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
"""Record the error and traceback from an exception."""
|
||||
self._exception = e
|
||||
self.error = str(e)
|
||||
self.error_traceback = self._format_error(e)
|
||||
self.status = InstallStatus.ERROR
|
||||
self.error_reason = self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Call to cancel the job."""
|
||||
self.status = InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def error_type(self) -> Optional[str]:
|
||||
"""Class name of the exception that led to status==ERROR."""
|
||||
return self._exception.__class__.__name__ if self._exception else None
|
||||
|
||||
def _format_error(self, exception: Exception) -> str:
|
||||
"""Error traceback."""
|
||||
return "".join(traceback.format_exception(exception))
|
||||
|
||||
@property
|
||||
def cancelled(self) -> bool:
|
||||
"""Set status to CANCELLED."""
|
||||
return self.status == InstallStatus.CANCELLED
|
||||
|
||||
@property
|
||||
def errored(self) -> bool:
|
||||
"""Return true if job has errored."""
|
||||
return self.status == InstallStatus.ERROR
|
||||
|
||||
@property
|
||||
def waiting(self) -> bool:
|
||||
"""Return true if job is waiting to run."""
|
||||
return self.status == InstallStatus.WAITING
|
||||
|
||||
@property
|
||||
def downloading(self) -> bool:
|
||||
"""Return true if job is downloading."""
|
||||
return self.status == InstallStatus.DOWNLOADING
|
||||
|
||||
@property
|
||||
def downloads_done(self) -> bool:
|
||||
"""Return true if job's downloads ae done."""
|
||||
return self.status == InstallStatus.DOWNLOADS_DONE
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""Return true if job is running."""
|
||||
return self.status == InstallStatus.RUNNING
|
||||
|
||||
@property
|
||||
def complete(self) -> bool:
|
||||
"""Return true if job completed without errors."""
|
||||
return self.status == InstallStatus.COMPLETED
|
||||
|
||||
@property
|
||||
def in_terminal_state(self) -> bool:
|
||||
"""Return true if job is in a terminal state."""
|
||||
return self.status in [InstallStatus.COMPLETED, InstallStatus.ERROR, InstallStatus.CANCELLED]
|
||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@@ -20,8 +20,8 @@ from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_install.model_install_base import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@@ -45,12 +45,13 @@ from invokeai.backend.util import InvokeAILogger
|
||||
from invokeai.backend.util.catch_sigint import catch_sigint
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .model_install_common import (
|
||||
from .model_install_base import (
|
||||
MODEL_SOURCE_TO_TYPE_MAP,
|
||||
HFModelSource,
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
ModelSource,
|
||||
StringLikeSource,
|
||||
URLModelSource,
|
||||
@@ -58,9 +59,6 @@ from .model_install_common import (
|
||||
|
||||
TMPDIR_PREFIX = "tmpinstall_"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
|
||||
|
||||
class ModelInstallService(ModelInstallServiceBase):
|
||||
"""class for InvokeAI model installation."""
|
||||
@@ -70,7 +68,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
app_config: InvokeAIAppConfig,
|
||||
record_store: ModelRecordServiceBase,
|
||||
download_queue: DownloadQueueServiceBase,
|
||||
event_bus: Optional["EventServiceBase"] = None,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
session: Optional[Session] = None,
|
||||
):
|
||||
"""
|
||||
@@ -106,7 +104,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return self._record_store
|
||||
|
||||
@property
|
||||
def event_bus(self) -> Optional["EventServiceBase"]: # noqa D102
|
||||
def event_bus(self) -> Optional[EventServiceBase]: # noqa D102
|
||||
return self._event_bus
|
||||
|
||||
# make the invoker optional here because we don't need it and it
|
||||
@@ -857,17 +855,35 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.status = InstallStatus.RUNNING
|
||||
self._logger.info(f"Model install started: {job.source}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_started(job)
|
||||
self._event_bus.emit_model_install_running(str(job.source))
|
||||
|
||||
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_download_progress(job)
|
||||
parts: List[Dict[str, str | int]] = [
|
||||
{
|
||||
"url": str(x.source),
|
||||
"local_path": str(x.download_path),
|
||||
"bytes": x.bytes,
|
||||
"total_bytes": x.total_bytes,
|
||||
}
|
||||
for x in job.download_parts
|
||||
]
|
||||
assert job.bytes is not None
|
||||
assert job.total_bytes is not None
|
||||
self._event_bus.emit_model_install_downloading(
|
||||
str(job.source),
|
||||
local_path=job.local_path.as_posix(),
|
||||
parts=parts,
|
||||
bytes=job.bytes,
|
||||
total_bytes=job.total_bytes,
|
||||
id=job.id,
|
||||
)
|
||||
|
||||
def _signal_job_downloads_done(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.DOWNLOADS_DONE
|
||||
self._logger.info(f"Model download complete: {job.source}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_downloads_complete(job)
|
||||
self._event_bus.emit_model_install_downloads_done(str(job.source))
|
||||
|
||||
def _signal_job_completed(self, job: ModelInstallJob) -> None:
|
||||
job.status = InstallStatus.COMPLETED
|
||||
@@ -875,19 +891,24 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._logger.info(f"Model install complete: {job.source}")
|
||||
self._logger.debug(f"{job.local_path} registered key {job.config_out.key}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_complete(job)
|
||||
assert job.local_path is not None
|
||||
assert job.config_out is not None
|
||||
key = job.config_out.key
|
||||
self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id)
|
||||
|
||||
def _signal_job_errored(self, job: ModelInstallJob) -> None:
|
||||
self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}")
|
||||
if self._event_bus:
|
||||
assert job.error_type is not None
|
||||
assert job.error is not None
|
||||
self._event_bus.emit_model_install_error(job)
|
||||
error_type = job.error_type
|
||||
error = job.error
|
||||
assert error_type is not None
|
||||
assert error is not None
|
||||
self._event_bus.emit_model_install_error(str(job.source), error_type, error, id=job.id)
|
||||
|
||||
def _signal_job_cancelled(self, job: ModelInstallJob) -> None:
|
||||
self._logger.info(f"Model install canceled: {job.source}")
|
||||
if self._event_bus:
|
||||
self._event_bus.emit_model_install_cancelled(job)
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
@@ -14,12 +15,18 @@ class ModelLoadServiceBase(ABC):
|
||||
"""Wrapper around AnyModelLoader."""
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context_data: Invocation context data used for event reporting
|
||||
"""
|
||||
|
||||
@property
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional, Type
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
@@ -50,18 +51,25 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
"""Return the checkpoint convert cache used by this loader."""
|
||||
return self._convert_cache
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(
|
||||
self,
|
||||
model_config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
context_data: Optional[InvocationContextData] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Given a model's configuration, load it and return the LoadedModel object.
|
||||
|
||||
:param model_config: Model configuration record (as returned by ModelRecordBase.get_model())
|
||||
:param submodel: For main (pipeline models), the submodel to fetch.
|
||||
:param context: Invocation context used for event reporting
|
||||
"""
|
||||
|
||||
# We don't have an invoker during testing
|
||||
# TODO(psyche): Mock this method on the invoker in the tests
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_started(model_config, submodel_type)
|
||||
if context_data:
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
||||
implementation, model_config, submodel_type = self._registry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
loaded_model: LoadedModel = implementation(
|
||||
@@ -71,7 +79,40 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
convert_cache=self._convert_cache,
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
if hasattr(self, "_invoker"):
|
||||
self._invoker.services.events.emit_model_load_complete(model_config, submodel_type)
|
||||
|
||||
if context_data:
|
||||
self._emit_load_event(
|
||||
context_data=context_data,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
loaded=True,
|
||||
)
|
||||
return loaded_model
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context_data: InvocationContextData,
|
||||
model_config: AnyModelConfig,
|
||||
loaded: Optional[bool] = False,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> None:
|
||||
if not self._invoker:
|
||||
return
|
||||
|
||||
if not loaded:
|
||||
self._invoker.services.events.emit_model_load_started(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
else:
|
||||
self._invoker.services.events.emit_model_load_completed(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
model_config=model_config,
|
||||
submodel_type=submodel_type,
|
||||
)
|
||||
|
||||
@@ -4,14 +4,11 @@ from threading import BoundedSemaphore, Thread
|
||||
from threading import Event as ThreadEvent
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||
from invokeai.app.services.events.events_common import (
|
||||
BatchEnqueuedEvent,
|
||||
FastAPIEvent,
|
||||
QueueClearedEvent,
|
||||
QueueItemStatusChangedEvent,
|
||||
register_events,
|
||||
)
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||
from invokeai.app.services.session_processor.session_processor_base import (
|
||||
OnAfterRunNode,
|
||||
@@ -63,11 +60,6 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
self._cancel_event = cancel_event
|
||||
self._profiler = profiler
|
||||
|
||||
def _is_canceled(self) -> bool:
|
||||
"""Check if the cancel event is set. This is also passed to the invocation context builder and called during
|
||||
denoising to check if the session has been canceled."""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def run(self, queue_item: SessionQueueItem):
|
||||
# Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here.
|
||||
|
||||
@@ -91,19 +83,13 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
)
|
||||
break
|
||||
|
||||
if invocation is None or self._is_canceled():
|
||||
if invocation is None or self._cancel_event.is_set():
|
||||
break
|
||||
|
||||
self.run_node(invocation, queue_item)
|
||||
|
||||
# The session is complete if all invocations have been run or there is an error on the session.
|
||||
# At this time, the queue item may be canceled, but the object itself here won't be updated yet. We must
|
||||
# use the cancel event to check if the session is canceled.
|
||||
if (
|
||||
queue_item.session.is_complete()
|
||||
or self._is_canceled()
|
||||
or queue_item.status in ["failed", "canceled", "completed"]
|
||||
):
|
||||
if queue_item.session.is_complete() or self._cancel_event.is_set():
|
||||
break
|
||||
|
||||
self._on_after_run_session(queue_item=queue_item)
|
||||
@@ -122,7 +108,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._services,
|
||||
is_canceled=self._is_canceled,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
@@ -136,12 +122,16 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here?
|
||||
pass
|
||||
except CanceledException:
|
||||
# A CanceledException is raised during the denoising step callback if the cancel event is set. We don't need
|
||||
# to do any handling here, and no error should be set - just pass and the cancellation will be handled
|
||||
# correctly in the next iteration of the session runner loop.
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# See the comment in the processor's `_on_queue_item_status_changed()` method for more details on how we
|
||||
# handle cancellation.
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
error_type = e.__class__.__name__
|
||||
@@ -156,11 +146,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
)
|
||||
|
||||
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Called before a session is run.
|
||||
|
||||
- Start the profiler if profiling is enabled.
|
||||
- Run any callbacks registered for this event.
|
||||
"""
|
||||
"""Run before a session is executed"""
|
||||
|
||||
self._services.logger.debug(
|
||||
f"On before run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
|
||||
@@ -174,14 +160,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
callback(queue_item=queue_item)
|
||||
|
||||
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||
"""Called after a session is run.
|
||||
|
||||
- Stop the profiler if profiling is enabled.
|
||||
- Update the queue item's session object in the database.
|
||||
- If not already canceled or failed, complete the queue item.
|
||||
- Log and reset performance statistics.
|
||||
- Run any callbacks registered for this event.
|
||||
"""
|
||||
"""Run after a session is executed"""
|
||||
|
||||
self._services.logger.debug(
|
||||
f"On after run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
|
||||
@@ -201,10 +180,14 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# while the session is running.
|
||||
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
|
||||
# The queue item may have been canceled or failed while the session was running. We should only complete it
|
||||
# if it is not already canceled or failed.
|
||||
if queue_item.status not in ["canceled", "failed"]:
|
||||
queue_item = self._services.session_queue.complete_queue_item(queue_item.item_id)
|
||||
# TODO(psyche): This feels jumbled - we should review separation of concerns here.
|
||||
# Send complete event. The events service will receive this and update the queue item's status.
|
||||
self._services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
)
|
||||
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
@@ -218,18 +201,21 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
pass
|
||||
|
||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Called before a node is run.
|
||||
|
||||
- Emits an invocation started event.
|
||||
- Run any callbacks registered for this event.
|
||||
"""
|
||||
"""Run before a node is executed"""
|
||||
|
||||
self._services.logger.debug(
|
||||
f"On before run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
|
||||
)
|
||||
|
||||
# Send starting event
|
||||
self._services.events.emit_invocation_started(queue_item=queue_item, invocation=invocation)
|
||||
self._services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
)
|
||||
|
||||
for callback in self._on_before_run_node_callbacks:
|
||||
callback(invocation=invocation, queue_item=queue_item)
|
||||
@@ -237,18 +223,22 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
def _on_after_run_node(
|
||||
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
|
||||
):
|
||||
"""Called after a node is run.
|
||||
|
||||
- Emits an invocation complete event.
|
||||
- Run any callbacks registered for this event.
|
||||
"""
|
||||
"""Run after a node is executed"""
|
||||
|
||||
self._services.logger.debug(
|
||||
f"On after run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
|
||||
)
|
||||
|
||||
# Send complete event on successful runs
|
||||
self._services.events.emit_invocation_complete(invocation=invocation, queue_item=queue_item, output=output)
|
||||
self._services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
result=output.model_dump(),
|
||||
)
|
||||
|
||||
for callback in self._on_after_run_node_callbacks:
|
||||
callback(invocation=invocation, queue_item=queue_item, output=output)
|
||||
@@ -261,14 +251,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
):
|
||||
"""Called when a node errors. Node errors may occur when running or preparing the node..
|
||||
|
||||
- Set the node error on the session object.
|
||||
- Log the error.
|
||||
- Fail the queue item.
|
||||
- Emits an invocation error event.
|
||||
- Run any callbacks registered for this event.
|
||||
"""
|
||||
"""Run when a node errors"""
|
||||
|
||||
self._services.logger.debug(
|
||||
f"On node error: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
|
||||
@@ -282,19 +265,19 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
)
|
||||
self._services.logger.error(error_traceback)
|
||||
|
||||
# Fail the queue item
|
||||
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
queue_item = self._services.session_queue.fail_queue_item(
|
||||
queue_item.item_id, error_type, error_message, error_traceback
|
||||
)
|
||||
|
||||
# Send error event
|
||||
self._services.events.emit_invocation_error(
|
||||
queue_item=queue_item,
|
||||
invocation=invocation,
|
||||
queue_batch_id=queue_item.session_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session.id,
|
||||
node=invocation.model_dump(),
|
||||
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
user_id=getattr(queue_item, "user_id", None),
|
||||
project_id=getattr(queue_item, "project_id", None),
|
||||
)
|
||||
|
||||
for callback in self._on_node_error_callbacks:
|
||||
@@ -332,9 +315,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._poll_now_event = ThreadEvent()
|
||||
self._cancel_event = ThreadEvent()
|
||||
|
||||
register_events(QueueClearedEvent, self._on_queue_cleared)
|
||||
register_events(BatchEnqueuedEvent, self._on_batch_enqueued)
|
||||
register_events(QueueItemStatusChangedEvent, self._on_queue_item_status_changed)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||
|
||||
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
|
||||
|
||||
@@ -369,25 +350,31 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def _poll_now(self) -> None:
|
||||
self._poll_now_event.set()
|
||||
|
||||
async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
|
||||
if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
if (
|
||||
event_name == "session_canceled"
|
||||
and self._queue_item
|
||||
and self._queue_item.item_id == event[1]["data"]["queue_item_id"]
|
||||
):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
|
||||
async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None:
|
||||
self._poll_now()
|
||||
|
||||
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
|
||||
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
|
||||
# When the queue item is canceled via HTTP, the queue item status is set to `"canceled"` and this event is
|
||||
# emitted. We need to respond to this event and stop graph execution. This is done by setting the cancel
|
||||
# event, which the session runner checks between invocations. If set, the session runner loop is broken.
|
||||
#
|
||||
# Long-running nodes that cannot be interrupted easily present a challenge. `denoise_latents` is one such
|
||||
# node, but it gets a step callback, called on each step of denoising. This callback checks if the queue item
|
||||
# is canceled, and if it is, raises a `CanceledException` to stop execution immediately.
|
||||
if event[1].status == "canceled":
|
||||
self._cancel_event.set()
|
||||
elif (
|
||||
event_name == "queue_cleared"
|
||||
and self._queue_item
|
||||
and self._queue_item.queue_id == event[1]["data"]["queue_id"]
|
||||
):
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
elif event_name == "batch_enqueued":
|
||||
self._poll_now()
|
||||
elif event_name == "queue_item_status_changed" and event[1]["data"]["queue_item"]["status"] in [
|
||||
"completed",
|
||||
"failed",
|
||||
"canceled",
|
||||
]:
|
||||
self._cancel_event.set()
|
||||
self._poll_now()
|
||||
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
@@ -476,22 +463,15 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
) -> None:
|
||||
"""Called when a non-fatal error occurs in the processor.
|
||||
|
||||
- Log the error.
|
||||
- If a queue item is provided, update the queue item with the completed session & fail it.
|
||||
- Run any callbacks registered for this event.
|
||||
"""
|
||||
|
||||
# Non-fatal error in processor
|
||||
self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}")
|
||||
self._invoker.services.logger.error(error_traceback)
|
||||
|
||||
if queue_item is not None:
|
||||
# Update the queue item with the completed session & fail it
|
||||
queue_item = self._invoker.services.session_queue.set_queue_item_session(
|
||||
queue_item.item_id, queue_item.session
|
||||
)
|
||||
queue_item = self._invoker.services.session_queue.fail_queue_item(
|
||||
# Update the queue item with the completed session
|
||||
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
# Fail the queue item
|
||||
self._invoker.services.session_queue.fail_queue_item(
|
||||
item_id=queue_item.item_id,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
|
||||
@@ -73,11 +73,6 @@ class SessionQueueBase(ABC):
|
||||
"""Gets the status of a batch"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
"""Completes a session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
"""Cancels a session queue item"""
|
||||
|
||||
@@ -2,6 +2,10 @@ import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
@@ -38,7 +42,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
||||
if prune_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
@@ -48,6 +52,60 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__conn = db.conn
|
||||
self.__cursor = self.__conn.cursor()
|
||||
|
||||
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||
return event[1]["event"] in match_in
|
||||
|
||||
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
|
||||
event_name = event[1]["event"]
|
||||
|
||||
# This was a match statement, but match is not supported on python 3.9
|
||||
if event_name == "graph_execution_state_complete":
|
||||
await self._handle_complete_event(event)
|
||||
elif event_name == "invocation_error":
|
||||
await self._handle_error_event(event)
|
||||
elif event_name == "session_canceled":
|
||||
await self._handle_cancel_event(event)
|
||||
return event
|
||||
|
||||
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
# When a queue item has an error, we get an error event, then a completed event.
|
||||
# Mark the queue item completed only if it isn't already marked completed, e.g.
|
||||
# by a previously-handled error event.
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
async def _handle_error_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
error_type = event[1]["data"]["error_type"]
|
||||
error_message = event[1]["data"]["error_message"]
|
||||
error_traceback = event[1]["data"]["error_traceback"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
# always set to failed if have an error, even if previously the item was marked completed or canceled
|
||||
queue_item = self._set_queue_item_status(
|
||||
item_id=queue_item.item_id,
|
||||
status="failed",
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
||||
@@ -248,7 +306,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
session_queue_item=queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
)
|
||||
return queue_item
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
@@ -357,11 +419,15 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
return PruneResult(deleted=count)
|
||||
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||
return queue_item
|
||||
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
)
|
||||
return queue_item
|
||||
|
||||
def fail_queue_item(
|
||||
@@ -371,13 +437,21 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
) -> SessionQueueItem:
|
||||
queue_item = self._set_queue_item_status(
|
||||
item_id=item_id,
|
||||
status="failed",
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
queue_item = self._set_queue_item_status(
|
||||
item_id=item_id,
|
||||
status="failed",
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
)
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
@@ -413,10 +487,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
queue_batch_id=current_queue_item.batch_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
current_queue_item, batch_status, queue_status
|
||||
session_queue_item=current_queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
@@ -456,10 +538,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
queue_batch_id=current_queue_item.batch_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(
|
||||
current_queue_item, batch_status, queue_status
|
||||
session_queue_item=current_queue_item,
|
||||
batch_status=batch_status,
|
||||
queue_status=queue_status,
|
||||
)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
|
||||
@@ -2,19 +2,18 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
from typing import Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
import networkx as nx
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
GetCoreSchemaHandler,
|
||||
GetJsonSchemaHandler,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic.fields import Field
|
||||
from pydantic.json_schema import JsonSchemaValue
|
||||
from pydantic_core import core_schema
|
||||
from pydantic_core import CoreSchema
|
||||
|
||||
# Importing * is bad karma but needed here for node detection
|
||||
from invokeai.app.invocations import * # noqa: F401 F403
|
||||
@@ -278,58 +277,73 @@ class CollectInvocation(BaseInvocation):
|
||||
return CollectInvocationOutput(collection=copy.copy(self.collection))
|
||||
|
||||
|
||||
class AnyInvocation(BaseInvocation):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
|
||||
def validate_invocation(v: Any) -> "AnyInvocation":
|
||||
return BaseInvocation.get_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
oneOf: list[dict[str, str]] = []
|
||||
names = [i.__name__ for i in BaseInvocation.get_invocations()]
|
||||
for name in sorted(names):
|
||||
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||
return {"oneOf": oneOf}
|
||||
|
||||
|
||||
class AnyInvocationOutput(BaseInvocationOutput):
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler):
|
||||
def validate_invocation_output(v: Any) -> "AnyInvocationOutput":
|
||||
return BaseInvocationOutput.get_typeadapter().validate_python(v)
|
||||
|
||||
return core_schema.no_info_plain_validator_function(validate_invocation_output)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(
|
||||
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
|
||||
) -> JsonSchemaValue:
|
||||
# Nodes are too powerful, we have to make our own OpenAPI schema manually
|
||||
# No but really, because the schema is dynamic depending on loaded nodes, we need to generate it manually
|
||||
|
||||
oneOf: list[dict[str, str]] = []
|
||||
names = [i.__name__ for i in BaseInvocationOutput.get_outputs()]
|
||||
for name in sorted(names):
|
||||
oneOf.append({"$ref": f"#/components/schemas/{name}"})
|
||||
return {"oneOf": oneOf}
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: str = Field(description="The id of this graph", default_factory=uuid_string)
|
||||
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
|
||||
nodes: dict[str, AnyInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||
nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict)
|
||||
edges: list[Edge] = Field(
|
||||
description="The connections between nodes and their fields in this graph",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
@field_validator("nodes", mode="plain")
|
||||
@classmethod
|
||||
def validate_nodes(cls, v: dict[str, Any]):
|
||||
"""Validates the nodes in the graph by retrieving a union of all node types and validating each node."""
|
||||
|
||||
# Invocations register themselves as their python modules are executed. The union of all invocations is
|
||||
# constructed at runtime. We use pydantic to validate `Graph.nodes` using that union.
|
||||
#
|
||||
# It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If
|
||||
# we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing
|
||||
# invocations will cause a graph to fail if they are used.
|
||||
#
|
||||
# We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the
|
||||
# pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime.
|
||||
#
|
||||
# This same pattern is used in `GraphExecutionState`.
|
||||
|
||||
nodes: dict[str, BaseInvocation] = {}
|
||||
typeadapter = BaseInvocation.get_typeadapter()
|
||||
for node_id, node in v.items():
|
||||
nodes[node_id] = typeadapter.validate_python(node)
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for
|
||||
# fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to
|
||||
# the generated schema as options for the `nodes` field.
|
||||
#
|
||||
# The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and
|
||||
# with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as
|
||||
# expected.
|
||||
#
|
||||
# You might be tempted to do something like this:
|
||||
#
|
||||
# ```py
|
||||
# cloned_model = create_model(cls.__name__, __base__=cls, nodes=...)
|
||||
# delattr(cloned_model, "validate_nodes")
|
||||
# cloned_model.model_rebuild(force=True)
|
||||
# json_schema = handler(cloned_model.__pydantic_core_schema__)
|
||||
# ```
|
||||
#
|
||||
# Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts
|
||||
# to build the JSON Schema for the cloned model. Instead, we have to manually clone the model.
|
||||
#
|
||||
# This same pattern is used in `GraphExecutionState`.
|
||||
|
||||
class Graph(BaseModel):
|
||||
id: Optional[str] = Field(default=None, description="The id of this graph")
|
||||
nodes: dict[
|
||||
str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")]
|
||||
] = Field(description="The nodes in this graph")
|
||||
edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph")
|
||||
|
||||
json_schema = handler(Graph.__pydantic_core_schema__)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
return json_schema
|
||||
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
"""Adds a node to a graph
|
||||
|
||||
@@ -760,7 +774,7 @@ class GraphExecutionState(BaseModel):
|
||||
)
|
||||
|
||||
# The results of executed nodes
|
||||
results: dict[str, AnyInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||
results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict)
|
||||
|
||||
# Errors raised when executing nodes
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict)
|
||||
@@ -777,12 +791,52 @@ class GraphExecutionState(BaseModel):
|
||||
default_factory=dict,
|
||||
)
|
||||
|
||||
@field_validator("results", mode="plain")
|
||||
@classmethod
|
||||
def validate_results(cls, v: dict[str, BaseInvocationOutput]):
|
||||
"""Validates the results in the GES by retrieving a union of all output types and validating each result."""
|
||||
|
||||
# See the comment in `Graph.validate_nodes` for an explanation of this logic.
|
||||
results: dict[str, BaseInvocationOutput] = {}
|
||||
typeadapter = BaseInvocationOutput.get_typeadapter()
|
||||
for result_id, result in v.items():
|
||||
results[result_id] = typeadapter.validate_python(result)
|
||||
return results
|
||||
|
||||
@field_validator("graph")
|
||||
def graph_is_valid(cls, v: Graph):
|
||||
"""Validates that the graph is valid"""
|
||||
v.validate_self()
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
|
||||
# See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic.
|
||||
class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(description="The id of the execution state")
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes")
|
||||
executed: set[str] = Field(description="The set of node ids that have been executed")
|
||||
executed_history: list[str] = Field(
|
||||
description="The list of node ids that have been executed, in order of execution"
|
||||
)
|
||||
results: dict[
|
||||
str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")]
|
||||
] = Field(description="The results of node executions")
|
||||
errors: dict[str, str] = Field(description="Errors raised when executing nodes")
|
||||
prepared_source_mapping: dict[str, str] = Field(
|
||||
description="The map of prepared nodes to original graph nodes"
|
||||
)
|
||||
source_prepared_mapping: dict[str, set[str]] = Field(
|
||||
description="The map of original graph nodes to prepared nodes"
|
||||
)
|
||||
|
||||
json_schema = handler(GraphExecutionState.__pydantic_core_schema__)
|
||||
json_schema = handler.resolve_ref_schema(json_schema)
|
||||
return json_schema
|
||||
|
||||
def next(self) -> Optional[BaseInvocation]:
|
||||
"""Gets the next node ready to execute."""
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from PIL.Image import Image
|
||||
from torch import Tensor
|
||||
@@ -352,11 +353,11 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
if isinstance(identifier, str):
|
||||
model = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type)
|
||||
return self._services.model_manager.load.load_model(model, submodel_type, self._data)
|
||||
else:
|
||||
_submodel_type = submodel_type or identifier.submodel_type
|
||||
model = self._services.model_manager.store.get_model(identifier.key)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type)
|
||||
return self._services.model_manager.load.load_model(model, _submodel_type, self._data)
|
||||
|
||||
def load_by_attrs(
|
||||
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
|
||||
@@ -381,7 +382,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
if len(configs) > 1:
|
||||
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")
|
||||
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type)
|
||||
return self._services.model_manager.load.load_model(configs[0], submodel_type, self._data)
|
||||
|
||||
def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
|
||||
"""Gets a model's config.
|
||||
@@ -448,10 +449,10 @@ class ConfigInterface(InvocationContextInterface):
|
||||
|
||||
class UtilInterface(InvocationContextInterface):
|
||||
def __init__(
|
||||
self, services: InvocationServices, data: InvocationContextData, is_canceled: Callable[[], bool]
|
||||
self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event
|
||||
) -> None:
|
||||
super().__init__(services, data)
|
||||
self._is_canceled = is_canceled
|
||||
self._cancel_event = cancel_event
|
||||
|
||||
def is_canceled(self) -> bool:
|
||||
"""Checks if the current session has been canceled.
|
||||
@@ -459,7 +460,7 @@ class UtilInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
True if the current session has been canceled, False if not.
|
||||
"""
|
||||
return self._is_canceled()
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
||||
"""
|
||||
@@ -534,7 +535,7 @@ class InvocationContext:
|
||||
def build_invocation_context(
|
||||
services: InvocationServices,
|
||||
data: InvocationContextData,
|
||||
is_canceled: Callable[[], bool],
|
||||
cancel_event: threading.Event,
|
||||
) -> InvocationContext:
|
||||
"""Builds the invocation context for a specific invocation execution.
|
||||
|
||||
@@ -551,7 +552,7 @@ def build_invocation_context(
|
||||
tensors = TensorsInterface(services=services, data=data)
|
||||
models = ModelsInterface(services=services, data=data)
|
||||
config = ConfigInterface(services=services, data=data)
|
||||
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
|
||||
util = UtilInterface(services=services, data=data, cancel_event=cancel_event)
|
||||
conditioning = ConditioningInterface(services=services, data=data)
|
||||
boards = BoardsInterface(services=services, data=data)
|
||||
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from pydantic.json_schema import models_json_schema
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, UIConfigBase
|
||||
from invokeai.app.invocations.fields import InputFieldJSONSchemaExtra, OutputFieldJSONSchemaExtra
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.events.events_common import EventBase
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
|
||||
|
||||
def move_defs_to_top_level(openapi_schema: dict[str, Any], component_schema: dict[str, Any]) -> None:
|
||||
"""Moves a component schema's $defs to the top level of the openapi schema. Useful when generating a schema
|
||||
for a single model that needs to be added back to the top level of the schema. Mutates openapi_schema and
|
||||
component_schema."""
|
||||
|
||||
defs = component_schema.pop("$defs", {})
|
||||
for schema_key, json_schema in defs.items():
|
||||
if schema_key in openapi_schema["components"]["schemas"]:
|
||||
continue
|
||||
openapi_schema["components"]["schemas"][schema_key] = json_schema
|
||||
|
||||
|
||||
def get_openapi_func(
|
||||
app: FastAPI, post_transform: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None
|
||||
) -> Callable[[], dict[str, Any]]:
|
||||
"""Gets the OpenAPI schema generator function.
|
||||
|
||||
Args:
|
||||
app (FastAPI): The FastAPI app to generate the schema for.
|
||||
post_transform (Optional[Callable[[dict[str, Any]], dict[str, Any]]], optional): A function to apply to the
|
||||
generated schema before returning it. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Callable[[], dict[str, Any]]: The OpenAPI schema generator function. When first called, the generated schema is
|
||||
cached in `app.openapi_schema`. On subsequent calls, the cached schema is returned. This caching behaviour
|
||||
matches FastAPI's default schema generation caching.
|
||||
"""
|
||||
|
||||
def openapi() -> dict[str, Any]:
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
title=app.title,
|
||||
description="An API for invoking AI image operations",
|
||||
version="1.0.0",
|
||||
routes=app.routes,
|
||||
separate_input_output_schemas=False, # https://fastapi.tiangolo.com/how-to/separate-openapi-schemas/
|
||||
)
|
||||
|
||||
# We'll create a map of invocation type to output schema to make some types simpler on the client.
|
||||
invocation_output_map_properties: dict[str, Any] = {}
|
||||
invocation_output_map_required: list[str] = []
|
||||
|
||||
# We need to manually add all outputs to the schema - pydantic doesn't add them because they aren't used directly.
|
||||
for output in BaseInvocationOutput.get_outputs():
|
||||
json_schema = output.model_json_schema(mode="serialization", ref_template="#/components/schemas/{model}")
|
||||
move_defs_to_top_level(openapi_schema, json_schema)
|
||||
openapi_schema["components"]["schemas"][output.__name__] = json_schema
|
||||
|
||||
# Technically, invocations are added to the schema by pydantic, but we still need to manually set their output
|
||||
# property, so we'll just do it all manually.
|
||||
for invocation in BaseInvocation.get_invocations():
|
||||
json_schema = invocation.model_json_schema(
|
||||
mode="serialization", ref_template="#/components/schemas/{model}"
|
||||
)
|
||||
move_defs_to_top_level(openapi_schema, json_schema)
|
||||
output_title = invocation.get_output_annotation().__name__
|
||||
outputs_ref = {"$ref": f"#/components/schemas/{output_title}"}
|
||||
json_schema["output"] = outputs_ref
|
||||
openapi_schema["components"]["schemas"][invocation.__name__] = json_schema
|
||||
|
||||
# Add this invocation and its output to the output map
|
||||
invocation_type = invocation.get_type()
|
||||
invocation_output_map_properties[invocation_type] = json_schema["output"]
|
||||
invocation_output_map_required.append(invocation_type)
|
||||
|
||||
# Add the output map to the schema
|
||||
openapi_schema["components"]["schemas"]["InvocationOutputMap"] = {
|
||||
"type": "object",
|
||||
"properties": invocation_output_map_properties,
|
||||
"required": invocation_output_map_required,
|
||||
}
|
||||
|
||||
# Some models don't end up in the schemas as standalone definitions because they aren't used directly in the API.
|
||||
# We need to add them manually here. WARNING: Pydantic can choke if you call `model.model_json_schema()` to get
|
||||
# a schema. This has something to do with schema refs - not totally clear. For whatever reason, using
|
||||
# `models_json_schema` seems to work fine.
|
||||
additional_models = [
|
||||
*EventBase.get_events(),
|
||||
UIConfigBase,
|
||||
InputFieldJSONSchemaExtra,
|
||||
OutputFieldJSONSchemaExtra,
|
||||
ModelIdentifierField,
|
||||
ProgressImage,
|
||||
]
|
||||
|
||||
additional_schemas = models_json_schema(
|
||||
[(m, "serialization") for m in additional_models],
|
||||
ref_template="#/components/schemas/{model}",
|
||||
)
|
||||
# additional_schemas[1] is a dict of $defs that we need to add to the top level of the schema
|
||||
move_defs_to_top_level(openapi_schema, additional_schemas[1])
|
||||
|
||||
if post_transform is not None:
|
||||
openapi_schema = post_transform(openapi_schema)
|
||||
|
||||
openapi_schema["components"]["schemas"] = dict(sorted(openapi_schema["components"]["schemas"].items()))
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
return openapi
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -13,36 +13,8 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContextData
|
||||
|
||||
# fast latents preview matrix for sdxl
|
||||
# generated by @StAlKeR7779
|
||||
SDXL_LATENT_RGB_FACTORS = [
|
||||
# R G B
|
||||
[0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
]
|
||||
SDXL_SMOOTH_MATRIX = [
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
]
|
||||
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
SD1_5_LATENT_RGB_FACTORS = [
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
]
|
||||
|
||||
|
||||
def sample_to_lowres_estimated_image(
|
||||
samples: torch.Tensor, latent_rgb_factors: torch.Tensor, smooth_matrix: Optional[torch.Tensor] = None
|
||||
):
|
||||
def sample_to_lowres_estimated_image(samples, latent_rgb_factors, smooth_matrix=None):
|
||||
latent_image = samples[0].permute(1, 2, 0) @ latent_rgb_factors
|
||||
|
||||
if smooth_matrix is not None:
|
||||
@@ -75,12 +47,64 @@ def stable_diffusion_step_callback(
|
||||
else:
|
||||
sample = intermediate_state.latents
|
||||
|
||||
# TODO: This does not seem to be needed any more?
|
||||
# # txt2img provides a Tensor in the step_callback
|
||||
# # img2img provides a PipelineIntermediateState
|
||||
# if isinstance(sample, PipelineIntermediateState):
|
||||
# # this was an img2img
|
||||
# print('img2img')
|
||||
# latents = sample.latents
|
||||
# step = sample.step
|
||||
# else:
|
||||
# print('txt2img')
|
||||
# latents = sample
|
||||
# step = intermediate_state.step
|
||||
|
||||
# TODO: only output a preview image when requested
|
||||
|
||||
if base_model in [BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner]:
|
||||
sdxl_latent_rgb_factors = torch.tensor(SDXL_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
sdxl_smooth_matrix = torch.tensor(SDXL_SMOOTH_MATRIX, dtype=sample.dtype, device=sample.device)
|
||||
# fast latents preview matrix for sdxl
|
||||
# generated by @StAlKeR7779
|
||||
sdxl_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3816, 0.4930, 0.5320],
|
||||
[-0.3753, 0.1631, 0.1739],
|
||||
[0.1770, 0.3588, -0.2048],
|
||||
[-0.4350, -0.2644, -0.4289],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
sdxl_smooth_matrix = torch.tensor(
|
||||
[
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
[0.0964, 0.4711, 0.0964],
|
||||
[0.0358, 0.0964, 0.0358],
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, sdxl_latent_rgb_factors, sdxl_smooth_matrix)
|
||||
else:
|
||||
v1_5_latent_rgb_factors = torch.tensor(SD1_5_LATENT_RGB_FACTORS, dtype=sample.dtype, device=sample.device)
|
||||
# origingally adapted from code by @erucipe and @keturn here:
|
||||
# https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/7
|
||||
|
||||
# these updated numbers for v1.5 are from @torridgristle
|
||||
v1_5_latent_rgb_factors = torch.tensor(
|
||||
[
|
||||
# R G B
|
||||
[0.3444, 0.1385, 0.0670], # L1
|
||||
[0.1247, 0.4027, 0.1494], # L2
|
||||
[-0.3192, 0.2513, 0.2103], # L3
|
||||
[-0.1307, -0.1874, -0.7445], # L4
|
||||
],
|
||||
dtype=sample.dtype,
|
||||
device=sample.device,
|
||||
)
|
||||
|
||||
image = sample_to_lowres_estimated_image(sample, v1_5_latent_rgb_factors)
|
||||
|
||||
(width, height) = image.size
|
||||
@@ -89,9 +113,15 @@ def stable_diffusion_step_callback(
|
||||
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
events.emit_invocation_denoise_progress(
|
||||
context_data.queue_item,
|
||||
context_data.invocation,
|
||||
intermediate_state,
|
||||
ProgressImage(dataURL=dataURL, width=width, height=height),
|
||||
events.emit_generator_progress(
|
||||
queue_id=context_data.queue_item.queue_id,
|
||||
queue_item_id=context_data.queue_item.item_id,
|
||||
queue_batch_id=context_data.queue_item.batch_id,
|
||||
graph_execution_state_id=context_data.queue_item.session_id,
|
||||
node_id=context_data.invocation.id,
|
||||
source_node_id=context_data.source_invocation_id,
|
||||
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
|
||||
step=intermediate_state.step,
|
||||
order=intermediate_state.order,
|
||||
total_steps=intermediate_state.total_steps,
|
||||
)
|
||||
|
||||
@@ -42,26 +42,10 @@ T = TypeVar("T")
|
||||
|
||||
@dataclass
|
||||
class CacheRecord(Generic[T]):
|
||||
"""
|
||||
Elements of the cache:
|
||||
|
||||
key: Unique key for each model, same as used in the models database.
|
||||
model: Model in memory.
|
||||
state_dict: A read-only copy of the model's state dict in RAM. It will be
|
||||
used as a template for creating a copy in the VRAM.
|
||||
size: Size of the model
|
||||
loaded: True if the model's state dict is currently in VRAM
|
||||
|
||||
Before a model is executed, the state_dict template is copied into VRAM,
|
||||
and then injected into the model. When the model is finished, the VRAM
|
||||
copy of the state dict is deleted, and the RAM version is reinjected
|
||||
into the model.
|
||||
"""
|
||||
"""Elements of the cache."""
|
||||
|
||||
key: str
|
||||
model: T
|
||||
device: torch.device
|
||||
state_dict: Optional[Dict[str, torch.Tensor]]
|
||||
size: int
|
||||
loaded: bool = False
|
||||
_locks: int = 0
|
||||
|
||||
@@ -20,6 +20,7 @@ context. Use like this:
|
||||
|
||||
import gc
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from logging import Logger
|
||||
@@ -161,9 +162,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if key in self._cached_models:
|
||||
return
|
||||
self.make_room(size)
|
||||
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size)
|
||||
cache_record = CacheRecord(key, model, size)
|
||||
self._cached_models[key] = cache_record
|
||||
self._cache_stack.append(key)
|
||||
|
||||
@@ -258,37 +257,17 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
|
||||
return
|
||||
|
||||
source_device = cache_entry.device
|
||||
source_device = cache_entry.model.device
|
||||
|
||||
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
|
||||
# This would need to be revised to support multi-GPU.
|
||||
if torch.device(source_device).type == torch.device(target_device).type:
|
||||
return
|
||||
|
||||
# This roundabout method for moving the model around is done to avoid
|
||||
# the cost of moving the model from RAM to VRAM and then back from VRAM to RAM.
|
||||
# When moving to VRAM, we copy (not move) each element of the state dict from
|
||||
# RAM to a new state dict in VRAM, and then inject it into the model.
|
||||
# This operation is slightly faster than running `to()` on the whole model.
|
||||
#
|
||||
# When the model needs to be removed from VRAM we simply delete the copy
|
||||
# of the state dict in VRAM, and reinject the state dict that is cached
|
||||
# in RAM into the model. So this operation is very fast.
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
|
||||
try:
|
||||
if cache_entry.state_dict is not None:
|
||||
assert hasattr(cache_entry.model, "load_state_dict")
|
||||
if target_device == self.storage_device:
|
||||
cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True)
|
||||
else:
|
||||
new_dict: Dict[str, torch.Tensor] = {}
|
||||
for k, v in cache_entry.state_dict.items():
|
||||
new_dict[k] = v.to(torch.device(target_device), copy=True)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device)
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
@@ -368,12 +347,43 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
|
||||
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
|
||||
# https://docs.python.org/3/library/gc.html#gc.get_referrers
|
||||
|
||||
# manualy clear local variable references of just finished function calls
|
||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
||||
if refs > 2:
|
||||
while True:
|
||||
cleared = False
|
||||
for referrer in gc.get_referrers(cache_entry.model):
|
||||
if type(referrer).__name__ == "frame":
|
||||
# RuntimeError: cannot clear an executing frame
|
||||
with suppress(RuntimeError):
|
||||
referrer.clear()
|
||||
cleared = True
|
||||
# break
|
||||
|
||||
# repeat if referrers changes(due to frame clear), else exit loop
|
||||
if cleared:
|
||||
gc.collect()
|
||||
else:
|
||||
break
|
||||
|
||||
device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None
|
||||
self.logger.debug(
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}"
|
||||
f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded},"
|
||||
f" refs: {refs}"
|
||||
)
|
||||
|
||||
if not cache_entry.locked:
|
||||
# Expected refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
# 1 from onnx runtime object
|
||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||
self.logger.debug(
|
||||
f"Removing {model_key} from RAM cache to free at least {(size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
|
||||
@@ -60,5 +60,5 @@ class ModelLocker(ModelLockerBase):
|
||||
|
||||
self._cache_entry.unlock()
|
||||
if not self._cache.lazy_offloading:
|
||||
self._cache.offload_unlocked_models(0)
|
||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||
self._cache.print_cuda_stats()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Textual Inversion wrapper class."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
@@ -66,52 +66,35 @@ class TextualInversionModelRaw(RawModel):
|
||||
return result
|
||||
|
||||
|
||||
class TextualInversionManager(BaseTextualInversionManager):
|
||||
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
|
||||
# no type hints for BaseTextualInversionManager?
|
||||
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
|
||||
pad_tokens: Dict[int, List[int]]
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.pad_tokens: dict[int, list[int]] = {}
|
||||
self.pad_tokens = {}
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||
"""Given a list of tokens ids, expand any TI tokens to their corresponding pad tokens.
|
||||
|
||||
For example, suppose we have a `<ti_dog>` TI with 4 vectors that was added to the tokenizer with the following
|
||||
mapping of tokens to token_ids:
|
||||
```
|
||||
<ti_dog>: 49408
|
||||
<ti_dog-!pad-1>: 49409
|
||||
<ti_dog-!pad-2>: 49410
|
||||
<ti_dog-!pad-3>: 49411
|
||||
```
|
||||
`self.pad_tokens` would be set to `{49408: [49408, 49409, 49410, 49411]}`.
|
||||
This function is responsible for expanding `49408` in the token_ids list to `[49408, 49409, 49410, 49411]`.
|
||||
"""
|
||||
# Short circuit if there are no pad tokens to save a little time.
|
||||
if len(self.pad_tokens) == 0:
|
||||
return token_ids
|
||||
|
||||
# This function assumes that compel has not included the BOS and EOS tokens in the token_ids list. We verify
|
||||
# this assumption here.
|
||||
if token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("token_ids must not start with bos_token_id")
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("token_ids must not end with eos_token_id")
|
||||
|
||||
# Expand any TI tokens to their corresponding pad tokens.
|
||||
new_token_ids: list[int] = []
|
||||
new_token_ids = []
|
||||
for token_id in token_ids:
|
||||
new_token_ids.append(token_id)
|
||||
if token_id in self.pad_tokens:
|
||||
new_token_ids.extend(self.pad_tokens[token_id])
|
||||
|
||||
# Do not exceed the max model input size. The -2 here is compensating for
|
||||
# compel.embeddings_provider.get_token_ids(), which first removes and then adds back the start and end tokens.
|
||||
max_length = self.tokenizer.model_max_length - 2
|
||||
# Do not exceed the max model input size
|
||||
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
|
||||
# which first removes and then adds back the start and end tokens.
|
||||
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
|
||||
if len(new_token_ids) > max_length:
|
||||
# HACK: If TI token expansion causes us to exceed the max text encoder input length, we silently discard
|
||||
# tokens. Token expansion should happen in a way that is compatible with compel's default handling of long
|
||||
# prompts.
|
||||
new_token_ids = new_token_ids[0:max_length]
|
||||
|
||||
return new_token_ids
|
||||
|
||||
@@ -1021,8 +1021,7 @@
|
||||
"float": "Kommazahlen",
|
||||
"enum": "Aufzählung",
|
||||
"fullyContainNodes": "Vollständig ausgewählte Nodes auswählen",
|
||||
"editMode": "Im Workflow-Editor bearbeiten",
|
||||
"resetToDefaultValue": "Auf Standardwert zurücksetzen"
|
||||
"editMode": "Im Workflow-Editor bearbeiten"
|
||||
},
|
||||
"hrf": {
|
||||
"enableHrf": "Korrektur für hohe Auflösungen",
|
||||
|
||||
@@ -148,8 +148,6 @@
|
||||
"viewingDesc": "Review images in a large gallery view",
|
||||
"editing": "Editing",
|
||||
"editingDesc": "Edit on the Control Layers canvas",
|
||||
"comparing": "Comparing",
|
||||
"comparingDesc": "Comparing two images",
|
||||
"enabled": "Enabled",
|
||||
"disabled": "Disabled"
|
||||
},
|
||||
@@ -377,23 +375,7 @@
|
||||
"bulkDownloadRequestFailed": "Problem Preparing Download",
|
||||
"bulkDownloadFailed": "Download Failed",
|
||||
"problemDeletingImages": "Problem Deleting Images",
|
||||
"problemDeletingImagesDesc": "One or more images could not be deleted",
|
||||
"viewerImage": "Viewer Image",
|
||||
"compareImage": "Compare Image",
|
||||
"openInViewer": "Open in Viewer",
|
||||
"selectForCompare": "Select for Compare",
|
||||
"selectAnImageToCompare": "Select an Image to Compare",
|
||||
"slider": "Slider",
|
||||
"sideBySide": "Side-by-Side",
|
||||
"hover": "Hover",
|
||||
"swapImages": "Swap Images",
|
||||
"compareOptions": "Comparison Options",
|
||||
"stretchToFit": "Stretch to Fit",
|
||||
"exitCompare": "Exit Compare",
|
||||
"compareHelp1": "Hold <Kbd>Alt</Kbd> while clicking a gallery image or using the arrow keys to change the compare image.",
|
||||
"compareHelp2": "Press <Kbd>M</Kbd> to cycle through comparison modes.",
|
||||
"compareHelp3": "Press <Kbd>C</Kbd> to swap the compared images.",
|
||||
"compareHelp4": "Press <Kbd>Z</Kbd> or <Kbd>Esc</Kbd> to exit."
|
||||
"problemDeletingImagesDesc": "One or more images could not be deleted"
|
||||
},
|
||||
"hotkeys": {
|
||||
"searchHotkeys": "Search Hotkeys",
|
||||
@@ -1122,7 +1104,7 @@
|
||||
"parameters": "Parameters",
|
||||
"parameterSet": "Parameter Recalled",
|
||||
"parameterSetDesc": "Recalled {{parameter}}",
|
||||
"parameterNotSet": "Parameter Not Recalled",
|
||||
"parameterNotSet": "Parameter Recalled",
|
||||
"parameterNotSetDesc": "Unable to recall {{parameter}}",
|
||||
"parameterNotSetDescWithMessage": "Unable to recall {{parameter}}: {{message}}",
|
||||
"parametersSet": "Parameters Recalled",
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
"settingsLabel": "Ajustes",
|
||||
"img2img": "Imagen a Imagen",
|
||||
"unifiedCanvas": "Lienzo Unificado",
|
||||
"nodes": "Flujos de trabajo",
|
||||
"nodes": "Editor del flujo de trabajo",
|
||||
"upload": "Subir imagen",
|
||||
"load": "Cargar",
|
||||
"statusDisconnected": "Desconectado",
|
||||
@@ -14,7 +14,7 @@
|
||||
"discordLabel": "Discord",
|
||||
"back": "Atrás",
|
||||
"loading": "Cargando",
|
||||
"postprocessing": "Postprocesado",
|
||||
"postprocessing": "Tratamiento posterior",
|
||||
"txt2img": "De texto a imagen",
|
||||
"accept": "Aceptar",
|
||||
"cancel": "Cancelar",
|
||||
@@ -42,42 +42,7 @@
|
||||
"copy": "Copiar",
|
||||
"beta": "Beta",
|
||||
"on": "En",
|
||||
"aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:",
|
||||
"installed": "Instalado",
|
||||
"green": "Verde",
|
||||
"editor": "Editor",
|
||||
"orderBy": "Ordenar por",
|
||||
"file": "Archivo",
|
||||
"goTo": "Ir a",
|
||||
"imageFailedToLoad": "No se puede cargar la imagen",
|
||||
"saveAs": "Guardar Como",
|
||||
"somethingWentWrong": "Algo salió mal",
|
||||
"nextPage": "Página Siguiente",
|
||||
"selected": "Seleccionado",
|
||||
"tab": "Tabulador",
|
||||
"positivePrompt": "Prompt Positivo",
|
||||
"negativePrompt": "Prompt Negativo",
|
||||
"error": "Error",
|
||||
"format": "formato",
|
||||
"unknown": "Desconocido",
|
||||
"input": "Entrada",
|
||||
"nodeEditor": "Editor de nodos",
|
||||
"template": "Plantilla",
|
||||
"prevPage": "Página Anterior",
|
||||
"red": "Rojo",
|
||||
"alpha": "Transparencia",
|
||||
"outputs": "Salidas",
|
||||
"editing": "Editando",
|
||||
"learnMore": "Aprende más",
|
||||
"enabled": "Activado",
|
||||
"disabled": "Desactivado",
|
||||
"folder": "Carpeta",
|
||||
"updated": "Actualizado",
|
||||
"created": "Creado",
|
||||
"save": "Guardar",
|
||||
"unknownError": "Error Desconocido",
|
||||
"blue": "Azul",
|
||||
"viewingDesc": "Revisar imágenes en una vista de galería grande"
|
||||
"aboutDesc": "¿Utilizas Invoke para trabajar? Mira aquí:"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Tamaño de la imagen",
|
||||
@@ -502,8 +467,7 @@
|
||||
"about": "Acerca de",
|
||||
"createIssue": "Crear un problema",
|
||||
"resetUI": "Interfaz de usuario $t(accessibility.reset)",
|
||||
"mode": "Modo",
|
||||
"submitSupportTicket": "Enviar Ticket de Soporte"
|
||||
"mode": "Modo"
|
||||
},
|
||||
"nodes": {
|
||||
"zoomInNodes": "Acercar",
|
||||
@@ -579,17 +543,5 @@
|
||||
"layers_one": "Capa",
|
||||
"layers_many": "Capas",
|
||||
"layers_other": "Capas"
|
||||
},
|
||||
"controlnet": {
|
||||
"crop": "Cortar",
|
||||
"delete": "Eliminar",
|
||||
"depthAnythingDescription": "Generación de mapa de profundidad usando la técnica de Depth Anything",
|
||||
"duplicate": "Duplicar",
|
||||
"colorMapDescription": "Genera un mapa de color desde la imagen",
|
||||
"depthMidasDescription": "Crea un mapa de profundidad con Midas",
|
||||
"balanced": "Equilibrado",
|
||||
"beginEndStepPercent": "Inicio / Final Porcentaje de pasos",
|
||||
"detectResolution": "Detectar resolución",
|
||||
"beginEndStepPercentShort": "Inicio / Final %"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@
|
||||
"outputs": "Risultati",
|
||||
"data": "Dati",
|
||||
"somethingWentWrong": "Qualcosa è andato storto",
|
||||
"copyError": "Errore $t(gallery.copy)",
|
||||
"copyError": "$t(gallery.copy) Errore",
|
||||
"input": "Ingresso",
|
||||
"notInstalled": "Non $t(common.installed)",
|
||||
"unknownError": "Errore sconosciuto",
|
||||
@@ -85,11 +85,7 @@
|
||||
"viewing": "Visualizza",
|
||||
"viewingDesc": "Rivedi le immagini in un'ampia vista della galleria",
|
||||
"editing": "Modifica",
|
||||
"editingDesc": "Modifica nell'area Livelli di controllo",
|
||||
"enabled": "Abilitato",
|
||||
"disabled": "Disabilitato",
|
||||
"comparingDesc": "Confronta due immagini",
|
||||
"comparing": "Confronta"
|
||||
"editingDesc": "Modifica nell'area Livelli di controllo"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Dimensione dell'immagine",
|
||||
@@ -126,30 +122,14 @@
|
||||
"bulkDownloadRequestedDesc": "La tua richiesta di download è in preparazione. L'operazione potrebbe richiedere alcuni istanti.",
|
||||
"bulkDownloadRequestFailed": "Problema durante la preparazione del download",
|
||||
"bulkDownloadFailed": "Scaricamento fallito",
|
||||
"alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine",
|
||||
"openInViewer": "Apri nel visualizzatore",
|
||||
"selectForCompare": "Seleziona per il confronto",
|
||||
"selectAnImageToCompare": "Seleziona un'immagine da confrontare",
|
||||
"slider": "Cursore",
|
||||
"sideBySide": "Fianco a Fianco",
|
||||
"compareImage": "Immagine di confronto",
|
||||
"viewerImage": "Immagine visualizzata",
|
||||
"hover": "Al passaggio del mouse",
|
||||
"swapImages": "Scambia le immagini",
|
||||
"compareOptions": "Opzioni di confronto",
|
||||
"stretchToFit": "Scala per adattare",
|
||||
"exitCompare": "Esci dal confronto",
|
||||
"compareHelp1": "Tieni premuto <Kbd>Alt</Kbd> mentre fai clic su un'immagine della galleria o usi i tasti freccia per cambiare l'immagine di confronto.",
|
||||
"compareHelp2": "Premi <Kbd>M</Kbd> per scorrere le modalità di confronto.",
|
||||
"compareHelp3": "Premi <Kbd>C</Kbd> per scambiare le immagini confrontate.",
|
||||
"compareHelp4": "Premi <Kbd>Z</Kbd> o <Kbd>Esc</Kbd> per uscire."
|
||||
"alwaysShowImageSizeBadge": "Mostra sempre le dimensioni dell'immagine"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Tasti di scelta rapida",
|
||||
"appHotkeys": "Applicazione",
|
||||
"generalHotkeys": "Generale",
|
||||
"galleryHotkeys": "Galleria",
|
||||
"unifiedCanvasHotkeys": "Tela",
|
||||
"unifiedCanvasHotkeys": "Tela Unificata",
|
||||
"invoke": {
|
||||
"title": "Invoke",
|
||||
"desc": "Genera un'immagine"
|
||||
@@ -167,8 +147,8 @@
|
||||
"desc": "Apre e chiude il pannello delle opzioni"
|
||||
},
|
||||
"pinOptions": {
|
||||
"title": "Fissa le opzioni",
|
||||
"desc": "Fissa il pannello delle opzioni"
|
||||
"title": "Appunta le opzioni",
|
||||
"desc": "Blocca il pannello delle opzioni"
|
||||
},
|
||||
"toggleGallery": {
|
||||
"title": "Attiva/disattiva galleria",
|
||||
@@ -352,14 +332,14 @@
|
||||
"title": "Annulla e cancella"
|
||||
},
|
||||
"resetOptionsAndGallery": {
|
||||
"title": "Ripristina le opzioni e la galleria",
|
||||
"desc": "Reimposta i pannelli delle opzioni e della galleria"
|
||||
"title": "Ripristina Opzioni e Galleria",
|
||||
"desc": "Reimposta le opzioni e i pannelli della galleria"
|
||||
},
|
||||
"searchHotkeys": "Cerca tasti di scelta rapida",
|
||||
"noHotkeysFound": "Nessun tasto di scelta rapida trovato",
|
||||
"toggleOptionsAndGallery": {
|
||||
"desc": "Apre e chiude le opzioni e i pannelli della galleria",
|
||||
"title": "Attiva/disattiva le opzioni e la galleria"
|
||||
"title": "Attiva/disattiva le Opzioni e la Galleria"
|
||||
},
|
||||
"clearSearch": "Cancella ricerca",
|
||||
"remixImage": {
|
||||
@@ -368,7 +348,7 @@
|
||||
},
|
||||
"toggleViewer": {
|
||||
"title": "Attiva/disattiva il visualizzatore di immagini",
|
||||
"desc": "Passa dal visualizzatore immagini all'area di lavoro per la scheda corrente."
|
||||
"desc": "Passa dal Visualizzatore immagini all'area di lavoro per la scheda corrente."
|
||||
}
|
||||
},
|
||||
"modelManager": {
|
||||
@@ -398,7 +378,7 @@
|
||||
"convertToDiffusers": "Converti in Diffusori",
|
||||
"convertToDiffusersHelpText2": "Questo processo sostituirà la voce in Gestione Modelli con la versione Diffusori dello stesso modello.",
|
||||
"convertToDiffusersHelpText4": "Questo è un processo una tantum. Potrebbero essere necessari circa 30-60 secondi a seconda delle specifiche del tuo computer.",
|
||||
"convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB in dimensione.",
|
||||
"convertToDiffusersHelpText5": "Assicurati di avere spazio su disco sufficiente. I modelli generalmente variano tra 2 GB e 7 GB di dimensioni.",
|
||||
"convertToDiffusersHelpText6": "Vuoi convertire questo modello?",
|
||||
"modelConverted": "Modello convertito",
|
||||
"alpha": "Alpha",
|
||||
@@ -548,7 +528,7 @@
|
||||
"layer": {
|
||||
"initialImageNoImageSelected": "Nessuna immagine iniziale selezionata",
|
||||
"t2iAdapterIncompatibleDimensions": "L'adattatore T2I richiede che la dimensione dell'immagine sia un multiplo di {{multiple}}",
|
||||
"controlAdapterNoModelSelected": "Nessun modello di adattatore di controllo selezionato",
|
||||
"controlAdapterNoModelSelected": "Nessun modello di Adattatore di Controllo selezionato",
|
||||
"controlAdapterIncompatibleBaseModel": "Il modello base dell'adattatore di controllo non è compatibile",
|
||||
"controlAdapterNoImageSelected": "Nessuna immagine dell'adattatore di controllo selezionata",
|
||||
"controlAdapterImageNotProcessed": "Immagine dell'adattatore di controllo non elaborata",
|
||||
@@ -626,25 +606,25 @@
|
||||
"canvasMerged": "Tela unita",
|
||||
"sentToImageToImage": "Inviato a Generazione da immagine",
|
||||
"sentToUnifiedCanvas": "Inviato alla Tela",
|
||||
"parametersNotSet": "Parametri non richiamati",
|
||||
"parametersNotSet": "Parametri non impostati",
|
||||
"metadataLoadFailed": "Impossibile caricare i metadati",
|
||||
"serverError": "Errore del Server",
|
||||
"connected": "Connesso al server",
|
||||
"connected": "Connesso al Server",
|
||||
"canceled": "Elaborazione annullata",
|
||||
"uploadFailedInvalidUploadDesc": "Deve essere una singola immagine PNG o JPEG",
|
||||
"parameterSet": "Parametro richiamato",
|
||||
"parameterNotSet": "Parametro non richiamato",
|
||||
"parameterSet": "{{parameter}} impostato",
|
||||
"parameterNotSet": "{{parameter}} non impostato",
|
||||
"problemCopyingImage": "Impossibile copiare l'immagine",
|
||||
"baseModelChangedCleared_one": "Cancellato o disabilitato {{count}} sottomodello incompatibile",
|
||||
"baseModelChangedCleared_many": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
|
||||
"baseModelChangedCleared_other": "Cancellati o disabilitati {{count}} sottomodelli incompatibili",
|
||||
"baseModelChangedCleared_one": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modello incompatibile",
|
||||
"baseModelChangedCleared_many": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili",
|
||||
"baseModelChangedCleared_other": "Il modello base è stato modificato, cancellato o disabilitato {{count}} sotto-modelli incompatibili",
|
||||
"imageSavingFailed": "Salvataggio dell'immagine non riuscito",
|
||||
"canvasSentControlnetAssets": "Tela inviata a ControlNet & Risorse",
|
||||
"problemCopyingCanvasDesc": "Impossibile copiare la tela",
|
||||
"loadedWithWarnings": "Flusso di lavoro caricato con avvisi",
|
||||
"canvasCopiedClipboard": "Tela copiata negli appunti",
|
||||
"maskSavedAssets": "Maschera salvata nelle risorse",
|
||||
"problemDownloadingCanvas": "Problema durante lo scarico della tela",
|
||||
"problemDownloadingCanvas": "Problema durante il download della tela",
|
||||
"problemMergingCanvas": "Problema nell'unione delle tele",
|
||||
"imageUploaded": "Immagine caricata",
|
||||
"addedToBoard": "Aggiunto alla bacheca",
|
||||
@@ -678,17 +658,7 @@
|
||||
"problemDownloadingImage": "Impossibile scaricare l'immagine",
|
||||
"prunedQueue": "Coda ripulita",
|
||||
"modelImportCanceled": "Importazione del modello annullata",
|
||||
"parameters": "Parametri",
|
||||
"parameterSetDesc": "{{parameter}} richiamato",
|
||||
"parameterNotSetDesc": "Impossibile richiamare {{parameter}}",
|
||||
"parameterNotSetDescWithMessage": "Impossibile richiamare {{parameter}}: {{message}}",
|
||||
"parametersSet": "Parametri richiamati",
|
||||
"errorCopied": "Errore copiato",
|
||||
"outOfMemoryError": "Errore di memoria esaurita",
|
||||
"baseModelChanged": "Modello base modificato",
|
||||
"sessionRef": "Sessione: {{sessionId}}",
|
||||
"somethingWentWrong": "Qualcosa è andato storto",
|
||||
"outOfMemoryErrorDesc": "Le impostazioni della generazione attuale superano la capacità del sistema. Modifica le impostazioni e riprova."
|
||||
"parameters": "Parametri"
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@@ -704,7 +674,7 @@
|
||||
"layer": "Livello",
|
||||
"base": "Base",
|
||||
"mask": "Maschera",
|
||||
"maskingOptions": "Opzioni maschera",
|
||||
"maskingOptions": "Opzioni di mascheramento",
|
||||
"enableMask": "Abilita maschera",
|
||||
"preserveMaskedArea": "Mantieni area mascherata",
|
||||
"clearMask": "Cancella maschera (Shift+C)",
|
||||
@@ -775,8 +745,7 @@
|
||||
"mode": "Modalità",
|
||||
"resetUI": "$t(accessibility.reset) l'Interfaccia Utente",
|
||||
"createIssue": "Segnala un problema",
|
||||
"about": "Informazioni",
|
||||
"submitSupportTicket": "Invia ticket di supporto"
|
||||
"about": "Informazioni"
|
||||
},
|
||||
"nodes": {
|
||||
"zoomOutNodes": "Rimpicciolire",
|
||||
@@ -821,7 +790,7 @@
|
||||
"workflowNotes": "Note",
|
||||
"versionUnknown": " Versione sconosciuta",
|
||||
"unableToValidateWorkflow": "Impossibile convalidare il flusso di lavoro",
|
||||
"updateApp": "Aggiorna Applicazione",
|
||||
"updateApp": "Aggiorna App",
|
||||
"unableToLoadWorkflow": "Impossibile caricare il flusso di lavoro",
|
||||
"updateNode": "Aggiorna nodo",
|
||||
"version": "Versione",
|
||||
@@ -913,14 +882,11 @@
|
||||
"missingNode": "Nodo di invocazione mancante",
|
||||
"missingInvocationTemplate": "Modello di invocazione mancante",
|
||||
"missingFieldTemplate": "Modello di campo mancante",
|
||||
"singleFieldType": "{{name}} (Singola)",
|
||||
"imageAccessError": "Impossibile trovare l'immagine {{image_name}}, ripristino delle impostazioni predefinite",
|
||||
"boardAccessError": "Impossibile trovare la bacheca {{board_id}}, ripristino ai valori predefiniti",
|
||||
"modelAccessError": "Impossibile trovare il modello {{key}}, ripristino ai valori predefiniti"
|
||||
"singleFieldType": "{{name}} (Singola)"
|
||||
},
|
||||
"boards": {
|
||||
"autoAddBoard": "Aggiungi automaticamente bacheca",
|
||||
"menuItemAutoAdd": "Aggiungi automaticamente a questa bacheca",
|
||||
"menuItemAutoAdd": "Aggiungi automaticamente a questa Bacheca",
|
||||
"cancel": "Annulla",
|
||||
"addBoard": "Aggiungi Bacheca",
|
||||
"bottomMessage": "L'eliminazione di questa bacheca e delle sue immagini ripristinerà tutte le funzionalità che le stanno attualmente utilizzando.",
|
||||
@@ -932,7 +898,7 @@
|
||||
"myBoard": "Bacheca",
|
||||
"searchBoard": "Cerca bacheche ...",
|
||||
"noMatching": "Nessuna bacheca corrispondente",
|
||||
"selectBoard": "Seleziona una bacheca",
|
||||
"selectBoard": "Seleziona una Bacheca",
|
||||
"uncategorized": "Non categorizzato",
|
||||
"downloadBoard": "Scarica la bacheca",
|
||||
"deleteBoardOnly": "solo la Bacheca",
|
||||
@@ -953,7 +919,7 @@
|
||||
"control": "Controllo",
|
||||
"crop": "Ritaglia",
|
||||
"depthMidas": "Profondità (Midas)",
|
||||
"detectResolution": "Rileva la risoluzione",
|
||||
"detectResolution": "Rileva risoluzione",
|
||||
"controlMode": "Modalità di controllo",
|
||||
"cannyDescription": "Canny rilevamento bordi",
|
||||
"depthZoe": "Profondità (Zoe)",
|
||||
@@ -964,7 +930,7 @@
|
||||
"showAdvanced": "Mostra opzioni Avanzate",
|
||||
"bgth": "Soglia rimozione sfondo",
|
||||
"importImageFromCanvas": "Importa immagine dalla Tela",
|
||||
"lineartDescription": "Converte l'immagine in linea",
|
||||
"lineartDescription": "Converte l'immagine in lineart",
|
||||
"importMaskFromCanvas": "Importa maschera dalla Tela",
|
||||
"hideAdvanced": "Nascondi opzioni avanzate",
|
||||
"resetControlImage": "Reimposta immagine di controllo",
|
||||
@@ -980,7 +946,7 @@
|
||||
"pidiDescription": "Elaborazione immagini PIDI",
|
||||
"fill": "Riempie",
|
||||
"colorMapDescription": "Genera una mappa dei colori dall'immagine",
|
||||
"lineartAnimeDescription": "Elaborazione linea in stile anime",
|
||||
"lineartAnimeDescription": "Elaborazione lineart in stile anime",
|
||||
"imageResolution": "Risoluzione dell'immagine",
|
||||
"colorMap": "Colore",
|
||||
"lowThreshold": "Soglia inferiore",
|
||||
|
||||
@@ -87,11 +87,7 @@
|
||||
"viewing": "Просмотр",
|
||||
"editing": "Редактирование",
|
||||
"viewingDesc": "Просмотр изображений в режиме большой галереи",
|
||||
"editingDesc": "Редактировать на холсте слоёв управления",
|
||||
"enabled": "Включено",
|
||||
"disabled": "Отключено",
|
||||
"comparingDesc": "Сравнение двух изображений",
|
||||
"comparing": "Сравнение"
|
||||
"editingDesc": "Редактировать на холсте слоёв управления"
|
||||
},
|
||||
"gallery": {
|
||||
"galleryImageSize": "Размер изображений",
|
||||
@@ -128,23 +124,7 @@
|
||||
"bulkDownloadRequested": "Подготовка к скачиванию",
|
||||
"bulkDownloadRequestedDesc": "Ваш запрос на скачивание готовится. Это может занять несколько минут.",
|
||||
"bulkDownloadRequestFailed": "Возникла проблема при подготовке скачивания",
|
||||
"alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения",
|
||||
"openInViewer": "Открыть в просмотрщике",
|
||||
"selectForCompare": "Выбрать для сравнения",
|
||||
"hover": "Наведение",
|
||||
"swapImages": "Поменять местами",
|
||||
"stretchToFit": "Растягивание до нужного размера",
|
||||
"exitCompare": "Выйти из сравнения",
|
||||
"compareHelp4": "Нажмите <Kbd>Z</Kbd> или <Kbd>Esc</Kbd> для выхода.",
|
||||
"compareImage": "Сравнить изображение",
|
||||
"viewerImage": "Изображение просмотрщика",
|
||||
"selectAnImageToCompare": "Выберите изображение для сравнения",
|
||||
"slider": "Слайдер",
|
||||
"sideBySide": "Бок о бок",
|
||||
"compareOptions": "Варианты сравнения",
|
||||
"compareHelp1": "Удерживайте <Kbd>Alt</Kbd> при нажатии на изображение в галерее или при помощи клавиш со стрелками, чтобы изменить сравниваемое изображение.",
|
||||
"compareHelp2": "Нажмите <Kbd>M</Kbd>, чтобы переключиться между режимами сравнения.",
|
||||
"compareHelp3": "Нажмите <Kbd>C</Kbd>, чтобы поменять местами сравниваемые изображения."
|
||||
"alwaysShowImageSizeBadge": "Всегда показывать значок размера изображения"
|
||||
},
|
||||
"hotkeys": {
|
||||
"keyboardShortcuts": "Горячие клавиши",
|
||||
@@ -548,20 +528,7 @@
|
||||
"missingFieldTemplate": "Отсутствует шаблон поля",
|
||||
"addingImagesTo": "Добавление изображений в",
|
||||
"invoke": "Создать",
|
||||
"imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается",
|
||||
"layer": {
|
||||
"controlAdapterImageNotProcessed": "Изображение адаптера контроля не обработано",
|
||||
"ipAdapterNoModelSelected": "IP адаптер не выбран",
|
||||
"controlAdapterNoModelSelected": "не выбрана модель адаптера контроля",
|
||||
"controlAdapterIncompatibleBaseModel": "несовместимая базовая модель адаптера контроля",
|
||||
"controlAdapterNoImageSelected": "не выбрано изображение контрольного адаптера",
|
||||
"initialImageNoImageSelected": "начальное изображение не выбрано",
|
||||
"rgNoRegion": "регион не выбран",
|
||||
"rgNoPromptsOrIPAdapters": "нет текстовых запросов или IP-адаптеров",
|
||||
"ipAdapterIncompatibleBaseModel": "несовместимая базовая модель IP-адаптера",
|
||||
"t2iAdapterIncompatibleDimensions": "Адаптер T2I требует, чтобы размеры изображения были кратны {{multiple}}",
|
||||
"ipAdapterNoImageSelected": "изображение IP-адаптера не выбрано"
|
||||
}
|
||||
"imageNotProcessedForControlAdapter": "Изображение адаптера контроля №{{number}} не обрабатывается"
|
||||
},
|
||||
"isAllowedToUpscale": {
|
||||
"useX2Model": "Изображение слишком велико для увеличения с помощью модели x4. Используйте модель x2",
|
||||
@@ -639,12 +606,12 @@
|
||||
"connected": "Подключено к серверу",
|
||||
"canceled": "Обработка отменена",
|
||||
"uploadFailedInvalidUploadDesc": "Должно быть одно изображение в формате PNG или JPEG",
|
||||
"parameterNotSet": "Параметр не задан",
|
||||
"parameterSet": "Параметр задан",
|
||||
"parameterNotSet": "Параметр {{parameter}} не задан",
|
||||
"parameterSet": "Параметр {{parameter}} задан",
|
||||
"problemCopyingImage": "Не удается скопировать изображение",
|
||||
"baseModelChangedCleared_one": "Очищена или отключена {{count}} несовместимая подмодель",
|
||||
"baseModelChangedCleared_few": "Очищены или отключены {{count}} несовместимые подмодели",
|
||||
"baseModelChangedCleared_many": "Очищены или отключены {{count}} несовместимых подмоделей",
|
||||
"baseModelChangedCleared_one": "Базовая модель изменила, очистила или отключила {{count}} несовместимую подмодель",
|
||||
"baseModelChangedCleared_few": "Базовая модель изменила, очистила или отключила {{count}} несовместимые подмодели",
|
||||
"baseModelChangedCleared_many": "Базовая модель изменила, очистила или отключила {{count}} несовместимых подмоделей",
|
||||
"imageSavingFailed": "Не удалось сохранить изображение",
|
||||
"canvasSentControlnetAssets": "Холст отправлен в ControlNet и ресурсы",
|
||||
"problemCopyingCanvasDesc": "Невозможно экспортировать базовый слой",
|
||||
@@ -685,17 +652,7 @@
|
||||
"resetInitialImage": "Сбросить начальное изображение",
|
||||
"prunedQueue": "Урезанная очередь",
|
||||
"modelImportCanceled": "Импорт модели отменен",
|
||||
"parameters": "Параметры",
|
||||
"parameterSetDesc": "Задан {{parameter}}",
|
||||
"parameterNotSetDesc": "Невозможно задать {{parameter}}",
|
||||
"baseModelChanged": "Базовая модель сменена",
|
||||
"parameterNotSetDescWithMessage": "Не удалось задать {{parameter}}: {{message}}",
|
||||
"parametersSet": "Параметры заданы",
|
||||
"errorCopied": "Ошибка скопирована",
|
||||
"sessionRef": "Сессия: {{sessionId}}",
|
||||
"outOfMemoryError": "Ошибка нехватки памяти",
|
||||
"outOfMemoryErrorDesc": "Ваши текущие настройки генерации превышают возможности системы. Пожалуйста, измените настройки и повторите попытку.",
|
||||
"somethingWentWrong": "Что-то пошло не так"
|
||||
"parameters": "Параметры"
|
||||
},
|
||||
"tooltip": {
|
||||
"feature": {
|
||||
@@ -782,8 +739,7 @@
|
||||
"loadMore": "Загрузить больше",
|
||||
"resetUI": "$t(accessibility.reset) интерфейс",
|
||||
"createIssue": "Сообщить о проблеме",
|
||||
"about": "Об этом",
|
||||
"submitSupportTicket": "Отправить тикет в службу поддержки"
|
||||
"about": "Об этом"
|
||||
},
|
||||
"nodes": {
|
||||
"zoomInNodes": "Увеличьте масштаб",
|
||||
@@ -876,7 +832,7 @@
|
||||
"workflowName": "Название",
|
||||
"collection": "Коллекция",
|
||||
"unknownErrorValidatingWorkflow": "Неизвестная ошибка при проверке рабочего процесса",
|
||||
"collectionFieldType": "{{name}} (Коллекция)",
|
||||
"collectionFieldType": "Коллекция {{name}}",
|
||||
"workflowNotes": "Примечания",
|
||||
"string": "Строка",
|
||||
"unknownNodeType": "Неизвестный тип узла",
|
||||
@@ -892,7 +848,7 @@
|
||||
"targetNodeDoesNotExist": "Недопустимое ребро: целевой/входной узел {{node}} не существует",
|
||||
"mismatchedVersion": "Недопустимый узел: узел {{node}} типа {{type}} имеет несоответствующую версию (попробовать обновить?)",
|
||||
"unknownFieldType": "$t(nodes.unknownField) тип: {{type}}",
|
||||
"collectionOrScalarFieldType": "{{name}} (Один или коллекция)",
|
||||
"collectionOrScalarFieldType": "Коллекция | Скаляр {{name}}",
|
||||
"betaDesc": "Этот вызов находится в бета-версии. Пока он не станет стабильным, в нем могут происходить изменения при обновлении приложений. Мы планируем поддерживать этот вызов в течение длительного времени.",
|
||||
"nodeVersion": "Версия узла",
|
||||
"loadingNodes": "Загрузка узлов...",
|
||||
@@ -914,16 +870,7 @@
|
||||
"noFieldsViewMode": "В этом рабочем процессе нет выбранных полей для отображения. Просмотрите полный рабочий процесс для настройки значений.",
|
||||
"graph": "График",
|
||||
"showEdgeLabels": "Показать метки на ребрах",
|
||||
"showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы",
|
||||
"cannotMixAndMatchCollectionItemTypes": "Невозможно смешивать и сопоставлять типы элементов коллекции",
|
||||
"missingNode": "Отсутствует узел вызова",
|
||||
"missingInvocationTemplate": "Отсутствует шаблон вызова",
|
||||
"missingFieldTemplate": "Отсутствующий шаблон поля",
|
||||
"singleFieldType": "{{name}} (Один)",
|
||||
"noGraph": "Нет графика",
|
||||
"imageAccessError": "Невозможно найти изображение {{image_name}}, сбрасываем на значение по умолчанию",
|
||||
"boardAccessError": "Невозможно найти доску {{board_id}}, сбрасываем на значение по умолчанию",
|
||||
"modelAccessError": "Невозможно найти модель {{key}}, сброс на модель по умолчанию"
|
||||
"showEdgeLabelsHelp": "Показать метки на ребрах, указывающие на соединенные узлы"
|
||||
},
|
||||
"controlnet": {
|
||||
"amult": "a_mult",
|
||||
@@ -1494,16 +1441,7 @@
|
||||
"clearQueueAlertDialog2": "Вы уверены, что хотите очистить очередь?",
|
||||
"item": "Элемент",
|
||||
"graphFailedToQueue": "Не удалось поставить график в очередь",
|
||||
"openQueue": "Открыть очередь",
|
||||
"prompts_one": "Запрос",
|
||||
"prompts_few": "Запроса",
|
||||
"prompts_many": "Запросов",
|
||||
"iterations_one": "Итерация",
|
||||
"iterations_few": "Итерации",
|
||||
"iterations_many": "Итераций",
|
||||
"generations_one": "Генерация",
|
||||
"generations_few": "Генерации",
|
||||
"generations_many": "Генераций"
|
||||
"openQueue": "Открыть очередь"
|
||||
},
|
||||
"sdxl": {
|
||||
"refinerStart": "Запуск доработчика",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"common": {
|
||||
"nodes": "工作流程",
|
||||
"nodes": "節點",
|
||||
"img2img": "圖片轉圖片",
|
||||
"statusDisconnected": "已中斷連線",
|
||||
"back": "返回",
|
||||
@@ -11,239 +11,17 @@
|
||||
"reportBugLabel": "回報錯誤",
|
||||
"githubLabel": "GitHub",
|
||||
"hotkeysLabel": "快捷鍵",
|
||||
"languagePickerLabel": "語言",
|
||||
"languagePickerLabel": "切換語言",
|
||||
"unifiedCanvas": "統一畫布",
|
||||
"cancel": "取消",
|
||||
"txt2img": "文字轉圖片",
|
||||
"controlNet": "ControlNet",
|
||||
"advanced": "進階",
|
||||
"folder": "資料夾",
|
||||
"installed": "已安裝",
|
||||
"accept": "接受",
|
||||
"goTo": "前往",
|
||||
"input": "輸入",
|
||||
"random": "隨機",
|
||||
"selected": "已選擇",
|
||||
"communityLabel": "社群",
|
||||
"loading": "載入中",
|
||||
"delete": "刪除",
|
||||
"copy": "複製",
|
||||
"error": "錯誤",
|
||||
"file": "檔案",
|
||||
"format": "格式",
|
||||
"imageFailedToLoad": "無法載入圖片"
|
||||
"txt2img": "文字轉圖片"
|
||||
},
|
||||
"accessibility": {
|
||||
"invokeProgressBar": "Invoke 進度條",
|
||||
"uploadImage": "上傳圖片",
|
||||
"reset": "重置",
|
||||
"reset": "重設",
|
||||
"nextImage": "下一張圖片",
|
||||
"previousImage": "上一張圖片",
|
||||
"menu": "選單",
|
||||
"loadMore": "載入更多",
|
||||
"about": "關於",
|
||||
"createIssue": "建立問題",
|
||||
"resetUI": "$t(accessibility.reset) 介面",
|
||||
"submitSupportTicket": "提交支援工單",
|
||||
"mode": "模式"
|
||||
},
|
||||
"boards": {
|
||||
"loading": "載入中…",
|
||||
"movingImagesToBoard_other": "正在移動 {{count}} 張圖片至板上:",
|
||||
"move": "移動",
|
||||
"uncategorized": "未分類",
|
||||
"cancel": "取消"
|
||||
},
|
||||
"metadata": {
|
||||
"workflow": "工作流程",
|
||||
"steps": "步數",
|
||||
"model": "模型",
|
||||
"seed": "種子",
|
||||
"vae": "VAE",
|
||||
"seamless": "無縫",
|
||||
"metadata": "元數據",
|
||||
"width": "寬度",
|
||||
"height": "高度"
|
||||
},
|
||||
"accordions": {
|
||||
"control": {
|
||||
"title": "控制"
|
||||
},
|
||||
"compositing": {
|
||||
"title": "合成"
|
||||
},
|
||||
"advanced": {
|
||||
"title": "進階",
|
||||
"options": "$t(accordions.advanced.title) 選項"
|
||||
}
|
||||
},
|
||||
"hotkeys": {
|
||||
"nodesHotkeys": "節點",
|
||||
"cancel": {
|
||||
"title": "取消"
|
||||
},
|
||||
"generalHotkeys": "一般",
|
||||
"keyboardShortcuts": "快捷鍵",
|
||||
"appHotkeys": "應用程式"
|
||||
},
|
||||
"modelManager": {
|
||||
"advanced": "進階",
|
||||
"allModels": "全部模型",
|
||||
"variant": "變體",
|
||||
"config": "配置",
|
||||
"model": "模型",
|
||||
"selected": "已選擇",
|
||||
"huggingFace": "HuggingFace",
|
||||
"install": "安裝",
|
||||
"metadata": "元數據",
|
||||
"delete": "刪除",
|
||||
"description": "描述",
|
||||
"cancel": "取消",
|
||||
"convert": "轉換",
|
||||
"manual": "手動",
|
||||
"none": "無",
|
||||
"name": "名稱",
|
||||
"load": "載入",
|
||||
"height": "高度",
|
||||
"width": "寬度",
|
||||
"search": "搜尋",
|
||||
"vae": "VAE",
|
||||
"settings": "設定"
|
||||
},
|
||||
"controlnet": {
|
||||
"mlsd": "M-LSD",
|
||||
"canny": "Canny",
|
||||
"duplicate": "重複",
|
||||
"none": "無",
|
||||
"pidi": "PIDI",
|
||||
"h": "H",
|
||||
"balanced": "平衡",
|
||||
"crop": "裁切",
|
||||
"processor": "處理器",
|
||||
"control": "控制",
|
||||
"f": "F",
|
||||
"lineart": "線條藝術",
|
||||
"w": "W",
|
||||
"hed": "HED",
|
||||
"delete": "刪除"
|
||||
},
|
||||
"queue": {
|
||||
"queue": "佇列",
|
||||
"canceled": "已取消",
|
||||
"failed": "已失敗",
|
||||
"completed": "已完成",
|
||||
"cancel": "取消",
|
||||
"session": "工作階段",
|
||||
"batch": "批量",
|
||||
"item": "項目",
|
||||
"completedIn": "完成於",
|
||||
"notReady": "無法排隊"
|
||||
},
|
||||
"parameters": {
|
||||
"cancel": {
|
||||
"cancel": "取消"
|
||||
},
|
||||
"height": "高度",
|
||||
"type": "類型",
|
||||
"symmetry": "對稱性",
|
||||
"images": "圖片",
|
||||
"width": "寬度",
|
||||
"coherenceMode": "模式",
|
||||
"seed": "種子",
|
||||
"general": "一般",
|
||||
"strength": "強度",
|
||||
"steps": "步數",
|
||||
"info": "資訊"
|
||||
},
|
||||
"settings": {
|
||||
"beta": "Beta",
|
||||
"developer": "開發者",
|
||||
"general": "一般",
|
||||
"models": "模型"
|
||||
},
|
||||
"popovers": {
|
||||
"paramModel": {
|
||||
"heading": "模型"
|
||||
},
|
||||
"compositingCoherenceMode": {
|
||||
"heading": "模式"
|
||||
},
|
||||
"paramSteps": {
|
||||
"heading": "步數"
|
||||
},
|
||||
"controlNetProcessor": {
|
||||
"heading": "處理器"
|
||||
},
|
||||
"paramVAE": {
|
||||
"heading": "VAE"
|
||||
},
|
||||
"paramHeight": {
|
||||
"heading": "高度"
|
||||
},
|
||||
"paramSeed": {
|
||||
"heading": "種子"
|
||||
},
|
||||
"paramWidth": {
|
||||
"heading": "寬度"
|
||||
},
|
||||
"refinerSteps": {
|
||||
"heading": "步數"
|
||||
}
|
||||
},
|
||||
"unifiedCanvas": {
|
||||
"undo": "復原",
|
||||
"mask": "遮罩",
|
||||
"eraser": "橡皮擦",
|
||||
"antialiasing": "抗鋸齒",
|
||||
"redo": "重做",
|
||||
"layer": "圖層",
|
||||
"accept": "接受",
|
||||
"brush": "刷子",
|
||||
"move": "移動",
|
||||
"brushSize": "大小"
|
||||
},
|
||||
"nodes": {
|
||||
"workflowName": "名稱",
|
||||
"notes": "註釋",
|
||||
"workflowVersion": "版本",
|
||||
"workflowNotes": "註釋",
|
||||
"executionStateError": "錯誤",
|
||||
"unableToUpdateNodes_other": "無法更新 {{count}} 個節點",
|
||||
"integer": "整數",
|
||||
"workflow": "工作流程",
|
||||
"enum": "枚舉",
|
||||
"edit": "編輯",
|
||||
"string": "字串",
|
||||
"workflowTags": "標籤",
|
||||
"node": "節點",
|
||||
"boolean": "布林值",
|
||||
"workflowAuthor": "作者",
|
||||
"version": "版本",
|
||||
"executionStateCompleted": "已完成",
|
||||
"edge": "邊緣",
|
||||
"versionUnknown": " 版本未知"
|
||||
},
|
||||
"sdxl": {
|
||||
"steps": "步數",
|
||||
"loading": "載入中…",
|
||||
"refiner": "精煉器"
|
||||
},
|
||||
"gallery": {
|
||||
"copy": "複製",
|
||||
"download": "下載",
|
||||
"loading": "載入中"
|
||||
},
|
||||
"ui": {
|
||||
"tabs": {
|
||||
"models": "模型",
|
||||
"queueTab": "$t(ui.tabs.queue) $t(common.tab)",
|
||||
"queue": "佇列"
|
||||
}
|
||||
},
|
||||
"models": {
|
||||
"loading": "載入中"
|
||||
},
|
||||
"workflows": {
|
||||
"name": "名稱"
|
||||
"menu": "選單"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,13 +19,6 @@ function ThemeLocaleProvider({ children }: ThemeLocaleProviderProps) {
|
||||
return extendTheme({
|
||||
..._theme,
|
||||
direction,
|
||||
shadows: {
|
||||
..._theme.shadows,
|
||||
selectedForCompare:
|
||||
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-400)',
|
||||
hoverSelectedForCompare:
|
||||
'0px 0px 0px 1px var(--invoke-colors-base-900), 0px 0px 0px 4px var(--invoke-colors-green-300)',
|
||||
},
|
||||
});
|
||||
}, [direction]);
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import type { MapStore } from 'nanostores';
|
||||
import { atom, map } from 'nanostores';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
import { setEventListeners } from 'services/events/setEventListeners';
|
||||
import type { ClientToServerEvents, ServerToClientEvents } from 'services/events/types';
|
||||
import { setEventListeners } from 'services/events/util/setEventListeners';
|
||||
import type { ManagerOptions, Socket, SocketOptions } from 'socket.io-client';
|
||||
import { io } from 'socket.io-client';
|
||||
|
||||
|
||||
@@ -35,22 +35,26 @@ import { addImageUploadedFulfilledListener } from 'app/store/middleware/listener
|
||||
import { addModelSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelSelected';
|
||||
import { addModelsLoadedListener } from 'app/store/middleware/listenerMiddleware/listeners/modelsLoaded';
|
||||
import { addDynamicPromptsListener } from 'app/store/middleware/listenerMiddleware/listeners/promptChanged';
|
||||
import { addSetDefaultSettingsListener } from 'app/store/middleware/listenerMiddleware/listeners/setDefaultSettings';
|
||||
import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected';
|
||||
import { addSocketDisconnectedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketDisconnected';
|
||||
import { addGeneratorProgressEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress';
|
||||
import { addGraphExecutionStateCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketGraphExecutionStateComplete';
|
||||
import { addInvocationCompleteEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete';
|
||||
import { addInvocationErrorEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError';
|
||||
import { addInvocationStartedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted';
|
||||
import { addModelInstallEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelInstall';
|
||||
import { addModelLoadEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketModelLoad';
|
||||
import { addSocketQueueItemStatusChangedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged';
|
||||
import { addSocketSubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketSubscribed';
|
||||
import { addSocketUnsubscribedEventListener } from 'app/store/middleware/listenerMiddleware/listeners/socketio/socketUnsubscribed';
|
||||
import { addStagingAreaImageSavedListener } from 'app/store/middleware/listenerMiddleware/listeners/stagingAreaImageSaved';
|
||||
import { addUpdateAllNodesRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested';
|
||||
import { addUpscaleRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested';
|
||||
import { addWorkflowLoadRequestedListener } from 'app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addSetDefaultSettingsListener } from './listeners/setDefaultSettings';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
export type AppStartListening = TypedStartListening<RootState, AppDispatch>;
|
||||
@@ -98,11 +102,14 @@ addCommitStagingAreaImageListener(startAppListening);
|
||||
|
||||
// Socket.IO
|
||||
addGeneratorProgressEventListener(startAppListening);
|
||||
addGraphExecutionStateCompleteEventListener(startAppListening);
|
||||
addInvocationCompleteEventListener(startAppListening);
|
||||
addInvocationErrorEventListener(startAppListening);
|
||||
addInvocationStartedEventListener(startAppListening);
|
||||
addSocketConnectedEventListener(startAppListening);
|
||||
addSocketDisconnectedEventListener(startAppListening);
|
||||
addSocketSubscribedEventListener(startAppListening);
|
||||
addSocketUnsubscribedEventListener(startAppListening);
|
||||
addModelLoadEventListener(startAppListening);
|
||||
addModelInstallEventListener(startAppListening);
|
||||
addSocketQueueItemStatusChangedEventListener(startAppListening);
|
||||
|
||||
@@ -5,8 +5,8 @@ import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import {
|
||||
socketBulkDownloadComplete,
|
||||
socketBulkDownloadError,
|
||||
socketBulkDownloadCompleted,
|
||||
socketBulkDownloadFailed,
|
||||
socketBulkDownloadStarted,
|
||||
} from 'services/events/actions';
|
||||
|
||||
@@ -54,7 +54,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketBulkDownloadComplete,
|
||||
actionCreator: socketBulkDownloadCompleted,
|
||||
effect: async (action) => {
|
||||
log.debug(action.payload.data, 'Bulk download preparation completed');
|
||||
|
||||
@@ -80,7 +80,7 @@ export const addBulkDownloadListeners = (startAppListening: AppStartListening) =
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketBulkDownloadError,
|
||||
actionCreator: socketBulkDownloadFailed,
|
||||
effect: async (action) => {
|
||||
log.debug(action.payload.data, 'Bulk download preparation failed');
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
isControlAdapterLayer,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { isEqual } from 'lodash-es';
|
||||
@@ -132,13 +133,13 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni
|
||||
const [invocationCompleteAction] = await take(
|
||||
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
||||
socketInvocationComplete.match(action) &&
|
||||
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.invocation_source_id === processorNode.id
|
||||
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.source_node_id === processorNode.id
|
||||
);
|
||||
|
||||
// We still have to check the output type
|
||||
assert(
|
||||
invocationCompleteAction.payload.data.result.type === 'image_output',
|
||||
isImageOutput(invocationCompleteAction.payload.data.result),
|
||||
`Processor did not return an image output, got: ${invocationCompleteAction.payload.data.result}`
|
||||
);
|
||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
selectControlAdapterById,
|
||||
} from 'features/controlAdapters/store/controlAdaptersSlice';
|
||||
import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@@ -68,12 +69,12 @@ export const addControlNetImageProcessedListener = (startAppListening: AppStartL
|
||||
const [invocationCompleteAction] = await take(
|
||||
(action): action is ReturnType<typeof socketInvocationComplete> =>
|
||||
socketInvocationComplete.match(action) &&
|
||||
action.payload.data.batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.invocation_source_id === nodeId
|
||||
action.payload.data.queue_batch_id === enqueueResult.batch.batch_id &&
|
||||
action.payload.data.source_node_id === nodeId
|
||||
);
|
||||
|
||||
// We still have to check the output type
|
||||
if (invocationCompleteAction.payload.data.result.type === 'image_output') {
|
||||
if (isImageOutput(invocationCompleteAction.payload.data.result)) {
|
||||
const { image_name } = invocationCompleteAction.payload.data.result.image;
|
||||
|
||||
// Wait for the ImageDTO to be received
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesSelectors } from 'services/api/util';
|
||||
@@ -11,7 +11,6 @@ export const galleryImageClicked = createAction<{
|
||||
shiftKey: boolean;
|
||||
ctrlKey: boolean;
|
||||
metaKey: boolean;
|
||||
altKey: boolean;
|
||||
}>('gallery/imageClicked');
|
||||
|
||||
/**
|
||||
@@ -29,7 +28,7 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
||||
startAppListening({
|
||||
actionCreator: galleryImageClicked,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { imageDTO, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
||||
const { imageDTO, shiftKey, ctrlKey, metaKey } = action.payload;
|
||||
const state = getState();
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const { data: listImagesData } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
@@ -42,13 +41,7 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
||||
const imageDTOs = imagesSelectors.selectAll(listImagesData);
|
||||
const selection = state.gallery.selection;
|
||||
|
||||
if (altKey) {
|
||||
if (state.gallery.imageToCompare?.image_name === imageDTO.image_name) {
|
||||
dispatch(imageToCompareChanged(null));
|
||||
} else {
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
}
|
||||
} else if (shiftKey) {
|
||||
if (shiftKey) {
|
||||
const rangeEndImageName = imageDTO.image_name;
|
||||
const lastSelectedImage = selection[selection.length - 1]?.image_name;
|
||||
const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage);
|
||||
|
||||
@@ -14,8 +14,7 @@ import {
|
||||
rgLayerIPAdapterImageChanged,
|
||||
} from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import { isValidDrop } from 'features/dnd/util/isValidDrop';
|
||||
import { imageSelected, imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@@ -31,9 +30,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const log = logger('dnd');
|
||||
const { activeData, overData } = action.payload;
|
||||
if (!isValidDrop(overData, activeData)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (activeData.payloadType === 'IMAGE_DTO') {
|
||||
log.debug({ activeData, overData }, 'Image dropped');
|
||||
@@ -54,7 +50,6 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
dispatch(imageSelected(activeData.payload.imageDTO));
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -187,18 +182,24 @@ export const addImageDroppedListener = (startAppListening: AppStartListening) =>
|
||||
}
|
||||
|
||||
/**
|
||||
* Image selected for compare
|
||||
* TODO
|
||||
* Image selection dropped on node image collection field
|
||||
*/
|
||||
if (
|
||||
overData.actionType === 'SELECT_FOR_COMPARE' &&
|
||||
activeData.payloadType === 'IMAGE_DTO' &&
|
||||
activeData.payload.imageDTO
|
||||
) {
|
||||
const { imageDTO } = activeData.payload;
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
return;
|
||||
}
|
||||
// if (
|
||||
// overData.actionType === 'SET_MULTI_NODES_IMAGE' &&
|
||||
// activeData.payloadType === 'IMAGE_DTO' &&
|
||||
// activeData.payload.imageDTO
|
||||
// ) {
|
||||
// const { fieldName, nodeId } = overData.context;
|
||||
// dispatch(
|
||||
// fieldValueChanged({
|
||||
// nodeId,
|
||||
// fieldName,
|
||||
// value: [activeData.payload.imageDTO],
|
||||
// })
|
||||
// );
|
||||
// return;
|
||||
// }
|
||||
|
||||
/**
|
||||
* Image dropped on user board
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketGeneratorProgress } from 'services/events/actions';
|
||||
@@ -12,9 +11,9 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
|
||||
startAppListening({
|
||||
actionCreator: socketGeneratorProgress,
|
||||
effect: (action) => {
|
||||
log.trace(parseify(action.payload), `Generator progress`);
|
||||
const { invocation_source_id, step, total_steps, progress_image } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
log.trace(action.payload, `Generator progress`);
|
||||
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
nes.progress = (step + 1) / total_steps;
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketGraphExecutionStateComplete } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addGraphExecutionStateCompleteEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketGraphExecutionStateComplete,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, 'Session complete');
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
@@ -28,12 +29,12 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
actionCreator: socketInvocationComplete,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
const { data } = action.payload;
|
||||
log.debug({ data: parseify(data) }, `Invocation complete (${data.invocation.type})`);
|
||||
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
|
||||
|
||||
const { result, invocation_source_id } = data;
|
||||
const { result, node, queue_batch_id, source_node_id } = data;
|
||||
// This complete event has an associated image output
|
||||
if (data.result.type === 'image_output' && !nodeTypeDenylist.includes(data.invocation.type)) {
|
||||
const { image_name } = data.result.image;
|
||||
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
|
||||
const { image_name } = result.image;
|
||||
const { canvas, gallery } = getState();
|
||||
|
||||
// This populates the `getImageDTO` cache
|
||||
@@ -47,7 +48,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
imageDTORequest.unsubscribe();
|
||||
|
||||
// Add canvas images to the staging area
|
||||
if (canvas.batchIds.includes(data.batch_id) && data.invocation_source_id === CANVAS_OUTPUT) {
|
||||
if (canvas.batchIds.includes(queue_batch_id) && data.source_node_id === CANVAS_OUTPUT) {
|
||||
dispatch(addImageToStagingArea(imageDTO));
|
||||
}
|
||||
|
||||
@@ -113,7 +114,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
}
|
||||
}
|
||||
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.COMPLETED;
|
||||
if (nes.progress !== null) {
|
||||
|
||||
@@ -1,24 +1,52 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import ToastWithSessionRefDescription from 'features/toast/ToastWithSessionRefDescription';
|
||||
import { t } from 'i18next';
|
||||
import { startCase } from 'lodash-es';
|
||||
import { socketInvocationError } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
const getTitle = (errorType: string) => {
|
||||
if (errorType === 'OutOfMemoryError') {
|
||||
return t('toast.outOfMemoryError');
|
||||
}
|
||||
return t('toast.serverError');
|
||||
};
|
||||
|
||||
const getDescription = (errorType: string, sessionId: string, isLocal?: boolean) => {
|
||||
if (!isLocal) {
|
||||
if (errorType === 'OutOfMemoryError') {
|
||||
return ToastWithSessionRefDescription({
|
||||
message: t('toast.outOfMemoryDescription'),
|
||||
sessionId,
|
||||
});
|
||||
}
|
||||
return ToastWithSessionRefDescription({
|
||||
message: errorType,
|
||||
sessionId,
|
||||
});
|
||||
}
|
||||
return errorType;
|
||||
};
|
||||
|
||||
export const addInvocationErrorEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketInvocationError,
|
||||
effect: (action) => {
|
||||
const { invocation_source_id, invocation, error_type, error_message, error_traceback } = action.payload.data;
|
||||
log.error(parseify(action.payload), `Invocation error (${invocation.type})`);
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
effect: (action, { getState }) => {
|
||||
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
|
||||
const { source_node_id, error_type, error_message, error_traceback, graph_execution_state_id } =
|
||||
action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.FAILED;
|
||||
nes.progress = null;
|
||||
nes.progressImage = null;
|
||||
|
||||
nes.error = {
|
||||
error_type,
|
||||
error_message,
|
||||
@@ -26,6 +54,19 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
|
||||
};
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
|
||||
const errorType = startCase(error_type);
|
||||
const sessionId = graph_execution_state_id;
|
||||
const { isLocal } = getState().config;
|
||||
|
||||
toast({
|
||||
id: `INVOCATION_ERROR_${errorType}`,
|
||||
title: getTitle(errorType),
|
||||
status: 'error',
|
||||
duration: null,
|
||||
description: getDescription(errorType, sessionId, isLocal),
|
||||
updateDescription: isLocal ? true : false,
|
||||
});
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketInvocationStarted } from 'services/events/actions';
|
||||
@@ -12,9 +11,9 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
|
||||
startAppListening({
|
||||
actionCreator: socketInvocationStarted,
|
||||
effect: (action) => {
|
||||
log.debug(parseify(action.payload), `Invocation started (${action.payload.data.invocation.type})`);
|
||||
const { invocation_source_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[invocation_source_id]);
|
||||
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
|
||||
const { source_node_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
|
||||
@@ -3,14 +3,14 @@ import { api, LIST_TAG } from 'services/api';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import {
|
||||
socketModelInstallCancelled,
|
||||
socketModelInstallComplete,
|
||||
socketModelInstallDownloadProgress,
|
||||
socketModelInstallCompleted,
|
||||
socketModelInstallDownloading,
|
||||
socketModelInstallError,
|
||||
} from 'services/events/actions';
|
||||
|
||||
export const addModelInstallEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallDownloadProgress,
|
||||
actionCreator: socketModelInstallDownloading,
|
||||
effect: async (action, { dispatch }) => {
|
||||
const { bytes, total_bytes, id } = action.payload.data;
|
||||
|
||||
@@ -29,7 +29,7 @@ export const addModelInstallEventListener = (startAppListening: AppStartListenin
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelInstallComplete,
|
||||
actionCreator: socketModelInstallCompleted,
|
||||
effect: (action, { dispatch }) => {
|
||||
const { id } = action.payload.data;
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketModelLoadComplete, socketModelLoadStarted } from 'services/events/actions';
|
||||
import { socketModelLoadCompleted, socketModelLoadStarted } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
@@ -8,11 +8,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
|
||||
startAppListening({
|
||||
actionCreator: socketModelLoadStarted,
|
||||
effect: (action) => {
|
||||
const { config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = config;
|
||||
const { model_config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = model_config;
|
||||
|
||||
const extras: string[] = [base, type];
|
||||
|
||||
if (submodel_type) {
|
||||
extras.push(submodel_type);
|
||||
}
|
||||
@@ -24,10 +23,10 @@ export const addModelLoadEventListener = (startAppListening: AppStartListening)
|
||||
});
|
||||
|
||||
startAppListening({
|
||||
actionCreator: socketModelLoadComplete,
|
||||
actionCreator: socketModelLoadCompleted,
|
||||
effect: (action) => {
|
||||
const { config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = config;
|
||||
const { model_config, submodel_type } = action.payload.data;
|
||||
const { name, base, type } = model_config;
|
||||
|
||||
const extras: string[] = [base, type];
|
||||
if (submodel_type) {
|
||||
|
||||
@@ -3,8 +3,6 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import ErrorToastDescription, { getTitleFromErrorType } from 'features/toast/ErrorToastDescription';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
|
||||
import { socketQueueItemStatusChanged } from 'services/events/actions';
|
||||
@@ -14,38 +12,18 @@ const log = logger('socketio');
|
||||
export const addSocketQueueItemStatusChangedEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketQueueItemStatusChanged,
|
||||
effect: async (action, { dispatch, getState }) => {
|
||||
effect: async (action, { dispatch }) => {
|
||||
// we've got new status for the queue item, batch and queue
|
||||
const {
|
||||
item_id,
|
||||
session_id,
|
||||
status,
|
||||
started_at,
|
||||
updated_at,
|
||||
completed_at,
|
||||
batch_status,
|
||||
queue_status,
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
} = action.payload.data;
|
||||
const { queue_item, batch_status, queue_status } = action.payload.data;
|
||||
|
||||
log.debug(action.payload, `Queue item ${item_id} status updated: ${status}`);
|
||||
log.debug(action.payload, `Queue item ${queue_item.item_id} status updated: ${queue_item.status}`);
|
||||
|
||||
// Update this specific queue item in the list of queue items (this is the queue item DTO, without the session)
|
||||
dispatch(
|
||||
queueApi.util.updateQueryData('listQueueItems', undefined, (draft) => {
|
||||
queueItemsAdapter.updateOne(draft, {
|
||||
id: String(item_id),
|
||||
changes: {
|
||||
status,
|
||||
started_at,
|
||||
updated_at: updated_at ?? undefined,
|
||||
completed_at: completed_at ?? undefined,
|
||||
error_type,
|
||||
error_message,
|
||||
error_traceback,
|
||||
},
|
||||
id: String(queue_item.item_id),
|
||||
changes: queue_item,
|
||||
});
|
||||
})
|
||||
);
|
||||
@@ -72,11 +50,11 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
|
||||
'CurrentSessionQueueItem',
|
||||
'NextSessionQueueItem',
|
||||
'InvocationCacheStatus',
|
||||
{ type: 'SessionQueueItem', id: item_id },
|
||||
{ type: 'SessionQueueItem', id: queue_item.item_id },
|
||||
])
|
||||
);
|
||||
|
||||
if (status === 'in_progress') {
|
||||
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
|
||||
forEach($nodeExecutionStates.get(), (nes) => {
|
||||
if (!nes) {
|
||||
return;
|
||||
@@ -89,25 +67,6 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
|
||||
clone.outputs = [];
|
||||
$nodeExecutionStates.setKey(clone.nodeId, clone);
|
||||
});
|
||||
} else if (status === 'failed' && error_type) {
|
||||
const isLocal = getState().config.isLocal ?? true;
|
||||
const sessionId = session_id;
|
||||
|
||||
toast({
|
||||
id: `INVOCATION_ERROR_${error_type}`,
|
||||
title: getTitleFromErrorType(error_type),
|
||||
status: 'error',
|
||||
duration: null,
|
||||
updateDescription: isLocal,
|
||||
description: (
|
||||
<ErrorToastDescription
|
||||
errorType={error_type}
|
||||
errorMessage={error_message}
|
||||
sessionId={sessionId}
|
||||
isLocal={isLocal}
|
||||
/>
|
||||
),
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
@@ -0,0 +1,14 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketSubscribedSession } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addSocketSubscribedEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketSubscribedSession,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, 'Subscribed');
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,13 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { socketUnsubscribedSession } from 'services/events/actions';
|
||||
const log = logger('socketio');
|
||||
|
||||
export const addSocketUnsubscribedEventListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: socketUnsubscribedSession,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, 'Unsubscribed');
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -3,7 +3,7 @@ import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { $needsFit } from 'features/nodes/store/reactFlowInstance';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
|
||||
import { graphToWorkflow } from 'features/nodes/util/workflow/graphToWorkflow';
|
||||
@@ -65,7 +65,9 @@ export const addWorkflowLoadRequestedListener = (startAppListening: AppStartList
|
||||
});
|
||||
}
|
||||
|
||||
$needsFit.set(true);
|
||||
requestAnimationFrame(() => {
|
||||
$flow.get()?.fitView();
|
||||
});
|
||||
} catch (e) {
|
||||
if (e instanceof WorkflowVersionError) {
|
||||
// The workflow version was not recognized in the valid list of versions
|
||||
|
||||
@@ -35,7 +35,6 @@ type IAIDndImageProps = FlexProps & {
|
||||
draggableData?: TypesafeDraggableData;
|
||||
dropLabel?: ReactNode;
|
||||
isSelected?: boolean;
|
||||
isSelectedForCompare?: boolean;
|
||||
thumbnail?: boolean;
|
||||
noContentFallback?: ReactElement;
|
||||
useThumbailFallback?: boolean;
|
||||
@@ -62,7 +61,6 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
draggableData,
|
||||
dropLabel,
|
||||
isSelected = false,
|
||||
isSelectedForCompare = false,
|
||||
thumbnail = false,
|
||||
noContentFallback = defaultNoContentFallback,
|
||||
uploadElement = defaultUploadElement,
|
||||
@@ -167,11 +165,7 @@ const IAIDndImage = (props: IAIDndImageProps) => {
|
||||
data-testid={dataTestId}
|
||||
/>
|
||||
{withMetadataOverlay && <ImageMetadataOverlay imageDTO={imageDTO} />}
|
||||
<SelectionOverlay
|
||||
isSelected={isSelected}
|
||||
isSelectedForCompare={isSelectedForCompare}
|
||||
isHovered={withHoverOverlay ? isHovered : false}
|
||||
/>
|
||||
<SelectionOverlay isSelected={isSelected} isHovered={withHoverOverlay ? isHovered : false} />
|
||||
</Flex>
|
||||
)}
|
||||
{!imageDTO && !isUploadDisabled && (
|
||||
|
||||
@@ -36,7 +36,7 @@ const IAIDroppable = (props: IAIDroppableProps) => {
|
||||
pointerEvents={active ? 'auto' : 'none'}
|
||||
>
|
||||
<AnimatePresence>
|
||||
{isValidDrop(data, active?.data.current) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
|
||||
{isValidDrop(data, active) && <IAIDropOverlay isOver={isOver} label={dropLabel} />}
|
||||
</AnimatePresence>
|
||||
</Box>
|
||||
);
|
||||
|
||||
@@ -3,17 +3,10 @@ import { memo, useMemo } from 'react';
|
||||
|
||||
type Props = {
|
||||
isSelected: boolean;
|
||||
isSelectedForCompare: boolean;
|
||||
isHovered: boolean;
|
||||
};
|
||||
const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props) => {
|
||||
const SelectionOverlay = ({ isSelected, isHovered }: Props) => {
|
||||
const shadow = useMemo(() => {
|
||||
if (isSelectedForCompare && isHovered) {
|
||||
return 'hoverSelectedForCompare';
|
||||
}
|
||||
if (isSelectedForCompare && !isHovered) {
|
||||
return 'selectedForCompare';
|
||||
}
|
||||
if (isSelected && isHovered) {
|
||||
return 'hoverSelected';
|
||||
}
|
||||
@@ -24,7 +17,7 @@ const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props
|
||||
return 'hoverUnselected';
|
||||
}
|
||||
return undefined;
|
||||
}, [isHovered, isSelected, isSelectedForCompare]);
|
||||
}, [isHovered, isSelected]);
|
||||
return (
|
||||
<Box
|
||||
className="selection-box"
|
||||
@@ -34,7 +27,7 @@ const SelectionOverlay = ({ isSelected, isSelectedForCompare, isHovered }: Props
|
||||
bottom={0}
|
||||
insetInlineStart={0}
|
||||
borderRadius="base"
|
||||
opacity={isSelected || isSelectedForCompare ? 1 : 0.7}
|
||||
opacity={isSelected ? 1 : 0.7}
|
||||
transitionProperty="common"
|
||||
transitionDuration="0.1s"
|
||||
pointerEvents="none"
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
|
||||
export const useBoolean = (initialValue: boolean) => {
|
||||
const [isTrue, set] = useState(initialValue);
|
||||
const setTrue = useCallback(() => set(true), []);
|
||||
const setFalse = useCallback(() => set(false), []);
|
||||
const toggle = useCallback(() => set((v) => !v), []);
|
||||
|
||||
const api = useMemo(
|
||||
() => ({
|
||||
isTrue,
|
||||
set,
|
||||
setTrue,
|
||||
setFalse,
|
||||
toggle,
|
||||
}),
|
||||
[isTrue, set, setTrue, setFalse, toggle]
|
||||
);
|
||||
|
||||
return api;
|
||||
};
|
||||
@@ -1,7 +1,3 @@
|
||||
export const stopPropagation = (e: React.MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
};
|
||||
|
||||
export const preventDefault = (e: React.MouseEvent) => {
|
||||
e.preventDefault();
|
||||
};
|
||||
|
||||
@@ -613,7 +613,7 @@ export const canvasSlice = createSlice({
|
||||
state.batchIds = state.batchIds.filter((id) => id !== batch_status.batch_id);
|
||||
}
|
||||
|
||||
const queueItemStatus = action.payload.data.status;
|
||||
const queueItemStatus = action.payload.data.queue_item.status;
|
||||
if (queueItemStatus === 'canceled' || queueItemStatus === 'failed') {
|
||||
resetStagingAreaIfEmpty(state);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { zModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { merge, omit } from 'lodash-es';
|
||||
import type {
|
||||
AnyInvocation,
|
||||
BaseModelType,
|
||||
ControlNetModelConfig,
|
||||
ImageDTO,
|
||||
T2IAdapterModelConfig,
|
||||
} from 'services/api/types';
|
||||
import type { BaseModelType, ControlNetModelConfig, Graph, ImageDTO, T2IAdapterModelConfig } from 'services/api/types';
|
||||
import { z } from 'zod';
|
||||
|
||||
const zId = z.string().min(1);
|
||||
@@ -153,7 +147,7 @@ const zBeginEndStepPct = z
|
||||
|
||||
const zControlAdapterBase = z.object({
|
||||
id: zId,
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
weight: z.number().gte(0).lte(1),
|
||||
image: zImageWithDims.nullable(),
|
||||
processedImage: zImageWithDims.nullable(),
|
||||
processorConfig: zProcessorConfig.nullable(),
|
||||
@@ -189,7 +183,7 @@ export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safePar
|
||||
export const zIPAdapterConfigV2 = z.object({
|
||||
id: zId,
|
||||
type: z.literal('ip_adapter'),
|
||||
weight: z.number().gte(-1).lte(2),
|
||||
weight: z.number().gte(0).lte(1),
|
||||
method: zIPMethodV2,
|
||||
image: zImageWithDims.nullable(),
|
||||
model: zModelIdentifierField.nullable(),
|
||||
@@ -222,7 +216,10 @@ type ProcessorData<T extends ProcessorTypeV2> = {
|
||||
labelTKey: string;
|
||||
descriptionTKey: string;
|
||||
buildDefaults(baseModel?: BaseModelType): Extract<ProcessorConfig, { type: T }>;
|
||||
buildNode(image: ImageWithDims, config: Extract<ProcessorConfig, { type: T }>): Extract<AnyInvocation, { type: T }>;
|
||||
buildNode(
|
||||
image: ImageWithDims,
|
||||
config: Extract<ProcessorConfig, { type: T }>
|
||||
): Extract<Graph['nodes'][string], { type: T }>;
|
||||
};
|
||||
|
||||
const minDim = (image: ImageWithDims): number => Math.min(image.width, image.height);
|
||||
|
||||
@@ -54,7 +54,7 @@ const BBOX_SELECTED_STROKE = 'rgba(78, 190, 255, 1)';
|
||||
const BRUSH_BORDER_INNER_COLOR = 'rgba(0,0,0,1)';
|
||||
const BRUSH_BORDER_OUTER_COLOR = 'rgba(255,255,255,0.8)';
|
||||
// This is invokeai/frontend/web/public/assets/images/transparent_bg.png as a dataURL
|
||||
export const STAGE_BG_DATAURL =
|
||||
const STAGE_BG_DATAURL =
|
||||
'';
|
||||
|
||||
const mapId = (object: { id: string }) => object.id;
|
||||
|
||||
@@ -18,7 +18,7 @@ type BaseDropData = {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type CurrentImageDropData = BaseDropData & {
|
||||
type CurrentImageDropData = BaseDropData & {
|
||||
actionType: 'SET_CURRENT_IMAGE';
|
||||
};
|
||||
|
||||
@@ -79,14 +79,6 @@ export type RemoveFromBoardDropData = BaseDropData & {
|
||||
actionType: 'REMOVE_FROM_BOARD';
|
||||
};
|
||||
|
||||
export type SelectForCompareDropData = BaseDropData & {
|
||||
actionType: 'SELECT_FOR_COMPARE';
|
||||
context: {
|
||||
firstImageName?: string | null;
|
||||
secondImageName?: string | null;
|
||||
};
|
||||
};
|
||||
|
||||
export type TypesafeDroppableData =
|
||||
| CurrentImageDropData
|
||||
| ControlAdapterDropData
|
||||
@@ -97,8 +89,7 @@ export type TypesafeDroppableData =
|
||||
| CALayerImageDropData
|
||||
| IPALayerImageDropData
|
||||
| RGLayerIPAdapterImageDropData
|
||||
| IILayerImageDropData
|
||||
| SelectForCompareDropData;
|
||||
| IILayerImageDropData;
|
||||
|
||||
type BaseDragData = {
|
||||
id: string;
|
||||
@@ -143,7 +134,7 @@ export type UseDraggableTypesafeReturnValue = Omit<ReturnType<typeof useOriginal
|
||||
over: TypesafeOver | null;
|
||||
};
|
||||
|
||||
interface TypesafeActive extends Omit<Active, 'data'> {
|
||||
export interface TypesafeActive extends Omit<Active, 'data'> {
|
||||
data: React.MutableRefObject<TypesafeDraggableData | undefined>;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import type { TypesafeActive, TypesafeDroppableData } from 'features/dnd/types';
|
||||
|
||||
export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?: TypesafeDraggableData | null) => {
|
||||
if (!overData || !activeData) {
|
||||
export const isValidDrop = (overData: TypesafeDroppableData | undefined, active: TypesafeActive | null) => {
|
||||
if (!overData || !active?.data.current) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const { actionType } = overData;
|
||||
const { payloadType } = activeData;
|
||||
const { payloadType } = active.data.current;
|
||||
|
||||
if (overData.id === activeData.id) {
|
||||
if (overData.id === active.data.current.id) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -29,8 +29,6 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SET_NODES_IMAGE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'SELECT_FOR_COMPARE':
|
||||
return payloadType === 'IMAGE_DTO';
|
||||
case 'ADD_TO_BOARD': {
|
||||
// If the board is the same, don't allow the drop
|
||||
|
||||
@@ -42,7 +40,7 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
|
||||
|
||||
// Check if the image's board is the board we are dragging onto
|
||||
if (payloadType === 'IMAGE_DTO') {
|
||||
const { imageDTO } = activeData.payload;
|
||||
const { imageDTO } = active.data.current.payload;
|
||||
const currentBoard = imageDTO.board_id ?? 'none';
|
||||
const destinationBoard = overData.context.boardId;
|
||||
|
||||
@@ -51,7 +49,7 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
|
||||
|
||||
if (payloadType === 'GALLERY_SELECTION') {
|
||||
// Assume all images are on the same board - this is true for the moment
|
||||
const currentBoard = activeData.payload.boardId;
|
||||
const currentBoard = active.data.current.payload.boardId;
|
||||
const destinationBoard = overData.context.boardId;
|
||||
return currentBoard !== destinationBoard;
|
||||
}
|
||||
@@ -69,14 +67,14 @@ export const isValidDrop = (overData?: TypesafeDroppableData | null, activeData?
|
||||
|
||||
// Check if the image's board is the board we are dragging onto
|
||||
if (payloadType === 'IMAGE_DTO') {
|
||||
const { imageDTO } = activeData.payload;
|
||||
const { imageDTO } = active.data.current.payload;
|
||||
const currentBoard = imageDTO.board_id ?? 'none';
|
||||
|
||||
return currentBoard !== 'none';
|
||||
}
|
||||
|
||||
if (payloadType === 'GALLERY_SELECTION') {
|
||||
const currentBoard = activeData.payload.boardId;
|
||||
const currentBoard = active.data.current.payload.boardId;
|
||||
return currentBoard !== 'none';
|
||||
}
|
||||
|
||||
|
||||
@@ -162,7 +162,7 @@ const GalleryBoard = ({ board, isSelected, setBoardToDelete }: GalleryBoardProps
|
||||
</Flex>
|
||||
)}
|
||||
{isSelectedForAutoAdd && <AutoAddIcon />}
|
||||
<SelectionOverlay isSelected={isSelected} isSelectedForCompare={false} isHovered={isHovered} />
|
||||
<SelectionOverlay isSelected={isSelected} isHovered={isHovered} />
|
||||
<Flex
|
||||
position="absolute"
|
||||
bottom={0}
|
||||
|
||||
@@ -117,7 +117,7 @@ const NoBoardBoard = memo(({ isSelected }: Props) => {
|
||||
>
|
||||
{boardName}
|
||||
</Flex>
|
||||
<SelectionOverlay isSelected={isSelected} isSelectedForCompare={false} isHovered={isHovered} />
|
||||
<SelectionOverlay isSelected={isSelected} isHovered={isHovered} />
|
||||
<IAIDroppable data={droppableData} dropLabel={<Text fontSize="md">{t('unifiedCanvas.move')}</Text>} />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
|
||||
@@ -10,7 +10,6 @@ import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice';
|
||||
import { useImageActions } from 'features/gallery/hooks/useImageActions';
|
||||
import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions';
|
||||
import { imageToCompareChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectOptimalDimension } from 'features/parameters/store/generationSlice';
|
||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||
@@ -28,7 +27,6 @@ import {
|
||||
PiDownloadSimpleBold,
|
||||
PiFlowArrowBold,
|
||||
PiFoldersBold,
|
||||
PiImagesBold,
|
||||
PiPlantBold,
|
||||
PiQuotesBold,
|
||||
PiShareFatBold,
|
||||
@@ -46,7 +44,6 @@ type SingleSelectionMenuItemsProps = {
|
||||
const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
const { imageDTO } = props;
|
||||
const optimalDimension = useAppSelector(selectOptimalDimension);
|
||||
const maySelectForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name !== imageDTO.image_name);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const isCanvasEnabled = useFeatureStatus('canvas');
|
||||
@@ -120,10 +117,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
downloadImage(imageDTO.image_url, imageDTO.image_name);
|
||||
}, [downloadImage, imageDTO.image_name, imageDTO.image_url]);
|
||||
|
||||
const handleSelectImageForCompare = useCallback(() => {
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
}, [dispatch, imageDTO]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<MenuItem as="a" href={imageDTO.image_url} target="_blank" icon={<PiShareFatBold />}>
|
||||
@@ -137,9 +130,6 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => {
|
||||
<MenuItem icon={<PiDownloadSimpleBold />} onClickCapture={handleDownloadImage}>
|
||||
{t('parameters.downloadImage')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiImagesBold />} isDisabled={!maySelectForCompare} onClick={handleSelectImageForCompare}>
|
||||
{t('gallery.selectForCompare')}
|
||||
</MenuItem>
|
||||
<MenuDivider />
|
||||
<MenuItem
|
||||
icon={getAndLoadEmbeddedWorkflowResult.isLoading ? <SpinnerIcon /> : <PiFlowArrowBold />}
|
||||
|
||||
@@ -11,7 +11,7 @@ import type { GallerySelectionDraggableData, ImageDraggableData, TypesafeDraggab
|
||||
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
|
||||
import { useMultiselect } from 'features/gallery/hooks/useMultiselect';
|
||||
import { useScrollIntoView } from 'features/gallery/hooks/useScrollIntoView';
|
||||
import { imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import type { MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -46,7 +46,6 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
const { t } = useTranslation();
|
||||
const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId);
|
||||
const alwaysShowImageSizeBadge = useAppSelector((s) => s.gallery.alwaysShowImageSizeBadge);
|
||||
const isSelectedForCompare = useAppSelector((s) => s.gallery.imageToCompare?.image_name === imageName);
|
||||
const { handleClick, isSelected, areMultiplesSelected } = useMultiselect(imageDTO);
|
||||
|
||||
const customStarUi = useStore($customStarUI);
|
||||
@@ -106,7 +105,6 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
|
||||
const onDoubleClick = useCallback(() => {
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
dispatch(imageToCompareChanged(null));
|
||||
}, [dispatch]);
|
||||
|
||||
const handleMouseOut = useCallback(() => {
|
||||
@@ -154,7 +152,6 @@ const GalleryImage = (props: HoverableImageProps) => {
|
||||
imageDTO={imageDTO}
|
||||
draggableData={draggableData}
|
||||
isSelected={isSelected}
|
||||
isSelectedForCompare={isSelectedForCompare}
|
||||
minSize={0}
|
||||
imageSx={imageSx}
|
||||
isDropDisabled={true}
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
import {
|
||||
Button,
|
||||
ButtonGroup,
|
||||
Flex,
|
||||
Icon,
|
||||
IconButton,
|
||||
Kbd,
|
||||
ListItem,
|
||||
Tooltip,
|
||||
UnorderedList,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
comparedImagesSwapped,
|
||||
comparisonFitChanged,
|
||||
comparisonModeChanged,
|
||||
comparisonModeCycled,
|
||||
imageToCompareChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { Trans, useTranslation } from 'react-i18next';
|
||||
import { PiArrowsOutBold, PiQuestion, PiSwapBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
export const CompareToolbar = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode);
|
||||
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
|
||||
const setComparisonModeSlider = useCallback(() => {
|
||||
dispatch(comparisonModeChanged('slider'));
|
||||
}, [dispatch]);
|
||||
const setComparisonModeSideBySide = useCallback(() => {
|
||||
dispatch(comparisonModeChanged('side-by-side'));
|
||||
}, [dispatch]);
|
||||
const setComparisonModeHover = useCallback(() => {
|
||||
dispatch(comparisonModeChanged('hover'));
|
||||
}, [dispatch]);
|
||||
const swapImages = useCallback(() => {
|
||||
dispatch(comparedImagesSwapped());
|
||||
}, [dispatch]);
|
||||
useHotkeys('c', swapImages, [swapImages]);
|
||||
const toggleComparisonFit = useCallback(() => {
|
||||
dispatch(comparisonFitChanged(comparisonFit === 'contain' ? 'fill' : 'contain'));
|
||||
}, [dispatch, comparisonFit]);
|
||||
const exitCompare = useCallback(() => {
|
||||
dispatch(imageToCompareChanged(null));
|
||||
}, [dispatch]);
|
||||
useHotkeys('esc', exitCompare, [exitCompare]);
|
||||
const nextMode = useCallback(() => {
|
||||
dispatch(comparisonModeCycled());
|
||||
}, [dispatch]);
|
||||
useHotkeys('m', nextMode, [nextMode]);
|
||||
|
||||
return (
|
||||
<Flex w="full" gap={2}>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineEnd="auto">
|
||||
<IconButton
|
||||
icon={<PiSwapBold />}
|
||||
aria-label={`${t('gallery.swapImages')} (C)`}
|
||||
tooltip={`${t('gallery.swapImages')} (C)`}
|
||||
onClick={swapImages}
|
||||
/>
|
||||
{comparisonMode !== 'side-by-side' && (
|
||||
<IconButton
|
||||
aria-label={t('gallery.stretchToFit')}
|
||||
tooltip={t('gallery.stretchToFit')}
|
||||
onClick={toggleComparisonFit}
|
||||
colorScheme={comparisonFit === 'fill' ? 'invokeBlue' : 'base'}
|
||||
variant="outline"
|
||||
icon={<PiArrowsOutBold />}
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex flex={1} gap={4} justifyContent="center">
|
||||
<ButtonGroup variant="outline">
|
||||
<Button
|
||||
flexShrink={0}
|
||||
onClick={setComparisonModeSlider}
|
||||
colorScheme={comparisonMode === 'slider' ? 'invokeBlue' : 'base'}
|
||||
>
|
||||
{t('gallery.slider')}
|
||||
</Button>
|
||||
<Button
|
||||
flexShrink={0}
|
||||
onClick={setComparisonModeSideBySide}
|
||||
colorScheme={comparisonMode === 'side-by-side' ? 'invokeBlue' : 'base'}
|
||||
>
|
||||
{t('gallery.sideBySide')}
|
||||
</Button>
|
||||
<Button
|
||||
flexShrink={0}
|
||||
onClick={setComparisonModeHover}
|
||||
colorScheme={comparisonMode === 'hover' ? 'invokeBlue' : 'base'}
|
||||
>
|
||||
{t('gallery.hover')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineStart="auto" alignItems="center">
|
||||
<Tooltip label={<CompareHelp />}>
|
||||
<Flex alignItems="center">
|
||||
<Icon boxSize={8} color="base.500" as={PiQuestion} lineHeight={0} />
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
<IconButton
|
||||
icon={<PiXBold />}
|
||||
aria-label={`${t('gallery.exitCompare')} (Esc)`}
|
||||
tooltip={`${t('gallery.exitCompare')} (Esc)`}
|
||||
onClick={exitCompare}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
CompareToolbar.displayName = 'CompareToolbar';
|
||||
|
||||
const CompareHelp = () => {
|
||||
return (
|
||||
<UnorderedList>
|
||||
<ListItem>
|
||||
<Trans i18nKey="gallery.compareHelp1" components={{ Kbd: <Kbd /> }}></Trans>
|
||||
</ListItem>
|
||||
<ListItem>
|
||||
<Trans i18nKey="gallery.compareHelp2" components={{ Kbd: <Kbd /> }}></Trans>
|
||||
</ListItem>
|
||||
<ListItem>
|
||||
<Trans i18nKey="gallery.compareHelp3" components={{ Kbd: <Kbd /> }}></Trans>
|
||||
</ListItem>
|
||||
<ListItem>
|
||||
<Trans i18nKey="gallery.compareHelp4" components={{ Kbd: <Kbd /> }}></Trans>
|
||||
</ListItem>
|
||||
</UnorderedList>
|
||||
);
|
||||
};
|
||||
@@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDndImage from 'common/components/IAIDndImage';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import type { TypesafeDraggableData } from 'features/dnd/types';
|
||||
import type { TypesafeDraggableData, TypesafeDroppableData } from 'features/dnd/types';
|
||||
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
@@ -22,7 +22,21 @@ const selectLastSelectedImageName = createSelector(
|
||||
(lastSelectedImage) => lastSelectedImage?.image_name
|
||||
);
|
||||
|
||||
const CurrentImagePreview = () => {
|
||||
type Props = {
|
||||
isDragDisabled?: boolean;
|
||||
isDropDisabled?: boolean;
|
||||
withNextPrevButtons?: boolean;
|
||||
withMetadata?: boolean;
|
||||
alwaysShowProgress?: boolean;
|
||||
};
|
||||
|
||||
const CurrentImagePreview = ({
|
||||
isDragDisabled = false,
|
||||
isDropDisabled = false,
|
||||
withNextPrevButtons = true,
|
||||
withMetadata = true,
|
||||
alwaysShowProgress = false,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const shouldShowImageDetails = useAppSelector((s) => s.ui.shouldShowImageDetails);
|
||||
const imageName = useAppSelector(selectLastSelectedImageName);
|
||||
@@ -41,6 +55,14 @@ const CurrentImagePreview = () => {
|
||||
}
|
||||
}, [imageDTO]);
|
||||
|
||||
const droppableData = useMemo<TypesafeDroppableData | undefined>(
|
||||
() => ({
|
||||
id: 'current-image',
|
||||
actionType: 'SET_CURRENT_IMAGE',
|
||||
}),
|
||||
[]
|
||||
);
|
||||
|
||||
// Show and hide the next/prev buttons on mouse move
|
||||
const [shouldShowNextPrevButtons, setShouldShowNextPrevButtons] = useState<boolean>(false);
|
||||
const timeoutId = useRef(0);
|
||||
@@ -64,27 +86,30 @@ const CurrentImagePreview = () => {
|
||||
justifyContent="center"
|
||||
position="relative"
|
||||
>
|
||||
{hasDenoiseProgress && shouldShowProgressInViewer ? (
|
||||
{hasDenoiseProgress && (shouldShowProgressInViewer || alwaysShowProgress) ? (
|
||||
<ProgressImage />
|
||||
) : (
|
||||
<IAIDndImage
|
||||
imageDTO={imageDTO}
|
||||
droppableData={droppableData}
|
||||
draggableData={draggableData}
|
||||
isDropDisabled={true}
|
||||
isDragDisabled={isDragDisabled}
|
||||
isDropDisabled={isDropDisabled}
|
||||
isUploadDisabled={true}
|
||||
fitContainer
|
||||
useThumbailFallback
|
||||
dropLabel={t('gallery.setCurrentImage')}
|
||||
noContentFallback={<IAINoContentFallback icon={PiImageBold} label={t('gallery.noImageSelected')} />}
|
||||
dataTestId="image-preview"
|
||||
/>
|
||||
)}
|
||||
{shouldShowImageDetails && imageDTO && (
|
||||
{shouldShowImageDetails && imageDTO && withMetadata && (
|
||||
<Box position="absolute" opacity={0.8} top={0} width="full" height="full" borderRadius="base">
|
||||
<ImageMetadataViewer image={imageDTO} />
|
||||
</Box>
|
||||
)}
|
||||
<AnimatePresence>
|
||||
{shouldShowNextPrevButtons && imageDTO && (
|
||||
{withNextPrevButtons && shouldShowNextPrevButtons && imageDTO && (
|
||||
<Box
|
||||
as={motion.div}
|
||||
key="nextPrevButtons"
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||
import { selectComparisonImages } from 'features/gallery/components/ImageViewer/common';
|
||||
import { ImageComparisonHover } from 'features/gallery/components/ImageViewer/ImageComparisonHover';
|
||||
import { ImageComparisonSideBySide } from 'features/gallery/components/ImageViewer/ImageComparisonSideBySide';
|
||||
import { ImageComparisonSlider } from 'features/gallery/components/ImageViewer/ImageComparisonSlider';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiImagesBold } from 'react-icons/pi';
|
||||
|
||||
type Props = {
|
||||
containerDims: Dimensions;
|
||||
};
|
||||
|
||||
export const ImageComparison = memo(({ containerDims }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const comparisonMode = useAppSelector((s) => s.gallery.comparisonMode);
|
||||
const { firstImage, secondImage } = useAppSelector(selectComparisonImages);
|
||||
|
||||
if (!firstImage || !secondImage) {
|
||||
// Should rarely/never happen - we don't render this component unless we have images to compare
|
||||
return <IAINoContentFallback label={t('gallery.selectAnImageToCompare')} icon={PiImagesBold} />;
|
||||
}
|
||||
|
||||
if (comparisonMode === 'slider') {
|
||||
return <ImageComparisonSlider containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />;
|
||||
}
|
||||
|
||||
if (comparisonMode === 'side-by-side') {
|
||||
return (
|
||||
<ImageComparisonSideBySide containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />
|
||||
);
|
||||
}
|
||||
|
||||
if (comparisonMode === 'hover') {
|
||||
return <ImageComparisonHover containerDims={containerDims} firstImage={firstImage} secondImage={secondImage} />;
|
||||
}
|
||||
});
|
||||
|
||||
ImageComparison.displayName = 'ImageComparison';
|
||||
@@ -1,47 +0,0 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import IAIDroppable from 'common/components/IAIDroppable';
|
||||
import type { CurrentImageDropData, SelectForCompareDropData } from 'features/dnd/types';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { selectComparisonImages } from './common';
|
||||
|
||||
const setCurrentImageDropData: CurrentImageDropData = {
|
||||
id: 'current-image',
|
||||
actionType: 'SET_CURRENT_IMAGE',
|
||||
};
|
||||
|
||||
export const ImageComparisonDroppable = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const imageViewer = useImageViewer();
|
||||
const { firstImage, secondImage } = useAppSelector(selectComparisonImages);
|
||||
const selectForCompareDropData = useMemo<SelectForCompareDropData>(
|
||||
() => ({
|
||||
id: 'image-comparison',
|
||||
actionType: 'SELECT_FOR_COMPARE',
|
||||
context: {
|
||||
firstImageName: firstImage?.image_name,
|
||||
secondImageName: secondImage?.image_name,
|
||||
},
|
||||
}),
|
||||
[firstImage?.image_name, secondImage?.image_name]
|
||||
);
|
||||
|
||||
if (!imageViewer.isOpen) {
|
||||
return (
|
||||
<Flex position="absolute" top={0} right={0} bottom={0} left={0} gap={2} pointerEvents="none">
|
||||
<IAIDroppable data={setCurrentImageDropData} dropLabel={t('gallery.openInViewer')} />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex position="absolute" top={0} right={0} bottom={0} left={0} gap={2} pointerEvents="none">
|
||||
<IAIDroppable data={selectForCompareDropData} dropLabel={t('gallery.selectForCompare')} />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ImageComparisonDroppable.displayName = 'ImageComparisonDroppable';
|
||||
@@ -1,117 +0,0 @@
|
||||
import { Box, Flex, Image } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useBoolean } from 'common/hooks/useBoolean';
|
||||
import { preventDefault } from 'common/util/stopPropagation';
|
||||
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
|
||||
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
|
||||
import { memo, useMemo, useRef } from 'react';
|
||||
|
||||
import type { ComparisonProps } from './common';
|
||||
import { fitDimsToContainer, getSecondImageDims } from './common';
|
||||
|
||||
export const ImageComparisonHover = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
|
||||
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
|
||||
const imageContainerRef = useRef<HTMLDivElement>(null);
|
||||
const mouseOver = useBoolean(false);
|
||||
const fittedDims = useMemo<Dimensions>(
|
||||
() => fitDimsToContainer(containerDims, firstImage),
|
||||
[containerDims, firstImage]
|
||||
);
|
||||
const compareImageDims = useMemo<Dimensions>(
|
||||
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),
|
||||
[comparisonFit, fittedDims, firstImage, secondImage]
|
||||
);
|
||||
return (
|
||||
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
|
||||
<Flex
|
||||
id="image-comparison-wrapper"
|
||||
w="full"
|
||||
h="full"
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
position="absolute"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<Box
|
||||
ref={imageContainerRef}
|
||||
position="relative"
|
||||
id="image-comparison-hover-image-container"
|
||||
w={fittedDims.width}
|
||||
h={fittedDims.height}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
userSelect="none"
|
||||
overflow="hidden"
|
||||
borderRadius="base"
|
||||
>
|
||||
<Image
|
||||
id="image-comparison-hover-first-image"
|
||||
src={firstImage.image_url}
|
||||
fallbackSrc={firstImage.thumbnail_url}
|
||||
w={fittedDims.width}
|
||||
h={fittedDims.height}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
objectFit="cover"
|
||||
objectPosition="top left"
|
||||
/>
|
||||
<ImageComparisonLabel type="first" opacity={mouseOver.isTrue ? 0 : 1} />
|
||||
|
||||
<Box
|
||||
id="image-comparison-hover-second-image-container"
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
overflow="hidden"
|
||||
opacity={mouseOver.isTrue ? 1 : 0}
|
||||
transitionDuration="0.2s"
|
||||
transitionProperty="common"
|
||||
>
|
||||
<Box
|
||||
id="image-comparison-hover-bg"
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
backgroundImage={STAGE_BG_DATAURL}
|
||||
backgroundRepeat="repeat"
|
||||
opacity={0.2}
|
||||
/>
|
||||
<Image
|
||||
position="relative"
|
||||
id="image-comparison-hover-second-image"
|
||||
src={secondImage.image_url}
|
||||
fallbackSrc={secondImage.thumbnail_url}
|
||||
w={compareImageDims.width}
|
||||
h={compareImageDims.height}
|
||||
maxW={fittedDims.width}
|
||||
maxH={fittedDims.height}
|
||||
objectFit={comparisonFit}
|
||||
objectPosition="top left"
|
||||
/>
|
||||
<ImageComparisonLabel type="second" opacity={mouseOver.isTrue ? 1 : 0} />
|
||||
</Box>
|
||||
<Box
|
||||
id="image-comparison-hover-interaction-overlay"
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
onMouseOver={mouseOver.setTrue}
|
||||
onMouseOut={mouseOver.setFalse}
|
||||
onContextMenu={preventDefault}
|
||||
userSelect="none"
|
||||
/>
|
||||
</Box>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ImageComparisonHover.displayName = 'ImageComparisonHover';
|
||||
@@ -1,33 +0,0 @@
|
||||
import type { TextProps } from '@invoke-ai/ui-library';
|
||||
import { Text } from '@invoke-ai/ui-library';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { DROP_SHADOW } from './common';
|
||||
|
||||
type Props = TextProps & {
|
||||
type: 'first' | 'second';
|
||||
};
|
||||
|
||||
export const ImageComparisonLabel = memo(({ type, ...rest }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<Text
|
||||
position="absolute"
|
||||
bottom={4}
|
||||
insetInlineEnd={type === 'first' ? undefined : 4}
|
||||
insetInlineStart={type === 'first' ? 4 : undefined}
|
||||
textOverflow="clip"
|
||||
whiteSpace="nowrap"
|
||||
filter={DROP_SHADOW}
|
||||
color="base.50"
|
||||
transitionDuration="0.2s"
|
||||
transitionProperty="common"
|
||||
{...rest}
|
||||
>
|
||||
{type === 'first' ? t('gallery.viewerImage') : t('gallery.compareImage')}
|
||||
</Text>
|
||||
);
|
||||
});
|
||||
|
||||
ImageComparisonLabel.displayName = 'ImageComparisonLabel';
|
||||
@@ -1,70 +0,0 @@
|
||||
import { Flex, Image } from '@invoke-ai/ui-library';
|
||||
import type { ComparisonProps } from 'features/gallery/components/ImageViewer/common';
|
||||
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
|
||||
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
|
||||
import { memo, useCallback, useRef } from 'react';
|
||||
import type { ImperativePanelGroupHandle } from 'react-resizable-panels';
|
||||
import { Panel, PanelGroup } from 'react-resizable-panels';
|
||||
|
||||
export const ImageComparisonSideBySide = memo(({ firstImage, secondImage }: ComparisonProps) => {
|
||||
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
|
||||
const onDoubleClickHandle = useCallback(() => {
|
||||
if (!panelGroupRef.current) {
|
||||
return;
|
||||
}
|
||||
panelGroupRef.current.setLayout([50, 50]);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
|
||||
<Flex w="full" h="full" maxW="full" maxH="full" position="absolute" alignItems="center" justifyContent="center">
|
||||
<PanelGroup ref={panelGroupRef} direction="horizontal" id="image-comparison-side-by-side">
|
||||
<Panel minSize={20}>
|
||||
<Flex position="relative" w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Flex position="absolute" maxW="full" maxH="full" aspectRatio={firstImage.width / firstImage.height}>
|
||||
<Image
|
||||
id="image-comparison-side-by-side-first-image"
|
||||
w={firstImage.width}
|
||||
h={firstImage.height}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
src={firstImage.image_url}
|
||||
fallbackSrc={firstImage.thumbnail_url}
|
||||
objectFit="contain"
|
||||
borderRadius="base"
|
||||
/>
|
||||
<ImageComparisonLabel type="first" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Panel>
|
||||
<ResizeHandle
|
||||
id="image-comparison-side-by-side-handle"
|
||||
onDoubleClick={onDoubleClickHandle}
|
||||
orientation="vertical"
|
||||
/>
|
||||
|
||||
<Panel minSize={20}>
|
||||
<Flex position="relative" w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Flex position="absolute" maxW="full" maxH="full" aspectRatio={secondImage.width / secondImage.height}>
|
||||
<Image
|
||||
id="image-comparison-side-by-side-first-image"
|
||||
w={secondImage.width}
|
||||
h={secondImage.height}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
src={secondImage.image_url}
|
||||
fallbackSrc={secondImage.thumbnail_url}
|
||||
objectFit="contain"
|
||||
borderRadius="base"
|
||||
/>
|
||||
<ImageComparisonLabel type="second" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Panel>
|
||||
</PanelGroup>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ImageComparisonSideBySide.displayName = 'ImageComparisonSideBySide';
|
||||
@@ -1,215 +0,0 @@
|
||||
import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { preventDefault } from 'common/util/stopPropagation';
|
||||
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||
import { STAGE_BG_DATAURL } from 'features/controlLayers/util/renderers';
|
||||
import { ImageComparisonLabel } from 'features/gallery/components/ImageViewer/ImageComparisonLabel';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { PiCaretLeftBold, PiCaretRightBold } from 'react-icons/pi';
|
||||
|
||||
import type { ComparisonProps } from './common';
|
||||
import { DROP_SHADOW, fitDimsToContainer, getSecondImageDims } from './common';
|
||||
|
||||
const INITIAL_POS = '50%';
|
||||
const HANDLE_WIDTH = 2;
|
||||
const HANDLE_WIDTH_PX = `${HANDLE_WIDTH}px`;
|
||||
const HANDLE_HITBOX = 20;
|
||||
const HANDLE_HITBOX_PX = `${HANDLE_HITBOX}px`;
|
||||
const HANDLE_INNER_LEFT_PX = `${HANDLE_HITBOX / 2 - HANDLE_WIDTH / 2}px`;
|
||||
const HANDLE_LEFT_INITIAL_PX = `calc(${INITIAL_POS} - ${HANDLE_HITBOX / 2}px)`;
|
||||
|
||||
export const ImageComparisonSlider = memo(({ firstImage, secondImage, containerDims }: ComparisonProps) => {
|
||||
const comparisonFit = useAppSelector((s) => s.gallery.comparisonFit);
|
||||
// How far the handle is from the left - this will be a CSS calculation that takes into account the handle width
|
||||
const [left, setLeft] = useState(HANDLE_LEFT_INITIAL_PX);
|
||||
// How wide the first image is
|
||||
const [width, setWidth] = useState(INITIAL_POS);
|
||||
const handleRef = useRef<HTMLDivElement>(null);
|
||||
// To manage aspect ratios, we need to know the size of the container
|
||||
const imageContainerRef = useRef<HTMLDivElement>(null);
|
||||
// To keep things smooth, we use RAF to update the handle position & gate it to 60fps
|
||||
const rafRef = useRef<number | null>(null);
|
||||
const lastMoveTimeRef = useRef<number>(0);
|
||||
|
||||
const fittedDims = useMemo<Dimensions>(
|
||||
() => fitDimsToContainer(containerDims, firstImage),
|
||||
[containerDims, firstImage]
|
||||
);
|
||||
|
||||
const compareImageDims = useMemo<Dimensions>(
|
||||
() => getSecondImageDims(comparisonFit, fittedDims, firstImage, secondImage),
|
||||
[comparisonFit, fittedDims, firstImage, secondImage]
|
||||
);
|
||||
|
||||
const updateHandlePos = useCallback((clientX: number) => {
|
||||
if (!handleRef.current || !imageContainerRef.current) {
|
||||
return;
|
||||
}
|
||||
lastMoveTimeRef.current = performance.now();
|
||||
const { x, width } = imageContainerRef.current.getBoundingClientRect();
|
||||
const rawHandlePos = ((clientX - x) * 100) / width;
|
||||
const handleWidthPct = (HANDLE_WIDTH * 100) / width;
|
||||
const newHandlePos = Math.min(100 - handleWidthPct, Math.max(0, rawHandlePos));
|
||||
setWidth(`${newHandlePos}%`);
|
||||
setLeft(`calc(${newHandlePos}% - ${HANDLE_HITBOX / 2}px)`);
|
||||
}, []);
|
||||
|
||||
const onMouseMove = useCallback(
|
||||
(e: MouseEvent) => {
|
||||
if (rafRef.current === null && performance.now() > lastMoveTimeRef.current + 1000 / 60) {
|
||||
rafRef.current = window.requestAnimationFrame(() => {
|
||||
updateHandlePos(e.clientX);
|
||||
rafRef.current = null;
|
||||
});
|
||||
}
|
||||
},
|
||||
[updateHandlePos]
|
||||
);
|
||||
|
||||
const onMouseUp = useCallback(() => {
|
||||
window.removeEventListener('mousemove', onMouseMove);
|
||||
}, [onMouseMove]);
|
||||
|
||||
const onMouseDown = useCallback(
|
||||
(e: React.MouseEvent<HTMLDivElement>) => {
|
||||
// Update the handle position immediately on click
|
||||
updateHandlePos(e.clientX);
|
||||
window.addEventListener('mouseup', onMouseUp, { once: true });
|
||||
window.addEventListener('mousemove', onMouseMove);
|
||||
},
|
||||
[onMouseMove, onMouseUp, updateHandlePos]
|
||||
);
|
||||
|
||||
useEffect(
|
||||
() => () => {
|
||||
if (rafRef.current !== null) {
|
||||
cancelAnimationFrame(rafRef.current);
|
||||
}
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex w="full" h="full" maxW="full" maxH="full" position="relative" alignItems="center" justifyContent="center">
|
||||
<Flex
|
||||
id="image-comparison-wrapper"
|
||||
w="full"
|
||||
h="full"
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
position="absolute"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<Box
|
||||
ref={imageContainerRef}
|
||||
position="relative"
|
||||
id="image-comparison-image-container"
|
||||
w={fittedDims.width}
|
||||
h={fittedDims.height}
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
userSelect="none"
|
||||
overflow="hidden"
|
||||
borderRadius="base"
|
||||
>
|
||||
<Box
|
||||
id="image-comparison-bg"
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
backgroundImage={STAGE_BG_DATAURL}
|
||||
backgroundRepeat="repeat"
|
||||
opacity={0.2}
|
||||
/>
|
||||
<Image
|
||||
position="relative"
|
||||
id="image-comparison-second-image"
|
||||
src={secondImage.image_url}
|
||||
fallbackSrc={secondImage.thumbnail_url}
|
||||
w={compareImageDims.width}
|
||||
h={compareImageDims.height}
|
||||
maxW={fittedDims.width}
|
||||
maxH={fittedDims.height}
|
||||
objectFit={comparisonFit}
|
||||
objectPosition="top left"
|
||||
/>
|
||||
<ImageComparisonLabel type="second" />
|
||||
<Box
|
||||
id="image-comparison-first-image-container"
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
w={width}
|
||||
overflow="hidden"
|
||||
>
|
||||
<Image
|
||||
id="image-comparison-first-image"
|
||||
src={firstImage.image_url}
|
||||
fallbackSrc={firstImage.thumbnail_url}
|
||||
w={fittedDims.width}
|
||||
h={fittedDims.height}
|
||||
objectFit="cover"
|
||||
objectPosition="top left"
|
||||
/>
|
||||
<ImageComparisonLabel type="first" />
|
||||
</Box>
|
||||
<Flex
|
||||
id="image-comparison-handle"
|
||||
ref={handleRef}
|
||||
position="absolute"
|
||||
top={0}
|
||||
bottom={0}
|
||||
left={left}
|
||||
w={HANDLE_HITBOX_PX}
|
||||
cursor="ew-resize"
|
||||
filter={DROP_SHADOW}
|
||||
opacity={0.8}
|
||||
color="base.50"
|
||||
>
|
||||
<Box
|
||||
id="image-comparison-handle-divider"
|
||||
w={HANDLE_WIDTH_PX}
|
||||
h="full"
|
||||
bg="currentColor"
|
||||
shadow="dark-lg"
|
||||
position="absolute"
|
||||
top={0}
|
||||
left={HANDLE_INNER_LEFT_PX}
|
||||
/>
|
||||
<Flex
|
||||
id="image-comparison-handle-icons"
|
||||
gap={4}
|
||||
position="absolute"
|
||||
left="50%"
|
||||
top="50%"
|
||||
transform="translate(-50%, 0)"
|
||||
filter={DROP_SHADOW}
|
||||
>
|
||||
<Icon as={PiCaretLeftBold} />
|
||||
<Icon as={PiCaretRightBold} />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Box
|
||||
id="image-comparison-interaction-overlay"
|
||||
position="absolute"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
onMouseDown={onMouseDown}
|
||||
onContextMenu={preventDefault}
|
||||
userSelect="none"
|
||||
cursor="ew-resize"
|
||||
/>
|
||||
</Box>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ImageComparisonSlider.displayName = 'ImageComparisonSlider';
|
||||
@@ -1,16 +1,36 @@
|
||||
import { Box, Flex } from '@invoke-ai/ui-library';
|
||||
import { CompareToolbar } from 'features/gallery/components/ImageViewer/CompareToolbar';
|
||||
import CurrentImagePreview from 'features/gallery/components/ImageViewer/CurrentImagePreview';
|
||||
import { ImageComparison } from 'features/gallery/components/ImageViewer/ImageComparison';
|
||||
import { ViewerToolbar } from 'features/gallery/components/ImageViewer/ViewerToolbar';
|
||||
import { memo } from 'react';
|
||||
import { useMeasure } from 'react-use';
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
|
||||
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import type { InvokeTabName } from 'features/ui/store/tabMap';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
|
||||
import { useImageViewer } from './useImageViewer';
|
||||
import CurrentImageButtons from './CurrentImageButtons';
|
||||
import CurrentImagePreview from './CurrentImagePreview';
|
||||
import { ViewerToggleMenu } from './ViewerToggleMenu';
|
||||
|
||||
const VIEWER_ENABLED_TABS: InvokeTabName[] = ['canvas', 'generation', 'workflows'];
|
||||
|
||||
export const ImageViewer = memo(() => {
|
||||
const imageViewer = useImageViewer();
|
||||
const [containerRef, containerDims] = useMeasure<HTMLDivElement>();
|
||||
const { isOpen, onToggle, onClose } = useImageViewer();
|
||||
const activeTabName = useAppSelector(activeTabNameSelector);
|
||||
const isViewerEnabled = useMemo(() => VIEWER_ENABLED_TABS.includes(activeTabName), [activeTabName]);
|
||||
const shouldShowViewer = useMemo(() => {
|
||||
if (!isViewerEnabled) {
|
||||
return false;
|
||||
}
|
||||
return isOpen;
|
||||
}, [isOpen, isViewerEnabled]);
|
||||
|
||||
useHotkeys('z', onToggle, { enabled: isViewerEnabled }, [isViewerEnabled, onToggle]);
|
||||
useHotkeys('esc', onClose, { enabled: isViewerEnabled }, [isViewerEnabled, onClose]);
|
||||
|
||||
if (!shouldShowViewer) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@@ -26,13 +46,25 @@ export const ImageViewer = memo(() => {
|
||||
rowGap={4}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
zIndex={10} // reactflow puts its minimap at 5, so we need to be above that
|
||||
>
|
||||
{imageViewer.isComparing && <CompareToolbar />}
|
||||
{!imageViewer.isComparing && <ViewerToolbar />}
|
||||
<Box ref={containerRef} w="full" h="full">
|
||||
{!imageViewer.isComparing && <CurrentImagePreview />}
|
||||
{imageViewer.isComparing && <ImageComparison containerDims={containerDims} />}
|
||||
</Box>
|
||||
<Flex w="full" gap={2}>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineEnd="auto">
|
||||
<ToggleProgressButton />
|
||||
<ToggleMetadataViewerButton />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex flex={1} gap={2} justifyContent="center">
|
||||
<CurrentImageButtons />
|
||||
</Flex>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineStart="auto">
|
||||
<ViewerToggleMenu />
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<CurrentImagePreview />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
|
||||
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
|
||||
import { memo } from 'react';
|
||||
|
||||
import CurrentImageButtons from './CurrentImageButtons';
|
||||
import CurrentImagePreview from './CurrentImagePreview';
|
||||
|
||||
export const ImageViewerWorkflows = memo(() => {
|
||||
return (
|
||||
<Flex
|
||||
layerStyle="first"
|
||||
borderRadius="base"
|
||||
position="absolute"
|
||||
flexDirection="column"
|
||||
top={0}
|
||||
right={0}
|
||||
bottom={0}
|
||||
left={0}
|
||||
p={2}
|
||||
rowGap={4}
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
zIndex={10} // reactflow puts its minimap at 5, so we need to be above that
|
||||
>
|
||||
<Flex w="full" gap={2}>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineEnd="auto">
|
||||
<ToggleProgressButton />
|
||||
<ToggleMetadataViewerButton />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex flex={1} gap={2} justifyContent="center">
|
||||
<CurrentImageButtons />
|
||||
</Flex>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineStart="auto" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<CurrentImagePreview />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ImageViewerWorkflows.displayName = 'ImageViewerWorkflows';
|
||||
@@ -9,35 +9,33 @@ import {
|
||||
PopoverTrigger,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useImageViewer } from 'features/gallery/components/ImageViewer/useImageViewer';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCaretDownBold, PiCheckBold, PiEyeBold, PiPencilBold } from 'react-icons/pi';
|
||||
|
||||
import { useImageViewer } from './useImageViewer';
|
||||
|
||||
export const ViewerToggleMenu = () => {
|
||||
const { t } = useTranslation();
|
||||
const imageViewer = useImageViewer();
|
||||
useHotkeys('z', imageViewer.onToggle, [imageViewer]);
|
||||
useHotkeys('esc', imageViewer.onClose, [imageViewer]);
|
||||
const { isOpen, onClose, onOpen } = useImageViewer();
|
||||
|
||||
return (
|
||||
<Popover isLazy>
|
||||
<PopoverTrigger>
|
||||
<Button variant="outline" data-testid="toggle-viewer-menu-button" pointerEvents="auto">
|
||||
<Button variant="outline" data-testid="toggle-viewer-menu-button">
|
||||
<Flex gap={3} w="full" alignItems="center">
|
||||
{imageViewer.isOpen ? <Icon as={PiEyeBold} /> : <Icon as={PiPencilBold} />}
|
||||
<Text fontSize="md">{imageViewer.isOpen ? t('common.viewing') : t('common.editing')}</Text>
|
||||
{isOpen ? <Icon as={PiEyeBold} /> : <Icon as={PiPencilBold} />}
|
||||
<Text fontSize="md">{isOpen ? t('common.viewing') : t('common.editing')}</Text>
|
||||
<Icon as={PiCaretDownBold} />
|
||||
</Flex>
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent p={2} pointerEvents="auto">
|
||||
<PopoverContent p={2}>
|
||||
<PopoverArrow />
|
||||
<PopoverBody>
|
||||
<Flex flexDir="column">
|
||||
<Button onClick={imageViewer.onOpen} variant="ghost" h="auto" w="auto" p={2}>
|
||||
<Button onClick={onOpen} variant="ghost" h="auto" w="auto" p={2}>
|
||||
<Flex gap={2} w="full">
|
||||
<Icon as={PiCheckBold} visibility={imageViewer.isOpen ? 'visible' : 'hidden'} />
|
||||
<Icon as={PiCheckBold} visibility={isOpen ? 'visible' : 'hidden'} />
|
||||
<Flex flexDir="column" gap={2} alignItems="flex-start">
|
||||
<Text fontWeight="semibold" color="base.100">
|
||||
{t('common.viewing')}
|
||||
@@ -48,9 +46,9 @@ export const ViewerToggleMenu = () => {
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Button>
|
||||
<Button onClick={imageViewer.onClose} variant="ghost" h="auto" w="auto" p={2}>
|
||||
<Button onClick={onClose} variant="ghost" h="auto" w="auto" p={2}>
|
||||
<Flex gap={2} w="full">
|
||||
<Icon as={PiCheckBold} visibility={imageViewer.isOpen ? 'hidden' : 'visible'} />
|
||||
<Icon as={PiCheckBold} visibility={isOpen ? 'hidden' : 'visible'} />
|
||||
<Flex flexDir="column" gap={2} alignItems="flex-start">
|
||||
<Text fontWeight="semibold" color="base.100">
|
||||
{t('common.editing')}
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { ToggleMetadataViewerButton } from 'features/gallery/components/ImageViewer/ToggleMetadataViewerButton';
|
||||
import { ToggleProgressButton } from 'features/gallery/components/ImageViewer/ToggleProgressButton';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { memo } from 'react';
|
||||
|
||||
import CurrentImageButtons from './CurrentImageButtons';
|
||||
import { ViewerToggleMenu } from './ViewerToggleMenu';
|
||||
|
||||
export const ViewerToolbar = memo(() => {
|
||||
const tab = useAppSelector(activeTabNameSelector);
|
||||
return (
|
||||
<Flex w="full" gap={2}>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineEnd="auto">
|
||||
<ToggleProgressButton />
|
||||
<ToggleMetadataViewerButton />
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex flex={1} gap={2} justifyContent="center">
|
||||
<CurrentImageButtons />
|
||||
</Flex>
|
||||
<Flex flex={1} justifyContent="center">
|
||||
<Flex gap={2} marginInlineStart="auto">
|
||||
{tab !== 'workflows' && <ViewerToggleMenu />}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
ViewerToolbar.displayName = 'ViewerToolbar';
|
||||
@@ -1,64 +0,0 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import type { Dimensions } from 'features/canvas/store/canvasTypes';
|
||||
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import type { ComparisonFit } from 'features/gallery/store/types';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
|
||||
|
||||
export type ComparisonProps = {
|
||||
firstImage: ImageDTO;
|
||||
secondImage: ImageDTO;
|
||||
containerDims: Dimensions;
|
||||
};
|
||||
|
||||
export const fitDimsToContainer = (containerDims: Dimensions, imageDims: Dimensions): Dimensions => {
|
||||
// Fall back to the image's dimensions if the container has no dimensions
|
||||
if (containerDims.width === 0 || containerDims.height === 0) {
|
||||
return { width: imageDims.width, height: imageDims.height };
|
||||
}
|
||||
|
||||
// Fall back to the image's dimensions if the image fits within the container
|
||||
if (imageDims.width <= containerDims.width && imageDims.height <= containerDims.height) {
|
||||
return { width: imageDims.width, height: imageDims.height };
|
||||
}
|
||||
|
||||
const targetAspectRatio = containerDims.width / containerDims.height;
|
||||
const imageAspectRatio = imageDims.width / imageDims.height;
|
||||
|
||||
let width: number;
|
||||
let height: number;
|
||||
|
||||
if (imageAspectRatio > targetAspectRatio) {
|
||||
// Image is wider than container's aspect ratio
|
||||
width = containerDims.width;
|
||||
height = width / imageAspectRatio;
|
||||
} else {
|
||||
// Image is taller than container's aspect ratio
|
||||
height = containerDims.height;
|
||||
width = height * imageAspectRatio;
|
||||
}
|
||||
return { width, height };
|
||||
};
|
||||
|
||||
/**
|
||||
* Gets the dimensions of the second image in a comparison based on the comparison fit mode.
|
||||
*/
|
||||
export const getSecondImageDims = (
|
||||
comparisonFit: ComparisonFit,
|
||||
fittedDims: Dimensions,
|
||||
firstImageDims: Dimensions,
|
||||
secondImageDims: Dimensions
|
||||
): Dimensions => {
|
||||
const width =
|
||||
comparisonFit === 'fill' ? fittedDims.width : (fittedDims.width * secondImageDims.width) / firstImageDims.width;
|
||||
const height =
|
||||
comparisonFit === 'fill' ? fittedDims.height : (fittedDims.height * secondImageDims.height) / firstImageDims.height;
|
||||
|
||||
return { width, height };
|
||||
};
|
||||
export const selectComparisonImages = createMemoizedSelector(selectGallerySlice, (gallerySlice) => {
|
||||
const firstImage = gallerySlice.selection.slice(-1)[0] ?? null;
|
||||
const secondImage = gallerySlice.imageToCompare;
|
||||
return { firstImage, secondImage };
|
||||
});
|
||||
@@ -1,31 +0,0 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { imageToCompareChanged, isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
export const useImageViewer = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const isComparing = useAppSelector((s) => s.gallery.imageToCompare !== null);
|
||||
const isOpen = useAppSelector((s) => s.gallery.isImageViewerOpen);
|
||||
|
||||
const onClose = useCallback(() => {
|
||||
if (isComparing && isOpen) {
|
||||
dispatch(imageToCompareChanged(null));
|
||||
} else {
|
||||
dispatch(isImageViewerOpenChanged(false));
|
||||
}
|
||||
}, [dispatch, isComparing, isOpen]);
|
||||
|
||||
const onOpen = useCallback(() => {
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
}, [dispatch]);
|
||||
|
||||
const onToggle = useCallback(() => {
|
||||
if (isComparing && isOpen) {
|
||||
dispatch(imageToCompareChanged(null));
|
||||
} else {
|
||||
dispatch(isImageViewerOpenChanged(!isOpen));
|
||||
}
|
||||
}, [dispatch, isComparing, isOpen]);
|
||||
|
||||
return { isOpen, onOpen, onClose, onToggle, isComparing };
|
||||
};
|
||||
@@ -0,0 +1,22 @@
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { isImageViewerOpenChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useCallback } from 'react';
|
||||
|
||||
export const useImageViewer = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const isOpen = useAppSelector((s) => s.gallery.isImageViewerOpen);
|
||||
|
||||
const onClose = useCallback(() => {
|
||||
dispatch(isImageViewerOpenChanged(false));
|
||||
}, [dispatch]);
|
||||
|
||||
const onOpen = useCallback(() => {
|
||||
dispatch(isImageViewerOpenChanged(true));
|
||||
}, [dispatch]);
|
||||
|
||||
const onToggle = useCallback(() => {
|
||||
dispatch(isImageViewerOpenChanged(!isOpen));
|
||||
}, [dispatch, isOpen]);
|
||||
|
||||
return { isOpen, onOpen, onClose, onToggle };
|
||||
};
|
||||
@@ -14,7 +14,7 @@ const nextPrevButtonStyles: ChakraProps['sx'] = {
|
||||
const NextPrevImageButtons = () => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const { prevImage, nextImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
|
||||
const { handleLeftImage, handleRightImage, isOnFirstImage, isOnLastImage } = useGalleryNavigation();
|
||||
|
||||
const {
|
||||
areMoreImagesAvailable,
|
||||
@@ -30,7 +30,7 @@ const NextPrevImageButtons = () => {
|
||||
aria-label={t('accessibility.previousImage')}
|
||||
icon={<PiCaretLeftBold size={64} />}
|
||||
variant="unstyled"
|
||||
onClick={prevImage}
|
||||
onClick={handleLeftImage}
|
||||
boxSize={16}
|
||||
sx={nextPrevButtonStyles}
|
||||
/>
|
||||
@@ -42,7 +42,7 @@ const NextPrevImageButtons = () => {
|
||||
aria-label={t('accessibility.nextImage')}
|
||||
icon={<PiCaretRightBold size={64} />}
|
||||
variant="unstyled"
|
||||
onClick={nextImage}
|
||||
onClick={handleRightImage}
|
||||
boxSize={16}
|
||||
sx={nextPrevButtonStyles}
|
||||
/>
|
||||
|
||||
@@ -27,16 +27,16 @@ export const useGalleryHotkeys = () => {
|
||||
useGalleryNavigation();
|
||||
|
||||
useHotkeys(
|
||||
['left', 'alt+left'],
|
||||
(e) => {
|
||||
canNavigateGallery && handleLeftImage(e.altKey);
|
||||
'left',
|
||||
() => {
|
||||
canNavigateGallery && handleLeftImage();
|
||||
},
|
||||
[handleLeftImage, canNavigateGallery]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
['right', 'alt+right'],
|
||||
(e) => {
|
||||
'right',
|
||||
() => {
|
||||
if (!canNavigateGallery) {
|
||||
return;
|
||||
}
|
||||
@@ -45,29 +45,29 @@ export const useGalleryHotkeys = () => {
|
||||
return;
|
||||
}
|
||||
if (!isOnLastImage) {
|
||||
handleRightImage(e.altKey);
|
||||
handleRightImage();
|
||||
}
|
||||
},
|
||||
[isOnLastImage, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleRightImage, canNavigateGallery]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
['up', 'alt+up'],
|
||||
(e) => {
|
||||
handleUpImage(e.altKey);
|
||||
'up',
|
||||
() => {
|
||||
handleUpImage();
|
||||
},
|
||||
{ preventDefault: true },
|
||||
[handleUpImage]
|
||||
);
|
||||
|
||||
useHotkeys(
|
||||
['down', 'alt+down'],
|
||||
(e) => {
|
||||
'down',
|
||||
() => {
|
||||
if (!areImagesBelowCurrent && areMoreImagesAvailable && !isFetching) {
|
||||
handleLoadMoreImages();
|
||||
return;
|
||||
}
|
||||
handleDownImage(e.altKey);
|
||||
handleDownImage();
|
||||
},
|
||||
{ preventDefault: true },
|
||||
[areImagesBelowCurrent, areMoreImagesAvailable, handleLoadMoreImages, isFetching, handleDownImage]
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { useAltModifier } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
|
||||
import { imageItemContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridItemContainer';
|
||||
import { imageListContainerTestId } from 'features/gallery/components/ImageGrid/ImageGridListContainer';
|
||||
import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types';
|
||||
import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages';
|
||||
import { imageSelected, imageToCompareChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { getIsVisible } from 'features/gallery/util/getIsVisible';
|
||||
import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign';
|
||||
import { clamp } from 'lodash-es';
|
||||
@@ -106,12 +106,10 @@ const getImageFuncs = {
|
||||
};
|
||||
|
||||
type UseGalleryNavigationReturn = {
|
||||
handleLeftImage: (alt?: boolean) => void;
|
||||
handleRightImage: (alt?: boolean) => void;
|
||||
handleUpImage: (alt?: boolean) => void;
|
||||
handleDownImage: (alt?: boolean) => void;
|
||||
prevImage: () => void;
|
||||
nextImage: () => void;
|
||||
handleLeftImage: () => void;
|
||||
handleRightImage: () => void;
|
||||
handleUpImage: () => void;
|
||||
handleDownImage: () => void;
|
||||
isOnFirstImage: boolean;
|
||||
isOnLastImage: boolean;
|
||||
areImagesBelowCurrent: boolean;
|
||||
@@ -125,15 +123,7 @@ type UseGalleryNavigationReturn = {
|
||||
*/
|
||||
export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
const dispatch = useAppDispatch();
|
||||
const alt = useAltModifier();
|
||||
const lastSelectedImage = useAppSelector((s) => {
|
||||
const lastSelected = s.gallery.selection.slice(-1)[0] ?? null;
|
||||
if (alt) {
|
||||
return s.gallery.imageToCompare ?? lastSelected;
|
||||
} else {
|
||||
return lastSelected;
|
||||
}
|
||||
});
|
||||
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
|
||||
const {
|
||||
queryResult: { data },
|
||||
} = useGalleryImages();
|
||||
@@ -146,7 +136,7 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
}, [lastSelectedImage, data]);
|
||||
|
||||
const handleNavigation = useCallback(
|
||||
(direction: 'left' | 'right' | 'up' | 'down', alt?: boolean) => {
|
||||
(direction: 'left' | 'right' | 'up' | 'down') => {
|
||||
if (!data) {
|
||||
return;
|
||||
}
|
||||
@@ -154,14 +144,10 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
if (!image || index === lastSelectedImageIndex) {
|
||||
return;
|
||||
}
|
||||
if (alt) {
|
||||
dispatch(imageToCompareChanged(image));
|
||||
} else {
|
||||
dispatch(imageSelected(image));
|
||||
}
|
||||
dispatch(imageSelected(image));
|
||||
scrollToImage(image.image_name, index);
|
||||
},
|
||||
[data, lastSelectedImageIndex, dispatch]
|
||||
[dispatch, lastSelectedImageIndex, data]
|
||||
);
|
||||
|
||||
const isOnFirstImage = useMemo(() => lastSelectedImageIndex === 0, [lastSelectedImageIndex]);
|
||||
@@ -176,41 +162,21 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
return lastSelectedImageIndex + imagesPerRow < loadedImagesCount;
|
||||
}, [lastSelectedImageIndex, loadedImagesCount]);
|
||||
|
||||
const handleLeftImage = useCallback(
|
||||
(alt?: boolean) => {
|
||||
handleNavigation('left', alt);
|
||||
},
|
||||
[handleNavigation]
|
||||
);
|
||||
const handleLeftImage = useCallback(() => {
|
||||
handleNavigation('left');
|
||||
}, [handleNavigation]);
|
||||
|
||||
const handleRightImage = useCallback(
|
||||
(alt?: boolean) => {
|
||||
handleNavigation('right', alt);
|
||||
},
|
||||
[handleNavigation]
|
||||
);
|
||||
const handleRightImage = useCallback(() => {
|
||||
handleNavigation('right');
|
||||
}, [handleNavigation]);
|
||||
|
||||
const handleUpImage = useCallback(
|
||||
(alt?: boolean) => {
|
||||
handleNavigation('up', alt);
|
||||
},
|
||||
[handleNavigation]
|
||||
);
|
||||
const handleUpImage = useCallback(() => {
|
||||
handleNavigation('up');
|
||||
}, [handleNavigation]);
|
||||
|
||||
const handleDownImage = useCallback(
|
||||
(alt?: boolean) => {
|
||||
handleNavigation('down', alt);
|
||||
},
|
||||
[handleNavigation]
|
||||
);
|
||||
|
||||
const nextImage = useCallback(() => {
|
||||
handleRightImage();
|
||||
}, [handleRightImage]);
|
||||
|
||||
const prevImage = useCallback(() => {
|
||||
handleLeftImage();
|
||||
}, [handleLeftImage]);
|
||||
const handleDownImage = useCallback(() => {
|
||||
handleNavigation('down');
|
||||
}, [handleNavigation]);
|
||||
|
||||
return {
|
||||
handleLeftImage,
|
||||
@@ -220,7 +186,5 @@ export const useGalleryNavigation = (): UseGalleryNavigationReturn => {
|
||||
isOnFirstImage,
|
||||
isOnLastImage,
|
||||
areImagesBelowCurrent,
|
||||
nextImage,
|
||||
prevImage,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -36,7 +36,6 @@ export const useMultiselect = (imageDTO?: ImageDTO) => {
|
||||
shiftKey: e.shiftKey,
|
||||
ctrlKey: e.ctrlKey,
|
||||
metaKey: e.metaKey,
|
||||
altKey: e.altKey,
|
||||
})
|
||||
);
|
||||
},
|
||||
|
||||
@@ -6,7 +6,7 @@ import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import type { BoardId, ComparisonMode, GalleryState, GalleryView } from './types';
|
||||
import type { BoardId, GalleryState, GalleryView } from './types';
|
||||
import { IMAGE_LIMIT, INITIAL_IMAGE_LIMIT } from './types';
|
||||
|
||||
const initialGalleryState: GalleryState = {
|
||||
@@ -22,9 +22,6 @@ const initialGalleryState: GalleryState = {
|
||||
limit: INITIAL_IMAGE_LIMIT,
|
||||
offset: 0,
|
||||
isImageViewerOpen: true,
|
||||
imageToCompare: null,
|
||||
comparisonMode: 'slider',
|
||||
comparisonFit: 'fill',
|
||||
};
|
||||
|
||||
export const gallerySlice = createSlice({
|
||||
@@ -37,28 +34,6 @@ export const gallerySlice = createSlice({
|
||||
selectionChanged: (state, action: PayloadAction<ImageDTO[]>) => {
|
||||
state.selection = uniqBy(action.payload, (i) => i.image_name);
|
||||
},
|
||||
imageToCompareChanged: (state, action: PayloadAction<ImageDTO | null>) => {
|
||||
state.imageToCompare = action.payload;
|
||||
if (action.payload) {
|
||||
state.isImageViewerOpen = true;
|
||||
}
|
||||
},
|
||||
comparisonModeChanged: (state, action: PayloadAction<ComparisonMode>) => {
|
||||
state.comparisonMode = action.payload;
|
||||
},
|
||||
comparisonModeCycled: (state) => {
|
||||
switch (state.comparisonMode) {
|
||||
case 'slider':
|
||||
state.comparisonMode = 'side-by-side';
|
||||
break;
|
||||
case 'side-by-side':
|
||||
state.comparisonMode = 'hover';
|
||||
break;
|
||||
case 'hover':
|
||||
state.comparisonMode = 'slider';
|
||||
break;
|
||||
}
|
||||
},
|
||||
shouldAutoSwitchChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.shouldAutoSwitch = action.payload;
|
||||
},
|
||||
@@ -104,16 +79,6 @@ export const gallerySlice = createSlice({
|
||||
isImageViewerOpenChanged: (state, action: PayloadAction<boolean>) => {
|
||||
state.isImageViewerOpen = action.payload;
|
||||
},
|
||||
comparedImagesSwapped: (state) => {
|
||||
if (state.imageToCompare) {
|
||||
const oldSelection = state.selection;
|
||||
state.selection = [state.imageToCompare];
|
||||
state.imageToCompare = oldSelection[0] ?? null;
|
||||
}
|
||||
},
|
||||
comparisonFitChanged: (state, action: PayloadAction<'contain' | 'fill'>) => {
|
||||
state.comparisonFit = action.payload;
|
||||
},
|
||||
},
|
||||
extraReducers: (builder) => {
|
||||
builder.addMatcher(isAnyBoardDeleted, (state, action) => {
|
||||
@@ -152,11 +117,6 @@ export const {
|
||||
moreImagesLoaded,
|
||||
alwaysShowImageSizeBadgeChanged,
|
||||
isImageViewerOpenChanged,
|
||||
imageToCompareChanged,
|
||||
comparisonModeChanged,
|
||||
comparedImagesSwapped,
|
||||
comparisonFitChanged,
|
||||
comparisonModeCycled,
|
||||
} = gallerySlice.actions;
|
||||
|
||||
const isAnyBoardDeleted = isAnyOf(
|
||||
@@ -178,13 +138,5 @@ export const galleryPersistConfig: PersistConfig<GalleryState> = {
|
||||
name: gallerySlice.name,
|
||||
initialState: initialGalleryState,
|
||||
migrate: migrateGalleryState,
|
||||
persistDenylist: [
|
||||
'selection',
|
||||
'selectedBoardId',
|
||||
'galleryView',
|
||||
'offset',
|
||||
'limit',
|
||||
'isImageViewerOpen',
|
||||
'imageToCompare',
|
||||
],
|
||||
persistDenylist: ['selection', 'selectedBoardId', 'galleryView', 'offset', 'limit', 'isImageViewerOpen'],
|
||||
};
|
||||
|
||||
@@ -7,8 +7,6 @@ export const IMAGE_LIMIT = 20;
|
||||
|
||||
export type GalleryView = 'images' | 'assets';
|
||||
export type BoardId = 'none' | (string & Record<never, never>);
|
||||
export type ComparisonMode = 'slider' | 'side-by-side' | 'hover';
|
||||
export type ComparisonFit = 'contain' | 'fill';
|
||||
|
||||
export type GalleryState = {
|
||||
selection: ImageDTO[];
|
||||
@@ -22,8 +20,5 @@ export type GalleryState = {
|
||||
offset: number;
|
||||
limit: number;
|
||||
alwaysShowImageSizeBadge: boolean;
|
||||
imageToCompare: ImageDTO | null;
|
||||
comparisonMode: ComparisonMode;
|
||||
comparisonFit: ComparisonFit;
|
||||
isImageViewerOpen: boolean;
|
||||
};
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { objectKeys } from 'common/util/objectKeys';
|
||||
import { shouldConcatPromptsChanged } from 'features/controlLayers/store/controlLayersSlice';
|
||||
import type { Layer } from 'features/controlLayers/store/types';
|
||||
import type { LoRA } from 'features/lora/store/loraSlice';
|
||||
import type {
|
||||
@@ -19,7 +16,6 @@ import { validators } from 'features/metadata/util/validators';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { t } from 'i18next';
|
||||
import { size } from 'lodash-es';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { parsers } from './parsers';
|
||||
@@ -380,25 +376,54 @@ export const handlers = {
|
||||
}),
|
||||
} as const;
|
||||
|
||||
type ParsedValue = Awaited<ReturnType<(typeof handlers)[keyof typeof handlers]['parse']>>;
|
||||
type RecallResults = Partial<Record<keyof typeof handlers, ParsedValue>>;
|
||||
|
||||
export const parseAndRecallPrompts = async (metadata: unknown) => {
|
||||
const keysToRecall: (keyof typeof handlers)[] = [
|
||||
'positivePrompt',
|
||||
'negativePrompt',
|
||||
'sdxlPositiveStylePrompt',
|
||||
'sdxlNegativeStylePrompt',
|
||||
];
|
||||
const recalled = await recallKeys(keysToRecall, metadata);
|
||||
if (size(recalled) > 0) {
|
||||
const results = await Promise.allSettled([
|
||||
handlers.positivePrompt.parse(metadata).then((positivePrompt) => {
|
||||
if (!handlers.positivePrompt.recall) {
|
||||
return;
|
||||
}
|
||||
handlers.positivePrompt?.recall(positivePrompt);
|
||||
}),
|
||||
handlers.negativePrompt.parse(metadata).then((negativePrompt) => {
|
||||
if (!handlers.negativePrompt.recall) {
|
||||
return;
|
||||
}
|
||||
handlers.negativePrompt?.recall(negativePrompt);
|
||||
}),
|
||||
handlers.sdxlPositiveStylePrompt.parse(metadata).then((sdxlPositiveStylePrompt) => {
|
||||
if (!handlers.sdxlPositiveStylePrompt.recall) {
|
||||
return;
|
||||
}
|
||||
handlers.sdxlPositiveStylePrompt?.recall(sdxlPositiveStylePrompt);
|
||||
}),
|
||||
handlers.sdxlNegativeStylePrompt.parse(metadata).then((sdxlNegativeStylePrompt) => {
|
||||
if (!handlers.sdxlNegativeStylePrompt.recall) {
|
||||
return;
|
||||
}
|
||||
handlers.sdxlNegativeStylePrompt?.recall(sdxlNegativeStylePrompt);
|
||||
}),
|
||||
]);
|
||||
if (results.some((result) => result.status === 'fulfilled')) {
|
||||
parameterSetToast(t('metadata.allPrompts'));
|
||||
}
|
||||
};
|
||||
|
||||
export const parseAndRecallImageDimensions = async (metadata: unknown) => {
|
||||
const recalled = recallKeys(['width', 'height'], metadata);
|
||||
if (size(recalled) > 0) {
|
||||
const results = await Promise.allSettled([
|
||||
handlers.width.parse(metadata).then((width) => {
|
||||
if (!handlers.width.recall) {
|
||||
return;
|
||||
}
|
||||
handlers.width?.recall(width);
|
||||
}),
|
||||
handlers.height.parse(metadata).then((height) => {
|
||||
if (!handlers.height.recall) {
|
||||
return;
|
||||
}
|
||||
handlers.height?.recall(height);
|
||||
}),
|
||||
]);
|
||||
if (results.some((result) => result.status === 'fulfilled')) {
|
||||
parameterSetToast(t('metadata.imageDimensions'));
|
||||
}
|
||||
};
|
||||
@@ -413,20 +438,28 @@ export const parseAndRecallAllMetadata = async (
|
||||
toControlLayers: boolean,
|
||||
skip: (keyof typeof handlers)[] = []
|
||||
) => {
|
||||
const skipKeys = deepClone(skip);
|
||||
const skipKeys = skip ?? [];
|
||||
if (toControlLayers) {
|
||||
skipKeys.push(...TO_CONTROL_LAYERS_SKIP_KEYS);
|
||||
} else {
|
||||
skipKeys.push(...NOT_TO_CONTROL_LAYERS_SKIP_KEYS);
|
||||
}
|
||||
const results = await Promise.allSettled(
|
||||
objectKeys(handlers)
|
||||
.filter((key) => !skipKeys.includes(key))
|
||||
.map((key) => {
|
||||
const { parse, recall } = handlers[key];
|
||||
return parse(metadata).then((value) => {
|
||||
if (!recall) {
|
||||
return;
|
||||
}
|
||||
/* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */
|
||||
recall(value);
|
||||
});
|
||||
})
|
||||
);
|
||||
|
||||
// We may need to take some further action depending on what was recalled. For example, we need to disable SDXL prompt
|
||||
// concat if the negative or positive style prompt was set. Because the recalling is all async, we need to collect all
|
||||
// results
|
||||
const keysToRecall = objectKeys(handlers).filter((key) => !skipKeys.includes(key));
|
||||
const recalled = await recallKeys(keysToRecall, metadata);
|
||||
|
||||
if (size(recalled) > 0) {
|
||||
if (results.some((result) => result.status === 'fulfilled')) {
|
||||
toast({
|
||||
id: 'PARAMETER_SET',
|
||||
title: t('toast.parametersSet'),
|
||||
@@ -440,43 +473,3 @@ export const parseAndRecallAllMetadata = async (
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Recalls a set of keys from metadata.
|
||||
* Includes special handling for some metadata where recalling may have side effects. For example, recalling a "style"
|
||||
* prompt that is different from the "positive" or "negative" prompt should disable prompt concatenation.
|
||||
* @param keysToRecall An array of keys to recall.
|
||||
* @param metadata The metadata to recall from
|
||||
* @returns A promise that resolves to an object containing the recalled values.
|
||||
*/
|
||||
const recallKeys = async (keysToRecall: (keyof typeof handlers)[], metadata: unknown): Promise<RecallResults> => {
|
||||
const { dispatch } = getStore();
|
||||
const recalled: RecallResults = {};
|
||||
for (const key of keysToRecall) {
|
||||
const { parse, recall } = handlers[key];
|
||||
if (!recall) {
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
const value = await parse(metadata);
|
||||
/* @ts-expect-error The return type of parse and the input type of recall are guaranteed to be compatible. */
|
||||
await recall(value);
|
||||
recalled[key] = value;
|
||||
} catch {
|
||||
// no-op
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
(recalled['sdxlPositiveStylePrompt'] && recalled['sdxlPositiveStylePrompt'] !== recalled['positivePrompt']) ||
|
||||
(recalled['sdxlNegativeStylePrompt'] && recalled['sdxlNegativeStylePrompt'] !== recalled['negativePrompt'])
|
||||
) {
|
||||
// If we set the negative style prompt or positive style prompt, we should disable prompt concat
|
||||
dispatch(shouldConcatPromptsChanged(false));
|
||||
} else {
|
||||
// Otherwise, we should enable prompt concat
|
||||
dispatch(shouldConcatPromptsChanged(true));
|
||||
}
|
||||
|
||||
return recalled;
|
||||
};
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { getStore } from 'app/store/nanostores/store';
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { isModelIdentifier, isModelIdentifierV2 } from 'features/nodes/types/common';
|
||||
import type { ModelIdentifier } from 'features/nodes/types/v2/common';
|
||||
import { modelsApi } from 'services/api/endpoints/models';
|
||||
import type { AnyModelConfig, BaseModelType, ModelType } from 'services/api/types';
|
||||
|
||||
@@ -108,30 +107,19 @@ export const fetchModelConfigWithTypeGuard = async <T extends AnyModelConfig>(
|
||||
|
||||
/**
|
||||
* Fetches the model key from a model identifier. This includes fetching the key for MM1 format model identifiers.
|
||||
* @param modelIdentifier The model identifier. This can be a MM1 or MM2 identifier. In every case, we attempt to fetch
|
||||
* the model config from the server to ensure that the model identifier is valid and represents an installed model.
|
||||
* @param modelIdentifier The model identifier. The MM2 format `{key: string}` simply extracts the key. The MM1 format
|
||||
* `{model_name: string, base_model: BaseModelType}` must do a network request to fetch the key.
|
||||
* @param type The type of model to fetch. This is used to fetch the key for MM1 format model identifiers.
|
||||
* @param message An optional custom message to include in the error if the model identifier is invalid.
|
||||
* @returns A promise that resolves to the model key.
|
||||
* @throws {InvalidModelConfigError} If the model identifier is invalid.
|
||||
*/
|
||||
export const getModelKey = async (
|
||||
modelIdentifier: unknown | ModelIdentifierField | ModelIdentifier,
|
||||
type: ModelType,
|
||||
message?: string
|
||||
): Promise<string> => {
|
||||
export const getModelKey = async (modelIdentifier: unknown, type: ModelType, message?: string): Promise<string> => {
|
||||
if (isModelIdentifier(modelIdentifier)) {
|
||||
try {
|
||||
// Check if the model exists by key
|
||||
return (await fetchModelConfig(modelIdentifier.key)).key;
|
||||
} catch {
|
||||
// If not, fetch the model key by name and base model
|
||||
return (await fetchModelConfigByAttrs(modelIdentifier.name, modelIdentifier.base, type)).key;
|
||||
}
|
||||
} else if (isModelIdentifierV2(modelIdentifier)) {
|
||||
// Try by old-format model identifier
|
||||
return modelIdentifier.key;
|
||||
}
|
||||
if (isModelIdentifierV2(modelIdentifier)) {
|
||||
return (await fetchModelConfigByAttrs(modelIdentifier.model_name, modelIdentifier.base_model, type)).key;
|
||||
}
|
||||
// Nope, couldn't find it
|
||||
throw new InvalidModelConfigError(message || `Invalid model identifier: ${modelIdentifier}`);
|
||||
};
|
||||
|
||||
@@ -72,12 +72,10 @@ export const ModelEdit = ({ form }: Props) => {
|
||||
<FormLabel>{t('modelManager.baseModel')}</FormLabel>
|
||||
<BaseModelSelect control={form.control} />
|
||||
</FormControl>
|
||||
{data.type === 'main' && (
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={form.control} />
|
||||
</FormControl>
|
||||
)}
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
<FormLabel>{t('modelManager.variant')}</FormLabel>
|
||||
<ModelVariantSelect control={form.control} />
|
||||
</FormControl>
|
||||
{data.type === 'main' && data.format === 'checkpoint' && (
|
||||
<>
|
||||
<FormControl flexDir="column" alignItems="flex-start" gap={1}>
|
||||
|
||||
@@ -19,7 +19,7 @@ import {
|
||||
redo,
|
||||
undo,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { $flow, $needsFit } from 'features/nodes/store/reactFlowInstance';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import type { CSSProperties, MouseEvent } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
@@ -68,7 +68,6 @@ export const Flow = memo(() => {
|
||||
const nodes = useAppSelector((s) => s.nodes.present.nodes);
|
||||
const edges = useAppSelector((s) => s.nodes.present.edges);
|
||||
const viewport = useStore($viewport);
|
||||
const needsFit = useStore($needsFit);
|
||||
const mayUndo = useAppSelector((s) => s.nodes.past.length > 0);
|
||||
const mayRedo = useAppSelector((s) => s.nodes.future.length > 0);
|
||||
const shouldSnapToGrid = useAppSelector((s) => s.workflowSettings.shouldSnapToGrid);
|
||||
@@ -93,16 +92,8 @@ export const Flow = memo(() => {
|
||||
const onNodesChange: OnNodesChange = useCallback(
|
||||
(nodeChanges) => {
|
||||
dispatch(nodesChanged(nodeChanges));
|
||||
const flow = $flow.get();
|
||||
if (!flow) {
|
||||
return;
|
||||
}
|
||||
if (needsFit) {
|
||||
$needsFit.set(false);
|
||||
flow.fitView();
|
||||
}
|
||||
},
|
||||
[dispatch, needsFit]
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const onEdgesChange: OnEdgesChange = useCallback(
|
||||
|
||||
@@ -15,20 +15,27 @@ const ViewportControls = () => {
|
||||
const { t } = useTranslation();
|
||||
const { zoomIn, zoomOut, fitView } = useReactFlow();
|
||||
const dispatch = useAppDispatch();
|
||||
// const shouldShowFieldTypeLegend = useAppSelector(
|
||||
// (s) => s.nodes.present.shouldShowFieldTypeLegend
|
||||
// );
|
||||
const shouldShowMinimapPanel = useAppSelector((s) => s.workflowSettings.shouldShowMinimapPanel);
|
||||
|
||||
const handleClickedZoomIn = useCallback(() => {
|
||||
zoomIn({ duration: 300 });
|
||||
zoomIn();
|
||||
}, [zoomIn]);
|
||||
|
||||
const handleClickedZoomOut = useCallback(() => {
|
||||
zoomOut({ duration: 300 });
|
||||
zoomOut();
|
||||
}, [zoomOut]);
|
||||
|
||||
const handleClickedFitView = useCallback(() => {
|
||||
fitView({ duration: 300 });
|
||||
fitView();
|
||||
}, [fitView]);
|
||||
|
||||
// const handleClickedToggleFieldTypeLegend = useCallback(() => {
|
||||
// dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend));
|
||||
// }, [shouldShowFieldTypeLegend, dispatch]);
|
||||
|
||||
const handleClickedToggleMiniMapPanel = useCallback(() => {
|
||||
dispatch(shouldShowMinimapPanelChanged(!shouldShowMinimapPanel));
|
||||
}, [shouldShowMinimapPanel, dispatch]);
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import 'reactflow/dist/style.css';
|
||||
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import QueueControls from 'features/queue/components/QueueControls';
|
||||
import ResizeHandle from 'features/ui/components/tabs/ResizeHandle';
|
||||
import { usePanelStorage } from 'features/ui/hooks/usePanelStorage';
|
||||
@@ -19,8 +21,14 @@ import { WorkflowName } from './WorkflowName';
|
||||
|
||||
const panelGroupStyles: CSSProperties = { height: '100%', width: '100%' };
|
||||
|
||||
const selector = createMemoizedSelector(selectWorkflowSlice, (workflow) => {
|
||||
return {
|
||||
mode: workflow.mode,
|
||||
};
|
||||
});
|
||||
|
||||
const NodeEditorPanelGroup = () => {
|
||||
const mode = useAppSelector((s) => s.workflow.mode);
|
||||
const { mode } = useAppSelector(selector);
|
||||
const panelGroupRef = useRef<ImperativePanelGroupHandle>(null);
|
||||
const panelStorage = usePanelStorage();
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user