Implement user input handling in Flows (#4490)

* Implement user input handling in Flow class
This commit is contained in:
João Moura
2026-02-16 13:41:03 -08:00
committed by GitHub
parent 4aedd58829
commit 84d57c7a24
12 changed files with 1929 additions and 23 deletions

View File

@@ -2,7 +2,30 @@ import subprocess
import click
from crewai.cli.utils import get_crews
from crewai.cli.utils import get_crews, get_flows
from crewai.flow import Flow
def _reset_flow_memory(flow: Flow) -> None:
"""Reset memory for a single flow instance.
Handles Memory, MemoryScope (both have .reset()), and MemorySlice
(delegates to the underlying ._memory). Silently succeeds when the
storage directory does not exist yet (nothing to reset).
Args:
flow: The flow instance whose memory should be reset.
"""
mem = flow.memory
if mem is None:
return
try:
if hasattr(mem, "reset"):
mem.reset()
elif hasattr(mem, "_memory") and hasattr(mem._memory, "reset"):
mem._memory.reset()
except (FileNotFoundError, OSError):
pass
def reset_memories_command(
@@ -12,7 +35,7 @@ def reset_memories_command(
kickoff_outputs: bool,
all: bool,
) -> None:
"""Reset the crew memories.
"""Reset the crew and flow memories.
Args:
memory: Whether to reset the unified memory.
@@ -29,8 +52,11 @@ def reset_memories_command(
return
crews = get_crews()
if not crews:
raise ValueError("No crew found.")
flows = get_flows()
if not crews and not flows:
raise ValueError("No crew or flow found.")
for crew in crews:
if all:
crew.reset_memories(command_type="all")
@@ -59,6 +85,20 @@ def reset_memories_command(
f"[Crew ({crew.name if crew.name else crew.id})] Agents knowledge has been reset."
)
for flow in flows:
flow_name = flow.name or flow.__class__.__name__
if all:
_reset_flow_memory(flow)
click.echo(
f"[Flow ({flow_name})] Reset memories command has been completed."
)
continue
if memory:
_reset_flow_memory(flow)
click.echo(
f"[Flow ({flow_name})] Memory has been reset."
)
except subprocess.CalledProcessError as e:
click.echo(f"An error occurred while resetting the memories: {e}", err=True)
click.echo(e.output, err=True)

View File

@@ -386,6 +386,109 @@ def fetch_crews(module_attr: Any) -> list[Crew]:
return crew_instances
def get_flow_instance(module_attr: Any) -> Flow | None:
"""Check if a module attribute is a user-defined Flow subclass and return an instance.
Args:
module_attr: An attribute from a loaded module.
Returns:
A Flow instance if the attribute is a valid user-defined Flow subclass,
None otherwise.
"""
if (
isinstance(module_attr, type)
and issubclass(module_attr, Flow)
and module_attr is not Flow
):
try:
return module_attr()
except Exception:
return None
return None
_SKIP_DIRS = frozenset(
{".venv", "venv", ".git", "__pycache__", "node_modules", ".tox", ".nox"}
)
def get_flows(flow_path: str = "main.py") -> list[Flow]:
"""Get the flow instances from project files.
Walks the project directory looking for files matching ``flow_path``
(default ``main.py``), loads each module, and extracts Flow subclass
instances. Directories that are clearly not user source code (virtual
environments, ``.git``, etc.) are pruned to avoid noisy import errors.
Args:
flow_path: Filename to search for (default ``main.py``).
Returns:
A list of discovered Flow instances.
"""
flow_instances: list[Flow] = []
try:
current_dir = os.getcwd()
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
src_dir = os.path.join(current_dir, "src")
if os.path.isdir(src_dir) and src_dir not in sys.path:
sys.path.insert(0, src_dir)
search_paths = [".", "src"] if os.path.isdir("src") else ["."]
for search_path in search_paths:
for root, dirs, files in os.walk(search_path):
dirs[:] = [
d
for d in dirs
if d not in _SKIP_DIRS and not d.startswith(".")
]
if flow_path in files and "cli/templates" not in root:
file_os_path = os.path.join(root, flow_path)
try:
spec = importlib.util.spec_from_file_location(
"flow_module", file_os_path
)
if not spec or not spec.loader:
continue
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
try:
spec.loader.exec_module(module)
for attr_name in dir(module):
module_attr = getattr(module, attr_name)
try:
if flow_instance := get_flow_instance(
module_attr
):
flow_instances.append(flow_instance)
except Exception: # noqa: S112
continue
if flow_instances:
break
except Exception: # noqa: S112
continue
except (ImportError, AttributeError):
continue
if flow_instances:
break
except Exception: # noqa: S110
pass
return flow_instances
def is_valid_tool(obj: Any) -> bool:
from crewai.tools.base_tool import Tool

View File

@@ -120,6 +120,52 @@ class FlowPlotEvent(FlowEvent):
type: str = "flow_plot"
class FlowInputRequestedEvent(FlowEvent):
"""Event emitted when a flow requests user input via ``Flow.ask()``.
This event is emitted before the flow suspends waiting for user input,
allowing UI frameworks and observability tools to know when a flow
needs user interaction.
Attributes:
flow_name: Name of the flow requesting input.
method_name: Name of the flow method that called ``ask()``.
message: The question or prompt being shown to the user.
metadata: Optional metadata sent with the question (e.g., user ID,
channel, session context).
"""
method_name: str
message: str
metadata: dict[str, Any] | None = None
type: str = "flow_input_requested"
class FlowInputReceivedEvent(FlowEvent):
"""Event emitted when user input is received after ``Flow.ask()``.
This event is emitted after the user provides input (or the request
times out), allowing UI frameworks and observability tools to track
input collection.
Attributes:
flow_name: Name of the flow that received input.
method_name: Name of the flow method that called ``ask()``.
message: The original question or prompt.
response: The user's response, or None if timed out / unavailable.
metadata: Optional metadata sent with the question.
response_metadata: Optional metadata from the provider about the
response (e.g., who responded, thread ID, timestamps).
"""
method_name: str
message: str
response: str | None = None
metadata: dict[str, Any] | None = None
response_metadata: dict[str, Any] | None = None
type: str = "flow_input_received"
class HumanFeedbackRequestedEvent(FlowEvent):
"""Event emitted when human feedback is requested.

View File

@@ -7,6 +7,7 @@ from crewai.flow.async_feedback import (
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.flow.flow_config import flow_config
from crewai.flow.human_feedback import HumanFeedbackResult, human_feedback
from crewai.flow.input_provider import InputProvider, InputResponse
from crewai.flow.persistence import persist
from crewai.flow.visualization import (
FlowStructure,
@@ -22,6 +23,8 @@ __all__ = [
"HumanFeedbackPending",
"HumanFeedbackProvider",
"HumanFeedbackResult",
"InputProvider",
"InputResponse",
"PendingFeedbackContext",
"and_",
"build_flow_structure",

View File

@@ -1,7 +1,8 @@
"""Default provider implementations for human feedback.
"""Default provider implementations for human feedback and user input.
This module provides the ConsoleProvider, which is the default synchronous
provider that collects feedback via console input.
provider that collects both feedback (for ``@human_feedback``) and user input
(for ``Flow.ask()``) via console.
"""
from __future__ import annotations
@@ -16,20 +17,23 @@ if TYPE_CHECKING:
class ConsoleProvider:
"""Default synchronous console-based feedback provider.
"""Default synchronous console-based provider for feedback and input.
This provider blocks execution and waits for console input from the user.
It displays the method output with formatting and prompts for feedback.
It serves two purposes:
- **Feedback** (``request_feedback``): Used by ``@human_feedback`` to
display method output and collect review feedback.
- **Input** (``request_input``): Used by ``Flow.ask()`` to prompt the
user with a question and collect a response.
This is the default provider used when no custom provider is specified
in the @human_feedback decorator.
in the ``@human_feedback`` decorator or on the Flow's ``input_provider``.
Example:
Example (feedback):
```python
from crewai.flow.async_feedback import ConsoleProvider
# Explicitly use console provider
@human_feedback(
message="Review this:",
provider=ConsoleProvider(),
@@ -37,9 +41,20 @@ class ConsoleProvider:
def my_method(self):
return "Content to review"
```
Example (input):
```python
from crewai.flow import Flow, start
class MyFlow(Flow):
@start()
def gather_info(self):
topic = self.ask("What topic should we research?")
return topic
```
"""
def __init__(self, verbose: bool = True):
def __init__(self, verbose: bool = True) -> None:
"""Initialize the console provider.
Args:
@@ -124,3 +139,55 @@ class ConsoleProvider:
finally:
# Resume live updates
formatter.resume_live_updates()
def request_input(
self,
message: str,
flow: Flow[Any],
metadata: dict[str, Any] | None = None,
) -> str | None:
"""Request user input via console (blocking).
Displays the prompt message with formatting and waits for the user
to type their response. Used by ``Flow.ask()``.
Unlike ``request_feedback``, this method does not display an
"OUTPUT FOR REVIEW" panel or emit feedback-specific events (those
are handled by ``ask()`` itself).
Args:
message: The question or prompt to display to the user.
flow: The Flow instance requesting input.
metadata: Optional metadata from the caller. Ignored by the
console provider (console has no concept of user routing).
Returns:
The user's input as a stripped string. Returns empty string
if user presses Enter without input. Never returns None
(console input is always available).
"""
from crewai.events.event_listener import event_listener
# Pause live updates during human input
formatter = event_listener.formatter
formatter.pause_live_updates()
try:
console = formatter.console
if self.verbose:
console.print()
console.print(message, style="yellow")
console.print()
response = input(">>> \n").strip()
else:
response = input(f"{message} ").strip()
# Add line break after input so formatter output starts clean
console.print()
return response
finally:
# Resume live updates
formatter.resume_live_updates()

View File

@@ -77,7 +77,7 @@ from crewai.flow.flow_wrappers import (
StartMethod,
)
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData, FlowMethodName, PendingListenerKey
from crewai.flow.types import FlowExecutionData, FlowMethodName, InputHistoryEntry, PendingListenerKey
from crewai.flow.utils import (
_extract_all_methods,
_extract_all_methods_recursive,
@@ -738,6 +738,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
tracing: bool | None = None
stream: bool = False
memory: Any = None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
input_provider: Any = None # InputProvider | None; per-flow override for self.ask()
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
class _FlowGeneric(cls): # type: ignore
@@ -784,6 +785,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._pending_feedback_context: PendingFeedbackContext | None = None
self.suppress_flow_events: bool = suppress_flow_events
# User input history (for self.ask())
self._input_history: list[InputHistoryEntry] = []
# Initialize state with initial values
self._state = self._create_initial_state()
self.tracing = tracing
@@ -2119,15 +2123,24 @@ class Flow(Generic[T], metaclass=FlowMeta):
if future:
self._event_futures.append(future)
if asyncio.iscoroutinefunction(method):
result = await method(*args, **kwargs)
else:
# Run sync methods in thread pool for isolation
# This allows Agent.kickoff() to work synchronously inside Flow methods
import contextvars
# Set method name in context so ask() can read it without
# stack inspection. Must happen before copy_context() so the
# value propagates into the thread pool for sync methods.
from crewai.flow.flow_context import current_flow_method_name
ctx = contextvars.copy_context()
result = await asyncio.to_thread(ctx.run, method, *args, **kwargs)
method_name_token = current_flow_method_name.set(method_name)
try:
if asyncio.iscoroutinefunction(method):
result = await method(*args, **kwargs)
else:
# Run sync methods in thread pool for isolation
# This allows Agent.kickoff() to work synchronously inside Flow methods
import contextvars
ctx = contextvars.copy_context()
result = await asyncio.to_thread(ctx.run, method, *args, **kwargs)
finally:
current_flow_method_name.reset(method_name_token)
# Auto-await coroutines returned from sync methods (enables AgentExecutor pattern)
if asyncio.iscoroutine(result):
@@ -2582,6 +2595,201 @@ class Flow(Generic[T], metaclass=FlowMeta):
logger.error(f"Error executing listener {listener_name}: {e}")
raise
# ── User Input (self.ask) ────────────────────────────────────────
def _resolve_input_provider(self) -> Any:
"""Resolve the input provider using the priority chain.
Resolution order:
1. ``self.input_provider`` (per-flow override)
2. ``flow_config.input_provider`` (global default)
3. ``ConsoleInputProvider()`` (built-in fallback)
Returns:
An object implementing the ``InputProvider`` protocol.
"""
from crewai.flow.async_feedback.providers import ConsoleProvider
from crewai.flow.flow_config import flow_config
if self.input_provider is not None:
return self.input_provider
if flow_config.input_provider is not None:
return flow_config.input_provider
return ConsoleProvider()
def _checkpoint_state_for_ask(self) -> None:
"""Auto-checkpoint flow state before waiting for user input.
If persistence is configured, saves the current state so that
``self.state`` is recoverable even if the process crashes while
waiting for input.
This is best-effort: if persistence is not configured, this is a no-op.
"""
if self._persistence is None:
return
try:
state_data = (
self._state
if isinstance(self._state, dict)
else self._state.model_dump()
)
self._persistence.save_state(
flow_uuid=self.flow_id,
method_name="_ask_checkpoint",
state_data=state_data,
)
except Exception:
logger.debug("Failed to checkpoint state before ask()", exc_info=True)
def ask(
self,
message: str,
timeout: float | None = None,
metadata: dict[str, Any] | None = None,
) -> str | None:
"""Request input from the user during flow execution.
Blocks the current thread until the user provides input or the
timeout expires. Works in both sync and async flow methods (the
flow framework runs sync methods in a thread pool via
``asyncio.to_thread``, so the event loop stays free).
Timeout ensures flows always terminate. When timeout expires,
``None`` is returned, enabling the pattern::
while (msg := self.ask("You: ", timeout=300)) is not None:
process(msg)
Before waiting for input, the current ``self.state`` is automatically
checkpointed to persistence (if configured) for durability.
Args:
message: The question or prompt to display to the user.
timeout: Maximum seconds to wait for input. ``None`` means
wait indefinitely. When timeout expires, returns ``None``.
Note: timeout is best-effort for the provider call --
``ask()`` returns ``None`` promptly, but the underlying
``request_input()`` may continue running in a background
thread until it completes naturally. Network providers
should implement their own internal timeouts.
metadata: Optional metadata to send to the input provider,
such as user ID, channel, session context. The provider
can use this to route the question to the right recipient.
Returns:
The user's input as a string, or ``None`` on timeout, disconnect,
or provider error. Empty string ``""`` means the user pressed
Enter without typing (intentional empty input).
Example:
```python
class MyFlow(Flow):
@start()
def gather_info(self):
topic = self.ask(
"What topic should we research?",
metadata={"user_id": "u123", "channel": "#research"},
)
if topic is None:
return "No input received"
return topic
```
"""
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError
from datetime import datetime
from crewai.events.types.flow_events import (
FlowInputReceivedEvent,
FlowInputRequestedEvent,
)
from crewai.flow.flow_context import current_flow_method_name
from crewai.flow.input_provider import InputResponse
method_name = current_flow_method_name.get("unknown")
# Emit input requested event
crewai_event_bus.emit(
self,
FlowInputRequestedEvent(
type="flow_input_requested",
flow_name=self.name or self.__class__.__name__,
method_name=method_name,
message=message,
metadata=metadata,
),
)
# Auto-checkpoint state before waiting
self._checkpoint_state_for_ask()
provider = self._resolve_input_provider()
raw: str | InputResponse | None = None
try:
if timeout is not None:
# Manual executor management to avoid shutdown(wait=True)
# deadlock when the provider call outlives the timeout.
executor = ThreadPoolExecutor(max_workers=1)
future = executor.submit(
provider.request_input, message, self, metadata
)
try:
raw = future.result(timeout=timeout)
except FuturesTimeoutError:
future.cancel()
raw = None
finally:
# wait=False so we don't block if the provider is still
# running (e.g. input() stuck waiting for user).
# cancel_futures=True cleans up any queued-but-not-started tasks.
executor.shutdown(wait=False, cancel_futures=True)
else:
raw = provider.request_input(message, self, metadata=metadata)
except KeyboardInterrupt:
raise
except Exception:
logger.debug("Input provider error in ask()", exc_info=True)
raw = None
# Normalize provider response: str, InputResponse, or None
response: str | None = None
response_metadata: dict[str, Any] | None = None
if isinstance(raw, InputResponse):
response = raw.text
response_metadata = raw.metadata
elif isinstance(raw, str):
response = raw
else:
response = None
# Record in history
self._input_history.append({
"message": message,
"response": response,
"method_name": method_name,
"timestamp": datetime.now(),
"metadata": metadata,
"response_metadata": response_metadata,
})
# Emit input received event
crewai_event_bus.emit(
self,
FlowInputReceivedEvent(
type="flow_input_received",
flow_name=self.name or self.__class__.__name__,
method_name=method_name,
message=message,
response=response,
metadata=metadata,
response_metadata=response_metadata,
),
)
return response
def _request_human_feedback(
self,
message: str,

View File

@@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from crewai.flow.async_feedback.types import HumanFeedbackProvider
from crewai.flow.input_provider import InputProvider
class FlowConfig:
@@ -20,10 +21,15 @@ class FlowConfig:
hitl_provider: The human-in-the-loop feedback provider.
Defaults to None (uses console input).
Can be overridden by deployments at startup.
input_provider: The input provider used by ``Flow.ask()``.
Defaults to None (uses ``ConsoleProvider``).
Can be overridden by
deployments at startup.
"""
def __init__(self) -> None:
self._hitl_provider: HumanFeedbackProvider | None = None
self._input_provider: InputProvider | None = None
@property
def hitl_provider(self) -> Any:
@@ -35,6 +41,32 @@ class FlowConfig:
"""Set the HITL provider."""
self._hitl_provider = provider
@property
def input_provider(self) -> Any:
"""Get the configured input provider for ``Flow.ask()``.
Returns:
The configured InputProvider instance, or None if not set
(in which case ``ConsoleInputProvider`` is used as default).
"""
return self._input_provider
@input_provider.setter
def input_provider(self, provider: Any) -> None:
"""Set the input provider for ``Flow.ask()``.
Args:
provider: An object implementing the ``InputProvider`` protocol.
Example:
```python
from crewai.flow import flow_config
flow_config.input_provider = WebSocketInputProvider(...)
```
"""
self._input_provider = provider
# Singleton instance
flow_config = FlowConfig()

View File

@@ -14,3 +14,7 @@ current_flow_request_id: contextvars.ContextVar[str | None] = contextvars.Contex
current_flow_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"flow_id", default=None
)
current_flow_method_name: contextvars.ContextVar[str] = contextvars.ContextVar(
"flow_method_name", default="unknown"
)

View File

@@ -0,0 +1,151 @@
"""Input provider protocol for Flow.ask().
This module provides the InputProvider protocol and InputResponse dataclass
used by Flow.ask() to request input from users during flow execution.
The default implementation is ``ConsoleProvider`` (from
``crewai.flow.async_feedback.providers``), which serves both feedback
and input collection via console.
Example (default console input):
```python
from crewai.flow import Flow, start
class MyFlow(Flow):
@start()
def gather_info(self):
topic = self.ask("What topic should we research?")
return topic
```
Example (custom provider with metadata):
```python
from crewai.flow import Flow, start
from crewai.flow.input_provider import InputProvider, InputResponse
class SlackProvider:
def request_input(self, message, flow, metadata=None):
channel = metadata.get("channel", "#general") if metadata else "#general"
thread = self.post_question(channel, message)
reply = self.wait_for_reply(thread)
return InputResponse(
text=reply.text,
metadata={"responded_by": reply.user_id, "thread_id": thread.id},
)
class MyFlow(Flow):
input_provider = SlackProvider()
@start()
def gather_info(self):
topic = self.ask("What topic?", metadata={"channel": "#research"})
return topic
```
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
if TYPE_CHECKING:
from crewai.flow.flow import Flow
@dataclass
class InputResponse:
"""Response from an InputProvider, optionally carrying metadata.
Simple providers can just return a string from ``request_input()``.
Providers that need to send metadata back (e.g., who responded,
thread ID, external timestamps) return an ``InputResponse`` instead.
``ask()`` normalizes both cases -- callers always get ``str | None``.
The response metadata is stored in ``_input_history`` and emitted
in ``FlowInputReceivedEvent``.
Attributes:
text: The user's input text, or None if unavailable.
metadata: Optional metadata from the provider about the response
(e.g., who responded, thread ID, timestamps).
Example:
```python
class MyProvider:
def request_input(self, message, flow, metadata=None):
response = get_response_from_external_system(message)
return InputResponse(
text=response.text,
metadata={"responded_by": response.user_id},
)
```
"""
text: str | None
metadata: dict[str, Any] | None = field(default=None)
@runtime_checkable
class InputProvider(Protocol):
"""Protocol for user input collection strategies.
Implement this protocol to create custom input providers that integrate
with external systems like websockets, web UIs, Slack, or custom APIs.
The default provider is ``ConsoleProvider``, which blocks waiting for
console input via Python's built-in ``input()`` function.
Providers are always synchronous. The flow framework runs sync methods
in a thread pool (via ``asyncio.to_thread``), so ``ask()`` never blocks
the event loop even inside async flow methods.
Providers can return either:
- ``str | None`` for simple cases (no response metadata)
- ``InputResponse`` when they need to send metadata back with the answer
Example (simple):
```python
class SimpleProvider:
def request_input(self, message: str, flow: Flow) -> str | None:
return input(message)
```
Example (with metadata):
```python
class SlackProvider:
def request_input(self, message, flow, metadata=None):
channel = metadata.get("channel") if metadata else "#general"
reply = self.post_and_wait(channel, message)
return InputResponse(
text=reply.text,
metadata={"responded_by": reply.user_id},
)
```
"""
def request_input(
self,
message: str,
flow: Flow[Any],
metadata: dict[str, Any] | None = None,
) -> str | InputResponse | None:
"""Request input from the user.
Args:
message: The question or prompt to display to the user.
flow: The Flow instance requesting input. Can be used to
access flow state, name, or other context.
metadata: Optional metadata from the caller, such as user ID,
channel, session context, etc. Providers can use this to
route the question to the right recipient.
Returns:
The user's input as a string, an ``InputResponse`` with text
and optional response metadata, or None if input is unavailable
(e.g., user cancelled, connection dropped).
"""
...

View File

@@ -4,6 +4,7 @@ This module contains TypedDict definitions and type aliases used throughout
the Flow system.
"""
from datetime import datetime
from typing import (
Annotated,
Any,
@@ -101,6 +102,30 @@ class FlowData(TypedDict):
flow_methods_attributes: list[FlowMethodData]
class InputHistoryEntry(TypedDict):
"""A single entry in the flow's input history from ``self.ask()``.
Each call to ``Flow.ask()`` appends one entry recording the question,
the user's response, which method asked, and any metadata exchanged
between the caller and the input provider.
Attributes:
message: The question or prompt that was displayed to the user.
response: The user's response, or None on timeout/error.
method_name: The flow method that called ``ask()``.
timestamp: When the input was received.
metadata: Metadata sent with the question (caller to provider).
response_metadata: Metadata received with the answer (provider to caller).
"""
message: str
response: str | None
method_name: str
timestamp: datetime
metadata: dict[str, Any] | None
response_metadata: dict[str, Any] | None
class FlowExecutionData(TypedDict):
"""Flow execution data.

View File

@@ -66,7 +66,9 @@ def mock_crew():
def mock_get_crews(mock_crew):
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
) as mock_get_crew:
) as mock_get_crew, mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[]
):
yield mock_get_crew
@@ -193,6 +195,79 @@ def test_reset_memory_from_many_crews(mock_get_crews, runner):
assert call_count == 2, "reset_memories should have been called twice"
@pytest.fixture
def mock_flow():
_mock = mock.Mock()
_mock.name = "TestFlow"
_mock.memory = mock.Mock()
_mock.memory.reset = mock.Mock()
return _mock
@pytest.fixture
def mock_get_flows(mock_flow):
with mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
) as mock_get_flow, mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[]
):
yield mock_get_flow
def test_reset_flow_memory(mock_get_flows, mock_flow, runner):
result = runner.invoke(reset_memories, ["-m"])
mock_flow.memory.reset.assert_called_once()
assert "[Flow (TestFlow)] Memory has been reset." in result.output
def test_reset_flow_all_memories(mock_get_flows, mock_flow, runner):
result = runner.invoke(reset_memories, ["-a"])
mock_flow.memory.reset.assert_called_once()
assert "[Flow (TestFlow)] Reset memories command has been completed." in result.output
def test_reset_flow_knowledge_no_effect(mock_get_flows, mock_flow, runner):
result = runner.invoke(reset_memories, ["--knowledge"])
mock_flow.memory.reset.assert_not_called()
assert "[Flow (TestFlow)]" not in result.output
def test_reset_no_crew_or_flow_found(runner):
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[]
), mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[]
):
result = runner.invoke(reset_memories, ["-m"])
assert "No crew or flow found." in result.output
def test_reset_crew_and_flow_memory(mock_crew, mock_flow, runner):
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew]
), mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
):
result = runner.invoke(reset_memories, ["-m"])
mock_crew.reset_memories.assert_called_once_with(command_type="memory")
mock_flow.memory.reset.assert_called_once()
assert f"[Crew ({mock_crew.name})] Memory has been reset." in result.output
assert "[Flow (TestFlow)] Memory has been reset." in result.output
def test_reset_flow_memory_none(runner):
mock_flow = mock.Mock()
mock_flow.name = "NoMemFlow"
mock_flow.memory = None
with mock.patch(
"crewai.cli.reset_memories_command.get_crews", return_value=[]
), mock.patch(
"crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow]
):
result = runner.invoke(reset_memories, ["-m"])
assert "[Flow (NoMemFlow)] Memory has been reset." in result.output
def test_reset_no_memory_flags(runner):
result = runner.invoke(
reset_memories,

File diff suppressed because it is too large Load Diff