feat(nodes,ui): fully migrate queue to session_processor

This commit is contained in:
psychedelicious
2023-09-17 18:25:46 +10:00
parent e1b8874bc5
commit 7a1fe7548b
60 changed files with 1479 additions and 1281 deletions

View File

@@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import logging
import sqlite3
from logging import Logger
@@ -11,9 +12,8 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
from invokeai.app.services.images import ImageService, ImageServiceDependencies
from invokeai.app.services.resource_name import SimpleNameService
from invokeai.app.services.session_execution.session_execution_default import DefaultSessionExecutionService
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.urls import LocalUrlService
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.version.invokeai_version import __version__
@@ -59,6 +59,7 @@ class ApiDependencies:
@staticmethod
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
logger.setLevel(logging.DEBUG)
logger.info(f"InvokeAI version {__version__}")
logger.info(f"Root directory = {str(config.root_path)}")
logger.debug(f"Internet connectivity is {config.internet_available}")
@@ -70,8 +71,8 @@ class ApiDependencies:
# TODO: build a file/path manager?
db_path = config.db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
# db_location = str(db_path)
db_location = ":memory:"
db_location = str(db_path)
# db_location = ":memory:"
logger.info(f"Using database at {db_location}")
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
@@ -140,7 +141,6 @@ class ApiDependencies:
logger=logger,
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
session_processor=DefaultSessionProcessor(),
session_execution=DefaultSessionExecutionService(),
)
create_system_graphs(services.graph_library)
@@ -160,4 +160,4 @@ class ApiDependencies:
print("SHUTTING DOWN")
if ApiDependencies.invoker:
print("SHUTTING DOWN INVOKER")
ApiDependencies.invoker.stop_service()
ApiDependencies.invoker.stop()

View File

@@ -23,7 +23,7 @@ class FastAPIEventService(EventServiceBase):
super().__init__()
def stop_service(self, *args, **kwargs):
def stop(self, *args, **kwargs):
self.__stop_event.set()
self.__queue.put(None)

View File

