Compare commits

..

6 Commits

Author SHA1 Message Date
Otto
9b20f4cd13 refactor: simplify ExecutionQueue docstrings and move test file
- Trim verbose BUG FIX docstring to concise 3-line note
- Remove redundant method docstrings (add, get, empty)
- Move test file to backend/data/ with proper pytest conventions
- Add note about ProcessPoolExecutor migration for future devs

Co-authored-by: Zamil Majdy <majdyz@users.noreply.github.com>
2026-02-08 16:11:35 +00:00
Nikhil Bhagat
a3d0f9cbd2 fix(backend): format test_execution_queue.py and remove unused variable 2025-12-14 19:37:29 +05:45
Nikhil Bhagat
02ddb51446 Added test_execution_queue.py and test the execution part and the test got passed 2025-12-14 19:05:14 +05:45
Nikhil Bhagat
750e096f15 fix(backend): replace multiprocessing.Manager().Queue() with queue.Queue()
ExecutionQueue was unnecessarily using multiprocessing.Manager().Queue() which
spawns a subprocess for IPC. Since ExecutionQueue is only accessed from threads
within the same process, queue.Queue() is sufficient and more efficient.

- Eliminates unnecessary subprocess spawning per graph execution
- Removes IPC overhead for queue operations
- Prevents potential resource leaks from Manager processes
- Improves scalability for concurrent graph executions
2025-12-14 19:04:14 +05:45
Krzysztof Czerwinski
ff5c8f324b Merge branch 'master' into dev 2025-12-12 22:26:39 +09:00
Zamil Majdy
71157bddd7 feat(backend): add agent mode support to SmartDecisionMakerBlock with autonomous tool execution loops (#11547)
## Summary

<img width="2072" height="1836" alt="image"
src="https://github.com/user-attachments/assets/9d231a77-6309-46b9-bc11-befb5d8e9fcc"
/>

**🚀 Major Feature: Agent Mode Support**

Adds autonomous agent mode to SmartDecisionMakerBlock, enabling it to
execute tools directly in loops until tasks are completed, rather than
just yielding tool calls for external execution.

##  **Key New Features**

### 🤖 **Agent Mode with Tool Execution Loops**
- **New `agent_mode_max_iterations` parameter** controls execution
behavior:
  - `0` = Traditional mode (single LLM call, yield tool calls)
  - `1+` = Agent mode with iteration limit
  - `-1` = Infinite agent mode (loop until finished)

### 🔄 **Autonomous Tool Execution**  
- **Direct tool execution** instead of yielding for external handling
- **Multi-iteration loops** with conversation state management
- **Automatic completion detection** when LLM stops making tool calls
- **Iteration limit handling** with graceful completion messages

### 🏗️ **Proper Database Operations**
- **Replace manual execution ID generation** with proper
`upsert_execution_input`/`upsert_execution_output`
- **Real NodeExecutionEntry objects** from database results
- **Proper execution status management**: QUEUED → RUNNING →
COMPLETED/FAILED

### 🔧 **Enhanced Type Safety**
- **Pydantic models** replace TypedDict: `ToolInfo` and
`ExecutionParams`
- **Runtime validation** with better error messages
- **Improved developer experience** with IDE support

## 🔧 **Technical Implementation**

### Agent Mode Flow:
```python
# Agent mode enabled with iterations
if input_data.agent_mode_max_iterations != 0:
    async for result in self._execute_tools_agent_mode(...):
        yield result  # "conversations", "finished"
    return

# Traditional mode (existing behavior)  
# Single LLM call + yield tool calls for external execution
```

### Tool Execution with Database Operations:
```python
# Before: Manual execution IDs
tool_exec_id = f"{node_exec_id}_tool_{sink_node_id}_{len(input_data)}"

# After: Proper database operations
node_exec_result, final_input_data = await db_client.upsert_execution_input(
    node_id=sink_node_id,
    graph_exec_id=execution_params.graph_exec_id,
    input_name=input_name, 
    input_data=input_value,
)
```

### Type Safety with Pydantic:
```python
# Before: Dict access prone to errors
execution_params["user_id"]  

# After: Validated model access
execution_params.user_id  # Runtime validation + IDE support
```

## 🧪 **Comprehensive Test Coverage**

- **Agent mode execution tests** with multi-iteration scenarios
- **Database operation verification** 
- **Type safety validation**
- **Backward compatibility** for traditional mode
- **Enhanced dynamic fields tests**

## 📊 **Usage Examples**

### Traditional Mode (Existing Behavior):
```python
SmartDecisionMakerBlock.Input(
    prompt="Search for keywords",
    agent_mode_max_iterations=0  # Default
)
# → Yields tool calls for external execution
```

### Agent Mode (New Feature):
```python  
SmartDecisionMakerBlock.Input(
    prompt="Complete this task using available tools",
    agent_mode_max_iterations=5  # Max 5 iterations
)
# → Executes tools directly until task completion or iteration limit
```

### Infinite Agent Mode:
```python
SmartDecisionMakerBlock.Input(
    prompt="Analyze and process this data thoroughly", 
    agent_mode_max_iterations=-1  # No limit, run until finished
)
# → Executes tools autonomously until LLM indicates completion
```

##  **Backward Compatibility**

- **Zero breaking changes** to existing functionality
- **Traditional mode remains default** (`agent_mode_max_iterations=0`)
- **All existing tests pass**
- **Same API for tool definitions and execution**

This transforms the SmartDecisionMakerBlock from a simple tool call
generator into a powerful autonomous agent capable of complex multi-step
task execution! 🎯

🤖 Generated with [Claude Code](https://claude.ai/code)

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-12-12 09:58:06 +00:00
10 changed files with 1119 additions and 114 deletions

View File

@@ -1,8 +1,11 @@
import logging import logging
import re import re
from collections import Counter from collections import Counter
from concurrent.futures import Future
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
import backend.blocks.llm as llm import backend.blocks.llm as llm
from backend.blocks.agent import AgentExecutorBlock from backend.blocks.agent import AgentExecutorBlock
from backend.data.block import ( from backend.data.block import (
@@ -20,16 +23,41 @@ from backend.data.dynamic_fields import (
is_dynamic_field, is_dynamic_field,
is_tool_pin, is_tool_pin,
) )
from backend.data.execution import ExecutionContext
from backend.data.model import NodeExecutionStats, SchemaField from backend.data.model import NodeExecutionStats, SchemaField
from backend.util import json from backend.util import json
from backend.util.clients import get_database_manager_async_client from backend.util.clients import get_database_manager_async_client
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
if TYPE_CHECKING: if TYPE_CHECKING:
from backend.data.graph import Link, Node from backend.data.graph import Link, Node
from backend.executor.manager import ExecutionProcessor
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ToolInfo(BaseModel):
"""Processed tool call information."""
tool_call: Any # The original tool call object from LLM response
tool_name: str # The function name
tool_def: dict[str, Any] # The tool definition from tool_functions
input_data: dict[str, Any] # Processed input data ready for tool execution
field_mapping: dict[str, str] # Field name mapping for the tool
class ExecutionParams(BaseModel):
"""Tool execution parameters."""
user_id: str
graph_id: str
node_id: str
graph_version: int
graph_exec_id: str
node_exec_id: str
execution_context: "ExecutionContext"
def _get_tool_requests(entry: dict[str, Any]) -> list[str]: def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
""" """
Return a list of tool_call_ids if the entry is a tool request. Return a list of tool_call_ids if the entry is a tool request.
@@ -105,6 +133,50 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
return {"role": "tool", "tool_call_id": call_id, "content": content} return {"role": "tool", "tool_call_id": call_id, "content": content}
def _combine_tool_responses(tool_outputs: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Combine multiple Anthropic tool responses into a single user message.
For non-Anthropic formats, returns the original list unchanged.
"""
if len(tool_outputs) <= 1:
return tool_outputs
# Anthropic responses have role="user", type="message", and content is a list with tool_result items
anthropic_responses = [
output
for output in tool_outputs
if (
output.get("role") == "user"
and output.get("type") == "message"
and isinstance(output.get("content"), list)
and any(
item.get("type") == "tool_result"
for item in output.get("content", [])
if isinstance(item, dict)
)
)
]
if len(anthropic_responses) > 1:
combined_content = [
item for response in anthropic_responses for item in response["content"]
]
combined_response = {
"role": "user",
"type": "message",
"content": combined_content,
}
non_anthropic_responses = [
output for output in tool_outputs if output not in anthropic_responses
]
return [combined_response] + non_anthropic_responses
return tool_outputs
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]: def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
""" """
Safely convert raw_response to dictionary format for conversation history. Safely convert raw_response to dictionary format for conversation history.
@@ -204,6 +276,17 @@ class SmartDecisionMakerBlock(Block):
default="localhost:11434", default="localhost:11434",
description="Ollama host for local models", description="Ollama host for local models",
) )
agent_mode_max_iterations: int = SchemaField(
title="Agent Mode Max Iterations",
description="Maximum iterations for agent mode. 0 = traditional mode (single LLM call, yield tool calls for external execution), -1 = infinite agent mode (loop until finished), 1+ = agent mode with max iterations limit.",
advanced=True,
default=0,
)
conversation_compaction: bool = SchemaField(
default=True,
title="Context window auto-compaction",
description="Automatically compact the context window once it hits the limit",
)
@classmethod @classmethod
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]: def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
@@ -506,6 +589,7 @@ class SmartDecisionMakerBlock(Block):
Returns the response if successful, raises ValueError if validation fails. Returns the response if successful, raises ValueError if validation fails.
""" """
resp = await llm.llm_call( resp = await llm.llm_call(
compress_prompt_to_fit=input_data.conversation_compaction,
credentials=credentials, credentials=credentials,
llm_model=input_data.model, llm_model=input_data.model,
prompt=current_prompt, prompt=current_prompt,
@@ -593,6 +677,291 @@ class SmartDecisionMakerBlock(Block):
return resp return resp
def _process_tool_calls(
self, response, tool_functions: list[dict[str, Any]]
) -> list[ToolInfo]:
"""Process tool calls and extract tool definitions, arguments, and input data.
Returns a list of tool info dicts with:
- tool_call: The original tool call object
- tool_name: The function name
- tool_def: The tool definition from tool_functions
- input_data: Processed input data dict (includes None values)
- field_mapping: Field name mapping for the tool
"""
if not response.tool_calls:
return []
processed_tools = []
for tool_call in response.tool_calls:
tool_name = tool_call.function.name
tool_args = json.loads(tool_call.function.arguments)
tool_def = next(
(
tool
for tool in tool_functions
if tool["function"]["name"] == tool_name
),
None,
)
if not tool_def:
if len(tool_functions) == 1:
tool_def = tool_functions[0]
else:
continue
# Build input data for the tool
input_data = {}
field_mapping = tool_def["function"].get("_field_mapping", {})
if "function" in tool_def and "parameters" in tool_def["function"]:
expected_args = tool_def["function"]["parameters"].get("properties", {})
for clean_arg_name in expected_args:
original_field_name = field_mapping.get(
clean_arg_name, clean_arg_name
)
arg_value = tool_args.get(clean_arg_name)
# Include all expected parameters, even if None (for backward compatibility with tests)
input_data[original_field_name] = arg_value
processed_tools.append(
ToolInfo(
tool_call=tool_call,
tool_name=tool_name,
tool_def=tool_def,
input_data=input_data,
field_mapping=field_mapping,
)
)
return processed_tools
def _update_conversation(
self, prompt: list[dict], response, tool_outputs: list | None = None
):
"""Update conversation history with response and tool outputs."""
# Don't add separate reasoning message with tool calls (breaks Anthropic's tool_use->tool_result pairing)
assistant_message = _convert_raw_response_to_dict(response.raw_response)
has_tool_calls = isinstance(assistant_message.get("content"), list) and any(
item.get("type") == "tool_use"
for item in assistant_message.get("content", [])
)
if response.reasoning and not has_tool_calls:
prompt.append(
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
)
prompt.append(assistant_message)
if tool_outputs:
prompt.extend(tool_outputs)
async def _execute_single_tool_with_manager(
self,
tool_info: ToolInfo,
execution_params: ExecutionParams,
execution_processor: "ExecutionProcessor",
) -> dict:
"""Execute a single tool using the execution manager for proper integration."""
# Lazy imports to avoid circular dependencies
from backend.data.execution import NodeExecutionEntry
tool_call = tool_info.tool_call
tool_def = tool_info.tool_def
raw_input_data = tool_info.input_data
# Get sink node and field mapping
sink_node_id = tool_def["function"]["_sink_node_id"]
# Use proper database operations for tool execution
db_client = get_database_manager_async_client()
# Get target node
target_node = await db_client.get_node(sink_node_id)
if not target_node:
raise ValueError(f"Target node {sink_node_id} not found")
# Create proper node execution using upsert_execution_input
node_exec_result = None
final_input_data = None
# Add all inputs to the execution
if not raw_input_data:
raise ValueError(f"Tool call has no input data: {tool_call}")
for input_name, input_value in raw_input_data.items():
node_exec_result, final_input_data = await db_client.upsert_execution_input(
node_id=sink_node_id,
graph_exec_id=execution_params.graph_exec_id,
input_name=input_name,
input_data=input_value,
)
assert node_exec_result is not None, "node_exec_result should not be None"
# Create NodeExecutionEntry for execution manager
node_exec_entry = NodeExecutionEntry(
user_id=execution_params.user_id,
graph_exec_id=execution_params.graph_exec_id,
graph_id=execution_params.graph_id,
graph_version=execution_params.graph_version,
node_exec_id=node_exec_result.node_exec_id,
node_id=sink_node_id,
block_id=target_node.block_id,
inputs=final_input_data or {},
execution_context=execution_params.execution_context,
)
# Use the execution manager to execute the tool node
try:
# Get NodeExecutionProgress from the execution manager's running nodes
node_exec_progress = execution_processor.running_node_execution[
sink_node_id
]
# Use the execution manager's own graph stats
graph_stats_pair = (
execution_processor.execution_stats,
execution_processor.execution_stats_lock,
)
# Create a completed future for the task tracking system
node_exec_future = Future()
node_exec_progress.add_task(
node_exec_id=node_exec_result.node_exec_id,
task=node_exec_future,
)
# Execute the node directly since we're in the SmartDecisionMaker context
node_exec_future.set_result(
await execution_processor.on_node_execution(
node_exec=node_exec_entry,
node_exec_progress=node_exec_progress,
nodes_input_masks=None,
graph_stats_pair=graph_stats_pair,
)
)
# Get outputs from database after execution completes using database manager client
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
node_exec_result.node_exec_id
)
# Create tool response
tool_response_content = (
json.dumps(node_outputs)
if node_outputs
else "Tool executed successfully"
)
return _create_tool_response(tool_call.id, tool_response_content)
except Exception as e:
logger.error(f"Tool execution with manager failed: {e}")
# Return error response
return _create_tool_response(
tool_call.id, f"Tool execution failed: {str(e)}"
)
async def _execute_tools_agent_mode(
self,
input_data,
credentials,
tool_functions: list[dict[str, Any]],
prompt: list[dict],
graph_exec_id: str,
node_id: str,
node_exec_id: str,
user_id: str,
graph_id: str,
graph_version: int,
execution_context: ExecutionContext,
execution_processor: "ExecutionProcessor",
):
"""Execute tools in agent mode with a loop until finished."""
max_iterations = input_data.agent_mode_max_iterations
iteration = 0
# Execution parameters for tool execution
execution_params = ExecutionParams(
user_id=user_id,
graph_id=graph_id,
node_id=node_id,
graph_version=graph_version,
graph_exec_id=graph_exec_id,
node_exec_id=node_exec_id,
execution_context=execution_context,
)
current_prompt = list(prompt)
while max_iterations < 0 or iteration < max_iterations:
iteration += 1
logger.debug(f"Agent mode iteration {iteration}")
# Prepare prompt for this iteration
iteration_prompt = list(current_prompt)
# On the last iteration, add a special system message to encourage completion
if max_iterations > 0 and iteration == max_iterations:
last_iteration_message = {
"role": "system",
"content": f"{MAIN_OBJECTIVE_PREFIX}This is your last iteration ({iteration}/{max_iterations}). "
"Try to complete the task with the information you have. If you cannot fully complete it, "
"provide a summary of what you've accomplished and what remains to be done. "
"Prefer finishing with a clear response rather than making additional tool calls.",
}
iteration_prompt.append(last_iteration_message)
# Get LLM response
try:
response = await self._attempt_llm_call_with_validation(
credentials, input_data, iteration_prompt, tool_functions
)
except Exception as e:
yield "error", f"LLM call failed in agent mode iteration {iteration}: {str(e)}"
return
# Process tool calls
processed_tools = self._process_tool_calls(response, tool_functions)
# If no tool calls, we're done
if not processed_tools:
yield "finished", response.response
self._update_conversation(current_prompt, response)
yield "conversations", current_prompt
return
# Execute tools and collect responses
tool_outputs = []
for tool_info in processed_tools:
try:
tool_response = await self._execute_single_tool_with_manager(
tool_info, execution_params, execution_processor
)
tool_outputs.append(tool_response)
except Exception as e:
logger.error(f"Tool execution failed: {e}")
# Create error response for the tool
error_response = _create_tool_response(
tool_info.tool_call.id, f"Error: {str(e)}"
)
tool_outputs.append(error_response)
tool_outputs = _combine_tool_responses(tool_outputs)
self._update_conversation(current_prompt, response, tool_outputs)
# Yield intermediate conversation state
yield "conversations", current_prompt
# If we reach max iterations, yield the current state
if max_iterations < 0:
yield "finished", f"Agent mode completed after {iteration} iterations"
else:
yield "finished", f"Agent mode completed after {max_iterations} iterations (limit reached)"
yield "conversations", current_prompt
async def run( async def run(
self, self,
input_data: Input, input_data: Input,
@@ -603,8 +972,12 @@ class SmartDecisionMakerBlock(Block):
graph_exec_id: str, graph_exec_id: str,
node_exec_id: str, node_exec_id: str,
user_id: str, user_id: str,
graph_version: int,
execution_context: ExecutionContext,
execution_processor: "ExecutionProcessor",
**kwargs, **kwargs,
) -> BlockOutput: ) -> BlockOutput:
tool_functions = await self._create_tool_node_signatures(node_id) tool_functions = await self._create_tool_node_signatures(node_id)
yield "tool_functions", json.dumps(tool_functions) yield "tool_functions", json.dumps(tool_functions)
@@ -648,24 +1021,52 @@ class SmartDecisionMakerBlock(Block):
input_data.prompt = llm.fmt.format_string(input_data.prompt, values) input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values) input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
prefix = "[Main Objective Prompt]: "
if input_data.sys_prompt and not any( if input_data.sys_prompt and not any(
p["role"] == "system" and p["content"].startswith(prefix) for p in prompt p["role"] == "system" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
for p in prompt
): ):
prompt.append({"role": "system", "content": prefix + input_data.sys_prompt}) prompt.append(
{
"role": "system",
"content": MAIN_OBJECTIVE_PREFIX + input_data.sys_prompt,
}
)
if input_data.prompt and not any( if input_data.prompt and not any(
p["role"] == "user" and p["content"].startswith(prefix) for p in prompt p["role"] == "user" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
for p in prompt
): ):
prompt.append({"role": "user", "content": prefix + input_data.prompt}) prompt.append(
{"role": "user", "content": MAIN_OBJECTIVE_PREFIX + input_data.prompt}
)
# Execute tools based on the selected mode
if input_data.agent_mode_max_iterations != 0:
# In agent mode, execute tools directly in a loop until finished
async for result in self._execute_tools_agent_mode(
input_data=input_data,
credentials=credentials,
tool_functions=tool_functions,
prompt=prompt,
graph_exec_id=graph_exec_id,
node_id=node_id,
node_exec_id=node_exec_id,
user_id=user_id,
graph_id=graph_id,
graph_version=graph_version,
execution_context=execution_context,
execution_processor=execution_processor,
):
yield result
return
# One-off mode: single LLM call and yield tool calls for external execution
current_prompt = list(prompt) current_prompt = list(prompt)
max_attempts = max(1, int(input_data.retry)) max_attempts = max(1, int(input_data.retry))
response = None response = None
last_error = None last_error = None
for attempt in range(max_attempts): for _ in range(max_attempts):
try: try:
response = await self._attempt_llm_call_with_validation( response = await self._attempt_llm_call_with_validation(
credentials, input_data, current_prompt, tool_functions credentials, input_data, current_prompt, tool_functions

View File

@@ -1,7 +1,11 @@
import logging import logging
import threading
from collections import defaultdict
from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from backend.data.execution import ExecutionContext
from backend.data.model import ProviderName, User from backend.data.model import ProviderName, User
from backend.server.model import CreateGraph from backend.server.model import CreateGraph
from backend.server.rest_api import AgentServer from backend.server.rest_api import AgentServer
@@ -17,10 +21,10 @@ async def create_graph(s: SpinTestServer, g, u: User):
async def create_credentials(s: SpinTestServer, u: User): async def create_credentials(s: SpinTestServer, u: User):
import backend.blocks.llm as llm import backend.blocks.llm as llm_module
provider = ProviderName.OPENAI provider = ProviderName.OPENAI
credentials = llm.TEST_CREDENTIALS credentials = llm_module.TEST_CREDENTIALS
return await s.agent_server.test_create_credentials(u.id, provider, credentials) return await s.agent_server.test_create_credentials(u.id, provider, credentials)
@@ -196,8 +200,6 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_smart_decision_maker_tracks_llm_stats(): async def test_smart_decision_maker_tracks_llm_stats():
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats.""" """Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@@ -216,7 +218,6 @@ async def test_smart_decision_maker_tracks_llm_stats():
} }
# Mock the _create_tool_node_signatures method to avoid database calls # Mock the _create_tool_node_signatures method to avoid database calls
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
@@ -234,10 +235,19 @@ async def test_smart_decision_maker_tracks_llm_stats():
prompt="Should I continue with this task?", prompt="Should I continue with this task?",
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
) )
# Execute the block # Execute the block
outputs = {} outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run( async for output_name, output_data in block.run(
input_data, input_data,
credentials=llm_module.TEST_CREDENTIALS, credentials=llm_module.TEST_CREDENTIALS,
@@ -246,6 +256,9 @@ async def test_smart_decision_maker_tracks_llm_stats():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -263,8 +276,6 @@ async def test_smart_decision_maker_tracks_llm_stats():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_smart_decision_maker_parameter_validation(): async def test_smart_decision_maker_parameter_validation():
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters.""" """Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@@ -311,8 +322,6 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_with_typo.reasoning = None mock_response_with_typo.reasoning = None
mock_response_with_typo.raw_response = {"role": "assistant", "content": None} mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
new_callable=AsyncMock, new_callable=AsyncMock,
@@ -329,8 +338,17 @@ async def test_smart_decision_maker_parameter_validation():
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2, # Set retry to 2 for testing retry=2, # Set retry to 2 for testing
agent_mode_max_iterations=0,
) )
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
# Should raise ValueError after retries due to typo'd parameter name # Should raise ValueError after retries due to typo'd parameter name
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
outputs = {} outputs = {}
@@ -342,6 +360,9 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -368,8 +389,6 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_missing_required.reasoning = None mock_response_missing_required.reasoning = None
mock_response_missing_required.raw_response = {"role": "assistant", "content": None} mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
new_callable=AsyncMock, new_callable=AsyncMock,
@@ -385,8 +404,17 @@ async def test_smart_decision_maker_parameter_validation():
prompt="Search for keywords", prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
) )
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
# Should raise ValueError due to missing required parameter # Should raise ValueError due to missing required parameter
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
outputs = {} outputs = {}
@@ -398,6 +426,9 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -418,8 +449,6 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_valid.reasoning = None mock_response_valid.reasoning = None
mock_response_valid.raw_response = {"role": "assistant", "content": None} mock_response_valid.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
new_callable=AsyncMock, new_callable=AsyncMock,
@@ -435,10 +464,19 @@ async def test_smart_decision_maker_parameter_validation():
prompt="Search for keywords", prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
) )
# Should succeed - optional parameter missing is OK # Should succeed - optional parameter missing is OK
outputs = {} outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run( async for output_name, output_data in block.run(
input_data, input_data,
credentials=llm_module.TEST_CREDENTIALS, credentials=llm_module.TEST_CREDENTIALS,
@@ -447,6 +485,9 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -472,8 +513,6 @@ async def test_smart_decision_maker_parameter_validation():
mock_response_all_params.reasoning = None mock_response_all_params.reasoning = None
mock_response_all_params.raw_response = {"role": "assistant", "content": None} mock_response_all_params.raw_response = {"role": "assistant", "content": None}
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
new_callable=AsyncMock, new_callable=AsyncMock,
@@ -489,10 +528,19 @@ async def test_smart_decision_maker_parameter_validation():
prompt="Search for keywords", prompt="Search for keywords",
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
) )
# Should succeed with all parameters # Should succeed with all parameters
outputs = {} outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run( async for output_name, output_data in block.run(
input_data, input_data,
credentials=llm_module.TEST_CREDENTIALS, credentials=llm_module.TEST_CREDENTIALS,
@@ -501,6 +549,9 @@ async def test_smart_decision_maker_parameter_validation():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -513,8 +564,6 @@ async def test_smart_decision_maker_parameter_validation():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_smart_decision_maker_raw_response_conversion(): async def test_smart_decision_maker_raw_response_conversion():
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism.""" """Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
from unittest.mock import MagicMock, patch
import backend.blocks.llm as llm_module import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
@@ -584,7 +633,6 @@ async def test_smart_decision_maker_raw_response_conversion():
) )
# Mock llm_call to return different responses on different calls # Mock llm_call to return different responses on different calls
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", new_callable=AsyncMock "backend.blocks.llm.llm_call", new_callable=AsyncMock
@@ -603,10 +651,19 @@ async def test_smart_decision_maker_raw_response_conversion():
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2, retry=2,
agent_mode_max_iterations=0,
) )
# Should succeed after retry, demonstrating our helper function works # Should succeed after retry, demonstrating our helper function works
outputs = {} outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run( async for output_name, output_data in block.run(
input_data, input_data,
credentials=llm_module.TEST_CREDENTIALS, credentials=llm_module.TEST_CREDENTIALS,
@@ -615,6 +672,9 @@ async def test_smart_decision_maker_raw_response_conversion():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -650,8 +710,6 @@ async def test_smart_decision_maker_raw_response_conversion():
"I'll help you with that." # Ollama returns string "I'll help you with that." # Ollama returns string
) )
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
new_callable=AsyncMock, new_callable=AsyncMock,
@@ -666,9 +724,18 @@ async def test_smart_decision_maker_raw_response_conversion():
prompt="Simple prompt", prompt="Simple prompt",
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
) )
outputs = {} outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run( async for output_name, output_data in block.run(
input_data, input_data,
credentials=llm_module.TEST_CREDENTIALS, credentials=llm_module.TEST_CREDENTIALS,
@@ -677,6 +744,9 @@ async def test_smart_decision_maker_raw_response_conversion():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
@@ -696,8 +766,6 @@ async def test_smart_decision_maker_raw_response_conversion():
"content": "Test response", "content": "Test response",
} # Dict format } # Dict format
from unittest.mock import AsyncMock
with patch( with patch(
"backend.blocks.llm.llm_call", "backend.blocks.llm.llm_call",
new_callable=AsyncMock, new_callable=AsyncMock,
@@ -712,6 +780,160 @@ async def test_smart_decision_maker_raw_response_conversion():
prompt="Another test", prompt="Another test",
model=llm_module.LlmModel.GPT4O, model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0,
)
outputs = {}
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
assert "finished" in outputs
assert outputs["finished"] == "Test response"
@pytest.mark.asyncio
async def test_smart_decision_maker_agent_mode():
"""Test that agent mode executes tools directly and loops until finished."""
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = SmartDecisionMakerBlock()
# Mock tool call that requires multiple iterations
mock_tool_call_1 = MagicMock()
mock_tool_call_1.id = "call_1"
mock_tool_call_1.function.name = "search_keywords"
mock_tool_call_1.function.arguments = (
'{"query": "test", "max_keyword_difficulty": 50}'
)
mock_response_1 = MagicMock()
mock_response_1.response = None
mock_response_1.tool_calls = [mock_tool_call_1]
mock_response_1.prompt_tokens = 50
mock_response_1.completion_tokens = 25
mock_response_1.reasoning = "Using search tool"
mock_response_1.raw_response = {
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_1", "type": "function"}],
}
# Final response with no tool calls (finished)
mock_response_2 = MagicMock()
mock_response_2.response = "Task completed successfully"
mock_response_2.tool_calls = []
mock_response_2.prompt_tokens = 30
mock_response_2.completion_tokens = 15
mock_response_2.reasoning = None
mock_response_2.raw_response = {
"role": "assistant",
"content": "Task completed successfully",
}
# Mock the LLM call to return different responses on each iteration
llm_call_mock = AsyncMock()
llm_call_mock.side_effect = [mock_response_1, mock_response_2]
# Mock tool node signatures
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "search_keywords",
"_sink_node_id": "test-sink-node-id",
"_field_mapping": {},
"parameters": {
"properties": {
"query": {"type": "string"},
"max_keyword_difficulty": {"type": "integer"},
},
"required": ["query", "max_keyword_difficulty"],
},
},
}
]
# Mock database and execution components
mock_db_client = AsyncMock()
mock_node = MagicMock()
mock_node.block_id = "test-block-id"
mock_db_client.get_node.return_value = mock_node
# Mock upsert_execution_input to return proper NodeExecutionResult and input data
mock_node_exec_result = MagicMock()
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
mock_input_data = {"query": "test", "max_keyword_difficulty": 50}
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
mock_input_data,
)
# No longer need mock_execute_node since we use execution_processor.on_node_execution
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
), patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
return_value=mock_db_client,
), patch(
"backend.executor.manager.async_update_node_execution_status",
new_callable=AsyncMock,
), patch(
"backend.integrations.creds_manager.IntegrationCredentialsManager"
):
# Create a mock execution context
mock_execution_context = ExecutionContext(
safe_mode=False,
)
# Create a mock execution processor for agent mode tests
mock_execution_processor = AsyncMock()
# Configure the execution processor mock with required attributes
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = threading.Lock()
# Mock the on_node_execution method to return successful stats
mock_node_stats = MagicMock()
mock_node_stats.error = None # No error
mock_execution_processor.on_node_execution = AsyncMock(
return_value=mock_node_stats
)
# Mock the get_execution_outputs_by_node_exec_id method
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
"result": {"status": "success", "data": "search completed"}
}
# Test agent mode with max_iterations = 3
input_data = SmartDecisionMakerBlock.Input(
prompt="Complete this task using tools",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
) )
outputs = {} outputs = {}
@@ -723,8 +945,115 @@ async def test_smart_decision_maker_raw_response_conversion():
graph_exec_id="test-exec-id", graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id", node_exec_id="test-node-exec-id",
user_id="test-user-id", user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_data outputs[output_name] = output_data
# Verify agent mode behavior
assert "tool_functions" in outputs # tool_functions is yielded in both modes
assert "finished" in outputs assert "finished" in outputs
assert outputs["finished"] == "Test response" assert outputs["finished"] == "Task completed successfully"
assert "conversations" in outputs
# Verify the conversation includes tool responses
conversations = outputs["conversations"]
assert len(conversations) > 2 # Should have multiple conversation entries
# Verify LLM was called twice (once for tool call, once for finish)
assert llm_call_mock.call_count == 2
# Verify tool was executed via execution processor
assert mock_execution_processor.on_node_execution.call_count == 1
@pytest.mark.asyncio
async def test_smart_decision_maker_traditional_mode_default():
"""Test that default behavior (agent_mode_max_iterations=0) works as traditional mode."""
import backend.blocks.llm as llm_module
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
block = SmartDecisionMakerBlock()
# Mock tool call
mock_tool_call = MagicMock()
mock_tool_call.function.name = "search_keywords"
mock_tool_call.function.arguments = (
'{"query": "test", "max_keyword_difficulty": 50}'
)
mock_response = MagicMock()
mock_response.response = None
mock_response.tool_calls = [mock_tool_call]
mock_response.prompt_tokens = 50
mock_response.completion_tokens = 25
mock_response.reasoning = None
mock_response.raw_response = {"role": "assistant", "content": None}
mock_tool_signatures = [
{
"type": "function",
"function": {
"name": "search_keywords",
"_sink_node_id": "test-sink-node-id",
"_field_mapping": {},
"parameters": {
"properties": {
"query": {"type": "string"},
"max_keyword_difficulty": {"type": "integer"},
},
"required": ["query", "max_keyword_difficulty"],
},
},
}
]
with patch(
"backend.blocks.llm.llm_call",
new_callable=AsyncMock,
return_value=mock_response,
), patch.object(
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
):
# Test default behavior (traditional mode)
input_data = SmartDecisionMakerBlock.Input(
prompt="Test prompt",
model=llm_module.LlmModel.GPT4O,
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
agent_mode_max_iterations=0, # Traditional mode
)
# Create execution context
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a mock execution processor for tests
mock_execution_processor = MagicMock()
outputs = {}
async for output_name, output_data in block.run(
input_data,
credentials=llm_module.TEST_CREDENTIALS,
graph_id="test-graph-id",
node_id="test-node-id",
graph_exec_id="test-exec-id",
node_exec_id="test-node-exec-id",
user_id="test-user-id",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_data
# Verify traditional mode behavior
assert (
"tool_functions" in outputs
) # Should yield tool_functions in traditional mode
assert (
"tools_^_test-sink-node-id_~_query" in outputs
) # Should yield individual tool parameters
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
assert "conversations" in outputs

View File

@@ -1,7 +1,7 @@
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling.""" """Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
import json import json
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest import pytest
@@ -308,10 +308,47 @@ async def test_output_yielding_with_dynamic_fields():
) as mock_llm: ) as mock_llm:
mock_llm.return_value = mock_response mock_llm.return_value = mock_response
# Mock the function signature creation # Mock the database manager to avoid HTTP calls during tool execution
with patch.object( with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db_manager, patch.object(
block, "_create_tool_node_signatures", new_callable=AsyncMock block, "_create_tool_node_signatures", new_callable=AsyncMock
) as mock_sig: ) as mock_sig:
# Set up the mock database manager
mock_db_client = AsyncMock()
mock_db_manager.return_value = mock_db_client
# Mock the node retrieval
mock_target_node = Mock()
mock_target_node.id = "test-sink-node-id"
mock_target_node.block_id = "CreateDictionaryBlock"
mock_target_node.block = Mock()
mock_target_node.block.name = "Create Dictionary"
mock_db_client.get_node.return_value = mock_target_node
# Mock the execution result creation
mock_node_exec_result = Mock()
mock_node_exec_result.node_exec_id = "mock-node-exec-id"
mock_final_input_data = {
"values_#_name": "Alice",
"values_#_age": 30,
"values_#_email": "alice@example.com",
}
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
mock_final_input_data,
)
# Mock the output retrieval
mock_outputs = {
"values_#_name": "Alice",
"values_#_age": 30,
"values_#_email": "alice@example.com",
}
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
mock_outputs
)
mock_sig.return_value = [ mock_sig.return_value = [
{ {
"type": "function", "type": "function",
@@ -337,10 +374,16 @@ async def test_output_yielding_with_dynamic_fields():
prompt="Create a user dictionary", prompt="Create a user dictionary",
credentials=llm.TEST_CREDENTIALS_INPUT, credentials=llm.TEST_CREDENTIALS_INPUT,
model=llm.LlmModel.GPT4O, model=llm.LlmModel.GPT4O,
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
) )
# Run the block # Run the block
outputs = {} outputs = {}
from backend.data.execution import ExecutionContext
mock_execution_context = ExecutionContext(safe_mode=False)
mock_execution_processor = MagicMock()
async for output_name, output_value in block.run( async for output_name, output_value in block.run(
input_data, input_data,
credentials=llm.TEST_CREDENTIALS, credentials=llm.TEST_CREDENTIALS,
@@ -349,6 +392,9 @@ async def test_output_yielding_with_dynamic_fields():
graph_exec_id="test_exec", graph_exec_id="test_exec",
node_exec_id="test_node_exec", node_exec_id="test_node_exec",
user_id="test_user", user_id="test_user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
): ):
outputs[output_name] = output_value outputs[output_name] = output_value
@@ -511,45 +557,108 @@ async def test_validation_errors_dont_pollute_conversation():
} }
] ]
# Create input data # Mock the database manager to avoid HTTP calls during tool execution
from backend.blocks import llm with patch(
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
) as mock_db_manager:
# Set up the mock database manager for agent mode
mock_db_client = AsyncMock()
mock_db_manager.return_value = mock_db_client
input_data = block.input_schema( # Mock the node retrieval
prompt="Test prompt", mock_target_node = Mock()
credentials=llm.TEST_CREDENTIALS_INPUT, mock_target_node.id = "test-sink-node-id"
model=llm.LlmModel.GPT4O, mock_target_node.block_id = "TestBlock"
retry=3, # Allow retries mock_target_node.block = Mock()
) mock_target_node.block.name = "Test Block"
mock_db_client.get_node.return_value = mock_target_node
# Run the block # Mock the execution result creation
outputs = {} mock_node_exec_result = Mock()
async for output_name, output_value in block.run( mock_node_exec_result.node_exec_id = "mock-node-exec-id"
input_data, mock_final_input_data = {"correct_param": "value"}
credentials=llm.TEST_CREDENTIALS, mock_db_client.upsert_execution_input.return_value = (
graph_id="test_graph", mock_node_exec_result,
node_id="test_node", mock_final_input_data,
graph_exec_id="test_exec", )
node_exec_id="test_node_exec",
user_id="test_user",
):
outputs[output_name] = output_value
# Verify we had 2 LLM calls (initial + retry) # Mock the output retrieval
assert call_count == 2 mock_outputs = {"correct_param": "value"}
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
mock_outputs
)
# Check the final conversation output # Create input data
final_conversation = outputs.get("conversations", []) from backend.blocks import llm
# The final conversation should NOT contain the validation error message input_data = block.input_schema(
error_messages = [ prompt="Test prompt",
msg credentials=llm.TEST_CREDENTIALS_INPUT,
for msg in final_conversation model=llm.LlmModel.GPT4O,
if msg.get("role") == "user" retry=3, # Allow retries
and "parameter errors" in msg.get("content", "") agent_mode_max_iterations=1,
] )
assert (
len(error_messages) == 0
), "Validation error leaked into final conversation"
# The final conversation should only have the successful response # Run the block
assert final_conversation[-1]["content"] == "valid" outputs = {}
from backend.data.execution import ExecutionContext
mock_execution_context = ExecutionContext(safe_mode=False)
# Create a proper mock execution processor for agent mode
from collections import defaultdict
mock_execution_processor = AsyncMock()
mock_execution_processor.execution_stats = MagicMock()
mock_execution_processor.execution_stats_lock = MagicMock()
# Create a mock NodeExecutionProgress for the sink node
mock_node_exec_progress = MagicMock()
mock_node_exec_progress.add_task = MagicMock()
mock_node_exec_progress.pop_output = MagicMock(
return_value=None
) # No outputs to process
# Set up running_node_execution as a defaultdict that returns our mock for any key
mock_execution_processor.running_node_execution = defaultdict(
lambda: mock_node_exec_progress
)
# Mock the on_node_execution method that gets called during tool execution
mock_node_stats = MagicMock()
mock_node_stats.error = None
mock_execution_processor.on_node_execution.return_value = (
mock_node_stats
)
async for output_name, output_value in block.run(
input_data,
credentials=llm.TEST_CREDENTIALS,
graph_id="test_graph",
node_id="test_node",
graph_exec_id="test_exec",
node_exec_id="test_node_exec",
user_id="test_user",
graph_version=1,
execution_context=mock_execution_context,
execution_processor=mock_execution_processor,
):
outputs[output_name] = output_value
# Verify we had at least 1 LLM call
assert call_count >= 1
# Check the final conversation output
final_conversation = outputs.get("conversations", [])
# The final conversation should NOT contain validation error messages
# Even if retries don't happen in agent mode, we should not leak errors
error_messages = [
msg
for msg in final_conversation
if msg.get("role") == "user"
and "parameter errors" in msg.get("content", "")
]
assert (
len(error_messages) == 0
), "Validation error leaked into final conversation"

View File

@@ -1,10 +1,10 @@
import logging import logging
import queue
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from enum import Enum from enum import Enum
from multiprocessing import Manager
from queue import Empty
from typing import ( from typing import (
TYPE_CHECKING,
Annotated, Annotated,
Any, Any,
AsyncGenerator, AsyncGenerator,
@@ -65,6 +65,9 @@ from .includes import (
) )
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
if TYPE_CHECKING:
pass
T = TypeVar("T") T = TypeVar("T")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -836,6 +839,30 @@ async def upsert_execution_output(
await AgentNodeExecutionInputOutput.prisma().create(data=data) await AgentNodeExecutionInputOutput.prisma().create(data=data)
async def get_execution_outputs_by_node_exec_id(
node_exec_id: str,
) -> dict[str, Any]:
"""
Get all execution outputs for a specific node execution ID.
Args:
node_exec_id: The node execution ID to get outputs for
Returns:
Dictionary mapping output names to their data values
"""
outputs = await AgentNodeExecutionInputOutput.prisma().find_many(
where={"referencedByOutputExecId": node_exec_id}
)
result = {}
for output in outputs:
if output.data is not None:
result[output.name] = type_utils.convert(output.data, JsonValue)
return result
async def update_graph_execution_start_time( async def update_graph_execution_start_time(
graph_exec_id: str, graph_exec_id: str,
) -> GraphExecution | None: ) -> GraphExecution | None:
@@ -1136,12 +1163,16 @@ class NodeExecutionEntry(BaseModel):
class ExecutionQueue(Generic[T]): class ExecutionQueue(Generic[T]):
""" """
Queue for managing the execution of agents. Thread-safe queue for managing node execution within a single graph execution.
This will be shared between different processes
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
threads within the same process. If migrating back to ProcessPoolExecutor,
replace with multiprocessing.Manager().Queue() for cross-process safety.
""" """
def __init__(self): def __init__(self):
self.queue = Manager().Queue() # Thread-safe queue (not multiprocessing) — see class docstring
self.queue: queue.Queue[T] = queue.Queue()
def add(self, execution: T) -> T: def add(self, execution: T) -> T:
self.queue.put(execution) self.queue.put(execution)
@@ -1156,7 +1187,7 @@ class ExecutionQueue(Generic[T]):
def get_or_none(self) -> T | None: def get_or_none(self) -> T | None:
try: try:
return self.queue.get_nowait() return self.queue.get_nowait()
except Empty: except queue.Empty:
return None return None

View File

@@ -0,0 +1,60 @@
"""Tests for ExecutionQueue thread-safety."""
import queue
import threading
import pytest
from backend.data.execution import ExecutionQueue
def test_execution_queue_uses_stdlib_queue():
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
q = ExecutionQueue()
assert isinstance(q.queue, queue.Queue)
def test_basic_operations():
"""Test add, get, empty, and get_or_none."""
q = ExecutionQueue()
assert q.empty() is True
assert q.get_or_none() is None
result = q.add("item1")
assert result == "item1"
assert q.empty() is False
item = q.get()
assert item == "item1"
assert q.empty() is True
def test_thread_safety():
"""Test concurrent access from multiple threads."""
q = ExecutionQueue()
results = []
num_items = 100
def producer():
for i in range(num_items):
q.add(f"item_{i}")
def consumer():
count = 0
while count < num_items:
item = q.get_or_none()
if item is not None:
results.append(item)
count += 1
producer_thread = threading.Thread(target=producer)
consumer_thread = threading.Thread(target=consumer)
producer_thread.start()
consumer_thread.start()
producer_thread.join(timeout=5)
consumer_thread.join(timeout=5)
assert len(results) == num_items

View File

@@ -13,6 +13,7 @@ from backend.data.execution import (
get_block_error_stats, get_block_error_stats,
get_child_graph_executions, get_child_graph_executions,
get_execution_kv_data, get_execution_kv_data,
get_execution_outputs_by_node_exec_id,
get_frequently_executed_graphs, get_frequently_executed_graphs,
get_graph_execution_meta, get_graph_execution_meta,
get_graph_executions, get_graph_executions,
@@ -147,6 +148,7 @@ class DatabaseManager(AppService):
update_graph_execution_stats = _(update_graph_execution_stats) update_graph_execution_stats = _(update_graph_execution_stats)
upsert_execution_input = _(upsert_execution_input) upsert_execution_input = _(upsert_execution_input)
upsert_execution_output = _(upsert_execution_output) upsert_execution_output = _(upsert_execution_output)
get_execution_outputs_by_node_exec_id = _(get_execution_outputs_by_node_exec_id)
get_execution_kv_data = _(get_execution_kv_data) get_execution_kv_data = _(get_execution_kv_data)
set_execution_kv_data = _(set_execution_kv_data) set_execution_kv_data = _(set_execution_kv_data)
get_block_error_stats = _(get_block_error_stats) get_block_error_stats = _(get_block_error_stats)
@@ -277,6 +279,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
get_user_integrations = d.get_user_integrations get_user_integrations = d.get_user_integrations
upsert_execution_input = d.upsert_execution_input upsert_execution_input = d.upsert_execution_input
upsert_execution_output = d.upsert_execution_output upsert_execution_output = d.upsert_execution_output
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
update_graph_execution_stats = d.update_graph_execution_stats update_graph_execution_stats = d.update_graph_execution_stats
update_node_execution_status = d.update_node_execution_status update_node_execution_status = d.update_node_execution_status
update_node_execution_status_batch = d.update_node_execution_status_batch update_node_execution_status_batch = d.update_node_execution_status_batch

View File

@@ -133,9 +133,8 @@ def execute_graph(
cluster_lock: ClusterLock, cluster_lock: ClusterLock,
): ):
"""Execute graph using thread-local ExecutionProcessor instance""" """Execute graph using thread-local ExecutionProcessor instance"""
return _tls.processor.on_graph_execution( processor: ExecutionProcessor = _tls.processor
graph_exec_entry, cancel_event, cluster_lock return processor.on_graph_execution(graph_exec_entry, cancel_event, cluster_lock)
)
T = TypeVar("T") T = TypeVar("T")
@@ -143,8 +142,8 @@ T = TypeVar("T")
async def execute_node( async def execute_node(
node: Node, node: Node,
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry, data: NodeExecutionEntry,
execution_processor: "ExecutionProcessor",
execution_stats: NodeExecutionStats | None = None, execution_stats: NodeExecutionStats | None = None,
nodes_input_masks: Optional[NodesInputMasks] = None, nodes_input_masks: Optional[NodesInputMasks] = None,
) -> BlockOutput: ) -> BlockOutput:
@@ -169,6 +168,7 @@ async def execute_node(
node_id = data.node_id node_id = data.node_id
node_block = node.block node_block = node.block
execution_context = data.execution_context execution_context = data.execution_context
creds_manager = execution_processor.creds_manager
log_metadata = LogMetadata( log_metadata = LogMetadata(
logger=_logger, logger=_logger,
@@ -212,6 +212,7 @@ async def execute_node(
"node_exec_id": node_exec_id, "node_exec_id": node_exec_id,
"user_id": user_id, "user_id": user_id,
"execution_context": execution_context, "execution_context": execution_context,
"execution_processor": execution_processor,
} }
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent # Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
@@ -608,8 +609,8 @@ class ExecutionProcessor:
async for output_name, output_data in execute_node( async for output_name, output_data in execute_node(
node=node, node=node,
creds_manager=self.creds_manager,
data=node_exec, data=node_exec,
execution_processor=self,
execution_stats=stats, execution_stats=stats,
nodes_input_masks=nodes_input_masks, nodes_input_masks=nodes_input_masks,
): ):
@@ -860,12 +861,17 @@ class ExecutionProcessor:
execution_stats_lock = threading.Lock() execution_stats_lock = threading.Lock()
# State holders ---------------------------------------------------- # State holders ----------------------------------------------------
running_node_execution: dict[str, NodeExecutionProgress] = defaultdict( self.running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
NodeExecutionProgress NodeExecutionProgress
) )
running_node_evaluation: dict[str, Future] = {} self.running_node_evaluation: dict[str, Future] = {}
self.execution_stats = execution_stats
self.execution_stats_lock = execution_stats_lock
execution_queue = ExecutionQueue[NodeExecutionEntry]() execution_queue = ExecutionQueue[NodeExecutionEntry]()
running_node_execution = self.running_node_execution
running_node_evaluation = self.running_node_evaluation
try: try:
if db_client.get_credits(graph_exec.user_id) <= 0: if db_client.get_credits(graph_exec.user_id) <= 0:
raise InsufficientBalanceError( raise InsufficientBalanceError(

View File

@@ -5,6 +5,13 @@ from tiktoken import encoding_for_model
from backend.util import json from backend.util import json
# ---------------------------------------------------------------------------#
# CONSTANTS #
# ---------------------------------------------------------------------------#
# Message prefixes for important system messages that should be protected during compression
MAIN_OBJECTIVE_PREFIX = "[Main Objective Prompt]: "
# ---------------------------------------------------------------------------# # ---------------------------------------------------------------------------#
# INTERNAL UTILITIES # # INTERNAL UTILITIES #
# ---------------------------------------------------------------------------# # ---------------------------------------------------------------------------#
@@ -63,6 +70,55 @@ def _msg_tokens(msg: dict, enc) -> int:
return WRAPPER + content_tokens + tool_call_tokens return WRAPPER + content_tokens + tool_call_tokens
def _is_tool_message(msg: dict) -> bool:
"""Check if a message contains tool calls or results that should be protected."""
content = msg.get("content")
# Check for Anthropic-style tool messages
if isinstance(content, list) and any(
isinstance(item, dict) and item.get("type") in ("tool_use", "tool_result")
for item in content
):
return True
# Check for OpenAI-style tool calls in the message
if "tool_calls" in msg or msg.get("role") == "tool":
return True
return False
def _is_objective_message(msg: dict) -> bool:
"""Check if a message contains objective/system prompts that should be absolutely protected."""
content = msg.get("content", "")
if isinstance(content, str):
# Protect any message with the main objective prefix
return content.startswith(MAIN_OBJECTIVE_PREFIX)
return False
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
"""
Carefully truncate tool message content while preserving tool structure.
Only truncates tool_result content, leaves tool_use intact.
"""
content = msg.get("content")
if not isinstance(content, list):
return
for item in content:
# Only process tool_result items, leave tool_use blocks completely intact
if not (isinstance(item, dict) and item.get("type") == "tool_result"):
continue
result_content = item.get("content", "")
if (
isinstance(result_content, str)
and _tok_len(result_content, enc) > max_tokens
):
item["content"] = _truncate_middle_tokens(result_content, enc, max_tokens)
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str: def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
""" """
Return *text* shortened to ≈max_tok tokens by keeping the head & tail Return *text* shortened to ≈max_tok tokens by keeping the head & tail
@@ -140,13 +196,21 @@ def compress_prompt(
return sum(_msg_tokens(m, enc) for m in msgs) return sum(_msg_tokens(m, enc) for m in msgs)
original_token_count = total_tokens() original_token_count = total_tokens()
if original_token_count + reserve <= target_tokens: if original_token_count + reserve <= target_tokens:
return msgs return msgs
# ---- STEP 0 : normalise content -------------------------------------- # ---- STEP 0 : normalise content --------------------------------------
# Convert non-string payloads to strings so token counting is coherent. # Convert non-string payloads to strings so token counting is coherent.
for m in msgs[1:-1]: # keep the first & last intact for i, m in enumerate(msgs):
if not isinstance(m.get("content"), str) and m.get("content") is not None: if not isinstance(m.get("content"), str) and m.get("content") is not None:
if _is_tool_message(m):
continue
# Keep first and last messages intact (unless they're tool messages)
if i == 0 or i == len(msgs) - 1:
continue
# Reasonable 20k-char ceiling prevents pathological blobs # Reasonable 20k-char ceiling prevents pathological blobs
content_str = json.dumps(m["content"], separators=(",", ":")) content_str = json.dumps(m["content"], separators=(",", ":"))
if len(content_str) > 20_000: if len(content_str) > 20_000:
@@ -157,34 +221,45 @@ def compress_prompt(
cap = start_cap cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap: while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for m in msgs[1:-1]: # keep first & last intact for m in msgs[1:-1]: # keep first & last intact
if _tok_len(m.get("content") or "", enc) > cap: if _is_tool_message(m):
m["content"] = _truncate_middle_tokens(m["content"], enc, cap) # For tool messages, only truncate tool result content, preserve structure
_truncate_tool_message_content(m, enc, cap)
continue
if _is_objective_message(m):
# Never truncate objective messages - they contain the core task
continue
content = m.get("content") or ""
if _tok_len(content, enc) > cap:
m["content"] = _truncate_middle_tokens(content, enc, cap)
cap //= 2 # tighten the screw cap //= 2 # tighten the screw
# ---- STEP 2 : middle-out deletion ----------------------------------- # ---- STEP 2 : middle-out deletion -----------------------------------
while total_tokens() + reserve > target_tokens and len(msgs) > 2: while total_tokens() + reserve > target_tokens and len(msgs) > 2:
# Identify all deletable messages (not first/last, not tool messages, not objective messages)
deletable_indices = []
for i in range(1, len(msgs) - 1): # Skip first and last
if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]):
deletable_indices.append(i)
if not deletable_indices:
break # nothing more we can drop
# Delete from center outward - find the index closest to center
centre = len(msgs) // 2 centre = len(msgs) // 2
# Build a symmetrical centre-out index walk: centre, centre+1, centre-1, ... to_delete = min(deletable_indices, key=lambda i: abs(i - centre))
order = [centre] + [ del msgs[to_delete]
i
for pair in zip(range(centre + 1, len(msgs) - 1), range(centre - 1, 0, -1))
for i in pair
]
removed = False
for i in order:
msg = msgs[i]
if "tool_calls" in msg or msg.get("role") == "tool":
continue # protect tool shells
del msgs[i]
removed = True
break
if not removed: # nothing more we can drop
break
# ---- STEP 3 : final safety-net trim on first & last ------------------ # ---- STEP 3 : final safety-net trim on first & last ------------------
cap = start_cap cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap: while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for idx in (0, -1): # first and last for idx in (0, -1): # first and last
if _is_tool_message(msgs[idx]):
# For tool messages at first/last position, truncate tool result content only
_truncate_tool_message_content(msgs[idx], enc, cap)
continue
text = msgs[idx].get("content") or "" text = msgs[idx].get("content") or ""
if _tok_len(text, enc) > cap: if _tok_len(text, enc) > cap:
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap) msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)

View File

@@ -3,7 +3,6 @@ import { scrollbarStyles } from "@/components/styles/scrollbars";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
import { X } from "@phosphor-icons/react"; import { X } from "@phosphor-icons/react";
import * as RXDialog from "@radix-ui/react-dialog"; import * as RXDialog from "@radix-ui/react-dialog";
import { debounce } from "lodash";
import { import {
CSSProperties, CSSProperties,
PropsWithChildren, PropsWithChildren,
@@ -71,21 +70,13 @@ export function DialogWrap({
if (!el) return; if (!el) return;
setHasVerticalScrollbar(el.scrollHeight > el.clientHeight + 1); setHasVerticalScrollbar(el.scrollHeight > el.clientHeight + 1);
} }
// Debounce the update function to prevent rapid successive state updates
const debouncedUpdate = debounce(update, 100);
// Initial update without debounce for immediate UI feedback
update(); update();
const ro = new ResizeObserver(update);
const ro = new ResizeObserver(debouncedUpdate);
if (scrollRef.current) ro.observe(scrollRef.current); if (scrollRef.current) ro.observe(scrollRef.current);
window.addEventListener("resize", debouncedUpdate); window.addEventListener("resize", update);
return () => { return () => {
debouncedUpdate.cancel();
ro.disconnect(); ro.disconnect();
window.removeEventListener("resize", debouncedUpdate); window.removeEventListener("resize", update);
}; };
}, []); }, []);

View File

@@ -48,7 +48,7 @@ const mockFlags = {
[Flag.AGENT_FAVORITING]: false, [Flag.AGENT_FAVORITING]: false,
[Flag.MARKETPLACE_SEARCH_TERMS]: DEFAULT_SEARCH_TERMS, [Flag.MARKETPLACE_SEARCH_TERMS]: DEFAULT_SEARCH_TERMS,
[Flag.ENABLE_PLATFORM_PAYMENT]: false, [Flag.ENABLE_PLATFORM_PAYMENT]: false,
[Flag.CHAT]: false, [Flag.CHAT]: true,
}; };
export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] | null { export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] | null {