feat(queue): change all execution-related events to use the queue_id as the room, also include queue_item_id in InvocationQueueItem

This allows for much simpler handling of queue items.
This commit is contained in:
psychedelicious
2023-09-19 13:49:24 +10:00
parent 0c5bafdeb6
commit e8ac82a492
20 changed files with 308 additions and 248 deletions

View File

@@ -14,29 +14,10 @@ class SocketIO:
def __init__(self, app: FastAPI):
self.__sio = SocketManager(app=app)
self.__sio.on("subscribe_session", handler=self._handle_sub_session)
self.__sio.on("unsubscribe_session", handler=self._handle_unsub_session)
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
async def _handle_session_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["graph_execution_state_id"],
)
async def _handle_sub_session(self, sid, data, *args, **kwargs):
if "session" in data:
self.__sio.enter_room(sid, data["session"])
async def _handle_unsub_session(self, sid, data, *args, **kwargs):
if "session" in data:
self.__sio.leave_room(sid, data["session"])
async def _handle_queue_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],

View File

@@ -417,12 +417,18 @@ class UIConfigBase(BaseModel):
class InvocationContext:
"""Initialized and provided to on execution of invocations."""
services: InvocationServices
graph_execution_state_id: str
queue_id: str
queue_item_id: str
def __init__(self, services: InvocationServices, graph_execution_state_id: str):
def __init__(self, services: InvocationServices, queue_id: str, queue_item_id: str, graph_execution_state_id: str):
self.services = services
self.graph_execution_state_id = graph_execution_state_id
self.queue_id = queue_id
self.queue_item_id = queue_item_id
class BaseInvocationOutput(BaseModel):

View File