@@ -3,10 +3,11 @@ from typing import Optional
from fastapi import Body, Path, Query
from fastapi.routing import APIRouter
from invokeai.app.services.session_execution.session_execution_common import SessionExecutionStatusResult
from invokeai.app.services.session_queue.session_queue_common import (
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatusResult
from invokeai.app.services.session_queue.session_queue_common import ( # CancelByBatchIDsResult,
QUEUE_ITEM_STATUS,
Batch,
BatchStatusResult,
CancelByBatchIDsResult,
ClearResult,
EnqueueBatchResult,
@@ -24,7 +25,9 @@ from ..dependencies import ApiDependencies
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
class SessionQueueAndExecutionStatusResult(SessionQueueStatusResult, SessionExecutionStatusResult):
class SessionQueueAndProcessorStatusResult(SessionQueueStatusResult, SessionProcessorStatusResult):
"""The overall status of session queue and processor"""
pass
@@ -84,42 +87,25 @@ async def list_queue_items(
@session_queue_router.put(
"/{queue_id}/start",
operation_id="start",
"/{queue_id}/resume",
operation_id="resume",
)
async def start(
async def resume(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> None:
"""Starts session queue execution"""
return ApiDependencies.invoker.services.session_execution.start(
queue_id=queue_id,
)
"""Resumes session processor"""
return ApiDependencies.invoker.services.session_processor.resume()
@session_queue_router.put(
"/{queue_id}/stop",
operation_id="stop",
"/{queue_id}/pause",
operation_id="pause",
)
async def stop(
async def Pause(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> None:
"""Stops session queue execution, waiting for the currently executing session to finish"""
return ApiDependencies.invoker.services.session_execution.stop(
queue_id=queue_id,
)
@session_queue_router.put(
"/{queue_id}/cancel",
operation_id="cancel",
)
async def cancel(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> None:
"""Stops session queue execution, immediately canceling the currently-executing session"""
return ApiDependencies.invoker.services.session_execution.cancel(
queue_id=queue_id,
)
"""Pauses session processor"""
return ApiDependencies.invoker.services.session_processor.pause()
@session_queue_router.put(
@@ -132,13 +118,6 @@ async def cancel_by_batch_ids(
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
) -> CancelByBatchIDsResult:
"""Immediately cancels all queue items from the given batch ids"""
current = ApiDependencies.invoker.services.session_execution.get_current(
queue_id=queue_id,
)
if current is not None and current.batch_id in batch_ids:
ApiDependencies.invoker.services.session_execution.cancel(
queue_id=queue_id,
)
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
@@ -153,12 +132,11 @@ async def clear(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> ClearResult:
"""Clears the queue entirely, immediately canceling the currently-executing session"""
ApiDependencies.invoker.services.session_execution.cancel(
queue_id=queue_id,
)
return ApiDependencies.invoker.services.session_queue.clear(
queue_id=queue_id,
)
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
if queue_item is not None:
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
return clear_result
@session_queue_router.put(
@@ -172,62 +150,78 @@ async def prune(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> PruneResult:
"""Prunes all completed or errored queue items"""
return ApiDependencies.invoker.services.session_queue.prune(
queue_id=queue_id,
)
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
@session_queue_router.get(
"/{queue_id}/current",
operation_id="current",
operation_id="get_current_queue_item",
responses={
200: {"model": Optional[SessionQueueItem]},
},
)
async def current(
async def get_current_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the currently execution queue item"""
return ApiDependencies.invoker.services.session_execution.get_current(
queue_id=queue_id,
)
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
@session_queue_router.get(
"/{queue_id}/peek",
operation_id="peek",
"/{queue_id}/next",
operation_id="get_next_queue_item",
responses={
200: {"model": Optional[SessionQueueItem]},
},
)
async def peek(
async def get_next_queue_item(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> Optional[SessionQueueItem]:
"""Gets the next queue item, without executing it"""
return ApiDependencies.invoker.services.session_queue.peek(
queue_id=queue_id,
)
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
@session_queue_router.get(
"/{queue_id}/status",
operation_id="get_status",
operation_id="get_queue_status",
responses={
200: {"model": SessionQueueAndExecutionStatusResult},
200: {"model": SessionQueueStatusResult},
},
)
async def get_status(
async def get_queue_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionQueueAndExecutionStatusResult:
) -> SessionQueueStatusResult:
"""Gets the status of the session queue"""
queue_status = ApiDependencies.invoker.services.session_queue.get_status(
queue_id=queue_id,
)
execution_status = ApiDependencies.invoker.services.session_execution.get_status(
queue_id=queue_id,
)
return ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
return SessionQueueAndExecutionStatusResult(**queue_status.dict(), **execution_status.dict())
@session_queue_router.get(
"/{queue_id}/processor/status",
operation_id="get_processor_status",
responses={
200: {"model": SessionProcessorStatusResult},
},
)
async def get_processor_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
) -> SessionProcessorStatusResult:
"""Gets the status of the session queue"""
return ApiDependencies.invoker.services.session_processor.get_status()
@session_queue_router.get(
"/{queue_id}/b/{batch_id}/status",
operation_id="get_batch_status",
responses={
200: {"model": BatchStatusResult},
},
)
async def get_batch_status(
queue_id: str = Path(description="The queue id to perform this operation on"),
batch_id: str = Path(description="The batch to get the status of"),
) -> BatchStatusResult:
"""Gets the status of the session queue"""
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
@session_queue_router.get(
@@ -242,7 +236,7 @@ async def get_queue_item(
item_id: str = Path(description="The queue item to get"),
) -> SessionQueueItem:
"""Gets a queue item"""
return ApiDependencies.invoker.services.session_queue.get_queue_item(queue_id=queue_id, item_id=item_id)
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
@session_queue_router.put(
@@ -257,27 +251,5 @@ async def cancel_queue_item(
item_id: str = Path(description="The queue item to cancel"),
) -> SessionQueueItem:
"""Deletes a queue item"""
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(queue_id=queue_id, item_id=item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
return ApiDependencies.invoker.services.session_queue.set_queue_item_status(
queue_id=queue_id, item_id=item_id, status="canceled"
)
return queue_item
@session_queue_router.put(
"/{queue_id}/start_processor",
operation_id="start_processor",
)
async def start_processor() -> None:
"""Deletes a queue item"""
ApiDependencies.invoker.services.session_processor.start()
@session_queue_router.put(
"/{queue_id}/stop_processor",
operation_id="stop_processor",
)
async def stop_processor() -> None:
"""Deletes a queue item"""
ApiDependencies.invoker.services.session_processor.stop()
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)

View File

@@ -4,7 +4,7 @@ from typing import Any, Optional
from invokeai.app.models.image import ProgressImage
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem
from invokeai.app.util.misc import get_timestamp
@@ -195,8 +195,20 @@ class EventServiceBase:
),
)
def emit_session_canceled(
self,
graph_execution_state_id: str,
) -> None:
"""Emitted when a session is canceled"""
self.__emit_session_event(
event_name="session_canceled",
payload=dict(
graph_execution_state_id=graph_execution_state_id,
),
)
def emit_queue_item_status_changed(self, session_queue_item: SessionQueueItem) -> None:
"""Emitted when a queue item is status_changed"""
"""Emitted when a queue item's status changes"""
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload=dict(
@@ -207,3 +219,17 @@ class EventServiceBase:
status=session_queue_item.status,
),
)
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
"""Emitted when a batch is enqueued"""
self.__emit_queue_event(
event_name="batch_enqueued",
payload=enqueue_result.dict(),
)
def emit_queue_cleared(self, queue_id: str) -> None:
"""Emitted when the queue is cleared"""
self.__emit_queue_event(
event_name="queue_cleared",
payload=dict(queue_id=queue_id),
)

View File

@@ -18,7 +18,6 @@ if TYPE_CHECKING:
from invokeai.app.services.item_storage import ItemStorageABC
from invokeai.app.services.latent_storage import LatentsStorageBase
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
from invokeai.app.services.session_execution.session_execution_base import SessionExecutionServiceBase
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
@@ -42,7 +41,6 @@ class InvocationServices:
queue: "InvocationQueueABC"
session_queue: "SessionQueueBase"
session_processor: "SessionProcessorBase"
session_execution: "SessionExecutionServiceBase"
def __init__(
self,
@@ -61,7 +59,6 @@ class InvocationServices:
queue: "InvocationQueueABC",
session_queue: "SessionQueueBase",
session_processor: "SessionProcessorBase",
session_execution: "SessionExecutionServiceBase",
):
self.board_images = board_images
self.boards = boards
@@ -78,4 +75,3 @@ class InvocationServices:
self.queue = queue
self.session_queue = session_queue
self.session_processor = session_processor
self.session_execution = session_execution

View File

@@ -52,14 +52,14 @@ class Invoker:
self.services.queue.cancel(graph_execution_state_id)
def __start_service(self, service) -> None:
# Call start_service() method on any services that have it
start_op = getattr(service, "start_service", None)
# Call start() method on any services that have it
start_op = getattr(service, "start", None)
if callable(start_op):
start_op(self)
def __stop_service(self, service) -> None:
# Call stop_service() method on any services that have it
stop_op = getattr(service, "stop_service", None)
# Call stop() method on any services that have it
stop_op = getattr(service, "stop", None)
if callable(stop_op):
stop_op(self)
@@ -68,7 +68,7 @@ class Invoker:
for service in vars(self.services):
self.__start_service(getattr(self.services, service))
def stop_service(self) -> None:
def stop(self) -> None:
"""Stops the invoker. A new invoker will have to be created to execute further."""
# First stop all services
for service in vars(self.services):

View File

@@ -17,7 +17,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
__invoker: Invoker
__threadLimit: BoundedSemaphore
def start_service(self, invoker) -> None:
def start(self, invoker) -> None:
# if we do want multithreading at some point, we could make this configurable
self.__threadLimit = BoundedSemaphore(1)
self.__invoker = invoker
@@ -30,7 +30,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
self.__invoker_thread.start()
def stop_service(self, *args, **kwargs) -> None:
def stop(self, *args, **kwargs) -> None:
self.__stop_event.set()
def __process(self, stop_event: Event):

View File

@@ -1,53 +0,0 @@
from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_execution.session_execution_common import SessionExecutionStatusResult
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
class SessionExecutionServiceBase(ABC):
@abstractmethod
def start_service(self, invoker: Invoker) -> None:
"""Service startup"""
pass
@abstractmethod
def start(
self,
queue_id: str,
) -> None:
"""Starts session queue execution"""
pass
@abstractmethod
def stop(
self,
queue_id: str,
) -> None:
"""Stops session queue execution after the currently executing session finishes"""
pass
@abstractmethod
def cancel(
self,
queue_id: str,
) -> None:
"""Stops session queue execution, immediately canceling the currently-executing session"""
pass
@abstractmethod
def get_current(
self,
queue_id: str,
) -> Optional[SessionQueueItem]:
"""Gets the currently-executing queue item"""
pass
@abstractmethod
def get_status(
self,
queue_id: str,
) -> SessionExecutionStatusResult:
"""Gets the status of the session queue"""
pass

View File

@@ -1,6 +0,0 @@
from pydantic import BaseModel, Field
class SessionExecutionStatusResult(BaseModel):
started: bool = Field(..., description="Whether the session queue is running")
stop_after_current: bool = Field(..., description="Whether the session queue is pending a stop")

View File

@@ -1,107 +0,0 @@
from typing import Optional
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_execution.session_execution_base import SessionExecutionServiceBase
from invokeai.app.services.session_execution.session_execution_common import SessionExecutionStatusResult
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
class DefaultSessionExecutionService(SessionExecutionServiceBase):
def __init__(self) -> None:
self._invoker: Invoker
self._current: Optional[SessionQueueItem] = None
self._started: bool = False
self._stop_after_current = False
def start_service(self, invoker: Invoker) -> None:
self._invoker = invoker
local_handler.register(event_name=EventServiceBase.session_event, _func=self._on_event)
async def _on_event(self, event: Event) -> Event:
event_name = event[1]["event"]
match event_name:
case "graph_execution_state_complete":
await self._handle_complete_event(event, False)
case "invocation_error":
await self._handle_complete_event(event, True)
case "session_retrieval_error":
await self._handle_complete_event(event, True)
case "invocation_retrieval_error":
await self._handle_complete_event(event, True)
return event
async def _handle_complete_event(self, event: Event, err: bool) -> None:
data = event[1]["data"]
queue_item = self._invoker.services.session_queue.get_queue_item_by_session_id(data["graph_execution_state_id"])
# Sessions are marked complete when they have an error, so we get an `invocation_error`
# followed by a `graph_execution_state_complete`. Don't mark queue items complete if
# they are already marked error.
if queue_item.status != "failed":
queue_item = self._invoker.services.session_queue.set_queue_item_status(
queue_id=queue_item.queue_id, item_id=queue_item.item_id, status="failed" if err else "completed"
)
self._invoker.services.events.emit_queue_item_status_changed(queue_item)
if self._stop_after_current:
self._stop_after_current = False
self._started = False
self._current = None
if self._started:
self.invoke_next(queue_id=queue_item.queue_id)
def _emit_queue_item_status(self) -> None:
if self._current is None:
return
self._invoker.services.events.emit_queue_item_status_changed(self._current)
def invoke_next(self, queue_id: str) -> None:
# do not invoke if already invoking
if self._current:
return
queue_item = self._invoker.services.session_queue.dequeue()
if queue_item is None:
# queue empty
self._current = None
self._started = False
self._stop_after_current = False
return
self._current = queue_item
# execute the session
self._invoker.services.graph_execution_manager.set(queue_item.session)
self._emit_queue_item_status()
# self._invoker.invoke(self._current.session, invoke_all=True)
def start(
self,
queue_id: str,
) -> None:
if not self._stop_after_current:
self._started = True
self.invoke_next(queue_id=queue_id)
def stop(self, queue_id: str) -> None:
self._started = False
self._stop_after_current = True
def cancel(self, queue_id: str) -> None:
if self._current is not None:
self._invoker.services.queue.cancel(self._current.session_id)
self._current = self._invoker.services.session_queue.set_queue_item_status(
queue_id=self._current.queue_id, item_id=self._current.item_id, status="canceled"
)
self._emit_queue_item_status()
self._current = None
self._started = False
self._stop_after_current = False
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
return self._current
def get_status(self, queue_id: str) -> SessionExecutionStatusResult:
return SessionExecutionStatusResult(started=self._started, stop_after_current=self._stop_after_current)

View File

@@ -1,21 +1,28 @@
from abc import ABC
from typing import Optional
from abc import ABC, abstractmethod
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatusResult
class SessionProcessorABC(ABC):
def start(self) -> None:
class SessionProcessorBase(ABC):
"""
Base class for session processor.
The session processor is responsible for executing sessions. It runs a simple polling loop,
checking the session queue for new sessions to execute. It must coordinate with the
invocation queue to ensure only one session is executing at a time.
"""
@abstractmethod
def resume(self) -> None:
"""Starts or resumes the session processor"""
pass
def stop(self) -> None:
@abstractmethod
def pause(self) -> None:
"""Pauses the session processor"""
pass
def poll_now(self) -> None:
pass
def get_current(self) -> Optional[SessionQueueItem]:
pass
def clear_current(self) -> None:
@abstractmethod
def get_status(self) -> SessionProcessorStatusResult:
"""Gets the status of the session processor"""
pass

View File

@@ -1,8 +1,7 @@
POLLING_INTERVAL = 1
THREAD_LIMIT = 1
FINISHED_SESSION_EVENTS = [
"graph_execution_state_complete",
"invocation_error",
"session_retrieval_error",
"invocation_retrieval_error",
]
from pydantic import BaseModel, Field
class SessionProcessorStatusResult(BaseModel):
is_started: bool = Field(description="Whether the session processor is started")
is_processing: bool = Field(description="Whether a session is being processed")
is_stop_pending: bool = Field(description="Whether processor is pending stopping")

View File

@@ -1,4 +1,6 @@
from threading import BoundedSemaphore, Event as ThreadEvent, Thread
from threading import BoundedSemaphore
from threading import Event as ThreadEvent
from threading import Thread
from typing import Optional
from fastapi_events.handlers.local import local_handler
@@ -8,99 +10,128 @@ from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
from ..invoker import Invoker
from .session_processor_base import SessionProcessorABC
from .session_processor_common import FINISHED_SESSION_EVENTS, POLLING_INTERVAL, THREAD_LIMIT
from .session_processor_base import SessionProcessorBase
from .session_processor_common import SessionProcessorStatusResult
POLLING_INTERVAL = 1
THREAD_LIMIT = 1
class DefaultSessionProcessor(SessionProcessorABC):
def start_service(self, invoker: Invoker) -> None:
class DefaultSessionProcessor(SessionProcessorBase):
def start(self, invoker: Invoker) -> None:
self.__invoker: Invoker = invoker
self.__queue_item: Optional[SessionQueueItem] = None
self.__stop_event = ThreadEvent()
# when a session is finished, we need to poll the queue immediately.
# because we need to wait for current item to finish and also wait
# if the queue is empty, need two events.
self.__poll_now_busy_event = ThreadEvent()
self.__poll_now_queue_event = ThreadEvent()
self.__poll_now_event = ThreadEvent()
local_handler.register(event_name=EventServiceBase.session_event, _func=self._on_session_event)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
self.__threadLimit = BoundedSemaphore(THREAD_LIMIT)
local_handler.register(event_name=EventServiceBase.session_event, _func=self._on_event)
self._start_thread()
def stop_service(self, *args, **kwargs) -> None:
def stop(self, *args, **kwargs) -> None:
self.__stop_event.set()
def _poll_now(self) -> None:
self.__poll_now_busy_event.set()
self.__poll_now_queue_event.set()
self.__poll_now_event.set()
def _start_thread(self) -> None:
# threads only live once, so we need to create a new one whenever we start the processor
# threads only live once, so we need to create a new one whenever we start the session processor
self.__thread = Thread(
name="session_processor",
target=self.__process,
kwargs=dict(
stop_event=self.__stop_event,
poll_now_busy_event=self.__poll_now_busy_event,
poll_now_queue_event=self.__poll_now_queue_event,
poll_now_event=self.__poll_now_event,
),
)
self.__thread.start()
async def _on_event(self, event: FastAPIEvent) -> None:
async def _on_session_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
if event_name in FINISHED_SESSION_EVENTS:
if event_name in [
"graph_execution_state_complete",
"invocation_error",
"session_retrieval_error",
"invocation_retrieval_error",
] or (
event_name == "session_canceled"
and self.__queue_item is not None
and self.__queue_item.session_id == event[1]["data"]["graph_execution_state_id"]
):
self.__queue_item = None
self._poll_now()
def stop(self) -> None:
self.__stop_event.set()
async def _on_queue_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
if event_name == "batch_enqueued":
self._poll_now()
if event_name == "queue_cleared":
self.__queue_item = None
self._poll_now()
def start(self) -> None:
if self.__thread.is_alive():
def _is_started(self) -> bool:
return self.__thread.is_alive()
def _is_processing(self) -> bool:
return self.__queue_item is not None
def _is_stop_pending(self) -> bool:
return self.__stop_event.is_set()
def get_status(self) -> SessionProcessorStatusResult:
return SessionProcessorStatusResult(
is_started=self._is_started(),
is_processing=self._is_processing(),
is_stop_pending=self._is_stop_pending(),
)
def resume(self) -> None:
if self._is_started():
return
self.__stop_event.clear()
self._start_thread()
def poll_now(self) -> None:
self._poll_now()
def get_current(self) -> Optional[SessionQueueItem]:
return self.__queue_item
def clear_current(self) -> None:
self.__queue_item = None
def pause(self) -> None:
self.__stop_event.set()
def __process(
self,
stop_event: ThreadEvent,
poll_now_busy_event: ThreadEvent,
poll_now_queue_event: ThreadEvent,
poll_now_event: ThreadEvent,
):
try:
self.__threadLimit.acquire()
queue_item: Optional[SessionQueueItem] = None
self.__invoker.services.logger
while not stop_event.is_set():
poll_now_busy_event.clear()
poll_now_queue_event.clear()
poll_now_event.clear()
# do not dequeue if there is already a session running
if self.__queue_item is not None:
poll_now_busy_event.wait(POLLING_INTERVAL)
continue
# get next queue item
self.__queue_item = self.__invoker.services.session_queue.dequeue()
if self.__queue_item is None:
poll_now_queue_event.wait(POLLING_INTERVAL)
continue
queue_item = self.__invoker.services.session_queue.dequeue()
self.__invoker.services.graph_execution_manager.set(self.__queue_item.session)
self.__invoker.invoke(self.__queue_item.session, invoke_all=True)
if queue_item is not None:
# TODO: Why isn't the log level specified in dependencies.py working?
# Within the thread, it is always INFO and `logger.debug()` doesn't display.
# self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
print(f"Executing queue item {queue_item.item_id}")
self.__queue_item = queue_item
self.__invoker.services.graph_execution_manager.set(queue_item.session)
self.__invoker.invoke(queue_item.session, invoke_all=True)
queue_item = None
if queue_item is None:
# self.__invoker.services.logger.debug("Waiting for next polling interval or event")
print("Waiting for next polling interval or event")
poll_now_event.wait(POLLING_INTERVAL)
continue
except Exception:
pass
finally:
stop_event.clear()
poll_now_event.clear()
self.__queue_item = None
self.__threadLimit.release()

View File

@@ -2,11 +2,12 @@ from abc import ABC, abstractmethod
from typing import Optional
from invokeai.app.services.graph import Graph
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_common import (
QUEUE_ITEM_STATUS,
Batch,
BatchStatusResult,
CancelByBatchIDsResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
EnqueueGraphResult,
@@ -16,15 +17,16 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItem,
SessionQueueItemDTO,
SessionQueueStatusResult,
SetManyQueueItemStatusResult,
)
from invokeai.app.services.shared.models import CursorPaginatedResults
class SessionQueueBase(ABC):
"""Base class for session queue"""
@abstractmethod
def start_service(self, invoker: Invoker) -> None:
"""Startup callback for the SessionQueue service"""
def dequeue(self) -> Optional[SessionQueueItem]:
"""Dequeues the next session queue item."""
pass
@abstractmethod
@@ -38,13 +40,13 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def dequeue(self) -> Optional[SessionQueueItem]:
"""Dequeues the next session queue item, returning it if one is available."""
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
"""Gets the currently-executing session queue item"""
pass
@abstractmethod
def peek(self, queue_id: str) -> Optional[SessionQueueItem]:
"""Peeks at the next session queue item, returning it if one is available."""
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
"""Gets the next session queue item (does not dequeue it)"""
pass
@abstractmethod
@@ -67,14 +69,29 @@ class SessionQueueBase(ABC):
"""Checks if the queue is empty"""
pass
@abstractmethod
def get_queue_status(self, queue_id: str) -> SessionQueueStatusResult:
"""Gets the status of the queue"""
pass
@abstractmethod
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatusResult:
"""Gets the status of a batch"""
pass
@abstractmethod
def cancel_queue_item(self, item_id: str) -> SessionQueueItem:
"""Cancels a session queue item"""
pass
@abstractmethod
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
"""Cancels all queue items with matching batch IDs"""
pass
@abstractmethod
def get_status(self, queue_id: str) -> SessionQueueStatusResult:
"""Gets the number of queue items with each status"""
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
"""Cancels all queue items with matching queue ID"""
pass
@abstractmethod
@@ -90,7 +107,7 @@ class SessionQueueBase(ABC):
pass
@abstractmethod
def get_queue_item(self, queue_id: str, item_id: str) -> SessionQueueItem:
def get_queue_item(self, item_id: str) -> SessionQueueItem:
"""Gets a session queue item by ID"""
pass
@@ -98,20 +115,3 @@ class SessionQueueBase(ABC):
def get_queue_item_by_session_id(self, session_id: str) -> SessionQueueItem:
"""Gets a queue item by session ID"""
pass
@abstractmethod
def set_queue_item_status(self, queue_id: str, item_id: str, status: QUEUE_ITEM_STATUS) -> SessionQueueItem:
"""Sets the status of a session queue item"""
pass
@abstractmethod
def set_many_queue_item_status(
self, queue_id: str, item_ids: list[str], status: QUEUE_ITEM_STATUS
) -> SetManyQueueItemStatusResult:
"""Sets the status of a session queue item"""
pass
@abstractmethod
def delete_queue_item(self, queue_id: str, item_id: str) -> SessionQueueItem:
"""Deletes a session queue item by ID"""
pass

View File

@@ -167,6 +167,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
@classmethod
@@ -237,15 +238,21 @@ class SessionQueueStatusResult(BaseModel):
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
max_queue_size: int = Field(..., description="Maximum number of queue items allowed")
class SetManyQueueItemStatusResult(BaseModel):
item_ids: list[str] = Field(..., description="The queue item IDs that were updated")
status: QUEUE_ITEM_STATUS = Field(..., description="The new status of the queue items")
class BatchStatusResult(BaseModel):
queue_id: str = Field(..., description="The ID of the queue")
batch_id: str = Field(..., description="The ID of the batch")
pending: int = Field(..., description="Number of queue items with status 'pending'")
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
completed: int = Field(..., description="Number of queue items with status 'complete'")
failed: int = Field(..., description="Number of queue items with status 'error'")
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
total: int = Field(..., description="Total number of queue items")
class EnqueueBatchResult(BaseModel):
queue_id: str = Field(description="The ID of the queue")
enqueued: int = Field(description="The total number of queue items enqueued")
requested: int = Field(description="The total number of queue items requested to be enqueued")
batch: Batch = Field(description="The batch that was enqueued")
@@ -273,9 +280,17 @@ class PruneResult(ClearResult):
class CancelByBatchIDsResult(BaseModel):
"""Result of canceling by list of batch ids"""
canceled: int = Field(..., description="Number of queue items canceled")
class CancelByQueueIDResult(CancelByBatchIDsResult):
"""Result of canceling by queue id"""
pass
class IsEmptyResult(BaseModel):
"""Result of checking if the session queue is empty"""

View File

@@ -2,6 +2,10 @@ import sqlite3
import threading
from typing import Optional, Union, cast
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event as FastAPIEvent
from invokeai.app.services.events import EventServiceBase
from invokeai.app.services.graph import Graph
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
@@ -9,7 +13,9 @@ from invokeai.app.services.session_queue.session_queue_common import (
DEFAULT_QUEUE_ID,
QUEUE_ITEM_STATUS,
Batch,
BatchStatusResult,
CancelByBatchIDsResult,
CancelByQueueIDResult,
ClearResult,
EnqueueBatchResult,
EnqueueGraphResult,
@@ -20,7 +26,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
SessionQueueItemDTO,
SessionQueueItemNotFoundError,
SessionQueueStatusResult,
SetManyQueueItemStatusResult,
calc_session_count,
prepare_values_to_insert,
)
@@ -28,119 +33,203 @@ from invokeai.app.services.shared.models import CursorPaginatedResults
class SqliteSessionQueue(SessionQueueBase):
_invoker: Invoker
_conn: sqlite3.Connection
_cursor: sqlite3.Cursor
_lock: threading.Lock
__invoker: Invoker
__conn: sqlite3.Connection
__cursor: sqlite3.Cursor
__lock: threading.Lock
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID)
local_handler.register(event_name=EventServiceBase.session_event, _func=self._on_session_event)
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
super().__init__()
self._conn = conn
self.__conn = conn
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
self._conn.row_factory = sqlite3.Row
self._cursor = self._conn.cursor()
self._lock = lock
self.__conn.row_factory = sqlite3.Row
self.__cursor = self.__conn.cursor()
self.__lock = lock
self._create_tables()
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
return event[1]["event"] in match_in
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
event_name = event[1]["event"]
match event_name:
# successful completion events
case "graph_execution_state_complete":
await self._handle_complete_event(event)
# error events
case "invocation_error":
await self._handle_error_event(event)
case "session_retrieval_error":
await self._handle_error_event(event)
case "invocation_retrieval_error":
await self._handle_error_event(event)
# canceled events
case "session_canceled":
await self._handle_cancel_event(event)
return event
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
try:
self._lock.acquire()
self._create_tables()
self._conn.commit()
finally:
self._lock.release()
session_id = event[1]["data"]["graph_execution_state_id"]
# When a queue item has an error, we get an error event, then a completed event.
# Mark the queue item completed only if it isn't already marked completed, e.g.
# by a previously-handled error event.
queue_item = self.get_queue_item_by_session_id(session_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
except SessionQueueItemNotFoundError:
return
async def _handle_error_event(self, event: FastAPIEvent) -> None:
try:
session_id = event[1]["data"]["graph_execution_state_id"]
error = event[1]["data"]["error"]
queue_item = self.get_queue_item_by_session_id(session_id)
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
except SessionQueueItemNotFoundError:
return
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
try:
session_id = event[1]["data"]["graph_execution_state_id"]
queue_item = self.get_queue_item_by_session_id(session_id)
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
except SessionQueueItemNotFoundError:
return
def _create_tables(self) -> None:
self._cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
item_id TEXT NOT NULL PRIMARY KEY, -- the unique identifier of this queue item
order_id INTEGER NOT NULL, -- used for ordering, cursor pagination
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
field_values TEXT, -- NULL if no values are associated with this queue item
session TEXT NOT NULL, -- the session to be executed
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'complete', 'error', 'canceled'
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
error TEXT, -- any errors associated with this queue item
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
completed_at DATETIME -- completed items are cleaned up on application startup
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
);
"""
)
"""Creates the session queue tables, indicies, and triggers"""
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
CREATE TABLE IF NOT EXISTS session_queue (
item_id TEXT NOT NULL PRIMARY KEY, -- the unique identifier of this queue item
order_id INTEGER NOT NULL, -- used for ordering, cursor pagination
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
field_values TEXT, -- NULL if no values are associated with this queue item
session TEXT NOT NULL, -- the session to be executed
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'complete', 'error', 'canceled'
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
error TEXT, -- any errors associated with this queue item
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
started_at DATETIME, -- updated via trigger
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
);
"""
)
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
"""
)
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_order_id ON session_queue(order_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_order_id ON session_queue(order_id);
"""
)
self._cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
"""
)
self._cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
"""
)
self.__cursor.execute(
"""--sql
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
"""
)
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'completed'
OR NEW.status = 'failed'
OR NEW.status = 'canceled'
BEGIN
UPDATE session_queue
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'completed'
OR NEW.status = 'failed'
OR NEW.status = 'canceled'
BEGIN
UPDATE session_queue
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self._cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
AFTER UPDATE
ON session_queue FOR EACH ROW
BEGIN
UPDATE session_queue
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = old.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
AFTER UPDATE OF status ON session_queue
FOR EACH ROW
WHEN
NEW.status = 'in_progress'
BEGIN
UPDATE session_queue
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = NEW.item_id;
END;
"""
)
self.__cursor.execute(
"""--sql
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
AFTER UPDATE
ON session_queue FOR EACH ROW
BEGIN
UPDATE session_queue
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
WHERE item_id = old.item_id;
END;
"""
)
self.__conn.commit()
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled.
This is necessary because the invoker may have been killed while processing a queue item.
"""
try:
self._lock.acquire()
self._cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -148,13 +237,14 @@ class SqliteSessionQueue(SessionQueueBase):
"""
)
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
def _get_current_queue_size(self, queue_id: str) -> int:
self._cursor.execute(
"""Gets the current number of pending queue items"""
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
@@ -164,10 +254,11 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
return cast(int, self._cursor.fetchone()[0])
return cast(int, self.__cursor.fetchone()[0])
def _get_highest_priority(self, queue_id: str) -> int:
self._cursor.execute(
"""Gets the highest priority value in the queue"""
self.__cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
@@ -177,19 +268,13 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
return cast(Union[int, None], self._cursor.fetchone()[0]) or 0
def start_service(self, invoker: Invoker) -> None:
self._invoker = invoker
self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID)
self._invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
enqueue_result = self.enqueue_batch(queue_id=queue_id, batch=Batch(graph=graph), prepend=prepend)
try:
self._lock.acquire()
self._cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
@@ -198,12 +283,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id, enqueue_result.batch.batch_id),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with batch id {enqueue_result.batch.batch_id}")
return EnqueueGraphResult(
@@ -213,24 +298,24 @@ class SqliteSessionQueue(SessionQueueBase):
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
self._lock.acquire()
self.__lock.acquire()
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id=queue_id)
max_queue_size = self._invoker.services.configuration.get_config().max_queue_size
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.get_config().max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id=queue_id) + 1
priority = self._get_highest_priority(queue_id) + 1
self._cursor.execute(
self.__cursor.execute(
"""--sql
SELECT MAX(order_id)
FROM session_queue
"""
)
max_order_id = cast(Optional[int], self._cursor.fetchone()[0]) or 0
max_order_id = cast(Optional[int], self.__cursor.fetchone()[0]) or 0
requested_count = calc_session_count(batch)
values_to_insert = prepare_values_to_insert(
@@ -245,30 +330,33 @@ class SqliteSessionQueue(SessionQueueBase):
if requested_count > enqueued_count:
values_to_insert = values_to_insert[:max_new_queue_items]
self._cursor.executemany(
self.__cursor.executemany(
"""--sql
INSERT INTO session_queue (item_id, queue_id, session, session_id, batch_id, field_values, priority, order_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
return EnqueueBatchResult(
self.__lock.release()
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
enqueued=enqueued_count,
batch=batch,
priority=priority,
)
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
try:
self._lock.acquire()
self._cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
@@ -279,23 +367,23 @@ class SqliteSessionQueue(SessionQueueBase):
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
if result is None:
return None
queue_item = SessionQueueItem.from_dict(dict(result))
return self.set_queue_item_status(
queue_id=queue_item.queue_id, item_id=queue_item.item_id, status="in_progress"
)
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
return queue_item
def peek(self, queue_id: str) -> Optional[SessionQueueItem]:
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
try:
self._lock.acquire()
self._cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT *
FROM session_queue
@@ -309,82 +397,66 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
if result is None:
return None
return SessionQueueItem.from_dict(dict(result))
def set_queue_item_status(self, queue_id: str, item_id: str, status: QUEUE_ITEM_STATUS) -> SessionQueueItem:
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
try:
self._lock.acquire()
self._cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET status = ?
SELECT *
FROM session_queue
WHERE
queue_id = ?
AND item_id = ?
AND status = 'in_progress'
LIMIT 1
""",
(status, queue_id, item_id),
(queue_id,),
)
self._conn.commit()
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
return self.get_queue_item(queue_id=queue_id, item_id=item_id)
self.__lock.release()
if result is None:
return None
return SessionQueueItem.from_dict(dict(result))
def set_many_queue_item_status(
self, queue_id: str, item_ids: list[str], status: QUEUE_ITEM_STATUS
) -> SetManyQueueItemStatusResult:
def _set_queue_item_status(
self, item_id: str, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
) -> SessionQueueItem:
try:
self._lock.acquire()
# update the queue items
placeholders = ", ".join(["?" for _ in item_ids])
update_query = f"""--sql
UPDATE session_queue
SET status = ?
WHERE
queue_id in ?
AND item_id IN ({placeholders})
"""
self._cursor.execute(update_query, [queue_id, status] + item_ids)
self._conn.commit()
# get queue items from list which were set to the status successfully
fetch_query = f"""--sql
SELECT item_id
FROM session_queue
WHERE
queue_id = ?
AND status = ?
AND item_id IN ({placeholders})
"""
self._cursor.execute(fetch_query, [queue_id, status] + item_ids)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
self.__lock.acquire()
self.__cursor.execute(
"""--sql
UPDATE session_queue
SET status = ?, error = ?
WHERE
item_id = ?
""",
(status, error, item_id),
)
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
updated_ids = [row[0] for row in result]
return SetManyQueueItemStatusResult(item_ids=updated_ids, status=status)
self.__lock.release()
return self.get_queue_item(item_id)
def is_empty(self, queue_id: str) -> IsEmptyResult:
try:
self._lock.acquire()
self._cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
@@ -392,18 +464,18 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
is_empty = cast(int, self._cursor.fetchone()[0]) == 0
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
try:
self._lock.acquire()
self._cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
@@ -411,40 +483,39 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
max_queue_size = self._invoker.services.configuration.max_queue_size
is_full = cast(int, self._cursor.fetchone()[0]) >= max_queue_size
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
return IsFullResult(is_full=is_full)
def delete_queue_item(self, queue_id: str, item_id: str) -> SessionQueueItem:
queue_item = self.get_queue_item(queue_id=queue_id, item_id=item_id)
def delete_queue_item(self, item_id: str) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id=item_id)
try:
self._lock.acquire()
self._cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
DELETE FROM session_queue
WHERE
queue_id = ?
AND item_id = ?
item_id = ?
""",
(queue_id, item_id),
(item_id,),
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
return queue_item
def clear(self, queue_id: str) -> ClearResult:
try:
self._lock.acquire()
self._cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -452,8 +523,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
count = self._cursor.fetchone()[0]
self._cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
"""--sql
DELETE
FROM session_queue
@@ -461,12 +532,13 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
@@ -480,8 +552,8 @@ class SqliteSessionQueue(SessionQueueBase):
OR status = 'canceled'
)
"""
self._lock.acquire()
self._cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -489,8 +561,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
count = self._cursor.fetchone()[0]
self._cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
DELETE
FROM session_queue
@@ -498,27 +570,38 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self._conn.commit()
self.__conn.commit()
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: str) -> SessionQueueItem:
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["canceled", "failed", "completed"]:
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
self.__invoker.services.queue.cancel(queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(queue_item.session_id)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
try:
self._lock.acquire()
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
WHERE
queue_id = ?
queue_id == ?
AND batch_id IN ({placeholders})
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
"""
params = [queue_id] + batch_ids
self._cursor.execute(
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -526,8 +609,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
count = self._cursor.fetchone()[0]
self._cursor.execute(
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -535,40 +618,84 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self._conn.commit()
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.queue.cancel(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(current_queue_item.session_id)
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
return CancelByBatchIDsResult(canceled=count)
def get_queue_item(self, queue_id: str, item_id: str) -> SessionQueueItem:
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
self._lock.acquire()
self._cursor.execute(
current_queue_item = self.get_current(queue_id)
self.__lock.acquire()
where = """--sql
WHERE
queue_id is ?
AND status != 'canceled'
AND status != 'completed'
AND status != 'failed'
"""
params = [queue_id]
self.__cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
{where};
""",
tuple(params),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
{where};
""",
tuple(params),
)
self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.queue.cancel(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(current_queue_item.session_id)
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return CancelByQueueIDResult(canceled=count)
def get_queue_item(self, item_id: str) -> SessionQueueItem:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
queue_id = ?
AND item_id = ?
item_id = ?
""",
(queue_id, item_id),
(item_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.from_dict(dict(result))
def get_queue_item_by_session_id(self, session_id: str) -> SessionQueueItem:
try:
self._lock.acquire()
self._cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
@@ -576,12 +703,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(session_id,),
)
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with session id {session_id}")
return SessionQueueItem.from_dict(dict(result))
@@ -595,7 +722,7 @@ class SqliteSessionQueue(SessionQueueBase):
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
try:
self._lock.acquire()
self.__lock.acquire()
query = """--sql
SELECT item_id,
order_id,
@@ -633,8 +760,8 @@ class SqliteSessionQueue(SessionQueueBase):
LIMIT ?
"""
params.append(limit + 1)
self._cursor.execute(query, params)
results = cast(list[sqlite3.Row], self._cursor.fetchall())
self.__cursor.execute(query, params)
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
items = [SessionQueueItemDTO.from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
@@ -642,16 +769,16 @@ class SqliteSessionQueue(SessionQueueBase):
items.pop()
has_more = True
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def get_status(self, queue_id: str) -> SessionQueueStatusResult:
def get_queue_status(self, queue_id: str) -> SessionQueueStatusResult:
try:
self._lock.acquire()
self._cursor.execute(
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
@@ -660,14 +787,14 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
except Exception:
self._conn.rollback()
self.__conn.rollback()
raise
finally:
self._lock.release()
self.__lock.release()
return SessionQueueStatusResult(
queue_id=queue_id,
@@ -677,5 +804,38 @@ class SqliteSessionQueue(SessionQueueBase):
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
max_queue_size=self._invoker.services.configuration.get_config().max_queue_size,
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatusResult:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
WHERE
queue_id = ?
AND batch_id = ?
GROUP BY status
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
return BatchStatusResult(
batch_id=batch_id,
queue_id=queue_id,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),
completed=counts.get("completed", 0),
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
)

View File

@@ -229,6 +229,7 @@
"clearTooltip": "Cancel and Clear All Items",
"clearSucceeded": "Queue Cleared",
"clearFailed": "Problem Clearing Queue",
"cancelBatch": "Cancel Batch",
"cancelItem": "Cancel Item",
"cancelByBatchIdsSucceeded": "Batches Canceled",
"cancelByBatchIdsFailed": "Problem Canceling Items by Batch IDs",

View File

@@ -86,7 +86,6 @@ import { addEnqueueRequestedLinear } from './listeners/enqueueRequestedLinear';
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
import { addDynamicPromptsListener } from './listeners/promptChanged';
import { addSocketQueueItemStatusChangedEventListener } from './listeners/socketio/socketQueueItemStatusChanged';
import { addSocketQueueStatusChangedEventListener } from './listeners/socketio/socketQueueStatusChanged';
export const listenerMiddleware = createListenerMiddleware();
@@ -175,7 +174,6 @@ addModelLoadEventListener();
addSessionRetrievalErrorEventListener();
addInvocationRetrievalErrorEventListener();
addSocketQueueItemStatusChangedEventListener();
addSocketQueueStatusChangedEventListener();
// Session Created
addSessionCreatedPendingListener();

View File

@@ -24,7 +24,7 @@ export const addBatchEnqueuedListener = () => {
req.reset();
dispatch(
queueApi.endpoints.startQueueExecution.initiate(undefined, {
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'startQueue',
})
);

View File

@@ -47,7 +47,7 @@ export const addControlNetImageProcessedListener = () => {
const enqueueResult = await req.unwrap();
req.reset();
dispatch(
queueApi.endpoints.startQueueExecution.initiate(undefined, {
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'startQueue',
})
);

View File

@@ -140,7 +140,7 @@ export const addEnqueueRequestedCanvasListener = () => {
const enqueueResult = await req.unwrap();
req.reset();
dispatch(
queueApi.endpoints.startQueueExecution.initiate(undefined, {
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'startQueue',
})
);

View File

@@ -51,7 +51,7 @@ export const addEnqueueRequestedLinear = () => {
req.reset();
dispatch(
queueApi.endpoints.startQueueExecution.initiate(undefined, {
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'startQueue',
})
);

View File

@@ -35,7 +35,7 @@ export const addEnqueueRequestedNodes = () => {
req.reset();
dispatch(
queueApi.endpoints.startQueueExecution.initiate(undefined, {
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'startQueue',
})
);

View File

@@ -12,7 +12,12 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
actionCreator: socketQueueItemStatusChanged,
effect: (action, { dispatch, getState }) => {
const log = logger('socketio');
const { item_id, status: newStatus } = action.payload.data;
const {
item_id,
batch_id,
graph_execution_state_id,
status: newStatus,
} = action.payload.data;
log.debug(
action.payload,
`Queue item ${item_id} status updated: ${newStatus}`
@@ -29,7 +34,6 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
);
const state = getState();
const { batch_id, graph_execution_state_id } = action.payload.data;
if (state.canvas.batchIds.includes(batch_id)) {
dispatch(canvasSessionIdAdded(graph_execution_state_id));
}
@@ -39,6 +43,10 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'SessionQueueStatus',
'SessionProcessorStatus',
{ type: 'SessionQueueItem', id: item_id },
{ type: 'SessionQueueItemDTO', id: item_id },
{ type: 'BatchStatus', id: batch_id },
])
);
},

View File

@@ -1,37 +0,0 @@
import { logger } from 'app/logging/logger';
import { queueApi } from 'services/api/endpoints/queue';
import {
appSocketQueueStatusChanged,
socketQueueStatusChanged,
} from 'services/events/actions';
import { startAppListening } from '../..';
import { addToast } from 'features/system/store/systemSlice';
import { t } from 'i18next';
export const addSocketQueueStatusChangedEventListener = () => {
startAppListening({
actionCreator: socketQueueStatusChanged,
effect: (action, { dispatch, getOriginalState }) => {
const log = logger('socketio');
log.debug(action.payload, `Queue status updated`);
// pass along the socket event as an application action
dispatch(appSocketQueueStatusChanged(action.payload));
const { data: oldQueueStatus } =
queueApi.endpoints.getQueueStatus.select()(getOriginalState());
if (
oldQueueStatus?.started === false &&
oldQueueStatus?.stop_after_current === true &&
action.payload.data.started === false &&
action.payload.data.stop_after_current === false
) {
dispatch(
addToast({ title: t('queue.stopSucceeded'), status: 'success' })
);
}
dispatch(queueApi.util.invalidateTags(['SessionQueueStatus']));
},
});
};

View File

@@ -38,7 +38,7 @@ export const addUpscaleRequestedListener = () => {
const enqueueResult = await req.unwrap();
req.reset();
dispatch(
queueApi.endpoints.startQueueExecution.initiate(undefined, {
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'startQueue',
})
);

View File

@@ -23,7 +23,7 @@ export const enqueueBatch = async (
req.reset();
dispatch(
queueApi.endpoints.startQueueExecution.initiate(undefined, {
queueApi.endpoints.resumeProcessor.initiate(undefined, {
fixedCacheKey: 'startQueue',
})
);

View File

@@ -1,7 +1,7 @@
import {
As,
ChakraProps,
Flex,
FlexProps,
Icon,
Skeleton,
Spinner,
@@ -47,15 +47,14 @@ export const IAILoadingImageFallback = (props: Props) => {
);
};
type IAINoImageFallbackProps = {
type IAINoImageFallbackProps = FlexProps & {
label?: string;
icon?: As | null;
boxSize?: StyleProps['boxSize'];
sx?: ChakraProps['sx'];
};
export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
const { icon = FaImage, boxSize = 16 } = props;
const { icon = FaImage, boxSize = 16, sx, ...rest } = props;
return (
<Flex
@@ -73,8 +72,9 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
_dark: {
color: 'base.500',
},
...props.sx,
...sx,
}}
{...rest}
>
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
{props.label && <Text textAlign="center">{props.label}</Text>}

View File

@@ -1,72 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAINumberInput from 'common/components/IAINumberInput';
import IAISlider from 'common/components/IAISlider';
import { setIterations } from 'features/parameters/store/generationSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createSelector(
[stateSelector],
(state) => {
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
state.config.sd.iterations;
const { iterations } = state.generation;
const { shouldUseSliders } = state.ui;
const step = state.hotkeys.shift ? fineStep : coarseStep;
return {
iterations,
initial,
min,
sliderMax,
inputMax,
step,
shouldUseSliders,
};
},
defaultSelectorOptions
);
const ParamIterations = () => {
const {
iterations,
initial,
min,
sliderMax,
inputMax,
step,
shouldUseSliders,
} = useAppSelector(selector);
const dispatch = useAppDispatch();
const { t } = useTranslation();
const handleChange = useCallback(
(v: number) => {
dispatch(setIterations(v));
},
[dispatch]
);
const handleReset = useCallback(() => {
dispatch(setIterations(initial));
}, [dispatch, initial]);
return (
<IAINumberInput
// label={t('parameters.runs')}
step={step}
min={min}
max={inputMax}
onChange={handleChange}
value={iterations}
numberInputFieldProps={{ textAlign: 'center' }}
formControlProps={{ w: 36 }}
/>
);
};
export default memo(ParamIterations);

View File

@@ -1,31 +1,34 @@
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaTimes } from 'react-icons/fa';
import {
useCancelQueueExecutionMutation,
useGetQueueStatusQuery,
useCancelQueueItemMutation,
useGetCurrentQueueItemQuery,
} from 'services/api/endpoints/queue';
import QueueButton from './common/QueueButton';
import { addToast } from 'features/system/store/systemSlice';
import { useIsQueueMutationInProgress } from '../hooks/useIsQueueMutationInProgress';
import QueueButton from './common/QueueButton';
type Props = {
asIconButton?: boolean;
};
const CancelQueueButton = ({ asIconButton }: Props) => {
const CancelCurrentQueueItemButton = ({ asIconButton }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const { data: queueStatusData } = useGetQueueStatusQuery();
const [cancelQueue] = useCancelQueueExecutionMutation({
fixedCacheKey: 'cancelQueue',
const { data: currentQueueItem } = useGetCurrentQueueItemQuery();
const [cancelQueueItem] = useCancelQueueItemMutation({
fixedCacheKey: 'cancelQueueItem',
});
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const handleClick = useCallback(async () => {
if (!currentQueueItem) {
return;
}
try {
await cancelQueue().unwrap();
await cancelQueueItem(currentQueueItem.item_id).unwrap();
dispatch(
addToast({
title: t('queue.cancelSucceeded'),
@@ -40,22 +43,19 @@ const CancelQueueButton = ({ asIconButton }: Props) => {
})
);
}
}, [cancelQueue, dispatch, t]);
}, [cancelQueueItem, currentQueueItem, dispatch, t]);
return (
<QueueButton
asIconButton={asIconButton}
label={t('queue.cancel')}
tooltip={t('queue.cancelTooltip')}
isDisabled={
!(queueStatusData?.started || queueStatusData?.stop_after_current) ||
isQueueMutationInProgress
}
isDisabled={!currentQueueItem || isQueueMutationInProgress}
icon={<FaTimes />}
onClick={handleClick}
colorScheme="orange"
colorScheme="error"
/>
);
};
export default memo(CancelQueueButton);
export default memo(CancelCurrentQueueItemButton);

View File

@@ -1,11 +1,11 @@
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { usePeekNextQueueItemQuery } from 'services/api/endpoints/queue';
import { useGetNextQueueItemQuery } from 'services/api/endpoints/queue';
import QueueItemCard from './common/QueueItemCard';
const NextQueueItemCard = () => {
const { t } = useTranslation();
const { data: nextQueueItemData } = usePeekNextQueueItemQuery();
const { data: nextQueueItemData } = useGetNextQueueItemQuery();
return (
<QueueItemCard

View File

@@ -34,11 +34,10 @@ const QueueBackButton = () => {
);
return (
<IAIButton
isDisabled={!isReady}
isDisabled={!isReady || isQueueMutationInProgress}
colorScheme="accent"
onClick={handleEnqueue}
tooltip={<EnqueueButtonTooltip />}
isLoading={isQueueMutationInProgress}
flexGrow={3}
minW={44}
>

View File

@@ -8,7 +8,6 @@ import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useEnqueueBatchMutation } from 'services/api/endpoints/queue';
import { useBoardName } from 'services/api/hooks/useBoardName';
import { usePredictedQueueCounts } from '../hooks/usePredictedQueueCounts';
const tooltipSelector = createSelector(
[stateSelector],
@@ -33,7 +32,6 @@ const QueueButtonTooltipContent = ({ prepend = false }: Props) => {
const [_, { isLoading }] = useEnqueueBatchMutation({
fixedCacheKey: 'enqueueBatch',
});
const counts = usePredictedQueueCounts();
const label = useMemo(() => {
if (isLoading) {
@@ -41,12 +39,12 @@ const QueueButtonTooltipContent = ({ prepend = false }: Props) => {
}
if (isReady) {
if (prepend) {
return t('queue.queueFront', { predicted: counts?.predicted ?? '?' });
return t('queue.queueFront');
}
return t('queue.queueBack', { predicted: counts?.predicted ?? '?' });
return t('queue.queueBack');
}
return t('queue.notReady');
}, [counts?.predicted, isLoading, isReady, prepend, t]);
}, [isLoading, isReady, prepend, t]);
return (
<Flex flexDir="column" gap={1}>

View File

@@ -1,12 +1,10 @@
import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
import ParamRuns from 'features/parameters/components/Parameters/Core/ParamRuns';
import CancelQueueButton from 'features/queue/components/CancelQueueButton';
import { ButtonGroup, Flex, Spacer, Text } from '@chakra-ui/react';
import CancelCurrentQueueItemButton from 'features/queue/components/CancelCurrentQueueItemButton';
import ClearQueueButton from 'features/queue/components/ClearQueueButton';
import QueueBackButton from 'features/queue/components/QueueBackButton';
import QueueFrontButton from 'features/queue/components/QueueFrontButton';
import StartQueueButton from 'features/queue/components/StartQueueButton';
import StopQueueButton from 'features/queue/components/StopQueueButton';
import { usePredictedQueueCounts } from 'features/queue/hooks/usePredictedQueueCounts';
import ResumeProcessorButton from 'features/queue/components/StartQueueButton';
import PauseProcessorButton from 'features/queue/components/StopQueueButton';
import ProgressBar from 'features/system/components/ProgressBar';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -29,11 +27,11 @@ const QueueControls = () => {
<ButtonGroup isAttached flexGrow={2}>
<QueueBackButton />
<QueueFrontButton />
<CancelCurrentQueueItemButton asIconButton />
</ButtonGroup>
<ButtonGroup isAttached>
<StartQueueButton asIconButton />
<StopQueueButton asIconButton />
<CancelQueueButton asIconButton />
<ResumeProcessorButton asIconButton />
<PauseProcessorButton asIconButton />
<ClearQueueButton asIconButton />
</ButtonGroup>
</Flex>
@@ -48,49 +46,35 @@ const QueueControls = () => {
export default memo(QueueControls);
const QueueCounts = memo(() => {
const counts = usePredictedQueueCounts();
const { data: queueStatus } = useGetQueueStatusQuery();
const { hasItems, pending } = useGetQueueStatusQuery(undefined, {
selectFromResult: ({ data }) => {
if (!data) {
return {
hasItems: false,
pending: 0,
};
}
const { pending, in_progress } = data;
return {
hasItems: pending + in_progress > 0,
pending,
};
},
});
const { t } = useTranslation();
if (!counts || !queueStatus) {
return null;
}
const { requested, predicted, max_queue_size } = counts;
const { pending, in_progress } = queueStatus;
return (
<Flex justifyContent="space-between" alignItems="center">
<ParamRuns />
{/* <Tooltip
label={
requested > predicted &&
t('queue.queueMaxExceeded', {
requested,
skip: requested - predicted,
max_queue_size,
})
}
>
<Text
variant="subtext"
fontSize="sm"
fontWeight={400}
fontStyle="oblique 10deg"
opacity={0.7}
color={requested > predicted ? 'warning.500' : undefined}
>
{t('queue.queueCountPrediction', { predicted })}
</Text>
</Tooltip> */}
<Spacer />
<Text
variant="subtext"
fontSize="sm"
fontWeight={400}
fontStyle="oblique 10deg"
opacity={0.7}
pe={1}
>
{pending + in_progress > 0
{hasItems
? t('queue.queuedCount', {
pending,
})

View File

@@ -36,10 +36,9 @@ const QueueFrontButton = () => {
<IAIIconButton
colorScheme="base"
aria-label={t('queue.queueFront')}
isDisabled={!isReady}
isDisabled={!isReady || isQueueMutationInProgress}
onClick={handleEnqueue}
tooltip={<EnqueueButtonTooltip prepend />}
isLoading={isQueueMutationInProgress}
icon={<FaBoltLightning />}
/>
);

View File

@@ -0,0 +1,135 @@
import {
ButtonGroup,
ChakraProps,
Collapse,
Flex,
Text,
} from '@chakra-ui/react';
import IAIIconButton from 'common/components/IAIIconButton';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaTimes } from 'react-icons/fa';
import { useCancelQueueItemMutation } from 'services/api/endpoints/queue';
import { SessionQueueItemDTO } from 'services/api/types';
import QueueStatusBadge from '../common/QueueStatusBadge';
import QueueItemDetail from './QueueItemDetail';
import { COLUMN_WIDTHS } from './constants';
import { ListContext } from './types';
const selectedStyles = { bg: 'base.300', _dark: { bg: 'base.750' } };
type InnerItemProps = {
index: number;
item: SessionQueueItemDTO;
context: ListContext;
};
const sx: ChakraProps['sx'] = {
_hover: selectedStyles,
"&[aria-selected='true']": selectedStyles,
};
const QueueItemComponent = ({ index, item, context }: InnerItemProps) => {
const { t } = useTranslation();
const handleToggle = useCallback(() => {
context.toggleQueueItem(item.item_id);
}, [context, item.item_id]);
const [cancelQueueItem, { isLoading: isLoadingCancelQueueItem }] =
useCancelQueueItemMutation();
const handleCancelQueueItem = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
cancelQueueItem(item.item_id);
},
[cancelQueueItem, item.item_id]
);
const isOpen = useMemo(
() => context.openQueueItems.includes(item.item_id),
[context.openQueueItems, item.item_id]
);
return (
<Flex
flexDir="column"
borderRadius="base"
aria-selected={isOpen}
fontSize="sm"
justifyContent="center"
sx={sx}
>
<Flex
alignItems="center"
gap={4}
p={1.5}
cursor="pointer"
onClick={handleToggle}
>
<Flex
w={COLUMN_WIDTHS.number}
justifyContent="flex-end"
alignItems="center"
>
<Text variant="subtext">{index + 1}</Text>
</Flex>
<Flex w={COLUMN_WIDTHS.statusBadge} alignItems="center">
<QueueStatusBadge status={item.status} />
</Flex>
<Flex w={COLUMN_WIDTHS.batchId}>
<Text
overflow="hidden"
textOverflow="ellipsis"
whiteSpace="nowrap"
alignItems="center"
>
{item.batch_id}
</Text>
</Flex>
<Flex alignItems="center" flexGrow={1}>
{item.field_values && (
<Flex gap={2}>
{item.field_values
.filter((v) => v.node_path !== 'metadata_accumulator')
.map(({ node_path, field_name, value }) => (
<Text
key={`${item.item_id}.${node_path}.${field_name}.${value}`}
whiteSpace="nowrap"
textOverflow="ellipsis"
overflow="hidden"
>
<Text as="span" fontWeight={600}>
{node_path}.{field_name}
</Text>
: {value}
</Text>
))}
</Flex>
)}
</Flex>
<Flex alignItems="center" w={COLUMN_WIDTHS.actions} pe={3}>
<ButtonGroup size="xs" variant="ghost">
<IAIIconButton
tooltip={t('queue.cancelItem')}
onClick={handleCancelQueueItem}
isLoading={isLoadingCancelQueueItem}
isDisabled={['canceled', 'completed', 'failed'].includes(
item.status
)}
aria-label={t('queue.cancelItem')}
icon={<FaTimes />}
/>
</ButtonGroup>
</Flex>
</Flex>
<Collapse
in={isOpen}
transition={{ enter: { duration: 0.1 }, exit: { duration: 0.1 } }}
unmountOnExit={true}
>
<QueueItemDetail queueItemDTO={item} />
</Collapse>
</Flex>
);
};
export default memo(QueueItemComponent);

View File

@@ -0,0 +1,166 @@
import { ButtonGroup, Flex, Heading, Spinner, Text } from '@chakra-ui/react';
import IAIButton from 'common/components/IAIButton';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import ScrollableContent from 'features/nodes/components/sidePanel/ScrollableContent';
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
import { MouseEvent, memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { FaTimes } from 'react-icons/fa';
import {
useCancelByBatchIdsMutation,
useCancelQueueItemMutation,
useGetBatchStatusQuery,
useGetQueueItemQuery,
} from 'services/api/endpoints/queue';
import { SessionQueueItemDTO } from 'services/api/types';
type Props = {
queueItemDTO: SessionQueueItemDTO;
};
const QueueItemComponent = ({ queueItemDTO }: Props) => {
const {
batch_id,
completed_at,
error,
item_id,
session_id,
started_at,
status,
} = queueItemDTO;
const { t } = useTranslation();
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const [cancelQueueItem, { isLoading: isLoadingCancelQueueItem }] =
useCancelQueueItemMutation();
const [cancelByBatchIds, { isLoading: isLoadingCancelByBatchIds }] =
useCancelByBatchIdsMutation();
const { isCanceled } = useGetBatchStatusQuery(
{ batch_id: queueItemDTO.batch_id },
{
selectFromResult: ({ data }) => {
if (!data) {
return { isCanceled: true };
}
return {
isCanceled: data?.in_progress === 0 && data?.pending === 0,
};
},
}
);
const handleCancelQueueItem = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
cancelQueueItem(item_id);
},
[cancelQueueItem, item_id]
);
const handleCancelBatch = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
cancelByBatchIds({ batch_ids: [batch_id] });
},
[cancelByBatchIds, batch_id]
);
const { data: queueItem } = useGetQueueItemQuery(item_id);
const executionTime = useMemo(() => {
if (!completed_at || !started_at) {
return 'n/a';
}
return String(
((Date.parse(completed_at) - Date.parse(started_at)) / 1000).toFixed(2)
);
}, [completed_at, started_at]);
return (
<Flex
layerStyle="third"
flexDir="column"
p={2}
pt={0}
borderRadius="base"
gap={2}
>
<Flex
layerStyle="second"
p={2}
gap={2}
justifyContent="space-between"
alignItems="center"
borderRadius="base"
>
<QueueItemData label="Item ID" data={item_id} />
<QueueItemData label="Batch ID" data={batch_id} />
<QueueItemData label="Session ID" data={session_id} />
<QueueItemData label="Execution Time" data={executionTime} />
<ButtonGroup size="xs" orientation="vertical">
<IAIButton
onClick={handleCancelQueueItem}
isLoading={isLoadingCancelQueueItem}
isDisabled={['canceled', 'completed', 'failed'].includes(status)}
aria-label={t('queue.cancelItem')}
icon={<FaTimes />}
colorScheme="error"
>
{t('queue.cancelItem')}
</IAIButton>
<IAIButton
onClick={handleCancelBatch}
isLoading={isLoadingCancelByBatchIds || isQueueMutationInProgress}
isDisabled={isCanceled}
aria-label={t('queue.cancelBatch')}
icon={<FaTimes />}
colorScheme="error"
>
{t('queue.cancelBatch')}
</IAIButton>
</ButtonGroup>
</Flex>
{error && (
<Flex
layerStyle="second"
p={3}
gap={1}
justifyContent="space-between"
alignItems="flex-start"
borderRadius="base"
flexDir="column"
>
<Heading size="sm" color="error.500" _dark={{ color: 'error.400' }}>
Error
</Heading>
<pre>{error}</pre>
</Flex>
)}
<Flex
layerStyle="second"
h={512}
w="full"
borderRadius="base"
alignItems="center"
justifyContent="center"
>
{queueItem ? (
<ScrollableContent>
<DataViewer label="Queue Item" data={queueItem} />
</ScrollableContent>
) : (
<Spinner opacity={0.5} />
)}
</Flex>
</Flex>
);
};
export default memo(QueueItemComponent);
type QueueItemDataProps = { label: string; data: string };
const QueueItemData = ({ label, data }: QueueItemDataProps) => {
return (
<Flex flexDir="column" p={1} gap={1}>
<Heading size="sm">{label}</Heading>
<Text>{data}</Text>
</Flex>
);
};

View File

@@ -1,20 +1,8 @@
import {
Box,
ChakraProps,
Collapse,
Flex,
Text,
forwardRef,
} from '@chakra-ui/react';
import { Flex, Heading } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { skipToken } from '@reduxjs/toolkit/dist/query';
import { stateSelector } from 'app/store/store';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAIIconButton from 'common/components/IAIIconButton';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import ScrollableContent from 'features/nodes/components/sidePanel/ScrollableContent';
import {
listCursorChanged,
listPriorityChanged,
@@ -23,26 +11,18 @@ import {
UseOverlayScrollbarsParams,
useOverlayScrollbars,
} from 'overlayscrollbars-react';
import {
MouseEvent,
memo,
useCallback,
useEffect,
useMemo,
useRef,
useState,
} from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { FaTimes } from 'react-icons/fa';
import { Components, ItemContent, Virtuoso } from 'react-virtuoso';
import {
queueItemsAdapter,
useCancelQueueItemMutation,
useGetQueueItemQuery,
useListQueueItemsQuery,
} from 'services/api/endpoints/queue';
import { SessionQueueItemDTO } from 'services/api/types';
import QueueStatusBadge from '../common/QueueStatusBadge';
import QueueItemComponent from './QueueItemComponent';
import QueueListComponent from './QueueListComponent';
import QueueListHeader from './QueueListHeader';
import { ListContext } from './types';
// eslint-disable-next-line @typescript-eslint/no-explicit-any
type TableVirtuosoScrollerRef = (ref: HTMLElement | Window | null) => any;
@@ -69,39 +49,19 @@ const selector = createSelector(
defaultSelectorOptions
);
const COLUMN_WIDTHS = {
number: '3rem',
statusBadge: '5.7rem',
batchId: '5rem',
fieldValues: 'auto',
actions: 'auto',
};
const computeItemKey = (index: number, item: SessionQueueItemDTO): string =>
item.item_id;
type ListContext = {
openQueueItems: string[];
toggleQueueItem: (item_id: string) => void;
};
const ListComponent: Components<SessionQueueItemDTO, ListContext>['List'] =
memo(
forwardRef((props, ref) => {
return (
<Flex {...props} ref={ref} flexDirection="column">
{props.children}
</Flex>
);
})
);
ListComponent.displayName = 'ListComponent';
const components: Components<SessionQueueItemDTO, ListContext> = {
List: ListComponent,
List: QueueListComponent,
};
const itemContent: ItemContent<SessionQueueItemDTO, ListContext> = (
index,
item,
context
) => <QueueItemComponent index={index} item={item} context={context} />;
const QueueList = () => {
const { listCursor, listPriority } = useAppSelector(selector);
const dispatch = useAppDispatch();
@@ -110,6 +70,7 @@ const QueueList = () => {
const [initialize, osInstance] = useOverlayScrollbars(
overlayScrollbarsConfig
);
const { t } = useTranslation();
useEffect(() => {
const { current: root } = rootRef;
@@ -165,36 +126,16 @@ const QueueList = () => {
);
return (
<Box w="full" h="full">
<Flex w="full" h="full" flexDir="column">
<QueueListHeader />
<Flex
ref={rootRef}
w="full"
h="full"
alignItems="center"
gap={4}
p={1}
pb={2}
textTransform="uppercase"
fontWeight={700}
fontSize="xs"
letterSpacing={1}
justifyContent="center"
>
<Flex
w={COLUMN_WIDTHS.number}
justifyContent="flex-end"
alignItems="center"
>
<Text variant="subtext">#</Text>
</Flex>
<Flex w={COLUMN_WIDTHS.statusBadge} alignItems="center">
<Text variant="subtext">status</Text>
</Flex>
<Flex w={COLUMN_WIDTHS.batchId} alignItems="center">
<Text variant="subtext">batch</Text>
</Flex>
<Flex alignItems="center" w={COLUMN_WIDTHS.fieldValues}>
<Text variant="subtext">batch field values</Text>
</Flex>
</Flex>
<Box ref={rootRef} w="full" h="full">
{listQueueItemsData && (
{queueItems.length ? (
<Virtuoso<SessionQueueItemDTO, ListContext>
data={queueItems}
endReached={handleLoadMore}
@@ -203,144 +144,15 @@ const QueueList = () => {
computeItemKey={computeItemKey}
components={components}
context={context}
style={{ display: 'flex', flexDirection: 'column', gap: '2px' }}
/>
) : (
<Heading color="base.400" _dark={{ color: 'base.500' }}>
{t('queue.queueEmpty')}
</Heading>
)}
</Box>
</Box>
</Flex>
</Flex>
);
};
export default memo(QueueList);
const selectedStyles = { bg: 'base.300', _dark: { bg: 'base.750' } };
const itemContent: ItemContent<SessionQueueItemDTO, ListContext> = (
index,
item,
context
) => <InnerItem index={index} item={item} context={context} />;
type InnerItemProps = {
index: number;
item: SessionQueueItemDTO;
context: ListContext;
};
const sx: ChakraProps['sx'] = {
_hover: selectedStyles,
"&[aria-selected='true']": selectedStyles,
};
const InnerItem = memo(({ index, item, context }: InnerItemProps) => {
const { t } = useTranslation();
const handleToggle = useCallback(() => {
context.toggleQueueItem(item.item_id);
}, [context, item.item_id]);
const [cancelQueueItem, { isLoading }] = useCancelQueueItemMutation();
const handleCancel = useCallback(
(e: MouseEvent<HTMLButtonElement>) => {
e.stopPropagation();
cancelQueueItem(item.item_id);
},
[cancelQueueItem, item.item_id]
);
const isOpen = useMemo(
() => context.openQueueItems.includes(item.item_id),
[context.openQueueItems, item.item_id]
);
const { data: queueItem } = useGetQueueItemQuery(
isOpen ? item.item_id : skipToken
);
return (
<Flex
flexDir="column"
borderRadius="base"
aria-selected={isOpen}
fontSize="sm"
justifyContent="center"
sx={sx}
>
<Flex
alignItems="center"
gap={4}
p={1}
cursor="pointer"
onClick={handleToggle}
>
<Flex
w={COLUMN_WIDTHS.number}
justifyContent="flex-end"
alignItems="center"
>
<Text variant="subtext">{index + 1}</Text>
</Flex>
<Flex w={COLUMN_WIDTHS.statusBadge} alignItems="center">
<QueueStatusBadge status={item.status} />
</Flex>
<Flex w={COLUMN_WIDTHS.batchId}>
<Text
overflow="hidden"
textOverflow="ellipsis"
whiteSpace="nowrap"
alignItems="center"
>
{item.batch_id}
</Text>
</Flex>
<Flex alignItems="center" flexGrow={1}>
{item.field_values && (
<Flex gap={2}>
{item.field_values
.filter((v) => v.node_path !== 'metadata_accumulator')
.map(({ node_path, field_name, value }) => (
<Text
key={`${item.item_id}.${node_path}.${field_name}.${value}`}
whiteSpace="nowrap"
textOverflow="ellipsis"
overflow="hidden"
>
<Text as="span" fontWeight={600}>
{node_path}.{field_name}
</Text>
: {value}
</Text>
))}
</Flex>
)}
</Flex>
<Flex alignItems="center" w={COLUMN_WIDTHS.actions}>
<IAIIconButton
tooltip={t('queue.cancelItem')}
onClick={handleCancel}
isLoading={isLoading}
isDisabled={['canceled', 'completed', 'failed'].includes(
item.status
)}
aria-label={t('queue.cancelItem')}
size="xs"
variant="ghost"
icon={<FaTimes />}
/>
</Flex>
</Flex>
<Collapse in={isOpen}>
<Flex layerStyle="third" p={2} pt={0} borderRadius="base">
<Flex h={512} w="full" pos="relative">
{queueItem ? (
<ScrollableContent>
<DataViewer label="Queue Item" data={queueItem} />
</ScrollableContent>
) : (
<IAINoContentFallback label="Loading" icon={null} />
)}
</Flex>
</Flex>
</Collapse>
</Flex>
);
});
InnerItem.displayName = 'InnerItem';

View File

@@ -0,0 +1,18 @@
import { Flex, forwardRef } from '@chakra-ui/react';
import { memo } from 'react';
import { Components } from 'react-virtuoso';
import { SessionQueueItemDTO } from 'services/api/types';
import { ListContext } from './types';
const QueueListComponent: Components<SessionQueueItemDTO, ListContext>['List'] =
memo(
forwardRef((props, ref) => {
return (
<Flex {...props} ref={ref} flexDirection="column" gap={0.5}>
{props.children}
</Flex>
);
})
);
export default memo(QueueListComponent);

View File

@@ -0,0 +1,37 @@
import { Flex, Text } from '@chakra-ui/react';
import { memo } from 'react';
import { COLUMN_WIDTHS } from './constants';
const QueueListHeader = () => {
return (
<Flex
alignItems="center"
gap={4}
p={1}
pb={2}
textTransform="uppercase"
fontWeight={700}
fontSize="xs"
letterSpacing={1}
>
<Flex
w={COLUMN_WIDTHS.number}
justifyContent="flex-end"
alignItems="center"
>
<Text variant="subtext">#</Text>
</Flex>
<Flex w={COLUMN_WIDTHS.statusBadge} alignItems="center">
<Text variant="subtext">status</Text>
</Flex>
<Flex w={COLUMN_WIDTHS.batchId} alignItems="center">
<Text variant="subtext">batch</Text>
</Flex>
<Flex alignItems="center" w={COLUMN_WIDTHS.fieldValues}>
<Text variant="subtext">batch field values</Text>
</Flex>
</Flex>
);
};
export default memo(QueueListHeader);

View File

@@ -0,0 +1,7 @@
export const COLUMN_WIDTHS = {
number: '3rem',
statusBadge: '5.7rem',
batchId: '5rem',
fieldValues: 'auto',
actions: 'auto',
};

View File

@@ -0,0 +1,4 @@
export type ListContext = {
openQueueItems: string[];
toggleQueueItem: (item_id: string) => void;
};

View File

@@ -4,8 +4,8 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaPlay } from 'react-icons/fa';
import {
useGetQueueStatusQuery,
useStartQueueExecutionMutation,
useGetProcessorStatusQuery,
useResumeProcessorMutation,
} from 'services/api/endpoints/queue';
import { useIsQueueMutationInProgress } from '../hooks/useIsQueueMutationInProgress';
import QueueButton from './common/QueueButton';
@@ -14,18 +14,18 @@ type Props = {
asIconButton?: boolean;
};
const StartQueueButton = ({ asIconButton }: Props) => {
const { data: queueStatusData } = useGetQueueStatusQuery();
const ResumeProcessorButton = ({ asIconButton }: Props) => {
const { data: processorStatus } = useGetProcessorStatusQuery();
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [startQueue] = useStartQueueExecutionMutation({
fixedCacheKey: 'startQueue',
const [resumeProcessor] = useResumeProcessorMutation({
fixedCacheKey: 'resumeProcessor',
});
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const handleClick = useCallback(async () => {
try {
await startQueue().unwrap();
await resumeProcessor().unwrap();
dispatch(
addToast({
title: t('queue.startSucceeded'),
@@ -40,7 +40,7 @@ const StartQueueButton = ({ asIconButton }: Props) => {
})
);
}
}, [dispatch, startQueue, t]);
}, [dispatch, resumeProcessor, t]);
return (
<QueueButton
@@ -48,9 +48,8 @@ const StartQueueButton = ({ asIconButton }: Props) => {
label={t('queue.start')}
tooltip={t('queue.startTooltip')}
isDisabled={
queueStatusData?.started ||
queueStatusData?.stop_after_current ||
queueStatusData?.pending === 0 ||
processorStatus?.is_started ||
processorStatus?.is_processing ||
isQueueMutationInProgress
}
icon={<FaPlay />}
@@ -60,4 +59,4 @@ const StartQueueButton = ({ asIconButton }: Props) => {
);
};
export default memo(StartQueueButton);
export default memo(ResumeProcessorButton);

View File

@@ -4,8 +4,8 @@ import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { FaStop } from 'react-icons/fa';
import {
useGetQueueStatusQuery,
useStopQueueExecutionMutation,
useGetProcessorStatusQuery,
usePauseProcessorMutation,
} from 'services/api/endpoints/queue';
import { useIsQueueMutationInProgress } from '../hooks/useIsQueueMutationInProgress';
import QueueButton from './common/QueueButton';
@@ -14,18 +14,18 @@ type Props = {
asIconButton?: boolean;
};
const StopQueueButton = ({ asIconButton }: Props) => {
const { data: queueStatusData } = useGetQueueStatusQuery();
const PauseProcessorButton = ({ asIconButton }: Props) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const [stopQueue] = useStopQueueExecutionMutation({
fixedCacheKey: 'stopQueue',
const { data: processorStatus } = useGetProcessorStatusQuery();
const [pauseProcessor] = usePauseProcessorMutation({
fixedCacheKey: 'pauseProcessor',
});
const isQueueMutationInProgress = useIsQueueMutationInProgress();
const handleClick = useCallback(async () => {
try {
await stopQueue().unwrap();
await pauseProcessor().unwrap();
dispatch(
addToast({
title: t('queue.stopRequested'),
@@ -40,14 +40,15 @@ const StopQueueButton = ({ asIconButton }: Props) => {
})
);
}
}, [dispatch, stopQueue, t]);
}, [dispatch, pauseProcessor, t]);
return (
<QueueButton
asIconButton={asIconButton}
label={t('queue.stop')}
tooltip={t('queue.stopTooltip')}
isDisabled={!queueStatusData?.started || isQueueMutationInProgress}
isDisabled={!processorStatus?.is_started || isQueueMutationInProgress}
isLoading={processorStatus?.is_stop_pending}
icon={<FaStop />}
onClick={handleClick}
colorScheme="gold"
@@ -55,4 +56,4 @@ const StopQueueButton = ({ asIconButton }: Props) => {
);
};
export default memo(StopQueueButton);
export default memo(PauseProcessorButton);

View File

@@ -1,10 +1,10 @@
import { ButtonGroup, ButtonGroupProps, Flex } from '@chakra-ui/react';
import { memo } from 'react';
import CancelQueueButton from './CancelQueueButton';
import CancelCurrentQueueItemButton from './CancelCurrentQueueItemButton';
import ClearQueueButton from './ClearQueueButton';
import PruneQueueButton from './PruneQueueButton';
import StartQueueButton from './StartQueueButton';
import StopQueueButton from './StopQueueButton';
import ResumeProcessorButton from './StartQueueButton';
import PauseProcessorButton from './StopQueueButton';
type Props = ButtonGroupProps & {
asIconButtons?: boolean;
@@ -14,9 +14,9 @@ const VerticalQueueControls = ({ asIconButtons, ...rest }: Props) => {
return (
<Flex flexDir="column" gap={2}>
<ButtonGroup w="full" isAttached {...rest}>
<StartQueueButton asIconButton={asIconButtons} />
<StopQueueButton asIconButton={asIconButtons} />
<CancelQueueButton asIconButton={asIconButtons} />
<ResumeProcessorButton asIconButton={asIconButtons} />
<PauseProcessorButton asIconButton={asIconButtons} />
<CancelCurrentQueueItemButton asIconButton={asIconButtons} />
</ButtonGroup>
<ButtonGroup w="full" isAttached {...rest}>
<PruneQueueButton asIconButton={asIconButtons} />

View File

@@ -1,12 +1,12 @@
import {
useCancelByBatchIdsMutation,
useCancelQueueExecutionMutation,
useCancelQueueItemMutation,
// useCancelByBatchIdsMutation,
useClearQueueMutation,
useEnqueueBatchMutation,
useEnqueueGraphMutation,
usePruneQueueMutation,
useStartQueueExecutionMutation,
useStopQueueExecutionMutation,
useResumeProcessorMutation,
usePauseProcessorMutation,
} from 'services/api/endpoints/queue';
export const useIsQueueMutationInProgress = () => {
@@ -18,17 +18,17 @@ export const useIsQueueMutationInProgress = () => {
useEnqueueGraphMutation({
fixedCacheKey: 'enqueueGraph',
});
const [_triggerStartQueue, { isLoading: isLoadingStartQueue }] =
useStartQueueExecutionMutation({
fixedCacheKey: 'startQueue',
const [_triggerResumeProcessor, { isLoading: isLoadingResumeProcessor }] =
useResumeProcessorMutation({
fixedCacheKey: 'resumeProcessor',
});
const [_triggerStopQueue, { isLoading: isLoadingStopQueue }] =
useStopQueueExecutionMutation({
fixedCacheKey: 'stopQueue',
const [_triggerPauseProcessor, { isLoading: isLoadingPauseProcessor }] =
usePauseProcessorMutation({
fixedCacheKey: 'pauseProcessor',
});
const [_triggerCancelQueue, { isLoading: isLoadingCancelQueue }] =
useCancelQueueExecutionMutation({
fixedCacheKey: 'cancelQueue',
useCancelQueueItemMutation({
fixedCacheKey: 'cancelQueueItem',
});
const [_triggerClearQueue, { isLoading: isLoadingClearQueue }] =
useClearQueueMutation({
@@ -38,18 +38,18 @@ export const useIsQueueMutationInProgress = () => {
usePruneQueueMutation({
fixedCacheKey: 'pruneQueue',
});
const [_triggerCancelByBatchIds, { isLoading: isLoadingCancelByBatchIds }] =
useCancelByBatchIdsMutation({
fixedCacheKey: 'cancelByBatchIds',
});
// const [_triggerCancelByBatchIds, { isLoading: isLoadingCancelByBatchIds }] =
// useCancelByBatchIdsMutation({
// fixedCacheKey: 'cancelByBatchIds',
// });
return (
isLoadingEnqueueBatch ||
isLoadingEnqueueGraph ||
isLoadingStartQueue ||
isLoadingStopQueue ||
isLoadingResumeProcessor ||
isLoadingPauseProcessor ||
isLoadingCancelQueue ||
isLoadingClearQueue ||
isLoadingPruneQueue ||
isLoadingCancelByBatchIds
isLoadingPruneQueue
// isLoadingCancelByBatchIds
);
};

View File

@@ -1,13 +1,13 @@
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
import { useGetProcessorStatusQuery } from 'services/api/endpoints/queue';
export const useIsQueueStarted = () => {
const { isStarted } = useGetQueueStatusQuery(undefined, {
const { isStarted } = useGetProcessorStatusQuery(undefined, {
selectFromResult: ({ data }) => {
if (!data) {
return { isStarted: false };
}
return { isStarted: data.started || data.stop_after_current };
return { isStarted: data.is_started || data.is_processing };
},
});

View File

@@ -1,34 +0,0 @@
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import { useMemo } from 'react';
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
const selector = createSelector(
[stateSelector, activeTabNameSelector],
({ dynamicPrompts, generation }, activeTabName) => {
if (activeTabName === 'nodes') {
return generation.iterations;
}
return dynamicPrompts.prompts.length * generation.iterations;
}
);
export const usePredictedQueueCounts = () => {
const { data: queueStatus } = useGetQueueStatusQuery();
const requested = useAppSelector(selector);
const counts = useMemo(() => {
if (!queueStatus) {
return;
}
const { max_queue_size, pending } = queueStatus;
const maxNew = max_queue_size - pending;
return {
requested,
max_queue_size,
predicted: Math.min(requested, maxNew),
};
}, [queueStatus, requested]);
return counts;
};

View File

@@ -41,6 +41,7 @@ const SDXLImageToImageTabCoreParameters = () => {
>
{shouldUseSliders ? (
<>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
<ParamModelandVAEandScheduler />
@@ -52,6 +53,7 @@ const SDXLImageToImageTabCoreParameters = () => {
) : (
<>
<Flex gap={3}>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
</Flex>

View File

@@ -39,6 +39,7 @@ const SDXLUnifiedCanvasTabCoreParameters = () => {
>
{shouldUseSliders ? (
<>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
<ParamModelandVAEandScheduler />
@@ -50,6 +51,7 @@ const SDXLUnifiedCanvasTabCoreParameters = () => {
) : (
<>
<Flex gap={3}>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
</Flex>

View File

@@ -1,11 +1,11 @@
import { Progress } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { useIsQueueStarted } from 'features/queue/hooks/useIsQueueStarted';
import { SystemState } from 'features/system/store/systemSlice';
import { isEqual } from 'lodash-es';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetProcessorStatusQuery } from 'services/api/endpoints/queue';
import { systemSelector } from '../store/systemSelectors';
const progressBarSelector = createSelector(
@@ -24,22 +24,22 @@ const progressBarSelector = createSelector(
const ProgressBar = () => {
const { t } = useTranslation();
const isStarted = useIsQueueStarted();
const { data: processorStatus } = useGetProcessorStatusQuery();
const { currentStep, totalSteps, currentStatusHasSteps } =
useAppSelector(progressBarSelector);
const value = useMemo(() => {
if (currentStep && isStarted) {
if (currentStep && processorStatus?.is_processing) {
return Math.round((currentStep * 100) / totalSteps);
}
return 0;
}, [currentStep, isStarted, totalSteps]);
}, [currentStep, processorStatus?.is_processing, totalSteps]);
return (
<Progress
value={value}
aria-label={t('accessibility.invokeProgressBar')}
isIndeterminate={isStarted && !currentStatusHasSteps}
isIndeterminate={processorStatus?.is_processing && !currentStatusHasSteps}
h="full"
w="full"
borderRadius={2}

View File

@@ -441,6 +441,5 @@ const isAnyServerError = isAnyOf(
const isAnyCancelQueueItem = isAnyOf(
queueApi.endpoints.cancelQueueItem.matchFulfilled,
queueApi.endpoints.cancelQueueExecution.matchFulfilled,
queueApi.endpoints.clearQueue.matchFulfilled
);

View File

@@ -41,6 +41,7 @@ const ImageToImageTabCoreParameters = () => {
>
{shouldUseSliders ? (
<>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
<ParamModelandVAEandScheduler />
@@ -52,6 +53,7 @@ const ImageToImageTabCoreParameters = () => {
) : (
<>
<Flex gap={3}>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
</Flex>

View File

@@ -42,6 +42,7 @@ const TextToImageTabCoreParameters = () => {
>
{shouldUseSliders ? (
<>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
<ParamModelandVAEandScheduler />
@@ -53,6 +54,7 @@ const TextToImageTabCoreParameters = () => {
) : (
<>
<Flex gap={3}>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
</Flex>

View File

@@ -39,6 +39,7 @@ const UnifiedCanvasCoreParameters = () => {
>
{shouldUseSliders ? (
<>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
<ParamModelandVAEandScheduler />
@@ -50,6 +51,7 @@ const UnifiedCanvasCoreParameters = () => {
) : (
<>
<Flex gap={3}>
<ParamIterations />
<ParamSteps />
<ParamCFGScale />
</Flex>

View File

@@ -70,6 +70,7 @@ export const queueApi = api.injectEndpoints({
}),
invalidatesTags: [
'SessionQueueStatus',
'SessionProcessorStatus',
'CurrentSessionQueueItem',
'NextSessionQueueItem',
],
@@ -94,6 +95,7 @@ export const queueApi = api.injectEndpoints({
}),
invalidatesTags: [
'SessionQueueStatus',
'SessionProcessorStatus',
'CurrentSessionQueueItem',
'NextSessionQueueItem',
],
@@ -107,31 +109,19 @@ export const queueApi = api.injectEndpoints({
}
},
}),
startQueueExecution: build.mutation<void, void>({
resumeProcessor: build.mutation<void, void>({
query: () => ({
url: `queue/${$queueId.get()}/start`,
url: `queue/${$queueId.get()}/resume`,
method: 'PUT',
}),
invalidatesTags: ['SessionQueueStatus', 'CurrentSessionQueueItem'],
invalidatesTags: ['SessionProcessorStatus', 'CurrentSessionQueueItem'],
}),
stopQueueExecution: build.mutation<void, void>({
pauseProcessor: build.mutation<void, void>({
query: () => ({
url: `queue/${$queueId.get()}/stop`,
url: `queue/${$queueId.get()}/pause`,
method: 'PUT',
}),
invalidatesTags: ['SessionQueueStatus', 'CurrentSessionQueueItem'],
}),
cancelQueueExecution: build.mutation<void, void>({
query: () => ({
url: `queue/${$queueId.get()}/cancel`,
method: 'PUT',
}),
invalidatesTags: [
'SessionQueueStatus',
'SessionQueueItem',
'SessionQueueItemDTO',
'CurrentSessionQueueItem',
],
invalidatesTags: ['SessionProcessorStatus', 'CurrentSessionQueueItem'],
}),
pruneQueue: build.mutation<
paths['/api/v1/queue/{queue_id}/prune']['put']['responses']['200']['content']['application/json'],
@@ -143,6 +133,8 @@ export const queueApi = api.injectEndpoints({
}),
invalidatesTags: [
'SessionQueueStatus',
'SessionProcessorStatus',
'BatchStatus',
'SessionQueueItem',
'SessionQueueItemDTO',
],
@@ -166,6 +158,8 @@ export const queueApi = api.injectEndpoints({
}),
invalidatesTags: [
'SessionQueueStatus',
'SessionProcessorStatus',
'BatchStatus',
'CurrentSessionQueueItem',
'NextSessionQueueItem',
'SessionQueueItem',
@@ -197,12 +191,12 @@ export const queueApi = api.injectEndpoints({
return tags;
},
}),
peekNextQueueItem: build.query<
paths['/api/v1/queue/{queue_id}/peek']['get']['responses']['200']['content']['application/json'],
getNextQueueItem: build.query<
paths['/api/v1/queue/{queue_id}/next']['get']['responses']['200']['content']['application/json'],
void
>({
query: () => ({
url: `queue/${$queueId.get()}/peek`,
url: `queue/${$queueId.get()}/next`,
method: 'GET',
}),
providesTags: (result) => {
@@ -223,6 +217,31 @@ export const queueApi = api.injectEndpoints({
}),
providesTags: ['SessionQueueStatus'],
}),
getProcessorStatus: build.query<
paths['/api/v1/queue/{queue_id}/processor/status']['get']['responses']['200']['content']['application/json'],
void
>({
query: () => ({
url: `queue/${$queueId.get()}/processor/status`,
method: 'GET',
}),
providesTags: ['SessionProcessorStatus'],
}),
getBatchStatus: build.query<
paths['/api/v1/queue/{queue_id}/b/{batch_id}/status']['get']['responses']['200']['content']['application/json'],
{ batch_id: string }
>({
query: ({ batch_id }) => ({
url: `queue/${$queueId.get()}/b/${batch_id}/status`,
method: 'GET',
}),
providesTags: (result) => {
if (!result) {
return [];
}
return [{ type: 'BatchStatus', id: result.batch_id }];
},
}),
getQueueItem: build.query<
paths['/api/v1/queue/{queue_id}/i/{item_id}']['get']['responses']['200']['content']['application/json'],
string
@@ -272,6 +291,7 @@ export const queueApi = api.injectEndpoints({
return [
{ type: 'SessionQueueItem', id: result.item_id },
{ type: 'SessionQueueItemDTO', id: result.item_id },
{ type: 'BatchStatus', id: result.batch_id },
];
},
}),
@@ -293,7 +313,11 @@ export const queueApi = api.injectEndpoints({
// no-op
}
},
invalidatesTags: ['SessionQueueItem', 'SessionQueueItemDTO'],
invalidatesTags: [
'SessionQueueItem',
'SessionQueueItemDTO',
'BatchStatus',
],
}),
listQueueItems: build.query<
EntityState<components['schemas']['SessionQueueItemDTO']> & {
@@ -334,17 +358,18 @@ export const {
useCancelByBatchIdsMutation,
useEnqueueGraphMutation,
useEnqueueBatchMutation,
useCancelQueueExecutionMutation,
useStopQueueExecutionMutation,
useStartQueueExecutionMutation,
usePauseProcessorMutation,
useResumeProcessorMutation,
useClearQueueMutation,
usePruneQueueMutation,
useGetCurrentQueueItemQuery,
useGetQueueStatusQuery,
useGetQueueItemQuery,
usePeekNextQueueItemQuery,
useGetNextQueueItemQuery,
useListQueueItemsQuery,
useCancelQueueItemMutation,
useGetProcessorStatusQuery,
useGetBatchStatusQuery,
} = queueApi;
// eslint-disable-next-line @typescript-eslint/no-explicit-any

View File

@@ -22,6 +22,8 @@ export const tagTypes = [
'SessionQueueItemDTO',
'SessionQueueItemDTOList',
'SessionQueueStatus',
'SessionProcessorStatus',
'BatchStatus',
];
export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>;
export const LIST_TAG = 'LIST';

View File

@@ -328,26 +328,19 @@ export type paths = {
*/
get: operations["list_queue_items"];
};
"/api/v1/queue/{queue_id}/start": {
"/api/v1/queue/{queue_id}/resume": {
/**
* Start
* @description Starts session queue execution
* Resume
* @description Resumes session processor
*/
put: operations["start"];
put: operations["resume"];
};
"/api/v1/queue/{queue_id}/stop": {
"/api/v1/queue/{queue_id}/pause": {
/**
* Stop
* @description Stops session queue execution, waiting for the currently executing session to finish
* Pause
* @description Pauses session processor
*/
put: operations["stop"];
};
"/api/v1/queue/{queue_id}/cancel": {
/**
* Cancel
* @description Stops session queue execution, immediately canceling the currently-executing session
*/
put: operations["cancel"];
put: operations["pause"];
};
"/api/v1/queue/{queue_id}/cancel_by_batch_ids": {
/**
@@ -372,24 +365,38 @@ export type paths = {
};
"/api/v1/queue/{queue_id}/current": {
/**
* Current
* Get Current Queue Item
* @description Gets the currently execution queue item
*/
get: operations["current"];
get: operations["get_current_queue_item"];
};
"/api/v1/queue/{queue_id}/peek": {
"/api/v1/queue/{queue_id}/next": {
/**
* Peek
* Get Next Queue Item
* @description Gets the next queue item, without executing it
*/
get: operations["peek"];
get: operations["get_next_queue_item"];
};
"/api/v1/queue/{queue_id}/status": {
/**
* Get Status
* Get Queue Status
* @description Gets the status of the session queue
*/
get: operations["get_status"];
get: operations["get_queue_status"];
};
"/api/v1/queue/{queue_id}/processor/status": {
/**
* Get Processor Status
* @description Gets the status of the session queue
*/
get: operations["get_processor_status"];
};
"/api/v1/queue/{queue_id}/b/{batch_id}/status": {
/**
* Get Batch Status
* @description Gets the status of the session queue
*/
get: operations["get_batch_status"];
};
"/api/v1/queue/{queue_id}/i/{item_id}": {
/**
@@ -549,6 +556,49 @@ export type components = {
*/
items?: (string | number)[];
};
/** BatchStatusResult */
BatchStatusResult: {
/**
* Queue Id
* @description The ID of the queue
*/
queue_id: string;
/**
* Batch Id
* @description The ID of the batch
*/
batch_id: string;
/**
* Pending
* @description Number of queue items with status 'pending'
*/
pending: number;
/**
* In Progress
* @description Number of queue items with status 'in_progress'
*/
in_progress: number;
/**
* Completed
* @description Number of queue items with status 'complete'
*/
completed: number;
/**
* Failed
* @description Number of queue items with status 'error'
*/
failed: number;
/**
* Canceled
* @description Number of queue items with status 'canceled'
*/
canceled: number;
/**
* Total
* @description Total number of queue items
*/
total: number;
};
/**
* Blank Image
* @description Creates a blank image and forwards it to the pipeline
@@ -1016,7 +1066,10 @@ export type components = {
*/
type: "infill_cv2";
};
/** CancelByBatchIDsResult */
/**
* CancelByBatchIDsResult
* @description Result of canceling by list of batch ids
*/
CancelByBatchIDsResult: {
/**
* Canceled
@@ -1124,11 +1177,6 @@ export type components = {
* @description The workflow to save with the image
*/
workflow?: string;
/**
* CLIP
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
/**
* Skipped Layers
* @description Number of layers to skip in text encoder
@@ -1141,6 +1189,11 @@ export type components = {
* @enum {string}
*/
type: "clip_skip";
/**
* CLIP
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
*/
clip?: components["schemas"]["ClipField"];
};
/**
* ClipSkipInvocationOutput
@@ -2338,6 +2391,11 @@ export type components = {
};
/** EnqueueBatchResult */
EnqueueBatchResult: {
/**
* Queue Id
* @description The ID of the queue
*/
queue_id: string;
/**
* Enqueued
* @description The total number of queue items enqueued
@@ -6890,58 +6948,23 @@ export type components = {
*/
type: "segment_anything_processor";
};
/** SessionQueueAndExecutionStatusResult */
SessionQueueAndExecutionStatusResult: {
/** SessionProcessorStatusResult */
SessionProcessorStatusResult: {
/**
* Started
* @description Whether the session queue is running
* Is Started
* @description Whether the session processor is started
*/
started: boolean;
is_started: boolean;
/**
* Stop After Current
* @description Whether the session queue is pending a stop
* Is Processing
* @description Whether a session is being processed
*/
stop_after_current: boolean;
is_processing: boolean;
/**
* Queue Id
* @description The ID of the queue
* Is Stop Pending
* @description Whether processor is pending stopping
*/
queue_id: string;
/**
* Pending
* @description Number of queue items with status 'pending'
*/
pending: number;
/**
* In Progress
* @description Number of queue items with status 'in_progress'
*/
in_progress: number;
/**
* Completed
* @description Number of queue items with status 'complete'
*/
completed: number;
/**
* Failed
* @description Number of queue items with status 'error'
*/
failed: number;
/**
* Canceled
* @description Number of queue items with status 'canceled'
*/
canceled: number;
/**
* Total
* @description Total number of queue items
*/
total: number;
/**
* Max Queue Size
* @description Maximum number of queue items allowed
*/
max_queue_size: number;
is_stop_pending: boolean;
};
/**
* SessionQueueItem
@@ -7006,6 +7029,11 @@ export type components = {
* @description When this queue item was updated
*/
updated_at: string;
/**
* Started At
* @description When this queue item was started
*/
started_at?: string;
/**
* Completed At
* @description When this queue item was completed
@@ -7080,12 +7108,55 @@ export type components = {
* @description When this queue item was updated
*/
updated_at: string;
/**
* Started At
* @description When this queue item was started
*/
started_at?: string;
/**
* Completed At
* @description When this queue item was completed
*/
completed_at?: string;
};
/** SessionQueueStatusResult */
SessionQueueStatusResult: {
/**
* Queue Id
* @description The ID of the queue
*/
queue_id: string;
/**
* Pending
* @description Number of queue items with status 'pending'
*/
pending: number;
/**
* In Progress
* @description Number of queue items with status 'in_progress'
*/
in_progress: number;
/**
* Completed
* @description Number of queue items with status 'complete'
*/
completed: number;
/**
* Failed
* @description Number of queue items with status 'error'
*/
failed: number;
/**
* Canceled
* @description Number of queue items with status 'canceled'
*/
canceled: number;
/**
* Total
* @description Total number of queue items
*/
total: number;
};
/**
* Show Image
* @description Displays a provided image using the OS image viewer, and passes it forward in the pipeline.
@@ -8078,6 +8149,12 @@ export type components = {
/** Ui Order */
ui_order?: number;
};
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
@@ -8102,12 +8179,6 @@ export type components = {
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;
@@ -9622,10 +9693,10 @@ export type operations = {
};
};
/**
* Start
* @description Starts session queue execution
* Resume
* @description Resumes session processor
*/
start: {
resume: {
parameters: {
path: {
/** @description The queue id to perform this operation on */
@@ -9648,36 +9719,10 @@ export type operations = {
};
};
/**
* Stop
* @description Stops session queue execution, waiting for the currently executing session to finish
* Pause
* @description Pauses session processor
*/
stop: {
parameters: {
path: {
/** @description The queue id to perform this operation on */
queue_id: string;
};
};
responses: {
/** @description Successful Response */
200: {
content: {
"application/json": unknown;
};
};
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
/**
* Cancel
* @description Stops session queue execution, immediately canceling the currently-executing session
*/
cancel: {
pause: {
parameters: {
path: {
/** @description The queue id to perform this operation on */
@@ -9783,10 +9828,10 @@ export type operations = {
};
};
/**
* Current
* Get Current Queue Item
* @description Gets the currently execution queue item
*/
current: {
get_current_queue_item: {
parameters: {
path: {
/** @description The queue id to perform this operation on */
@@ -9809,10 +9854,10 @@ export type operations = {
};
};
/**
* Peek
* Get Next Queue Item
* @description Gets the next queue item, without executing it
*/
peek: {
get_next_queue_item: {
parameters: {
path: {
/** @description The queue id to perform this operation on */
@@ -9835,10 +9880,10 @@ export type operations = {
};
};
/**
* Get Status
* Get Queue Status
* @description Gets the status of the session queue
*/
get_status: {
get_queue_status: {
parameters: {
path: {
/** @description The queue id to perform this operation on */
@@ -9849,7 +9894,61 @@ export type operations = {
/** @description Successful Response */
200: {
content: {
"application/json": components["schemas"]["SessionQueueAndExecutionStatusResult"];
"application/json": components["schemas"]["SessionQueueStatusResult"];
};
};
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
/**
* Get Processor Status
* @description Gets the status of the session queue
*/
get_processor_status: {
parameters: {
path: {
/** @description The queue id to perform this operation on */
queue_id: string;
};
};
responses: {
/** @description Successful Response */
200: {
content: {
"application/json": components["schemas"]["SessionProcessorStatusResult"];
};
};
/** @description Validation Error */
422: {
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
/**
* Get Batch Status
* @description Gets the status of the session queue
*/
get_batch_status: {
parameters: {
path: {
/** @description The queue id to perform this operation on */
queue_id: string;
/** @description The batch to get the status of */
batch_id: string;
};
};
responses: {
/** @description Successful Response */
200: {
content: {
"application/json": components["schemas"]["BatchStatusResult"];
};
};
/** @description Validation Error */

View File

@@ -78,7 +78,7 @@ def mock_invoker(mock_services: InvocationServices) -> Invoker:
def test_can_create_graph_state(mock_invoker: Invoker):
g = mock_invoker.create_execution_state()
mock_invoker.stop_service()
mock_invoker.stop()
assert g is not None
assert isinstance(g, GraphExecutionState)
@@ -86,7 +86,7 @@ def test_can_create_graph_state(mock_invoker: Invoker):
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph)
mock_invoker.stop_service()
mock_invoker.stop()
assert g is not None
assert isinstance(g, GraphExecutionState)
@@ -104,7 +104,7 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
return len(g.executed) > 0
wait_until(lambda: has_executed_any(g), timeout=5, interval=1)
mock_invoker.stop_service()
mock_invoker.stop()
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert len(g.executed) > 0
@@ -121,7 +121,7 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
return g.is_complete()
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
mock_invoker.stop_service()
mock_invoker.stop()
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.is_complete()
@@ -139,7 +139,7 @@ def test_handles_errors(mock_invoker: Invoker):
return g.is_complete()
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
mock_invoker.stop_service()
mock_invoker.stop()
g = mock_invoker.services.graph_execution_manager.get(g.id)
assert g.has_error()