mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(nodes,ui): fully migrate queue to session_processor
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from logging import Logger
|
||||
|
||||
@@ -11,9 +12,8 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_record_storage import SqliteImageRecordStorage
|
||||
from invokeai.app.services.images import ImageService, ImageServiceDependencies
|
||||
from invokeai.app.services.resource_name import SimpleNameService
|
||||
from invokeai.app.services.session_execution.session_execution_default import DefaultSessionExecutionService
|
||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor
|
||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from invokeai.app.services.urls import LocalUrlService
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.version.invokeai_version import __version__
|
||||
@@ -59,6 +59,7 @@ class ApiDependencies:
|
||||
|
||||
@staticmethod
|
||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger):
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.info(f"InvokeAI version {__version__}")
|
||||
logger.info(f"Root directory = {str(config.root_path)}")
|
||||
logger.debug(f"Internet connectivity is {config.internet_available}")
|
||||
@@ -70,8 +71,8 @@ class ApiDependencies:
|
||||
# TODO: build a file/path manager?
|
||||
db_path = config.db_path
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# db_location = str(db_path)
|
||||
db_location = ":memory:"
|
||||
db_location = str(db_path)
|
||||
# db_location = ":memory:"
|
||||
|
||||
logger.info(f"Using database at {db_location}")
|
||||
db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution
|
||||
@@ -140,7 +141,6 @@ class ApiDependencies:
|
||||
logger=logger,
|
||||
session_queue=SqliteSessionQueue(conn=db_conn, lock=lock),
|
||||
session_processor=DefaultSessionProcessor(),
|
||||
session_execution=DefaultSessionExecutionService(),
|
||||
)
|
||||
|
||||
create_system_graphs(services.graph_library)
|
||||
@@ -160,4 +160,4 @@ class ApiDependencies:
|
||||
print("SHUTTING DOWN")
|
||||
if ApiDependencies.invoker:
|
||||
print("SHUTTING DOWN INVOKER")
|
||||
ApiDependencies.invoker.stop_service()
|
||||
ApiDependencies.invoker.stop()
|
||||
|
||||
@@ -23,7 +23,7 @@ class FastAPIEventService(EventServiceBase):
|
||||
|
||||
super().__init__()
|
||||
|
||||
def stop_service(self, *args, **kwargs):
|
||||
def stop(self, *args, **kwargs):
|
||||
self.__stop_event.set()
|
||||
self.__queue.put(None)
|
||||
|
||||
|
||||
@@ -3,10 +3,11 @@ from typing import Optional
|
||||
from fastapi import Body, Path, Query
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from invokeai.app.services.session_execution.session_execution_common import SessionExecutionStatusResult
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatusResult
|
||||
from invokeai.app.services.session_queue.session_queue_common import ( # CancelByBatchIDsResult,
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
BatchStatusResult,
|
||||
CancelByBatchIDsResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
@@ -24,7 +25,9 @@ from ..dependencies import ApiDependencies
|
||||
session_queue_router = APIRouter(prefix="/v1/queue", tags=["queue"])
|
||||
|
||||
|
||||
class SessionQueueAndExecutionStatusResult(SessionQueueStatusResult, SessionExecutionStatusResult):
|
||||
class SessionQueueAndProcessorStatusResult(SessionQueueStatusResult, SessionProcessorStatusResult):
|
||||
"""The overall status of session queue and processor"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -84,42 +87,25 @@ async def list_queue_items(
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/start",
|
||||
operation_id="start",
|
||||
"/{queue_id}/resume",
|
||||
operation_id="resume",
|
||||
)
|
||||
async def start(
|
||||
async def resume(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> None:
|
||||
"""Starts session queue execution"""
|
||||
return ApiDependencies.invoker.services.session_execution.start(
|
||||
queue_id=queue_id,
|
||||
)
|
||||
"""Resumes session processor"""
|
||||
return ApiDependencies.invoker.services.session_processor.resume()
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/stop",
|
||||
operation_id="stop",
|
||||
"/{queue_id}/pause",
|
||||
operation_id="pause",
|
||||
)
|
||||
async def stop(
|
||||
async def Pause(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> None:
|
||||
"""Stops session queue execution, waiting for the currently executing session to finish"""
|
||||
return ApiDependencies.invoker.services.session_execution.stop(
|
||||
queue_id=queue_id,
|
||||
)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/cancel",
|
||||
operation_id="cancel",
|
||||
)
|
||||
async def cancel(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> None:
|
||||
"""Stops session queue execution, immediately canceling the currently-executing session"""
|
||||
return ApiDependencies.invoker.services.session_execution.cancel(
|
||||
queue_id=queue_id,
|
||||
)
|
||||
"""Pauses session processor"""
|
||||
return ApiDependencies.invoker.services.session_processor.pause()
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -132,13 +118,6 @@ async def cancel_by_batch_ids(
|
||||
batch_ids: list[str] = Body(description="The list of batch_ids to cancel all queue items for", embed=True),
|
||||
) -> CancelByBatchIDsResult:
|
||||
"""Immediately cancels all queue items from the given batch ids"""
|
||||
current = ApiDependencies.invoker.services.session_execution.get_current(
|
||||
queue_id=queue_id,
|
||||
)
|
||||
if current is not None and current.batch_id in batch_ids:
|
||||
ApiDependencies.invoker.services.session_execution.cancel(
|
||||
queue_id=queue_id,
|
||||
)
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_by_batch_ids(queue_id=queue_id, batch_ids=batch_ids)
|
||||
|
||||
|
||||
@@ -153,12 +132,11 @@ async def clear(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> ClearResult:
|
||||
"""Clears the queue entirely, immediately canceling the currently-executing session"""
|
||||
ApiDependencies.invoker.services.session_execution.cancel(
|
||||
queue_id=queue_id,
|
||||
)
|
||||
return ApiDependencies.invoker.services.session_queue.clear(
|
||||
queue_id=queue_id,
|
||||
)
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
if queue_item is not None:
|
||||
ApiDependencies.invoker.services.session_queue.cancel_queue_item(queue_item.item_id)
|
||||
clear_result = ApiDependencies.invoker.services.session_queue.clear(queue_id)
|
||||
return clear_result
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -172,62 +150,78 @@ async def prune(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> PruneResult:
|
||||
"""Prunes all completed or errored queue items"""
|
||||
return ApiDependencies.invoker.services.session_queue.prune(
|
||||
queue_id=queue_id,
|
||||
)
|
||||
return ApiDependencies.invoker.services.session_queue.prune(queue_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/current",
|
||||
operation_id="current",
|
||||
operation_id="get_current_queue_item",
|
||||
responses={
|
||||
200: {"model": Optional[SessionQueueItem]},
|
||||
},
|
||||
)
|
||||
async def current(
|
||||
async def get_current_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the currently execution queue item"""
|
||||
return ApiDependencies.invoker.services.session_execution.get_current(
|
||||
queue_id=queue_id,
|
||||
)
|
||||
return ApiDependencies.invoker.services.session_queue.get_current(queue_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/peek",
|
||||
operation_id="peek",
|
||||
"/{queue_id}/next",
|
||||
operation_id="get_next_queue_item",
|
||||
responses={
|
||||
200: {"model": Optional[SessionQueueItem]},
|
||||
},
|
||||
)
|
||||
async def peek(
|
||||
async def get_next_queue_item(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the next queue item, without executing it"""
|
||||
return ApiDependencies.invoker.services.session_queue.peek(
|
||||
queue_id=queue_id,
|
||||
)
|
||||
return ApiDependencies.invoker.services.session_queue.get_next(queue_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/status",
|
||||
operation_id="get_status",
|
||||
operation_id="get_queue_status",
|
||||
responses={
|
||||
200: {"model": SessionQueueAndExecutionStatusResult},
|
||||
200: {"model": SessionQueueStatusResult},
|
||||
},
|
||||
)
|
||||
async def get_status(
|
||||
async def get_queue_status(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionQueueAndExecutionStatusResult:
|
||||
) -> SessionQueueStatusResult:
|
||||
"""Gets the status of the session queue"""
|
||||
queue_status = ApiDependencies.invoker.services.session_queue.get_status(
|
||||
queue_id=queue_id,
|
||||
)
|
||||
execution_status = ApiDependencies.invoker.services.session_execution.get_status(
|
||||
queue_id=queue_id,
|
||||
)
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_status(queue_id)
|
||||
|
||||
return SessionQueueAndExecutionStatusResult(**queue_status.dict(), **execution_status.dict())
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/processor/status",
|
||||
operation_id="get_processor_status",
|
||||
responses={
|
||||
200: {"model": SessionProcessorStatusResult},
|
||||
},
|
||||
)
|
||||
async def get_processor_status(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
) -> SessionProcessorStatusResult:
|
||||
"""Gets the status of the session queue"""
|
||||
return ApiDependencies.invoker.services.session_processor.get_status()
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
"/{queue_id}/b/{batch_id}/status",
|
||||
operation_id="get_batch_status",
|
||||
responses={
|
||||
200: {"model": BatchStatusResult},
|
||||
},
|
||||
)
|
||||
async def get_batch_status(
|
||||
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||
batch_id: str = Path(description="The batch to get the status of"),
|
||||
) -> BatchStatusResult:
|
||||
"""Gets the status of the session queue"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_batch_status(queue_id=queue_id, batch_id=batch_id)
|
||||
|
||||
|
||||
@session_queue_router.get(
|
||||
@@ -242,7 +236,7 @@ async def get_queue_item(
|
||||
item_id: str = Path(description="The queue item to get"),
|
||||
) -> SessionQueueItem:
|
||||
"""Gets a queue item"""
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item(queue_id=queue_id, item_id=item_id)
|
||||
return ApiDependencies.invoker.services.session_queue.get_queue_item(item_id)
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
@@ -257,27 +251,5 @@ async def cancel_queue_item(
|
||||
item_id: str = Path(description="The queue item to cancel"),
|
||||
) -> SessionQueueItem:
|
||||
"""Deletes a queue item"""
|
||||
queue_item = ApiDependencies.invoker.services.session_queue.get_queue_item(queue_id=queue_id, item_id=item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
return ApiDependencies.invoker.services.session_queue.set_queue_item_status(
|
||||
queue_id=queue_id, item_id=item_id, status="canceled"
|
||||
)
|
||||
return queue_item
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/start_processor",
|
||||
operation_id="start_processor",
|
||||
)
|
||||
async def start_processor() -> None:
|
||||
"""Deletes a queue item"""
|
||||
ApiDependencies.invoker.services.session_processor.start()
|
||||
|
||||
|
||||
@session_queue_router.put(
|
||||
"/{queue_id}/stop_processor",
|
||||
operation_id="stop_processor",
|
||||
)
|
||||
async def stop_processor() -> None:
|
||||
"""Deletes a queue item"""
|
||||
ApiDependencies.invoker.services.session_processor.stop()
|
||||
return ApiDependencies.invoker.services.session_queue.cancel_queue_item(item_id)
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any, Optional
|
||||
|
||||
from invokeai.app.models.image import ProgressImage
|
||||
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.app.services.session_queue.session_queue_common import EnqueueBatchResult, SessionQueueItem
|
||||
from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
|
||||
@@ -195,8 +195,20 @@ class EventServiceBase:
|
||||
),
|
||||
)
|
||||
|
||||
def emit_session_canceled(
|
||||
self,
|
||||
graph_execution_state_id: str,
|
||||
) -> None:
|
||||
"""Emitted when a session is canceled"""
|
||||
self.__emit_session_event(
|
||||
event_name="session_canceled",
|
||||
payload=dict(
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_queue_item_status_changed(self, session_queue_item: SessionQueueItem) -> None:
|
||||
"""Emitted when a queue item is status_changed"""
|
||||
"""Emitted when a queue item's status changes"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_item_status_changed",
|
||||
payload=dict(
|
||||
@@ -207,3 +219,17 @@ class EventServiceBase:
|
||||
status=session_queue_item.status,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_batch_enqueued(self, enqueue_result: EnqueueBatchResult) -> None:
|
||||
"""Emitted when a batch is enqueued"""
|
||||
self.__emit_queue_event(
|
||||
event_name="batch_enqueued",
|
||||
payload=enqueue_result.dict(),
|
||||
)
|
||||
|
||||
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||
"""Emitted when the queue is cleared"""
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_cleared",
|
||||
payload=dict(queue_id=queue_id),
|
||||
)
|
||||
|
||||
@@ -18,7 +18,6 @@ if TYPE_CHECKING:
|
||||
from invokeai.app.services.item_storage import ItemStorageABC
|
||||
from invokeai.app.services.latent_storage import LatentsStorageBase
|
||||
from invokeai.app.services.model_manager_service import ModelManagerServiceBase
|
||||
from invokeai.app.services.session_execution.session_execution_base import SessionExecutionServiceBase
|
||||
from invokeai.app.services.session_processor.session_processor_base import SessionProcessorBase
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
|
||||
@@ -42,7 +41,6 @@ class InvocationServices:
|
||||
queue: "InvocationQueueABC"
|
||||
session_queue: "SessionQueueBase"
|
||||
session_processor: "SessionProcessorBase"
|
||||
session_execution: "SessionExecutionServiceBase"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -61,7 +59,6 @@ class InvocationServices:
|
||||
queue: "InvocationQueueABC",
|
||||
session_queue: "SessionQueueBase",
|
||||
session_processor: "SessionProcessorBase",
|
||||
session_execution: "SessionExecutionServiceBase",
|
||||
):
|
||||
self.board_images = board_images
|
||||
self.boards = boards
|
||||
@@ -78,4 +75,3 @@ class InvocationServices:
|
||||
self.queue = queue
|
||||
self.session_queue = session_queue
|
||||
self.session_processor = session_processor
|
||||
self.session_execution = session_execution
|
||||
|
||||
@@ -52,14 +52,14 @@ class Invoker:
|
||||
self.services.queue.cancel(graph_execution_state_id)
|
||||
|
||||
def __start_service(self, service) -> None:
|
||||
# Call start_service() method on any services that have it
|
||||
start_op = getattr(service, "start_service", None)
|
||||
# Call start() method on any services that have it
|
||||
start_op = getattr(service, "start", None)
|
||||
if callable(start_op):
|
||||
start_op(self)
|
||||
|
||||
def __stop_service(self, service) -> None:
|
||||
# Call stop_service() method on any services that have it
|
||||
stop_op = getattr(service, "stop_service", None)
|
||||
# Call stop() method on any services that have it
|
||||
stop_op = getattr(service, "stop", None)
|
||||
if callable(stop_op):
|
||||
stop_op(self)
|
||||
|
||||
@@ -68,7 +68,7 @@ class Invoker:
|
||||
for service in vars(self.services):
|
||||
self.__start_service(getattr(self.services, service))
|
||||
|
||||
def stop_service(self) -> None:
|
||||
def stop(self) -> None:
|
||||
"""Stops the invoker. A new invoker will have to be created to execute further."""
|
||||
# First stop all services
|
||||
for service in vars(self.services):
|
||||
|
||||
@@ -17,7 +17,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
__invoker: Invoker
|
||||
__threadLimit: BoundedSemaphore
|
||||
|
||||
def start_service(self, invoker) -> None:
|
||||
def start(self, invoker) -> None:
|
||||
# if we do want multithreading at some point, we could make this configurable
|
||||
self.__threadLimit = BoundedSemaphore(1)
|
||||
self.__invoker = invoker
|
||||
@@ -30,7 +30,7 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
self.__invoker_thread.daemon = True # TODO: make async and do not use threads
|
||||
self.__invoker_thread.start()
|
||||
|
||||
def stop_service(self, *args, **kwargs) -> None:
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
def __process(self, stop_event: Event):
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_execution.session_execution_common import SessionExecutionStatusResult
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
|
||||
class SessionExecutionServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def start_service(self, invoker: Invoker) -> None:
|
||||
"""Service startup"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start(
|
||||
self,
|
||||
queue_id: str,
|
||||
) -> None:
|
||||
"""Starts session queue execution"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def stop(
|
||||
self,
|
||||
queue_id: str,
|
||||
) -> None:
|
||||
"""Stops session queue execution after the currently executing session finishes"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel(
|
||||
self,
|
||||
queue_id: str,
|
||||
) -> None:
|
||||
"""Stops session queue execution, immediately canceling the currently-executing session"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_current(
|
||||
self,
|
||||
queue_id: str,
|
||||
) -> Optional[SessionQueueItem]:
|
||||
"""Gets the currently-executing queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_status(
|
||||
self,
|
||||
queue_id: str,
|
||||
) -> SessionExecutionStatusResult:
|
||||
"""Gets the status of the session queue"""
|
||||
pass
|
||||
@@ -1,6 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SessionExecutionStatusResult(BaseModel):
|
||||
started: bool = Field(..., description="Whether the session queue is running")
|
||||
stop_after_current: bool = Field(..., description="Whether the session queue is pending a stop")
|
||||
@@ -1,107 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_execution.session_execution_base import SessionExecutionServiceBase
|
||||
from invokeai.app.services.session_execution.session_execution_common import SessionExecutionStatusResult
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
|
||||
class DefaultSessionExecutionService(SessionExecutionServiceBase):
|
||||
def __init__(self) -> None:
|
||||
self._invoker: Invoker
|
||||
self._current: Optional[SessionQueueItem] = None
|
||||
self._started: bool = False
|
||||
self._stop_after_current = False
|
||||
|
||||
def start_service(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._on_event)
|
||||
|
||||
async def _on_event(self, event: Event) -> Event:
|
||||
event_name = event[1]["event"]
|
||||
match event_name:
|
||||
case "graph_execution_state_complete":
|
||||
await self._handle_complete_event(event, False)
|
||||
case "invocation_error":
|
||||
await self._handle_complete_event(event, True)
|
||||
case "session_retrieval_error":
|
||||
await self._handle_complete_event(event, True)
|
||||
case "invocation_retrieval_error":
|
||||
await self._handle_complete_event(event, True)
|
||||
return event
|
||||
|
||||
async def _handle_complete_event(self, event: Event, err: bool) -> None:
|
||||
data = event[1]["data"]
|
||||
queue_item = self._invoker.services.session_queue.get_queue_item_by_session_id(data["graph_execution_state_id"])
|
||||
# Sessions are marked complete when they have an error, so we get an `invocation_error`
|
||||
# followed by a `graph_execution_state_complete`. Don't mark queue items complete if
|
||||
# they are already marked error.
|
||||
if queue_item.status != "failed":
|
||||
queue_item = self._invoker.services.session_queue.set_queue_item_status(
|
||||
queue_id=queue_item.queue_id, item_id=queue_item.item_id, status="failed" if err else "completed"
|
||||
)
|
||||
self._invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
if self._stop_after_current:
|
||||
self._stop_after_current = False
|
||||
self._started = False
|
||||
self._current = None
|
||||
if self._started:
|
||||
self.invoke_next(queue_id=queue_item.queue_id)
|
||||
|
||||
def _emit_queue_item_status(self) -> None:
|
||||
if self._current is None:
|
||||
return
|
||||
self._invoker.services.events.emit_queue_item_status_changed(self._current)
|
||||
|
||||
def invoke_next(self, queue_id: str) -> None:
|
||||
# do not invoke if already invoking
|
||||
if self._current:
|
||||
return
|
||||
|
||||
queue_item = self._invoker.services.session_queue.dequeue()
|
||||
|
||||
if queue_item is None:
|
||||
# queue empty
|
||||
self._current = None
|
||||
self._started = False
|
||||
self._stop_after_current = False
|
||||
return
|
||||
|
||||
self._current = queue_item
|
||||
# execute the session
|
||||
self._invoker.services.graph_execution_manager.set(queue_item.session)
|
||||
self._emit_queue_item_status()
|
||||
# self._invoker.invoke(self._current.session, invoke_all=True)
|
||||
|
||||
def start(
|
||||
self,
|
||||
queue_id: str,
|
||||
) -> None:
|
||||
if not self._stop_after_current:
|
||||
self._started = True
|
||||
self.invoke_next(queue_id=queue_id)
|
||||
|
||||
def stop(self, queue_id: str) -> None:
|
||||
self._started = False
|
||||
self._stop_after_current = True
|
||||
|
||||
def cancel(self, queue_id: str) -> None:
|
||||
if self._current is not None:
|
||||
self._invoker.services.queue.cancel(self._current.session_id)
|
||||
self._current = self._invoker.services.session_queue.set_queue_item_status(
|
||||
queue_id=self._current.queue_id, item_id=self._current.item_id, status="canceled"
|
||||
)
|
||||
self._emit_queue_item_status()
|
||||
self._current = None
|
||||
self._started = False
|
||||
self._stop_after_current = False
|
||||
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
return self._current
|
||||
|
||||
def get_status(self, queue_id: str) -> SessionExecutionStatusResult:
|
||||
return SessionExecutionStatusResult(started=self._started, stop_after_current=self._stop_after_current)
|
||||
@@ -1,21 +1,28 @@
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatusResult
|
||||
|
||||
|
||||
class SessionProcessorABC(ABC):
|
||||
def start(self) -> None:
|
||||
class SessionProcessorBase(ABC):
|
||||
"""
|
||||
Base class for session processor.
|
||||
|
||||
The session processor is responsible for executing sessions. It runs a simple polling loop,
|
||||
checking the session queue for new sessions to execute. It must coordinate with the
|
||||
invocation queue to ensure only one session is executing at a time.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def resume(self) -> None:
|
||||
"""Starts or resumes the session processor"""
|
||||
pass
|
||||
|
||||
def stop(self) -> None:
|
||||
@abstractmethod
|
||||
def pause(self) -> None:
|
||||
"""Pauses the session processor"""
|
||||
pass
|
||||
|
||||
def poll_now(self) -> None:
|
||||
pass
|
||||
|
||||
def get_current(self) -> Optional[SessionQueueItem]:
|
||||
pass
|
||||
|
||||
def clear_current(self) -> None:
|
||||
@abstractmethod
|
||||
def get_status(self) -> SessionProcessorStatusResult:
|
||||
"""Gets the status of the session processor"""
|
||||
pass
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
POLLING_INTERVAL = 1
|
||||
THREAD_LIMIT = 1
|
||||
FINISHED_SESSION_EVENTS = [
|
||||
"graph_execution_state_complete",
|
||||
"invocation_error",
|
||||
"session_retrieval_error",
|
||||
"invocation_retrieval_error",
|
||||
]
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SessionProcessorStatusResult(BaseModel):
|
||||
is_started: bool = Field(description="Whether the session processor is started")
|
||||
is_processing: bool = Field(description="Whether a session is being processed")
|
||||
is_stop_pending: bool = Field(description="Whether processor is pending stopping")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from threading import BoundedSemaphore, Event as ThreadEvent, Thread
|
||||
from threading import BoundedSemaphore
|
||||
from threading import Event as ThreadEvent
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
@@ -8,99 +10,128 @@ from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||
|
||||
from ..invoker import Invoker
|
||||
from .session_processor_base import SessionProcessorABC
|
||||
from .session_processor_common import FINISHED_SESSION_EVENTS, POLLING_INTERVAL, THREAD_LIMIT
|
||||
from .session_processor_base import SessionProcessorBase
|
||||
from .session_processor_common import SessionProcessorStatusResult
|
||||
|
||||
POLLING_INTERVAL = 1
|
||||
THREAD_LIMIT = 1
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorABC):
|
||||
def start_service(self, invoker: Invoker) -> None:
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker: Invoker = invoker
|
||||
self.__queue_item: Optional[SessionQueueItem] = None
|
||||
|
||||
self.__stop_event = ThreadEvent()
|
||||
# when a session is finished, we need to poll the queue immediately.
|
||||
# because we need to wait for current item to finish and also wait
|
||||
# if the queue is empty, need two events.
|
||||
self.__poll_now_busy_event = ThreadEvent()
|
||||
self.__poll_now_queue_event = ThreadEvent()
|
||||
self.__poll_now_event = ThreadEvent()
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._on_session_event)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||
|
||||
self.__threadLimit = BoundedSemaphore(THREAD_LIMIT)
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._on_event)
|
||||
self._start_thread()
|
||||
|
||||
def stop_service(self, *args, **kwargs) -> None:
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
def _poll_now(self) -> None:
|
||||
self.__poll_now_busy_event.set()
|
||||
self.__poll_now_queue_event.set()
|
||||
self.__poll_now_event.set()
|
||||
|
||||
def _start_thread(self) -> None:
|
||||
# threads only live once, so we need to create a new one whenever we start the processor
|
||||
# threads only live once, so we need to create a new one whenever we start the session processor
|
||||
self.__thread = Thread(
|
||||
name="session_processor",
|
||||
target=self.__process,
|
||||
kwargs=dict(
|
||||
stop_event=self.__stop_event,
|
||||
poll_now_busy_event=self.__poll_now_busy_event,
|
||||
poll_now_queue_event=self.__poll_now_queue_event,
|
||||
poll_now_event=self.__poll_now_event,
|
||||
),
|
||||
)
|
||||
self.__thread.start()
|
||||
|
||||
async def _on_event(self, event: FastAPIEvent) -> None:
|
||||
async def _on_session_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
if event_name in FINISHED_SESSION_EVENTS:
|
||||
if event_name in [
|
||||
"graph_execution_state_complete",
|
||||
"invocation_error",
|
||||
"session_retrieval_error",
|
||||
"invocation_retrieval_error",
|
||||
] or (
|
||||
event_name == "session_canceled"
|
||||
and self.__queue_item is not None
|
||||
and self.__queue_item.session_id == event[1]["data"]["graph_execution_state_id"]
|
||||
):
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
|
||||
def stop(self) -> None:
|
||||
self.__stop_event.set()
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
if event_name == "batch_enqueued":
|
||||
self._poll_now()
|
||||
if event_name == "queue_cleared":
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
|
||||
def start(self) -> None:
|
||||
if self.__thread.is_alive():
|
||||
def _is_started(self) -> bool:
|
||||
return self.__thread.is_alive()
|
||||
|
||||
def _is_processing(self) -> bool:
|
||||
return self.__queue_item is not None
|
||||
|
||||
def _is_stop_pending(self) -> bool:
|
||||
return self.__stop_event.is_set()
|
||||
|
||||
def get_status(self) -> SessionProcessorStatusResult:
|
||||
return SessionProcessorStatusResult(
|
||||
is_started=self._is_started(),
|
||||
is_processing=self._is_processing(),
|
||||
is_stop_pending=self._is_stop_pending(),
|
||||
)
|
||||
|
||||
def resume(self) -> None:
|
||||
if self._is_started():
|
||||
return
|
||||
self.__stop_event.clear()
|
||||
self._start_thread()
|
||||
|
||||
def poll_now(self) -> None:
|
||||
self._poll_now()
|
||||
|
||||
def get_current(self) -> Optional[SessionQueueItem]:
|
||||
return self.__queue_item
|
||||
|
||||
def clear_current(self) -> None:
|
||||
self.__queue_item = None
|
||||
def pause(self) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
def __process(
|
||||
self,
|
||||
stop_event: ThreadEvent,
|
||||
poll_now_busy_event: ThreadEvent,
|
||||
poll_now_queue_event: ThreadEvent,
|
||||
poll_now_event: ThreadEvent,
|
||||
):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
|
||||
queue_item: Optional[SessionQueueItem] = None
|
||||
self.__invoker.services.logger
|
||||
while not stop_event.is_set():
|
||||
poll_now_busy_event.clear()
|
||||
poll_now_queue_event.clear()
|
||||
poll_now_event.clear()
|
||||
|
||||
# do not dequeue if there is already a session running
|
||||
if self.__queue_item is not None:
|
||||
poll_now_busy_event.wait(POLLING_INTERVAL)
|
||||
continue
|
||||
|
||||
# get next queue item
|
||||
self.__queue_item = self.__invoker.services.session_queue.dequeue()
|
||||
|
||||
if self.__queue_item is None:
|
||||
poll_now_queue_event.wait(POLLING_INTERVAL)
|
||||
continue
|
||||
queue_item = self.__invoker.services.session_queue.dequeue()
|
||||
|
||||
self.__invoker.services.graph_execution_manager.set(self.__queue_item.session)
|
||||
self.__invoker.invoke(self.__queue_item.session, invoke_all=True)
|
||||
if queue_item is not None:
|
||||
# TODO: Why isn't the log level specified in dependencies.py working?
|
||||
# Within the thread, it is always INFO and `logger.debug()` doesn't display.
|
||||
# self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
|
||||
print(f"Executing queue item {queue_item.item_id}")
|
||||
self.__queue_item = queue_item
|
||||
self.__invoker.services.graph_execution_manager.set(queue_item.session)
|
||||
self.__invoker.invoke(queue_item.session, invoke_all=True)
|
||||
queue_item = None
|
||||
|
||||
if queue_item is None:
|
||||
# self.__invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
print("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(POLLING_INTERVAL)
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
stop_event.clear()
|
||||
poll_now_event.clear()
|
||||
self.__queue_item = None
|
||||
self.__threadLimit.release()
|
||||
|
||||
@@ -2,11 +2,12 @@ from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.graph import Graph
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
BatchStatusResult,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
EnqueueGraphResult,
|
||||
@@ -16,15 +17,16 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItem,
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueStatusResult,
|
||||
SetManyQueueItemStatusResult,
|
||||
)
|
||||
from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||
|
||||
|
||||
class SessionQueueBase(ABC):
|
||||
"""Base class for session queue"""
|
||||
|
||||
@abstractmethod
|
||||
def start_service(self, invoker: Invoker) -> None:
|
||||
"""Startup callback for the SessionQueue service"""
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
"""Dequeues the next session queue item."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -38,13 +40,13 @@ class SessionQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
"""Dequeues the next session queue item, returning it if one is available."""
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
"""Gets the currently-executing session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def peek(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
"""Peeks at the next session queue item, returning it if one is available."""
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
"""Gets the next session queue item (does not dequeue it)"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -67,14 +69,29 @@ class SessionQueueBase(ABC):
|
||||
"""Checks if the queue is empty"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatusResult:
|
||||
"""Gets the status of the queue"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatusResult:
|
||||
"""Gets the status of a batch"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_queue_item(self, item_id: str) -> SessionQueueItem:
|
||||
"""Cancels a session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
"""Cancels all queue items with matching batch IDs"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_status(self, queue_id: str) -> SessionQueueStatusResult:
|
||||
"""Gets the number of queue items with each status"""
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
"""Cancels all queue items with matching queue ID"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -90,7 +107,7 @@ class SessionQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_queue_item(self, queue_id: str, item_id: str) -> SessionQueueItem:
|
||||
def get_queue_item(self, item_id: str) -> SessionQueueItem:
|
||||
"""Gets a session queue item by ID"""
|
||||
pass
|
||||
|
||||
@@ -98,20 +115,3 @@ class SessionQueueBase(ABC):
|
||||
def get_queue_item_by_session_id(self, session_id: str) -> SessionQueueItem:
|
||||
"""Gets a queue item by session ID"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_queue_item_status(self, queue_id: str, item_id: str, status: QUEUE_ITEM_STATUS) -> SessionQueueItem:
|
||||
"""Sets the status of a session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set_many_queue_item_status(
|
||||
self, queue_id: str, item_ids: list[str], status: QUEUE_ITEM_STATUS
|
||||
) -> SetManyQueueItemStatusResult:
|
||||
"""Sets the status of a session queue item"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_queue_item(self, queue_id: str, item_id: str) -> SessionQueueItem:
|
||||
"""Deletes a session queue item by ID"""
|
||||
pass
|
||||
|
||||
@@ -167,6 +167,7 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
||||
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
|
||||
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
|
||||
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
|
||||
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
|
||||
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
|
||||
|
||||
@classmethod
|
||||
@@ -237,15 +238,21 @@ class SessionQueueStatusResult(BaseModel):
|
||||
failed: int = Field(..., description="Number of queue items with status 'error'")
|
||||
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
||||
total: int = Field(..., description="Total number of queue items")
|
||||
max_queue_size: int = Field(..., description="Maximum number of queue items allowed")
|
||||
|
||||
|
||||
class SetManyQueueItemStatusResult(BaseModel):
|
||||
item_ids: list[str] = Field(..., description="The queue item IDs that were updated")
|
||||
status: QUEUE_ITEM_STATUS = Field(..., description="The new status of the queue items")
|
||||
class BatchStatusResult(BaseModel):
|
||||
queue_id: str = Field(..., description="The ID of the queue")
|
||||
batch_id: str = Field(..., description="The ID of the batch")
|
||||
pending: int = Field(..., description="Number of queue items with status 'pending'")
|
||||
in_progress: int = Field(..., description="Number of queue items with status 'in_progress'")
|
||||
completed: int = Field(..., description="Number of queue items with status 'complete'")
|
||||
failed: int = Field(..., description="Number of queue items with status 'error'")
|
||||
canceled: int = Field(..., description="Number of queue items with status 'canceled'")
|
||||
total: int = Field(..., description="Total number of queue items")
|
||||
|
||||
|
||||
class EnqueueBatchResult(BaseModel):
|
||||
queue_id: str = Field(description="The ID of the queue")
|
||||
enqueued: int = Field(description="The total number of queue items enqueued")
|
||||
requested: int = Field(description="The total number of queue items requested to be enqueued")
|
||||
batch: Batch = Field(description="The batch that was enqueued")
|
||||
@@ -273,9 +280,17 @@ class PruneResult(ClearResult):
|
||||
|
||||
|
||||
class CancelByBatchIDsResult(BaseModel):
|
||||
"""Result of canceling by list of batch ids"""
|
||||
|
||||
canceled: int = Field(..., description="Number of queue items canceled")
|
||||
|
||||
|
||||
class CancelByQueueIDResult(CancelByBatchIDsResult):
|
||||
"""Result of canceling by queue id"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class IsEmptyResult(BaseModel):
|
||||
"""Result of checking if the session queue is empty"""
|
||||
|
||||
|
||||
@@ -2,6 +2,10 @@ import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event as FastAPIEvent
|
||||
|
||||
from invokeai.app.services.events import EventServiceBase
|
||||
from invokeai.app.services.graph import Graph
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
@@ -9,7 +13,9 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
DEFAULT_QUEUE_ID,
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
BatchStatusResult,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
EnqueueBatchResult,
|
||||
EnqueueGraphResult,
|
||||
@@ -20,7 +26,6 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
SessionQueueItemDTO,
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatusResult,
|
||||
SetManyQueueItemStatusResult,
|
||||
calc_session_count,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
@@ -28,119 +33,203 @@ from invokeai.app.services.shared.models import CursorPaginatedResults
|
||||
|
||||
|
||||
class SqliteSessionQueue(SessionQueueBase):
|
||||
_invoker: Invoker
|
||||
_conn: sqlite3.Connection
|
||||
_cursor: sqlite3.Cursor
|
||||
_lock: threading.Lock
|
||||
__invoker: Invoker
|
||||
__conn: sqlite3.Connection
|
||||
__cursor: sqlite3.Cursor
|
||||
__lock: threading.Lock
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._on_session_event)
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||
super().__init__()
|
||||
self._conn = conn
|
||||
self.__conn = conn
|
||||
# Enable row factory to get rows as dictionaries (must be done before making the cursor!)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._cursor = self._conn.cursor()
|
||||
self._lock = lock
|
||||
self.__conn.row_factory = sqlite3.Row
|
||||
self.__cursor = self.__conn.cursor()
|
||||
self.__lock = lock
|
||||
self._create_tables()
|
||||
|
||||
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
||||
return event[1]["event"] in match_in
|
||||
|
||||
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
|
||||
event_name = event[1]["event"]
|
||||
match event_name:
|
||||
# successful completion events
|
||||
case "graph_execution_state_complete":
|
||||
await self._handle_complete_event(event)
|
||||
# error events
|
||||
case "invocation_error":
|
||||
await self._handle_error_event(event)
|
||||
case "session_retrieval_error":
|
||||
await self._handle_error_event(event)
|
||||
case "invocation_retrieval_error":
|
||||
await self._handle_error_event(event)
|
||||
# canceled events
|
||||
case "session_canceled":
|
||||
await self._handle_cancel_event(event)
|
||||
return event
|
||||
|
||||
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._create_tables()
|
||||
self._conn.commit()
|
||||
finally:
|
||||
self._lock.release()
|
||||
session_id = event[1]["data"]["graph_execution_state_id"]
|
||||
# When a queue item has an error, we get an error event, then a completed event.
|
||||
# Mark the queue item completed only if it isn't already marked completed, e.g.
|
||||
# by a previously-handled error event.
|
||||
queue_item = self.get_queue_item_by_session_id(session_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
async def _handle_error_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
session_id = event[1]["data"]["graph_execution_state_id"]
|
||||
error = event[1]["data"]["error"]
|
||||
queue_item = self.get_queue_item_by_session_id(session_id)
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
session_id = event[1]["data"]["graph_execution_state_id"]
|
||||
queue_item = self.get_queue_item_by_session_id(session_id)
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
except SessionQueueItemNotFoundError:
|
||||
return
|
||||
|
||||
def _create_tables(self) -> None:
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS session_queue (
|
||||
item_id TEXT NOT NULL PRIMARY KEY, -- the unique identifier of this queue item
|
||||
order_id INTEGER NOT NULL, -- used for ordering, cursor pagination
|
||||
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
|
||||
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
|
||||
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
|
||||
field_values TEXT, -- NULL if no values are associated with this queue item
|
||||
session TEXT NOT NULL, -- the session to be executed
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'complete', 'error', 'canceled'
|
||||
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
|
||||
error TEXT, -- any errors associated with this queue item
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
|
||||
completed_at DATETIME -- completed items are cleaned up on application startup
|
||||
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
|
||||
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
"""Creates the session queue tables, indicies, and triggers"""
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TABLE IF NOT EXISTS session_queue (
|
||||
item_id TEXT NOT NULL PRIMARY KEY, -- the unique identifier of this queue item
|
||||
order_id INTEGER NOT NULL, -- used for ordering, cursor pagination
|
||||
batch_id TEXT NOT NULL, -- identifier of the batch this queue item belongs to
|
||||
queue_id TEXT NOT NULL, -- identifier of the queue this queue item belongs to
|
||||
session_id TEXT NOT NULL UNIQUE, -- duplicated data from the session column, for ease of access
|
||||
field_values TEXT, -- NULL if no values are associated with this queue item
|
||||
session TEXT NOT NULL, -- the session to be executed
|
||||
status TEXT NOT NULL DEFAULT 'pending', -- the status of the queue item, one of 'pending', 'in_progress', 'complete', 'error', 'canceled'
|
||||
priority INTEGER NOT NULL DEFAULT 0, -- the priority, higher is more important
|
||||
error TEXT, -- any errors associated with this queue item
|
||||
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
|
||||
updated_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- updated via trigger
|
||||
started_at DATETIME, -- updated via trigger
|
||||
completed_at DATETIME -- updated via trigger, completed items are cleaned up on application startup
|
||||
-- Ideally this is a FK, but graph_executions uses INSERT OR REPLACE, and REPLACE triggers the ON DELETE CASCADE...
|
||||
-- FOREIGN KEY (session_id) REFERENCES graph_executions (id) ON DELETE CASCADE
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
|
||||
"""
|
||||
)
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_item_id ON session_queue(item_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_order_id ON session_queue(order_id);
|
||||
"""
|
||||
)
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_order_id ON session_queue(order_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
|
||||
"""
|
||||
)
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_session_queue_session_id ON session_queue(session_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
|
||||
"""
|
||||
)
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_batch_id ON session_queue(batch_id);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
|
||||
"""
|
||||
)
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_priority ON session_queue(priority);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
|
||||
"""
|
||||
)
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE INDEX IF NOT EXISTS idx_session_queue_created_status ON session_queue(status);
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'completed'
|
||||
OR NEW.status = 'failed'
|
||||
OR NEW.status = 'canceled'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_completed_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'completed'
|
||||
OR NEW.status = 'failed'
|
||||
OR NEW.status = 'canceled'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET completed_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self._cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
|
||||
AFTER UPDATE
|
||||
ON session_queue FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = old.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_started_at
|
||||
AFTER UPDATE OF status ON session_queue
|
||||
FOR EACH ROW
|
||||
WHEN
|
||||
NEW.status = 'in_progress'
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET started_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = NEW.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
CREATE TRIGGER IF NOT EXISTS tg_session_queue_updated_at
|
||||
AFTER UPDATE
|
||||
ON session_queue FOR EACH ROW
|
||||
BEGIN
|
||||
UPDATE session_queue
|
||||
SET updated_at = STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')
|
||||
WHERE item_id = old.item_id;
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""
|
||||
Sets all in_progress queue items to canceled.
|
||||
This is necessary because the invoker may have been killed while processing a queue item.
|
||||
"""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
@@ -148,13 +237,14 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
|
||||
def _get_current_queue_size(self, queue_id: str) -> int:
|
||||
self._cursor.execute(
|
||||
"""Gets the current number of pending queue items"""
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
@@ -164,10 +254,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(int, self._cursor.fetchone()[0])
|
||||
return cast(int, self.__cursor.fetchone()[0])
|
||||
|
||||
def _get_highest_priority(self, queue_id: str) -> int:
|
||||
self._cursor.execute(
|
||||
"""Gets the highest priority value in the queue"""
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(priority)
|
||||
FROM session_queue
|
||||
@@ -177,19 +268,13 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
return cast(Union[int, None], self._cursor.fetchone()[0]) or 0
|
||||
|
||||
def start_service(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
self._invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
|
||||
|
||||
def enqueue_graph(self, queue_id: str, graph: Graph, prepend: bool) -> EnqueueGraphResult:
|
||||
enqueue_result = self.enqueue_batch(queue_id=queue_id, batch=Batch(graph=graph), prepend=prepend)
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
@@ -198,12 +283,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id, enqueue_result.batch.batch_id),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with batch id {enqueue_result.batch.batch_id}")
|
||||
return EnqueueGraphResult(
|
||||
@@ -213,24 +298,24 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self.__lock.acquire()
|
||||
|
||||
# TODO: how does this work in a multi-user scenario?
|
||||
current_queue_size = self._get_current_queue_size(queue_id=queue_id)
|
||||
max_queue_size = self._invoker.services.configuration.get_config().max_queue_size
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.get_config().max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
|
||||
priority = 0
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id=queue_id) + 1
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
|
||||
self._cursor.execute(
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT MAX(order_id)
|
||||
FROM session_queue
|
||||
"""
|
||||
)
|
||||
max_order_id = cast(Optional[int], self._cursor.fetchone()[0]) or 0
|
||||
max_order_id = cast(Optional[int], self.__cursor.fetchone()[0]) or 0
|
||||
|
||||
requested_count = calc_session_count(batch)
|
||||
values_to_insert = prepare_values_to_insert(
|
||||
@@ -245,30 +330,33 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
if requested_count > enqueued_count:
|
||||
values_to_insert = values_to_insert[:max_new_queue_items]
|
||||
|
||||
self._cursor.executemany(
|
||||
self.__cursor.executemany(
|
||||
"""--sql
|
||||
INSERT INTO session_queue (item_id, queue_id, session, session_id, batch_id, field_values, priority, order_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
values_to_insert,
|
||||
)
|
||||
self._conn.commit()
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return EnqueueBatchResult(
|
||||
self.__lock.release()
|
||||
enqueue_result = EnqueueBatchResult(
|
||||
queue_id=queue_id,
|
||||
requested=requested_count,
|
||||
enqueued=enqueued_count,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
)
|
||||
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
|
||||
return enqueue_result
|
||||
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
@@ -279,23 +367,23 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
LIMIT 1
|
||||
"""
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
queue_item = SessionQueueItem.from_dict(dict(result))
|
||||
return self.set_queue_item_status(
|
||||
queue_id=queue_item.queue_id, item_id=queue_item.item_id, status="in_progress"
|
||||
)
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
return queue_item
|
||||
|
||||
def peek(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
@@ -309,82 +397,66 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
|
||||
def set_queue_item_status(self, queue_id: str, item_id: str, status: QUEUE_ITEM_STATUS) -> SessionQueueItem:
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = ?
|
||||
SELECT *
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND item_id = ?
|
||||
AND status = 'in_progress'
|
||||
LIMIT 1
|
||||
""",
|
||||
(status, queue_id, item_id),
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
return self.get_queue_item(queue_id=queue_id, item_id=item_id)
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
return None
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
|
||||
def set_many_queue_item_status(
|
||||
self, queue_id: str, item_ids: list[str], status: QUEUE_ITEM_STATUS
|
||||
) -> SetManyQueueItemStatusResult:
|
||||
def _set_queue_item_status(
|
||||
self, item_id: str, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
|
||||
) -> SessionQueueItem:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
# update the queue items
|
||||
placeholders = ", ".join(["?" for _ in item_ids])
|
||||
|
||||
update_query = f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = ?
|
||||
WHERE
|
||||
queue_id in ?
|
||||
AND item_id IN ({placeholders})
|
||||
"""
|
||||
|
||||
self._cursor.execute(update_query, [queue_id, status] + item_ids)
|
||||
self._conn.commit()
|
||||
|
||||
# get queue items from list which were set to the status successfully
|
||||
fetch_query = f"""--sql
|
||||
SELECT item_id
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND status = ?
|
||||
AND item_id IN ({placeholders})
|
||||
"""
|
||||
|
||||
self._cursor.execute(fetch_query, [queue_id, status] + item_ids)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = ?, error = ?
|
||||
WHERE
|
||||
item_id = ?
|
||||
""",
|
||||
(status, error, item_id),
|
||||
)
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
updated_ids = [row[0] for row in result]
|
||||
return SetManyQueueItemStatusResult(item_ids=updated_ids, status=status)
|
||||
self.__lock.release()
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
@@ -392,18 +464,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
is_empty = cast(int, self._cursor.fetchone()[0]) == 0
|
||||
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
return IsEmptyResult(is_empty=is_empty)
|
||||
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT count(*)
|
||||
FROM session_queue
|
||||
@@ -411,40 +483,39 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
max_queue_size = self._invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, self._cursor.fetchone()[0]) >= max_queue_size
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
return IsFullResult(is_full=is_full)
|
||||
|
||||
def delete_queue_item(self, queue_id: str, item_id: str) -> SessionQueueItem:
|
||||
queue_item = self.get_queue_item(queue_id=queue_id, item_id=item_id)
|
||||
def delete_queue_item(self, item_id: str) -> SessionQueueItem:
|
||||
queue_item = self.get_queue_item(item_id=item_id)
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
DELETE FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND item_id = ?
|
||||
item_id = ?
|
||||
""",
|
||||
(queue_id, item_id),
|
||||
(item_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
return queue_item
|
||||
|
||||
def clear(self, queue_id: str) -> ClearResult:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -452,8 +523,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
@@ -461,12 +532,13 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
||||
return ClearResult(deleted=count)
|
||||
|
||||
def prune(self, queue_id: str) -> PruneResult:
|
||||
@@ -480,8 +552,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
OR status = 'canceled'
|
||||
)
|
||||
"""
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -489,8 +561,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
DELETE
|
||||
FROM session_queue
|
||||
@@ -498,27 +570,38 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
self._conn.commit()
|
||||
self.__conn.commit()
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
return PruneResult(deleted=count)
|
||||
|
||||
def cancel_queue_item(self, item_id: str) -> SessionQueueItem:
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||
self.__invoker.services.queue.cancel(queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(queue_item.session_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
return queue_item
|
||||
|
||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
placeholders = ", ".join(["?" for _ in batch_ids])
|
||||
where = f"""--sql
|
||||
WHERE
|
||||
queue_id = ?
|
||||
queue_id == ?
|
||||
AND batch_id IN ({placeholders})
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = [queue_id] + batch_ids
|
||||
self._cursor.execute(
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
@@ -526,8 +609,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
count = self._cursor.fetchone()[0]
|
||||
self._cursor.execute(
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
@@ -535,40 +618,84 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self._conn.commit()
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
return CancelByBatchIDsResult(canceled=count)
|
||||
|
||||
def get_queue_item(self, queue_id: str, item_id: str) -> SessionQueueItem:
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
self.__lock.acquire()
|
||||
where = """--sql
|
||||
WHERE
|
||||
queue_id is ?
|
||||
AND status != 'canceled'
|
||||
AND status != 'completed'
|
||||
AND status != 'failed'
|
||||
"""
|
||||
params = [queue_id]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM session_queue
|
||||
{where};
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
count = self.__cursor.fetchone()[0]
|
||||
self.__cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE session_queue
|
||||
SET status = 'canceled'
|
||||
{where};
|
||||
""",
|
||||
tuple(params),
|
||||
)
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
return CancelByQueueIDResult(canceled=count)
|
||||
|
||||
def get_queue_item(self, item_id: str) -> SessionQueueItem:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND item_id = ?
|
||||
item_id = ?
|
||||
""",
|
||||
(queue_id, item_id),
|
||||
(item_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
|
||||
def get_queue_item_by_session_id(self, session_id: str) -> SessionQueueItem:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
@@ -576,12 +703,12 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(session_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self._cursor.fetchone())
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with session id {session_id}")
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
@@ -595,7 +722,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self.__lock.acquire()
|
||||
query = """--sql
|
||||
SELECT item_id,
|
||||
order_id,
|
||||
@@ -633,8 +760,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit + 1)
|
||||
self._cursor.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
self.__cursor.execute(query, params)
|
||||
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
items = [SessionQueueItemDTO.from_dict(dict(result)) for result in results]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
@@ -642,16 +769,16 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
items.pop()
|
||||
has_more = True
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
|
||||
|
||||
def get_status(self, queue_id: str) -> SessionQueueStatusResult:
|
||||
def get_queue_status(self, queue_id: str) -> SessionQueueStatusResult:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
self._cursor.execute(
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
@@ -660,14 +787,14 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
""",
|
||||
(queue_id,),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
except Exception:
|
||||
self._conn.rollback()
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._lock.release()
|
||||
self.__lock.release()
|
||||
|
||||
return SessionQueueStatusResult(
|
||||
queue_id=queue_id,
|
||||
@@ -677,5 +804,38 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
failed=counts.get("failed", 0),
|
||||
canceled=counts.get("canceled", 0),
|
||||
total=total,
|
||||
max_queue_size=self._invoker.services.configuration.get_config().max_queue_size,
|
||||
)
|
||||
|
||||
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatusResult:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT status, count(*)
|
||||
FROM session_queue
|
||||
WHERE
|
||||
queue_id = ?
|
||||
AND batch_id = ?
|
||||
GROUP BY status
|
||||
""",
|
||||
(queue_id, batch_id),
|
||||
)
|
||||
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
||||
total = sum(row[1] for row in result)
|
||||
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
|
||||
return BatchStatusResult(
|
||||
batch_id=batch_id,
|
||||
queue_id=queue_id,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
completed=counts.get("completed", 0),
|
||||
failed=counts.get("failed", 0),
|
||||
canceled=counts.get("canceled", 0),
|
||||
total=total,
|
||||
)
|
||||
|
||||
@@ -229,6 +229,7 @@
|
||||
"clearTooltip": "Cancel and Clear All Items",
|
||||
"clearSucceeded": "Queue Cleared",
|
||||
"clearFailed": "Problem Clearing Queue",
|
||||
"cancelBatch": "Cancel Batch",
|
||||
"cancelItem": "Cancel Item",
|
||||
"cancelByBatchIdsSucceeded": "Batches Canceled",
|
||||
"cancelByBatchIdsFailed": "Problem Canceling Items by Batch IDs",
|
||||
|
||||
@@ -86,7 +86,6 @@ import { addEnqueueRequestedLinear } from './listeners/enqueueRequestedLinear';
|
||||
import { addWorkflowLoadedListener } from './listeners/workflowLoaded';
|
||||
import { addDynamicPromptsListener } from './listeners/promptChanged';
|
||||
import { addSocketQueueItemStatusChangedEventListener } from './listeners/socketio/socketQueueItemStatusChanged';
|
||||
import { addSocketQueueStatusChangedEventListener } from './listeners/socketio/socketQueueStatusChanged';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@@ -175,7 +174,6 @@ addModelLoadEventListener();
|
||||
addSessionRetrievalErrorEventListener();
|
||||
addInvocationRetrievalErrorEventListener();
|
||||
addSocketQueueItemStatusChangedEventListener();
|
||||
addSocketQueueStatusChangedEventListener();
|
||||
|
||||
// Session Created
|
||||
addSessionCreatedPendingListener();
|
||||
|
||||
@@ -24,7 +24,7 @@ export const addBatchEnqueuedListener = () => {
|
||||
req.reset();
|
||||
|
||||
dispatch(
|
||||
queueApi.endpoints.startQueueExecution.initiate(undefined, {
|
||||
queueApi.endpoints.resumeProcessor.initiate(undefined, {
|
||||
fixedCacheKey: 'startQueue',
|
||||
})
|
||||
);
|
||||
|
||||
@@ -47,7 +47,7 @@ export const addControlNetImageProcessedListener = () => {
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
dispatch(
|
||||
queueApi.endpoints.startQueueExecution.initiate(undefined, {
|
||||
queueApi.endpoints.resumeProcessor.initiate(undefined, {
|
||||
fixedCacheKey: 'startQueue',
|
||||
})
|
||||
);
|
||||
|
||||
@@ -140,7 +140,7 @@ export const addEnqueueRequestedCanvasListener = () => {
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
dispatch(
|
||||
queueApi.endpoints.startQueueExecution.initiate(undefined, {
|
||||
queueApi.endpoints.resumeProcessor.initiate(undefined, {
|
||||
fixedCacheKey: 'startQueue',
|
||||
})
|
||||
);
|
||||
|
||||
@@ -51,7 +51,7 @@ export const addEnqueueRequestedLinear = () => {
|
||||
req.reset();
|
||||
|
||||
dispatch(
|
||||
queueApi.endpoints.startQueueExecution.initiate(undefined, {
|
||||
queueApi.endpoints.resumeProcessor.initiate(undefined, {
|
||||
fixedCacheKey: 'startQueue',
|
||||
})
|
||||
);
|
||||
|
||||
@@ -35,7 +35,7 @@ export const addEnqueueRequestedNodes = () => {
|
||||
req.reset();
|
||||
|
||||
dispatch(
|
||||
queueApi.endpoints.startQueueExecution.initiate(undefined, {
|
||||
queueApi.endpoints.resumeProcessor.initiate(undefined, {
|
||||
fixedCacheKey: 'startQueue',
|
||||
})
|
||||
);
|
||||
|
||||
@@ -12,7 +12,12 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
|
||||
actionCreator: socketQueueItemStatusChanged,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('socketio');
|
||||
const { item_id, status: newStatus } = action.payload.data;
|
||||
const {
|
||||
item_id,
|
||||
batch_id,
|
||||
graph_execution_state_id,
|
||||
status: newStatus,
|
||||
} = action.payload.data;
|
||||
log.debug(
|
||||
action.payload,
|
||||
`Queue item ${item_id} status updated: ${newStatus}`
|
||||
@@ -29,7 +34,6 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
|
||||
);
|
||||
|
||||
const state = getState();
|
||||
const { batch_id, graph_execution_state_id } = action.payload.data;
|
||||
if (state.canvas.batchIds.includes(batch_id)) {
|
||||
dispatch(canvasSessionIdAdded(graph_execution_state_id));
|
||||
}
|
||||
@@ -39,6 +43,10 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
|
||||
'CurrentSessionQueueItem',
|
||||
'NextSessionQueueItem',
|
||||
'SessionQueueStatus',
|
||||
'SessionProcessorStatus',
|
||||
{ type: 'SessionQueueItem', id: item_id },
|
||||
{ type: 'SessionQueueItemDTO', id: item_id },
|
||||
{ type: 'BatchStatus', id: batch_id },
|
||||
])
|
||||
);
|
||||
},
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import {
|
||||
appSocketQueueStatusChanged,
|
||||
socketQueueStatusChanged,
|
||||
} from 'services/events/actions';
|
||||
import { startAppListening } from '../..';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { t } from 'i18next';
|
||||
|
||||
export const addSocketQueueStatusChangedEventListener = () => {
|
||||
startAppListening({
|
||||
actionCreator: socketQueueStatusChanged,
|
||||
effect: (action, { dispatch, getOriginalState }) => {
|
||||
const log = logger('socketio');
|
||||
log.debug(action.payload, `Queue status updated`);
|
||||
// pass along the socket event as an application action
|
||||
dispatch(appSocketQueueStatusChanged(action.payload));
|
||||
|
||||
const { data: oldQueueStatus } =
|
||||
queueApi.endpoints.getQueueStatus.select()(getOriginalState());
|
||||
|
||||
if (
|
||||
oldQueueStatus?.started === false &&
|
||||
oldQueueStatus?.stop_after_current === true &&
|
||||
action.payload.data.started === false &&
|
||||
action.payload.data.stop_after_current === false
|
||||
) {
|
||||
dispatch(
|
||||
addToast({ title: t('queue.stopSucceeded'), status: 'success' })
|
||||
);
|
||||
}
|
||||
|
||||
dispatch(queueApi.util.invalidateTags(['SessionQueueStatus']));
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -38,7 +38,7 @@ export const addUpscaleRequestedListener = () => {
|
||||
const enqueueResult = await req.unwrap();
|
||||
req.reset();
|
||||
dispatch(
|
||||
queueApi.endpoints.startQueueExecution.initiate(undefined, {
|
||||
queueApi.endpoints.resumeProcessor.initiate(undefined, {
|
||||
fixedCacheKey: 'startQueue',
|
||||
})
|
||||
);
|
||||
|
||||
@@ -23,7 +23,7 @@ export const enqueueBatch = async (
|
||||
req.reset();
|
||||
|
||||
dispatch(
|
||||
queueApi.endpoints.startQueueExecution.initiate(undefined, {
|
||||
queueApi.endpoints.resumeProcessor.initiate(undefined, {
|
||||
fixedCacheKey: 'startQueue',
|
||||
})
|
||||
);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import {
|
||||
As,
|
||||
ChakraProps,
|
||||
Flex,
|
||||
FlexProps,
|
||||
Icon,
|
||||
Skeleton,
|
||||
Spinner,
|
||||
@@ -47,15 +47,14 @@ export const IAILoadingImageFallback = (props: Props) => {
|
||||
);
|
||||
};
|
||||
|
||||
type IAINoImageFallbackProps = {
|
||||
type IAINoImageFallbackProps = FlexProps & {
|
||||
label?: string;
|
||||
icon?: As | null;
|
||||
boxSize?: StyleProps['boxSize'];
|
||||
sx?: ChakraProps['sx'];
|
||||
};
|
||||
|
||||
export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
|
||||
const { icon = FaImage, boxSize = 16 } = props;
|
||||
const { icon = FaImage, boxSize = 16, sx, ...rest } = props;
|
||||
|
||||
return (
|
||||
<Flex
|
||||
@@ -73,8 +72,9 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => {
|
||||
_dark: {
|
||||
color: 'base.500',
|
||||
},
|
||||
...props.sx,
|
||||
...sx,
|
||||
}}
|
||||
{...rest}
|
||||
>
|
||||
{icon && <Icon as={icon} boxSize={boxSize} opacity={0.7} />}
|
||||
{props.label && <Text textAlign="center">{props.label}</Text>}
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAINumberInput from 'common/components/IAINumberInput';
|
||||
import IAISlider from 'common/components/IAISlider';
|
||||
import { setIterations } from 'features/parameters/store/generationSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector],
|
||||
(state) => {
|
||||
const { initial, min, sliderMax, inputMax, fineStep, coarseStep } =
|
||||
state.config.sd.iterations;
|
||||
const { iterations } = state.generation;
|
||||
const { shouldUseSliders } = state.ui;
|
||||
|
||||
const step = state.hotkeys.shift ? fineStep : coarseStep;
|
||||
|
||||
return {
|
||||
iterations,
|
||||
initial,
|
||||
min,
|
||||
sliderMax,
|
||||
inputMax,
|
||||
step,
|
||||
shouldUseSliders,
|
||||
};
|
||||
},
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const ParamIterations = () => {
|
||||
const {
|
||||
iterations,
|
||||
initial,
|
||||
min,
|
||||
sliderMax,
|
||||
inputMax,
|
||||
step,
|
||||
shouldUseSliders,
|
||||
} = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const handleChange = useCallback(
|
||||
(v: number) => {
|
||||
dispatch(setIterations(v));
|
||||
},
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
const handleReset = useCallback(() => {
|
||||
dispatch(setIterations(initial));
|
||||
}, [dispatch, initial]);
|
||||
|
||||
return (
|
||||
<IAINumberInput
|
||||
// label={t('parameters.runs')}
|
||||
step={step}
|
||||
min={min}
|
||||
max={inputMax}
|
||||
onChange={handleChange}
|
||||
value={iterations}
|
||||
numberInputFieldProps={{ textAlign: 'center' }}
|
||||
formControlProps={{ w: 36 }}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ParamIterations);
|
||||
@@ -1,31 +1,34 @@
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaTimes } from 'react-icons/fa';
|
||||
import {
|
||||
useCancelQueueExecutionMutation,
|
||||
useGetQueueStatusQuery,
|
||||
useCancelQueueItemMutation,
|
||||
useGetCurrentQueueItemQuery,
|
||||
} from 'services/api/endpoints/queue';
|
||||
import QueueButton from './common/QueueButton';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { useIsQueueMutationInProgress } from '../hooks/useIsQueueMutationInProgress';
|
||||
import QueueButton from './common/QueueButton';
|
||||
|
||||
type Props = {
|
||||
asIconButton?: boolean;
|
||||
};
|
||||
|
||||
const CancelQueueButton = ({ asIconButton }: Props) => {
|
||||
const CancelCurrentQueueItemButton = ({ asIconButton }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const { data: queueStatusData } = useGetQueueStatusQuery();
|
||||
const [cancelQueue] = useCancelQueueExecutionMutation({
|
||||
fixedCacheKey: 'cancelQueue',
|
||||
const { data: currentQueueItem } = useGetCurrentQueueItemQuery();
|
||||
const [cancelQueueItem] = useCancelQueueItemMutation({
|
||||
fixedCacheKey: 'cancelQueueItem',
|
||||
});
|
||||
const isQueueMutationInProgress = useIsQueueMutationInProgress();
|
||||
|
||||
const handleClick = useCallback(async () => {
|
||||
if (!currentQueueItem) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await cancelQueue().unwrap();
|
||||
await cancelQueueItem(currentQueueItem.item_id).unwrap();
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.cancelSucceeded'),
|
||||
@@ -40,22 +43,19 @@ const CancelQueueButton = ({ asIconButton }: Props) => {
|
||||
})
|
||||
);
|
||||
}
|
||||
}, [cancelQueue, dispatch, t]);
|
||||
}, [cancelQueueItem, currentQueueItem, dispatch, t]);
|
||||
|
||||
return (
|
||||
<QueueButton
|
||||
asIconButton={asIconButton}
|
||||
label={t('queue.cancel')}
|
||||
tooltip={t('queue.cancelTooltip')}
|
||||
isDisabled={
|
||||
!(queueStatusData?.started || queueStatusData?.stop_after_current) ||
|
||||
isQueueMutationInProgress
|
||||
}
|
||||
isDisabled={!currentQueueItem || isQueueMutationInProgress}
|
||||
icon={<FaTimes />}
|
||||
onClick={handleClick}
|
||||
colorScheme="orange"
|
||||
colorScheme="error"
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(CancelQueueButton);
|
||||
export default memo(CancelCurrentQueueItemButton);
|
||||
@@ -1,11 +1,11 @@
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { usePeekNextQueueItemQuery } from 'services/api/endpoints/queue';
|
||||
import { useGetNextQueueItemQuery } from 'services/api/endpoints/queue';
|
||||
import QueueItemCard from './common/QueueItemCard';
|
||||
|
||||
const NextQueueItemCard = () => {
|
||||
const { t } = useTranslation();
|
||||
const { data: nextQueueItemData } = usePeekNextQueueItemQuery();
|
||||
const { data: nextQueueItemData } = useGetNextQueueItemQuery();
|
||||
|
||||
return (
|
||||
<QueueItemCard
|
||||
|
||||
@@ -34,11 +34,10 @@ const QueueBackButton = () => {
|
||||
);
|
||||
return (
|
||||
<IAIButton
|
||||
isDisabled={!isReady}
|
||||
isDisabled={!isReady || isQueueMutationInProgress}
|
||||
colorScheme="accent"
|
||||
onClick={handleEnqueue}
|
||||
tooltip={<EnqueueButtonTooltip />}
|
||||
isLoading={isQueueMutationInProgress}
|
||||
flexGrow={3}
|
||||
minW={44}
|
||||
>
|
||||
|
||||
@@ -8,7 +8,6 @@ import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useEnqueueBatchMutation } from 'services/api/endpoints/queue';
|
||||
import { useBoardName } from 'services/api/hooks/useBoardName';
|
||||
import { usePredictedQueueCounts } from '../hooks/usePredictedQueueCounts';
|
||||
|
||||
const tooltipSelector = createSelector(
|
||||
[stateSelector],
|
||||
@@ -33,7 +32,6 @@ const QueueButtonTooltipContent = ({ prepend = false }: Props) => {
|
||||
const [_, { isLoading }] = useEnqueueBatchMutation({
|
||||
fixedCacheKey: 'enqueueBatch',
|
||||
});
|
||||
const counts = usePredictedQueueCounts();
|
||||
|
||||
const label = useMemo(() => {
|
||||
if (isLoading) {
|
||||
@@ -41,12 +39,12 @@ const QueueButtonTooltipContent = ({ prepend = false }: Props) => {
|
||||
}
|
||||
if (isReady) {
|
||||
if (prepend) {
|
||||
return t('queue.queueFront', { predicted: counts?.predicted ?? '?' });
|
||||
return t('queue.queueFront');
|
||||
}
|
||||
return t('queue.queueBack', { predicted: counts?.predicted ?? '?' });
|
||||
return t('queue.queueBack');
|
||||
}
|
||||
return t('queue.notReady');
|
||||
}, [counts?.predicted, isLoading, isReady, prepend, t]);
|
||||
}, [isLoading, isReady, prepend, t]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={1}>
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import { ButtonGroup, Flex, Text } from '@chakra-ui/react';
|
||||
import ParamRuns from 'features/parameters/components/Parameters/Core/ParamRuns';
|
||||
import CancelQueueButton from 'features/queue/components/CancelQueueButton';
|
||||
import { ButtonGroup, Flex, Spacer, Text } from '@chakra-ui/react';
|
||||
import CancelCurrentQueueItemButton from 'features/queue/components/CancelCurrentQueueItemButton';
|
||||
import ClearQueueButton from 'features/queue/components/ClearQueueButton';
|
||||
import QueueBackButton from 'features/queue/components/QueueBackButton';
|
||||
import QueueFrontButton from 'features/queue/components/QueueFrontButton';
|
||||
import StartQueueButton from 'features/queue/components/StartQueueButton';
|
||||
import StopQueueButton from 'features/queue/components/StopQueueButton';
|
||||
import { usePredictedQueueCounts } from 'features/queue/hooks/usePredictedQueueCounts';
|
||||
import ResumeProcessorButton from 'features/queue/components/StartQueueButton';
|
||||
import PauseProcessorButton from 'features/queue/components/StopQueueButton';
|
||||
import ProgressBar from 'features/system/components/ProgressBar';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -29,11 +27,11 @@ const QueueControls = () => {
|
||||
<ButtonGroup isAttached flexGrow={2}>
|
||||
<QueueBackButton />
|
||||
<QueueFrontButton />
|
||||
<CancelCurrentQueueItemButton asIconButton />
|
||||
</ButtonGroup>
|
||||
<ButtonGroup isAttached>
|
||||
<StartQueueButton asIconButton />
|
||||
<StopQueueButton asIconButton />
|
||||
<CancelQueueButton asIconButton />
|
||||
<ResumeProcessorButton asIconButton />
|
||||
<PauseProcessorButton asIconButton />
|
||||
<ClearQueueButton asIconButton />
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
@@ -48,49 +46,35 @@ const QueueControls = () => {
|
||||
export default memo(QueueControls);
|
||||
|
||||
const QueueCounts = memo(() => {
|
||||
const counts = usePredictedQueueCounts();
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const { hasItems, pending } = useGetQueueStatusQuery(undefined, {
|
||||
selectFromResult: ({ data }) => {
|
||||
if (!data) {
|
||||
return {
|
||||
hasItems: false,
|
||||
pending: 0,
|
||||
};
|
||||
}
|
||||
|
||||
const { pending, in_progress } = data;
|
||||
|
||||
return {
|
||||
hasItems: pending + in_progress > 0,
|
||||
pending,
|
||||
};
|
||||
},
|
||||
});
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!counts || !queueStatus) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { requested, predicted, max_queue_size } = counts;
|
||||
const { pending, in_progress } = queueStatus;
|
||||
return (
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<ParamRuns />
|
||||
{/* <Tooltip
|
||||
label={
|
||||
requested > predicted &&
|
||||
t('queue.queueMaxExceeded', {
|
||||
requested,
|
||||
skip: requested - predicted,
|
||||
max_queue_size,
|
||||
})
|
||||
}
|
||||
>
|
||||
<Text
|
||||
variant="subtext"
|
||||
fontSize="sm"
|
||||
fontWeight={400}
|
||||
fontStyle="oblique 10deg"
|
||||
opacity={0.7}
|
||||
color={requested > predicted ? 'warning.500' : undefined}
|
||||
>
|
||||
{t('queue.queueCountPrediction', { predicted })}
|
||||
</Text>
|
||||
</Tooltip> */}
|
||||
<Spacer />
|
||||
<Text
|
||||
variant="subtext"
|
||||
fontSize="sm"
|
||||
fontWeight={400}
|
||||
fontStyle="oblique 10deg"
|
||||
opacity={0.7}
|
||||
pe={1}
|
||||
>
|
||||
{pending + in_progress > 0
|
||||
{hasItems
|
||||
? t('queue.queuedCount', {
|
||||
pending,
|
||||
})
|
||||
|
||||
@@ -36,10 +36,9 @@ const QueueFrontButton = () => {
|
||||
<IAIIconButton
|
||||
colorScheme="base"
|
||||
aria-label={t('queue.queueFront')}
|
||||
isDisabled={!isReady}
|
||||
isDisabled={!isReady || isQueueMutationInProgress}
|
||||
onClick={handleEnqueue}
|
||||
tooltip={<EnqueueButtonTooltip prepend />}
|
||||
isLoading={isQueueMutationInProgress}
|
||||
icon={<FaBoltLightning />}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
import {
|
||||
ButtonGroup,
|
||||
ChakraProps,
|
||||
Collapse,
|
||||
Flex,
|
||||
Text,
|
||||
} from '@chakra-ui/react';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaTimes } from 'react-icons/fa';
|
||||
import { useCancelQueueItemMutation } from 'services/api/endpoints/queue';
|
||||
import { SessionQueueItemDTO } from 'services/api/types';
|
||||
import QueueStatusBadge from '../common/QueueStatusBadge';
|
||||
import QueueItemDetail from './QueueItemDetail';
|
||||
import { COLUMN_WIDTHS } from './constants';
|
||||
import { ListContext } from './types';
|
||||
|
||||
const selectedStyles = { bg: 'base.300', _dark: { bg: 'base.750' } };
|
||||
|
||||
type InnerItemProps = {
|
||||
index: number;
|
||||
item: SessionQueueItemDTO;
|
||||
context: ListContext;
|
||||
};
|
||||
|
||||
const sx: ChakraProps['sx'] = {
|
||||
_hover: selectedStyles,
|
||||
"&[aria-selected='true']": selectedStyles,
|
||||
};
|
||||
|
||||
const QueueItemComponent = ({ index, item, context }: InnerItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
const handleToggle = useCallback(() => {
|
||||
context.toggleQueueItem(item.item_id);
|
||||
}, [context, item.item_id]);
|
||||
const [cancelQueueItem, { isLoading: isLoadingCancelQueueItem }] =
|
||||
useCancelQueueItemMutation();
|
||||
const handleCancelQueueItem = useCallback(
|
||||
(e: MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
cancelQueueItem(item.item_id);
|
||||
},
|
||||
[cancelQueueItem, item.item_id]
|
||||
);
|
||||
const isOpen = useMemo(
|
||||
() => context.openQueueItems.includes(item.item_id),
|
||||
[context.openQueueItems, item.item_id]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
flexDir="column"
|
||||
borderRadius="base"
|
||||
aria-selected={isOpen}
|
||||
fontSize="sm"
|
||||
justifyContent="center"
|
||||
sx={sx}
|
||||
>
|
||||
<Flex
|
||||
alignItems="center"
|
||||
gap={4}
|
||||
p={1.5}
|
||||
cursor="pointer"
|
||||
onClick={handleToggle}
|
||||
>
|
||||
<Flex
|
||||
w={COLUMN_WIDTHS.number}
|
||||
justifyContent="flex-end"
|
||||
alignItems="center"
|
||||
>
|
||||
<Text variant="subtext">{index + 1}</Text>
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.statusBadge} alignItems="center">
|
||||
<QueueStatusBadge status={item.status} />
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.batchId}>
|
||||
<Text
|
||||
overflow="hidden"
|
||||
textOverflow="ellipsis"
|
||||
whiteSpace="nowrap"
|
||||
alignItems="center"
|
||||
>
|
||||
{item.batch_id}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Flex alignItems="center" flexGrow={1}>
|
||||
{item.field_values && (
|
||||
<Flex gap={2}>
|
||||
{item.field_values
|
||||
.filter((v) => v.node_path !== 'metadata_accumulator')
|
||||
.map(({ node_path, field_name, value }) => (
|
||||
<Text
|
||||
key={`${item.item_id}.${node_path}.${field_name}.${value}`}
|
||||
whiteSpace="nowrap"
|
||||
textOverflow="ellipsis"
|
||||
overflow="hidden"
|
||||
>
|
||||
<Text as="span" fontWeight={600}>
|
||||
{node_path}.{field_name}
|
||||
</Text>
|
||||
: {value}
|
||||
</Text>
|
||||
))}
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex alignItems="center" w={COLUMN_WIDTHS.actions} pe={3}>
|
||||
<ButtonGroup size="xs" variant="ghost">
|
||||
<IAIIconButton
|
||||
tooltip={t('queue.cancelItem')}
|
||||
onClick={handleCancelQueueItem}
|
||||
isLoading={isLoadingCancelQueueItem}
|
||||
isDisabled={['canceled', 'completed', 'failed'].includes(
|
||||
item.status
|
||||
)}
|
||||
aria-label={t('queue.cancelItem')}
|
||||
icon={<FaTimes />}
|
||||
/>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
<Collapse
|
||||
in={isOpen}
|
||||
transition={{ enter: { duration: 0.1 }, exit: { duration: 0.1 } }}
|
||||
unmountOnExit={true}
|
||||
>
|
||||
<QueueItemDetail queueItemDTO={item} />
|
||||
</Collapse>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(QueueItemComponent);
|
||||
@@ -0,0 +1,166 @@
|
||||
import { ButtonGroup, Flex, Heading, Spinner, Text } from '@chakra-ui/react';
|
||||
import IAIButton from 'common/components/IAIButton';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import ScrollableContent from 'features/nodes/components/sidePanel/ScrollableContent';
|
||||
import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress';
|
||||
import { MouseEvent, memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaTimes } from 'react-icons/fa';
|
||||
import {
|
||||
useCancelByBatchIdsMutation,
|
||||
useCancelQueueItemMutation,
|
||||
useGetBatchStatusQuery,
|
||||
useGetQueueItemQuery,
|
||||
} from 'services/api/endpoints/queue';
|
||||
import { SessionQueueItemDTO } from 'services/api/types';
|
||||
|
||||
type Props = {
|
||||
queueItemDTO: SessionQueueItemDTO;
|
||||
};
|
||||
|
||||
const QueueItemComponent = ({ queueItemDTO }: Props) => {
|
||||
const {
|
||||
batch_id,
|
||||
completed_at,
|
||||
error,
|
||||
item_id,
|
||||
session_id,
|
||||
started_at,
|
||||
status,
|
||||
} = queueItemDTO;
|
||||
const { t } = useTranslation();
|
||||
const isQueueMutationInProgress = useIsQueueMutationInProgress();
|
||||
const [cancelQueueItem, { isLoading: isLoadingCancelQueueItem }] =
|
||||
useCancelQueueItemMutation();
|
||||
const [cancelByBatchIds, { isLoading: isLoadingCancelByBatchIds }] =
|
||||
useCancelByBatchIdsMutation();
|
||||
const { isCanceled } = useGetBatchStatusQuery(
|
||||
{ batch_id: queueItemDTO.batch_id },
|
||||
{
|
||||
selectFromResult: ({ data }) => {
|
||||
if (!data) {
|
||||
return { isCanceled: true };
|
||||
}
|
||||
|
||||
return {
|
||||
isCanceled: data?.in_progress === 0 && data?.pending === 0,
|
||||
};
|
||||
},
|
||||
}
|
||||
);
|
||||
const handleCancelQueueItem = useCallback(
|
||||
(e: MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
cancelQueueItem(item_id);
|
||||
},
|
||||
[cancelQueueItem, item_id]
|
||||
);
|
||||
const handleCancelBatch = useCallback(
|
||||
(e: MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
cancelByBatchIds({ batch_ids: [batch_id] });
|
||||
},
|
||||
[cancelByBatchIds, batch_id]
|
||||
);
|
||||
const { data: queueItem } = useGetQueueItemQuery(item_id);
|
||||
const executionTime = useMemo(() => {
|
||||
if (!completed_at || !started_at) {
|
||||
return 'n/a';
|
||||
}
|
||||
return String(
|
||||
((Date.parse(completed_at) - Date.parse(started_at)) / 1000).toFixed(2)
|
||||
);
|
||||
}, [completed_at, started_at]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
layerStyle="third"
|
||||
flexDir="column"
|
||||
p={2}
|
||||
pt={0}
|
||||
borderRadius="base"
|
||||
gap={2}
|
||||
>
|
||||
<Flex
|
||||
layerStyle="second"
|
||||
p={2}
|
||||
gap={2}
|
||||
justifyContent="space-between"
|
||||
alignItems="center"
|
||||
borderRadius="base"
|
||||
>
|
||||
<QueueItemData label="Item ID" data={item_id} />
|
||||
<QueueItemData label="Batch ID" data={batch_id} />
|
||||
<QueueItemData label="Session ID" data={session_id} />
|
||||
<QueueItemData label="Execution Time" data={executionTime} />
|
||||
<ButtonGroup size="xs" orientation="vertical">
|
||||
<IAIButton
|
||||
onClick={handleCancelQueueItem}
|
||||
isLoading={isLoadingCancelQueueItem}
|
||||
isDisabled={['canceled', 'completed', 'failed'].includes(status)}
|
||||
aria-label={t('queue.cancelItem')}
|
||||
icon={<FaTimes />}
|
||||
colorScheme="error"
|
||||
>
|
||||
{t('queue.cancelItem')}
|
||||
</IAIButton>
|
||||
<IAIButton
|
||||
onClick={handleCancelBatch}
|
||||
isLoading={isLoadingCancelByBatchIds || isQueueMutationInProgress}
|
||||
isDisabled={isCanceled}
|
||||
aria-label={t('queue.cancelBatch')}
|
||||
icon={<FaTimes />}
|
||||
colorScheme="error"
|
||||
>
|
||||
{t('queue.cancelBatch')}
|
||||
</IAIButton>
|
||||
</ButtonGroup>
|
||||
</Flex>
|
||||
{error && (
|
||||
<Flex
|
||||
layerStyle="second"
|
||||
p={3}
|
||||
gap={1}
|
||||
justifyContent="space-between"
|
||||
alignItems="flex-start"
|
||||
borderRadius="base"
|
||||
flexDir="column"
|
||||
>
|
||||
<Heading size="sm" color="error.500" _dark={{ color: 'error.400' }}>
|
||||
Error
|
||||
</Heading>
|
||||
<pre>{error}</pre>
|
||||
</Flex>
|
||||
)}
|
||||
<Flex
|
||||
layerStyle="second"
|
||||
h={512}
|
||||
w="full"
|
||||
borderRadius="base"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
{queueItem ? (
|
||||
<ScrollableContent>
|
||||
<DataViewer label="Queue Item" data={queueItem} />
|
||||
</ScrollableContent>
|
||||
) : (
|
||||
<Spinner opacity={0.5} />
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(QueueItemComponent);
|
||||
|
||||
type QueueItemDataProps = { label: string; data: string };
|
||||
|
||||
const QueueItemData = ({ label, data }: QueueItemDataProps) => {
|
||||
return (
|
||||
<Flex flexDir="column" p={1} gap={1}>
|
||||
<Heading size="sm">{label}</Heading>
|
||||
<Text>{data}</Text>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
@@ -1,20 +1,8 @@
|
||||
import {
|
||||
Box,
|
||||
ChakraProps,
|
||||
Collapse,
|
||||
Flex,
|
||||
Text,
|
||||
forwardRef,
|
||||
} from '@chakra-ui/react';
|
||||
import { Flex, Heading } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/dist/query';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||
import IAIIconButton from 'common/components/IAIIconButton';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import ScrollableContent from 'features/nodes/components/sidePanel/ScrollableContent';
|
||||
import {
|
||||
listCursorChanged,
|
||||
listPriorityChanged,
|
||||
@@ -23,26 +11,18 @@ import {
|
||||
UseOverlayScrollbarsParams,
|
||||
useOverlayScrollbars,
|
||||
} from 'overlayscrollbars-react';
|
||||
import {
|
||||
MouseEvent,
|
||||
memo,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaTimes } from 'react-icons/fa';
|
||||
import { Components, ItemContent, Virtuoso } from 'react-virtuoso';
|
||||
import {
|
||||
queueItemsAdapter,
|
||||
useCancelQueueItemMutation,
|
||||
useGetQueueItemQuery,
|
||||
useListQueueItemsQuery,
|
||||
} from 'services/api/endpoints/queue';
|
||||
import { SessionQueueItemDTO } from 'services/api/types';
|
||||
import QueueStatusBadge from '../common/QueueStatusBadge';
|
||||
import QueueItemComponent from './QueueItemComponent';
|
||||
import QueueListComponent from './QueueListComponent';
|
||||
import QueueListHeader from './QueueListHeader';
|
||||
import { ListContext } from './types';
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
type TableVirtuosoScrollerRef = (ref: HTMLElement | Window | null) => any;
|
||||
@@ -69,39 +49,19 @@ const selector = createSelector(
|
||||
defaultSelectorOptions
|
||||
);
|
||||
|
||||
const COLUMN_WIDTHS = {
|
||||
number: '3rem',
|
||||
statusBadge: '5.7rem',
|
||||
batchId: '5rem',
|
||||
fieldValues: 'auto',
|
||||
actions: 'auto',
|
||||
};
|
||||
|
||||
const computeItemKey = (index: number, item: SessionQueueItemDTO): string =>
|
||||
item.item_id;
|
||||
|
||||
type ListContext = {
|
||||
openQueueItems: string[];
|
||||
toggleQueueItem: (item_id: string) => void;
|
||||
};
|
||||
|
||||
const ListComponent: Components<SessionQueueItemDTO, ListContext>['List'] =
|
||||
memo(
|
||||
forwardRef((props, ref) => {
|
||||
return (
|
||||
<Flex {...props} ref={ref} flexDirection="column">
|
||||
{props.children}
|
||||
</Flex>
|
||||
);
|
||||
})
|
||||
);
|
||||
|
||||
ListComponent.displayName = 'ListComponent';
|
||||
|
||||
const components: Components<SessionQueueItemDTO, ListContext> = {
|
||||
List: ListComponent,
|
||||
List: QueueListComponent,
|
||||
};
|
||||
|
||||
const itemContent: ItemContent<SessionQueueItemDTO, ListContext> = (
|
||||
index,
|
||||
item,
|
||||
context
|
||||
) => <QueueItemComponent index={index} item={item} context={context} />;
|
||||
|
||||
const QueueList = () => {
|
||||
const { listCursor, listPriority } = useAppSelector(selector);
|
||||
const dispatch = useAppDispatch();
|
||||
@@ -110,6 +70,7 @@ const QueueList = () => {
|
||||
const [initialize, osInstance] = useOverlayScrollbars(
|
||||
overlayScrollbarsConfig
|
||||
);
|
||||
const { t } = useTranslation();
|
||||
|
||||
useEffect(() => {
|
||||
const { current: root } = rootRef;
|
||||
@@ -165,36 +126,16 @@ const QueueList = () => {
|
||||
);
|
||||
|
||||
return (
|
||||
<Box w="full" h="full">
|
||||
<Flex w="full" h="full" flexDir="column">
|
||||
<QueueListHeader />
|
||||
<Flex
|
||||
ref={rootRef}
|
||||
w="full"
|
||||
h="full"
|
||||
alignItems="center"
|
||||
gap={4}
|
||||
p={1}
|
||||
pb={2}
|
||||
textTransform="uppercase"
|
||||
fontWeight={700}
|
||||
fontSize="xs"
|
||||
letterSpacing={1}
|
||||
justifyContent="center"
|
||||
>
|
||||
<Flex
|
||||
w={COLUMN_WIDTHS.number}
|
||||
justifyContent="flex-end"
|
||||
alignItems="center"
|
||||
>
|
||||
<Text variant="subtext">#</Text>
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.statusBadge} alignItems="center">
|
||||
<Text variant="subtext">status</Text>
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.batchId} alignItems="center">
|
||||
<Text variant="subtext">batch</Text>
|
||||
</Flex>
|
||||
<Flex alignItems="center" w={COLUMN_WIDTHS.fieldValues}>
|
||||
<Text variant="subtext">batch field values</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Box ref={rootRef} w="full" h="full">
|
||||
{listQueueItemsData && (
|
||||
{queueItems.length ? (
|
||||
<Virtuoso<SessionQueueItemDTO, ListContext>
|
||||
data={queueItems}
|
||||
endReached={handleLoadMore}
|
||||
@@ -203,144 +144,15 @@ const QueueList = () => {
|
||||
computeItemKey={computeItemKey}
|
||||
components={components}
|
||||
context={context}
|
||||
style={{ display: 'flex', flexDirection: 'column', gap: '2px' }}
|
||||
/>
|
||||
) : (
|
||||
<Heading color="base.400" _dark={{ color: 'base.500' }}>
|
||||
{t('queue.queueEmpty')}
|
||||
</Heading>
|
||||
)}
|
||||
</Box>
|
||||
</Box>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(QueueList);
|
||||
|
||||
const selectedStyles = { bg: 'base.300', _dark: { bg: 'base.750' } };
|
||||
|
||||
const itemContent: ItemContent<SessionQueueItemDTO, ListContext> = (
|
||||
index,
|
||||
item,
|
||||
context
|
||||
) => <InnerItem index={index} item={item} context={context} />;
|
||||
|
||||
type InnerItemProps = {
|
||||
index: number;
|
||||
item: SessionQueueItemDTO;
|
||||
context: ListContext;
|
||||
};
|
||||
|
||||
const sx: ChakraProps['sx'] = {
|
||||
_hover: selectedStyles,
|
||||
"&[aria-selected='true']": selectedStyles,
|
||||
};
|
||||
|
||||
const InnerItem = memo(({ index, item, context }: InnerItemProps) => {
|
||||
const { t } = useTranslation();
|
||||
const handleToggle = useCallback(() => {
|
||||
context.toggleQueueItem(item.item_id);
|
||||
}, [context, item.item_id]);
|
||||
const [cancelQueueItem, { isLoading }] = useCancelQueueItemMutation();
|
||||
const handleCancel = useCallback(
|
||||
(e: MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
cancelQueueItem(item.item_id);
|
||||
},
|
||||
[cancelQueueItem, item.item_id]
|
||||
);
|
||||
const isOpen = useMemo(
|
||||
() => context.openQueueItems.includes(item.item_id),
|
||||
[context.openQueueItems, item.item_id]
|
||||
);
|
||||
const { data: queueItem } = useGetQueueItemQuery(
|
||||
isOpen ? item.item_id : skipToken
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
flexDir="column"
|
||||
borderRadius="base"
|
||||
aria-selected={isOpen}
|
||||
fontSize="sm"
|
||||
justifyContent="center"
|
||||
sx={sx}
|
||||
>
|
||||
<Flex
|
||||
alignItems="center"
|
||||
gap={4}
|
||||
p={1}
|
||||
cursor="pointer"
|
||||
onClick={handleToggle}
|
||||
>
|
||||
<Flex
|
||||
w={COLUMN_WIDTHS.number}
|
||||
justifyContent="flex-end"
|
||||
alignItems="center"
|
||||
>
|
||||
<Text variant="subtext">{index + 1}</Text>
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.statusBadge} alignItems="center">
|
||||
<QueueStatusBadge status={item.status} />
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.batchId}>
|
||||
<Text
|
||||
overflow="hidden"
|
||||
textOverflow="ellipsis"
|
||||
whiteSpace="nowrap"
|
||||
alignItems="center"
|
||||
>
|
||||
{item.batch_id}
|
||||
</Text>
|
||||
</Flex>
|
||||
<Flex alignItems="center" flexGrow={1}>
|
||||
{item.field_values && (
|
||||
<Flex gap={2}>
|
||||
{item.field_values
|
||||
.filter((v) => v.node_path !== 'metadata_accumulator')
|
||||
.map(({ node_path, field_name, value }) => (
|
||||
<Text
|
||||
key={`${item.item_id}.${node_path}.${field_name}.${value}`}
|
||||
whiteSpace="nowrap"
|
||||
textOverflow="ellipsis"
|
||||
overflow="hidden"
|
||||
>
|
||||
<Text as="span" fontWeight={600}>
|
||||
{node_path}.{field_name}
|
||||
</Text>
|
||||
: {value}
|
||||
</Text>
|
||||
))}
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
<Flex alignItems="center" w={COLUMN_WIDTHS.actions}>
|
||||
<IAIIconButton
|
||||
tooltip={t('queue.cancelItem')}
|
||||
onClick={handleCancel}
|
||||
isLoading={isLoading}
|
||||
isDisabled={['canceled', 'completed', 'failed'].includes(
|
||||
item.status
|
||||
)}
|
||||
aria-label={t('queue.cancelItem')}
|
||||
size="xs"
|
||||
variant="ghost"
|
||||
icon={<FaTimes />}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
|
||||
<Collapse in={isOpen}>
|
||||
<Flex layerStyle="third" p={2} pt={0} borderRadius="base">
|
||||
<Flex h={512} w="full" pos="relative">
|
||||
{queueItem ? (
|
||||
<ScrollableContent>
|
||||
<DataViewer label="Queue Item" data={queueItem} />
|
||||
</ScrollableContent>
|
||||
) : (
|
||||
<IAINoContentFallback label="Loading" icon={null} />
|
||||
)}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Collapse>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
InnerItem.displayName = 'InnerItem';
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
import { Flex, forwardRef } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { Components } from 'react-virtuoso';
|
||||
import { SessionQueueItemDTO } from 'services/api/types';
|
||||
import { ListContext } from './types';
|
||||
|
||||
const QueueListComponent: Components<SessionQueueItemDTO, ListContext>['List'] =
|
||||
memo(
|
||||
forwardRef((props, ref) => {
|
||||
return (
|
||||
<Flex {...props} ref={ref} flexDirection="column" gap={0.5}>
|
||||
{props.children}
|
||||
</Flex>
|
||||
);
|
||||
})
|
||||
);
|
||||
|
||||
export default memo(QueueListComponent);
|
||||
@@ -0,0 +1,37 @@
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import { COLUMN_WIDTHS } from './constants';
|
||||
|
||||
const QueueListHeader = () => {
|
||||
return (
|
||||
<Flex
|
||||
alignItems="center"
|
||||
gap={4}
|
||||
p={1}
|
||||
pb={2}
|
||||
textTransform="uppercase"
|
||||
fontWeight={700}
|
||||
fontSize="xs"
|
||||
letterSpacing={1}
|
||||
>
|
||||
<Flex
|
||||
w={COLUMN_WIDTHS.number}
|
||||
justifyContent="flex-end"
|
||||
alignItems="center"
|
||||
>
|
||||
<Text variant="subtext">#</Text>
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.statusBadge} alignItems="center">
|
||||
<Text variant="subtext">status</Text>
|
||||
</Flex>
|
||||
<Flex w={COLUMN_WIDTHS.batchId} alignItems="center">
|
||||
<Text variant="subtext">batch</Text>
|
||||
</Flex>
|
||||
<Flex alignItems="center" w={COLUMN_WIDTHS.fieldValues}>
|
||||
<Text variant="subtext">batch field values</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(QueueListHeader);
|
||||
@@ -0,0 +1,7 @@
|
||||
export const COLUMN_WIDTHS = {
|
||||
number: '3rem',
|
||||
statusBadge: '5.7rem',
|
||||
batchId: '5rem',
|
||||
fieldValues: 'auto',
|
||||
actions: 'auto',
|
||||
};
|
||||
@@ -0,0 +1,4 @@
|
||||
export type ListContext = {
|
||||
openQueueItems: string[];
|
||||
toggleQueueItem: (item_id: string) => void;
|
||||
};
|
||||
@@ -4,8 +4,8 @@ import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaPlay } from 'react-icons/fa';
|
||||
import {
|
||||
useGetQueueStatusQuery,
|
||||
useStartQueueExecutionMutation,
|
||||
useGetProcessorStatusQuery,
|
||||
useResumeProcessorMutation,
|
||||
} from 'services/api/endpoints/queue';
|
||||
import { useIsQueueMutationInProgress } from '../hooks/useIsQueueMutationInProgress';
|
||||
import QueueButton from './common/QueueButton';
|
||||
@@ -14,18 +14,18 @@ type Props = {
|
||||
asIconButton?: boolean;
|
||||
};
|
||||
|
||||
const StartQueueButton = ({ asIconButton }: Props) => {
|
||||
const { data: queueStatusData } = useGetQueueStatusQuery();
|
||||
const ResumeProcessorButton = ({ asIconButton }: Props) => {
|
||||
const { data: processorStatus } = useGetProcessorStatusQuery();
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const [startQueue] = useStartQueueExecutionMutation({
|
||||
fixedCacheKey: 'startQueue',
|
||||
const [resumeProcessor] = useResumeProcessorMutation({
|
||||
fixedCacheKey: 'resumeProcessor',
|
||||
});
|
||||
const isQueueMutationInProgress = useIsQueueMutationInProgress();
|
||||
|
||||
const handleClick = useCallback(async () => {
|
||||
try {
|
||||
await startQueue().unwrap();
|
||||
await resumeProcessor().unwrap();
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.startSucceeded'),
|
||||
@@ -40,7 +40,7 @@ const StartQueueButton = ({ asIconButton }: Props) => {
|
||||
})
|
||||
);
|
||||
}
|
||||
}, [dispatch, startQueue, t]);
|
||||
}, [dispatch, resumeProcessor, t]);
|
||||
|
||||
return (
|
||||
<QueueButton
|
||||
@@ -48,9 +48,8 @@ const StartQueueButton = ({ asIconButton }: Props) => {
|
||||
label={t('queue.start')}
|
||||
tooltip={t('queue.startTooltip')}
|
||||
isDisabled={
|
||||
queueStatusData?.started ||
|
||||
queueStatusData?.stop_after_current ||
|
||||
queueStatusData?.pending === 0 ||
|
||||
processorStatus?.is_started ||
|
||||
processorStatus?.is_processing ||
|
||||
isQueueMutationInProgress
|
||||
}
|
||||
icon={<FaPlay />}
|
||||
@@ -60,4 +59,4 @@ const StartQueueButton = ({ asIconButton }: Props) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(StartQueueButton);
|
||||
export default memo(ResumeProcessorButton);
|
||||
|
||||
@@ -4,8 +4,8 @@ import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { FaStop } from 'react-icons/fa';
|
||||
import {
|
||||
useGetQueueStatusQuery,
|
||||
useStopQueueExecutionMutation,
|
||||
useGetProcessorStatusQuery,
|
||||
usePauseProcessorMutation,
|
||||
} from 'services/api/endpoints/queue';
|
||||
import { useIsQueueMutationInProgress } from '../hooks/useIsQueueMutationInProgress';
|
||||
import QueueButton from './common/QueueButton';
|
||||
@@ -14,18 +14,18 @@ type Props = {
|
||||
asIconButton?: boolean;
|
||||
};
|
||||
|
||||
const StopQueueButton = ({ asIconButton }: Props) => {
|
||||
const { data: queueStatusData } = useGetQueueStatusQuery();
|
||||
const PauseProcessorButton = ({ asIconButton }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const [stopQueue] = useStopQueueExecutionMutation({
|
||||
fixedCacheKey: 'stopQueue',
|
||||
const { data: processorStatus } = useGetProcessorStatusQuery();
|
||||
const [pauseProcessor] = usePauseProcessorMutation({
|
||||
fixedCacheKey: 'pauseProcessor',
|
||||
});
|
||||
const isQueueMutationInProgress = useIsQueueMutationInProgress();
|
||||
|
||||
const handleClick = useCallback(async () => {
|
||||
try {
|
||||
await stopQueue().unwrap();
|
||||
await pauseProcessor().unwrap();
|
||||
dispatch(
|
||||
addToast({
|
||||
title: t('queue.stopRequested'),
|
||||
@@ -40,14 +40,15 @@ const StopQueueButton = ({ asIconButton }: Props) => {
|
||||
})
|
||||
);
|
||||
}
|
||||
}, [dispatch, stopQueue, t]);
|
||||
}, [dispatch, pauseProcessor, t]);
|
||||
|
||||
return (
|
||||
<QueueButton
|
||||
asIconButton={asIconButton}
|
||||
label={t('queue.stop')}
|
||||
tooltip={t('queue.stopTooltip')}
|
||||
isDisabled={!queueStatusData?.started || isQueueMutationInProgress}
|
||||
isDisabled={!processorStatus?.is_started || isQueueMutationInProgress}
|
||||
isLoading={processorStatus?.is_stop_pending}
|
||||
icon={<FaStop />}
|
||||
onClick={handleClick}
|
||||
colorScheme="gold"
|
||||
@@ -55,4 +56,4 @@ const StopQueueButton = ({ asIconButton }: Props) => {
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(StopQueueButton);
|
||||
export default memo(PauseProcessorButton);
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { ButtonGroup, ButtonGroupProps, Flex } from '@chakra-ui/react';
|
||||
import { memo } from 'react';
|
||||
import CancelQueueButton from './CancelQueueButton';
|
||||
import CancelCurrentQueueItemButton from './CancelCurrentQueueItemButton';
|
||||
import ClearQueueButton from './ClearQueueButton';
|
||||
import PruneQueueButton from './PruneQueueButton';
|
||||
import StartQueueButton from './StartQueueButton';
|
||||
import StopQueueButton from './StopQueueButton';
|
||||
import ResumeProcessorButton from './StartQueueButton';
|
||||
import PauseProcessorButton from './StopQueueButton';
|
||||
|
||||
type Props = ButtonGroupProps & {
|
||||
asIconButtons?: boolean;
|
||||
@@ -14,9 +14,9 @@ const VerticalQueueControls = ({ asIconButtons, ...rest }: Props) => {
|
||||
return (
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<ButtonGroup w="full" isAttached {...rest}>
|
||||
<StartQueueButton asIconButton={asIconButtons} />
|
||||
<StopQueueButton asIconButton={asIconButtons} />
|
||||
<CancelQueueButton asIconButton={asIconButtons} />
|
||||
<ResumeProcessorButton asIconButton={asIconButtons} />
|
||||
<PauseProcessorButton asIconButton={asIconButtons} />
|
||||
<CancelCurrentQueueItemButton asIconButton={asIconButtons} />
|
||||
</ButtonGroup>
|
||||
<ButtonGroup w="full" isAttached {...rest}>
|
||||
<PruneQueueButton asIconButton={asIconButtons} />
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import {
|
||||
useCancelByBatchIdsMutation,
|
||||
useCancelQueueExecutionMutation,
|
||||
useCancelQueueItemMutation,
|
||||
// useCancelByBatchIdsMutation,
|
||||
useClearQueueMutation,
|
||||
useEnqueueBatchMutation,
|
||||
useEnqueueGraphMutation,
|
||||
usePruneQueueMutation,
|
||||
useStartQueueExecutionMutation,
|
||||
useStopQueueExecutionMutation,
|
||||
useResumeProcessorMutation,
|
||||
usePauseProcessorMutation,
|
||||
} from 'services/api/endpoints/queue';
|
||||
|
||||
export const useIsQueueMutationInProgress = () => {
|
||||
@@ -18,17 +18,17 @@ export const useIsQueueMutationInProgress = () => {
|
||||
useEnqueueGraphMutation({
|
||||
fixedCacheKey: 'enqueueGraph',
|
||||
});
|
||||
const [_triggerStartQueue, { isLoading: isLoadingStartQueue }] =
|
||||
useStartQueueExecutionMutation({
|
||||
fixedCacheKey: 'startQueue',
|
||||
const [_triggerResumeProcessor, { isLoading: isLoadingResumeProcessor }] =
|
||||
useResumeProcessorMutation({
|
||||
fixedCacheKey: 'resumeProcessor',
|
||||
});
|
||||
const [_triggerStopQueue, { isLoading: isLoadingStopQueue }] =
|
||||
useStopQueueExecutionMutation({
|
||||
fixedCacheKey: 'stopQueue',
|
||||
const [_triggerPauseProcessor, { isLoading: isLoadingPauseProcessor }] =
|
||||
usePauseProcessorMutation({
|
||||
fixedCacheKey: 'pauseProcessor',
|
||||
});
|
||||
const [_triggerCancelQueue, { isLoading: isLoadingCancelQueue }] =
|
||||
useCancelQueueExecutionMutation({
|
||||
fixedCacheKey: 'cancelQueue',
|
||||
useCancelQueueItemMutation({
|
||||
fixedCacheKey: 'cancelQueueItem',
|
||||
});
|
||||
const [_triggerClearQueue, { isLoading: isLoadingClearQueue }] =
|
||||
useClearQueueMutation({
|
||||
@@ -38,18 +38,18 @@ export const useIsQueueMutationInProgress = () => {
|
||||
usePruneQueueMutation({
|
||||
fixedCacheKey: 'pruneQueue',
|
||||
});
|
||||
const [_triggerCancelByBatchIds, { isLoading: isLoadingCancelByBatchIds }] =
|
||||
useCancelByBatchIdsMutation({
|
||||
fixedCacheKey: 'cancelByBatchIds',
|
||||
});
|
||||
// const [_triggerCancelByBatchIds, { isLoading: isLoadingCancelByBatchIds }] =
|
||||
// useCancelByBatchIdsMutation({
|
||||
// fixedCacheKey: 'cancelByBatchIds',
|
||||
// });
|
||||
return (
|
||||
isLoadingEnqueueBatch ||
|
||||
isLoadingEnqueueGraph ||
|
||||
isLoadingStartQueue ||
|
||||
isLoadingStopQueue ||
|
||||
isLoadingResumeProcessor ||
|
||||
isLoadingPauseProcessor ||
|
||||
isLoadingCancelQueue ||
|
||||
isLoadingClearQueue ||
|
||||
isLoadingPruneQueue ||
|
||||
isLoadingCancelByBatchIds
|
||||
isLoadingPruneQueue
|
||||
// isLoadingCancelByBatchIds
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
import { useGetProcessorStatusQuery } from 'services/api/endpoints/queue';
|
||||
|
||||
export const useIsQueueStarted = () => {
|
||||
const { isStarted } = useGetQueueStatusQuery(undefined, {
|
||||
const { isStarted } = useGetProcessorStatusQuery(undefined, {
|
||||
selectFromResult: ({ data }) => {
|
||||
if (!data) {
|
||||
return { isStarted: false };
|
||||
}
|
||||
|
||||
return { isStarted: data.started || data.stop_after_current };
|
||||
return { isStarted: data.is_started || data.is_processing };
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
|
||||
import { useMemo } from 'react';
|
||||
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
|
||||
|
||||
const selector = createSelector(
|
||||
[stateSelector, activeTabNameSelector],
|
||||
({ dynamicPrompts, generation }, activeTabName) => {
|
||||
if (activeTabName === 'nodes') {
|
||||
return generation.iterations;
|
||||
}
|
||||
return dynamicPrompts.prompts.length * generation.iterations;
|
||||
}
|
||||
);
|
||||
|
||||
export const usePredictedQueueCounts = () => {
|
||||
const { data: queueStatus } = useGetQueueStatusQuery();
|
||||
const requested = useAppSelector(selector);
|
||||
const counts = useMemo(() => {
|
||||
if (!queueStatus) {
|
||||
return;
|
||||
}
|
||||
const { max_queue_size, pending } = queueStatus;
|
||||
const maxNew = max_queue_size - pending;
|
||||
return {
|
||||
requested,
|
||||
max_queue_size,
|
||||
predicted: Math.min(requested, maxNew),
|
||||
};
|
||||
}, [queueStatus, requested]);
|
||||
return counts;
|
||||
};
|
||||
@@ -41,6 +41,7 @@ const SDXLImageToImageTabCoreParameters = () => {
|
||||
>
|
||||
{shouldUseSliders ? (
|
||||
<>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
<ParamModelandVAEandScheduler />
|
||||
@@ -52,6 +53,7 @@ const SDXLImageToImageTabCoreParameters = () => {
|
||||
) : (
|
||||
<>
|
||||
<Flex gap={3}>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</Flex>
|
||||
|
||||
@@ -39,6 +39,7 @@ const SDXLUnifiedCanvasTabCoreParameters = () => {
|
||||
>
|
||||
{shouldUseSliders ? (
|
||||
<>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
<ParamModelandVAEandScheduler />
|
||||
@@ -50,6 +51,7 @@ const SDXLUnifiedCanvasTabCoreParameters = () => {
|
||||
) : (
|
||||
<>
|
||||
<Flex gap={3}>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</Flex>
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { Progress } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useIsQueueStarted } from 'features/queue/hooks/useIsQueueStarted';
|
||||
import { SystemState } from 'features/system/store/systemSlice';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetProcessorStatusQuery } from 'services/api/endpoints/queue';
|
||||
import { systemSelector } from '../store/systemSelectors';
|
||||
|
||||
const progressBarSelector = createSelector(
|
||||
@@ -24,22 +24,22 @@ const progressBarSelector = createSelector(
|
||||
|
||||
const ProgressBar = () => {
|
||||
const { t } = useTranslation();
|
||||
const isStarted = useIsQueueStarted();
|
||||
const { data: processorStatus } = useGetProcessorStatusQuery();
|
||||
const { currentStep, totalSteps, currentStatusHasSteps } =
|
||||
useAppSelector(progressBarSelector);
|
||||
|
||||
const value = useMemo(() => {
|
||||
if (currentStep && isStarted) {
|
||||
if (currentStep && processorStatus?.is_processing) {
|
||||
return Math.round((currentStep * 100) / totalSteps);
|
||||
}
|
||||
return 0;
|
||||
}, [currentStep, isStarted, totalSteps]);
|
||||
}, [currentStep, processorStatus?.is_processing, totalSteps]);
|
||||
|
||||
return (
|
||||
<Progress
|
||||
value={value}
|
||||
aria-label={t('accessibility.invokeProgressBar')}
|
||||
isIndeterminate={isStarted && !currentStatusHasSteps}
|
||||
isIndeterminate={processorStatus?.is_processing && !currentStatusHasSteps}
|
||||
h="full"
|
||||
w="full"
|
||||
borderRadius={2}
|
||||
|
||||
@@ -441,6 +441,5 @@ const isAnyServerError = isAnyOf(
|
||||
|
||||
const isAnyCancelQueueItem = isAnyOf(
|
||||
queueApi.endpoints.cancelQueueItem.matchFulfilled,
|
||||
queueApi.endpoints.cancelQueueExecution.matchFulfilled,
|
||||
queueApi.endpoints.clearQueue.matchFulfilled
|
||||
);
|
||||
|
||||
@@ -41,6 +41,7 @@ const ImageToImageTabCoreParameters = () => {
|
||||
>
|
||||
{shouldUseSliders ? (
|
||||
<>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
<ParamModelandVAEandScheduler />
|
||||
@@ -52,6 +53,7 @@ const ImageToImageTabCoreParameters = () => {
|
||||
) : (
|
||||
<>
|
||||
<Flex gap={3}>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</Flex>
|
||||
|
||||
@@ -42,6 +42,7 @@ const TextToImageTabCoreParameters = () => {
|
||||
>
|
||||
{shouldUseSliders ? (
|
||||
<>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
<ParamModelandVAEandScheduler />
|
||||
@@ -53,6 +54,7 @@ const TextToImageTabCoreParameters = () => {
|
||||
) : (
|
||||
<>
|
||||
<Flex gap={3}>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</Flex>
|
||||
|
||||
@@ -39,6 +39,7 @@ const UnifiedCanvasCoreParameters = () => {
|
||||
>
|
||||
{shouldUseSliders ? (
|
||||
<>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
<ParamModelandVAEandScheduler />
|
||||
@@ -50,6 +51,7 @@ const UnifiedCanvasCoreParameters = () => {
|
||||
) : (
|
||||
<>
|
||||
<Flex gap={3}>
|
||||
<ParamIterations />
|
||||
<ParamSteps />
|
||||
<ParamCFGScale />
|
||||
</Flex>
|
||||
|
||||
@@ -70,6 +70,7 @@ export const queueApi = api.injectEndpoints({
|
||||
}),
|
||||
invalidatesTags: [
|
||||
'SessionQueueStatus',
|
||||
'SessionProcessorStatus',
|
||||
'CurrentSessionQueueItem',
|
||||
'NextSessionQueueItem',
|
||||
],
|
||||
@@ -94,6 +95,7 @@ export const queueApi = api.injectEndpoints({
|
||||
}),
|
||||
invalidatesTags: [
|
||||
'SessionQueueStatus',
|
||||
'SessionProcessorStatus',
|
||||
'CurrentSessionQueueItem',
|
||||
'NextSessionQueueItem',
|
||||
],
|
||||
@@ -107,31 +109,19 @@ export const queueApi = api.injectEndpoints({
|
||||
}
|
||||
},
|
||||
}),
|
||||
startQueueExecution: build.mutation<void, void>({
|
||||
resumeProcessor: build.mutation<void, void>({
|
||||
query: () => ({
|
||||
url: `queue/${$queueId.get()}/start`,
|
||||
url: `queue/${$queueId.get()}/resume`,
|
||||
method: 'PUT',
|
||||
}),
|
||||
invalidatesTags: ['SessionQueueStatus', 'CurrentSessionQueueItem'],
|
||||
invalidatesTags: ['SessionProcessorStatus', 'CurrentSessionQueueItem'],
|
||||
}),
|
||||
stopQueueExecution: build.mutation<void, void>({
|
||||
pauseProcessor: build.mutation<void, void>({
|
||||
query: () => ({
|
||||
url: `queue/${$queueId.get()}/stop`,
|
||||
url: `queue/${$queueId.get()}/pause`,
|
||||
method: 'PUT',
|
||||
}),
|
||||
invalidatesTags: ['SessionQueueStatus', 'CurrentSessionQueueItem'],
|
||||
}),
|
||||
cancelQueueExecution: build.mutation<void, void>({
|
||||
query: () => ({
|
||||
url: `queue/${$queueId.get()}/cancel`,
|
||||
method: 'PUT',
|
||||
}),
|
||||
invalidatesTags: [
|
||||
'SessionQueueStatus',
|
||||
'SessionQueueItem',
|
||||
'SessionQueueItemDTO',
|
||||
'CurrentSessionQueueItem',
|
||||
],
|
||||
invalidatesTags: ['SessionProcessorStatus', 'CurrentSessionQueueItem'],
|
||||
}),
|
||||
pruneQueue: build.mutation<
|
||||
paths['/api/v1/queue/{queue_id}/prune']['put']['responses']['200']['content']['application/json'],
|
||||
@@ -143,6 +133,8 @@ export const queueApi = api.injectEndpoints({
|
||||
}),
|
||||
invalidatesTags: [
|
||||
'SessionQueueStatus',
|
||||
'SessionProcessorStatus',
|
||||
'BatchStatus',
|
||||
'SessionQueueItem',
|
||||
'SessionQueueItemDTO',
|
||||
],
|
||||
@@ -166,6 +158,8 @@ export const queueApi = api.injectEndpoints({
|
||||
}),
|
||||
invalidatesTags: [
|
||||
'SessionQueueStatus',
|
||||
'SessionProcessorStatus',
|
||||
'BatchStatus',
|
||||
'CurrentSessionQueueItem',
|
||||
'NextSessionQueueItem',
|
||||
'SessionQueueItem',
|
||||
@@ -197,12 +191,12 @@ export const queueApi = api.injectEndpoints({
|
||||
return tags;
|
||||
},
|
||||
}),
|
||||
peekNextQueueItem: build.query<
|
||||
paths['/api/v1/queue/{queue_id}/peek']['get']['responses']['200']['content']['application/json'],
|
||||
getNextQueueItem: build.query<
|
||||
paths['/api/v1/queue/{queue_id}/next']['get']['responses']['200']['content']['application/json'],
|
||||
void
|
||||
>({
|
||||
query: () => ({
|
||||
url: `queue/${$queueId.get()}/peek`,
|
||||
url: `queue/${$queueId.get()}/next`,
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: (result) => {
|
||||
@@ -223,6 +217,31 @@ export const queueApi = api.injectEndpoints({
|
||||
}),
|
||||
providesTags: ['SessionQueueStatus'],
|
||||
}),
|
||||
getProcessorStatus: build.query<
|
||||
paths['/api/v1/queue/{queue_id}/processor/status']['get']['responses']['200']['content']['application/json'],
|
||||
void
|
||||
>({
|
||||
query: () => ({
|
||||
url: `queue/${$queueId.get()}/processor/status`,
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: ['SessionProcessorStatus'],
|
||||
}),
|
||||
getBatchStatus: build.query<
|
||||
paths['/api/v1/queue/{queue_id}/b/{batch_id}/status']['get']['responses']['200']['content']['application/json'],
|
||||
{ batch_id: string }
|
||||
>({
|
||||
query: ({ batch_id }) => ({
|
||||
url: `queue/${$queueId.get()}/b/${batch_id}/status`,
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: (result) => {
|
||||
if (!result) {
|
||||
return [];
|
||||
}
|
||||
return [{ type: 'BatchStatus', id: result.batch_id }];
|
||||
},
|
||||
}),
|
||||
getQueueItem: build.query<
|
||||
paths['/api/v1/queue/{queue_id}/i/{item_id}']['get']['responses']['200']['content']['application/json'],
|
||||
string
|
||||
@@ -272,6 +291,7 @@ export const queueApi = api.injectEndpoints({
|
||||
return [
|
||||
{ type: 'SessionQueueItem', id: result.item_id },
|
||||
{ type: 'SessionQueueItemDTO', id: result.item_id },
|
||||
{ type: 'BatchStatus', id: result.batch_id },
|
||||
];
|
||||
},
|
||||
}),
|
||||
@@ -293,7 +313,11 @@ export const queueApi = api.injectEndpoints({
|
||||
// no-op
|
||||
}
|
||||
},
|
||||
invalidatesTags: ['SessionQueueItem', 'SessionQueueItemDTO'],
|
||||
invalidatesTags: [
|
||||
'SessionQueueItem',
|
||||
'SessionQueueItemDTO',
|
||||
'BatchStatus',
|
||||
],
|
||||
}),
|
||||
listQueueItems: build.query<
|
||||
EntityState<components['schemas']['SessionQueueItemDTO']> & {
|
||||
@@ -334,17 +358,18 @@ export const {
|
||||
useCancelByBatchIdsMutation,
|
||||
useEnqueueGraphMutation,
|
||||
useEnqueueBatchMutation,
|
||||
useCancelQueueExecutionMutation,
|
||||
useStopQueueExecutionMutation,
|
||||
useStartQueueExecutionMutation,
|
||||
usePauseProcessorMutation,
|
||||
useResumeProcessorMutation,
|
||||
useClearQueueMutation,
|
||||
usePruneQueueMutation,
|
||||
useGetCurrentQueueItemQuery,
|
||||
useGetQueueStatusQuery,
|
||||
useGetQueueItemQuery,
|
||||
usePeekNextQueueItemQuery,
|
||||
useGetNextQueueItemQuery,
|
||||
useListQueueItemsQuery,
|
||||
useCancelQueueItemMutation,
|
||||
useGetProcessorStatusQuery,
|
||||
useGetBatchStatusQuery,
|
||||
} = queueApi;
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
|
||||
@@ -22,6 +22,8 @@ export const tagTypes = [
|
||||
'SessionQueueItemDTO',
|
||||
'SessionQueueItemDTOList',
|
||||
'SessionQueueStatus',
|
||||
'SessionProcessorStatus',
|
||||
'BatchStatus',
|
||||
];
|
||||
export type ApiTagDescription = TagDescription<(typeof tagTypes)[number]>;
|
||||
export const LIST_TAG = 'LIST';
|
||||
|
||||
337
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
337
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
@@ -328,26 +328,19 @@ export type paths = {
|
||||
*/
|
||||
get: operations["list_queue_items"];
|
||||
};
|
||||
"/api/v1/queue/{queue_id}/start": {
|
||||
"/api/v1/queue/{queue_id}/resume": {
|
||||
/**
|
||||
* Start
|
||||
* @description Starts session queue execution
|
||||
* Resume
|
||||
* @description Resumes session processor
|
||||
*/
|
||||
put: operations["start"];
|
||||
put: operations["resume"];
|
||||
};
|
||||
"/api/v1/queue/{queue_id}/stop": {
|
||||
"/api/v1/queue/{queue_id}/pause": {
|
||||
/**
|
||||
* Stop
|
||||
* @description Stops session queue execution, waiting for the currently executing session to finish
|
||||
* Pause
|
||||
* @description Pauses session processor
|
||||
*/
|
||||
put: operations["stop"];
|
||||
};
|
||||
"/api/v1/queue/{queue_id}/cancel": {
|
||||
/**
|
||||
* Cancel
|
||||
* @description Stops session queue execution, immediately canceling the currently-executing session
|
||||
*/
|
||||
put: operations["cancel"];
|
||||
put: operations["pause"];
|
||||
};
|
||||
"/api/v1/queue/{queue_id}/cancel_by_batch_ids": {
|
||||
/**
|
||||
@@ -372,24 +365,38 @@ export type paths = {
|
||||
};
|
||||
"/api/v1/queue/{queue_id}/current": {
|
||||
/**
|
||||
* Current
|
||||
* Get Current Queue Item
|
||||
* @description Gets the currently execution queue item
|
||||
*/
|
||||
get: operations["current"];
|
||||
get: operations["get_current_queue_item"];
|
||||
};
|
||||
"/api/v1/queue/{queue_id}/peek": {
|
||||
"/api/v1/queue/{queue_id}/next": {
|
||||
/**
|
||||
* Peek
|
||||
* Get Next Queue Item
|
||||
* @description Gets the next queue item, without executing it
|
||||
*/
|
||||
get: operations["peek"];
|
||||
get: operations["get_next_queue_item"];
|
||||
};
|
||||
"/api/v1/queue/{queue_id}/status": {
|
||||
/**
|
||||
* Get Status
|
||||
* Get Queue Status
|
||||
* @description Gets the status of the session queue
|
||||
*/
|
||||
get: operations["get_status"];
|
||||
get: operations["get_queue_status"];
|
||||
};
|
||||
"/api/v1/queue/{queue_id}/processor/status": {
|
||||
/**
|
||||
* Get Processor Status
|
||||
* @description Gets the status of the session queue
|
||||
*/
|
||||
get: operations["get_processor_status"];
|
||||
};
|
||||
"/api/v1/queue/{queue_id}/b/{batch_id}/status": {
|
||||
/**
|
||||
* Get Batch Status
|
||||
* @description Gets the status of the session queue
|
||||
*/
|
||||
get: operations["get_batch_status"];
|
||||
};
|
||||
"/api/v1/queue/{queue_id}/i/{item_id}": {
|
||||
/**
|
||||
@@ -549,6 +556,49 @@ export type components = {
|
||||
*/
|
||||
items?: (string | number)[];
|
||||
};
|
||||
/** BatchStatusResult */
|
||||
BatchStatusResult: {
|
||||
/**
|
||||
* Queue Id
|
||||
* @description The ID of the queue
|
||||
*/
|
||||
queue_id: string;
|
||||
/**
|
||||
* Batch Id
|
||||
* @description The ID of the batch
|
||||
*/
|
||||
batch_id: string;
|
||||
/**
|
||||
* Pending
|
||||
* @description Number of queue items with status 'pending'
|
||||
*/
|
||||
pending: number;
|
||||
/**
|
||||
* In Progress
|
||||
* @description Number of queue items with status 'in_progress'
|
||||
*/
|
||||
in_progress: number;
|
||||
/**
|
||||
* Completed
|
||||
* @description Number of queue items with status 'complete'
|
||||
*/
|
||||
completed: number;
|
||||
/**
|
||||
* Failed
|
||||
* @description Number of queue items with status 'error'
|
||||
*/
|
||||
failed: number;
|
||||
/**
|
||||
* Canceled
|
||||
* @description Number of queue items with status 'canceled'
|
||||
*/
|
||||
canceled: number;
|
||||
/**
|
||||
* Total
|
||||
* @description Total number of queue items
|
||||
*/
|
||||
total: number;
|
||||
};
|
||||
/**
|
||||
* Blank Image
|
||||
* @description Creates a blank image and forwards it to the pipeline
|
||||
@@ -1016,7 +1066,10 @@ export type components = {
|
||||
*/
|
||||
type: "infill_cv2";
|
||||
};
|
||||
/** CancelByBatchIDsResult */
|
||||
/**
|
||||
* CancelByBatchIDsResult
|
||||
* @description Result of canceling by list of batch ids
|
||||
*/
|
||||
CancelByBatchIDsResult: {
|
||||
/**
|
||||
* Canceled
|
||||
@@ -1124,11 +1177,6 @@ export type components = {
|
||||
* @description The workflow to save with the image
|
||||
*/
|
||||
workflow?: string;
|
||||
/**
|
||||
* CLIP
|
||||
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
||||
*/
|
||||
clip?: components["schemas"]["ClipField"];
|
||||
/**
|
||||
* Skipped Layers
|
||||
* @description Number of layers to skip in text encoder
|
||||
@@ -1141,6 +1189,11 @@ export type components = {
|
||||
* @enum {string}
|
||||
*/
|
||||
type: "clip_skip";
|
||||
/**
|
||||
* CLIP
|
||||
* @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count
|
||||
*/
|
||||
clip?: components["schemas"]["ClipField"];
|
||||
};
|
||||
/**
|
||||
* ClipSkipInvocationOutput
|
||||
@@ -2338,6 +2391,11 @@ export type components = {
|
||||
};
|
||||
/** EnqueueBatchResult */
|
||||
EnqueueBatchResult: {
|
||||
/**
|
||||
* Queue Id
|
||||
* @description The ID of the queue
|
||||
*/
|
||||
queue_id: string;
|
||||
/**
|
||||
* Enqueued
|
||||
* @description The total number of queue items enqueued
|
||||
@@ -6890,58 +6948,23 @@ export type components = {
|
||||
*/
|
||||
type: "segment_anything_processor";
|
||||
};
|
||||
/** SessionQueueAndExecutionStatusResult */
|
||||
SessionQueueAndExecutionStatusResult: {
|
||||
/** SessionProcessorStatusResult */
|
||||
SessionProcessorStatusResult: {
|
||||
/**
|
||||
* Started
|
||||
* @description Whether the session queue is running
|
||||
* Is Started
|
||||
* @description Whether the session processor is started
|
||||
*/
|
||||
started: boolean;
|
||||
is_started: boolean;
|
||||
/**
|
||||
* Stop After Current
|
||||
* @description Whether the session queue is pending a stop
|
||||
* Is Processing
|
||||
* @description Whether a session is being processed
|
||||
*/
|
||||
stop_after_current: boolean;
|
||||
is_processing: boolean;
|
||||
/**
|
||||
* Queue Id
|
||||
* @description The ID of the queue
|
||||
* Is Stop Pending
|
||||
* @description Whether processor is pending stopping
|
||||
*/
|
||||
queue_id: string;
|
||||
/**
|
||||
* Pending
|
||||
* @description Number of queue items with status 'pending'
|
||||
*/
|
||||
pending: number;
|
||||
/**
|
||||
* In Progress
|
||||
* @description Number of queue items with status 'in_progress'
|
||||
*/
|
||||
in_progress: number;
|
||||
/**
|
||||
* Completed
|
||||
* @description Number of queue items with status 'complete'
|
||||
*/
|
||||
completed: number;
|
||||
/**
|
||||
* Failed
|
||||
* @description Number of queue items with status 'error'
|
||||
*/
|
||||
failed: number;
|
||||
/**
|
||||
* Canceled
|
||||
* @description Number of queue items with status 'canceled'
|
||||
*/
|
||||
canceled: number;
|
||||
/**
|
||||
* Total
|
||||
* @description Total number of queue items
|
||||
*/
|
||||
total: number;
|
||||
/**
|
||||
* Max Queue Size
|
||||
* @description Maximum number of queue items allowed
|
||||
*/
|
||||
max_queue_size: number;
|
||||
is_stop_pending: boolean;
|
||||
};
|
||||
/**
|
||||
* SessionQueueItem
|
||||
@@ -7006,6 +7029,11 @@ export type components = {
|
||||
* @description When this queue item was updated
|
||||
*/
|
||||
updated_at: string;
|
||||
/**
|
||||
* Started At
|
||||
* @description When this queue item was started
|
||||
*/
|
||||
started_at?: string;
|
||||
/**
|
||||
* Completed At
|
||||
* @description When this queue item was completed
|
||||
@@ -7080,12 +7108,55 @@ export type components = {
|
||||
* @description When this queue item was updated
|
||||
*/
|
||||
updated_at: string;
|
||||
/**
|
||||
* Started At
|
||||
* @description When this queue item was started
|
||||
*/
|
||||
started_at?: string;
|
||||
/**
|
||||
* Completed At
|
||||
* @description When this queue item was completed
|
||||
*/
|
||||
completed_at?: string;
|
||||
};
|
||||
/** SessionQueueStatusResult */
|
||||
SessionQueueStatusResult: {
|
||||
/**
|
||||
* Queue Id
|
||||
* @description The ID of the queue
|
||||
*/
|
||||
queue_id: string;
|
||||
/**
|
||||
* Pending
|
||||
* @description Number of queue items with status 'pending'
|
||||
*/
|
||||
pending: number;
|
||||
/**
|
||||
* In Progress
|
||||
* @description Number of queue items with status 'in_progress'
|
||||
*/
|
||||
in_progress: number;
|
||||
/**
|
||||
* Completed
|
||||
* @description Number of queue items with status 'complete'
|
||||
*/
|
||||
completed: number;
|
||||
/**
|
||||
* Failed
|
||||
* @description Number of queue items with status 'error'
|
||||
*/
|
||||
failed: number;
|
||||
/**
|
||||
* Canceled
|
||||
* @description Number of queue items with status 'canceled'
|
||||
*/
|
||||
canceled: number;
|
||||
/**
|
||||
* Total
|
||||
* @description Total number of queue items
|
||||
*/
|
||||
total: number;
|
||||
};
|
||||
/**
|
||||
* Show Image
|
||||
* @description Displays a provided image using the OS image viewer, and passes it forward in the pipeline.
|
||||
@@ -8078,6 +8149,12 @@ export type components = {
|
||||
/** Ui Order */
|
||||
ui_order?: number;
|
||||
};
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusionOnnxModelFormat
|
||||
* @description An enumeration.
|
||||
@@ -8102,12 +8179,6 @@ export type components = {
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
};
|
||||
responses: never;
|
||||
parameters: never;
|
||||
@@ -9622,10 +9693,10 @@ export type operations = {
|
||||
};
|
||||
};
|
||||
/**
|
||||
* Start
|
||||
* @description Starts session queue execution
|
||||
* Resume
|
||||
* @description Resumes session processor
|
||||
*/
|
||||
start: {
|
||||
resume: {
|
||||
parameters: {
|
||||
path: {
|
||||
/** @description The queue id to perform this operation on */
|
||||
@@ -9648,36 +9719,10 @@ export type operations = {
|
||||
};
|
||||
};
|
||||
/**
|
||||
* Stop
|
||||
* @description Stops session queue execution, waiting for the currently executing session to finish
|
||||
* Pause
|
||||
* @description Pauses session processor
|
||||
*/
|
||||
stop: {
|
||||
parameters: {
|
||||
path: {
|
||||
/** @description The queue id to perform this operation on */
|
||||
queue_id: string;
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
/** @description Successful Response */
|
||||
200: {
|
||||
content: {
|
||||
"application/json": unknown;
|
||||
};
|
||||
};
|
||||
/** @description Validation Error */
|
||||
422: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["HTTPValidationError"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
/**
|
||||
* Cancel
|
||||
* @description Stops session queue execution, immediately canceling the currently-executing session
|
||||
*/
|
||||
cancel: {
|
||||
pause: {
|
||||
parameters: {
|
||||
path: {
|
||||
/** @description The queue id to perform this operation on */
|
||||
@@ -9783,10 +9828,10 @@ export type operations = {
|
||||
};
|
||||
};
|
||||
/**
|
||||
* Current
|
||||
* Get Current Queue Item
|
||||
* @description Gets the currently execution queue item
|
||||
*/
|
||||
current: {
|
||||
get_current_queue_item: {
|
||||
parameters: {
|
||||
path: {
|
||||
/** @description The queue id to perform this operation on */
|
||||
@@ -9809,10 +9854,10 @@ export type operations = {
|
||||
};
|
||||
};
|
||||
/**
|
||||
* Peek
|
||||
* Get Next Queue Item
|
||||
* @description Gets the next queue item, without executing it
|
||||
*/
|
||||
peek: {
|
||||
get_next_queue_item: {
|
||||
parameters: {
|
||||
path: {
|
||||
/** @description The queue id to perform this operation on */
|
||||
@@ -9835,10 +9880,10 @@ export type operations = {
|
||||
};
|
||||
};
|
||||
/**
|
||||
* Get Status
|
||||
* Get Queue Status
|
||||
* @description Gets the status of the session queue
|
||||
*/
|
||||
get_status: {
|
||||
get_queue_status: {
|
||||
parameters: {
|
||||
path: {
|
||||
/** @description The queue id to perform this operation on */
|
||||
@@ -9849,7 +9894,61 @@ export type operations = {
|
||||
/** @description Successful Response */
|
||||
200: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["SessionQueueAndExecutionStatusResult"];
|
||||
"application/json": components["schemas"]["SessionQueueStatusResult"];
|
||||
};
|
||||
};
|
||||
/** @description Validation Error */
|
||||
422: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["HTTPValidationError"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
/**
|
||||
* Get Processor Status
|
||||
* @description Gets the status of the session queue
|
||||
*/
|
||||
get_processor_status: {
|
||||
parameters: {
|
||||
path: {
|
||||
/** @description The queue id to perform this operation on */
|
||||
queue_id: string;
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
/** @description Successful Response */
|
||||
200: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["SessionProcessorStatusResult"];
|
||||
};
|
||||
};
|
||||
/** @description Validation Error */
|
||||
422: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["HTTPValidationError"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
/**
|
||||
* Get Batch Status
|
||||
* @description Gets the status of the session queue
|
||||
*/
|
||||
get_batch_status: {
|
||||
parameters: {
|
||||
path: {
|
||||
/** @description The queue id to perform this operation on */
|
||||
queue_id: string;
|
||||
/** @description The batch to get the status of */
|
||||
batch_id: string;
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
/** @description Successful Response */
|
||||
200: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["BatchStatusResult"];
|
||||
};
|
||||
};
|
||||
/** @description Validation Error */
|
||||
|
||||
@@ -78,7 +78,7 @@ def mock_invoker(mock_services: InvocationServices) -> Invoker:
|
||||
|
||||
def test_can_create_graph_state(mock_invoker: Invoker):
|
||||
g = mock_invoker.create_execution_state()
|
||||
mock_invoker.stop_service()
|
||||
mock_invoker.stop()
|
||||
|
||||
assert g is not None
|
||||
assert isinstance(g, GraphExecutionState)
|
||||
@@ -86,7 +86,7 @@ def test_can_create_graph_state(mock_invoker: Invoker):
|
||||
|
||||
def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
|
||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||
mock_invoker.stop_service()
|
||||
mock_invoker.stop()
|
||||
|
||||
assert g is not None
|
||||
assert isinstance(g, GraphExecutionState)
|
||||
@@ -104,7 +104,7 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
||||
return len(g.executed) > 0
|
||||
|
||||
wait_until(lambda: has_executed_any(g), timeout=5, interval=1)
|
||||
mock_invoker.stop_service()
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert len(g.executed) > 0
|
||||
@@ -121,7 +121,7 @@ def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
||||
return g.is_complete()
|
||||
|
||||
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
|
||||
mock_invoker.stop_service()
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert g.is_complete()
|
||||
@@ -139,7 +139,7 @@ def test_handles_errors(mock_invoker: Invoker):
|
||||
return g.is_complete()
|
||||
|
||||
wait_until(lambda: has_executed_all(g), timeout=5, interval=1)
|
||||
mock_invoker.stop_service()
|
||||
mock_invoker.stop()
|
||||
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
assert g.has_error()
|
||||
|
||||
Reference in New Issue
Block a user