From 84d57c7a24614373d7f845af6d59d1c4e1df094d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Moura?= Date: Mon, 16 Feb 2026 13:41:03 -0800 Subject: [PATCH] Implement user input handling in Flows (#4490) * Implement user input handling in Flow class --- .../src/crewai/cli/reset_memories_command.py | 48 +- lib/crewai/src/crewai/cli/utils.py | 103 ++ .../src/crewai/events/types/flow_events.py | 46 + lib/crewai/src/crewai/flow/__init__.py | 3 + .../crewai/flow/async_feedback/providers.py | 85 +- lib/crewai/src/crewai/flow/flow.py | 226 +++- lib/crewai/src/crewai/flow/flow_config.py | 32 + lib/crewai/src/crewai/flow/flow_context.py | 4 + lib/crewai/src/crewai/flow/input_provider.py | 151 +++ lib/crewai/src/crewai/flow/types.py | 25 + lib/crewai/tests/cli/test_cli.py | 77 +- lib/crewai/tests/test_flow_ask.py | 1152 +++++++++++++++++ 12 files changed, 1929 insertions(+), 23 deletions(-) create mode 100644 lib/crewai/src/crewai/flow/input_provider.py create mode 100644 lib/crewai/tests/test_flow_ask.py diff --git a/lib/crewai/src/crewai/cli/reset_memories_command.py b/lib/crewai/src/crewai/cli/reset_memories_command.py index 5d3d73de9..85971f94f 100644 --- a/lib/crewai/src/crewai/cli/reset_memories_command.py +++ b/lib/crewai/src/crewai/cli/reset_memories_command.py @@ -2,7 +2,30 @@ import subprocess import click -from crewai.cli.utils import get_crews +from crewai.cli.utils import get_crews, get_flows +from crewai.flow import Flow + + +def _reset_flow_memory(flow: Flow) -> None: + """Reset memory for a single flow instance. + + Handles Memory, MemoryScope (both have .reset()), and MemorySlice + (delegates to the underlying ._memory). Silently succeeds when the + storage directory does not exist yet (nothing to reset). + + Args: + flow: The flow instance whose memory should be reset. + """ + mem = flow.memory + if mem is None: + return + try: + if hasattr(mem, "reset"): + mem.reset() + elif hasattr(mem, "_memory") and hasattr(mem._memory, "reset"): + mem._memory.reset() + except (FileNotFoundError, OSError): + pass def reset_memories_command( @@ -12,7 +35,7 @@ def reset_memories_command( kickoff_outputs: bool, all: bool, ) -> None: - """Reset the crew memories. + """Reset the crew and flow memories. Args: memory: Whether to reset the unified memory. @@ -29,8 +52,11 @@ def reset_memories_command( return crews = get_crews() - if not crews: - raise ValueError("No crew found.") + flows = get_flows() + + if not crews and not flows: + raise ValueError("No crew or flow found.") + for crew in crews: if all: crew.reset_memories(command_type="all") @@ -59,6 +85,20 @@ def reset_memories_command( f"[Crew ({crew.name if crew.name else crew.id})] Agents knowledge has been reset." ) + for flow in flows: + flow_name = flow.name or flow.__class__.__name__ + if all: + _reset_flow_memory(flow) + click.echo( + f"[Flow ({flow_name})] Reset memories command has been completed." + ) + continue + if memory: + _reset_flow_memory(flow) + click.echo( + f"[Flow ({flow_name})] Memory has been reset." + ) + except subprocess.CalledProcessError as e: click.echo(f"An error occurred while resetting the memories: {e}", err=True) click.echo(e.output, err=True) diff --git a/lib/crewai/src/crewai/cli/utils.py b/lib/crewai/src/crewai/cli/utils.py index b73f9f76b..6ee181ea1 100644 --- a/lib/crewai/src/crewai/cli/utils.py +++ b/lib/crewai/src/crewai/cli/utils.py @@ -386,6 +386,109 @@ def fetch_crews(module_attr: Any) -> list[Crew]: return crew_instances +def get_flow_instance(module_attr: Any) -> Flow | None: + """Check if a module attribute is a user-defined Flow subclass and return an instance. + + Args: + module_attr: An attribute from a loaded module. + + Returns: + A Flow instance if the attribute is a valid user-defined Flow subclass, + None otherwise. + """ + if ( + isinstance(module_attr, type) + and issubclass(module_attr, Flow) + and module_attr is not Flow + ): + try: + return module_attr() + except Exception: + return None + return None + + +_SKIP_DIRS = frozenset( + {".venv", "venv", ".git", "__pycache__", "node_modules", ".tox", ".nox"} +) + + +def get_flows(flow_path: str = "main.py") -> list[Flow]: + """Get the flow instances from project files. + + Walks the project directory looking for files matching ``flow_path`` + (default ``main.py``), loads each module, and extracts Flow subclass + instances. Directories that are clearly not user source code (virtual + environments, ``.git``, etc.) are pruned to avoid noisy import errors. + + Args: + flow_path: Filename to search for (default ``main.py``). + + Returns: + A list of discovered Flow instances. + """ + flow_instances: list[Flow] = [] + try: + current_dir = os.getcwd() + if current_dir not in sys.path: + sys.path.insert(0, current_dir) + + src_dir = os.path.join(current_dir, "src") + if os.path.isdir(src_dir) and src_dir not in sys.path: + sys.path.insert(0, src_dir) + + search_paths = [".", "src"] if os.path.isdir("src") else ["."] + + for search_path in search_paths: + for root, dirs, files in os.walk(search_path): + dirs[:] = [ + d + for d in dirs + if d not in _SKIP_DIRS and not d.startswith(".") + ] + if flow_path in files and "cli/templates" not in root: + file_os_path = os.path.join(root, flow_path) + try: + spec = importlib.util.spec_from_file_location( + "flow_module", file_os_path + ) + if not spec or not spec.loader: + continue + + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + + try: + spec.loader.exec_module(module) + + for attr_name in dir(module): + module_attr = getattr(module, attr_name) + try: + if flow_instance := get_flow_instance( + module_attr + ): + flow_instances.append(flow_instance) + except Exception: # noqa: S112 + continue + + if flow_instances: + break + + except Exception: # noqa: S112 + continue + + except (ImportError, AttributeError): + continue + + if flow_instances: + break + + except Exception: # noqa: S110 + pass + + return flow_instances + + def is_valid_tool(obj: Any) -> bool: from crewai.tools.base_tool import Tool diff --git a/lib/crewai/src/crewai/events/types/flow_events.py b/lib/crewai/src/crewai/events/types/flow_events.py index 826722762..3eea1bbdd 100644 --- a/lib/crewai/src/crewai/events/types/flow_events.py +++ b/lib/crewai/src/crewai/events/types/flow_events.py @@ -120,6 +120,52 @@ class FlowPlotEvent(FlowEvent): type: str = "flow_plot" +class FlowInputRequestedEvent(FlowEvent): + """Event emitted when a flow requests user input via ``Flow.ask()``. + + This event is emitted before the flow suspends waiting for user input, + allowing UI frameworks and observability tools to know when a flow + needs user interaction. + + Attributes: + flow_name: Name of the flow requesting input. + method_name: Name of the flow method that called ``ask()``. + message: The question or prompt being shown to the user. + metadata: Optional metadata sent with the question (e.g., user ID, + channel, session context). + """ + + method_name: str + message: str + metadata: dict[str, Any] | None = None + type: str = "flow_input_requested" + + +class FlowInputReceivedEvent(FlowEvent): + """Event emitted when user input is received after ``Flow.ask()``. + + This event is emitted after the user provides input (or the request + times out), allowing UI frameworks and observability tools to track + input collection. + + Attributes: + flow_name: Name of the flow that received input. + method_name: Name of the flow method that called ``ask()``. + message: The original question or prompt. + response: The user's response, or None if timed out / unavailable. + metadata: Optional metadata sent with the question. + response_metadata: Optional metadata from the provider about the + response (e.g., who responded, thread ID, timestamps). + """ + + method_name: str + message: str + response: str | None = None + metadata: dict[str, Any] | None = None + response_metadata: dict[str, Any] | None = None + type: str = "flow_input_received" + + class HumanFeedbackRequestedEvent(FlowEvent): """Event emitted when human feedback is requested. diff --git a/lib/crewai/src/crewai/flow/__init__.py b/lib/crewai/src/crewai/flow/__init__.py index 2e31d9220..ec4a3ac5e 100644 --- a/lib/crewai/src/crewai/flow/__init__.py +++ b/lib/crewai/src/crewai/flow/__init__.py @@ -7,6 +7,7 @@ from crewai.flow.async_feedback import ( from crewai.flow.flow import Flow, and_, listen, or_, router, start from crewai.flow.flow_config import flow_config from crewai.flow.human_feedback import HumanFeedbackResult, human_feedback +from crewai.flow.input_provider import InputProvider, InputResponse from crewai.flow.persistence import persist from crewai.flow.visualization import ( FlowStructure, @@ -22,6 +23,8 @@ __all__ = [ "HumanFeedbackPending", "HumanFeedbackProvider", "HumanFeedbackResult", + "InputProvider", + "InputResponse", "PendingFeedbackContext", "and_", "build_flow_structure", diff --git a/lib/crewai/src/crewai/flow/async_feedback/providers.py b/lib/crewai/src/crewai/flow/async_feedback/providers.py index e86c0a747..65055d650 100644 --- a/lib/crewai/src/crewai/flow/async_feedback/providers.py +++ b/lib/crewai/src/crewai/flow/async_feedback/providers.py @@ -1,7 +1,8 @@ -"""Default provider implementations for human feedback. +"""Default provider implementations for human feedback and user input. This module provides the ConsoleProvider, which is the default synchronous -provider that collects feedback via console input. +provider that collects both feedback (for ``@human_feedback``) and user input +(for ``Flow.ask()``) via console. """ from __future__ import annotations @@ -16,20 +17,23 @@ if TYPE_CHECKING: class ConsoleProvider: - """Default synchronous console-based feedback provider. + """Default synchronous console-based provider for feedback and input. This provider blocks execution and waits for console input from the user. - It displays the method output with formatting and prompts for feedback. + It serves two purposes: + + - **Feedback** (``request_feedback``): Used by ``@human_feedback`` to + display method output and collect review feedback. + - **Input** (``request_input``): Used by ``Flow.ask()`` to prompt the + user with a question and collect a response. This is the default provider used when no custom provider is specified - in the @human_feedback decorator. + in the ``@human_feedback`` decorator or on the Flow's ``input_provider``. - Example: + Example (feedback): ```python from crewai.flow.async_feedback import ConsoleProvider - - # Explicitly use console provider @human_feedback( message="Review this:", provider=ConsoleProvider(), @@ -37,9 +41,20 @@ class ConsoleProvider: def my_method(self): return "Content to review" ``` + + Example (input): + ```python + from crewai.flow import Flow, start + + class MyFlow(Flow): + @start() + def gather_info(self): + topic = self.ask("What topic should we research?") + return topic + ``` """ - def __init__(self, verbose: bool = True): + def __init__(self, verbose: bool = True) -> None: """Initialize the console provider. Args: @@ -124,3 +139,55 @@ class ConsoleProvider: finally: # Resume live updates formatter.resume_live_updates() + + def request_input( + self, + message: str, + flow: Flow[Any], + metadata: dict[str, Any] | None = None, + ) -> str | None: + """Request user input via console (blocking). + + Displays the prompt message with formatting and waits for the user + to type their response. Used by ``Flow.ask()``. + + Unlike ``request_feedback``, this method does not display an + "OUTPUT FOR REVIEW" panel or emit feedback-specific events (those + are handled by ``ask()`` itself). + + Args: + message: The question or prompt to display to the user. + flow: The Flow instance requesting input. + metadata: Optional metadata from the caller. Ignored by the + console provider (console has no concept of user routing). + + Returns: + The user's input as a stripped string. Returns empty string + if user presses Enter without input. Never returns None + (console input is always available). + """ + from crewai.events.event_listener import event_listener + + # Pause live updates during human input + formatter = event_listener.formatter + formatter.pause_live_updates() + + try: + console = formatter.console + + if self.verbose: + console.print() + console.print(message, style="yellow") + console.print() + + response = input(">>> \n").strip() + else: + response = input(f"{message} ").strip() + + # Add line break after input so formatter output starts clean + console.print() + + return response + finally: + # Resume live updates + formatter.resume_live_updates() diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index c3ac1ad72..d8e74fc08 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -77,7 +77,7 @@ from crewai.flow.flow_wrappers import ( StartMethod, ) from crewai.flow.persistence.base import FlowPersistence -from crewai.flow.types import FlowExecutionData, FlowMethodName, PendingListenerKey +from crewai.flow.types import FlowExecutionData, FlowMethodName, InputHistoryEntry, PendingListenerKey from crewai.flow.utils import ( _extract_all_methods, _extract_all_methods_recursive, @@ -738,6 +738,7 @@ class Flow(Generic[T], metaclass=FlowMeta): tracing: bool | None = None stream: bool = False memory: Any = None # Memory | MemoryScope | MemorySlice | None; auto-created if not set + input_provider: Any = None # InputProvider | None; per-flow override for self.ask() def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: class _FlowGeneric(cls): # type: ignore @@ -784,6 +785,9 @@ class Flow(Generic[T], metaclass=FlowMeta): self._pending_feedback_context: PendingFeedbackContext | None = None self.suppress_flow_events: bool = suppress_flow_events + # User input history (for self.ask()) + self._input_history: list[InputHistoryEntry] = [] + # Initialize state with initial values self._state = self._create_initial_state() self.tracing = tracing @@ -2119,15 +2123,24 @@ class Flow(Generic[T], metaclass=FlowMeta): if future: self._event_futures.append(future) - if asyncio.iscoroutinefunction(method): - result = await method(*args, **kwargs) - else: - # Run sync methods in thread pool for isolation - # This allows Agent.kickoff() to work synchronously inside Flow methods - import contextvars + # Set method name in context so ask() can read it without + # stack inspection. Must happen before copy_context() so the + # value propagates into the thread pool for sync methods. + from crewai.flow.flow_context import current_flow_method_name - ctx = contextvars.copy_context() - result = await asyncio.to_thread(ctx.run, method, *args, **kwargs) + method_name_token = current_flow_method_name.set(method_name) + try: + if asyncio.iscoroutinefunction(method): + result = await method(*args, **kwargs) + else: + # Run sync methods in thread pool for isolation + # This allows Agent.kickoff() to work synchronously inside Flow methods + import contextvars + + ctx = contextvars.copy_context() + result = await asyncio.to_thread(ctx.run, method, *args, **kwargs) + finally: + current_flow_method_name.reset(method_name_token) # Auto-await coroutines returned from sync methods (enables AgentExecutor pattern) if asyncio.iscoroutine(result): @@ -2582,6 +2595,201 @@ class Flow(Generic[T], metaclass=FlowMeta): logger.error(f"Error executing listener {listener_name}: {e}") raise + # ── User Input (self.ask) ──────────────────────────────────────── + + def _resolve_input_provider(self) -> Any: + """Resolve the input provider using the priority chain. + + Resolution order: + 1. ``self.input_provider`` (per-flow override) + 2. ``flow_config.input_provider`` (global default) + 3. ``ConsoleInputProvider()`` (built-in fallback) + + Returns: + An object implementing the ``InputProvider`` protocol. + """ + from crewai.flow.async_feedback.providers import ConsoleProvider + from crewai.flow.flow_config import flow_config + + if self.input_provider is not None: + return self.input_provider + if flow_config.input_provider is not None: + return flow_config.input_provider + return ConsoleProvider() + + def _checkpoint_state_for_ask(self) -> None: + """Auto-checkpoint flow state before waiting for user input. + + If persistence is configured, saves the current state so that + ``self.state`` is recoverable even if the process crashes while + waiting for input. + + This is best-effort: if persistence is not configured, this is a no-op. + """ + if self._persistence is None: + return + try: + state_data = ( + self._state + if isinstance(self._state, dict) + else self._state.model_dump() + ) + self._persistence.save_state( + flow_uuid=self.flow_id, + method_name="_ask_checkpoint", + state_data=state_data, + ) + except Exception: + logger.debug("Failed to checkpoint state before ask()", exc_info=True) + + def ask( + self, + message: str, + timeout: float | None = None, + metadata: dict[str, Any] | None = None, + ) -> str | None: + """Request input from the user during flow execution. + + Blocks the current thread until the user provides input or the + timeout expires. Works in both sync and async flow methods (the + flow framework runs sync methods in a thread pool via + ``asyncio.to_thread``, so the event loop stays free). + + Timeout ensures flows always terminate. When timeout expires, + ``None`` is returned, enabling the pattern:: + + while (msg := self.ask("You: ", timeout=300)) is not None: + process(msg) + + Before waiting for input, the current ``self.state`` is automatically + checkpointed to persistence (if configured) for durability. + + Args: + message: The question or prompt to display to the user. + timeout: Maximum seconds to wait for input. ``None`` means + wait indefinitely. When timeout expires, returns ``None``. + Note: timeout is best-effort for the provider call -- + ``ask()`` returns ``None`` promptly, but the underlying + ``request_input()`` may continue running in a background + thread until it completes naturally. Network providers + should implement their own internal timeouts. + metadata: Optional metadata to send to the input provider, + such as user ID, channel, session context. The provider + can use this to route the question to the right recipient. + + Returns: + The user's input as a string, or ``None`` on timeout, disconnect, + or provider error. Empty string ``""`` means the user pressed + Enter without typing (intentional empty input). + + Example: + ```python + class MyFlow(Flow): + @start() + def gather_info(self): + topic = self.ask( + "What topic should we research?", + metadata={"user_id": "u123", "channel": "#research"}, + ) + if topic is None: + return "No input received" + return topic + ``` + """ + from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError + from datetime import datetime + + from crewai.events.types.flow_events import ( + FlowInputReceivedEvent, + FlowInputRequestedEvent, + ) + from crewai.flow.flow_context import current_flow_method_name + from crewai.flow.input_provider import InputResponse + + method_name = current_flow_method_name.get("unknown") + + # Emit input requested event + crewai_event_bus.emit( + self, + FlowInputRequestedEvent( + type="flow_input_requested", + flow_name=self.name or self.__class__.__name__, + method_name=method_name, + message=message, + metadata=metadata, + ), + ) + + # Auto-checkpoint state before waiting + self._checkpoint_state_for_ask() + + provider = self._resolve_input_provider() + raw: str | InputResponse | None = None + + try: + if timeout is not None: + # Manual executor management to avoid shutdown(wait=True) + # deadlock when the provider call outlives the timeout. + executor = ThreadPoolExecutor(max_workers=1) + future = executor.submit( + provider.request_input, message, self, metadata + ) + try: + raw = future.result(timeout=timeout) + except FuturesTimeoutError: + future.cancel() + raw = None + finally: + # wait=False so we don't block if the provider is still + # running (e.g. input() stuck waiting for user). + # cancel_futures=True cleans up any queued-but-not-started tasks. + executor.shutdown(wait=False, cancel_futures=True) + else: + raw = provider.request_input(message, self, metadata=metadata) + except KeyboardInterrupt: + raise + except Exception: + logger.debug("Input provider error in ask()", exc_info=True) + raw = None + + # Normalize provider response: str, InputResponse, or None + response: str | None = None + response_metadata: dict[str, Any] | None = None + + if isinstance(raw, InputResponse): + response = raw.text + response_metadata = raw.metadata + elif isinstance(raw, str): + response = raw + else: + response = None + + # Record in history + self._input_history.append({ + "message": message, + "response": response, + "method_name": method_name, + "timestamp": datetime.now(), + "metadata": metadata, + "response_metadata": response_metadata, + }) + + # Emit input received event + crewai_event_bus.emit( + self, + FlowInputReceivedEvent( + type="flow_input_received", + flow_name=self.name or self.__class__.__name__, + method_name=method_name, + message=message, + response=response, + metadata=metadata, + response_metadata=response_metadata, + ), + ) + + return response + def _request_human_feedback( self, message: str, diff --git a/lib/crewai/src/crewai/flow/flow_config.py b/lib/crewai/src/crewai/flow/flow_config.py index 8684cc3cf..a4a6bfbe4 100644 --- a/lib/crewai/src/crewai/flow/flow_config.py +++ b/lib/crewai/src/crewai/flow/flow_config.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from crewai.flow.async_feedback.types import HumanFeedbackProvider + from crewai.flow.input_provider import InputProvider class FlowConfig: @@ -20,10 +21,15 @@ class FlowConfig: hitl_provider: The human-in-the-loop feedback provider. Defaults to None (uses console input). Can be overridden by deployments at startup. + input_provider: The input provider used by ``Flow.ask()``. + Defaults to None (uses ``ConsoleProvider``). + Can be overridden by + deployments at startup. """ def __init__(self) -> None: self._hitl_provider: HumanFeedbackProvider | None = None + self._input_provider: InputProvider | None = None @property def hitl_provider(self) -> Any: @@ -35,6 +41,32 @@ class FlowConfig: """Set the HITL provider.""" self._hitl_provider = provider + @property + def input_provider(self) -> Any: + """Get the configured input provider for ``Flow.ask()``. + + Returns: + The configured InputProvider instance, or None if not set + (in which case ``ConsoleInputProvider`` is used as default). + """ + return self._input_provider + + @input_provider.setter + def input_provider(self, provider: Any) -> None: + """Set the input provider for ``Flow.ask()``. + + Args: + provider: An object implementing the ``InputProvider`` protocol. + + Example: + ```python + from crewai.flow import flow_config + + flow_config.input_provider = WebSocketInputProvider(...) + ``` + """ + self._input_provider = provider + # Singleton instance flow_config = FlowConfig() diff --git a/lib/crewai/src/crewai/flow/flow_context.py b/lib/crewai/src/crewai/flow/flow_context.py index ae9bd69f9..0ff6cf973 100644 --- a/lib/crewai/src/crewai/flow/flow_context.py +++ b/lib/crewai/src/crewai/flow/flow_context.py @@ -14,3 +14,7 @@ current_flow_request_id: contextvars.ContextVar[str | None] = contextvars.Contex current_flow_id: contextvars.ContextVar[str | None] = contextvars.ContextVar( "flow_id", default=None ) + +current_flow_method_name: contextvars.ContextVar[str] = contextvars.ContextVar( + "flow_method_name", default="unknown" +) diff --git a/lib/crewai/src/crewai/flow/input_provider.py b/lib/crewai/src/crewai/flow/input_provider.py new file mode 100644 index 000000000..20799abbe --- /dev/null +++ b/lib/crewai/src/crewai/flow/input_provider.py @@ -0,0 +1,151 @@ +"""Input provider protocol for Flow.ask(). + +This module provides the InputProvider protocol and InputResponse dataclass +used by Flow.ask() to request input from users during flow execution. + +The default implementation is ``ConsoleProvider`` (from +``crewai.flow.async_feedback.providers``), which serves both feedback +and input collection via console. + +Example (default console input): + ```python + from crewai.flow import Flow, start + + + class MyFlow(Flow): + @start() + def gather_info(self): + topic = self.ask("What topic should we research?") + return topic + ``` + +Example (custom provider with metadata): + ```python + from crewai.flow import Flow, start + from crewai.flow.input_provider import InputProvider, InputResponse + + + class SlackProvider: + def request_input(self, message, flow, metadata=None): + channel = metadata.get("channel", "#general") if metadata else "#general" + thread = self.post_question(channel, message) + reply = self.wait_for_reply(thread) + return InputResponse( + text=reply.text, + metadata={"responded_by": reply.user_id, "thread_id": thread.id}, + ) + + + class MyFlow(Flow): + input_provider = SlackProvider() + + @start() + def gather_info(self): + topic = self.ask("What topic?", metadata={"channel": "#research"}) + return topic + ``` +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + + +if TYPE_CHECKING: + from crewai.flow.flow import Flow + + +@dataclass +class InputResponse: + """Response from an InputProvider, optionally carrying metadata. + + Simple providers can just return a string from ``request_input()``. + Providers that need to send metadata back (e.g., who responded, + thread ID, external timestamps) return an ``InputResponse`` instead. + + ``ask()`` normalizes both cases -- callers always get ``str | None``. + The response metadata is stored in ``_input_history`` and emitted + in ``FlowInputReceivedEvent``. + + Attributes: + text: The user's input text, or None if unavailable. + metadata: Optional metadata from the provider about the response + (e.g., who responded, thread ID, timestamps). + + Example: + ```python + class MyProvider: + def request_input(self, message, flow, metadata=None): + response = get_response_from_external_system(message) + return InputResponse( + text=response.text, + metadata={"responded_by": response.user_id}, + ) + ``` + """ + + text: str | None + metadata: dict[str, Any] | None = field(default=None) + + +@runtime_checkable +class InputProvider(Protocol): + """Protocol for user input collection strategies. + + Implement this protocol to create custom input providers that integrate + with external systems like websockets, web UIs, Slack, or custom APIs. + + The default provider is ``ConsoleProvider``, which blocks waiting for + console input via Python's built-in ``input()`` function. + + Providers are always synchronous. The flow framework runs sync methods + in a thread pool (via ``asyncio.to_thread``), so ``ask()`` never blocks + the event loop even inside async flow methods. + + Providers can return either: + - ``str | None`` for simple cases (no response metadata) + - ``InputResponse`` when they need to send metadata back with the answer + + Example (simple): + ```python + class SimpleProvider: + def request_input(self, message: str, flow: Flow) -> str | None: + return input(message) + ``` + + Example (with metadata): + ```python + class SlackProvider: + def request_input(self, message, flow, metadata=None): + channel = metadata.get("channel") if metadata else "#general" + reply = self.post_and_wait(channel, message) + return InputResponse( + text=reply.text, + metadata={"responded_by": reply.user_id}, + ) + ``` + """ + + def request_input( + self, + message: str, + flow: Flow[Any], + metadata: dict[str, Any] | None = None, + ) -> str | InputResponse | None: + """Request input from the user. + + Args: + message: The question or prompt to display to the user. + flow: The Flow instance requesting input. Can be used to + access flow state, name, or other context. + metadata: Optional metadata from the caller, such as user ID, + channel, session context, etc. Providers can use this to + route the question to the right recipient. + + Returns: + The user's input as a string, an ``InputResponse`` with text + and optional response metadata, or None if input is unavailable + (e.g., user cancelled, connection dropped). + """ + ... diff --git a/lib/crewai/src/crewai/flow/types.py b/lib/crewai/src/crewai/flow/types.py index 024de41df..65ed3a995 100644 --- a/lib/crewai/src/crewai/flow/types.py +++ b/lib/crewai/src/crewai/flow/types.py @@ -4,6 +4,7 @@ This module contains TypedDict definitions and type aliases used throughout the Flow system. """ +from datetime import datetime from typing import ( Annotated, Any, @@ -101,6 +102,30 @@ class FlowData(TypedDict): flow_methods_attributes: list[FlowMethodData] +class InputHistoryEntry(TypedDict): + """A single entry in the flow's input history from ``self.ask()``. + + Each call to ``Flow.ask()`` appends one entry recording the question, + the user's response, which method asked, and any metadata exchanged + between the caller and the input provider. + + Attributes: + message: The question or prompt that was displayed to the user. + response: The user's response, or None on timeout/error. + method_name: The flow method that called ``ask()``. + timestamp: When the input was received. + metadata: Metadata sent with the question (caller to provider). + response_metadata: Metadata received with the answer (provider to caller). + """ + + message: str + response: str | None + method_name: str + timestamp: datetime + metadata: dict[str, Any] | None + response_metadata: dict[str, Any] | None + + class FlowExecutionData(TypedDict): """Flow execution data. diff --git a/lib/crewai/tests/cli/test_cli.py b/lib/crewai/tests/cli/test_cli.py index 529f5ded7..ed74a6036 100644 --- a/lib/crewai/tests/cli/test_cli.py +++ b/lib/crewai/tests/cli/test_cli.py @@ -66,7 +66,9 @@ def mock_crew(): def mock_get_crews(mock_crew): with mock.patch( "crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew] - ) as mock_get_crew: + ) as mock_get_crew, mock.patch( + "crewai.cli.reset_memories_command.get_flows", return_value=[] + ): yield mock_get_crew @@ -193,6 +195,79 @@ def test_reset_memory_from_many_crews(mock_get_crews, runner): assert call_count == 2, "reset_memories should have been called twice" +@pytest.fixture +def mock_flow(): + _mock = mock.Mock() + _mock.name = "TestFlow" + _mock.memory = mock.Mock() + _mock.memory.reset = mock.Mock() + return _mock + + +@pytest.fixture +def mock_get_flows(mock_flow): + with mock.patch( + "crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow] + ) as mock_get_flow, mock.patch( + "crewai.cli.reset_memories_command.get_crews", return_value=[] + ): + yield mock_get_flow + + +def test_reset_flow_memory(mock_get_flows, mock_flow, runner): + result = runner.invoke(reset_memories, ["-m"]) + mock_flow.memory.reset.assert_called_once() + assert "[Flow (TestFlow)] Memory has been reset." in result.output + + +def test_reset_flow_all_memories(mock_get_flows, mock_flow, runner): + result = runner.invoke(reset_memories, ["-a"]) + mock_flow.memory.reset.assert_called_once() + assert "[Flow (TestFlow)] Reset memories command has been completed." in result.output + + +def test_reset_flow_knowledge_no_effect(mock_get_flows, mock_flow, runner): + result = runner.invoke(reset_memories, ["--knowledge"]) + mock_flow.memory.reset.assert_not_called() + assert "[Flow (TestFlow)]" not in result.output + + +def test_reset_no_crew_or_flow_found(runner): + with mock.patch( + "crewai.cli.reset_memories_command.get_crews", return_value=[] + ), mock.patch( + "crewai.cli.reset_memories_command.get_flows", return_value=[] + ): + result = runner.invoke(reset_memories, ["-m"]) + assert "No crew or flow found." in result.output + + +def test_reset_crew_and_flow_memory(mock_crew, mock_flow, runner): + with mock.patch( + "crewai.cli.reset_memories_command.get_crews", return_value=[mock_crew] + ), mock.patch( + "crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow] + ): + result = runner.invoke(reset_memories, ["-m"]) + mock_crew.reset_memories.assert_called_once_with(command_type="memory") + mock_flow.memory.reset.assert_called_once() + assert f"[Crew ({mock_crew.name})] Memory has been reset." in result.output + assert "[Flow (TestFlow)] Memory has been reset." in result.output + + +def test_reset_flow_memory_none(runner): + mock_flow = mock.Mock() + mock_flow.name = "NoMemFlow" + mock_flow.memory = None + with mock.patch( + "crewai.cli.reset_memories_command.get_crews", return_value=[] + ), mock.patch( + "crewai.cli.reset_memories_command.get_flows", return_value=[mock_flow] + ): + result = runner.invoke(reset_memories, ["-m"]) + assert "[Flow (NoMemFlow)] Memory has been reset." in result.output + + def test_reset_no_memory_flags(runner): result = runner.invoke( reset_memories, diff --git a/lib/crewai/tests/test_flow_ask.py b/lib/crewai/tests/test_flow_ask.py new file mode 100644 index 000000000..d198e261c --- /dev/null +++ b/lib/crewai/tests/test_flow_ask.py @@ -0,0 +1,1152 @@ +"""Tests for Flow.ask() user input method. + +This module tests the ask() method on Flow, including basic usage, +timeout behavior, provider resolution, event emission, auto-checkpoint +durability, input history tracking, and integration with flow machinery. +""" + +from __future__ import annotations + +import time +from datetime import datetime +from typing import Any +from unittest.mock import MagicMock, patch + +from crewai.flow import Flow, flow_config, listen, start +from crewai.flow.async_feedback.providers import ConsoleProvider +from crewai.flow.flow import FlowState +from crewai.flow.input_provider import InputProvider, InputResponse + + +# ── Test helpers ───────────────────────────────────────────────── + + +class MockInputProvider: + """Mock input provider that returns pre-configured responses.""" + + def __init__(self, responses: list[str | None]) -> None: + self.responses = responses + self._call_count = 0 + self.messages: list[str] = [] + self.received_metadata: list[dict[str, Any] | None] = [] + + def request_input( + self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None + ) -> str | None: + self.messages.append(message) + self.received_metadata.append(metadata) + if self._call_count >= len(self.responses): + return None + response = self.responses[self._call_count] + self._call_count += 1 + return response + + +class SlowMockProvider: + """Mock provider that delays before returning, for timeout tests.""" + + def __init__(self, delay: float, response: str = "delayed") -> None: + self.delay = delay + self.response = response + + def request_input( + self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None + ) -> str | None: + time.sleep(self.delay) + return self.response + + +# ── Basic Functionality ────────────────────────────────────────── + + +class TestAskBasic: + """Tests for basic ask() functionality.""" + + def test_ask_returns_user_input(self) -> None: + """ask() returns the string from the input provider.""" + + class TestFlow(Flow): + input_provider = MockInputProvider(["hello"]) + + @start() + def my_method(self): + return self.ask("Say something:") + + flow = TestFlow() + result = flow.kickoff() + assert result == "hello" + + def test_ask_in_async_method(self) -> None: + """ask() works inside an async flow method.""" + + class TestFlow(Flow): + input_provider = MockInputProvider(["async hello"]) + + @start() + async def my_method(self): + return self.ask("Say something:") + + flow = TestFlow() + result = flow.kickoff() + assert result == "async hello" + + def test_ask_in_start_method(self) -> None: + """ask() works inside a @start() method, flow completes normally.""" + execution_log: list[str] = [] + + class TestFlow(Flow): + input_provider = MockInputProvider(["AI"]) + + @start() + def gather(self): + topic = self.ask("Topic?") + execution_log.append(f"got:{topic}") + return topic + + flow = TestFlow() + result = flow.kickoff() + assert result == "AI" + assert execution_log == ["got:AI"] + + def test_ask_in_listen_method(self) -> None: + """ask() works inside a @listen() method.""" + + class TestFlow(Flow): + input_provider = MockInputProvider(["detailed"]) + + @start() + def step1(self): + return "topic" + + @listen("step1") + def step2(self): + depth = self.ask("How deep?") + return f"researching at {depth} level" + + flow = TestFlow() + result = flow.kickoff() + assert result == "researching at detailed level" + + def test_ask_multiple_calls(self) -> None: + """Multiple ask() calls in one method return correct values in order.""" + + class TestFlow(Flow): + input_provider = MockInputProvider(["AI", "detailed", "english"]) + + @start() + def gather(self): + topic = self.ask("Topic?") + depth = self.ask("Depth?") + lang = self.ask("Language?") + return {"topic": topic, "depth": depth, "lang": lang} + + flow = TestFlow() + result = flow.kickoff() + assert result == {"topic": "AI", "depth": "detailed", "lang": "english"} + + def test_ask_conditional(self) -> None: + """ask() called conditionally based on previous answer.""" + + class TestFlow(Flow): + input_provider = MockInputProvider(["AI", "LLMs"]) + + @start() + def gather(self): + topic = self.ask("Topic?") + if topic == "AI": + focus = self.ask("Specific area?") + else: + focus = "general" + return {"topic": topic, "focus": focus} + + flow = TestFlow() + result = flow.kickoff() + assert result == {"topic": "AI", "focus": "LLMs"} + + def test_ask_returns_empty_string_on_enter(self) -> None: + """Empty string means user pressed Enter (intentional empty input).""" + + class TestFlow(Flow): + input_provider = MockInputProvider([""]) + + @start() + def my_method(self): + result = self.ask("Optional input:") + return result + + flow = TestFlow() + result = flow.kickoff() + assert result == "" + assert result is not None # Explicitly not None + + +# ── Timeout ────────────────────────────────────────────────────── + + +class TestAskTimeout: + """Tests for timeout behavior.""" + + def test_ask_timeout_returns_none(self) -> None: + """ask() returns None when timeout expires.""" + + class TestFlow(Flow): + input_provider = SlowMockProvider(delay=5.0) + + @start() + def my_method(self): + return self.ask("Question?", timeout=0.1) + + flow = TestFlow() + result = flow.kickoff() + assert result is None + + def test_ask_timeout_in_async_method(self) -> None: + """ask() timeout works inside an async flow method.""" + + class TestFlow(Flow): + input_provider = SlowMockProvider(delay=5.0) + + @start() + async def my_method(self): + return self.ask("Question?", timeout=0.1) + + flow = TestFlow() + result = flow.kickoff() + assert result is None + + def test_ask_loop_with_timeout_termination(self) -> None: + """while (msg := ask(...)) is not None pattern terminates on timeout.""" + messages_received: list[str] = [] + + class TestFlow(Flow): + input_provider = MockInputProvider(["hello", "world", None]) + + @start() + def chat(self): + while (msg := self.ask("You:")) is not None: + messages_received.append(msg) + return len(messages_received) + + flow = TestFlow() + result = flow.kickoff() + assert result == 2 + assert messages_received == ["hello", "world"] + + def test_ask_no_timeout_waits_indefinitely(self) -> None: + """ask() with no timeout blocks until provider returns.""" + + class TestFlow(Flow): + input_provider = MockInputProvider(["answer"]) + + @start() + def my_method(self): + return self.ask("Question?") # no timeout + + flow = TestFlow() + result = flow.kickoff() + assert result == "answer" + + +# ── Provider Resolution ────────────────────────────────────────── + + +class TestProviderResolution: + """Tests for provider resolution priority chain.""" + + def test_ask_uses_flow_level_provider(self) -> None: + """Per-flow input_provider is used when set.""" + provider = MockInputProvider(["from flow"]) + + class TestFlow(Flow): + input_provider = provider + + @start() + def my_method(self): + return self.ask("Q?") + + flow = TestFlow() + flow.kickoff() + assert provider.messages == ["Q?"] + + def test_ask_uses_global_config_provider(self) -> None: + """flow_config.input_provider is used as fallback.""" + provider = MockInputProvider(["from config"]) + + original = flow_config.input_provider + try: + flow_config.input_provider = provider + + class TestFlow(Flow): + @start() + def my_method(self): + return self.ask("Q?") + + flow = TestFlow() + result = flow.kickoff() + assert result == "from config" + assert provider.messages == ["Q?"] + finally: + flow_config.input_provider = original + + def test_ask_defaults_to_console_provider(self) -> None: + """When no provider configured, ConsoleProvider is used.""" + original = flow_config.input_provider + try: + flow_config.input_provider = None + + class TestFlow(Flow): + # No input_provider set + @start() + def my_method(self): + return self.ask("Q?") + + flow = TestFlow() + resolved = flow._resolve_input_provider() + assert isinstance(resolved, ConsoleProvider) + finally: + flow_config.input_provider = original + + def test_flow_provider_overrides_global(self) -> None: + """Per-flow provider takes precedence over global config.""" + flow_provider = MockInputProvider(["from flow"]) + global_provider = MockInputProvider(["from global"]) + + original = flow_config.input_provider + try: + flow_config.input_provider = global_provider + + class TestFlow(Flow): + input_provider = flow_provider + + @start() + def my_method(self): + return self.ask("Q?") + + flow = TestFlow() + result = flow.kickoff() + assert result == "from flow" + assert flow_provider.messages == ["Q?"] + assert global_provider.messages == [] # not called + finally: + flow_config.input_provider = original + + +# ── Events ─────────────────────────────────────────────────────── + + +class TestAskEvents: + """Tests for event emission during ask().""" + + def test_ask_emits_input_requested_event(self) -> None: + """FlowInputRequestedEvent is emitted when ask() is called.""" + from crewai.events.event_bus import crewai_event_bus + from crewai.events.types.flow_events import FlowInputRequestedEvent + + events_captured: list[FlowInputRequestedEvent] = [] + + class TestFlow(Flow): + input_provider = MockInputProvider(["answer"]) + + @start() + def my_method(self): + return self.ask("What topic?") + + flow = TestFlow() + + original_emit = crewai_event_bus.emit + + def capture_emit(source: Any, event: Any) -> Any: + if isinstance(event, FlowInputRequestedEvent): + events_captured.append(event) + return original_emit(source, event) + + with patch.object(crewai_event_bus, "emit", side_effect=capture_emit): + flow.kickoff() + + assert len(events_captured) == 1 + assert events_captured[0].message == "What topic?" + assert events_captured[0].type == "flow_input_requested" + + def test_ask_emits_input_received_event(self) -> None: + """FlowInputReceivedEvent is emitted after input is received.""" + from crewai.events.event_bus import crewai_event_bus + from crewai.events.types.flow_events import FlowInputReceivedEvent + + events_captured: list[FlowInputReceivedEvent] = [] + + class TestFlow(Flow): + input_provider = MockInputProvider(["my answer"]) + + @start() + def my_method(self): + return self.ask("Question?") + + flow = TestFlow() + + original_emit = crewai_event_bus.emit + + def capture_emit(source: Any, event: Any) -> Any: + if isinstance(event, FlowInputReceivedEvent): + events_captured.append(event) + return original_emit(source, event) + + with patch.object(crewai_event_bus, "emit", side_effect=capture_emit): + flow.kickoff() + + assert len(events_captured) == 1 + assert events_captured[0].message == "Question?" + assert events_captured[0].response == "my answer" + assert events_captured[0].type == "flow_input_received" + + def test_ask_timeout_emits_received_with_none(self) -> None: + """FlowInputReceivedEvent has response=None on timeout.""" + from crewai.events.event_bus import crewai_event_bus + from crewai.events.types.flow_events import FlowInputReceivedEvent + + events_captured: list[FlowInputReceivedEvent] = [] + + class TestFlow(Flow): + input_provider = SlowMockProvider(delay=5.0) + + @start() + def my_method(self): + return self.ask("Question?", timeout=0.1) + + flow = TestFlow() + + original_emit = crewai_event_bus.emit + + def capture_emit(source: Any, event: Any) -> Any: + if isinstance(event, FlowInputReceivedEvent): + events_captured.append(event) + return original_emit(source, event) + + with patch.object(crewai_event_bus, "emit", side_effect=capture_emit): + flow.kickoff() + + assert len(events_captured) == 1 + assert events_captured[0].response is None + + +# ── Auto-checkpoint (Durability) ───────────────────────────────── + + +class TestAskCheckpoint: + """Tests for auto-checkpoint durability before ask() waits.""" + + def test_ask_checkpoints_state_before_waiting(self) -> None: + """State is saved to persistence before waiting for input.""" + mock_persistence = MagicMock() + mock_persistence.load_state.return_value = None + + class TestFlow(Flow): + input_provider = MockInputProvider(["answer"]) + + @start() + def my_method(self): + self.state["important"] = "data" + return self.ask("Question?") + + flow = TestFlow(persistence=mock_persistence) + flow.kickoff() + + # Find the _ask_checkpoint call among save_state calls + checkpoint_calls = [ + c for c in mock_persistence.save_state.call_args_list + if c.kwargs.get("method_name") == "_ask_checkpoint" + or (len(c.args) >= 2 and c.args[1] == "_ask_checkpoint") + ] + assert len(checkpoint_calls) >= 1 + + def test_ask_no_checkpoint_without_persistence(self) -> None: + """No error when persistence is not configured.""" + + class TestFlow(Flow): + input_provider = MockInputProvider(["answer"]) + + @start() + def my_method(self): + return self.ask("Question?") + + flow = TestFlow() # No persistence + result = flow.kickoff() + assert result == "answer" # Works fine without persistence + + def test_state_recoverable_after_checkpoint(self) -> None: + """State set before ask() is checkpointed and recoverable. + + The auto-checkpoint happens *before* the provider is called, so + state values set prior to ask() are persisted. This means if the + server crashes while waiting for input, previously gathered data + is safe. + """ + mock_persistence = MagicMock() + mock_persistence.load_state.return_value = None + + class GatherFlow(Flow): + input_provider = MockInputProvider(["AI", "detailed"]) + + @start() + def gather(self): + # First ask: nothing in state yet + topic = self.ask("Topic?") + self.state["topic"] = topic + # Second ask: state now has topic, checkpoint saves it + depth = self.ask("Depth?") + self.state["depth"] = depth + return {"topic": topic, "depth": depth} + + flow = GatherFlow(persistence=mock_persistence) + result = flow.kickoff() + assert result == {"topic": "AI", "depth": "detailed"} + + # Find the checkpoint calls + checkpoint_calls = [ + c for c in mock_persistence.save_state.call_args_list + if c.kwargs.get("method_name") == "_ask_checkpoint" + or (len(c.args) >= 2 and c.args[1] == "_ask_checkpoint") + ] + assert len(checkpoint_calls) == 2 + + # The second checkpoint (before asking "Depth?") should have topic + second_checkpoint = checkpoint_calls[1] + # state_data is the third positional arg or keyword arg + if second_checkpoint.kwargs.get("state_data"): + state_data = second_checkpoint.kwargs["state_data"] + else: + state_data = second_checkpoint.args[2] + assert state_data.get("topic") == "AI" + + +# ── Input History ──────────────────────────────────────────────── + + +class TestInputHistory: + """Tests for _input_history tracking.""" + + def test_input_history_accumulated(self) -> None: + """_input_history tracks all ask/response pairs.""" + + class TestFlow(Flow): + input_provider = MockInputProvider(["AI", "detailed"]) + + @start() + def gather(self): + self.ask("Topic?") + self.ask("Depth?") + return "done" + + flow = TestFlow() + flow.kickoff() + + assert len(flow._input_history) == 2 + assert flow._input_history[0]["message"] == "Topic?" + assert flow._input_history[0]["response"] == "AI" + assert flow._input_history[1]["message"] == "Depth?" + assert flow._input_history[1]["response"] == "detailed" + + def test_input_history_includes_method_name(self) -> None: + """Input history records which method called ask().""" + + class TestFlow(Flow): + input_provider = MockInputProvider(["AI"]) + + @start() + def gather_info(self): + self.ask("Topic?") + return "done" + + flow = TestFlow() + flow.kickoff() + + assert len(flow._input_history) == 1 + assert flow._input_history[0]["method_name"] == "gather_info" + + def test_input_history_includes_timestamp(self) -> None: + """Input history records timestamps.""" + + class TestFlow(Flow): + input_provider = MockInputProvider(["AI"]) + + @start() + def my_method(self): + self.ask("Topic?") + return "done" + + flow = TestFlow() + before = datetime.now() + flow.kickoff() + after = datetime.now() + + assert len(flow._input_history) == 1 + ts = flow._input_history[0]["timestamp"] + assert isinstance(ts, datetime) + assert before <= ts <= after + + def test_input_history_records_none_on_timeout(self) -> None: + """Input history records None response on timeout.""" + + class TestFlow(Flow): + input_provider = SlowMockProvider(delay=5.0) + + @start() + def my_method(self): + self.ask("Question?", timeout=0.1) + return "done" + + flow = TestFlow() + flow.kickoff() + + assert len(flow._input_history) == 1 + assert flow._input_history[0]["response"] is None + + +# ── Integration ────────────────────────────────────────────────── + + +class TestAskIntegration: + """Integration tests for ask() with other flow features.""" + + def test_ask_works_with_listen_chain(self) -> None: + """ask() in a start method, result flows to listener.""" + execution_log: list[str] = [] + + class TestFlow(Flow): + input_provider = MockInputProvider(["AI agents"]) + + @start() + def gather(self): + topic = self.ask("Topic?") + execution_log.append(f"gathered:{topic}") + return topic + + @listen("gather") + def process(self): + execution_log.append("processing") + return "processed" + + flow = TestFlow() + flow.kickoff() + assert "gathered:AI agents" in execution_log + assert "processing" in execution_log + + def test_ask_with_structured_state(self) -> None: + """ask() works with Pydantic-based flow state.""" + + class ResearchState(FlowState): + topic: str = "" + depth: str = "" + + class TestFlow(Flow[ResearchState]): + initial_state = ResearchState + input_provider = MockInputProvider(["AI", "detailed"]) + + @start() + def gather(self): + self.state.topic = self.ask("Topic?") + self.state.depth = self.ask("Depth?") + return {"topic": self.state.topic, "depth": self.state.depth} + + flow = TestFlow() + result = flow.kickoff() + assert result == {"topic": "AI", "depth": "detailed"} + assert flow.state.topic == "AI" + assert flow.state.depth == "detailed" + + def test_ask_in_async_method_with_listen_chain(self) -> None: + """ask() in an async start method, result flows to listener.""" + execution_log: list[str] = [] + + class TestFlow(Flow): + input_provider = MockInputProvider(["async topic"]) + + @start() + async def gather(self): + topic = self.ask("Topic?") + execution_log.append(f"gathered:{topic}") + return topic + + @listen("gather") + def process(self): + execution_log.append("processing") + return "processed" + + flow = TestFlow() + flow.kickoff() + assert "gathered:async topic" in execution_log + assert "processing" in execution_log + + def test_ask_with_state_persistence_recovery(self) -> None: + """Ask checkpoints state so previously gathered values survive.""" + mock_persistence = MagicMock() + mock_persistence.load_state.return_value = None + + class RecoverableFlow(Flow): + input_provider = MockInputProvider(["AI", "detailed"]) + + @start() + def gather(self): + if not self.state.get("topic"): + self.state["topic"] = self.ask("Topic?") + if not self.state.get("depth"): + self.state["depth"] = self.ask("Depth?") + return { + "topic": self.state["topic"], + "depth": self.state["depth"], + } + + flow = RecoverableFlow(persistence=mock_persistence) + result = flow.kickoff() + assert result["topic"] == "AI" + assert result["depth"] == "detailed" + + # Verify checkpoints were made + checkpoint_calls = [ + c for c in mock_persistence.save_state.call_args_list + if c.kwargs.get("method_name") == "_ask_checkpoint" + or (len(c.args) >= 2 and c.args[1] == "_ask_checkpoint") + ] + # Two ask() calls = two checkpoints + assert len(checkpoint_calls) == 2 + + def test_ask_and_human_feedback_coexist(self) -> None: + """ask() and @human_feedback can be used in the same flow.""" + from crewai.flow import human_feedback + + class TestFlow(Flow): + input_provider = MockInputProvider(["AI"]) + + @start() + def gather(self): + topic = self.ask("Topic?") + return topic + + @listen("gather") + @human_feedback(message="Review this topic:") + def review(self): + return f"Researching: {self.state.get('_last_topic', 'unknown')}" + + flow = TestFlow() + + with patch.object(flow, "_request_human_feedback", return_value="looks good"): + flow.kickoff() + + # Flow completed with both ask and human_feedback + assert flow.last_human_feedback is not None + + def test_ask_preserves_flow_lifecycle(self) -> None: + """Flow events (started, finished) still fire normally with ask().""" + from crewai.events.event_bus import crewai_event_bus + from crewai.events.types.flow_events import ( + FlowFinishedEvent, + FlowStartedEvent, + ) + + events_seen: list[str] = [] + + class TestFlow(Flow): + input_provider = MockInputProvider(["answer"]) + + @start() + def my_method(self): + return self.ask("Q?") + + flow = TestFlow() + + original_emit = crewai_event_bus.emit + + def capture_emit(source: Any, event: Any) -> Any: + if isinstance(event, FlowStartedEvent): + events_seen.append("started") + elif isinstance(event, FlowFinishedEvent): + events_seen.append("finished") + return original_emit(source, event) + + with patch.object(crewai_event_bus, "emit", side_effect=capture_emit): + flow.kickoff() + + assert "started" in events_seen + assert "finished" in events_seen + + +# ── Console Provider ───────────────────────────────────────────── + + +class TestConsoleProviderInput: + """Tests for ConsoleProvider.request_input() (used by Flow.ask()).""" + + def test_console_provider_pauses_live_updates(self) -> None: + """ConsoleProvider pauses and resumes formatter live updates.""" + from crewai.events.event_listener import event_listener + + mock_formatter = MagicMock() + mock_formatter.console = MagicMock() + + provider = ConsoleProvider(verbose=True) + + with ( + patch.object(event_listener, "formatter", mock_formatter), + patch("builtins.input", return_value="test input"), + ): + result = provider.request_input("Question?", MagicMock()) + + mock_formatter.pause_live_updates.assert_called_once() + mock_formatter.resume_live_updates.assert_called_once() + assert result == "test input" + + def test_console_provider_displays_message(self) -> None: + """ConsoleProvider displays the message with Rich console.""" + from crewai.events.event_listener import event_listener + + mock_formatter = MagicMock() + mock_console = MagicMock() + mock_formatter.console = mock_console + + provider = ConsoleProvider(verbose=True) + + with ( + patch.object(event_listener, "formatter", mock_formatter), + patch("builtins.input", return_value="answer"), + ): + provider.request_input("What topic?", MagicMock()) + + # Verify the message was printed + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("What topic?" in c for c in print_calls) + + def test_console_provider_non_verbose(self) -> None: + """ConsoleProvider in non-verbose mode uses plain input.""" + from crewai.events.event_listener import event_listener + + mock_formatter = MagicMock() + mock_formatter.console = MagicMock() + + provider = ConsoleProvider(verbose=False) + + with ( + patch.object(event_listener, "formatter", mock_formatter), + patch("builtins.input", return_value="plain answer") as mock_input, + ): + result = provider.request_input("Q?", MagicMock()) + + assert result == "plain answer" + mock_input.assert_called_once_with("Q? ") + + def test_console_provider_strips_response(self) -> None: + """ConsoleProvider strips whitespace from response.""" + from crewai.events.event_listener import event_listener + + mock_formatter = MagicMock() + mock_formatter.console = MagicMock() + + provider = ConsoleProvider(verbose=False) + + with ( + patch.object(event_listener, "formatter", mock_formatter), + patch("builtins.input", return_value=" spaced answer "), + ): + result = provider.request_input("Q?", MagicMock()) + + assert result == "spaced answer" + + def test_console_provider_implements_protocol(self) -> None: + """ConsoleProvider satisfies the InputProvider protocol.""" + provider = ConsoleProvider() + assert isinstance(provider, InputProvider) + + +# ── InputProvider Protocol ─────────────────────────────────────── + + +class TestInputProviderProtocol: + """Tests for the InputProvider protocol.""" + + def test_custom_provider_satisfies_protocol(self) -> None: + """A class with request_input satisfies the InputProvider protocol.""" + + class MyProvider: + def request_input(self, message: str, flow: Flow[Any]) -> str | None: + return "custom" + + provider = MyProvider() + assert isinstance(provider, InputProvider) + + def test_mock_provider_satisfies_protocol(self) -> None: + """MockInputProvider satisfies the InputProvider protocol.""" + provider = MockInputProvider(["test"]) + assert isinstance(provider, InputProvider) + + +# ── Error Handling ─────────────────────────────────────────────── + + +class TestAskErrorHandling: + """Tests for error handling in ask().""" + + def test_ask_returns_none_on_provider_error(self) -> None: + """ask() returns None if provider raises an exception.""" + + class FailingProvider: + def request_input(self, message: str, flow: Flow[Any]) -> str | None: + raise RuntimeError("Provider failed") + + class TestFlow(Flow): + input_provider = FailingProvider() + + @start() + def my_method(self): + return self.ask("Question?") + + flow = TestFlow() + result = flow.kickoff() + assert result is None + + def test_ask_in_async_method_returns_none_on_provider_error(self) -> None: + """ask() returns None if provider raises in an async method.""" + + class FailingProvider: + def request_input(self, message: str, flow: Flow[Any]) -> str | None: + raise RuntimeError("Provider failed") + + class TestFlow(Flow): + input_provider = FailingProvider() + + @start() + async def my_method(self): + return self.ask("Question?") + + flow = TestFlow() + result = flow.kickoff() + assert result is None + + +# ── Metadata ───────────────────────────────────────────────────── + + +class TestAskMetadata: + """Tests for bidirectional metadata support in ask().""" + + def test_ask_passes_metadata_to_provider(self) -> None: + """Provider receives the metadata dict from ask().""" + provider = MockInputProvider(["answer"]) + + class TestFlow(Flow): + input_provider = provider + + @start() + def my_method(self): + return self.ask("Q?", metadata={"user_id": "u123"}) + + flow = TestFlow() + flow.kickoff() + assert provider.received_metadata == [{"user_id": "u123"}] + + def test_ask_metadata_none_by_default(self) -> None: + """Provider receives None metadata when not provided.""" + provider = MockInputProvider(["answer"]) + + class TestFlow(Flow): + input_provider = provider + + @start() + def my_method(self): + return self.ask("Q?") + + flow = TestFlow() + flow.kickoff() + assert provider.received_metadata == [None] + + def test_ask_provider_returns_input_response(self) -> None: + """Provider returns InputResponse with response metadata.""" + + class MetadataProvider: + def request_input( + self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None + ) -> InputResponse: + return InputResponse( + text="the answer", + metadata={"responded_by": "u456", "thread_id": "t789"}, + ) + + class TestFlow(Flow): + input_provider = MetadataProvider() + + @start() + def my_method(self): + return self.ask("Q?", metadata={"user_id": "u123"}) + + flow = TestFlow() + result = flow.kickoff() + + # ask() still returns plain string + assert result == "the answer" + + # History has both metadata dicts + assert len(flow._input_history) == 1 + entry = flow._input_history[0] + assert entry["metadata"] == {"user_id": "u123"} + assert entry["response_metadata"] == {"responded_by": "u456", "thread_id": "t789"} + + def test_ask_provider_returns_string_with_metadata_sent(self) -> None: + """Provider returns plain string; history has metadata but no response_metadata.""" + + class TestFlow(Flow): + input_provider = MockInputProvider(["answer"]) + + @start() + def my_method(self): + return self.ask("Q?", metadata={"channel": "#research"}) + + flow = TestFlow() + flow.kickoff() + + entry = flow._input_history[0] + assert entry["metadata"] == {"channel": "#research"} + assert entry["response_metadata"] is None + + def test_ask_metadata_in_requested_event(self) -> None: + """FlowInputRequestedEvent carries metadata.""" + from crewai.events.event_bus import crewai_event_bus + from crewai.events.types.flow_events import FlowInputRequestedEvent + + events_captured: list[FlowInputRequestedEvent] = [] + + class TestFlow(Flow): + input_provider = MockInputProvider(["answer"]) + + @start() + def my_method(self): + return self.ask("Q?", metadata={"user_id": "u123"}) + + flow = TestFlow() + original_emit = crewai_event_bus.emit + + def capture_emit(source: Any, event: Any) -> Any: + if isinstance(event, FlowInputRequestedEvent): + events_captured.append(event) + return original_emit(source, event) + + with patch.object(crewai_event_bus, "emit", side_effect=capture_emit): + flow.kickoff() + + assert len(events_captured) == 1 + assert events_captured[0].metadata == {"user_id": "u123"} + + def test_ask_metadata_in_received_event(self) -> None: + """FlowInputReceivedEvent carries both metadata and response_metadata.""" + from crewai.events.event_bus import crewai_event_bus + from crewai.events.types.flow_events import FlowInputReceivedEvent + + events_captured: list[FlowInputReceivedEvent] = [] + + class MetadataProvider: + def request_input( + self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None + ) -> InputResponse: + return InputResponse(text="answer", metadata={"responded_by": "u456"}) + + class TestFlow(Flow): + input_provider = MetadataProvider() + + @start() + def my_method(self): + return self.ask("Q?", metadata={"user_id": "u123"}) + + flow = TestFlow() + original_emit = crewai_event_bus.emit + + def capture_emit(source: Any, event: Any) -> Any: + if isinstance(event, FlowInputReceivedEvent): + events_captured.append(event) + return original_emit(source, event) + + with patch.object(crewai_event_bus, "emit", side_effect=capture_emit): + flow.kickoff() + + assert len(events_captured) == 1 + assert events_captured[0].metadata == {"user_id": "u123"} + assert events_captured[0].response_metadata == {"responded_by": "u456"} + assert events_captured[0].response == "answer" + + def test_ask_input_response_with_none_text(self) -> None: + """Provider returns InputResponse with text=None.""" + + class NoneTextProvider: + def request_input( + self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None + ) -> InputResponse: + return InputResponse(text=None, metadata={"reason": "user_declined"}) + + class TestFlow(Flow): + input_provider = NoneTextProvider() + + @start() + def my_method(self): + return self.ask("Q?") + + flow = TestFlow() + result = flow.kickoff() + assert result is None + + entry = flow._input_history[0] + assert entry["response"] is None + assert entry["response_metadata"] == {"reason": "user_declined"} + + def test_ask_metadata_thread_safe(self) -> None: + """Concurrent ask() calls with different metadata don't cross-contaminate.""" + import threading + + call_log: list[dict[str, Any]] = [] + log_lock = threading.Lock() + + class TrackingProvider: + def request_input( + self, message: str, flow: Flow[Any], metadata: dict[str, Any] | None = None + ) -> InputResponse: + # Small delay to increase chance of interleaving + time.sleep(0.05) + with log_lock: + call_log.append({"message": message, "metadata": metadata}) + user = metadata.get("user", "unknown") if metadata else "unknown" + return InputResponse( + text=f"answer from {user}", + metadata={"responded_by": user}, + ) + + class TestFlow(Flow): + input_provider = TrackingProvider() + + @start() + def trigger(self): + return "go" + + @listen("trigger") + def listener_a(self): + return self.ask("Question A?", metadata={"user": "alice"}) + + @listen("trigger") + def listener_b(self): + return self.ask("Question B?", metadata={"user": "bob"}) + + flow = TestFlow() + flow.kickoff() + + # Both calls should have recorded their own metadata + assert len(flow._input_history) == 2 + + alice_entry = next( + (e for e in flow._input_history if e["metadata"] and e["metadata"].get("user") == "alice"), + None, + ) + bob_entry = next( + (e for e in flow._input_history if e["metadata"] and e["metadata"].get("user") == "bob"), + None, + ) + + assert alice_entry is not None + assert alice_entry["response"] == "answer from alice" + assert alice_entry["response_metadata"] == {"responded_by": "alice"} + + assert bob_entry is not None + assert bob_entry["response"] == "answer from bob" + assert bob_entry["response_metadata"] == {"responded_by": "bob"}