Improve EventListener and TraceCollectionListener for improved event… (#4160)

* Refactor EventListener and TraceCollectionListener for improved event handling

- Removed unused threading and method branches from EventListener to simplify the code.
- Updated event handling methods in EventListener to use new formatter methods for better clarity and consistency.
- Refactored TraceCollectionListener to eliminate unnecessary parameters in formatter calls, enhancing readability.
- Simplified ConsoleFormatter by removing outdated tree management methods and focusing on panel-based output for status updates.
- Enhanced ToolUsage to track run attempts for better tool usage metrics.

* clearer for knowledge retrieval and dropped some reduancies

* Refactor EventListener and ConsoleFormatter for improved clarity and consistency

- Removed the MCPToolExecutionCompletedEvent handler from EventListener to streamline event processing.
- Updated ConsoleFormatter to enhance output formatting by adding line breaks for better readability in status content.
- Renamed status messages for MCP Tool execution to provide clearer context during tool operations.

* fix run attempt incrementation

* task name consistency

* memory events consistency

* ensure hitl works

* linting
This commit is contained in:
Lorenze Jay
2025-12-30 11:36:31 -08:00
committed by GitHub
parent b9dd166a6b
commit 467ee2917e
8 changed files with 651 additions and 1987 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import time import time
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from crewai.agents.parser import AgentFinish
from crewai.events.event_listener import event_listener from crewai.events.event_listener import event_listener
from crewai.memory.entity.entity_memory_item import EntityMemoryItem from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
@@ -29,7 +30,7 @@ class CrewAgentExecutorMixin:
_i18n: I18N _i18n: I18N
_printer: Printer = Printer() _printer: Printer = Printer()
def _create_short_term_memory(self, output) -> None: def _create_short_term_memory(self, output: AgentFinish) -> None:
"""Create and save a short-term memory item if conditions are met.""" """Create and save a short-term memory item if conditions are met."""
if ( if (
self.crew self.crew
@@ -53,7 +54,7 @@ class CrewAgentExecutorMixin:
"error", f"Failed to add to short term memory: {e}" "error", f"Failed to add to short term memory: {e}"
) )
def _create_external_memory(self, output) -> None: def _create_external_memory(self, output: AgentFinish) -> None:
"""Create and save a external-term memory item if conditions are met.""" """Create and save a external-term memory item if conditions are met."""
if ( if (
self.crew self.crew
@@ -75,7 +76,7 @@ class CrewAgentExecutorMixin:
"error", f"Failed to add to external memory: {e}" "error", f"Failed to add to external memory: {e}"
) )
def _create_long_term_memory(self, output) -> None: def _create_long_term_memory(self, output: AgentFinish) -> None:
"""Create and save long-term and entity memory items based on evaluation.""" """Create and save long-term and entity memory items based on evaluation."""
if ( if (
self.crew self.crew
@@ -136,40 +137,50 @@ class CrewAgentExecutorMixin:
) )
def _ask_human_input(self, final_answer: str) -> str: def _ask_human_input(self, final_answer: str) -> str:
"""Prompt human input with mode-appropriate messaging.""" """Prompt human input with mode-appropriate messaging.
event_listener.formatter.pause_live_updates()
try:
self._printer.print(
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
)
Note: The final answer is already displayed via the AgentLogsExecutionEvent
panel, so we only show the feedback prompt here.
"""
from rich.panel import Panel
from rich.text import Text
formatter = event_listener.formatter
formatter.pause_live_updates()
try:
# Training mode prompt (single iteration) # Training mode prompt (single iteration)
if self.crew and getattr(self.crew, "_train", False): if self.crew and getattr(self.crew, "_train", False):
prompt = ( prompt_text = (
"\n\n=====\n" "TRAINING MODE: Provide feedback to improve the agent's performance.\n\n"
"## TRAINING MODE: Provide feedback to improve the agent's performance.\n"
"This will be used to train better versions of the agent.\n" "This will be used to train better versions of the agent.\n"
"Please provide detailed feedback about the result quality and reasoning process.\n" "Please provide detailed feedback about the result quality and reasoning process."
"=====\n"
) )
title = "🎓 Training Feedback Required"
# Regular human-in-the-loop prompt (multiple iterations) # Regular human-in-the-loop prompt (multiple iterations)
else: else:
prompt = ( prompt_text = (
"\n\n=====\n" "Provide feedback on the Final Result above.\n\n"
"## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n" "• If you are happy with the result, simply hit Enter without typing anything.\n"
"Please follow these guidelines:\n" "• Otherwise, provide specific improvement requests.\n"
" - If you are happy with the result, simply hit Enter without typing anything.\n" "• You can provide multiple rounds of feedback until satisfied."
" - Otherwise, provide specific improvement requests.\n"
" - You can provide multiple rounds of feedback until satisfied.\n"
"=====\n"
) )
title = "💬 Human Feedback Required"
content = Text()
content.append(prompt_text, style="yellow")
prompt_panel = Panel(
content,
title=title,
border_style="yellow",
padding=(1, 2),
)
formatter.console.print(prompt_panel)
self._printer.print(content=prompt, color="bold_yellow")
response = input() response = input()
if response.strip() != "": if response.strip() != "":
self._printer.print( formatter.console.print("\n[cyan]Processing your feedback...[/cyan]")
content="\nProcessing your feedback...", color="cyan"
)
return response return response
finally: finally:
event_listener.formatter.resume_live_updates() formatter.resume_live_updates()

View File

@@ -7,6 +7,7 @@ and memory management.
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable
import logging
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler from pydantic import BaseModel, GetCoreSchemaHandler
@@ -51,6 +52,8 @@ from crewai.utilities.tool_utils import (
from crewai.utilities.training_handler import CrewTrainingHandler from crewai.utilities.training_handler import CrewTrainingHandler
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai.agent import Agent from crewai.agent import Agent
from crewai.agents.tools_handler import ToolsHandler from crewai.agents.tools_handler import ToolsHandler
@@ -541,7 +544,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if self.agent is None: if self.agent is None:
raise ValueError("Agent cannot be None") raise ValueError("Agent cannot be None")
crewai_event_bus.emit( future = crewai_event_bus.emit(
self.agent, self.agent,
AgentLogsExecutionEvent( AgentLogsExecutionEvent(
agent_role=self.agent.role, agent_role=self.agent.role,
@@ -551,6 +554,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
), ),
) )
if future is not None:
try:
future.result(timeout=5.0)
except Exception as e:
logger.error(f"Failed to show logs for agent execution event: {e}")
def _handle_crew_training_output( def _handle_crew_training_output(
self, result: AgentFinish, human_feedback: str | None = None self, result: AgentFinish, human_feedback: str | None = None
) -> None: ) -> None:

View File

@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
from io import StringIO from io import StringIO
import threading
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from pydantic import Field, PrivateAttr from pydantic import Field, PrivateAttr
@@ -17,8 +16,6 @@ from crewai.events.types.a2a_events import (
A2AResponseReceivedEvent, A2AResponseReceivedEvent,
) )
from crewai.events.types.agent_events import ( from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent, LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent, LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent, LiteAgentExecutionStartedEvent,
@@ -48,7 +45,6 @@ from crewai.events.types.flow_events import (
from crewai.events.types.knowledge_events import ( from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent, KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent, KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent, KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent, KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent, KnowledgeSearchQueryFailedEvent,
@@ -112,7 +108,6 @@ class EventListener(BaseEventListener):
text_stream: StringIO = StringIO() text_stream: StringIO = StringIO()
knowledge_retrieval_in_progress: bool = False knowledge_retrieval_in_progress: bool = False
knowledge_query_in_progress: bool = False knowledge_query_in_progress: bool = False
method_branches: dict[str, Any] = Field(default_factory=dict)
def __new__(cls) -> EventListener: def __new__(cls) -> EventListener:
if cls._instance is None: if cls._instance is None:
@@ -126,10 +121,8 @@ class EventListener(BaseEventListener):
self._telemetry = Telemetry() self._telemetry = Telemetry()
self._telemetry.set_tracer() self._telemetry.set_tracer()
self.execution_spans = {} self.execution_spans = {}
self.method_branches = {}
self._initialized = True self._initialized = True
self.formatter = ConsoleFormatter(verbose=True) self.formatter = ConsoleFormatter(verbose=True)
self._crew_tree_lock = threading.Condition()
# Initialize trace listener with formatter for memory event handling # Initialize trace listener with formatter for memory event handling
trace_listener = TraceCollectionListener() trace_listener = TraceCollectionListener()
@@ -140,12 +133,10 @@ class EventListener(BaseEventListener):
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None: def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
@crewai_event_bus.on(CrewKickoffStartedEvent) @crewai_event_bus.on(CrewKickoffStartedEvent)
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None: def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
with self._crew_tree_lock: self.formatter.handle_crew_started(event.crew_name or "Crew", source.id)
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id) source._execution_span = self._telemetry.crew_execution_span(
source._execution_span = self._telemetry.crew_execution_span( source, event.inputs
source, event.inputs )
)
self._crew_tree_lock.notify_all()
@crewai_event_bus.on(CrewKickoffCompletedEvent) @crewai_event_bus.on(CrewKickoffCompletedEvent)
def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None: def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None:
@@ -153,8 +144,7 @@ class EventListener(BaseEventListener):
final_string_output = event.output.raw final_string_output = event.output.raw
self._telemetry.end_crew(source, final_string_output) self._telemetry.end_crew(source, final_string_output)
self.formatter.update_crew_tree( self.formatter.handle_crew_status(
self.formatter.current_crew_tree,
event.crew_name or "Crew", event.crew_name or "Crew",
source.id, source.id,
"completed", "completed",
@@ -163,8 +153,7 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(CrewKickoffFailedEvent) @crewai_event_bus.on(CrewKickoffFailedEvent)
def on_crew_failed(source: Any, event: CrewKickoffFailedEvent) -> None: def on_crew_failed(source: Any, event: CrewKickoffFailedEvent) -> None:
self.formatter.update_crew_tree( self.formatter.handle_crew_status(
self.formatter.current_crew_tree,
event.crew_name or "Crew", event.crew_name or "Crew",
source.id, source.id,
"failed", "failed",
@@ -197,23 +186,22 @@ class EventListener(BaseEventListener):
# ----------- TASK EVENTS ----------- # ----------- TASK EVENTS -----------
def get_task_name(source: Any) -> str | None:
return (
source.name
if hasattr(source, "name") and source.name
else source.description
if hasattr(source, "description") and source.description
else None
)
@crewai_event_bus.on(TaskStartedEvent) @crewai_event_bus.on(TaskStartedEvent)
def on_task_started(source: Any, event: TaskStartedEvent) -> None: def on_task_started(source: Any, event: TaskStartedEvent) -> None:
span = self._telemetry.task_started(crew=source.agent.crew, task=source) span = self._telemetry.task_started(crew=source.agent.crew, task=source)
self.execution_spans[source] = span self.execution_spans[source] = span
with self._crew_tree_lock: task_name = get_task_name(source)
self._crew_tree_lock.wait_for( self.formatter.handle_task_started(source.id, task_name)
lambda: self.formatter.current_crew_tree is not None, timeout=5.0
)
if self.formatter.current_crew_tree is not None:
task_name = (
source.name if hasattr(source, "name") and source.name else None
)
self.formatter.create_task_branch(
self.formatter.current_crew_tree, source.id, task_name
)
@crewai_event_bus.on(TaskCompletedEvent) @crewai_event_bus.on(TaskCompletedEvent)
def on_task_completed(source: Any, event: TaskCompletedEvent) -> None: def on_task_completed(source: Any, event: TaskCompletedEvent) -> None:
@@ -224,13 +212,9 @@ class EventListener(BaseEventListener):
self.execution_spans[source] = None self.execution_spans[source] = None
# Pass task name if it exists # Pass task name if it exists
task_name = source.name if hasattr(source, "name") and source.name else None task_name = get_task_name(source)
self.formatter.update_task_status( self.formatter.handle_task_status(
self.formatter.current_crew_tree, source.id, source.agent.role, "completed", task_name
source.id,
source.agent.role,
"completed",
task_name,
) )
@crewai_event_bus.on(TaskFailedEvent) @crewai_event_bus.on(TaskFailedEvent)
@@ -242,37 +226,12 @@ class EventListener(BaseEventListener):
self.execution_spans[source] = None self.execution_spans[source] = None
# Pass task name if it exists # Pass task name if it exists
task_name = source.name if hasattr(source, "name") and source.name else None task_name = get_task_name(source)
self.formatter.update_task_status( self.formatter.handle_task_status(
self.formatter.current_crew_tree, source.id, source.agent.role, "failed", task_name
source.id,
source.agent.role,
"failed",
task_name,
) )
# ----------- AGENT EVENTS ----------- # ----------- AGENT EVENTS -----------
@crewai_event_bus.on(AgentExecutionStartedEvent)
def on_agent_execution_started(
_: Any, event: AgentExecutionStartedEvent
) -> None:
self.formatter.create_agent_branch(
self.formatter.current_task_branch,
event.agent.role,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(AgentExecutionCompletedEvent)
def on_agent_execution_completed(
_: Any, event: AgentExecutionCompletedEvent
) -> None:
self.formatter.update_agent_status(
self.formatter.current_agent_branch,
event.agent.role,
self.formatter.current_crew_tree,
)
# ----------- LITE AGENT EVENTS ----------- # ----------- LITE AGENT EVENTS -----------
@crewai_event_bus.on(LiteAgentExecutionStartedEvent) @crewai_event_bus.on(LiteAgentExecutionStartedEvent)
@@ -316,79 +275,61 @@ class EventListener(BaseEventListener):
self._telemetry.flow_execution_span( self._telemetry.flow_execution_span(
event.flow_name, list(source._methods.keys()) event.flow_name, list(source._methods.keys())
) )
tree = self.formatter.create_flow_tree(event.flow_name, str(source.flow_id)) self.formatter.handle_flow_created(event.flow_name, str(source.flow_id))
self.formatter.current_flow_tree = tree self.formatter.handle_flow_started(event.flow_name, str(source.flow_id))
self.formatter.start_flow(event.flow_name, str(source.flow_id))
@crewai_event_bus.on(FlowFinishedEvent) @crewai_event_bus.on(FlowFinishedEvent)
def on_flow_finished(source: Any, event: FlowFinishedEvent) -> None: def on_flow_finished(source: Any, event: FlowFinishedEvent) -> None:
self.formatter.update_flow_status( self.formatter.handle_flow_status(
self.formatter.current_flow_tree, event.flow_name, source.flow_id event.flow_name,
source.flow_id,
) )
@crewai_event_bus.on(MethodExecutionStartedEvent) @crewai_event_bus.on(MethodExecutionStartedEvent)
def on_method_execution_started( def on_method_execution_started(
_: Any, event: MethodExecutionStartedEvent _: Any, event: MethodExecutionStartedEvent
) -> None: ) -> None:
method_branch = self.method_branches.get(event.method_name) self.formatter.handle_method_status(
updated_branch = self.formatter.update_method_status(
method_branch,
self.formatter.current_flow_tree,
event.method_name, event.method_name,
"running", "running",
) )
self.method_branches[event.method_name] = updated_branch
@crewai_event_bus.on(MethodExecutionFinishedEvent) @crewai_event_bus.on(MethodExecutionFinishedEvent)
def on_method_execution_finished( def on_method_execution_finished(
_: Any, event: MethodExecutionFinishedEvent _: Any, event: MethodExecutionFinishedEvent
) -> None: ) -> None:
method_branch = self.method_branches.get(event.method_name) self.formatter.handle_method_status(
updated_branch = self.formatter.update_method_status(
method_branch,
self.formatter.current_flow_tree,
event.method_name, event.method_name,
"completed", "completed",
) )
self.method_branches[event.method_name] = updated_branch
@crewai_event_bus.on(MethodExecutionFailedEvent) @crewai_event_bus.on(MethodExecutionFailedEvent)
def on_method_execution_failed( def on_method_execution_failed(
_: Any, event: MethodExecutionFailedEvent _: Any, event: MethodExecutionFailedEvent
) -> None: ) -> None:
method_branch = self.method_branches.get(event.method_name) self.formatter.handle_method_status(
updated_branch = self.formatter.update_method_status(
method_branch,
self.formatter.current_flow_tree,
event.method_name, event.method_name,
"failed", "failed",
) )
self.method_branches[event.method_name] = updated_branch
@crewai_event_bus.on(MethodExecutionPausedEvent) @crewai_event_bus.on(MethodExecutionPausedEvent)
def on_method_execution_paused( def on_method_execution_paused(
_: Any, event: MethodExecutionPausedEvent _: Any, event: MethodExecutionPausedEvent
) -> None: ) -> None:
method_branch = self.method_branches.get(event.method_name) self.formatter.handle_method_status(
updated_branch = self.formatter.update_method_status(
method_branch,
self.formatter.current_flow_tree,
event.method_name, event.method_name,
"paused", "paused",
) )
self.method_branches[event.method_name] = updated_branch
@crewai_event_bus.on(FlowPausedEvent) @crewai_event_bus.on(FlowPausedEvent)
def on_flow_paused(_: Any, event: FlowPausedEvent) -> None: def on_flow_paused(_: Any, event: FlowPausedEvent) -> None:
self.formatter.update_flow_status( self.formatter.handle_flow_status(
self.formatter.current_flow_tree,
event.flow_name, event.flow_name,
event.flow_id, event.flow_id,
"paused", "paused",
) )
# ----------- TOOL USAGE EVENTS ----------- # ----------- TOOL USAGE EVENTS -----------
@crewai_event_bus.on(ToolUsageStartedEvent) @crewai_event_bus.on(ToolUsageStartedEvent)
def on_tool_usage_started(source: Any, event: ToolUsageStartedEvent) -> None: def on_tool_usage_started(source: Any, event: ToolUsageStartedEvent) -> None:
if isinstance(source, LLM): if isinstance(source, LLM):
@@ -398,9 +339,9 @@ class EventListener(BaseEventListener):
) )
else: else:
self.formatter.handle_tool_usage_started( self.formatter.handle_tool_usage_started(
self.formatter.current_agent_branch,
event.tool_name, event.tool_name,
self.formatter.current_crew_tree, event.tool_args,
event.run_attempts,
) )
@crewai_event_bus.on(ToolUsageFinishedEvent) @crewai_event_bus.on(ToolUsageFinishedEvent)
@@ -409,12 +350,6 @@ class EventListener(BaseEventListener):
self.formatter.handle_llm_tool_usage_finished( self.formatter.handle_llm_tool_usage_finished(
event.tool_name, event.tool_name,
) )
else:
self.formatter.handle_tool_usage_finished(
self.formatter.current_tool_branch,
event.tool_name,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(ToolUsageErrorEvent) @crewai_event_bus.on(ToolUsageErrorEvent)
def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None: def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None:
@@ -425,10 +360,9 @@ class EventListener(BaseEventListener):
) )
else: else:
self.formatter.handle_tool_usage_error( self.formatter.handle_tool_usage_error(
self.formatter.current_tool_branch,
event.tool_name, event.tool_name,
event.error, event.error,
self.formatter.current_crew_tree, event.run_attempts,
) )
# ----------- LLM EVENTS ----------- # ----------- LLM EVENTS -----------
@@ -437,32 +371,15 @@ class EventListener(BaseEventListener):
def on_llm_call_started(_: Any, event: LLMCallStartedEvent) -> None: def on_llm_call_started(_: Any, event: LLMCallStartedEvent) -> None:
self.text_stream = StringIO() self.text_stream = StringIO()
self.next_chunk = 0 self.next_chunk = 0
# Capture the returned tool branch and update the current_tool_branch reference
thinking_branch = self.formatter.handle_llm_call_started(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
# Update the formatter's current_tool_branch to ensure proper cleanup
if thinking_branch is not None:
self.formatter.current_tool_branch = thinking_branch
@crewai_event_bus.on(LLMCallCompletedEvent) @crewai_event_bus.on(LLMCallCompletedEvent)
def on_llm_call_completed(_: Any, event: LLMCallCompletedEvent) -> None: def on_llm_call_completed(_: Any, event: LLMCallCompletedEvent) -> None:
self.formatter.handle_llm_stream_completed() self.formatter.handle_llm_stream_completed()
self.formatter.handle_llm_call_completed(
self.formatter.current_tool_branch,
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(LLMCallFailedEvent) @crewai_event_bus.on(LLMCallFailedEvent)
def on_llm_call_failed(_: Any, event: LLMCallFailedEvent) -> None: def on_llm_call_failed(_: Any, event: LLMCallFailedEvent) -> None:
self.formatter.handle_llm_stream_completed() self.formatter.handle_llm_stream_completed()
self.formatter.handle_llm_call_failed( self.formatter.handle_llm_call_failed(event.error)
self.formatter.current_tool_branch,
event.error,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(LLMStreamChunkEvent) @crewai_event_bus.on(LLMStreamChunkEvent)
def on_llm_stream_chunk(_: Any, event: LLMStreamChunkEvent) -> None: def on_llm_stream_chunk(_: Any, event: LLMStreamChunkEvent) -> None:
@@ -473,9 +390,7 @@ class EventListener(BaseEventListener):
accumulated_text = self.text_stream.getvalue() accumulated_text = self.text_stream.getvalue()
self.formatter.handle_llm_stream_chunk( self.formatter.handle_llm_stream_chunk(
event.chunk,
accumulated_text, accumulated_text,
self.formatter.current_crew_tree,
event.call_type, event.call_type,
) )
@@ -515,7 +430,6 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(CrewTestCompletedEvent) @crewai_event_bus.on(CrewTestCompletedEvent)
def on_crew_test_completed(_: Any, event: CrewTestCompletedEvent) -> None: def on_crew_test_completed(_: Any, event: CrewTestCompletedEvent) -> None:
self.formatter.handle_crew_test_completed( self.formatter.handle_crew_test_completed(
self.formatter.current_flow_tree,
event.crew_name or "Crew", event.crew_name or "Crew",
) )
@@ -532,10 +446,7 @@ class EventListener(BaseEventListener):
self.knowledge_retrieval_in_progress = True self.knowledge_retrieval_in_progress = True
self.formatter.handle_knowledge_retrieval_started( self.formatter.handle_knowledge_retrieval_started()
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(KnowledgeRetrievalCompletedEvent) @crewai_event_bus.on(KnowledgeRetrievalCompletedEvent)
def on_knowledge_retrieval_completed( def on_knowledge_retrieval_completed(
@@ -546,24 +457,13 @@ class EventListener(BaseEventListener):
self.knowledge_retrieval_in_progress = False self.knowledge_retrieval_in_progress = False
self.formatter.handle_knowledge_retrieval_completed( self.formatter.handle_knowledge_retrieval_completed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.retrieved_knowledge, event.retrieved_knowledge,
event.query,
) )
@crewai_event_bus.on(KnowledgeQueryStartedEvent)
def on_knowledge_query_started(
_: Any, event: KnowledgeQueryStartedEvent
) -> None:
pass
@crewai_event_bus.on(KnowledgeQueryFailedEvent) @crewai_event_bus.on(KnowledgeQueryFailedEvent)
def on_knowledge_query_failed(_: Any, event: KnowledgeQueryFailedEvent) -> None: def on_knowledge_query_failed(_: Any, event: KnowledgeQueryFailedEvent) -> None:
self.formatter.handle_knowledge_query_failed( self.formatter.handle_knowledge_query_failed(event.error)
self.formatter.current_agent_branch,
event.error,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(KnowledgeQueryCompletedEvent) @crewai_event_bus.on(KnowledgeQueryCompletedEvent)
def on_knowledge_query_completed( def on_knowledge_query_completed(
@@ -575,11 +475,7 @@ class EventListener(BaseEventListener):
def on_knowledge_search_query_failed( def on_knowledge_search_query_failed(
_: Any, event: KnowledgeSearchQueryFailedEvent _: Any, event: KnowledgeSearchQueryFailedEvent
) -> None: ) -> None:
self.formatter.handle_knowledge_search_query_failed( self.formatter.handle_knowledge_search_query_failed(event.error)
self.formatter.current_agent_branch,
event.error,
self.formatter.current_crew_tree,
)
# ----------- REASONING EVENTS ----------- # ----------- REASONING EVENTS -----------
@@ -587,11 +483,7 @@ class EventListener(BaseEventListener):
def on_agent_reasoning_started( def on_agent_reasoning_started(
_: Any, event: AgentReasoningStartedEvent _: Any, event: AgentReasoningStartedEvent
) -> None: ) -> None:
self.formatter.handle_reasoning_started( self.formatter.handle_reasoning_started(event.attempt)
self.formatter.current_agent_branch,
event.attempt,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(AgentReasoningCompletedEvent) @crewai_event_bus.on(AgentReasoningCompletedEvent)
def on_agent_reasoning_completed( def on_agent_reasoning_completed(
@@ -600,14 +492,12 @@ class EventListener(BaseEventListener):
self.formatter.handle_reasoning_completed( self.formatter.handle_reasoning_completed(
event.plan, event.plan,
event.ready, event.ready,
self.formatter.current_crew_tree,
) )
@crewai_event_bus.on(AgentReasoningFailedEvent) @crewai_event_bus.on(AgentReasoningFailedEvent)
def on_agent_reasoning_failed(_: Any, event: AgentReasoningFailedEvent) -> None: def on_agent_reasoning_failed(_: Any, event: AgentReasoningFailedEvent) -> None:
self.formatter.handle_reasoning_failed( self.formatter.handle_reasoning_failed(
event.error, event.error,
self.formatter.current_crew_tree,
) )
# ----------- AGENT LOGGING EVENTS ----------- # ----------- AGENT LOGGING EVENTS -----------
@@ -734,18 +624,6 @@ class EventListener(BaseEventListener):
event.tool_args, event.tool_args,
) )
@crewai_event_bus.on(MCPToolExecutionCompletedEvent)
def on_mcp_tool_execution_completed(
_: Any, event: MCPToolExecutionCompletedEvent
) -> None:
self.formatter.handle_mcp_tool_execution_completed(
event.server_name,
event.tool_name,
event.tool_args,
event.result,
event.execution_duration_ms,
)
@crewai_event_bus.on(MCPToolExecutionFailedEvent) @crewai_event_bus.on(MCPToolExecutionFailedEvent)
def on_mcp_tool_execution_failed( def on_mcp_tool_execution_failed(
_: Any, event: MCPToolExecutionFailedEvent _: Any, event: MCPToolExecutionFailedEvent

View File

@@ -1,7 +1,7 @@
"""Trace collection listener for orchestrating trace collection.""" """Trace collection listener for orchestrating trace collection."""
import os import os
from typing import Any, ClassVar from typing import Any, ClassVar, cast
import uuid import uuid
from typing_extensions import Self from typing_extensions import Self
@@ -105,7 +105,7 @@ class TraceCollectionListener(BaseEventListener):
"""Create or return singleton instance.""" """Create or return singleton instance."""
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
return cls._instance return cast(Self, cls._instance)
def __init__( def __init__(
self, self,
@@ -319,21 +319,12 @@ class TraceCollectionListener(BaseEventListener):
source: Any, event: MemoryQueryCompletedEvent source: Any, event: MemoryQueryCompletedEvent
) -> None: ) -> None:
self._handle_action_event("memory_query_completed", source, event) self._handle_action_event("memory_query_completed", source, event)
if self.formatter and self.memory_retrieval_in_progress:
self.formatter.handle_memory_query_completed(
self.formatter.current_agent_branch,
event.source_type or "memory",
event.query_time_ms,
self.formatter.current_crew_tree,
)
@event_bus.on(MemoryQueryFailedEvent) @event_bus.on(MemoryQueryFailedEvent)
def on_memory_query_failed(source: Any, event: MemoryQueryFailedEvent) -> None: def on_memory_query_failed(source: Any, event: MemoryQueryFailedEvent) -> None:
self._handle_action_event("memory_query_failed", source, event) self._handle_action_event("memory_query_failed", source, event)
if self.formatter and self.memory_retrieval_in_progress: if self.formatter and self.memory_retrieval_in_progress:
self.formatter.handle_memory_query_failed( self.formatter.handle_memory_query_failed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.error, event.error,
event.source_type or "memory", event.source_type or "memory",
) )
@@ -347,10 +338,7 @@ class TraceCollectionListener(BaseEventListener):
self.memory_save_in_progress = True self.memory_save_in_progress = True
self.formatter.handle_memory_save_started( self.formatter.handle_memory_save_started()
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
@event_bus.on(MemorySaveCompletedEvent) @event_bus.on(MemorySaveCompletedEvent)
def on_memory_save_completed( def on_memory_save_completed(
@@ -364,8 +352,6 @@ class TraceCollectionListener(BaseEventListener):
self.memory_save_in_progress = False self.memory_save_in_progress = False
self.formatter.handle_memory_save_completed( self.formatter.handle_memory_save_completed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.save_time_ms, event.save_time_ms,
event.source_type or "memory", event.source_type or "memory",
) )
@@ -375,10 +361,8 @@ class TraceCollectionListener(BaseEventListener):
self._handle_action_event("memory_save_failed", source, event) self._handle_action_event("memory_save_failed", source, event)
if self.formatter and self.memory_save_in_progress: if self.formatter and self.memory_save_in_progress:
self.formatter.handle_memory_save_failed( self.formatter.handle_memory_save_failed(
self.formatter.current_agent_branch,
event.error, event.error,
event.source_type or "memory", event.source_type or "memory",
self.formatter.current_crew_tree,
) )
@event_bus.on(MemoryRetrievalStartedEvent) @event_bus.on(MemoryRetrievalStartedEvent)
@@ -391,10 +375,7 @@ class TraceCollectionListener(BaseEventListener):
self.memory_retrieval_in_progress = True self.memory_retrieval_in_progress = True
self.formatter.handle_memory_retrieval_started( self.formatter.handle_memory_retrieval_started()
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
@event_bus.on(MemoryRetrievalCompletedEvent) @event_bus.on(MemoryRetrievalCompletedEvent)
def on_memory_retrieval_completed( def on_memory_retrieval_completed(
@@ -406,8 +387,6 @@ class TraceCollectionListener(BaseEventListener):
self.memory_retrieval_in_progress = False self.memory_retrieval_in_progress = False
self.formatter.handle_memory_retrieval_completed( self.formatter.handle_memory_retrieval_completed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.memory_content, event.memory_content,
event.retrieval_time_ms, event.retrieval_time_ms,
) )

File diff suppressed because it is too large Load Diff

View File

@@ -249,6 +249,7 @@ class ToolUsage:
"tool_args": self.action.tool_input, "tool_args": self.action.tool_input,
"tool_class": self.action.tool, "tool_class": self.action.tool,
"agent": self.agent, "agent": self.agent,
"run_attempts": self._run_attempts,
} }
if self.agent.fingerprint: # type: ignore if self.agent.fingerprint: # type: ignore
@@ -435,6 +436,7 @@ class ToolUsage:
"tool_args": self.action.tool_input, "tool_args": self.action.tool_input,
"tool_class": self.action.tool, "tool_class": self.action.tool,
"agent": self.agent, "agent": self.agent,
"run_attempts": self._run_attempts,
} }
# TODO: Investigate fingerprint attribute availability on BaseAgent/LiteAgent # TODO: Investigate fingerprint attribute availability on BaseAgent/LiteAgent

View File

@@ -7,22 +7,19 @@ from crewai.events.event_listener import event_listener
class TestFlowHumanInputIntegration: class TestFlowHumanInputIntegration:
"""Test integration between Flow execution and human input functionality.""" """Test integration between Flow execution and human input functionality."""
def test_console_formatter_pause_resume_methods(self): def test_console_formatter_pause_resume_methods_exist(self):
"""Test that ConsoleFormatter pause/resume methods work correctly.""" """Test that ConsoleFormatter pause/resume methods exist and are callable."""
formatter = event_listener.formatter formatter = event_listener.formatter
original_paused_state = formatter._live_paused # Methods should exist and be callable
assert hasattr(formatter, "pause_live_updates")
assert hasattr(formatter, "resume_live_updates")
assert callable(formatter.pause_live_updates)
assert callable(formatter.resume_live_updates)
try: # Should not raise
formatter._live_paused = False formatter.pause_live_updates()
formatter.resume_live_updates()
formatter.pause_live_updates()
assert formatter._live_paused
formatter.resume_live_updates()
assert not formatter._live_paused
finally:
formatter._live_paused = original_paused_state
@patch("builtins.input", return_value="") @patch("builtins.input", return_value="")
def test_human_input_pauses_flow_updates(self, mock_input): def test_human_input_pauses_flow_updates(self, mock_input):
@@ -38,23 +35,16 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter formatter = event_listener.formatter
original_paused_state = formatter._live_paused with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
):
result = executor._ask_human_input("Test result")
try: mock_pause.assert_called_once()
formatter._live_paused = False mock_resume.assert_called_once()
mock_input.assert_called_once()
with ( assert result == ""
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
):
result = executor._ask_human_input("Test result")
mock_pause.assert_called_once()
mock_resume.assert_called_once()
mock_input.assert_called_once()
assert result == ""
finally:
formatter._live_paused = original_paused_state
@patch("builtins.input", side_effect=["feedback", ""]) @patch("builtins.input", side_effect=["feedback", ""])
def test_multiple_human_input_rounds(self, mock_input): def test_multiple_human_input_rounds(self, mock_input):
@@ -70,53 +60,46 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter formatter = event_listener.formatter
original_paused_state = formatter._live_paused pause_calls = []
resume_calls = []
try: def track_pause():
pause_calls = [] pause_calls.append(True)
resume_calls = []
def track_pause(): def track_resume():
pause_calls.append(True) resume_calls.append(True)
def track_resume(): with (
resume_calls.append(True) patch.object(formatter, "pause_live_updates", side_effect=track_pause),
patch.object(
formatter, "resume_live_updates", side_effect=track_resume
),
):
result1 = executor._ask_human_input("Test result 1")
assert result1 == "feedback"
with ( result2 = executor._ask_human_input("Test result 2")
patch.object(formatter, "pause_live_updates", side_effect=track_pause), assert result2 == ""
patch.object(
formatter, "resume_live_updates", side_effect=track_resume
),
):
result1 = executor._ask_human_input("Test result 1")
assert result1 == "feedback"
result2 = executor._ask_human_input("Test result 2") assert len(pause_calls) == 2
assert result2 == "" assert len(resume_calls) == 2
assert len(pause_calls) == 2
assert len(resume_calls) == 2
finally:
formatter._live_paused = original_paused_state
def test_pause_resume_with_no_live_session(self): def test_pause_resume_with_no_live_session(self):
"""Test pause/resume methods handle case when no Live session exists.""" """Test pause/resume methods handle case when no Live session exists."""
formatter = event_listener.formatter formatter = event_listener.formatter
original_live = formatter._live original_streaming_live = formatter._streaming_live
original_paused_state = formatter._live_paused
try: try:
formatter._live = None formatter._streaming_live = None
formatter._live_paused = False
# Should not raise when no session exists
formatter.pause_live_updates() formatter.pause_live_updates()
formatter.resume_live_updates() formatter.resume_live_updates()
assert not formatter._live_paused assert formatter._streaming_live is None
finally: finally:
formatter._live = original_live formatter._streaming_live = original_streaming_live
formatter._live_paused = original_paused_state
def test_pause_resume_exception_handling(self): def test_pause_resume_exception_handling(self):
"""Test that resume is called even if exception occurs during human input.""" """Test that resume is called even if exception occurs during human input."""
@@ -131,23 +114,18 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter formatter = event_listener.formatter
original_paused_state = formatter._live_paused with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch(
"builtins.input", side_effect=KeyboardInterrupt("Test exception")
),
):
with pytest.raises(KeyboardInterrupt):
executor._ask_human_input("Test result")
try: mock_pause.assert_called_once()
with ( mock_resume.assert_called_once()
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch(
"builtins.input", side_effect=KeyboardInterrupt("Test exception")
),
):
with pytest.raises(KeyboardInterrupt):
executor._ask_human_input("Test result")
mock_pause.assert_called_once()
mock_resume.assert_called_once()
finally:
formatter._live_paused = original_paused_state
def test_training_mode_human_input(self): def test_training_mode_human_input(self):
"""Test human input in training mode.""" """Test human input in training mode."""
@@ -162,28 +140,25 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter formatter = event_listener.formatter
original_paused_state = formatter._live_paused with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch.object(formatter.console, "print") as mock_console_print,
patch("builtins.input", return_value="training feedback"),
):
result = executor._ask_human_input("Test result")
try: mock_pause.assert_called_once()
with ( mock_resume.assert_called_once()
patch.object(formatter, "pause_live_updates") as mock_pause, assert result == "training feedback"
patch.object(formatter, "resume_live_updates") as mock_resume,
patch("builtins.input", return_value="training feedback"),
):
result = executor._ask_human_input("Test result")
mock_pause.assert_called_once() # Verify the training panel was printed via formatter's console
mock_resume.assert_called_once() mock_console_print.assert_called()
assert result == "training feedback" # Check that a Panel with training title was printed
call_args = mock_console_print.call_args_list
executor._printer.print.assert_called() training_panel_found = any(
call_args = [ hasattr(call[0][0], "title") and "Training" in str(call[0][0].title)
call[1]["content"] for call in call_args
for call in executor._printer.print.call_args_list if call[0]
] )
training_prompt_found = any( assert training_panel_found
"TRAINING MODE" in content for content in call_args
)
assert training_prompt_found
finally:
formatter._live_paused = original_paused_state

View File

@@ -1,116 +1,107 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from rich.tree import Tree
from rich.live import Live from rich.live import Live
from crewai.events.utils.console_formatter import ConsoleFormatter from crewai.events.utils.console_formatter import ConsoleFormatter
class TestConsoleFormatterPauseResume: class TestConsoleFormatterPauseResume:
"""Test ConsoleFormatter pause/resume functionality.""" """Test ConsoleFormatter pause/resume functionality for HITL features."""
def test_pause_live_updates_with_active_session(self): def test_pause_stops_active_streaming_session(self):
"""Test pausing when Live session is active.""" """Test pausing stops an active streaming Live session."""
formatter = ConsoleFormatter() formatter = ConsoleFormatter()
mock_live = MagicMock(spec=Live) mock_live = MagicMock(spec=Live)
formatter._live = mock_live formatter._streaming_live = mock_live
formatter._live_paused = False
formatter.pause_live_updates() formatter.pause_live_updates()
mock_live.stop.assert_called_once() mock_live.stop.assert_called_once()
assert formatter._live_paused assert formatter._streaming_live is None
def test_pause_live_updates_when_already_paused(self): def test_pause_is_safe_when_no_session(self):
"""Test pausing when already paused does nothing.""" """Test pausing when no streaming session exists doesn't error."""
formatter = ConsoleFormatter()
formatter._streaming_live = None
# Should not raise
formatter.pause_live_updates()
assert formatter._streaming_live is None
def test_multiple_pauses_are_safe(self):
"""Test calling pause multiple times is safe."""
formatter = ConsoleFormatter() formatter = ConsoleFormatter()
mock_live = MagicMock(spec=Live) mock_live = MagicMock(spec=Live)
formatter._live = mock_live formatter._streaming_live = mock_live
formatter._live_paused = True
formatter.pause_live_updates() formatter.pause_live_updates()
mock_live.stop.assert_called_once()
assert formatter._streaming_live is None
mock_live.stop.assert_not_called() # Second pause should not error (no session to stop)
assert formatter._live_paused
def test_pause_live_updates_with_no_session(self):
"""Test pausing when no Live session exists."""
formatter = ConsoleFormatter()
formatter._live = None
formatter._live_paused = False
formatter.pause_live_updates() formatter.pause_live_updates()
assert formatter._live_paused def test_resume_is_safe(self):
"""Test resume method exists and doesn't error."""
def test_resume_live_updates_when_paused(self):
"""Test resuming when paused."""
formatter = ConsoleFormatter() formatter = ConsoleFormatter()
formatter._live_paused = True # Should not raise
formatter.resume_live_updates() formatter.resume_live_updates()
assert not formatter._live_paused def test_streaming_after_pause_resume_creates_new_session(self):
"""Test that streaming after pause/resume creates new Live session."""
def test_resume_live_updates_when_not_paused(self):
"""Test resuming when not paused does nothing."""
formatter = ConsoleFormatter() formatter = ConsoleFormatter()
formatter.verbose = True
formatter._live_paused = False # Simulate having an active session
mock_live = MagicMock(spec=Live)
formatter._streaming_live = mock_live
# Pause stops the session
formatter.pause_live_updates()
assert formatter._streaming_live is None
# Resume (no-op, sessions created on demand)
formatter.resume_live_updates() formatter.resume_live_updates()
assert not formatter._live_paused # After resume, streaming should be able to start a new session
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
mock_live_instance = MagicMock()
mock_live_class.return_value = mock_live_instance
def test_print_after_resume_restarts_live_session(self): # Simulate streaming chunk (this creates a new Live session)
"""Test that printing a Tree after resume creates new Live session.""" formatter.handle_llm_stream_chunk("test chunk", call_type=None)
mock_live_class.assert_called_once()
mock_live_instance.start.assert_called_once()
assert formatter._streaming_live == mock_live_instance
def test_pause_resume_cycle_with_streaming(self):
"""Test full pause/resume cycle during streaming."""
formatter = ConsoleFormatter() formatter = ConsoleFormatter()
formatter.verbose = True
formatter._live_paused = True
formatter._live = None
formatter.resume_live_updates()
assert not formatter._live_paused
tree = Tree("Test")
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class: with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
mock_live_instance = MagicMock() mock_live_instance = MagicMock()
mock_live_class.return_value = mock_live_instance mock_live_class.return_value = mock_live_instance
formatter.print(tree) # Start streaming
formatter.handle_llm_stream_chunk("chunk 1", call_type=None)
assert formatter._streaming_live == mock_live_instance
mock_live_class.assert_called_once() # Pause should stop the session
mock_live_instance.start.assert_called_once() formatter.pause_live_updates()
assert formatter._live == mock_live_instance mock_live_instance.stop.assert_called_once()
assert formatter._streaming_live is None
def test_multiple_pause_resume_cycles(self): # Resume (no-op)
"""Test multiple pause/resume cycles work correctly.""" formatter.resume_live_updates()
formatter = ConsoleFormatter()
mock_live = MagicMock(spec=Live) # Create a new mock for the next session
formatter._live = mock_live mock_live_instance_2 = MagicMock()
formatter._live_paused = False mock_live_class.return_value = mock_live_instance_2
formatter.pause_live_updates() # Streaming again creates new session
assert formatter._live_paused formatter.handle_llm_stream_chunk("chunk 2", call_type=None)
mock_live.stop.assert_called_once() assert formatter._streaming_live == mock_live_instance_2
assert formatter._live is None # Live session should be cleared
formatter.resume_live_updates()
assert not formatter._live_paused
formatter.pause_live_updates()
assert formatter._live_paused
formatter.resume_live_updates()
assert not formatter._live_paused
def test_pause_resume_state_initialization(self):
"""Test that _live_paused is properly initialized."""
formatter = ConsoleFormatter()
assert hasattr(formatter, "_live_paused")
assert not formatter._live_paused