diff --git a/invokeai/app/services/session_processor/session_processor_base.py b/invokeai/app/services/session_processor/session_processor_base.py index bfae74e5fe..1436627a9e 100644 --- a/invokeai/app/services/session_processor/session_processor_base.py +++ b/invokeai/app/services/session_processor/session_processor_base.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from threading import Event -from types import TracebackType from typing import Optional, Protocol from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput @@ -71,9 +70,9 @@ class OnNodeError(Protocol): self, invocation: BaseInvocation, queue_item: SessionQueueItem, - exc_type: type, - exc_value: BaseException, - exc_traceback: TracebackType, + error_type: str, + error_message: str, + error_traceback: str, ) -> bool: ... @@ -88,8 +87,8 @@ class OnAfterRunSession(Protocol): class OnNonFatalProcessorError(Protocol): def __call__( self, - exc_type: type, - exc_value: BaseException, - exc_traceback: TracebackType, - queue_item: Optional[SessionQueueItem] = None, + queue_item: Optional[SessionQueueItem], + error_type: str, + error_message: str, + error_traceback: str, ) -> bool: ... diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index cddb7cdc03..49277a105d 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -2,7 +2,6 @@ import traceback from contextlib import suppress from threading import BoundedSemaphore, Thread from threading import Event as ThreadEvent -from types import TracebackType from typing import Optional from fastapi_events.handlers.local import local_handler @@ -30,12 +29,6 @@ from .session_processor_base import InvocationServices, SessionProcessorBase, Se from .session_processor_common import SessionProcessorStatus -def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str: - """Formats a stacktrace as a string""" - - return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) - - class DefaultSessionRunner(SessionRunnerBase): """Processes a single session's invocations""" @@ -71,10 +64,16 @@ class DefaultSessionRunner(SessionRunnerBase): invocation = queue_item.session.next() # Anything other than a `NodeInputError` is handled as a processor error except NodeInputError as e: - # Must extract the exception traceback here to not lose its stacktrace when we change scope - traceback = e.__traceback__ - assert traceback is not None - self._on_node_error(e.node, queue_item, type(e), e, traceback) + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._on_node_error( + invocation=e.node, + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) break if invocation is None or self._cancel_event.is_set(): @@ -126,10 +125,16 @@ class DefaultSessionRunner(SessionRunnerBase): # loop go to its next iteration, and the cancel event will be handled correctly. pass except Exception as e: - # Must extract the exception traceback here to not lose its stacktrace when we change scope - exc_traceback = e.__traceback__ - assert exc_traceback is not None - self._on_node_error(invocation, queue_item, type(e), e, exc_traceback) + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._on_node_error( + invocation=invocation, + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: # If profiling is enabled, start the profiler @@ -166,7 +171,7 @@ class DefaultSessionRunner(SessionRunnerBase): self._services.performance_statistics.reset_stats() for callback in self._on_after_run_session_callbacks: - callback(queue_item) + callback(queue_item=queue_item) def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): """Run before a node is executed""" @@ -181,7 +186,7 @@ class DefaultSessionRunner(SessionRunnerBase): ) for callback in self._on_before_run_node_callbacks: - callback(invocation, queue_item) + callback(invocation=invocation, queue_item=queue_item) def _on_after_run_node( self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput @@ -199,23 +204,23 @@ class DefaultSessionRunner(SessionRunnerBase): ) for callback in self._on_after_run_node_callbacks: - callback(invocation, queue_item, output) + callback(invocation=invocation, queue_item=queue_item, output=output) def _on_node_error( self, invocation: BaseInvocation, queue_item: SessionQueueItem, - exc_type: type, - exc_value: BaseException, - exc_traceback: TracebackType, + error_type: str, + error_message: str, + error_traceback: str, ): - stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) - - queue_item.session.set_node_error(invocation.id, stacktrace) + # Node errors do not get the full traceback. Only the queue item gets the full traceback. + node_error = f"{error_type}: {error_message}" + queue_item.session.set_node_error(invocation.id, node_error) self._services.logger.error( - f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {exc_type.__name__}" + f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}): {error_message}" ) - self._services.logger.error(stacktrace) + self._services.logger.error(error_traceback) # Send error event self._services.events.emit_invocation_error( @@ -225,14 +230,21 @@ class DefaultSessionRunner(SessionRunnerBase): graph_execution_state_id=queue_item.session.id, node=invocation.model_dump(), source_node_id=queue_item.session.prepared_source_mapping[invocation.id], - error_type=exc_type.__name__, - error=stacktrace, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, user_id=getattr(queue_item, "user_id", None), project_id=getattr(queue_item, "project_id", None), ) for callback in self._on_node_error_callbacks: - callback(invocation, queue_item, exc_type, exc_value, exc_traceback) + callback( + invocation=invocation, + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) class DefaultSessionProcessor(SessionProcessorBase): @@ -374,16 +386,25 @@ class DefaultSessionProcessor(SessionProcessorBase): self.session_runner.run(queue_item=self._queue_item) except Exception as e: - # Must extract the exception traceback here to not lose its stacktrace when we change scope - exc_traceback = e.__traceback__ - assert exc_traceback is not None - self._on_non_fatal_processor_error(self._queue_item, type(e), e, exc_traceback) - # Immediately poll for next queue item + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._on_non_fatal_processor_error( + queue_item=self._queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) + # Wait for next polling interval or event to try again poll_now_event.wait(self._polling_interval) continue - except Exception: + except Exception as e: # Fatal error in processor, log and pass - we're done here - self._invoker.services.logger.error(f"Fatal Error in session processor:\n{traceback.format_exc()}") + error_type = e.__class__.__name__ + error_message = str(e) + error_traceback = traceback.format_exc() + self._invoker.services.logger.error(f"Fatal Error in session processor {error_type}: {error_message}") + self._invoker.services.logger.error(error_traceback) pass finally: stop_event.clear() @@ -394,19 +415,29 @@ class DefaultSessionProcessor(SessionProcessorBase): def _on_non_fatal_processor_error( self, queue_item: Optional[SessionQueueItem], - exc_type: type, - exc_value: BaseException, - exc_traceback: TracebackType, + error_type: str, + error_message: str, + error_traceback: str, ) -> None: - stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) # Non-fatal error in processor - self._invoker.services.logger.error(f"Non-fatal error in session processor: {exc_type.__name__}") - self._invoker.services.logger.error(stacktrace) + self._invoker.services.logger.error(f"Non-fatal error in session processor {error_type}: {error_message}") + self._invoker.services.logger.error(error_traceback) + if queue_item is not None: # Update the queue item with the completed session self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session) - # And cancel the queue item with an error - self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace) + # Fail the queue item + self._invoker.services.session_queue.fail_queue_item( + item_id=queue_item.item_id, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + ) for callback in self._on_non_fatal_processor_error_callbacks: - callback(exc_type, exc_value, exc_traceback, queue_item) + callback( + queue_item=queue_item, + error_type=error_type, + error_message=error_message, + error_traceback=error_traceback, + )