mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-22 12:58:22 -05:00
Compare commits
31 Commits
bria-clone
...
psyche/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc78a0e699 | ||
|
|
08a42c3c03 | ||
|
|
0758e9cb9b | ||
|
|
fb93e686b2 | ||
|
|
350feeed56 | ||
|
|
169b75b2b7 | ||
|
|
c88de180e7 | ||
|
|
7d1844eaf2 | ||
|
|
a98ddedb95 | ||
|
|
6063487b20 | ||
|
|
9a4c167342 | ||
|
|
19227fe4e6 | ||
|
|
db0ef8d316 | ||
|
|
6a34176376 | ||
|
|
d6696a7b97 | ||
|
|
0e81e7b460 | ||
|
|
7652fbc2e9 | ||
|
|
a55b2f09e2 | ||
|
|
23b05344a3 | ||
|
|
80905ff3ea | ||
|
|
df5457231f | ||
|
|
d30c1ad6dc | ||
|
|
b1f819ae8d | ||
|
|
eff359625a | ||
|
|
cef1585dfb | ||
|
|
cb8e9e1c7b | ||
|
|
f7c356d142 | ||
|
|
efb069dd71 | ||
|
|
8edc25d35a | ||
|
|
82957bb826 | ||
|
|
e51a3025ea |
@@ -29,7 +29,7 @@ from ..services.model_images.model_images_default import ModelImageFileStorageDi
|
|||||||
from ..services.model_manager.model_manager_default import ModelManagerService
|
from ..services.model_manager.model_manager_default import ModelManagerService
|
||||||
from ..services.model_records import ModelRecordServiceSQL
|
from ..services.model_records import ModelRecordServiceSQL
|
||||||
from ..services.names.names_default import SimpleNameService
|
from ..services.names.names_default import SimpleNameService
|
||||||
from ..services.session_processor.session_processor_default import DefaultSessionProcessor
|
from ..services.session_processor.session_processor_default import DefaultSessionProcessor, DefaultSessionRunner
|
||||||
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
from ..services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||||
from ..services.urls.urls_default import LocalUrlService
|
from ..services.urls.urls_default import LocalUrlService
|
||||||
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
from ..services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||||
@@ -103,7 +103,8 @@ class ApiDependencies:
|
|||||||
)
|
)
|
||||||
names = SimpleNameService()
|
names = SimpleNameService()
|
||||||
performance_statistics = InvocationStatsService()
|
performance_statistics = InvocationStatsService()
|
||||||
session_processor = DefaultSessionProcessor()
|
|
||||||
|
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
|
||||||
session_queue = SqliteSessionQueue(db=db)
|
session_queue = SqliteSessionQueue(db=db)
|
||||||
urls = LocalUrlService()
|
urls = LocalUrlService()
|
||||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||||
|
|||||||
@@ -121,7 +121,8 @@ class EventServiceBase:
|
|||||||
node: dict,
|
node: dict,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
error_type: str,
|
error_type: str,
|
||||||
error: str,
|
error_message: str,
|
||||||
|
error_traceback: str,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
project_id: str | None,
|
project_id: str | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -136,7 +137,8 @@ class EventServiceBase:
|
|||||||
"node": node,
|
"node": node,
|
||||||
"source_node_id": source_node_id,
|
"source_node_id": source_node_id,
|
||||||
"error_type": error_type,
|
"error_type": error_type,
|
||||||
"error": error,
|
"error_message": error_message,
|
||||||
|
"error_traceback": error_traceback,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"project_id": project_id,
|
"project_id": project_id,
|
||||||
},
|
},
|
||||||
@@ -257,7 +259,9 @@ class EventServiceBase:
|
|||||||
"status": session_queue_item.status,
|
"status": session_queue_item.status,
|
||||||
"batch_id": session_queue_item.batch_id,
|
"batch_id": session_queue_item.batch_id,
|
||||||
"session_id": session_queue_item.session_id,
|
"session_id": session_queue_item.session_id,
|
||||||
"error": session_queue_item.error,
|
"error_type": session_queue_item.error_type,
|
||||||
|
"error_message": session_queue_item.error_message,
|
||||||
|
"error_traceback": session_queue_item.error_traceback,
|
||||||
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
"created_at": str(session_queue_item.created_at) if session_queue_item.created_at else None,
|
||||||
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
"updated_at": str(session_queue_item.updated_at) if session_queue_item.updated_at else None,
|
||||||
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
"started_at": str(session_queue_item.started_at) if session_queue_item.started_at else None,
|
||||||
|
|||||||
@@ -1,6 +1,49 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from threading import Event
|
||||||
|
from typing import Optional, Protocol
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
from invokeai.app.services.session_processor.session_processor_common import SessionProcessorStatus
|
||||||
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
||||||
|
from invokeai.app.util.profiler import Profiler
|
||||||
|
|
||||||
|
|
||||||
|
class SessionRunnerBase(ABC):
|
||||||
|
"""
|
||||||
|
Base class for session runner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
|
||||||
|
"""Starts the session runner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
services: The invocation services.
|
||||||
|
cancel_event: The cancel event.
|
||||||
|
profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session
|
||||||
|
stats will be still be recorded and logged when profiling is disabled.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run(self, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Runs a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_item: The session to run.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Run a single node in the graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
invocation: The invocation to run.
|
||||||
|
queue_item: The session queue item.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SessionProcessorBase(ABC):
|
class SessionProcessorBase(ABC):
|
||||||
@@ -26,3 +69,85 @@ class SessionProcessorBase(ABC):
|
|||||||
def get_status(self) -> SessionProcessorStatus:
|
def get_status(self) -> SessionProcessorStatus:
|
||||||
"""Gets the status of the session processor"""
|
"""Gets the status of the session processor"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OnBeforeRunNode(Protocol):
|
||||||
|
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Callback to run before executing a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
invocation: The invocation that will be executed.
|
||||||
|
queue_item: The session queue item.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class OnAfterRunNode(Protocol):
|
||||||
|
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None:
|
||||||
|
"""Callback to run before executing a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
invocation: The invocation that was executed.
|
||||||
|
queue_item: The session queue item.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class OnNodeError(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
invocation: BaseInvocation,
|
||||||
|
queue_item: SessionQueueItem,
|
||||||
|
error_type: str,
|
||||||
|
error_message: str,
|
||||||
|
error_traceback: str,
|
||||||
|
) -> None:
|
||||||
|
"""Callback to run when a node has an error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
invocation: The invocation that errored.
|
||||||
|
queue_item: The session queue item.
|
||||||
|
error_type: The type of error, e.g. "ValueError".
|
||||||
|
error_message: The error message, e.g. "Invalid value".
|
||||||
|
error_traceback: The stringified error traceback.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class OnBeforeRunSession(Protocol):
|
||||||
|
def __call__(self, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Callback to run before executing a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_item: The session queue item.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class OnAfterRunSession(Protocol):
|
||||||
|
def __call__(self, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Callback to run after executing a session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_item: The session queue item.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class OnNonFatalProcessorError(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
queue_item: Optional[SessionQueueItem],
|
||||||
|
error_type: str,
|
||||||
|
error_message: str,
|
||||||
|
error_traceback: str,
|
||||||
|
) -> None:
|
||||||
|
"""Callback to run when a non-fatal error occurs in the processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
queue_item: The session queue item, if one was being executed when the error occurred.
|
||||||
|
error_type: The type of error, e.g. "ValueError".
|
||||||
|
error_message: The error message, e.g. "Invalid value".
|
||||||
|
error_traceback: The stringified error traceback.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|||||||
@@ -7,21 +7,305 @@ from typing import Optional
|
|||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.typing import Event as FastAPIEvent
|
from fastapi_events.typing import Event as FastAPIEvent
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError
|
||||||
|
from invokeai.app.services.session_processor.session_processor_base import (
|
||||||
|
OnAfterRunNode,
|
||||||
|
OnAfterRunSession,
|
||||||
|
OnBeforeRunNode,
|
||||||
|
OnBeforeRunSession,
|
||||||
|
OnNodeError,
|
||||||
|
OnNonFatalProcessorError,
|
||||||
|
)
|
||||||
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
from invokeai.app.services.session_processor.session_processor_common import CanceledException
|
||||||
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem, SessionQueueItemNotFoundError
|
||||||
|
from invokeai.app.services.shared.graph import NodeInputError
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
from invokeai.app.services.shared.invocation_context import InvocationContextData, build_invocation_context
|
||||||
from invokeai.app.util.profiler import Profiler
|
from invokeai.app.util.profiler import Profiler
|
||||||
|
|
||||||
from ..invoker import Invoker
|
from ..invoker import Invoker
|
||||||
from .session_processor_base import SessionProcessorBase
|
from .session_processor_base import InvocationServices, SessionProcessorBase, SessionRunnerBase
|
||||||
from .session_processor_common import SessionProcessorStatus
|
from .session_processor_common import SessionProcessorStatus
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultSessionRunner(SessionRunnerBase):
|
||||||
|
"""Processes a single session's invocations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
on_before_run_session_callbacks: Optional[list[OnBeforeRunSession]] = None,
|
||||||
|
on_before_run_node_callbacks: Optional[list[OnBeforeRunNode]] = None,
|
||||||
|
on_after_run_node_callbacks: Optional[list[OnAfterRunNode]] = None,
|
||||||
|
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
|
||||||
|
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
on_before_run_session_callbacks: Callbacks to run before the session starts.
|
||||||
|
on_before_run_node_callbacks: Callbacks to run before each node starts.
|
||||||
|
on_after_run_node_callbacks: Callbacks to run after each node completes.
|
||||||
|
on_node_error_callbacks: Callbacks to run when a node errors.
|
||||||
|
on_after_run_session_callbacks: Callbacks to run after the session completes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
|
||||||
|
self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
|
||||||
|
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
|
||||||
|
self._on_node_error_callbacks = on_node_error_callbacks or []
|
||||||
|
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
|
||||||
|
|
||||||
|
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
||||||
|
self._services = services
|
||||||
|
self._cancel_event = cancel_event
|
||||||
|
self._profiler = profiler
|
||||||
|
|
||||||
|
def run(self, queue_item: SessionQueueItem):
|
||||||
|
# Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here.
|
||||||
|
|
||||||
|
self._on_before_run_session(queue_item=queue_item)
|
||||||
|
|
||||||
|
# Loop over invocations until the session is complete or canceled
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
invocation = queue_item.session.next()
|
||||||
|
# Anything other than a `NodeInputError` is handled as a processor error
|
||||||
|
except NodeInputError as e:
|
||||||
|
error_type = e.__class__.__name__
|
||||||
|
error_message = str(e)
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
self._on_node_error(
|
||||||
|
invocation=e.node,
|
||||||
|
queue_item=queue_item,
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if invocation is None or self._cancel_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
self.run_node(invocation, queue_item)
|
||||||
|
|
||||||
|
# The session is complete if all invocations have been run or there is an error on the session.
|
||||||
|
if queue_item.session.is_complete() or self._cancel_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
self._on_after_run_session(queue_item=queue_item)
|
||||||
|
|
||||||
|
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||||
|
try:
|
||||||
|
# Any unhandled exception in this scope is an invocation error & will fail the graph
|
||||||
|
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||||
|
self._on_before_run_node(invocation, queue_item)
|
||||||
|
|
||||||
|
data = InvocationContextData(
|
||||||
|
invocation=invocation,
|
||||||
|
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
|
queue_item=queue_item,
|
||||||
|
)
|
||||||
|
context = build_invocation_context(
|
||||||
|
data=data,
|
||||||
|
services=self._services,
|
||||||
|
cancel_event=self._cancel_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invoke the node
|
||||||
|
output = invocation.invoke_internal(context=context, services=self._services)
|
||||||
|
# Save output and history
|
||||||
|
queue_item.session.complete(invocation.id, output)
|
||||||
|
|
||||||
|
self._on_after_run_node(invocation, queue_item, output)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here?
|
||||||
|
pass
|
||||||
|
except CanceledException:
|
||||||
|
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||||
|
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||||
|
# be able to cancel them mid-execution.
|
||||||
|
#
|
||||||
|
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||||
|
# is executed after each step. This step callback checks if the canceled event is set,
|
||||||
|
# then raises a CanceledException to stop execution immediately.
|
||||||
|
#
|
||||||
|
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||||
|
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
error_type = e.__class__.__name__
|
||||||
|
error_message = str(e)
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
self._on_node_error(
|
||||||
|
invocation=invocation,
|
||||||
|
queue_item=queue_item,
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Run before a session is executed"""
|
||||||
|
|
||||||
|
self._services.logger.debug(
|
||||||
|
f"On before run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If profiling is enabled, start the profiler
|
||||||
|
if self._profiler is not None:
|
||||||
|
self._profiler.start(profile_id=queue_item.session_id)
|
||||||
|
|
||||||
|
for callback in self._on_before_run_session_callbacks:
|
||||||
|
callback(queue_item=queue_item)
|
||||||
|
|
||||||
|
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||||
|
"""Run after a session is executed"""
|
||||||
|
|
||||||
|
self._services.logger.debug(
|
||||||
|
f"On after run session: queue item {queue_item.item_id}, session {queue_item.session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we are profiling, stop the profiler and dump the profile & stats
|
||||||
|
if self._profiler is not None:
|
||||||
|
profile_path = self._profiler.stop()
|
||||||
|
stats_path = profile_path.with_suffix(".json")
|
||||||
|
self._services.performance_statistics.dump_stats(
|
||||||
|
graph_execution_state_id=queue_item.session.id, output_path=stats_path
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Update the queue item with the completed session. If the queue item has been removed from the queue,
|
||||||
|
# we'll get a SessionQueueItemNotFoundError and we can ignore it. This can happen if the queue is cleared
|
||||||
|
# while the session is running.
|
||||||
|
queue_item = self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||||
|
|
||||||
|
# TODO(psyche): This feels jumbled - we should review separation of concerns here.
|
||||||
|
# Send complete event. The events service will receive this and update the queue item's status.
|
||||||
|
self._services.events.emit_graph_execution_complete(
|
||||||
|
queue_batch_id=queue_item.batch_id,
|
||||||
|
queue_item_id=queue_item.item_id,
|
||||||
|
queue_id=queue_item.queue_id,
|
||||||
|
graph_execution_state_id=queue_item.session.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||||
|
# we don't care about that - suppress the error.
|
||||||
|
with suppress(GESStatsNotFoundError):
|
||||||
|
self._services.performance_statistics.log_stats(queue_item.session.id)
|
||||||
|
self._services.performance_statistics.reset_stats()
|
||||||
|
|
||||||
|
for callback in self._on_after_run_session_callbacks:
|
||||||
|
callback(queue_item=queue_item)
|
||||||
|
except SessionQueueItemNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||||
|
"""Run before a node is executed"""
|
||||||
|
|
||||||
|
self._services.logger.debug(
|
||||||
|
f"On before run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send starting event
|
||||||
|
self._services.events.emit_invocation_started(
|
||||||
|
queue_batch_id=queue_item.batch_id,
|
||||||
|
queue_item_id=queue_item.item_id,
|
||||||
|
queue_id=queue_item.queue_id,
|
||||||
|
graph_execution_state_id=queue_item.session_id,
|
||||||
|
node=invocation.model_dump(),
|
||||||
|
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
|
)
|
||||||
|
|
||||||
|
for callback in self._on_before_run_node_callbacks:
|
||||||
|
callback(invocation=invocation, queue_item=queue_item)
|
||||||
|
|
||||||
|
def _on_after_run_node(
|
||||||
|
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
|
||||||
|
):
|
||||||
|
"""Run after a node is executed"""
|
||||||
|
|
||||||
|
self._services.logger.debug(
|
||||||
|
f"On after run node: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send complete event on successful runs
|
||||||
|
self._services.events.emit_invocation_complete(
|
||||||
|
queue_batch_id=queue_item.batch_id,
|
||||||
|
queue_item_id=queue_item.item_id,
|
||||||
|
queue_id=queue_item.queue_id,
|
||||||
|
graph_execution_state_id=queue_item.session.id,
|
||||||
|
node=invocation.model_dump(),
|
||||||
|
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
|
result=output.model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
for callback in self._on_after_run_node_callbacks:
|
||||||
|
callback(invocation=invocation, queue_item=queue_item, output=output)
|
||||||
|
|
||||||
|
def _on_node_error(
|
||||||
|
self,
|
||||||
|
invocation: BaseInvocation,
|
||||||
|
queue_item: SessionQueueItem,
|
||||||
|
error_type: str,
|
||||||
|
error_message: str,
|
||||||
|
error_traceback: str,
|
||||||
|
):
|
||||||
|
"""Run when a node errors"""
|
||||||
|
|
||||||
|
self._services.logger.debug(
|
||||||
|
f"On node error: queue item {queue_item.item_id}, session {queue_item.session_id}, node {invocation.id} ({invocation.get_type()})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
|
||||||
|
node_error = f"{error_type}: {error_message}"
|
||||||
|
queue_item.session.set_node_error(invocation.id, node_error)
|
||||||
|
self._services.logger.error(
|
||||||
|
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}"
|
||||||
|
)
|
||||||
|
self._services.logger.error(error_traceback)
|
||||||
|
|
||||||
|
# Send error event
|
||||||
|
self._services.events.emit_invocation_error(
|
||||||
|
queue_batch_id=queue_item.session_id,
|
||||||
|
queue_item_id=queue_item.item_id,
|
||||||
|
queue_id=queue_item.queue_id,
|
||||||
|
graph_execution_state_id=queue_item.session.id,
|
||||||
|
node=invocation.model_dump(),
|
||||||
|
source_node_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
user_id=getattr(queue_item, "user_id", None),
|
||||||
|
project_id=getattr(queue_item, "project_id", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
for callback in self._on_node_error_callbacks:
|
||||||
|
callback(
|
||||||
|
invocation=invocation,
|
||||||
|
queue_item=queue_item,
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DefaultSessionProcessor(SessionProcessorBase):
|
class DefaultSessionProcessor(SessionProcessorBase):
|
||||||
def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_runner: Optional[SessionRunnerBase] = None,
|
||||||
|
on_non_fatal_processor_error_callbacks: Optional[list[OnNonFatalProcessorError]] = None,
|
||||||
|
thread_limit: int = 1,
|
||||||
|
polling_interval: int = 1,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.session_runner = session_runner if session_runner else DefaultSessionRunner()
|
||||||
|
self._on_non_fatal_processor_error_callbacks = on_non_fatal_processor_error_callbacks or []
|
||||||
|
self._thread_limit = thread_limit
|
||||||
|
self._polling_interval = polling_interval
|
||||||
|
|
||||||
|
def start(self, invoker: Invoker) -> None:
|
||||||
self._invoker: Invoker = invoker
|
self._invoker: Invoker = invoker
|
||||||
self._queue_item: Optional[SessionQueueItem] = None
|
self._queue_item: Optional[SessionQueueItem] = None
|
||||||
self._invocation: Optional[BaseInvocation] = None
|
self._invocation: Optional[BaseInvocation] = None
|
||||||
@@ -33,9 +317,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
|
|
||||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||||
|
|
||||||
self._thread_limit = thread_limit
|
self._thread_semaphore = BoundedSemaphore(self._thread_limit)
|
||||||
self._thread_semaphore = BoundedSemaphore(thread_limit)
|
|
||||||
self._polling_interval = polling_interval
|
|
||||||
|
|
||||||
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
|
# If profiling is enabled, create a profiler. The same profiler will be used for all sessions. Internally,
|
||||||
# the profiler will create a new profile for each session.
|
# the profiler will create a new profile for each session.
|
||||||
@@ -49,6 +331,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.session_runner.start(services=invoker.services, cancel_event=self._cancel_event, profiler=self._profiler)
|
||||||
self._thread = Thread(
|
self._thread = Thread(
|
||||||
name="session_processor",
|
name="session_processor",
|
||||||
target=self._process,
|
target=self._process,
|
||||||
@@ -91,6 +374,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
"failed",
|
"failed",
|
||||||
"canceled",
|
"canceled",
|
||||||
]:
|
]:
|
||||||
|
self._cancel_event.set()
|
||||||
self._poll_now()
|
self._poll_now()
|
||||||
|
|
||||||
def resume(self) -> SessionProcessorStatus:
|
def resume(self) -> SessionProcessorStatus:
|
||||||
@@ -116,8 +400,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
resume_event: ThreadEvent,
|
resume_event: ThreadEvent,
|
||||||
cancel_event: ThreadEvent,
|
cancel_event: ThreadEvent,
|
||||||
):
|
):
|
||||||
# Outermost processor try block; any unhandled exception is a fatal processor error
|
|
||||||
try:
|
try:
|
||||||
|
# Any unhandled exception in this block is a fatal processor error and will stop the processor.
|
||||||
self._thread_semaphore.acquire()
|
self._thread_semaphore.acquire()
|
||||||
stop_event.clear()
|
stop_event.clear()
|
||||||
resume_event.set()
|
resume_event.set()
|
||||||
@@ -125,8 +409,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
|
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
poll_now_event.clear()
|
poll_now_event.clear()
|
||||||
# Middle processor try block; any unhandled exception is a non-fatal processor error
|
|
||||||
try:
|
try:
|
||||||
|
# Any unhandled exception in this block is a nonfatal processor error and will be handled.
|
||||||
# If we are paused, wait for resume event
|
# If we are paused, wait for resume event
|
||||||
resume_event.wait()
|
resume_event.wait()
|
||||||
|
|
||||||
@@ -142,165 +426,62 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
|
||||||
cancel_event.clear()
|
cancel_event.clear()
|
||||||
|
|
||||||
# If profiling is enabled, start the profiler
|
# Run the graph
|
||||||
if self._profiler is not None:
|
self.session_runner.run(queue_item=self._queue_item)
|
||||||
self._profiler.start(profile_id=self._queue_item.session_id)
|
|
||||||
|
|
||||||
# Prepare invocations and take the first
|
except Exception as e:
|
||||||
self._invocation = self._queue_item.session.next()
|
error_type = e.__class__.__name__
|
||||||
|
error_message = str(e)
|
||||||
# Loop over invocations until the session is complete or canceled
|
error_traceback = traceback.format_exc()
|
||||||
while self._invocation is not None and not cancel_event.is_set():
|
self._on_non_fatal_processor_error(
|
||||||
# get the source node id to provide to clients (the prepared node id is not as useful)
|
queue_item=self._queue_item,
|
||||||
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
# Send starting event
|
error_traceback=error_traceback,
|
||||||
self._invoker.services.events.emit_invocation_started(
|
|
||||||
queue_batch_id=self._queue_item.batch_id,
|
|
||||||
queue_item_id=self._queue_item.item_id,
|
|
||||||
queue_id=self._queue_item.queue_id,
|
|
||||||
graph_execution_state_id=self._queue_item.session_id,
|
|
||||||
node=self._invocation.model_dump(),
|
|
||||||
source_node_id=source_invocation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
|
|
||||||
try:
|
|
||||||
with self._invoker.services.performance_statistics.collect_stats(
|
|
||||||
self._invocation, self._queue_item.session.id
|
|
||||||
):
|
|
||||||
# Build invocation context (the node-facing API)
|
|
||||||
data = InvocationContextData(
|
|
||||||
invocation=self._invocation,
|
|
||||||
source_invocation_id=source_invocation_id,
|
|
||||||
queue_item=self._queue_item,
|
|
||||||
)
|
|
||||||
context = build_invocation_context(
|
|
||||||
data=data,
|
|
||||||
services=self._invoker.services,
|
|
||||||
cancel_event=self._cancel_event,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Invoke the node
|
|
||||||
outputs = self._invocation.invoke_internal(
|
|
||||||
context=context, services=self._invoker.services
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save outputs and history
|
|
||||||
self._queue_item.session.complete(self._invocation.id, outputs)
|
|
||||||
|
|
||||||
# Send complete event
|
|
||||||
self._invoker.services.events.emit_invocation_complete(
|
|
||||||
queue_batch_id=self._queue_item.batch_id,
|
|
||||||
queue_item_id=self._queue_item.item_id,
|
|
||||||
queue_id=self._queue_item.queue_id,
|
|
||||||
graph_execution_state_id=self._queue_item.session.id,
|
|
||||||
node=self._invocation.model_dump(),
|
|
||||||
source_node_id=source_invocation_id,
|
|
||||||
result=outputs.model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
# TODO(MM2): Create an event for this
|
|
||||||
pass
|
|
||||||
|
|
||||||
except CanceledException:
|
|
||||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
|
||||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
|
||||||
# be able to cancel them mid-execution.
|
|
||||||
#
|
|
||||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
|
||||||
# is executed after each step. This step callback checks if the canceled event is set,
|
|
||||||
# then raises a CanceledException to stop execution immediately.
|
|
||||||
#
|
|
||||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
|
||||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
|
||||||
pass
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error = traceback.format_exc()
|
|
||||||
|
|
||||||
# Save error
|
|
||||||
self._queue_item.session.set_node_error(self._invocation.id, error)
|
|
||||||
self._invoker.services.logger.error(
|
|
||||||
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
|
|
||||||
)
|
|
||||||
self._invoker.services.logger.error(error)
|
|
||||||
|
|
||||||
# Send error event
|
|
||||||
self._invoker.services.events.emit_invocation_error(
|
|
||||||
queue_batch_id=self._queue_item.session_id,
|
|
||||||
queue_item_id=self._queue_item.item_id,
|
|
||||||
queue_id=self._queue_item.queue_id,
|
|
||||||
graph_execution_state_id=self._queue_item.session.id,
|
|
||||||
node=self._invocation.model_dump(),
|
|
||||||
source_node_id=source_invocation_id,
|
|
||||||
error_type=e.__class__.__name__,
|
|
||||||
error=error,
|
|
||||||
user_id=None,
|
|
||||||
project_id=None,
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
# The session is complete if the all invocations are complete or there was an error
|
|
||||||
if self._queue_item.session.is_complete() or cancel_event.is_set():
|
|
||||||
# Send complete event
|
|
||||||
self._invoker.services.session_queue.set_queue_item_session(
|
|
||||||
self._queue_item.item_id, self._queue_item.session
|
|
||||||
)
|
|
||||||
self._invoker.services.events.emit_graph_execution_complete(
|
|
||||||
queue_batch_id=self._queue_item.batch_id,
|
|
||||||
queue_item_id=self._queue_item.item_id,
|
|
||||||
queue_id=self._queue_item.queue_id,
|
|
||||||
graph_execution_state_id=self._queue_item.session.id,
|
|
||||||
)
|
|
||||||
# If we are profiling, stop the profiler and dump the profile & stats
|
|
||||||
if self._profiler:
|
|
||||||
profile_path = self._profiler.stop()
|
|
||||||
stats_path = profile_path.with_suffix(".json")
|
|
||||||
self._invoker.services.performance_statistics.dump_stats(
|
|
||||||
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
|
|
||||||
)
|
|
||||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
|
||||||
# we don't care about that - suppress the error.
|
|
||||||
with suppress(GESStatsNotFoundError):
|
|
||||||
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
|
|
||||||
self._invoker.services.performance_statistics.reset_stats()
|
|
||||||
|
|
||||||
# Set the invocation to None to prepare for the next session
|
|
||||||
self._invocation = None
|
|
||||||
else:
|
|
||||||
# Prepare the next invocation
|
|
||||||
self._invocation = self._queue_item.session.next()
|
|
||||||
else:
|
|
||||||
# The queue was empty, wait for next polling interval or event to try again
|
|
||||||
self._invoker.services.logger.debug("Waiting for next polling interval or event")
|
|
||||||
poll_now_event.wait(self._polling_interval)
|
|
||||||
continue
|
|
||||||
except Exception:
|
|
||||||
# Non-fatal error in processor
|
|
||||||
self._invoker.services.logger.error(
|
|
||||||
f"Non-fatal error in session processor:\n{traceback.format_exc()}"
|
|
||||||
)
|
)
|
||||||
# Cancel the queue item
|
# Wait for next polling interval or event to try again
|
||||||
if self._queue_item is not None:
|
|
||||||
self._invoker.services.session_queue.set_queue_item_session(
|
|
||||||
self._queue_item.item_id, self._queue_item.session
|
|
||||||
)
|
|
||||||
self._invoker.services.session_queue.cancel_queue_item(
|
|
||||||
self._queue_item.item_id, error=traceback.format_exc()
|
|
||||||
)
|
|
||||||
# Reset the invocation to None to prepare for the next session
|
|
||||||
self._invocation = None
|
|
||||||
# Immediately poll for next queue item
|
|
||||||
poll_now_event.wait(self._polling_interval)
|
poll_now_event.wait(self._polling_interval)
|
||||||
continue
|
continue
|
||||||
except Exception:
|
except Exception as e:
|
||||||
# Fatal error in processor, log and pass - we're done here
|
# Fatal error in processor, log and pass - we're done here
|
||||||
self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}")
|
error_type = e.__class__.__name__
|
||||||
|
error_message = str(e)
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}")
|
||||||
|
self._invoker.services.logger.error(error_traceback)
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
stop_event.clear()
|
stop_event.clear()
|
||||||
poll_now_event.clear()
|
poll_now_event.clear()
|
||||||
self._queue_item = None
|
self._queue_item = None
|
||||||
self._thread_semaphore.release()
|
self._thread_semaphore.release()
|
||||||
|
|
||||||
|
def _on_non_fatal_processor_error(
|
||||||
|
self,
|
||||||
|
queue_item: Optional[SessionQueueItem],
|
||||||
|
error_type: str,
|
||||||
|
error_message: str,
|
||||||
|
error_traceback: str,
|
||||||
|
) -> None:
|
||||||
|
# Non-fatal error in processor
|
||||||
|
self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}")
|
||||||
|
self._invoker.services.logger.error(error_traceback)
|
||||||
|
|
||||||
|
if queue_item is not None:
|
||||||
|
# Update the queue item with the completed session
|
||||||
|
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||||
|
# Fail the queue item
|
||||||
|
self._invoker.services.session_queue.fail_queue_item(
|
||||||
|
item_id=queue_item.item_id,
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
)
|
||||||
|
|
||||||
|
for callback in self._on_non_fatal_processor_error_callbacks:
|
||||||
|
callback(
|
||||||
|
queue_item=queue_item,
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
)
|
||||||
|
|||||||
@@ -74,10 +74,17 @@ class SessionQueueBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
|
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||||
"""Cancels a session queue item"""
|
"""Cancels a session queue item"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def fail_queue_item(
|
||||||
|
self, item_id: int, error_type: str, error_message: str, error_traceback: str
|
||||||
|
) -> SessionQueueItem:
|
||||||
|
"""Fails a session queue item"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
||||||
"""Cancels all queue items with matching batch IDs"""
|
"""Cancels all queue items with matching batch IDs"""
|
||||||
|
|||||||
@@ -3,7 +3,16 @@ import json
|
|||||||
from itertools import chain, product
|
from itertools import chain, product
|
||||||
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
from typing import Generator, Iterable, Literal, NamedTuple, Optional, TypeAlias, Union, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, StrictStr, TypeAdapter, field_validator, model_validator
|
from pydantic import (
|
||||||
|
AliasChoices,
|
||||||
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
StrictStr,
|
||||||
|
TypeAdapter,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
from pydantic_core import to_jsonable_python
|
from pydantic_core import to_jsonable_python
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||||
@@ -189,7 +198,13 @@ class SessionQueueItemWithoutGraph(BaseModel):
|
|||||||
session_id: str = Field(
|
session_id: str = Field(
|
||||||
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
|
||||||
)
|
)
|
||||||
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
|
error_type: Optional[str] = Field(default=None, description="The error type if this queue item errored")
|
||||||
|
error_message: Optional[str] = Field(default=None, description="The error message if this queue item errored")
|
||||||
|
error_traceback: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The error traceback if this queue item errored",
|
||||||
|
validation_alias=AliasChoices("error_traceback", "error"),
|
||||||
|
)
|
||||||
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
|
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
|
||||||
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
|
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
|
||||||
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
|
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
|
||||||
|
|||||||
@@ -82,10 +82,18 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
async def _handle_error_event(self, event: FastAPIEvent) -> None:
|
async def _handle_error_event(self, event: FastAPIEvent) -> None:
|
||||||
try:
|
try:
|
||||||
item_id = event[1]["data"]["queue_item_id"]
|
item_id = event[1]["data"]["queue_item_id"]
|
||||||
error = event[1]["data"]["error"]
|
error_type = event[1]["data"]["error_type"]
|
||||||
|
error_message = event[1]["data"]["error_message"]
|
||||||
|
error_traceback = event[1]["data"]["error_traceback"]
|
||||||
queue_item = self.get_queue_item(item_id)
|
queue_item = self.get_queue_item(item_id)
|
||||||
# always set to failed if have an error, even if previously the item was marked completed or canceled
|
# always set to failed if have an error, even if previously the item was marked completed or canceled
|
||||||
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
|
queue_item = self._set_queue_item_status(
|
||||||
|
item_id=queue_item.item_id,
|
||||||
|
status="failed",
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
)
|
||||||
except SessionQueueItemNotFoundError:
|
except SessionQueueItemNotFoundError:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -272,17 +280,22 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
return SessionQueueItem.queue_item_from_dict(dict(result))
|
return SessionQueueItem.queue_item_from_dict(dict(result))
|
||||||
|
|
||||||
def _set_queue_item_status(
|
def _set_queue_item_status(
|
||||||
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
|
self,
|
||||||
|
item_id: int,
|
||||||
|
status: QUEUE_ITEM_STATUS,
|
||||||
|
error_type: Optional[str] = None,
|
||||||
|
error_message: Optional[str] = None,
|
||||||
|
error_traceback: Optional[str] = None,
|
||||||
) -> SessionQueueItem:
|
) -> SessionQueueItem:
|
||||||
try:
|
try:
|
||||||
self.__lock.acquire()
|
self.__lock.acquire()
|
||||||
self.__cursor.execute(
|
self.__cursor.execute(
|
||||||
"""--sql
|
"""--sql
|
||||||
UPDATE session_queue
|
UPDATE session_queue
|
||||||
SET status = ?, error = ?
|
SET status = ?, error_type = ?, error_message = ?, error_traceback = ?
|
||||||
WHERE item_id = ?
|
WHERE item_id = ?
|
||||||
""",
|
""",
|
||||||
(status, error, item_id),
|
(status, error_type, error_message, error_traceback, item_id),
|
||||||
)
|
)
|
||||||
self.__conn.commit()
|
self.__conn.commit()
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -339,26 +352,6 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock.release()
|
self.__lock.release()
|
||||||
return IsFullResult(is_full=is_full)
|
return IsFullResult(is_full=is_full)
|
||||||
|
|
||||||
def delete_queue_item(self, item_id: int) -> SessionQueueItem:
|
|
||||||
queue_item = self.get_queue_item(item_id=item_id)
|
|
||||||
try:
|
|
||||||
self.__lock.acquire()
|
|
||||||
self.__cursor.execute(
|
|
||||||
"""--sql
|
|
||||||
DELETE FROM session_queue
|
|
||||||
WHERE
|
|
||||||
item_id = ?
|
|
||||||
""",
|
|
||||||
(item_id,),
|
|
||||||
)
|
|
||||||
self.__conn.commit()
|
|
||||||
except Exception:
|
|
||||||
self.__conn.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self.__lock.release()
|
|
||||||
return queue_item
|
|
||||||
|
|
||||||
def clear(self, queue_id: str) -> ClearResult:
|
def clear(self, queue_id: str) -> ClearResult:
|
||||||
try:
|
try:
|
||||||
self.__lock.acquire()
|
self.__lock.acquire()
|
||||||
@@ -425,11 +418,34 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
self.__lock.release()
|
self.__lock.release()
|
||||||
return PruneResult(deleted=count)
|
return PruneResult(deleted=count)
|
||||||
|
|
||||||
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
|
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||||
queue_item = self.get_queue_item(item_id)
|
queue_item = self.get_queue_item(item_id)
|
||||||
if queue_item.status not in ["canceled", "failed", "completed"]:
|
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||||
status = "failed" if error is not None else "canceled"
|
queue_item = self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||||
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
|
self.__invoker.services.events.emit_session_canceled(
|
||||||
|
queue_item_id=queue_item.item_id,
|
||||||
|
queue_id=queue_item.queue_id,
|
||||||
|
queue_batch_id=queue_item.batch_id,
|
||||||
|
graph_execution_state_id=queue_item.session_id,
|
||||||
|
)
|
||||||
|
return queue_item
|
||||||
|
|
||||||
|
def fail_queue_item(
|
||||||
|
self,
|
||||||
|
item_id: int,
|
||||||
|
error_type: str,
|
||||||
|
error_message: str,
|
||||||
|
error_traceback: str,
|
||||||
|
) -> SessionQueueItem:
|
||||||
|
queue_item = self.get_queue_item(item_id)
|
||||||
|
if queue_item.status not in ["canceled", "failed", "completed"]:
|
||||||
|
queue_item = self._set_queue_item_status(
|
||||||
|
item_id=item_id,
|
||||||
|
status="failed",
|
||||||
|
error_type=error_type,
|
||||||
|
error_message=error_message,
|
||||||
|
error_traceback=error_traceback,
|
||||||
|
)
|
||||||
self.__invoker.services.events.emit_session_canceled(
|
self.__invoker.services.events.emit_session_canceled(
|
||||||
queue_item_id=queue_item.item_id,
|
queue_item_id=queue_item.item_id,
|
||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
@@ -602,7 +618,9 @@ class SqliteSessionQueue(SessionQueueBase):
|
|||||||
status,
|
status,
|
||||||
priority,
|
priority,
|
||||||
field_values,
|
field_values,
|
||||||
error,
|
error_type,
|
||||||
|
error_message,
|
||||||
|
error_traceback,
|
||||||
created_at,
|
created_at,
|
||||||
updated_at,
|
updated_at,
|
||||||
completed_at,
|
completed_at,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import networkx as nx
|
|||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
GetJsonSchemaHandler,
|
GetJsonSchemaHandler,
|
||||||
|
ValidationError,
|
||||||
field_validator,
|
field_validator,
|
||||||
)
|
)
|
||||||
from pydantic.fields import Field
|
from pydantic.fields import Field
|
||||||
@@ -190,6 +191,39 @@ class UnknownGraphValidationError(ValueError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NodeInputError(ValueError):
|
||||||
|
"""Raised when a node fails preparation. This occurs when a node's inputs are being set from its incomers, but an
|
||||||
|
input fails validation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
node: The node that failed preparation. Note: only successfully set fields will be accurate. Review the error to
|
||||||
|
determine which field caused the failure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, node: BaseInvocation, e: ValidationError):
|
||||||
|
self.original_error = e
|
||||||
|
self.node = node
|
||||||
|
# When preparing a node, we set each input one-at-a-time. We may thus safely assume that the first error
|
||||||
|
# represents the first input that failed.
|
||||||
|
self.failed_input = loc_to_dot_sep(e.errors()[0]["loc"])
|
||||||
|
super().__init__(f"Node {node.id} has invalid incoming input for {self.failed_input}")
|
||||||
|
|
||||||
|
|
||||||
|
def loc_to_dot_sep(loc: tuple[Union[str, int], ...]) -> str:
|
||||||
|
"""Helper to pretty-print pydantic error locations as dot-separated strings.
|
||||||
|
Taken from https://docs.pydantic.dev/latest/errors/errors/#customize-error-messages
|
||||||
|
"""
|
||||||
|
path = ""
|
||||||
|
for i, x in enumerate(loc):
|
||||||
|
if isinstance(x, str):
|
||||||
|
if i > 0:
|
||||||
|
path += "."
|
||||||
|
path += x
|
||||||
|
else:
|
||||||
|
path += f"[{x}]"
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("iterate_output")
|
@invocation_output("iterate_output")
|
||||||
class IterateInvocationOutput(BaseInvocationOutput):
|
class IterateInvocationOutput(BaseInvocationOutput):
|
||||||
"""Used to connect iteration outputs. Will be expanded to a specific output."""
|
"""Used to connect iteration outputs. Will be expanded to a specific output."""
|
||||||
@@ -821,7 +855,10 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
# Get values from edges
|
# Get values from edges
|
||||||
if next_node is not None:
|
if next_node is not None:
|
||||||
self._prepare_inputs(next_node)
|
try:
|
||||||
|
self._prepare_inputs(next_node)
|
||||||
|
except ValidationError as e:
|
||||||
|
raise NodeInputError(next_node, e)
|
||||||
|
|
||||||
# If next is still none, there's no next node, return None
|
# If next is still none, there's no next node, return None
|
||||||
return next_node
|
return next_node
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from invokeai.app.services.shared.sqlite_migrator.migrations.migration_6 import
|
|||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_7 import build_migration_7
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_8 import build_migration_8
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_9 import build_migration_9
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_10 import build_migration_10
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
|
||||||
|
|
||||||
|
|
||||||
@@ -41,6 +42,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
|
|||||||
migrator.register_migration(build_migration_7())
|
migrator.register_migration(build_migration_7())
|
||||||
migrator.register_migration(build_migration_8(app_config=config))
|
migrator.register_migration(build_migration_8(app_config=config))
|
||||||
migrator.register_migration(build_migration_9())
|
migrator.register_migration(build_migration_9())
|
||||||
|
migrator.register_migration(build_migration_10())
|
||||||
migrator.run_migrations()
|
migrator.run_migrations()
|
||||||
|
|
||||||
return db
|
return db
|
||||||
|
|||||||
@@ -0,0 +1,35 @@
|
|||||||
|
import sqlite3
|
||||||
|
|
||||||
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||||
|
|
||||||
|
|
||||||
|
class Migration10Callback:
|
||||||
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
self._update_error_cols(cursor)
|
||||||
|
|
||||||
|
def _update_error_cols(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
"""
|
||||||
|
- Adds `error_type` and `error_message` columns to the session queue table.
|
||||||
|
- Renames the `error` column to `error_traceback`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor.execute("ALTER TABLE session_queue ADD COLUMN error_type TEXT;")
|
||||||
|
cursor.execute("ALTER TABLE session_queue ADD COLUMN error_message TEXT;")
|
||||||
|
cursor.execute("ALTER TABLE session_queue RENAME COLUMN error TO error_traceback;")
|
||||||
|
|
||||||
|
|
||||||
|
def build_migration_10() -> Migration:
|
||||||
|
"""
|
||||||
|
Build the migration from database version 9 to 10.
|
||||||
|
|
||||||
|
This migration does the following:
|
||||||
|
- Adds `error_type` and `error_message` columns to the session queue table.
|
||||||
|
- Renames the `error` column to `error_traceback`.
|
||||||
|
"""
|
||||||
|
migration_10 = Migration(
|
||||||
|
from_version=9,
|
||||||
|
to_version=10,
|
||||||
|
callback=Migration10Callback(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return migration_10
|
||||||
@@ -39,13 +39,19 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
|
|||||||
actionCreator: socketInvocationError,
|
actionCreator: socketInvocationError,
|
||||||
effect: (action, { getState }) => {
|
effect: (action, { getState }) => {
|
||||||
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
|
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
|
||||||
const { source_node_id, error_type, graph_execution_state_id } = action.payload.data;
|
const { source_node_id, error_type, error_message, error_traceback, graph_execution_state_id } =
|
||||||
|
action.payload.data;
|
||||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||||
if (nes) {
|
if (nes) {
|
||||||
nes.status = zNodeStatus.enum.FAILED;
|
nes.status = zNodeStatus.enum.FAILED;
|
||||||
nes.error = action.payload.data.error;
|
|
||||||
nes.progress = null;
|
nes.progress = null;
|
||||||
nes.progressImage = null;
|
nes.progressImage = null;
|
||||||
|
|
||||||
|
nes.error = {
|
||||||
|
error_type,
|
||||||
|
error_message,
|
||||||
|
error_traceback,
|
||||||
|
};
|
||||||
upsertExecutionState(nes.nodeId, nes);
|
upsertExecutionState(nes.nodeId, nes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -70,13 +70,18 @@ export const isInvocationNodeData = (node?: AnyNodeData | null): node is Invocat
|
|||||||
|
|
||||||
// #region NodeExecutionState
|
// #region NodeExecutionState
|
||||||
export const zNodeStatus = z.enum(['PENDING', 'IN_PROGRESS', 'COMPLETED', 'FAILED']);
|
export const zNodeStatus = z.enum(['PENDING', 'IN_PROGRESS', 'COMPLETED', 'FAILED']);
|
||||||
|
const zNodeError = z.object({
|
||||||
|
error_type: z.string(),
|
||||||
|
error_message: z.string(),
|
||||||
|
error_traceback: z.string(),
|
||||||
|
});
|
||||||
const zNodeExecutionState = z.object({
|
const zNodeExecutionState = z.object({
|
||||||
nodeId: z.string().trim().min(1),
|
nodeId: z.string().trim().min(1),
|
||||||
status: zNodeStatus,
|
status: zNodeStatus,
|
||||||
progress: z.number().nullable(),
|
progress: z.number().nullable(),
|
||||||
progressImage: zProgressImage.nullable(),
|
progressImage: zProgressImage.nullable(),
|
||||||
error: z.string().nullable(),
|
|
||||||
outputs: z.array(z.any()),
|
outputs: z.array(z.any()),
|
||||||
|
error: zNodeError.nullable(),
|
||||||
});
|
});
|
||||||
export type NodeExecutionState = z.infer<typeof zNodeExecutionState>;
|
export type NodeExecutionState = z.infer<typeof zNodeExecutionState>;
|
||||||
// #endregion
|
// #endregion
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ const QueueItemComponent = ({ queueItemDTO }: Props) => {
|
|||||||
</Button>
|
</Button>
|
||||||
</ButtonGroup>
|
</ButtonGroup>
|
||||||
</Flex>
|
</Flex>
|
||||||
{queueItem?.error && (
|
{(queueItem?.error_traceback || queueItem?.error_message) && (
|
||||||
<Flex
|
<Flex
|
||||||
layerStyle="second"
|
layerStyle="second"
|
||||||
p={3}
|
p={3}
|
||||||
@@ -89,7 +89,7 @@ const QueueItemComponent = ({ queueItemDTO }: Props) => {
|
|||||||
<Heading size="sm" color="error.400">
|
<Heading size="sm" color="error.400">
|
||||||
{t('common.error')}
|
{t('common.error')}
|
||||||
</Heading>
|
</Heading>
|
||||||
<pre>{queueItem.error}</pre>
|
<pre>{queueItem?.error_traceback || queueItem?.error_message}</pre>
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
<Flex layerStyle="second" h={512} w="full" borderRadius="base" alignItems="center" justifyContent="center">
|
<Flex layerStyle="second" h={512} w="full" borderRadius="base" alignItems="center" justifyContent="center">
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ const initialSystemState: SystemState = {
|
|||||||
shouldUseWatermarker: false,
|
shouldUseWatermarker: false,
|
||||||
shouldEnableInformationalPopovers: false,
|
shouldEnableInformationalPopovers: false,
|
||||||
status: 'DISCONNECTED',
|
status: 'DISCONNECTED',
|
||||||
|
cancellations: [],
|
||||||
};
|
};
|
||||||
|
|
||||||
export const systemSlice = createSlice({
|
export const systemSlice = createSlice({
|
||||||
@@ -88,6 +89,7 @@ export const systemSlice = createSlice({
|
|||||||
* Invocation Started
|
* Invocation Started
|
||||||
*/
|
*/
|
||||||
builder.addCase(socketInvocationStarted, (state) => {
|
builder.addCase(socketInvocationStarted, (state) => {
|
||||||
|
state.cancellations = [];
|
||||||
state.denoiseProgress = null;
|
state.denoiseProgress = null;
|
||||||
state.status = 'PROCESSING';
|
state.status = 'PROCESSING';
|
||||||
});
|
});
|
||||||
@@ -105,6 +107,12 @@ export const systemSlice = createSlice({
|
|||||||
queue_batch_id: batch_id,
|
queue_batch_id: batch_id,
|
||||||
} = action.payload.data;
|
} = action.payload.data;
|
||||||
|
|
||||||
|
if (state.cancellations.includes(session_id)) {
|
||||||
|
// Do not update the progress if this session has been cancelled. This prevents a race condition where we get a
|
||||||
|
// progress update after the session has been cancelled.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
state.denoiseProgress = {
|
state.denoiseProgress = {
|
||||||
step,
|
step,
|
||||||
total_steps,
|
total_steps,
|
||||||
@@ -146,6 +154,7 @@ export const systemSlice = createSlice({
|
|||||||
if (['completed', 'canceled', 'failed'].includes(action.payload.data.queue_item.status)) {
|
if (['completed', 'canceled', 'failed'].includes(action.payload.data.queue_item.status)) {
|
||||||
state.status = 'CONNECTED';
|
state.status = 'CONNECTED';
|
||||||
state.denoiseProgress = null;
|
state.denoiseProgress = null;
|
||||||
|
state.cancellations.push(action.payload.data.queue_item.session_id);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
@@ -177,5 +186,5 @@ export const systemPersistConfig: PersistConfig<SystemState> = {
|
|||||||
name: systemSlice.name,
|
name: systemSlice.name,
|
||||||
initialState: initialSystemState,
|
initialState: initialSystemState,
|
||||||
migrate: migrateSystemState,
|
migrate: migrateSystemState,
|
||||||
persistDenylist: ['isConnected', 'denoiseProgress', 'status'],
|
persistDenylist: ['isConnected', 'denoiseProgress', 'status', 'cancellations'],
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -55,4 +55,5 @@ export interface SystemState {
|
|||||||
shouldUseWatermarker: boolean;
|
shouldUseWatermarker: boolean;
|
||||||
status: SystemStatus;
|
status: SystemStatus;
|
||||||
shouldEnableInformationalPopovers: boolean;
|
shouldEnableInformationalPopovers: boolean;
|
||||||
|
cancellations: string[]
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -116,7 +116,8 @@ export type InvocationErrorEvent = {
|
|||||||
node: BaseNode;
|
node: BaseNode;
|
||||||
source_node_id: string;
|
source_node_id: string;
|
||||||
error_type: string;
|
error_type: string;
|
||||||
error: string;
|
error_message: string;
|
||||||
|
error_traceback: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -187,7 +188,9 @@ export type QueueItemStatusChangedEvent = {
|
|||||||
batch_id: string;
|
batch_id: string;
|
||||||
session_id: string;
|
session_id: string;
|
||||||
status: components['schemas']['SessionQueueItemDTO']['status'];
|
status: components['schemas']['SessionQueueItemDTO']['status'];
|
||||||
error: string | undefined;
|
error_type?: string | null;
|
||||||
|
error_message?: string | null;
|
||||||
|
error_traceback?: string | null;
|
||||||
created_at: string;
|
created_at: string;
|
||||||
updated_at: string;
|
updated_at: string;
|
||||||
started_at: string | undefined;
|
started_at: string | undefined;
|
||||||
|
|||||||
Reference in New Issue
Block a user