@@ -9,23 +9,13 @@ from invokeai.app.util.misc import get_timestamp
class EventServiceBase:
session_event: str = "session_event"
queue_event: str = "queue_event"
processor_event: str = "processor_event"
"""Basic event bus, to have an empty stand-in when not needed"""
def dispatch(self, event_name: str, payload: Any) -> None:
pass
def __emit_session_event(self, event_name: str, payload: dict) -> None:
"""Session events are emitted to a room with the session_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.session_event,
payload=dict(event=event_name, data=payload),
)
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
"""Queue events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
@@ -34,18 +24,12 @@ class EventServiceBase:
payload=dict(event=event_name, data=payload),
)
def __emit_processor_event(self, event_name: str, payload: dict) -> None:
"""Processor events are emitted to a room with "processor" as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.processor_event,
payload=dict(event=event_name, data=payload),
)
# Define events here for every event in the system.
# This will make them easier to integrate until we find a schema generator.
def emit_generator_progress(
self,
queue_id: str,
queue_item_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
@@ -55,9 +39,12 @@ class EventServiceBase:
total_steps: int,
) -> None:
"""Emitted when there is generation progress"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="generator_progress",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
item_id=queue_item_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
@@ -70,15 +57,19 @@ class EventServiceBase:
def emit_invocation_complete(
self,
queue_id: str,
queue_item_id: str,
graph_execution_state_id: str,
result: dict,
node: dict,
source_node_id: str,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="invocation_complete",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
@@ -88,6 +79,8 @@ class EventServiceBase:
def emit_invocation_error(
self,
queue_id: str,
queue_item_id: str,
graph_execution_state_id: str,
node: dict,
source_node_id: str,
@@ -95,9 +88,11 @@ class EventServiceBase:
error: str,
) -> None:
"""Emitted when an invocation has completed"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="invocation_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
@@ -106,28 +101,36 @@ class EventServiceBase:
),
)
def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None:
def emit_invocation_started(
self, queue_id: str, queue_item_id: str, graph_execution_state_id: str, node: dict, source_node_id: str
) -> None:
"""Emitted when an invocation has started"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="invocation_started",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
graph_execution_state_id=graph_execution_state_id,
node=node,
source_node_id=source_node_id,
),
)
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
def emit_graph_execution_complete(self, queue_id: str, queue_item_id: str, graph_execution_state_id: str) -> None:
"""Emitted when a session has completed all invocations"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="graph_execution_state_complete",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
graph_execution_state_id=graph_execution_state_id,
),
)
def emit_model_load_started(
self,
queue_id: str,
queue_item_id: str,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
@@ -135,9 +138,11 @@ class EventServiceBase:
submodel: SubModelType,
) -> None:
"""Emitted when a model is requested"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="model_load_started",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
@@ -148,6 +153,8 @@ class EventServiceBase:
def emit_model_load_completed(
self,
queue_id: str,
queue_item_id: str,
graph_execution_state_id: str,
model_name: str,
base_model: BaseModelType,
@@ -156,9 +163,11 @@ class EventServiceBase:
model_info: ModelInfo,
) -> None:
"""Emitted when a model is correctly loaded (returns model info)"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="model_load_completed",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
graph_execution_state_id=graph_execution_state_id,
model_name=model_name,
base_model=base_model,
@@ -172,14 +181,18 @@ class EventServiceBase:
def emit_session_retrieval_error(
self,
queue_id: str,
queue_item_id: str,
graph_execution_state_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when session retrieval fails"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="session_retrieval_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
graph_execution_state_id=graph_execution_state_id,
error_type=error_type,
error=error,
@@ -188,15 +201,19 @@ class EventServiceBase:
def emit_invocation_retrieval_error(
self,
queue_id: str,
queue_item_id: str,
graph_execution_state_id: str,
node_id: str,
error_type: str,
error: str,
) -> None:
"""Emitted when invocation retrieval fails"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="invocation_retrieval_error",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
graph_execution_state_id=graph_execution_state_id,
node_id=node_id,
error_type=error_type,
@@ -206,12 +223,16 @@ class EventServiceBase:
def emit_session_canceled(
self,
queue_id: str,
queue_item_id: str,
graph_execution_state_id: str,
) -> None:
"""Emitted when a session is canceled"""
self.__emit_session_event(
self.__emit_queue_event(
event_name="session_canceled",
payload=dict(
queue_id=queue_id,
queue_item_id=queue_item_id,
graph_execution_state_id=graph_execution_state_id,
),
)
@@ -221,11 +242,11 @@ class EventServiceBase:
self.__emit_queue_event(
event_name="queue_item_status_changed",
payload=dict(
item_id=session_queue_item.item_id,
queue_id=session_queue_item.queue_id,
queue_item_id=session_queue_item.item_id,
status=session_queue_item.status,
batch_id=session_queue_item.batch_id,
session_id=session_queue_item.session_id,
queue_id=session_queue_item.queue_id,
error=session_queue_item.error,
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
@@ -238,7 +259,11 @@ class EventServiceBase:
"""Emitted when a batch is enqueued"""
self.__emit_queue_event(
event_name="batch_enqueued",
payload=enqueue_result.dict(),
payload=dict(
queue_id=enqueue_result.queue_id,
batch_id=enqueue_result.batch.batch_id,
enqueued=enqueue_result.enqueued,
),
)
def emit_queue_cleared(self, queue_id: str) -> None:

View File

@@ -253,49 +253,53 @@ class Graph(BaseModel):
@root_validator
def validate_nodes_and_edges(cls, values):
"""Validates that all edges match nodes in the graph"""
nodes = cast(dict[str, BaseInvocation], values.get("nodes"))
edges = cast(list[Edge], values.get("edges"))
nodes = cast(Optional[dict[str, BaseInvocation]], values.get("nodes"))
edges = cast(Optional[list[Edge]], values.get("edges"))
# Validate that all node ids are unique
node_ids = [n.id for n in nodes.values()]
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
if duplicate_node_ids:
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
if nodes is not None:
# Validate that all node ids are unique
node_ids = [n.id for n in nodes.values()]
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
if duplicate_node_ids:
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
# Validate that all node ids match the keys in the nodes dict
for k, v in nodes.items():
if k != v.id:
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
# Validate that all node ids match the keys in the nodes dict
for k, v in nodes.items():
if k != v.id:
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
# Validate that all edges match nodes in the graph
node_ids = set([e.source.node_id for e in edges] + [e.destination.node_id for e in edges])
missing_node_ids = [node_id for node_id in node_ids if node_id not in nodes]
if missing_node_ids:
raise NodeNotFoundError(f"All edges must reference nodes in the graph, missing nodes: {missing_node_ids}")
# Validate that all edge fields match node fields in the graph
for edge in edges:
source_node = nodes.get(edge.source.node_id, None)
if source_node is None:
raise NodeFieldNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
destination_node = nodes.get(edge.destination.node_id, None)
if destination_node is None:
raise NodeFieldNotFoundError(
f"Edge destination node {edge.destination.node_id} does not exist in the graph"
if edges is not None and nodes is not None:
# Validate that all edges match nodes in the graph
node_ids = set([e.source.node_id for e in edges] + [e.destination.node_id for e in edges])
missing_node_ids = [node_id for node_id in node_ids if node_id not in nodes]
if missing_node_ids:
raise NodeNotFoundError(
f"All edges must reference nodes in the graph, missing nodes: {missing_node_ids}"
)
# output fields are not on the node object directly, they are on the output type
if edge.source.field not in source_node.get_output_type().__fields__:
raise NodeFieldNotFoundError(
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
)
# Validate that all edge fields match node fields in the graph
for edge in edges:
source_node = nodes.get(edge.source.node_id, None)
if source_node is None:
raise NodeFieldNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
# input fields are on the node
if edge.destination.field not in destination_node.__fields__:
raise NodeFieldNotFoundError(
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
)
destination_node = nodes.get(edge.destination.node_id, None)
if destination_node is None:
raise NodeFieldNotFoundError(
f"Edge destination node {edge.destination.node_id} does not exist in the graph"
)
# output fields are not on the node object directly, they are on the output type
if edge.source.field not in source_node.get_output_type().__fields__:
raise NodeFieldNotFoundError(
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
)
# input fields are on the node
if edge.destination.field not in destination_node.__fields__:
raise NodeFieldNotFoundError(
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
)
return values
@@ -760,7 +764,6 @@ class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
id: str = Field(description="The id of the execution state", default_factory=uuid_string)
# TODO: Store a reference to the graph instead of the actual graph?
graph: Graph = Field(description="The graph being executed")

View File

@@ -11,6 +11,10 @@ from pydantic import BaseModel, Field
class InvocationQueueItem(BaseModel):
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
invocation_id: str = Field(description="The ID of the node being invoked")
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
session_queue_item_id: str = Field(
description="The ID of session queue item from which this invocation queue item came"
)
invoke_all: bool = Field(default=False)
timestamp: float = Field(default_factory=time.time)

View File

@@ -17,7 +17,9 @@ class Invoker:
self.services = services
self._start()
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
def invoke(
self, queue_id: str, queue_item_id: str, graph_execution_state: GraphExecutionState, invoke_all: bool = False
) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
@@ -32,7 +34,8 @@ class Invoker:
# Queue the invocation
self.services.queue.put(
InvocationQueueItem(
# session_id = session.id,
session_queue_item_id=queue_item_id,
session_queue_id=queue_id,
graph_execution_state_id=graph_execution_state.id,
invocation_id=invocation.id,
invoke_all=invoke_all,

View File

@@ -525,7 +525,7 @@ class ModelManagerService(ModelManagerServiceBase):
def _emit_load_event(
self,
context,
context: InvocationContext,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
@@ -537,6 +537,8 @@ class ModelManagerService(ModelManagerServiceBase):
if model_info:
context.services.events.emit_model_load_completed(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
graph_execution_state_id=context.graph_execution_state_id,
model_name=model_name,
base_model=base_model,
@@ -546,6 +548,8 @@ class ModelManagerService(ModelManagerServiceBase):
)
else:
context.services.events.emit_model_load_started(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
graph_execution_state_id=context.graph_execution_state_id,
model_name=model_name,
base_model=base_model,

View File

@@ -1,6 +1,7 @@
import time
import traceback
from threading import BoundedSemaphore, Event, Thread
from typing import Optional
import invokeai.backend.util.logging as logger
@@ -37,10 +38,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
try:
self.__threadLimit.acquire()
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
queue_item: Optional[InvocationQueueItem] = None
while not stop_event.is_set():
try:
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
queue_item = self.__invoker.services.queue.get()
except Exception as e:
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
@@ -48,7 +50,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# do not hammer the queue
time.sleep(0.5)
continue
try:
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
queue_item.graph_execution_state_id
@@ -56,6 +57,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
self.__invoker.services.events.emit_session_retrieval_error(
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id,
error_type=e.__class__.__name__,
error=traceback.format_exc(),
@@ -67,6 +70,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
except Exception as e:
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
self.__invoker.services.events.emit_invocation_retrieval_error(
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=queue_item.graph_execution_state_id,
node_id=queue_item.invocation_id,
error_type=e.__class__.__name__,
@@ -79,6 +84,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send starting event
self.__invoker.services.events.emit_invocation_started(
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
@@ -89,13 +96,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
graph_id = graph_execution_state.id
model_manager = self.__invoker.services.model_manager
with statistics.collect_stats(invocation, graph_id, model_manager):
# use the internal invoke_internal(), which wraps the node's invoke() method in
# this accomodates nodes which require a value, but get it only from a
# connection
# use the internal invoke_internal(), which wraps the node's invoke() method,
# which handles a few things:
# - nodes that require a value, but get it only from a connection
# - referencing the invocation cache instead of executing the node
outputs = invocation.invoke_internal(
InvocationContext(
services=self.__invoker.services,
graph_execution_state_id=graph_execution_state.id,
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
)
)
@@ -111,6 +121,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
# Send complete event
self.__invoker.services.events.emit_invocation_complete(
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
@@ -138,6 +150,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
# Send error event
self.__invoker.services.events.emit_invocation_error(
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
@@ -155,10 +169,17 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
is_complete = graph_execution_state.is_complete()
if queue_item.invoke_all and not is_complete:
try:
self.__invoker.invoke(graph_execution_state, invoke_all=True)
self.__invoker.invoke(
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state=graph_execution_state,
invoke_all=True,
)
except Exception as e:
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
self.__invoker.services.events.emit_invocation_error(
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
node=invocation.dict(),
source_node_id=source_node_id,
@@ -166,7 +187,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
error=traceback.format_exc(),
)
elif is_complete:
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
self.__invoker.services.events.emit_graph_execution_complete(
queue_item_id=queue_item.session_queue_item_id,
queue_id=queue_item.session_queue_id,
graph_execution_state_id=graph_execution_state.id,
)
except KeyboardInterrupt:
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor

View File

@@ -26,20 +26,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
self.__stop_event = ThreadEvent()
self.__poll_now_event = ThreadEvent()
local_handler.register(event_name=EventServiceBase.session_event, _func=self._on_session_event)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
self.__threadLimit = BoundedSemaphore(THREAD_LIMIT)
self._start_thread()
def stop(self, *args, **kwargs) -> None:
self.__stop_event.set()
def _poll_now(self) -> None:
self.__poll_now_event.set()
def _start_thread(self) -> None:
# threads only live once, so we need to create a new one whenever we start the session processor
self.__thread = Thread(
name="session_processor",
target=self.__process,
@@ -49,40 +38,29 @@ class DefaultSessionProcessor(SessionProcessorBase):
)
self.__thread.start()
async def _on_session_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
if event_name in [
"graph_execution_state_complete",
"invocation_error",
"session_retrieval_error",
"invocation_retrieval_error",
] or (
event_name == "session_canceled"
and self.__queue_item is not None
and self.__queue_item.session_id == event[1]["data"]["graph_execution_state_id"]
):
self.__queue_item = None
self._poll_now()
def stop(self, *args, **kwargs) -> None:
self.__stop_event.set()
def _poll_now(self) -> None:
self.__poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent) -> None:
event_name = event[1]["event"]
if event_name == "batch_enqueued":
self._poll_now()
if event_name == "queue_cleared":
self.__queue_item = None
self._poll_now()
def _is_started(self) -> bool:
return self.__resume_event.is_set()
def _is_processing(self) -> bool:
return self.__queue_item is not None
def get_status(self) -> SessionProcessorStatus:
return SessionProcessorStatus(
is_started=self._is_started(),
is_processing=self._is_processing(),
)
match event_name:
case "graph_execution_state_complete" | "invocation_error" | "session_retrieval_error" | "invocation_retrieval_error":
self.__queue_item = None
self._poll_now()
case "session_canceled" if self.__queue_item is not None and self.__queue_item.session_id == event[1][
"data"
]["graph_execution_state_id"]:
self.__queue_item = None
self._poll_now()
case "batch_enqueued":
self._poll_now()
case "queue_cleared":
self.__queue_item = None
self._poll_now()
def resume(self) -> SessionProcessorStatus:
if not self.__resume_event.is_set():
@@ -94,6 +72,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
self.__resume_event.clear()
return self.get_status()
def get_status(self) -> SessionProcessorStatus:
return SessionProcessorStatus(
is_started=self.__resume_event.is_set(),
is_processing=self.__queue_item is not None,
)
def __process(
self,
stop_event: ThreadEvent,
@@ -114,19 +98,23 @@ class DefaultSessionProcessor(SessionProcessorBase):
queue_item = self.__invoker.services.session_queue.dequeue()
if queue_item is not None:
# TODO: Why isn't the log level specified in dependencies.py working?
# Within the thread, it is always INFO and `logger.debug()` doesn't display.
# self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
self.__queue_item = queue_item
self.__invoker.services.graph_execution_manager.set(queue_item.session)
self.__invoker.invoke(queue_item.session, invoke_all=True)
self.__invoker.invoke(
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state=queue_item.session,
invoke_all=True,
)
queue_item = None
if queue_item is None:
# self.__invoker.services.logger.debug("Waiting for next polling interval or event")
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(POLLING_INTERVAL)
continue
except Exception:
except Exception as e:
self.__invoker.services.logger.error(f"Error in session processor: {e}")
pass
finally:
stop_event.clear()

View File

@@ -110,8 +110,3 @@ class SessionQueueBase(ABC):
def get_queue_item(self, item_id: str) -> SessionQueueItem:
"""Gets a session queue item by ID"""
pass
@abstractmethod
def get_queue_item_by_session_id(self, session_id: str) -> SessionQueueItem:
"""Gets a queue item by session ID"""
pass

View File

@@ -90,7 +90,7 @@ class Batch(BaseModel):
# Get the type of the first item in the list
first_item_type = type(datum.items[0]) if datum.items else None
for item in datum.items:
if type(item) != first_item_type:
if type(item) is not first_item_type:
raise BatchItemsTypeError("All items in a batch must have the same type")
return v

View File

@@ -42,7 +42,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.__invoker = invoker
self._set_in_progress_to_canceled()
prune_result = self.prune(DEFAULT_QUEUE_ID)
local_handler.register(event_name=EventServiceBase.session_event, _func=self._on_session_event)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
@@ -60,28 +60,21 @@ class SqliteSessionQueue(SessionQueueBase):
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
event_name = event[1]["event"]
match event_name:
# successful completion events
case "graph_execution_state_complete":
await self._handle_complete_event(event)
# error events
case "invocation_error":
case "invocation_error" | "session_retrieval_error" | "invocation_retrieval_error":
await self._handle_error_event(event)
case "session_retrieval_error":
await self._handle_error_event(event)
case "invocation_retrieval_error":
await self._handle_error_event(event)
# canceled events
case "session_canceled":
await self._handle_cancel_event(event)
return event
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
try:
session_id = event[1]["data"]["graph_execution_state_id"]
item_id = event[1]["data"]["queue_item_id"]
# When a queue item has an error, we get an error event, then a completed event.
# Mark the queue item completed only if it isn't already marked completed, e.g.
# by a previously-handled error event.
queue_item = self.get_queue_item_by_session_id(session_id)
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ["completed", "failed", "canceled"]:
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
@@ -90,9 +83,9 @@ class SqliteSessionQueue(SessionQueueBase):
async def _handle_error_event(self, event: FastAPIEvent) -> None:
try:
session_id = event[1]["data"]["graph_execution_state_id"]
item_id = event[1]["data"]["queue_item_id"]
error = event[1]["data"]["error"]
queue_item = self.get_queue_item_by_session_id(session_id)
queue_item = self.get_queue_item(item_id)
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
except SessionQueueItemNotFoundError:
@@ -100,8 +93,8 @@ class SqliteSessionQueue(SessionQueueBase):
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
try:
session_id = event[1]["data"]["graph_execution_state_id"]
queue_item = self.get_queue_item_by_session_id(session_id)
item_id = event[1]["data"]["queue_item_id"]
queue_item = self.get_queue_item(item_id)
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
except SessionQueueItemNotFoundError:
@@ -583,7 +576,11 @@ class SqliteSessionQueue(SessionQueueBase):
if queue_item.status not in ["canceled", "failed", "completed"]:
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
self.__invoker.services.queue.cancel(queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
graph_execution_state_id=queue_item.session_id,
)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
return queue_item
@@ -621,7 +618,11 @@ class SqliteSessionQueue(SessionQueueBase):
self.__conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self.__invoker.services.queue.cancel(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
graph_execution_state_id=current_queue_item.session_id,
)
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
except Exception:
self.__conn.rollback()
@@ -662,7 +663,11 @@ class SqliteSessionQueue(SessionQueueBase):
self.__conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self.__invoker.services.queue.cancel(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(current_queue_item.session_id)
self.__invoker.services.events.emit_session_canceled(
queue_item_id=current_queue_item.item_id,
queue_id=current_queue_item.queue_id,
graph_execution_state_id=current_queue_item.session_id,
)
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
except Exception:
self.__conn.rollback()
@@ -692,27 +697,6 @@ class SqliteSessionQueue(SessionQueueBase):
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.from_dict(dict(result))
def get_queue_item_by_session_id(self, session_id: str) -> SessionQueueItem:
try:
self.__lock.acquire()
self.__cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
session_id = ?
""",
(session_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
except Exception:
self.__conn.rollback()
raise
finally:
self.__lock.release()
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with session id {session_id}")
return SessionQueueItem.from_dict(dict(result))
def list_queue_items(
self,
queue_id: str,

View File

@@ -110,6 +110,8 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG")
context.services.events.emit_generator_progress(
queue_id=context.queue_id,
queue_item_id=context.queue_item_id,
graph_execution_state_id=context.graph_execution_state_id,
node=node,
source_node_id=source_node_id,

View File

@@ -12,8 +12,12 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
actionCreator: socketQueueItemStatusChanged,
effect: (action, { dispatch, getState }) => {
const log = logger('socketio');
const { item_id, batch_id, graph_execution_state_id, status } =
action.payload.data;
const {
queue_item_id: item_id,
batch_id,
graph_execution_state_id,
status,
} = action.payload.data;
log.debug(
action.payload,
`Queue item ${item_id} status updated: ${status}`

View File

@@ -8,11 +8,13 @@ export type paths = {
"/api/v1/sessions/": {
/**
* List Sessions
* @deprecated
* @description Gets a list of sessions, optionally searching
*/
get: operations["list_sessions"];
/**
* Create Session
* @deprecated
* @description Creates a new session, optionally initializing it with an invocation graph
*/
post: operations["create_session"];
@@ -20,6 +22,7 @@ export type paths = {
"/api/v1/sessions/{session_id}": {
/**
* Get Session
* @deprecated
* @description Gets a session
*/
get: operations["get_session"];
@@ -27,6 +30,7 @@ export type paths = {
"/api/v1/sessions/{session_id}/nodes": {
/**
* Add Node
* @deprecated
* @description Adds a node to the graph
*/
post: operations["add_node"];
@@ -34,11 +38,13 @@ export type paths = {
"/api/v1/sessions/{session_id}/nodes/{node_path}": {
/**
* Update Node
* @deprecated
* @description Updates a node in the graph and removes all linked edges
*/
put: operations["update_node"];
/**
* Delete Node
* @deprecated
* @description Deletes a node in the graph and removes all linked edges
*/
delete: operations["delete_node"];
@@ -46,6 +52,7 @@ export type paths = {
"/api/v1/sessions/{session_id}/edges": {
/**
* Add Edge
* @deprecated
* @description Adds an edge to the graph
*/
post: operations["add_edge"];
@@ -53,6 +60,7 @@ export type paths = {
"/api/v1/sessions/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}": {
/**
* Delete Edge
* @deprecated
* @description Deletes an edge from the graph
*/
delete: operations["delete_edge"];
@@ -60,11 +68,13 @@ export type paths = {
"/api/v1/sessions/{session_id}/invoke": {
/**
* Invoke Session
* @deprecated
* @description Invokes a session
*/
put: operations["invoke_session"];
/**
* Cancel Session Invoke
* @deprecated
* @description Invokes a session
*/
delete: operations["cancel_session_invoke"];
@@ -8161,35 +8171,35 @@ export type components = {
ui_order?: number;
};
/**
* ControlNetModelFormat
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion2ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionXLModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusionOnnxModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusionOnnxModelFormat: "olive" | "onnx";
/**
* ControlNetModelFormat
* @description An enumeration.
* @enum {string}
*/
ControlNetModelFormat: "checkpoint" | "diffusers";
/**
* StableDiffusion1ModelFormat
* @description An enumeration.
* @enum {string}
*/
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
};
responses: never;
parameters: never;
@@ -8206,6 +8216,7 @@ export type operations = {
/**
* List Sessions
* @deprecated
* @description Gets a list of sessions, optionally searching
*/
list_sessions: {
@@ -8236,9 +8247,16 @@ export type operations = {
};
/**
* Create Session
* @deprecated
* @description Creates a new session, optionally initializing it with an invocation graph
*/
create_session: {
parameters: {
query?: {
/** @description The id of the queue to associate the session with */
queue_id?: string;
};
};
requestBody?: {
content: {
"application/json": components["schemas"]["Graph"];
@@ -8265,6 +8283,7 @@ export type operations = {
};
/**
* Get Session
* @deprecated
* @description Gets a session
*/
get_session: {
@@ -8295,6 +8314,7 @@ export type operations = {
};
/**
* Add Node
* @deprecated
* @description Adds a node to the graph
*/
add_node: {
@@ -8334,6 +8354,7 @@ export type operations = {
};
/**
* Update Node
* @deprecated
* @description Updates a node in the graph and removes all linked edges
*/
update_node: {
@@ -8375,6 +8396,7 @@ export type operations = {
};
/**
* Delete Node
* @deprecated
* @description Deletes a node in the graph and removes all linked edges
*/
delete_node: {
@@ -8411,6 +8433,7 @@ export type operations = {
};
/**
* Add Edge
* @deprecated
* @description Adds an edge to the graph
*/
add_edge: {
@@ -8450,6 +8473,7 @@ export type operations = {
};
/**
* Delete Edge
* @deprecated
* @description Deletes an edge from the graph
*/
delete_edge: {
@@ -8492,11 +8516,14 @@ export type operations = {
};
/**
* Invoke Session
* @deprecated
* @description Invokes a session
*/
invoke_session: {
parameters: {
query?: {
query: {
/** @description The id of the queue to associate the session with */
queue_id: string;
/** @description Whether or not to invoke all remaining invocations */
all?: boolean;
};
@@ -8534,6 +8561,7 @@ export type operations = {
};
/**
* Cancel Session Invoke
* @deprecated
* @description Invokes a session
*/
cancel_session_invoke: {

View File

@@ -1,4 +1,5 @@
import { createAsyncThunk, isAnyOf } from '@reduxjs/toolkit';
import { $queueId } from 'features/queue/store/nanoStores';
import { isObject } from 'lodash-es';
import { $client } from 'services/api/client';
import { paths } from 'services/api/schema';
@@ -33,6 +34,7 @@ export const sessionCreated = createAsyncThunk<
const { POST } = $client.get();
const { data, error, response } = await POST('/api/v1/sessions/', {
body: graph,
params: { query: { queue_id: $queueId.get() } },
});
if (error) {
@@ -76,7 +78,10 @@ export const sessionInvoked = createAsyncThunk<
const { error, response } = await PUT(
'/api/v1/sessions/{session_id}/invoke',
{
params: { query: { all: true }, path: { session_id } },
params: {
query: { queue_id: $queueId.get(), all: true },
path: { session_id },
},
}
);

View File

@@ -33,6 +33,8 @@ export type BaseNode = {
};
export type ModelLoadStartedEvent = {
queue_id: string;
queue_item_id: string;
graph_execution_state_id: string;
model_name: string;
base_model: BaseModelType;
@@ -41,6 +43,8 @@ export type ModelLoadStartedEvent = {
};
export type ModelLoadCompletedEvent = {
queue_id: string;
queue_item_id: string;
graph_execution_state_id: string;
model_name: string;
base_model: BaseModelType;
@@ -57,6 +61,8 @@ export type ModelLoadCompletedEvent = {
* @example socket.on('generator_progress', (data: GeneratorProgressEvent) => { ... }
*/
export type GeneratorProgressEvent = {
queue_id: string;
queue_item_id: string;
graph_execution_state_id: string;
node: BaseNode;
source_node_id: string;
@@ -73,6 +79,8 @@ export type GeneratorProgressEvent = {
* @example socket.on('invocation_complete', (data: InvocationCompleteEvent) => { ... }
*/
export type InvocationCompleteEvent = {
queue_id: string;
queue_item_id: string;
graph_execution_state_id: string;
node: BaseNode;
source_node_id: string;
@@ -85,6 +93,8 @@ export type InvocationCompleteEvent = {
* @example socket.on('invocation_error', (data: InvocationErrorEvent) => { ... }
*/
export type InvocationErrorEvent = {
queue_id: string;
queue_item_id: string;
graph_execution_state_id: string;
node: BaseNode;
source_node_id: string;
@@ -98,6 +108,8 @@ export type InvocationErrorEvent = {
* @example socket.on('invocation_started', (data: InvocationStartedEvent) => { ... }
*/
export type InvocationStartedEvent = {
queue_id: string;
queue_item_id: string;
graph_execution_state_id: string;
node: BaseNode;
source_node_id: string;
@@ -109,6 +121,8 @@ export type InvocationStartedEvent = {
* @example socket.on('graph_execution_state_complete', (data: GraphExecutionStateCompleteEvent) => { ... }
*/
export type GraphExecutionStateCompleteEvent = {
queue_id: string;
queue_item_id: string;
graph_execution_state_id: string;
};
@@ -118,6 +132,8 @@ export type GraphExecutionStateCompleteEvent = {
* @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... }
*/
export type SessionRetrievalErrorEvent = {
queue_id: string;
queue_item_id: string;
graph_execution_state_id: string;
error_type: string;
error: string;
@@ -129,6 +145,8 @@ export type SessionRetrievalErrorEvent = {
* @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... }
*/
export type InvocationRetrievalErrorEvent = {
queue_id: string;
queue_item_id: string;
graph_execution_state_id: string;
node_id: string;
error_type: string;
@@ -141,10 +159,10 @@ export type InvocationRetrievalErrorEvent = {
* @example socket.on('queue_item_status_changed', (data: QueueItemStatusChangedEvent) => { ... }
*/
export type QueueItemStatusChangedEvent = {
item_id: string;
queue_id: string;
queue_item_id: string;
batch_id: string;
session_id: string;
queue_id: string;
graph_execution_state_id: string;
status: components['schemas']['SessionQueueItemDTO']['status'];
error: string | undefined;
@@ -154,14 +172,6 @@ export type QueueItemStatusChangedEvent = {
completed_at: string | undefined;
};
export type ClientEmitSubscribeSession = {
session: string;
};
export type ClientEmitUnsubscribeSession = {
session: string;
};
export type ClientEmitSubscribeQueue = {
queue_id: string;
};
@@ -188,8 +198,6 @@ export type ServerToClientEvents = {
export type ClientToServerEvents = {
connect: () => void;
disconnect: () => void;
subscribe_session: (payload: ClientEmitSubscribeSession) => void;
unsubscribe_session: (payload: ClientEmitUnsubscribeSession) => void;
subscribe_queue: (payload: ClientEmitSubscribeQueue) => void;
unsubscribe_queue: (payload: ClientEmitUnsubscribeQueue) => void;
};

View File

@@ -156,17 +156,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
});
socket.on('queue_item_status_changed', (data) => {
const { status, session_id } = data;
if (status === 'in_progress') {
socket.emit('subscribe_session', {
session: session_id,
});
}
if (['new', 'completed', 'failed', 'canceled'].includes(status)) {
socket.emit('unsubscribe_session', {
session: session_id,
});
}
dispatch(socketQueueItemStatusChanged({ data }));
});
};

View File

@@ -9,7 +9,6 @@ from .test_nodes import ( # isort: split
TestEventService,
TextToImageTestInvocation,
)
import sqlite3
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
@@ -20,6 +19,7 @@ from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from .test_invoker import create_edge
@@ -70,7 +70,9 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
return (None, None)
print(f"invoking {n.id}: {type(n)}")
o = n.invoke(InvocationContext(services, "1"))
o = n.invoke(
InvocationContext(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, services=services, graph_execution_state_id="1")
)
g.complete(n.id, o)
return (n, o)

View File

@@ -3,15 +3,8 @@ import threading
import pytest
from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
from .test_nodes import (
# This import must happen before other invoke imports or test in other files(!!) break
from .test_nodes import ( # isort: split
ErrorInvocation,
PromptTestInvocation,
TestEventService,
@@ -20,6 +13,15 @@ from .test_nodes import (
wait_until,
)
from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.processor import DefaultInvocationProcessor
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
@pytest.fixture
def simple_graph():
@@ -96,7 +98,7 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(g)
invocation_id = mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g)
assert invocation_id is not None
def has_executed_any(g: GraphExecutionState):
@@ -113,7 +115,9 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
g = mock_invoker.create_execution_state(graph=simple_graph)
invocation_id = mock_invoker.invoke(g, invoke_all=True)
invocation_id = mock_invoker.invoke(
queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True
)
assert invocation_id is not None
def has_executed_all(g: GraphExecutionState):
@@ -132,7 +136,7 @@ def test_handles_errors(mock_invoker: Invoker):
g = mock_invoker.create_execution_state()
g.graph.add_node(ErrorInvocation(id="1"))
mock_invoker.invoke(g, invoke_all=True)
mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True)
def has_executed_all(g: GraphExecutionState):
g = mock_invoker.services.graph_execution_manager.get(g.id)