Improve EventListener and TraceCollectionListener for improved event… (#4160)

* Refactor EventListener and TraceCollectionListener for improved event handling

- Removed unused threading and method branches from EventListener to simplify the code.
- Updated event handling methods in EventListener to use new formatter methods for better clarity and consistency.
- Refactored TraceCollectionListener to eliminate unnecessary parameters in formatter calls, enhancing readability.
- Simplified ConsoleFormatter by removing outdated tree management methods and focusing on panel-based output for status updates.
- Enhanced ToolUsage to track run attempts for better tool usage metrics.

* clearer for knowledge retrieval and dropped some reduancies

* Refactor EventListener and ConsoleFormatter for improved clarity and consistency

- Removed the MCPToolExecutionCompletedEvent handler from EventListener to streamline event processing.
- Updated ConsoleFormatter to enhance output formatting by adding line breaks for better readability in status content.
- Renamed status messages for MCP Tool execution to provide clearer context during tool operations.

* fix run attempt incrementation

* task name consistency

* memory events consistency

* ensure hitl works

* linting
This commit is contained in:
Lorenze Jay
2025-12-30 11:36:31 -08:00
committed by GitHub
parent b9dd166a6b
commit 467ee2917e
8 changed files with 651 additions and 1987 deletions

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import time
from typing import TYPE_CHECKING
from crewai.agents.parser import AgentFinish
from crewai.events.event_listener import event_listener
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
@@ -29,7 +30,7 @@ class CrewAgentExecutorMixin:
_i18n: I18N
_printer: Printer = Printer()
def _create_short_term_memory(self, output) -> None:
def _create_short_term_memory(self, output: AgentFinish) -> None:
"""Create and save a short-term memory item if conditions are met."""
if (
self.crew
@@ -53,7 +54,7 @@ class CrewAgentExecutorMixin:
"error", f"Failed to add to short term memory: {e}"
)
def _create_external_memory(self, output) -> None:
def _create_external_memory(self, output: AgentFinish) -> None:
"""Create and save a external-term memory item if conditions are met."""
if (
self.crew
@@ -75,7 +76,7 @@ class CrewAgentExecutorMixin:
"error", f"Failed to add to external memory: {e}"
)
def _create_long_term_memory(self, output) -> None:
def _create_long_term_memory(self, output: AgentFinish) -> None:
"""Create and save long-term and entity memory items based on evaluation."""
if (
self.crew
@@ -136,40 +137,50 @@ class CrewAgentExecutorMixin:
)
def _ask_human_input(self, final_answer: str) -> str:
"""Prompt human input with mode-appropriate messaging."""
event_listener.formatter.pause_live_updates()
try:
self._printer.print(
content=f"\033[1m\033[95m ## Final Result:\033[00m \033[92m{final_answer}\033[00m"
)
"""Prompt human input with mode-appropriate messaging.
Note: The final answer is already displayed via the AgentLogsExecutionEvent
panel, so we only show the feedback prompt here.
"""
from rich.panel import Panel
from rich.text import Text
formatter = event_listener.formatter
formatter.pause_live_updates()
try:
# Training mode prompt (single iteration)
if self.crew and getattr(self.crew, "_train", False):
prompt = (
"\n\n=====\n"
"## TRAINING MODE: Provide feedback to improve the agent's performance.\n"
prompt_text = (
"TRAINING MODE: Provide feedback to improve the agent's performance.\n\n"
"This will be used to train better versions of the agent.\n"
"Please provide detailed feedback about the result quality and reasoning process.\n"
"=====\n"
"Please provide detailed feedback about the result quality and reasoning process."
)
title = "🎓 Training Feedback Required"
# Regular human-in-the-loop prompt (multiple iterations)
else:
prompt = (
"\n\n=====\n"
"## HUMAN FEEDBACK: Provide feedback on the Final Result and Agent's actions.\n"
"Please follow these guidelines:\n"
" - If you are happy with the result, simply hit Enter without typing anything.\n"
" - Otherwise, provide specific improvement requests.\n"
" - You can provide multiple rounds of feedback until satisfied.\n"
"=====\n"
prompt_text = (
"Provide feedback on the Final Result above.\n\n"
"• If you are happy with the result, simply hit Enter without typing anything.\n"
"• Otherwise, provide specific improvement requests.\n"
"• You can provide multiple rounds of feedback until satisfied."
)
title = "💬 Human Feedback Required"
content = Text()
content.append(prompt_text, style="yellow")
prompt_panel = Panel(
content,
title=title,
border_style="yellow",
padding=(1, 2),
)
formatter.console.print(prompt_panel)
self._printer.print(content=prompt, color="bold_yellow")
response = input()
if response.strip() != "":
self._printer.print(
content="\nProcessing your feedback...", color="cyan"
)
formatter.console.print("\n[cyan]Processing your feedback...[/cyan]")
return response
finally:
event_listener.formatter.resume_live_updates()
formatter.resume_live_updates()

View File

@@ -7,6 +7,7 @@ and memory management.
from __future__ import annotations
from collections.abc import Callable
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler
@@ -51,6 +52,8 @@ from crewai.utilities.tool_utils import (
from crewai.utilities.training_handler import CrewTrainingHandler
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.tools_handler import ToolsHandler
@@ -541,7 +544,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if self.agent is None:
raise ValueError("Agent cannot be None")
crewai_event_bus.emit(
future = crewai_event_bus.emit(
self.agent,
AgentLogsExecutionEvent(
agent_role=self.agent.role,
@@ -551,6 +554,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
),
)
if future is not None:
try:
future.result(timeout=5.0)
except Exception as e:
logger.error(f"Failed to show logs for agent execution event: {e}")
def _handle_crew_training_output(
self, result: AgentFinish, human_feedback: str | None = None
) -> None:

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from io import StringIO
import threading
from typing import TYPE_CHECKING, Any
from pydantic import Field, PrivateAttr
@@ -17,8 +16,6 @@ from crewai.events.types.a2a_events import (
A2AResponseReceivedEvent,
)
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
@@ -48,7 +45,6 @@ from crewai.events.types.flow_events import (
from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
@@ -112,7 +108,6 @@ class EventListener(BaseEventListener):
text_stream: StringIO = StringIO()
knowledge_retrieval_in_progress: bool = False
knowledge_query_in_progress: bool = False
method_branches: dict[str, Any] = Field(default_factory=dict)
def __new__(cls) -> EventListener:
if cls._instance is None:
@@ -126,10 +121,8 @@ class EventListener(BaseEventListener):
self._telemetry = Telemetry()
self._telemetry.set_tracer()
self.execution_spans = {}
self.method_branches = {}
self._initialized = True
self.formatter = ConsoleFormatter(verbose=True)
self._crew_tree_lock = threading.Condition()
# Initialize trace listener with formatter for memory event handling
trace_listener = TraceCollectionListener()
@@ -140,12 +133,10 @@ class EventListener(BaseEventListener):
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
@crewai_event_bus.on(CrewKickoffStartedEvent)
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
with self._crew_tree_lock:
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
source._execution_span = self._telemetry.crew_execution_span(
source, event.inputs
)
self._crew_tree_lock.notify_all()
self.formatter.handle_crew_started(event.crew_name or "Crew", source.id)
source._execution_span = self._telemetry.crew_execution_span(
source, event.inputs
)
@crewai_event_bus.on(CrewKickoffCompletedEvent)
def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None:
@@ -153,8 +144,7 @@ class EventListener(BaseEventListener):
final_string_output = event.output.raw
self._telemetry.end_crew(source, final_string_output)
self.formatter.update_crew_tree(
self.formatter.current_crew_tree,
self.formatter.handle_crew_status(
event.crew_name or "Crew",
source.id,
"completed",
@@ -163,8 +153,7 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(CrewKickoffFailedEvent)
def on_crew_failed(source: Any, event: CrewKickoffFailedEvent) -> None:
self.formatter.update_crew_tree(
self.formatter.current_crew_tree,
self.formatter.handle_crew_status(
event.crew_name or "Crew",
source.id,
"failed",
@@ -197,23 +186,22 @@ class EventListener(BaseEventListener):
# ----------- TASK EVENTS -----------
def get_task_name(source: Any) -> str | None:
return (
source.name
if hasattr(source, "name") and source.name
else source.description
if hasattr(source, "description") and source.description
else None
)
@crewai_event_bus.on(TaskStartedEvent)
def on_task_started(source: Any, event: TaskStartedEvent) -> None:
span = self._telemetry.task_started(crew=source.agent.crew, task=source)
self.execution_spans[source] = span
with self._crew_tree_lock:
self._crew_tree_lock.wait_for(
lambda: self.formatter.current_crew_tree is not None, timeout=5.0
)
if self.formatter.current_crew_tree is not None:
task_name = (
source.name if hasattr(source, "name") and source.name else None
)
self.formatter.create_task_branch(
self.formatter.current_crew_tree, source.id, task_name
)
task_name = get_task_name(source)
self.formatter.handle_task_started(source.id, task_name)
@crewai_event_bus.on(TaskCompletedEvent)
def on_task_completed(source: Any, event: TaskCompletedEvent) -> None:
@@ -224,13 +212,9 @@ class EventListener(BaseEventListener):
self.execution_spans[source] = None
# Pass task name if it exists
task_name = source.name if hasattr(source, "name") and source.name else None
self.formatter.update_task_status(
self.formatter.current_crew_tree,
source.id,
source.agent.role,
"completed",
task_name,
task_name = get_task_name(source)
self.formatter.handle_task_status(
source.id, source.agent.role, "completed", task_name
)
@crewai_event_bus.on(TaskFailedEvent)
@@ -242,37 +226,12 @@ class EventListener(BaseEventListener):
self.execution_spans[source] = None
# Pass task name if it exists
task_name = source.name if hasattr(source, "name") and source.name else None
self.formatter.update_task_status(
self.formatter.current_crew_tree,
source.id,
source.agent.role,
"failed",
task_name,
task_name = get_task_name(source)
self.formatter.handle_task_status(
source.id, source.agent.role, "failed", task_name
)
# ----------- AGENT EVENTS -----------
@crewai_event_bus.on(AgentExecutionStartedEvent)
def on_agent_execution_started(
_: Any, event: AgentExecutionStartedEvent
) -> None:
self.formatter.create_agent_branch(
self.formatter.current_task_branch,
event.agent.role,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(AgentExecutionCompletedEvent)
def on_agent_execution_completed(
_: Any, event: AgentExecutionCompletedEvent
) -> None:
self.formatter.update_agent_status(
self.formatter.current_agent_branch,
event.agent.role,
self.formatter.current_crew_tree,
)
# ----------- LITE AGENT EVENTS -----------
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
@@ -316,79 +275,61 @@ class EventListener(BaseEventListener):
self._telemetry.flow_execution_span(
event.flow_name, list(source._methods.keys())
)
tree = self.formatter.create_flow_tree(event.flow_name, str(source.flow_id))
self.formatter.current_flow_tree = tree
self.formatter.start_flow(event.flow_name, str(source.flow_id))
self.formatter.handle_flow_created(event.flow_name, str(source.flow_id))
self.formatter.handle_flow_started(event.flow_name, str(source.flow_id))
@crewai_event_bus.on(FlowFinishedEvent)
def on_flow_finished(source: Any, event: FlowFinishedEvent) -> None:
self.formatter.update_flow_status(
self.formatter.current_flow_tree, event.flow_name, source.flow_id
self.formatter.handle_flow_status(
event.flow_name,
source.flow_id,
)
@crewai_event_bus.on(MethodExecutionStartedEvent)
def on_method_execution_started(
_: Any, event: MethodExecutionStartedEvent
) -> None:
method_branch = self.method_branches.get(event.method_name)
updated_branch = self.formatter.update_method_status(
method_branch,
self.formatter.current_flow_tree,
self.formatter.handle_method_status(
event.method_name,
"running",
)
self.method_branches[event.method_name] = updated_branch
@crewai_event_bus.on(MethodExecutionFinishedEvent)
def on_method_execution_finished(
_: Any, event: MethodExecutionFinishedEvent
) -> None:
method_branch = self.method_branches.get(event.method_name)
updated_branch = self.formatter.update_method_status(
method_branch,
self.formatter.current_flow_tree,
self.formatter.handle_method_status(
event.method_name,
"completed",
)
self.method_branches[event.method_name] = updated_branch
@crewai_event_bus.on(MethodExecutionFailedEvent)
def on_method_execution_failed(
_: Any, event: MethodExecutionFailedEvent
) -> None:
method_branch = self.method_branches.get(event.method_name)
updated_branch = self.formatter.update_method_status(
method_branch,
self.formatter.current_flow_tree,
self.formatter.handle_method_status(
event.method_name,
"failed",
)
self.method_branches[event.method_name] = updated_branch
@crewai_event_bus.on(MethodExecutionPausedEvent)
def on_method_execution_paused(
_: Any, event: MethodExecutionPausedEvent
) -> None:
method_branch = self.method_branches.get(event.method_name)
updated_branch = self.formatter.update_method_status(
method_branch,
self.formatter.current_flow_tree,
self.formatter.handle_method_status(
event.method_name,
"paused",
)
self.method_branches[event.method_name] = updated_branch
@crewai_event_bus.on(FlowPausedEvent)
def on_flow_paused(_: Any, event: FlowPausedEvent) -> None:
self.formatter.update_flow_status(
self.formatter.current_flow_tree,
self.formatter.handle_flow_status(
event.flow_name,
event.flow_id,
"paused",
)
# ----------- TOOL USAGE EVENTS -----------
@crewai_event_bus.on(ToolUsageStartedEvent)
def on_tool_usage_started(source: Any, event: ToolUsageStartedEvent) -> None:
if isinstance(source, LLM):
@@ -398,9 +339,9 @@ class EventListener(BaseEventListener):
)
else:
self.formatter.handle_tool_usage_started(
self.formatter.current_agent_branch,
event.tool_name,
self.formatter.current_crew_tree,
event.tool_args,
event.run_attempts,
)
@crewai_event_bus.on(ToolUsageFinishedEvent)
@@ -409,12 +350,6 @@ class EventListener(BaseEventListener):
self.formatter.handle_llm_tool_usage_finished(
event.tool_name,
)
else:
self.formatter.handle_tool_usage_finished(
self.formatter.current_tool_branch,
event.tool_name,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(ToolUsageErrorEvent)
def on_tool_usage_error(source: Any, event: ToolUsageErrorEvent) -> None:
@@ -425,10 +360,9 @@ class EventListener(BaseEventListener):
)
else:
self.formatter.handle_tool_usage_error(
self.formatter.current_tool_branch,
event.tool_name,
event.error,
self.formatter.current_crew_tree,
event.run_attempts,
)
# ----------- LLM EVENTS -----------
@@ -437,32 +371,15 @@ class EventListener(BaseEventListener):
def on_llm_call_started(_: Any, event: LLMCallStartedEvent) -> None:
self.text_stream = StringIO()
self.next_chunk = 0
# Capture the returned tool branch and update the current_tool_branch reference
thinking_branch = self.formatter.handle_llm_call_started(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
# Update the formatter's current_tool_branch to ensure proper cleanup
if thinking_branch is not None:
self.formatter.current_tool_branch = thinking_branch
@crewai_event_bus.on(LLMCallCompletedEvent)
def on_llm_call_completed(_: Any, event: LLMCallCompletedEvent) -> None:
self.formatter.handle_llm_stream_completed()
self.formatter.handle_llm_call_completed(
self.formatter.current_tool_branch,
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(LLMCallFailedEvent)
def on_llm_call_failed(_: Any, event: LLMCallFailedEvent) -> None:
self.formatter.handle_llm_stream_completed()
self.formatter.handle_llm_call_failed(
self.formatter.current_tool_branch,
event.error,
self.formatter.current_crew_tree,
)
self.formatter.handle_llm_call_failed(event.error)
@crewai_event_bus.on(LLMStreamChunkEvent)
def on_llm_stream_chunk(_: Any, event: LLMStreamChunkEvent) -> None:
@@ -473,9 +390,7 @@ class EventListener(BaseEventListener):
accumulated_text = self.text_stream.getvalue()
self.formatter.handle_llm_stream_chunk(
event.chunk,
accumulated_text,
self.formatter.current_crew_tree,
event.call_type,
)
@@ -515,7 +430,6 @@ class EventListener(BaseEventListener):
@crewai_event_bus.on(CrewTestCompletedEvent)
def on_crew_test_completed(_: Any, event: CrewTestCompletedEvent) -> None:
self.formatter.handle_crew_test_completed(
self.formatter.current_flow_tree,
event.crew_name or "Crew",
)
@@ -532,10 +446,7 @@ class EventListener(BaseEventListener):
self.knowledge_retrieval_in_progress = True
self.formatter.handle_knowledge_retrieval_started(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
self.formatter.handle_knowledge_retrieval_started()
@crewai_event_bus.on(KnowledgeRetrievalCompletedEvent)
def on_knowledge_retrieval_completed(
@@ -546,24 +457,13 @@ class EventListener(BaseEventListener):
self.knowledge_retrieval_in_progress = False
self.formatter.handle_knowledge_retrieval_completed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.retrieved_knowledge,
event.query,
)
@crewai_event_bus.on(KnowledgeQueryStartedEvent)
def on_knowledge_query_started(
_: Any, event: KnowledgeQueryStartedEvent
) -> None:
pass
@crewai_event_bus.on(KnowledgeQueryFailedEvent)
def on_knowledge_query_failed(_: Any, event: KnowledgeQueryFailedEvent) -> None:
self.formatter.handle_knowledge_query_failed(
self.formatter.current_agent_branch,
event.error,
self.formatter.current_crew_tree,
)
self.formatter.handle_knowledge_query_failed(event.error)
@crewai_event_bus.on(KnowledgeQueryCompletedEvent)
def on_knowledge_query_completed(
@@ -575,11 +475,7 @@ class EventListener(BaseEventListener):
def on_knowledge_search_query_failed(
_: Any, event: KnowledgeSearchQueryFailedEvent
) -> None:
self.formatter.handle_knowledge_search_query_failed(
self.formatter.current_agent_branch,
event.error,
self.formatter.current_crew_tree,
)
self.formatter.handle_knowledge_search_query_failed(event.error)
# ----------- REASONING EVENTS -----------
@@ -587,11 +483,7 @@ class EventListener(BaseEventListener):
def on_agent_reasoning_started(
_: Any, event: AgentReasoningStartedEvent
) -> None:
self.formatter.handle_reasoning_started(
self.formatter.current_agent_branch,
event.attempt,
self.formatter.current_crew_tree,
)
self.formatter.handle_reasoning_started(event.attempt)
@crewai_event_bus.on(AgentReasoningCompletedEvent)
def on_agent_reasoning_completed(
@@ -600,14 +492,12 @@ class EventListener(BaseEventListener):
self.formatter.handle_reasoning_completed(
event.plan,
event.ready,
self.formatter.current_crew_tree,
)
@crewai_event_bus.on(AgentReasoningFailedEvent)
def on_agent_reasoning_failed(_: Any, event: AgentReasoningFailedEvent) -> None:
self.formatter.handle_reasoning_failed(
event.error,
self.formatter.current_crew_tree,
)
# ----------- AGENT LOGGING EVENTS -----------
@@ -734,18 +624,6 @@ class EventListener(BaseEventListener):
event.tool_args,
)
@crewai_event_bus.on(MCPToolExecutionCompletedEvent)
def on_mcp_tool_execution_completed(
_: Any, event: MCPToolExecutionCompletedEvent
) -> None:
self.formatter.handle_mcp_tool_execution_completed(
event.server_name,
event.tool_name,
event.tool_args,
event.result,
event.execution_duration_ms,
)
@crewai_event_bus.on(MCPToolExecutionFailedEvent)
def on_mcp_tool_execution_failed(
_: Any, event: MCPToolExecutionFailedEvent

View File

@@ -1,7 +1,7 @@
"""Trace collection listener for orchestrating trace collection."""
import os
from typing import Any, ClassVar
from typing import Any, ClassVar, cast
import uuid
from typing_extensions import Self
@@ -105,7 +105,7 @@ class TraceCollectionListener(BaseEventListener):
"""Create or return singleton instance."""
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
return cast(Self, cls._instance)
def __init__(
self,
@@ -319,21 +319,12 @@ class TraceCollectionListener(BaseEventListener):
source: Any, event: MemoryQueryCompletedEvent
) -> None:
self._handle_action_event("memory_query_completed", source, event)
if self.formatter and self.memory_retrieval_in_progress:
self.formatter.handle_memory_query_completed(
self.formatter.current_agent_branch,
event.source_type or "memory",
event.query_time_ms,
self.formatter.current_crew_tree,
)
@event_bus.on(MemoryQueryFailedEvent)
def on_memory_query_failed(source: Any, event: MemoryQueryFailedEvent) -> None:
self._handle_action_event("memory_query_failed", source, event)
if self.formatter and self.memory_retrieval_in_progress:
self.formatter.handle_memory_query_failed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.error,
event.source_type or "memory",
)
@@ -347,10 +338,7 @@ class TraceCollectionListener(BaseEventListener):
self.memory_save_in_progress = True
self.formatter.handle_memory_save_started(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
self.formatter.handle_memory_save_started()
@event_bus.on(MemorySaveCompletedEvent)
def on_memory_save_completed(
@@ -364,8 +352,6 @@ class TraceCollectionListener(BaseEventListener):
self.memory_save_in_progress = False
self.formatter.handle_memory_save_completed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.save_time_ms,
event.source_type or "memory",
)
@@ -375,10 +361,8 @@ class TraceCollectionListener(BaseEventListener):
self._handle_action_event("memory_save_failed", source, event)
if self.formatter and self.memory_save_in_progress:
self.formatter.handle_memory_save_failed(
self.formatter.current_agent_branch,
event.error,
event.source_type or "memory",
self.formatter.current_crew_tree,
)
@event_bus.on(MemoryRetrievalStartedEvent)
@@ -391,10 +375,7 @@ class TraceCollectionListener(BaseEventListener):
self.memory_retrieval_in_progress = True
self.formatter.handle_memory_retrieval_started(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
)
self.formatter.handle_memory_retrieval_started()
@event_bus.on(MemoryRetrievalCompletedEvent)
def on_memory_retrieval_completed(
@@ -406,8 +387,6 @@ class TraceCollectionListener(BaseEventListener):
self.memory_retrieval_in_progress = False
self.formatter.handle_memory_retrieval_completed(
self.formatter.current_agent_branch,
self.formatter.current_crew_tree,
event.memory_content,
event.retrieval_time_ms,
)

File diff suppressed because it is too large Load Diff

View File

@@ -249,6 +249,7 @@ class ToolUsage:
"tool_args": self.action.tool_input,
"tool_class": self.action.tool,
"agent": self.agent,
"run_attempts": self._run_attempts,
}
if self.agent.fingerprint: # type: ignore
@@ -435,6 +436,7 @@ class ToolUsage:
"tool_args": self.action.tool_input,
"tool_class": self.action.tool,
"agent": self.agent,
"run_attempts": self._run_attempts,
}
# TODO: Investigate fingerprint attribute availability on BaseAgent/LiteAgent

View File

@@ -7,22 +7,19 @@ from crewai.events.event_listener import event_listener
class TestFlowHumanInputIntegration:
"""Test integration between Flow execution and human input functionality."""
def test_console_formatter_pause_resume_methods(self):
"""Test that ConsoleFormatter pause/resume methods work correctly."""
def test_console_formatter_pause_resume_methods_exist(self):
"""Test that ConsoleFormatter pause/resume methods exist and are callable."""
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
# Methods should exist and be callable
assert hasattr(formatter, "pause_live_updates")
assert hasattr(formatter, "resume_live_updates")
assert callable(formatter.pause_live_updates)
assert callable(formatter.resume_live_updates)
try:
formatter._live_paused = False
formatter.pause_live_updates()
assert formatter._live_paused
formatter.resume_live_updates()
assert not formatter._live_paused
finally:
formatter._live_paused = original_paused_state
# Should not raise
formatter.pause_live_updates()
formatter.resume_live_updates()
@patch("builtins.input", return_value="")
def test_human_input_pauses_flow_updates(self, mock_input):
@@ -38,23 +35,16 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
):
result = executor._ask_human_input("Test result")
try:
formatter._live_paused = False
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
):
result = executor._ask_human_input("Test result")
mock_pause.assert_called_once()
mock_resume.assert_called_once()
mock_input.assert_called_once()
assert result == ""
finally:
formatter._live_paused = original_paused_state
mock_pause.assert_called_once()
mock_resume.assert_called_once()
mock_input.assert_called_once()
assert result == ""
@patch("builtins.input", side_effect=["feedback", ""])
def test_multiple_human_input_rounds(self, mock_input):
@@ -70,53 +60,46 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
pause_calls = []
resume_calls = []
try:
pause_calls = []
resume_calls = []
def track_pause():
pause_calls.append(True)
def track_pause():
pause_calls.append(True)
def track_resume():
resume_calls.append(True)
def track_resume():
resume_calls.append(True)
with (
patch.object(formatter, "pause_live_updates", side_effect=track_pause),
patch.object(
formatter, "resume_live_updates", side_effect=track_resume
),
):
result1 = executor._ask_human_input("Test result 1")
assert result1 == "feedback"
with (
patch.object(formatter, "pause_live_updates", side_effect=track_pause),
patch.object(
formatter, "resume_live_updates", side_effect=track_resume
),
):
result1 = executor._ask_human_input("Test result 1")
assert result1 == "feedback"
result2 = executor._ask_human_input("Test result 2")
assert result2 == ""
result2 = executor._ask_human_input("Test result 2")
assert result2 == ""
assert len(pause_calls) == 2
assert len(resume_calls) == 2
finally:
formatter._live_paused = original_paused_state
assert len(pause_calls) == 2
assert len(resume_calls) == 2
def test_pause_resume_with_no_live_session(self):
"""Test pause/resume methods handle case when no Live session exists."""
formatter = event_listener.formatter
original_live = formatter._live
original_paused_state = formatter._live_paused
original_streaming_live = formatter._streaming_live
try:
formatter._live = None
formatter._live_paused = False
formatter._streaming_live = None
# Should not raise when no session exists
formatter.pause_live_updates()
formatter.resume_live_updates()
assert not formatter._live_paused
assert formatter._streaming_live is None
finally:
formatter._live = original_live
formatter._live_paused = original_paused_state
formatter._streaming_live = original_streaming_live
def test_pause_resume_exception_handling(self):
"""Test that resume is called even if exception occurs during human input."""
@@ -131,23 +114,18 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch(
"builtins.input", side_effect=KeyboardInterrupt("Test exception")
),
):
with pytest.raises(KeyboardInterrupt):
executor._ask_human_input("Test result")
try:
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch(
"builtins.input", side_effect=KeyboardInterrupt("Test exception")
),
):
with pytest.raises(KeyboardInterrupt):
executor._ask_human_input("Test result")
mock_pause.assert_called_once()
mock_resume.assert_called_once()
finally:
formatter._live_paused = original_paused_state
mock_pause.assert_called_once()
mock_resume.assert_called_once()
def test_training_mode_human_input(self):
"""Test human input in training mode."""
@@ -162,28 +140,25 @@ class TestFlowHumanInputIntegration:
formatter = event_listener.formatter
original_paused_state = formatter._live_paused
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch.object(formatter.console, "print") as mock_console_print,
patch("builtins.input", return_value="training feedback"),
):
result = executor._ask_human_input("Test result")
try:
with (
patch.object(formatter, "pause_live_updates") as mock_pause,
patch.object(formatter, "resume_live_updates") as mock_resume,
patch("builtins.input", return_value="training feedback"),
):
result = executor._ask_human_input("Test result")
mock_pause.assert_called_once()
mock_resume.assert_called_once()
assert result == "training feedback"
mock_pause.assert_called_once()
mock_resume.assert_called_once()
assert result == "training feedback"
executor._printer.print.assert_called()
call_args = [
call[1]["content"]
for call in executor._printer.print.call_args_list
]
training_prompt_found = any(
"TRAINING MODE" in content for content in call_args
)
assert training_prompt_found
finally:
formatter._live_paused = original_paused_state
# Verify the training panel was printed via formatter's console
mock_console_print.assert_called()
# Check that a Panel with training title was printed
call_args = mock_console_print.call_args_list
training_panel_found = any(
hasattr(call[0][0], "title") and "Training" in str(call[0][0].title)
for call in call_args
if call[0]
)
assert training_panel_found

View File

@@ -1,116 +1,107 @@
from unittest.mock import MagicMock, patch
from rich.tree import Tree
from rich.live import Live
from crewai.events.utils.console_formatter import ConsoleFormatter
class TestConsoleFormatterPauseResume:
"""Test ConsoleFormatter pause/resume functionality."""
"""Test ConsoleFormatter pause/resume functionality for HITL features."""
def test_pause_live_updates_with_active_session(self):
"""Test pausing when Live session is active."""
def test_pause_stops_active_streaming_session(self):
"""Test pausing stops an active streaming Live session."""
formatter = ConsoleFormatter()
mock_live = MagicMock(spec=Live)
formatter._live = mock_live
formatter._live_paused = False
formatter._streaming_live = mock_live
formatter.pause_live_updates()
mock_live.stop.assert_called_once()
assert formatter._live_paused
assert formatter._streaming_live is None
def test_pause_live_updates_when_already_paused(self):
"""Test pausing when already paused does nothing."""
def test_pause_is_safe_when_no_session(self):
"""Test pausing when no streaming session exists doesn't error."""
formatter = ConsoleFormatter()
formatter._streaming_live = None
# Should not raise
formatter.pause_live_updates()
assert formatter._streaming_live is None
def test_multiple_pauses_are_safe(self):
"""Test calling pause multiple times is safe."""
formatter = ConsoleFormatter()
mock_live = MagicMock(spec=Live)
formatter._live = mock_live
formatter._live_paused = True
formatter._streaming_live = mock_live
formatter.pause_live_updates()
mock_live.stop.assert_called_once()
assert formatter._streaming_live is None
mock_live.stop.assert_not_called()
assert formatter._live_paused
def test_pause_live_updates_with_no_session(self):
"""Test pausing when no Live session exists."""
formatter = ConsoleFormatter()
formatter._live = None
formatter._live_paused = False
# Second pause should not error (no session to stop)
formatter.pause_live_updates()
assert formatter._live_paused
def test_resume_live_updates_when_paused(self):
"""Test resuming when paused."""
def test_resume_is_safe(self):
"""Test resume method exists and doesn't error."""
formatter = ConsoleFormatter()
formatter._live_paused = True
# Should not raise
formatter.resume_live_updates()
assert not formatter._live_paused
def test_resume_live_updates_when_not_paused(self):
"""Test resuming when not paused does nothing."""
def test_streaming_after_pause_resume_creates_new_session(self):
"""Test that streaming after pause/resume creates new Live session."""
formatter = ConsoleFormatter()
formatter.verbose = True
formatter._live_paused = False
# Simulate having an active session
mock_live = MagicMock(spec=Live)
formatter._streaming_live = mock_live
# Pause stops the session
formatter.pause_live_updates()
assert formatter._streaming_live is None
# Resume (no-op, sessions created on demand)
formatter.resume_live_updates()
assert not formatter._live_paused
# After resume, streaming should be able to start a new session
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
mock_live_instance = MagicMock()
mock_live_class.return_value = mock_live_instance
def test_print_after_resume_restarts_live_session(self):
"""Test that printing a Tree after resume creates new Live session."""
# Simulate streaming chunk (this creates a new Live session)
formatter.handle_llm_stream_chunk("test chunk", call_type=None)
mock_live_class.assert_called_once()
mock_live_instance.start.assert_called_once()
assert formatter._streaming_live == mock_live_instance
def test_pause_resume_cycle_with_streaming(self):
"""Test full pause/resume cycle during streaming."""
formatter = ConsoleFormatter()
formatter._live_paused = True
formatter._live = None
formatter.resume_live_updates()
assert not formatter._live_paused
tree = Tree("Test")
formatter.verbose = True
with patch("crewai.events.utils.console_formatter.Live") as mock_live_class:
mock_live_instance = MagicMock()
mock_live_class.return_value = mock_live_instance
formatter.print(tree)
# Start streaming
formatter.handle_llm_stream_chunk("chunk 1", call_type=None)
assert formatter._streaming_live == mock_live_instance
mock_live_class.assert_called_once()
mock_live_instance.start.assert_called_once()
assert formatter._live == mock_live_instance
# Pause should stop the session
formatter.pause_live_updates()
mock_live_instance.stop.assert_called_once()
assert formatter._streaming_live is None
def test_multiple_pause_resume_cycles(self):
"""Test multiple pause/resume cycles work correctly."""
formatter = ConsoleFormatter()
# Resume (no-op)
formatter.resume_live_updates()
mock_live = MagicMock(spec=Live)
formatter._live = mock_live
formatter._live_paused = False
# Create a new mock for the next session
mock_live_instance_2 = MagicMock()
mock_live_class.return_value = mock_live_instance_2
formatter.pause_live_updates()
assert formatter._live_paused
mock_live.stop.assert_called_once()
assert formatter._live is None # Live session should be cleared
formatter.resume_live_updates()
assert not formatter._live_paused
formatter.pause_live_updates()
assert formatter._live_paused
formatter.resume_live_updates()
assert not formatter._live_paused
def test_pause_resume_state_initialization(self):
"""Test that _live_paused is properly initialized."""
formatter = ConsoleFormatter()
assert hasattr(formatter, "_live_paused")
assert not formatter._live_paused
# Streaming again creates new session
formatter.handle_llm_stream_chunk("chunk 2", call_type=None)
assert formatter._streaming_live == mock_live_instance_2