mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(queue): change all execution-related events to use the queue_id as the room, also include queue_item_id in InvocationQueueItem
This allows for much simpler handling of queue items.
This commit is contained in:
@@ -14,29 +14,10 @@ class SocketIO:
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app=app)
|
||||
|
||||
self.__sio.on("subscribe_session", handler=self._handle_sub_session)
|
||||
self.__sio.on("unsubscribe_session", handler=self._handle_unsub_session)
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
||||
|
||||
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue)
|
||||
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["graph_execution_state_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub_session(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.enter_room(sid, data["session"])
|
||||
|
||||
async def _handle_unsub_session(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.leave_room(sid, data["session"])
|
||||
|
||||
async def _handle_queue_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
|
||||
@@ -417,12 +417,18 @@ class UIConfigBase(BaseModel):
|
||||
|
||||
|
||||
class InvocationContext:
|
||||
"""Initialized and provided to on execution of invocations."""
|
||||
|
||||
services: InvocationServices
|
||||
graph_execution_state_id: str
|
||||
queue_id: str
|
||||
queue_item_id: str
|
||||
|
||||
def __init__(self, services: InvocationServices, graph_execution_state_id: str):
|
||||
def __init__(self, services: InvocationServices, queue_id: str, queue_item_id: str, graph_execution_state_id: str):
|
||||
self.services = services
|
||||
self.graph_execution_state_id = graph_execution_state_id
|
||||
self.queue_id = queue_id
|
||||
self.queue_item_id = queue_item_id
|
||||
|
||||
|
||||
class BaseInvocationOutput(BaseModel):
|
||||
|
||||
@@ -9,23 +9,13 @@ from invokeai.app.util.misc import get_timestamp
|
||||
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
queue_event: str = "queue_event"
|
||||
processor_event: str = "processor_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Session events are emitted to a room with the session_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.session_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
def __emit_queue_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Queue events are emitted to a room with queue_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
@@ -34,18 +24,12 @@ class EventServiceBase:
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
def __emit_processor_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Processor events are emitted to a room with "processor" as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.processor_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
@@ -55,9 +39,12 @@ class EventServiceBase:
|
||||
total_steps: int,
|
||||
) -> None:
|
||||
"""Emitted when there is generation progress"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="generator_progress",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
item_id=queue_item_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
@@ -70,15 +57,19 @@ class EventServiceBase:
|
||||
|
||||
def emit_invocation_complete(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: str,
|
||||
graph_execution_state_id: str,
|
||||
result: dict,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_complete",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
@@ -88,6 +79,8 @@ class EventServiceBase:
|
||||
|
||||
def emit_invocation_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node: dict,
|
||||
source_node_id: str,
|
||||
@@ -95,9 +88,11 @@ class EventServiceBase:
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when an invocation has completed"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_error",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
@@ -106,28 +101,36 @@ class EventServiceBase:
|
||||
),
|
||||
)
|
||||
|
||||
def emit_invocation_started(self, graph_execution_state_id: str, node: dict, source_node_id: str) -> None:
|
||||
def emit_invocation_started(
|
||||
self, queue_id: str, queue_item_id: str, graph_execution_state_id: str, node: dict, source_node_id: str
|
||||
) -> None:
|
||||
"""Emitted when an invocation has started"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_started",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_graph_execution_complete(self, graph_execution_state_id: str) -> None:
|
||||
def emit_graph_execution_complete(self, queue_id: str, queue_item_id: str, graph_execution_state_id: str) -> None:
|
||||
"""Emitted when a session has completed all invocations"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="graph_execution_state_complete",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_model_load_started(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
@@ -135,9 +138,11 @@ class EventServiceBase:
|
||||
submodel: SubModelType,
|
||||
) -> None:
|
||||
"""Emitted when a model is requested"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_started",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
@@ -148,6 +153,8 @@ class EventServiceBase:
|
||||
|
||||
def emit_model_load_completed(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: str,
|
||||
graph_execution_state_id: str,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
@@ -156,9 +163,11 @@ class EventServiceBase:
|
||||
model_info: ModelInfo,
|
||||
) -> None:
|
||||
"""Emitted when a model is correctly loaded (returns model info)"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="model_load_completed",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
@@ -172,14 +181,18 @@ class EventServiceBase:
|
||||
|
||||
def emit_session_retrieval_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: str,
|
||||
graph_execution_state_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when session retrieval fails"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="session_retrieval_error",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
error_type=error_type,
|
||||
error=error,
|
||||
@@ -188,15 +201,19 @@ class EventServiceBase:
|
||||
|
||||
def emit_invocation_retrieval_error(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: str,
|
||||
graph_execution_state_id: str,
|
||||
node_id: str,
|
||||
error_type: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
"""Emitted when invocation retrieval fails"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="invocation_retrieval_error",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
node_id=node_id,
|
||||
error_type=error_type,
|
||||
@@ -206,12 +223,16 @@ class EventServiceBase:
|
||||
|
||||
def emit_session_canceled(
|
||||
self,
|
||||
queue_id: str,
|
||||
queue_item_id: str,
|
||||
graph_execution_state_id: str,
|
||||
) -> None:
|
||||
"""Emitted when a session is canceled"""
|
||||
self.__emit_session_event(
|
||||
self.__emit_queue_event(
|
||||
event_name="session_canceled",
|
||||
payload=dict(
|
||||
queue_id=queue_id,
|
||||
queue_item_id=queue_item_id,
|
||||
graph_execution_state_id=graph_execution_state_id,
|
||||
),
|
||||
)
|
||||
@@ -221,11 +242,11 @@ class EventServiceBase:
|
||||
self.__emit_queue_event(
|
||||
event_name="queue_item_status_changed",
|
||||
payload=dict(
|
||||
item_id=session_queue_item.item_id,
|
||||
queue_id=session_queue_item.queue_id,
|
||||
queue_item_id=session_queue_item.item_id,
|
||||
status=session_queue_item.status,
|
||||
batch_id=session_queue_item.batch_id,
|
||||
session_id=session_queue_item.session_id,
|
||||
queue_id=session_queue_item.queue_id,
|
||||
error=session_queue_item.error,
|
||||
created_at=str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||
updated_at=str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||
@@ -238,7 +259,11 @@ class EventServiceBase:
|
||||
"""Emitted when a batch is enqueued"""
|
||||
self.__emit_queue_event(
|
||||
event_name="batch_enqueued",
|
||||
payload=enqueue_result.dict(),
|
||||
payload=dict(
|
||||
queue_id=enqueue_result.queue_id,
|
||||
batch_id=enqueue_result.batch.batch_id,
|
||||
enqueued=enqueue_result.enqueued,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_queue_cleared(self, queue_id: str) -> None:
|
||||
|
||||
@@ -253,49 +253,53 @@ class Graph(BaseModel):
|
||||
@root_validator
|
||||
def validate_nodes_and_edges(cls, values):
|
||||
"""Validates that all edges match nodes in the graph"""
|
||||
nodes = cast(dict[str, BaseInvocation], values.get("nodes"))
|
||||
edges = cast(list[Edge], values.get("edges"))
|
||||
nodes = cast(Optional[dict[str, BaseInvocation]], values.get("nodes"))
|
||||
edges = cast(Optional[list[Edge]], values.get("edges"))
|
||||
|
||||
# Validate that all node ids are unique
|
||||
node_ids = [n.id for n in nodes.values()]
|
||||
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
|
||||
if duplicate_node_ids:
|
||||
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
||||
if nodes is not None:
|
||||
# Validate that all node ids are unique
|
||||
node_ids = [n.id for n in nodes.values()]
|
||||
duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2])
|
||||
if duplicate_node_ids:
|
||||
raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}")
|
||||
|
||||
# Validate that all node ids match the keys in the nodes dict
|
||||
for k, v in nodes.items():
|
||||
if k != v.id:
|
||||
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
|
||||
# Validate that all node ids match the keys in the nodes dict
|
||||
for k, v in nodes.items():
|
||||
if k != v.id:
|
||||
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
|
||||
|
||||
# Validate that all edges match nodes in the graph
|
||||
node_ids = set([e.source.node_id for e in edges] + [e.destination.node_id for e in edges])
|
||||
missing_node_ids = [node_id for node_id in node_ids if node_id not in nodes]
|
||||
if missing_node_ids:
|
||||
raise NodeNotFoundError(f"All edges must reference nodes in the graph, missing nodes: {missing_node_ids}")
|
||||
|
||||
# Validate that all edge fields match node fields in the graph
|
||||
for edge in edges:
|
||||
source_node = nodes.get(edge.source.node_id, None)
|
||||
if source_node is None:
|
||||
raise NodeFieldNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
|
||||
|
||||
destination_node = nodes.get(edge.destination.node_id, None)
|
||||
if destination_node is None:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge destination node {edge.destination.node_id} does not exist in the graph"
|
||||
if edges is not None and nodes is not None:
|
||||
# Validate that all edges match nodes in the graph
|
||||
node_ids = set([e.source.node_id for e in edges] + [e.destination.node_id for e in edges])
|
||||
missing_node_ids = [node_id for node_id in node_ids if node_id not in nodes]
|
||||
if missing_node_ids:
|
||||
raise NodeNotFoundError(
|
||||
f"All edges must reference nodes in the graph, missing nodes: {missing_node_ids}"
|
||||
)
|
||||
|
||||
# output fields are not on the node object directly, they are on the output type
|
||||
if edge.source.field not in source_node.get_output_type().__fields__:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
||||
)
|
||||
# Validate that all edge fields match node fields in the graph
|
||||
for edge in edges:
|
||||
source_node = nodes.get(edge.source.node_id, None)
|
||||
if source_node is None:
|
||||
raise NodeFieldNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph")
|
||||
|
||||
# input fields are on the node
|
||||
if edge.destination.field not in destination_node.__fields__:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
|
||||
)
|
||||
destination_node = nodes.get(edge.destination.node_id, None)
|
||||
if destination_node is None:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge destination node {edge.destination.node_id} does not exist in the graph"
|
||||
)
|
||||
|
||||
# output fields are not on the node object directly, they are on the output type
|
||||
if edge.source.field not in source_node.get_output_type().__fields__:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}"
|
||||
)
|
||||
|
||||
# input fields are on the node
|
||||
if edge.destination.field not in destination_node.__fields__:
|
||||
raise NodeFieldNotFoundError(
|
||||
f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@@ -760,7 +764,6 @@ class GraphExecutionState(BaseModel):
|
||||
"""Tracks the state of a graph execution"""
|
||||
|
||||
id: str = Field(description="The id of the execution state", default_factory=uuid_string)
|
||||
|
||||
# TODO: Store a reference to the graph instead of the actual graph?
|
||||
graph: Graph = Field(description="The graph being executed")
|
||||
|
||||
|
||||
@@ -11,6 +11,10 @@ from pydantic import BaseModel, Field
|
||||
class InvocationQueueItem(BaseModel):
|
||||
graph_execution_state_id: str = Field(description="The ID of the graph execution state")
|
||||
invocation_id: str = Field(description="The ID of the node being invoked")
|
||||
session_queue_id: str = Field(description="The ID of the session queue from which this invocation queue item came")
|
||||
session_queue_item_id: str = Field(
|
||||
description="The ID of session queue item from which this invocation queue item came"
|
||||
)
|
||||
invoke_all: bool = Field(default=False)
|
||||
timestamp: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
@@ -17,7 +17,9 @@ class Invoker:
|
||||
self.services = services
|
||||
self._start()
|
||||
|
||||
def invoke(self, graph_execution_state: GraphExecutionState, invoke_all: bool = False) -> Optional[str]:
|
||||
def invoke(
|
||||
self, queue_id: str, queue_item_id: str, graph_execution_state: GraphExecutionState, invoke_all: bool = False
|
||||
) -> Optional[str]:
|
||||
"""Determines the next node to invoke and enqueues it, preparing if needed.
|
||||
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
|
||||
|
||||
@@ -32,7 +34,8 @@ class Invoker:
|
||||
# Queue the invocation
|
||||
self.services.queue.put(
|
||||
InvocationQueueItem(
|
||||
# session_id = session.id,
|
||||
session_queue_item_id=queue_item_id,
|
||||
session_queue_id=queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
invocation_id=invocation.id,
|
||||
invoke_all=invoke_all,
|
||||
|
||||
@@ -525,7 +525,7 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
def _emit_load_event(
|
||||
self,
|
||||
context,
|
||||
context: InvocationContext,
|
||||
model_name: str,
|
||||
base_model: BaseModelType,
|
||||
model_type: ModelType,
|
||||
@@ -537,6 +537,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
|
||||
if model_info:
|
||||
context.services.events.emit_model_load_completed(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
@@ -546,6 +548,8 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
)
|
||||
else:
|
||||
context.services.events.emit_model_load_started(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import time
|
||||
import traceback
|
||||
from threading import BoundedSemaphore, Event, Thread
|
||||
from typing import Optional
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
@@ -37,10 +38,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
try:
|
||||
self.__threadLimit.acquire()
|
||||
statistics: InvocationStatsServiceBase = self.__invoker.services.performance_statistics
|
||||
queue_item: Optional[InvocationQueueItem] = None
|
||||
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
queue_item: InvocationQueueItem = self.__invoker.services.queue.get()
|
||||
queue_item = self.__invoker.services.queue.get()
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while getting from queue:\n%s" % e)
|
||||
|
||||
@@ -48,7 +50,6 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
# do not hammer the queue
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
try:
|
||||
graph_execution_state = self.__invoker.services.graph_execution_manager.get(
|
||||
queue_item.graph_execution_state_id
|
||||
@@ -56,6 +57,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving session:\n%s" % e)
|
||||
self.__invoker.services.events.emit_session_retrieval_error(
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||
error_type=e.__class__.__name__,
|
||||
error=traceback.format_exc(),
|
||||
@@ -67,6 +70,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Exception while retrieving invocation:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_retrieval_error(
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=queue_item.graph_execution_state_id,
|
||||
node_id=queue_item.invocation_id,
|
||||
error_type=e.__class__.__name__,
|
||||
@@ -79,6 +84,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
# Send starting event
|
||||
self.__invoker.services.events.emit_invocation_started(
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
@@ -89,13 +96,16 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
graph_id = graph_execution_state.id
|
||||
model_manager = self.__invoker.services.model_manager
|
||||
with statistics.collect_stats(invocation, graph_id, model_manager):
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method in
|
||||
# this accomodates nodes which require a value, but get it only from a
|
||||
# connection
|
||||
# use the internal invoke_internal(), which wraps the node's invoke() method,
|
||||
# which handles a few things:
|
||||
# - nodes that require a value, but get it only from a connection
|
||||
# - referencing the invocation cache instead of executing the node
|
||||
outputs = invocation.invoke_internal(
|
||||
InvocationContext(
|
||||
services=self.__invoker.services,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -111,6 +121,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
|
||||
# Send complete event
|
||||
self.__invoker.services.events.emit_invocation_complete(
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
@@ -138,6 +150,8 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
# Send error event
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
@@ -155,10 +169,17 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
is_complete = graph_execution_state.is_complete()
|
||||
if queue_item.invoke_all and not is_complete:
|
||||
try:
|
||||
self.__invoker.invoke(graph_execution_state, invoke_all=True)
|
||||
self.__invoker.invoke(
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state=graph_execution_state,
|
||||
invoke_all=True,
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Error while invoking:\n%s" % e)
|
||||
self.__invoker.services.events.emit_invocation_error(
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
node=invocation.dict(),
|
||||
source_node_id=source_node_id,
|
||||
@@ -166,7 +187,11 @@ class DefaultInvocationProcessor(InvocationProcessorABC):
|
||||
error=traceback.format_exc(),
|
||||
)
|
||||
elif is_complete:
|
||||
self.__invoker.services.events.emit_graph_execution_complete(graph_execution_state.id)
|
||||
self.__invoker.services.events.emit_graph_execution_complete(
|
||||
queue_item_id=queue_item.session_queue_item_id,
|
||||
queue_id=queue_item.session_queue_id,
|
||||
graph_execution_state_id=graph_execution_state.id,
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass # Log something? KeyboardInterrupt is probably not going to be seen by the processor
|
||||
|
||||
@@ -26,20 +26,9 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self.__stop_event = ThreadEvent()
|
||||
self.__poll_now_event = ThreadEvent()
|
||||
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._on_session_event)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||
|
||||
self.__threadLimit = BoundedSemaphore(THREAD_LIMIT)
|
||||
self._start_thread()
|
||||
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
def _poll_now(self) -> None:
|
||||
self.__poll_now_event.set()
|
||||
|
||||
def _start_thread(self) -> None:
|
||||
# threads only live once, so we need to create a new one whenever we start the session processor
|
||||
self.__thread = Thread(
|
||||
name="session_processor",
|
||||
target=self.__process,
|
||||
@@ -49,40 +38,29 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
)
|
||||
self.__thread.start()
|
||||
|
||||
async def _on_session_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
if event_name in [
|
||||
"graph_execution_state_complete",
|
||||
"invocation_error",
|
||||
"session_retrieval_error",
|
||||
"invocation_retrieval_error",
|
||||
] or (
|
||||
event_name == "session_canceled"
|
||||
and self.__queue_item is not None
|
||||
and self.__queue_item.session_id == event[1]["data"]["graph_execution_state_id"]
|
||||
):
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
def stop(self, *args, **kwargs) -> None:
|
||||
self.__stop_event.set()
|
||||
|
||||
def _poll_now(self) -> None:
|
||||
self.__poll_now_event.set()
|
||||
|
||||
async def _on_queue_event(self, event: FastAPIEvent) -> None:
|
||||
event_name = event[1]["event"]
|
||||
if event_name == "batch_enqueued":
|
||||
self._poll_now()
|
||||
if event_name == "queue_cleared":
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
|
||||
def _is_started(self) -> bool:
|
||||
return self.__resume_event.is_set()
|
||||
|
||||
def _is_processing(self) -> bool:
|
||||
return self.__queue_item is not None
|
||||
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
return SessionProcessorStatus(
|
||||
is_started=self._is_started(),
|
||||
is_processing=self._is_processing(),
|
||||
)
|
||||
match event_name:
|
||||
case "graph_execution_state_complete" | "invocation_error" | "session_retrieval_error" | "invocation_retrieval_error":
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
case "session_canceled" if self.__queue_item is not None and self.__queue_item.session_id == event[1][
|
||||
"data"
|
||||
]["graph_execution_state_id"]:
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
case "batch_enqueued":
|
||||
self._poll_now()
|
||||
case "queue_cleared":
|
||||
self.__queue_item = None
|
||||
self._poll_now()
|
||||
|
||||
def resume(self) -> SessionProcessorStatus:
|
||||
if not self.__resume_event.is_set():
|
||||
@@ -94,6 +72,12 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self.__resume_event.clear()
|
||||
return self.get_status()
|
||||
|
||||
def get_status(self) -> SessionProcessorStatus:
|
||||
return SessionProcessorStatus(
|
||||
is_started=self.__resume_event.is_set(),
|
||||
is_processing=self.__queue_item is not None,
|
||||
)
|
||||
|
||||
def __process(
|
||||
self,
|
||||
stop_event: ThreadEvent,
|
||||
@@ -114,19 +98,23 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
queue_item = self.__invoker.services.session_queue.dequeue()
|
||||
|
||||
if queue_item is not None:
|
||||
# TODO: Why isn't the log level specified in dependencies.py working?
|
||||
# Within the thread, it is always INFO and `logger.debug()` doesn't display.
|
||||
# self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
|
||||
self.__invoker.services.logger.debug(f"Executing queue item {queue_item.item_id}")
|
||||
self.__queue_item = queue_item
|
||||
self.__invoker.services.graph_execution_manager.set(queue_item.session)
|
||||
self.__invoker.invoke(queue_item.session, invoke_all=True)
|
||||
self.__invoker.invoke(
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state=queue_item.session,
|
||||
invoke_all=True,
|
||||
)
|
||||
queue_item = None
|
||||
|
||||
if queue_item is None:
|
||||
# self.__invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||
poll_now_event.wait(POLLING_INTERVAL)
|
||||
continue
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error(f"Error in session processor: {e}")
|
||||
pass
|
||||
finally:
|
||||
stop_event.clear()
|
||||
|
||||
@@ -110,8 +110,3 @@ class SessionQueueBase(ABC):
|
||||
def get_queue_item(self, item_id: str) -> SessionQueueItem:
|
||||
"""Gets a session queue item by ID"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_queue_item_by_session_id(self, session_id: str) -> SessionQueueItem:
|
||||
"""Gets a queue item by session ID"""
|
||||
pass
|
||||
|
||||
@@ -90,7 +90,7 @@ class Batch(BaseModel):
|
||||
# Get the type of the first item in the list
|
||||
first_item_type = type(datum.items[0]) if datum.items else None
|
||||
for item in datum.items:
|
||||
if type(item) != first_item_type:
|
||||
if type(item) is not first_item_type:
|
||||
raise BatchItemsTypeError("All items in a batch must have the same type")
|
||||
return v
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._on_session_event)
|
||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
||||
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
||||
|
||||
def __init__(self, conn: sqlite3.Connection, lock: threading.Lock) -> None:
|
||||
@@ -60,28 +60,21 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
|
||||
event_name = event[1]["event"]
|
||||
match event_name:
|
||||
# successful completion events
|
||||
case "graph_execution_state_complete":
|
||||
await self._handle_complete_event(event)
|
||||
# error events
|
||||
case "invocation_error":
|
||||
case "invocation_error" | "session_retrieval_error" | "invocation_retrieval_error":
|
||||
await self._handle_error_event(event)
|
||||
case "session_retrieval_error":
|
||||
await self._handle_error_event(event)
|
||||
case "invocation_retrieval_error":
|
||||
await self._handle_error_event(event)
|
||||
# canceled events
|
||||
case "session_canceled":
|
||||
await self._handle_cancel_event(event)
|
||||
return event
|
||||
|
||||
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
session_id = event[1]["data"]["graph_execution_state_id"]
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
# When a queue item has an error, we get an error event, then a completed event.
|
||||
# Mark the queue item completed only if it isn't already marked completed, e.g.
|
||||
# by a previously-handled error event.
|
||||
queue_item = self.get_queue_item_by_session_id(session_id)
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ["completed", "failed", "canceled"]:
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
@@ -90,9 +83,9 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
async def _handle_error_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
session_id = event[1]["data"]["graph_execution_state_id"]
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
error = event[1]["data"]["error"]
|
||||
queue_item = self.get_queue_item_by_session_id(session_id)
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
except SessionQueueItemNotFoundError:
|
||||
@@ -100,8 +93,8 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
|
||||
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
|
||||
try:
|
||||
session_id = event[1]["data"]["graph_execution_state_id"]
|
||||
queue_item = self.get_queue_item_by_session_id(session_id)
|
||||
item_id = event[1]["data"]["queue_item_id"]
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
except SessionQueueItemNotFoundError:
|
||||
@@ -583,7 +576,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||
self.__invoker.services.queue.cancel(queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
graph_execution_state_id=queue_item.session_id,
|
||||
)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item)
|
||||
return queue_item
|
||||
|
||||
@@ -621,7 +618,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
@@ -662,7 +663,11 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
self.__conn.commit()
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self.__invoker.services.queue.cancel(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(current_queue_item.session_id)
|
||||
self.__invoker.services.events.emit_session_canceled(
|
||||
queue_item_id=current_queue_item.item_id,
|
||||
queue_id=current_queue_item.queue_id,
|
||||
graph_execution_state_id=current_queue_item.session_id,
|
||||
)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(current_queue_item)
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
@@ -692,27 +697,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
|
||||
def get_queue_item_by_session_id(self, session_id: str) -> SessionQueueItem:
|
||||
try:
|
||||
self.__lock.acquire()
|
||||
self.__cursor.execute(
|
||||
"""--sql
|
||||
SELECT * FROM session_queue
|
||||
WHERE
|
||||
session_id = ?
|
||||
""",
|
||||
(session_id,),
|
||||
)
|
||||
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
||||
except Exception:
|
||||
self.__conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
self.__lock.release()
|
||||
if result is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with session id {session_id}")
|
||||
return SessionQueueItem.from_dict(dict(result))
|
||||
|
||||
def list_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
|
||||
@@ -110,6 +110,8 @@ def stable_diffusion_step_callback(
|
||||
dataURL = image_to_dataURL(image, image_format="JPEG")
|
||||
|
||||
context.services.events.emit_generator_progress(
|
||||
queue_id=context.queue_id,
|
||||
queue_item_id=context.queue_item_id,
|
||||
graph_execution_state_id=context.graph_execution_state_id,
|
||||
node=node,
|
||||
source_node_id=source_node_id,
|
||||
|
||||
@@ -12,8 +12,12 @@ export const addSocketQueueItemStatusChangedEventListener = () => {
|
||||
actionCreator: socketQueueItemStatusChanged,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const log = logger('socketio');
|
||||
const { item_id, batch_id, graph_execution_state_id, status } =
|
||||
action.payload.data;
|
||||
const {
|
||||
queue_item_id: item_id,
|
||||
batch_id,
|
||||
graph_execution_state_id,
|
||||
status,
|
||||
} = action.payload.data;
|
||||
log.debug(
|
||||
action.payload,
|
||||
`Queue item ${item_id} status updated: ${status}`
|
||||
|
||||
@@ -8,11 +8,13 @@ export type paths = {
|
||||
"/api/v1/sessions/": {
|
||||
/**
|
||||
* List Sessions
|
||||
* @deprecated
|
||||
* @description Gets a list of sessions, optionally searching
|
||||
*/
|
||||
get: operations["list_sessions"];
|
||||
/**
|
||||
* Create Session
|
||||
* @deprecated
|
||||
* @description Creates a new session, optionally initializing it with an invocation graph
|
||||
*/
|
||||
post: operations["create_session"];
|
||||
@@ -20,6 +22,7 @@ export type paths = {
|
||||
"/api/v1/sessions/{session_id}": {
|
||||
/**
|
||||
* Get Session
|
||||
* @deprecated
|
||||
* @description Gets a session
|
||||
*/
|
||||
get: operations["get_session"];
|
||||
@@ -27,6 +30,7 @@ export type paths = {
|
||||
"/api/v1/sessions/{session_id}/nodes": {
|
||||
/**
|
||||
* Add Node
|
||||
* @deprecated
|
||||
* @description Adds a node to the graph
|
||||
*/
|
||||
post: operations["add_node"];
|
||||
@@ -34,11 +38,13 @@ export type paths = {
|
||||
"/api/v1/sessions/{session_id}/nodes/{node_path}": {
|
||||
/**
|
||||
* Update Node
|
||||
* @deprecated
|
||||
* @description Updates a node in the graph and removes all linked edges
|
||||
*/
|
||||
put: operations["update_node"];
|
||||
/**
|
||||
* Delete Node
|
||||
* @deprecated
|
||||
* @description Deletes a node in the graph and removes all linked edges
|
||||
*/
|
||||
delete: operations["delete_node"];
|
||||
@@ -46,6 +52,7 @@ export type paths = {
|
||||
"/api/v1/sessions/{session_id}/edges": {
|
||||
/**
|
||||
* Add Edge
|
||||
* @deprecated
|
||||
* @description Adds an edge to the graph
|
||||
*/
|
||||
post: operations["add_edge"];
|
||||
@@ -53,6 +60,7 @@ export type paths = {
|
||||
"/api/v1/sessions/{session_id}/edges/{from_node_id}/{from_field}/{to_node_id}/{to_field}": {
|
||||
/**
|
||||
* Delete Edge
|
||||
* @deprecated
|
||||
* @description Deletes an edge from the graph
|
||||
*/
|
||||
delete: operations["delete_edge"];
|
||||
@@ -60,11 +68,13 @@ export type paths = {
|
||||
"/api/v1/sessions/{session_id}/invoke": {
|
||||
/**
|
||||
* Invoke Session
|
||||
* @deprecated
|
||||
* @description Invokes a session
|
||||
*/
|
||||
put: operations["invoke_session"];
|
||||
/**
|
||||
* Cancel Session Invoke
|
||||
* @deprecated
|
||||
* @description Invokes a session
|
||||
*/
|
||||
delete: operations["cancel_session_invoke"];
|
||||
@@ -8161,35 +8171,35 @@ export type components = {
|
||||
ui_order?: number;
|
||||
};
|
||||
/**
|
||||
* ControlNetModelFormat
|
||||
* StableDiffusionXLModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
ControlNetModelFormat: "checkpoint" | "diffusers";
|
||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion2ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion2ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusionXLModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusionXLModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusionOnnxModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusionOnnxModelFormat: "olive" | "onnx";
|
||||
/**
|
||||
* ControlNetModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
ControlNetModelFormat: "checkpoint" | "diffusers";
|
||||
/**
|
||||
* StableDiffusion1ModelFormat
|
||||
* @description An enumeration.
|
||||
* @enum {string}
|
||||
*/
|
||||
StableDiffusion1ModelFormat: "checkpoint" | "diffusers";
|
||||
};
|
||||
responses: never;
|
||||
parameters: never;
|
||||
@@ -8206,6 +8216,7 @@ export type operations = {
|
||||
|
||||
/**
|
||||
* List Sessions
|
||||
* @deprecated
|
||||
* @description Gets a list of sessions, optionally searching
|
||||
*/
|
||||
list_sessions: {
|
||||
@@ -8236,9 +8247,16 @@ export type operations = {
|
||||
};
|
||||
/**
|
||||
* Create Session
|
||||
* @deprecated
|
||||
* @description Creates a new session, optionally initializing it with an invocation graph
|
||||
*/
|
||||
create_session: {
|
||||
parameters: {
|
||||
query?: {
|
||||
/** @description The id of the queue to associate the session with */
|
||||
queue_id?: string;
|
||||
};
|
||||
};
|
||||
requestBody?: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["Graph"];
|
||||
@@ -8265,6 +8283,7 @@ export type operations = {
|
||||
};
|
||||
/**
|
||||
* Get Session
|
||||
* @deprecated
|
||||
* @description Gets a session
|
||||
*/
|
||||
get_session: {
|
||||
@@ -8295,6 +8314,7 @@ export type operations = {
|
||||
};
|
||||
/**
|
||||
* Add Node
|
||||
* @deprecated
|
||||
* @description Adds a node to the graph
|
||||
*/
|
||||
add_node: {
|
||||
@@ -8334,6 +8354,7 @@ export type operations = {
|
||||
};
|
||||
/**
|
||||
* Update Node
|
||||
* @deprecated
|
||||
* @description Updates a node in the graph and removes all linked edges
|
||||
*/
|
||||
update_node: {
|
||||
@@ -8375,6 +8396,7 @@ export type operations = {
|
||||
};
|
||||
/**
|
||||
* Delete Node
|
||||
* @deprecated
|
||||
* @description Deletes a node in the graph and removes all linked edges
|
||||
*/
|
||||
delete_node: {
|
||||
@@ -8411,6 +8433,7 @@ export type operations = {
|
||||
};
|
||||
/**
|
||||
* Add Edge
|
||||
* @deprecated
|
||||
* @description Adds an edge to the graph
|
||||
*/
|
||||
add_edge: {
|
||||
@@ -8450,6 +8473,7 @@ export type operations = {
|
||||
};
|
||||
/**
|
||||
* Delete Edge
|
||||
* @deprecated
|
||||
* @description Deletes an edge from the graph
|
||||
*/
|
||||
delete_edge: {
|
||||
@@ -8492,11 +8516,14 @@ export type operations = {
|
||||
};
|
||||
/**
|
||||
* Invoke Session
|
||||
* @deprecated
|
||||
* @description Invokes a session
|
||||
*/
|
||||
invoke_session: {
|
||||
parameters: {
|
||||
query?: {
|
||||
query: {
|
||||
/** @description The id of the queue to associate the session with */
|
||||
queue_id: string;
|
||||
/** @description Whether or not to invoke all remaining invocations */
|
||||
all?: boolean;
|
||||
};
|
||||
@@ -8534,6 +8561,7 @@ export type operations = {
|
||||
};
|
||||
/**
|
||||
* Cancel Session Invoke
|
||||
* @deprecated
|
||||
* @description Invokes a session
|
||||
*/
|
||||
cancel_session_invoke: {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createAsyncThunk, isAnyOf } from '@reduxjs/toolkit';
|
||||
import { $queueId } from 'features/queue/store/nanoStores';
|
||||
import { isObject } from 'lodash-es';
|
||||
import { $client } from 'services/api/client';
|
||||
import { paths } from 'services/api/schema';
|
||||
@@ -33,6 +34,7 @@ export const sessionCreated = createAsyncThunk<
|
||||
const { POST } = $client.get();
|
||||
const { data, error, response } = await POST('/api/v1/sessions/', {
|
||||
body: graph,
|
||||
params: { query: { queue_id: $queueId.get() } },
|
||||
});
|
||||
|
||||
if (error) {
|
||||
@@ -76,7 +78,10 @@ export const sessionInvoked = createAsyncThunk<
|
||||
const { error, response } = await PUT(
|
||||
'/api/v1/sessions/{session_id}/invoke',
|
||||
{
|
||||
params: { query: { all: true }, path: { session_id } },
|
||||
params: {
|
||||
query: { queue_id: $queueId.get(), all: true },
|
||||
path: { session_id },
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
@@ -33,6 +33,8 @@ export type BaseNode = {
|
||||
};
|
||||
|
||||
export type ModelLoadStartedEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: string;
|
||||
graph_execution_state_id: string;
|
||||
model_name: string;
|
||||
base_model: BaseModelType;
|
||||
@@ -41,6 +43,8 @@ export type ModelLoadStartedEvent = {
|
||||
};
|
||||
|
||||
export type ModelLoadCompletedEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: string;
|
||||
graph_execution_state_id: string;
|
||||
model_name: string;
|
||||
base_model: BaseModelType;
|
||||
@@ -57,6 +61,8 @@ export type ModelLoadCompletedEvent = {
|
||||
* @example socket.on('generator_progress', (data: GeneratorProgressEvent) => { ... }
|
||||
*/
|
||||
export type GeneratorProgressEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node: BaseNode;
|
||||
source_node_id: string;
|
||||
@@ -73,6 +79,8 @@ export type GeneratorProgressEvent = {
|
||||
* @example socket.on('invocation_complete', (data: InvocationCompleteEvent) => { ... }
|
||||
*/
|
||||
export type InvocationCompleteEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node: BaseNode;
|
||||
source_node_id: string;
|
||||
@@ -85,6 +93,8 @@ export type InvocationCompleteEvent = {
|
||||
* @example socket.on('invocation_error', (data: InvocationErrorEvent) => { ... }
|
||||
*/
|
||||
export type InvocationErrorEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node: BaseNode;
|
||||
source_node_id: string;
|
||||
@@ -98,6 +108,8 @@ export type InvocationErrorEvent = {
|
||||
* @example socket.on('invocation_started', (data: InvocationStartedEvent) => { ... }
|
||||
*/
|
||||
export type InvocationStartedEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node: BaseNode;
|
||||
source_node_id: string;
|
||||
@@ -109,6 +121,8 @@ export type InvocationStartedEvent = {
|
||||
* @example socket.on('graph_execution_state_complete', (data: GraphExecutionStateCompleteEvent) => { ... }
|
||||
*/
|
||||
export type GraphExecutionStateCompleteEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: string;
|
||||
graph_execution_state_id: string;
|
||||
};
|
||||
|
||||
@@ -118,6 +132,8 @@ export type GraphExecutionStateCompleteEvent = {
|
||||
* @example socket.on('session_retrieval_error', (data: SessionRetrievalErrorEvent) => { ... }
|
||||
*/
|
||||
export type SessionRetrievalErrorEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: string;
|
||||
graph_execution_state_id: string;
|
||||
error_type: string;
|
||||
error: string;
|
||||
@@ -129,6 +145,8 @@ export type SessionRetrievalErrorEvent = {
|
||||
* @example socket.on('invocation_retrieval_error', (data: InvocationRetrievalErrorEvent) => { ... }
|
||||
*/
|
||||
export type InvocationRetrievalErrorEvent = {
|
||||
queue_id: string;
|
||||
queue_item_id: string;
|
||||
graph_execution_state_id: string;
|
||||
node_id: string;
|
||||
error_type: string;
|
||||
@@ -141,10 +159,10 @@ export type InvocationRetrievalErrorEvent = {
|
||||
* @example socket.on('queue_item_status_changed', (data: QueueItemStatusChangedEvent) => { ... }
|
||||
*/
|
||||
export type QueueItemStatusChangedEvent = {
|
||||
item_id: string;
|
||||
queue_id: string;
|
||||
queue_item_id: string;
|
||||
batch_id: string;
|
||||
session_id: string;
|
||||
queue_id: string;
|
||||
graph_execution_state_id: string;
|
||||
status: components['schemas']['SessionQueueItemDTO']['status'];
|
||||
error: string | undefined;
|
||||
@@ -154,14 +172,6 @@ export type QueueItemStatusChangedEvent = {
|
||||
completed_at: string | undefined;
|
||||
};
|
||||
|
||||
export type ClientEmitSubscribeSession = {
|
||||
session: string;
|
||||
};
|
||||
|
||||
export type ClientEmitUnsubscribeSession = {
|
||||
session: string;
|
||||
};
|
||||
|
||||
export type ClientEmitSubscribeQueue = {
|
||||
queue_id: string;
|
||||
};
|
||||
@@ -188,8 +198,6 @@ export type ServerToClientEvents = {
|
||||
export type ClientToServerEvents = {
|
||||
connect: () => void;
|
||||
disconnect: () => void;
|
||||
subscribe_session: (payload: ClientEmitSubscribeSession) => void;
|
||||
unsubscribe_session: (payload: ClientEmitUnsubscribeSession) => void;
|
||||
subscribe_queue: (payload: ClientEmitSubscribeQueue) => void;
|
||||
unsubscribe_queue: (payload: ClientEmitUnsubscribeQueue) => void;
|
||||
};
|
||||
|
||||
@@ -156,17 +156,6 @@ export const setEventListeners = (arg: SetEventListenersArg) => {
|
||||
});
|
||||
|
||||
socket.on('queue_item_status_changed', (data) => {
|
||||
const { status, session_id } = data;
|
||||
if (status === 'in_progress') {
|
||||
socket.emit('subscribe_session', {
|
||||
session: session_id,
|
||||
});
|
||||
}
|
||||
if (['new', 'completed', 'failed', 'canceled'].includes(status)) {
|
||||
socket.emit('unsubscribe_session', {
|
||||
session: session_id,
|
||||
});
|
||||
}
|
||||
dispatch(socketQueueItemStatusChanged({ data }));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -9,7 +9,6 @@ from .test_nodes import ( # isort: split
|
||||
TestEventService,
|
||||
TextToImageTestInvocation,
|
||||
)
|
||||
|
||||
import sqlite3
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
|
||||
@@ -20,6 +19,7 @@ from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
|
||||
from .test_invoker import create_edge
|
||||
@@ -70,7 +70,9 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B
|
||||
return (None, None)
|
||||
|
||||
print(f"invoking {n.id}: {type(n)}")
|
||||
o = n.invoke(InvocationContext(services, "1"))
|
||||
o = n.invoke(
|
||||
InvocationContext(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, services=services, graph_execution_state_id="1")
|
||||
)
|
||||
g.complete(n.id, o)
|
||||
|
||||
return (n, o)
|
||||
|
||||
@@ -3,15 +3,8 @@ import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
|
||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
|
||||
from .test_nodes import (
|
||||
# This import must happen before other invoke imports or test in other files(!!) break
|
||||
from .test_nodes import ( # isort: split
|
||||
ErrorInvocation,
|
||||
PromptTestInvocation,
|
||||
TestEventService,
|
||||
@@ -20,6 +13,15 @@ from .test_nodes import (
|
||||
wait_until,
|
||||
)
|
||||
|
||||
from invokeai.app.services.graph import Graph, GraphExecutionState, GraphInvocation, LibraryGraph
|
||||
from invokeai.app.services.invocation_queue import MemoryInvocationQueue
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.processor import DefaultInvocationProcessor
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.sqlite import SqliteItemStorage, sqlite_memory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_graph():
|
||||
@@ -96,7 +98,7 @@ def test_can_create_graph_state_from_graph(mock_invoker: Invoker, simple_graph):
|
||||
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||
def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||
invocation_id = mock_invoker.invoke(g)
|
||||
invocation_id = mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g)
|
||||
assert invocation_id is not None
|
||||
|
||||
def has_executed_any(g: GraphExecutionState):
|
||||
@@ -113,7 +115,9 @@ def test_can_invoke(mock_invoker: Invoker, simple_graph):
|
||||
# @pytest.mark.xfail(reason = "Requires fixing following the model manager refactor")
|
||||
def test_can_invoke_all(mock_invoker: Invoker, simple_graph):
|
||||
g = mock_invoker.create_execution_state(graph=simple_graph)
|
||||
invocation_id = mock_invoker.invoke(g, invoke_all=True)
|
||||
invocation_id = mock_invoker.invoke(
|
||||
queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True
|
||||
)
|
||||
assert invocation_id is not None
|
||||
|
||||
def has_executed_all(g: GraphExecutionState):
|
||||
@@ -132,7 +136,7 @@ def test_handles_errors(mock_invoker: Invoker):
|
||||
g = mock_invoker.create_execution_state()
|
||||
g.graph.add_node(ErrorInvocation(id="1"))
|
||||
|
||||
mock_invoker.invoke(g, invoke_all=True)
|
||||
mock_invoker.invoke(queue_item_id="1", queue_id=DEFAULT_QUEUE_ID, graph_execution_state=g, invoke_all=True)
|
||||
|
||||
def has_executed_all(g: GraphExecutionState):
|
||||
g = mock_invoker.services.graph_execution_manager.get(g.id)
|
||||
|
||||
Reference in New Issue
Block a user