mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-08 13:55:06 -05:00
Compare commits
6 Commits
seer/perf/
...
fix/execut
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b20f4cd13 | ||
|
|
a3d0f9cbd2 | ||
|
|
02ddb51446 | ||
|
|
750e096f15 | ||
|
|
ff5c8f324b | ||
|
|
71157bddd7 |
@@ -1,8 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from concurrent.futures import Future
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import backend.blocks.llm as llm
|
import backend.blocks.llm as llm
|
||||||
from backend.blocks.agent import AgentExecutorBlock
|
from backend.blocks.agent import AgentExecutorBlock
|
||||||
from backend.data.block import (
|
from backend.data.block import (
|
||||||
@@ -20,16 +23,41 @@ from backend.data.dynamic_fields import (
|
|||||||
is_dynamic_field,
|
is_dynamic_field,
|
||||||
is_tool_pin,
|
is_tool_pin,
|
||||||
)
|
)
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import NodeExecutionStats, SchemaField
|
from backend.data.model import NodeExecutionStats, SchemaField
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.clients import get_database_manager_async_client
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.data.graph import Link, Node
|
from backend.data.graph import Link, Node
|
||||||
|
from backend.executor.manager import ExecutionProcessor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolInfo(BaseModel):
|
||||||
|
"""Processed tool call information."""
|
||||||
|
|
||||||
|
tool_call: Any # The original tool call object from LLM response
|
||||||
|
tool_name: str # The function name
|
||||||
|
tool_def: dict[str, Any] # The tool definition from tool_functions
|
||||||
|
input_data: dict[str, Any] # Processed input data ready for tool execution
|
||||||
|
field_mapping: dict[str, str] # Field name mapping for the tool
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionParams(BaseModel):
|
||||||
|
"""Tool execution parameters."""
|
||||||
|
|
||||||
|
user_id: str
|
||||||
|
graph_id: str
|
||||||
|
node_id: str
|
||||||
|
graph_version: int
|
||||||
|
graph_exec_id: str
|
||||||
|
node_exec_id: str
|
||||||
|
execution_context: "ExecutionContext"
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
def _get_tool_requests(entry: dict[str, Any]) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Return a list of tool_call_ids if the entry is a tool request.
|
Return a list of tool_call_ids if the entry is a tool request.
|
||||||
@@ -105,6 +133,50 @@ def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
|
|||||||
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
return {"role": "tool", "tool_call_id": call_id, "content": content}
|
||||||
|
|
||||||
|
|
||||||
|
def _combine_tool_responses(tool_outputs: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Combine multiple Anthropic tool responses into a single user message.
|
||||||
|
For non-Anthropic formats, returns the original list unchanged.
|
||||||
|
"""
|
||||||
|
if len(tool_outputs) <= 1:
|
||||||
|
return tool_outputs
|
||||||
|
|
||||||
|
# Anthropic responses have role="user", type="message", and content is a list with tool_result items
|
||||||
|
anthropic_responses = [
|
||||||
|
output
|
||||||
|
for output in tool_outputs
|
||||||
|
if (
|
||||||
|
output.get("role") == "user"
|
||||||
|
and output.get("type") == "message"
|
||||||
|
and isinstance(output.get("content"), list)
|
||||||
|
and any(
|
||||||
|
item.get("type") == "tool_result"
|
||||||
|
for item in output.get("content", [])
|
||||||
|
if isinstance(item, dict)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(anthropic_responses) > 1:
|
||||||
|
combined_content = [
|
||||||
|
item for response in anthropic_responses for item in response["content"]
|
||||||
|
]
|
||||||
|
|
||||||
|
combined_response = {
|
||||||
|
"role": "user",
|
||||||
|
"type": "message",
|
||||||
|
"content": combined_content,
|
||||||
|
}
|
||||||
|
|
||||||
|
non_anthropic_responses = [
|
||||||
|
output for output in tool_outputs if output not in anthropic_responses
|
||||||
|
]
|
||||||
|
|
||||||
|
return [combined_response] + non_anthropic_responses
|
||||||
|
|
||||||
|
return tool_outputs
|
||||||
|
|
||||||
|
|
||||||
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
def _convert_raw_response_to_dict(raw_response: Any) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Safely convert raw_response to dictionary format for conversation history.
|
Safely convert raw_response to dictionary format for conversation history.
|
||||||
@@ -204,6 +276,17 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
default="localhost:11434",
|
default="localhost:11434",
|
||||||
description="Ollama host for local models",
|
description="Ollama host for local models",
|
||||||
)
|
)
|
||||||
|
agent_mode_max_iterations: int = SchemaField(
|
||||||
|
title="Agent Mode Max Iterations",
|
||||||
|
description="Maximum iterations for agent mode. 0 = traditional mode (single LLM call, yield tool calls for external execution), -1 = infinite agent mode (loop until finished), 1+ = agent mode with max iterations limit.",
|
||||||
|
advanced=True,
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
conversation_compaction: bool = SchemaField(
|
||||||
|
default=True,
|
||||||
|
title="Context window auto-compaction",
|
||||||
|
description="Automatically compact the context window once it hits the limit",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
def get_missing_links(cls, data: BlockInput, links: list["Link"]) -> set[str]:
|
||||||
@@ -506,6 +589,7 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
Returns the response if successful, raises ValueError if validation fails.
|
Returns the response if successful, raises ValueError if validation fails.
|
||||||
"""
|
"""
|
||||||
resp = await llm.llm_call(
|
resp = await llm.llm_call(
|
||||||
|
compress_prompt_to_fit=input_data.conversation_compaction,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
llm_model=input_data.model,
|
llm_model=input_data.model,
|
||||||
prompt=current_prompt,
|
prompt=current_prompt,
|
||||||
@@ -593,6 +677,291 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
def _process_tool_calls(
|
||||||
|
self, response, tool_functions: list[dict[str, Any]]
|
||||||
|
) -> list[ToolInfo]:
|
||||||
|
"""Process tool calls and extract tool definitions, arguments, and input data.
|
||||||
|
|
||||||
|
Returns a list of tool info dicts with:
|
||||||
|
- tool_call: The original tool call object
|
||||||
|
- tool_name: The function name
|
||||||
|
- tool_def: The tool definition from tool_functions
|
||||||
|
- input_data: Processed input data dict (includes None values)
|
||||||
|
- field_mapping: Field name mapping for the tool
|
||||||
|
"""
|
||||||
|
if not response.tool_calls:
|
||||||
|
return []
|
||||||
|
|
||||||
|
processed_tools = []
|
||||||
|
for tool_call in response.tool_calls:
|
||||||
|
tool_name = tool_call.function.name
|
||||||
|
tool_args = json.loads(tool_call.function.arguments)
|
||||||
|
|
||||||
|
tool_def = next(
|
||||||
|
(
|
||||||
|
tool
|
||||||
|
for tool in tool_functions
|
||||||
|
if tool["function"]["name"] == tool_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if not tool_def:
|
||||||
|
if len(tool_functions) == 1:
|
||||||
|
tool_def = tool_functions[0]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build input data for the tool
|
||||||
|
input_data = {}
|
||||||
|
field_mapping = tool_def["function"].get("_field_mapping", {})
|
||||||
|
if "function" in tool_def and "parameters" in tool_def["function"]:
|
||||||
|
expected_args = tool_def["function"]["parameters"].get("properties", {})
|
||||||
|
for clean_arg_name in expected_args:
|
||||||
|
original_field_name = field_mapping.get(
|
||||||
|
clean_arg_name, clean_arg_name
|
||||||
|
)
|
||||||
|
arg_value = tool_args.get(clean_arg_name)
|
||||||
|
# Include all expected parameters, even if None (for backward compatibility with tests)
|
||||||
|
input_data[original_field_name] = arg_value
|
||||||
|
|
||||||
|
processed_tools.append(
|
||||||
|
ToolInfo(
|
||||||
|
tool_call=tool_call,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_def=tool_def,
|
||||||
|
input_data=input_data,
|
||||||
|
field_mapping=field_mapping,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_tools
|
||||||
|
|
||||||
|
def _update_conversation(
|
||||||
|
self, prompt: list[dict], response, tool_outputs: list | None = None
|
||||||
|
):
|
||||||
|
"""Update conversation history with response and tool outputs."""
|
||||||
|
# Don't add separate reasoning message with tool calls (breaks Anthropic's tool_use->tool_result pairing)
|
||||||
|
assistant_message = _convert_raw_response_to_dict(response.raw_response)
|
||||||
|
has_tool_calls = isinstance(assistant_message.get("content"), list) and any(
|
||||||
|
item.get("type") == "tool_use"
|
||||||
|
for item in assistant_message.get("content", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.reasoning and not has_tool_calls:
|
||||||
|
prompt.append(
|
||||||
|
{"role": "assistant", "content": f"[Reasoning]: {response.reasoning}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt.append(assistant_message)
|
||||||
|
|
||||||
|
if tool_outputs:
|
||||||
|
prompt.extend(tool_outputs)
|
||||||
|
|
||||||
|
async def _execute_single_tool_with_manager(
|
||||||
|
self,
|
||||||
|
tool_info: ToolInfo,
|
||||||
|
execution_params: ExecutionParams,
|
||||||
|
execution_processor: "ExecutionProcessor",
|
||||||
|
) -> dict:
|
||||||
|
"""Execute a single tool using the execution manager for proper integration."""
|
||||||
|
# Lazy imports to avoid circular dependencies
|
||||||
|
from backend.data.execution import NodeExecutionEntry
|
||||||
|
|
||||||
|
tool_call = tool_info.tool_call
|
||||||
|
tool_def = tool_info.tool_def
|
||||||
|
raw_input_data = tool_info.input_data
|
||||||
|
|
||||||
|
# Get sink node and field mapping
|
||||||
|
sink_node_id = tool_def["function"]["_sink_node_id"]
|
||||||
|
|
||||||
|
# Use proper database operations for tool execution
|
||||||
|
db_client = get_database_manager_async_client()
|
||||||
|
|
||||||
|
# Get target node
|
||||||
|
target_node = await db_client.get_node(sink_node_id)
|
||||||
|
if not target_node:
|
||||||
|
raise ValueError(f"Target node {sink_node_id} not found")
|
||||||
|
|
||||||
|
# Create proper node execution using upsert_execution_input
|
||||||
|
node_exec_result = None
|
||||||
|
final_input_data = None
|
||||||
|
|
||||||
|
# Add all inputs to the execution
|
||||||
|
if not raw_input_data:
|
||||||
|
raise ValueError(f"Tool call has no input data: {tool_call}")
|
||||||
|
|
||||||
|
for input_name, input_value in raw_input_data.items():
|
||||||
|
node_exec_result, final_input_data = await db_client.upsert_execution_input(
|
||||||
|
node_id=sink_node_id,
|
||||||
|
graph_exec_id=execution_params.graph_exec_id,
|
||||||
|
input_name=input_name,
|
||||||
|
input_data=input_value,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert node_exec_result is not None, "node_exec_result should not be None"
|
||||||
|
|
||||||
|
# Create NodeExecutionEntry for execution manager
|
||||||
|
node_exec_entry = NodeExecutionEntry(
|
||||||
|
user_id=execution_params.user_id,
|
||||||
|
graph_exec_id=execution_params.graph_exec_id,
|
||||||
|
graph_id=execution_params.graph_id,
|
||||||
|
graph_version=execution_params.graph_version,
|
||||||
|
node_exec_id=node_exec_result.node_exec_id,
|
||||||
|
node_id=sink_node_id,
|
||||||
|
block_id=target_node.block_id,
|
||||||
|
inputs=final_input_data or {},
|
||||||
|
execution_context=execution_params.execution_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the execution manager to execute the tool node
|
||||||
|
try:
|
||||||
|
# Get NodeExecutionProgress from the execution manager's running nodes
|
||||||
|
node_exec_progress = execution_processor.running_node_execution[
|
||||||
|
sink_node_id
|
||||||
|
]
|
||||||
|
|
||||||
|
# Use the execution manager's own graph stats
|
||||||
|
graph_stats_pair = (
|
||||||
|
execution_processor.execution_stats,
|
||||||
|
execution_processor.execution_stats_lock,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a completed future for the task tracking system
|
||||||
|
node_exec_future = Future()
|
||||||
|
node_exec_progress.add_task(
|
||||||
|
node_exec_id=node_exec_result.node_exec_id,
|
||||||
|
task=node_exec_future,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the node directly since we're in the SmartDecisionMaker context
|
||||||
|
node_exec_future.set_result(
|
||||||
|
await execution_processor.on_node_execution(
|
||||||
|
node_exec=node_exec_entry,
|
||||||
|
node_exec_progress=node_exec_progress,
|
||||||
|
nodes_input_masks=None,
|
||||||
|
graph_stats_pair=graph_stats_pair,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get outputs from database after execution completes using database manager client
|
||||||
|
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
|
||||||
|
node_exec_result.node_exec_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tool response
|
||||||
|
tool_response_content = (
|
||||||
|
json.dumps(node_outputs)
|
||||||
|
if node_outputs
|
||||||
|
else "Tool executed successfully"
|
||||||
|
)
|
||||||
|
return _create_tool_response(tool_call.id, tool_response_content)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tool execution with manager failed: {e}")
|
||||||
|
# Return error response
|
||||||
|
return _create_tool_response(
|
||||||
|
tool_call.id, f"Tool execution failed: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_tools_agent_mode(
|
||||||
|
self,
|
||||||
|
input_data,
|
||||||
|
credentials,
|
||||||
|
tool_functions: list[dict[str, Any]],
|
||||||
|
prompt: list[dict],
|
||||||
|
graph_exec_id: str,
|
||||||
|
node_id: str,
|
||||||
|
node_exec_id: str,
|
||||||
|
user_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
execution_processor: "ExecutionProcessor",
|
||||||
|
):
|
||||||
|
"""Execute tools in agent mode with a loop until finished."""
|
||||||
|
max_iterations = input_data.agent_mode_max_iterations
|
||||||
|
iteration = 0
|
||||||
|
|
||||||
|
# Execution parameters for tool execution
|
||||||
|
execution_params = ExecutionParams(
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
node_id=node_id,
|
||||||
|
graph_version=graph_version,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
execution_context=execution_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_prompt = list(prompt)
|
||||||
|
|
||||||
|
while max_iterations < 0 or iteration < max_iterations:
|
||||||
|
iteration += 1
|
||||||
|
logger.debug(f"Agent mode iteration {iteration}")
|
||||||
|
|
||||||
|
# Prepare prompt for this iteration
|
||||||
|
iteration_prompt = list(current_prompt)
|
||||||
|
|
||||||
|
# On the last iteration, add a special system message to encourage completion
|
||||||
|
if max_iterations > 0 and iteration == max_iterations:
|
||||||
|
last_iteration_message = {
|
||||||
|
"role": "system",
|
||||||
|
"content": f"{MAIN_OBJECTIVE_PREFIX}This is your last iteration ({iteration}/{max_iterations}). "
|
||||||
|
"Try to complete the task with the information you have. If you cannot fully complete it, "
|
||||||
|
"provide a summary of what you've accomplished and what remains to be done. "
|
||||||
|
"Prefer finishing with a clear response rather than making additional tool calls.",
|
||||||
|
}
|
||||||
|
iteration_prompt.append(last_iteration_message)
|
||||||
|
|
||||||
|
# Get LLM response
|
||||||
|
try:
|
||||||
|
response = await self._attempt_llm_call_with_validation(
|
||||||
|
credentials, input_data, iteration_prompt, tool_functions
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
yield "error", f"LLM call failed in agent mode iteration {iteration}: {str(e)}"
|
||||||
|
return
|
||||||
|
|
||||||
|
# Process tool calls
|
||||||
|
processed_tools = self._process_tool_calls(response, tool_functions)
|
||||||
|
|
||||||
|
# If no tool calls, we're done
|
||||||
|
if not processed_tools:
|
||||||
|
yield "finished", response.response
|
||||||
|
self._update_conversation(current_prompt, response)
|
||||||
|
yield "conversations", current_prompt
|
||||||
|
return
|
||||||
|
|
||||||
|
# Execute tools and collect responses
|
||||||
|
tool_outputs = []
|
||||||
|
for tool_info in processed_tools:
|
||||||
|
try:
|
||||||
|
tool_response = await self._execute_single_tool_with_manager(
|
||||||
|
tool_info, execution_params, execution_processor
|
||||||
|
)
|
||||||
|
tool_outputs.append(tool_response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Tool execution failed: {e}")
|
||||||
|
# Create error response for the tool
|
||||||
|
error_response = _create_tool_response(
|
||||||
|
tool_info.tool_call.id, f"Error: {str(e)}"
|
||||||
|
)
|
||||||
|
tool_outputs.append(error_response)
|
||||||
|
|
||||||
|
tool_outputs = _combine_tool_responses(tool_outputs)
|
||||||
|
|
||||||
|
self._update_conversation(current_prompt, response, tool_outputs)
|
||||||
|
|
||||||
|
# Yield intermediate conversation state
|
||||||
|
yield "conversations", current_prompt
|
||||||
|
|
||||||
|
# If we reach max iterations, yield the current state
|
||||||
|
if max_iterations < 0:
|
||||||
|
yield "finished", f"Agent mode completed after {iteration} iterations"
|
||||||
|
else:
|
||||||
|
yield "finished", f"Agent mode completed after {max_iterations} iterations (limit reached)"
|
||||||
|
yield "conversations", current_prompt
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
input_data: Input,
|
input_data: Input,
|
||||||
@@ -603,8 +972,12 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
node_exec_id: str,
|
node_exec_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
graph_version: int,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
|
execution_processor: "ExecutionProcessor",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
|
|
||||||
tool_functions = await self._create_tool_node_signatures(node_id)
|
tool_functions = await self._create_tool_node_signatures(node_id)
|
||||||
yield "tool_functions", json.dumps(tool_functions)
|
yield "tool_functions", json.dumps(tool_functions)
|
||||||
|
|
||||||
@@ -648,24 +1021,52 @@ class SmartDecisionMakerBlock(Block):
|
|||||||
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
|
input_data.prompt = llm.fmt.format_string(input_data.prompt, values)
|
||||||
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
|
input_data.sys_prompt = llm.fmt.format_string(input_data.sys_prompt, values)
|
||||||
|
|
||||||
prefix = "[Main Objective Prompt]: "
|
|
||||||
|
|
||||||
if input_data.sys_prompt and not any(
|
if input_data.sys_prompt and not any(
|
||||||
p["role"] == "system" and p["content"].startswith(prefix) for p in prompt
|
p["role"] == "system" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||||
|
for p in prompt
|
||||||
):
|
):
|
||||||
prompt.append({"role": "system", "content": prefix + input_data.sys_prompt})
|
prompt.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": MAIN_OBJECTIVE_PREFIX + input_data.sys_prompt,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if input_data.prompt and not any(
|
if input_data.prompt and not any(
|
||||||
p["role"] == "user" and p["content"].startswith(prefix) for p in prompt
|
p["role"] == "user" and p["content"].startswith(MAIN_OBJECTIVE_PREFIX)
|
||||||
|
for p in prompt
|
||||||
):
|
):
|
||||||
prompt.append({"role": "user", "content": prefix + input_data.prompt})
|
prompt.append(
|
||||||
|
{"role": "user", "content": MAIN_OBJECTIVE_PREFIX + input_data.prompt}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute tools based on the selected mode
|
||||||
|
if input_data.agent_mode_max_iterations != 0:
|
||||||
|
# In agent mode, execute tools directly in a loop until finished
|
||||||
|
async for result in self._execute_tools_agent_mode(
|
||||||
|
input_data=input_data,
|
||||||
|
credentials=credentials,
|
||||||
|
tool_functions=tool_functions,
|
||||||
|
prompt=prompt,
|
||||||
|
graph_exec_id=graph_exec_id,
|
||||||
|
node_id=node_id,
|
||||||
|
node_exec_id=node_exec_id,
|
||||||
|
user_id=user_id,
|
||||||
|
graph_id=graph_id,
|
||||||
|
graph_version=graph_version,
|
||||||
|
execution_context=execution_context,
|
||||||
|
execution_processor=execution_processor,
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
|
# One-off mode: single LLM call and yield tool calls for external execution
|
||||||
current_prompt = list(prompt)
|
current_prompt = list(prompt)
|
||||||
max_attempts = max(1, int(input_data.retry))
|
max_attempts = max(1, int(input_data.retry))
|
||||||
response = None
|
response = None
|
||||||
|
|
||||||
last_error = None
|
last_error = None
|
||||||
for attempt in range(max_attempts):
|
for _ in range(max_attempts):
|
||||||
try:
|
try:
|
||||||
response = await self._attempt_llm_call_with_validation(
|
response = await self._attempt_llm_call_with_validation(
|
||||||
credentials, input_data, current_prompt, tool_functions
|
credentials, input_data, current_prompt, tool_functions
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
|
from collections import defaultdict
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import ProviderName, User
|
from backend.data.model import ProviderName, User
|
||||||
from backend.server.model import CreateGraph
|
from backend.server.model import CreateGraph
|
||||||
from backend.server.rest_api import AgentServer
|
from backend.server.rest_api import AgentServer
|
||||||
@@ -17,10 +21,10 @@ async def create_graph(s: SpinTestServer, g, u: User):
|
|||||||
|
|
||||||
|
|
||||||
async def create_credentials(s: SpinTestServer, u: User):
|
async def create_credentials(s: SpinTestServer, u: User):
|
||||||
import backend.blocks.llm as llm
|
import backend.blocks.llm as llm_module
|
||||||
|
|
||||||
provider = ProviderName.OPENAI
|
provider = ProviderName.OPENAI
|
||||||
credentials = llm.TEST_CREDENTIALS
|
credentials = llm_module.TEST_CREDENTIALS
|
||||||
return await s.agent_server.test_create_credentials(u.id, provider, credentials)
|
return await s.agent_server.test_create_credentials(u.id, provider, credentials)
|
||||||
|
|
||||||
|
|
||||||
@@ -196,8 +200,6 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_smart_decision_maker_tracks_llm_stats():
|
async def test_smart_decision_maker_tracks_llm_stats():
|
||||||
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
|
"""Test that SmartDecisionMakerBlock correctly tracks LLM usage stats."""
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import backend.blocks.llm as llm_module
|
import backend.blocks.llm as llm_module
|
||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
|
|
||||||
@@ -216,7 +218,6 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Mock the _create_tool_node_signatures method to avoid database calls
|
# Mock the _create_tool_node_signatures method to avoid database calls
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
@@ -234,10 +235,19 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
|||||||
prompt="Should I continue with this task?",
|
prompt="Should I continue with this task?",
|
||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the block
|
# Execute the block
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
async for output_name, output_data in block.run(
|
async for output_name, output_data in block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=llm_module.TEST_CREDENTIALS,
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
@@ -246,6 +256,9 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -263,8 +276,6 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_smart_decision_maker_parameter_validation():
|
async def test_smart_decision_maker_parameter_validation():
|
||||||
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
"""Test that SmartDecisionMakerBlock correctly validates tool call parameters."""
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import backend.blocks.llm as llm_module
|
import backend.blocks.llm as llm_module
|
||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
|
|
||||||
@@ -311,8 +322,6 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
mock_response_with_typo.reasoning = None
|
mock_response_with_typo.reasoning = None
|
||||||
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
|
mock_response_with_typo.raw_response = {"role": "assistant", "content": None}
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
@@ -329,8 +338,17 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
retry=2, # Set retry to 2 for testing
|
retry=2, # Set retry to 2 for testing
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
# Should raise ValueError after retries due to typo'd parameter name
|
# Should raise ValueError after retries due to typo'd parameter name
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
outputs = {}
|
outputs = {}
|
||||||
@@ -342,6 +360,9 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -368,8 +389,6 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
mock_response_missing_required.reasoning = None
|
mock_response_missing_required.reasoning = None
|
||||||
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
|
mock_response_missing_required.raw_response = {"role": "assistant", "content": None}
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
@@ -385,8 +404,17 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
prompt="Search for keywords",
|
prompt="Search for keywords",
|
||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
# Should raise ValueError due to missing required parameter
|
# Should raise ValueError due to missing required parameter
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
outputs = {}
|
outputs = {}
|
||||||
@@ -398,6 +426,9 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -418,8 +449,6 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
mock_response_valid.reasoning = None
|
mock_response_valid.reasoning = None
|
||||||
mock_response_valid.raw_response = {"role": "assistant", "content": None}
|
mock_response_valid.raw_response = {"role": "assistant", "content": None}
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
@@ -435,10 +464,19 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
prompt="Search for keywords",
|
prompt="Search for keywords",
|
||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should succeed - optional parameter missing is OK
|
# Should succeed - optional parameter missing is OK
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
async for output_name, output_data in block.run(
|
async for output_name, output_data in block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=llm_module.TEST_CREDENTIALS,
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
@@ -447,6 +485,9 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -472,8 +513,6 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
mock_response_all_params.reasoning = None
|
mock_response_all_params.reasoning = None
|
||||||
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
|
mock_response_all_params.raw_response = {"role": "assistant", "content": None}
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
@@ -489,10 +528,19 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
prompt="Search for keywords",
|
prompt="Search for keywords",
|
||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should succeed with all parameters
|
# Should succeed with all parameters
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
async for output_name, output_data in block.run(
|
async for output_name, output_data in block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=llm_module.TEST_CREDENTIALS,
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
@@ -501,6 +549,9 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -513,8 +564,6 @@ async def test_smart_decision_maker_parameter_validation():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_smart_decision_maker_raw_response_conversion():
|
async def test_smart_decision_maker_raw_response_conversion():
|
||||||
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
"""Test that SmartDecisionMaker correctly handles different raw_response types with retry mechanism."""
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import backend.blocks.llm as llm_module
|
import backend.blocks.llm as llm_module
|
||||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
|
|
||||||
@@ -584,7 +633,6 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Mock llm_call to return different responses on different calls
|
# Mock llm_call to return different responses on different calls
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||||
@@ -603,10 +651,19 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
retry=2,
|
retry=2,
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Should succeed after retry, demonstrating our helper function works
|
# Should succeed after retry, demonstrating our helper function works
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
async for output_name, output_data in block.run(
|
async for output_name, output_data in block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=llm_module.TEST_CREDENTIALS,
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
@@ -615,6 +672,9 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -650,8 +710,6 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
"I'll help you with that." # Ollama returns string
|
"I'll help you with that." # Ollama returns string
|
||||||
)
|
)
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
@@ -666,9 +724,18 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
prompt="Simple prompt",
|
prompt="Simple prompt",
|
||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
async for output_name, output_data in block.run(
|
async for output_name, output_data in block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=llm_module.TEST_CREDENTIALS,
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
@@ -677,6 +744,9 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
@@ -696,8 +766,6 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
"content": "Test response",
|
"content": "Test response",
|
||||||
} # Dict format
|
} # Dict format
|
||||||
|
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.blocks.llm.llm_call",
|
"backend.blocks.llm.llm_call",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
@@ -712,6 +780,160 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
prompt="Another test",
|
prompt="Another test",
|
||||||
model=llm_module.LlmModel.GPT4O,
|
model=llm_module.LlmModel.GPT4O,
|
||||||
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = {}
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
|
async for output_name, output_data in block.run(
|
||||||
|
input_data,
|
||||||
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
|
graph_id="test-graph-id",
|
||||||
|
node_id="test-node-id",
|
||||||
|
graph_exec_id="test-exec-id",
|
||||||
|
node_exec_id="test-node-exec-id",
|
||||||
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
|
):
|
||||||
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
|
assert "finished" in outputs
|
||||||
|
assert outputs["finished"] == "Test response"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_decision_maker_agent_mode():
|
||||||
|
"""Test that agent mode executes tools directly and loops until finished."""
|
||||||
|
import backend.blocks.llm as llm_module
|
||||||
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
|
|
||||||
|
block = SmartDecisionMakerBlock()
|
||||||
|
|
||||||
|
# Mock tool call that requires multiple iterations
|
||||||
|
mock_tool_call_1 = MagicMock()
|
||||||
|
mock_tool_call_1.id = "call_1"
|
||||||
|
mock_tool_call_1.function.name = "search_keywords"
|
||||||
|
mock_tool_call_1.function.arguments = (
|
||||||
|
'{"query": "test", "max_keyword_difficulty": 50}'
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response_1 = MagicMock()
|
||||||
|
mock_response_1.response = None
|
||||||
|
mock_response_1.tool_calls = [mock_tool_call_1]
|
||||||
|
mock_response_1.prompt_tokens = 50
|
||||||
|
mock_response_1.completion_tokens = 25
|
||||||
|
mock_response_1.reasoning = "Using search tool"
|
||||||
|
mock_response_1.raw_response = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{"id": "call_1", "type": "function"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Final response with no tool calls (finished)
|
||||||
|
mock_response_2 = MagicMock()
|
||||||
|
mock_response_2.response = "Task completed successfully"
|
||||||
|
mock_response_2.tool_calls = []
|
||||||
|
mock_response_2.prompt_tokens = 30
|
||||||
|
mock_response_2.completion_tokens = 15
|
||||||
|
mock_response_2.reasoning = None
|
||||||
|
mock_response_2.raw_response = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Task completed successfully",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mock the LLM call to return different responses on each iteration
|
||||||
|
llm_call_mock = AsyncMock()
|
||||||
|
llm_call_mock.side_effect = [mock_response_1, mock_response_2]
|
||||||
|
|
||||||
|
# Mock tool node signatures
|
||||||
|
mock_tool_signatures = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search_keywords",
|
||||||
|
"_sink_node_id": "test-sink-node-id",
|
||||||
|
"_field_mapping": {},
|
||||||
|
"parameters": {
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
"max_keyword_difficulty": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["query", "max_keyword_difficulty"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock database and execution components
|
||||||
|
mock_db_client = AsyncMock()
|
||||||
|
mock_node = MagicMock()
|
||||||
|
mock_node.block_id = "test-block-id"
|
||||||
|
mock_db_client.get_node.return_value = mock_node
|
||||||
|
|
||||||
|
# Mock upsert_execution_input to return proper NodeExecutionResult and input data
|
||||||
|
mock_node_exec_result = MagicMock()
|
||||||
|
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
|
||||||
|
mock_input_data = {"query": "test", "max_keyword_difficulty": 50}
|
||||||
|
mock_db_client.upsert_execution_input.return_value = (
|
||||||
|
mock_node_exec_result,
|
||||||
|
mock_input_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No longer need mock_execute_node since we use execution_processor.on_node_execution
|
||||||
|
|
||||||
|
with patch("backend.blocks.llm.llm_call", llm_call_mock), patch.object(
|
||||||
|
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||||
|
), patch(
|
||||||
|
"backend.blocks.smart_decision_maker.get_database_manager_async_client",
|
||||||
|
return_value=mock_db_client,
|
||||||
|
), patch(
|
||||||
|
"backend.executor.manager.async_update_node_execution_status",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
), patch(
|
||||||
|
"backend.integrations.creds_manager.IntegrationCredentialsManager"
|
||||||
|
):
|
||||||
|
|
||||||
|
# Create a mock execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(
|
||||||
|
safe_mode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a mock execution processor for agent mode tests
|
||||||
|
|
||||||
|
mock_execution_processor = AsyncMock()
|
||||||
|
# Configure the execution processor mock with required attributes
|
||||||
|
mock_execution_processor.running_node_execution = defaultdict(MagicMock)
|
||||||
|
mock_execution_processor.execution_stats = MagicMock()
|
||||||
|
mock_execution_processor.execution_stats_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Mock the on_node_execution method to return successful stats
|
||||||
|
mock_node_stats = MagicMock()
|
||||||
|
mock_node_stats.error = None # No error
|
||||||
|
mock_execution_processor.on_node_execution = AsyncMock(
|
||||||
|
return_value=mock_node_stats
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the get_execution_outputs_by_node_exec_id method
|
||||||
|
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
|
||||||
|
"result": {"status": "success", "data": "search completed"}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test agent mode with max_iterations = 3
|
||||||
|
input_data = SmartDecisionMakerBlock.Input(
|
||||||
|
prompt="Complete this task using tools",
|
||||||
|
model=llm_module.LlmModel.GPT4O,
|
||||||
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=3, # Enable agent mode with 3 max iterations
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = {}
|
outputs = {}
|
||||||
@@ -723,8 +945,115 @@ async def test_smart_decision_maker_raw_response_conversion():
|
|||||||
graph_exec_id="test-exec-id",
|
graph_exec_id="test-exec-id",
|
||||||
node_exec_id="test-node-exec-id",
|
node_exec_id="test-node-exec-id",
|
||||||
user_id="test-user-id",
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_data
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
|
# Verify agent mode behavior
|
||||||
|
assert "tool_functions" in outputs # tool_functions is yielded in both modes
|
||||||
assert "finished" in outputs
|
assert "finished" in outputs
|
||||||
assert outputs["finished"] == "Test response"
|
assert outputs["finished"] == "Task completed successfully"
|
||||||
|
assert "conversations" in outputs
|
||||||
|
|
||||||
|
# Verify the conversation includes tool responses
|
||||||
|
conversations = outputs["conversations"]
|
||||||
|
assert len(conversations) > 2 # Should have multiple conversation entries
|
||||||
|
|
||||||
|
# Verify LLM was called twice (once for tool call, once for finish)
|
||||||
|
assert llm_call_mock.call_count == 2
|
||||||
|
|
||||||
|
# Verify tool was executed via execution processor
|
||||||
|
assert mock_execution_processor.on_node_execution.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_decision_maker_traditional_mode_default():
|
||||||
|
"""Test that default behavior (agent_mode_max_iterations=0) works as traditional mode."""
|
||||||
|
import backend.blocks.llm as llm_module
|
||||||
|
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||||
|
|
||||||
|
block = SmartDecisionMakerBlock()
|
||||||
|
|
||||||
|
# Mock tool call
|
||||||
|
mock_tool_call = MagicMock()
|
||||||
|
mock_tool_call.function.name = "search_keywords"
|
||||||
|
mock_tool_call.function.arguments = (
|
||||||
|
'{"query": "test", "max_keyword_difficulty": 50}'
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.response = None
|
||||||
|
mock_response.tool_calls = [mock_tool_call]
|
||||||
|
mock_response.prompt_tokens = 50
|
||||||
|
mock_response.completion_tokens = 25
|
||||||
|
mock_response.reasoning = None
|
||||||
|
mock_response.raw_response = {"role": "assistant", "content": None}
|
||||||
|
|
||||||
|
mock_tool_signatures = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "search_keywords",
|
||||||
|
"_sink_node_id": "test-sink-node-id",
|
||||||
|
"_field_mapping": {},
|
||||||
|
"parameters": {
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
"max_keyword_difficulty": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["query", "max_keyword_difficulty"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"backend.blocks.llm.llm_call",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
), patch.object(
|
||||||
|
block, "_create_tool_node_signatures", return_value=mock_tool_signatures
|
||||||
|
):
|
||||||
|
|
||||||
|
# Test default behavior (traditional mode)
|
||||||
|
input_data = SmartDecisionMakerBlock.Input(
|
||||||
|
prompt="Test prompt",
|
||||||
|
model=llm_module.LlmModel.GPT4O,
|
||||||
|
credentials=llm_module.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||||
|
agent_mode_max_iterations=0, # Traditional mode
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create execution context
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a mock execution processor for tests
|
||||||
|
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
|
outputs = {}
|
||||||
|
async for output_name, output_data in block.run(
|
||||||
|
input_data,
|
||||||
|
credentials=llm_module.TEST_CREDENTIALS,
|
||||||
|
graph_id="test-graph-id",
|
||||||
|
node_id="test-node-id",
|
||||||
|
graph_exec_id="test-exec-id",
|
||||||
|
node_exec_id="test-node-exec-id",
|
||||||
|
user_id="test-user-id",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
|
):
|
||||||
|
outputs[output_name] = output_data
|
||||||
|
|
||||||
|
# Verify traditional mode behavior
|
||||||
|
assert (
|
||||||
|
"tool_functions" in outputs
|
||||||
|
) # Should yield tool_functions in traditional mode
|
||||||
|
assert (
|
||||||
|
"tools_^_test-sink-node-id_~_query" in outputs
|
||||||
|
) # Should yield individual tool parameters
|
||||||
|
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
||||||
|
assert "conversations" in outputs
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
"""Comprehensive tests for SmartDecisionMakerBlock dynamic field handling."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -308,10 +308,47 @@ async def test_output_yielding_with_dynamic_fields():
|
|||||||
) as mock_llm:
|
) as mock_llm:
|
||||||
mock_llm.return_value = mock_response
|
mock_llm.return_value = mock_response
|
||||||
|
|
||||||
# Mock the function signature creation
|
# Mock the database manager to avoid HTTP calls during tool execution
|
||||||
with patch.object(
|
with patch(
|
||||||
|
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||||
|
) as mock_db_manager, patch.object(
|
||||||
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
||||||
) as mock_sig:
|
) as mock_sig:
|
||||||
|
# Set up the mock database manager
|
||||||
|
mock_db_client = AsyncMock()
|
||||||
|
mock_db_manager.return_value = mock_db_client
|
||||||
|
|
||||||
|
# Mock the node retrieval
|
||||||
|
mock_target_node = Mock()
|
||||||
|
mock_target_node.id = "test-sink-node-id"
|
||||||
|
mock_target_node.block_id = "CreateDictionaryBlock"
|
||||||
|
mock_target_node.block = Mock()
|
||||||
|
mock_target_node.block.name = "Create Dictionary"
|
||||||
|
mock_db_client.get_node.return_value = mock_target_node
|
||||||
|
|
||||||
|
# Mock the execution result creation
|
||||||
|
mock_node_exec_result = Mock()
|
||||||
|
mock_node_exec_result.node_exec_id = "mock-node-exec-id"
|
||||||
|
mock_final_input_data = {
|
||||||
|
"values_#_name": "Alice",
|
||||||
|
"values_#_age": 30,
|
||||||
|
"values_#_email": "alice@example.com",
|
||||||
|
}
|
||||||
|
mock_db_client.upsert_execution_input.return_value = (
|
||||||
|
mock_node_exec_result,
|
||||||
|
mock_final_input_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the output retrieval
|
||||||
|
mock_outputs = {
|
||||||
|
"values_#_name": "Alice",
|
||||||
|
"values_#_age": 30,
|
||||||
|
"values_#_email": "alice@example.com",
|
||||||
|
}
|
||||||
|
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
|
||||||
|
mock_outputs
|
||||||
|
)
|
||||||
|
|
||||||
mock_sig.return_value = [
|
mock_sig.return_value = [
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
@@ -337,10 +374,16 @@ async def test_output_yielding_with_dynamic_fields():
|
|||||||
prompt="Create a user dictionary",
|
prompt="Create a user dictionary",
|
||||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||||
model=llm.LlmModel.GPT4O,
|
model=llm.LlmModel.GPT4O,
|
||||||
|
agent_mode_max_iterations=0, # Use traditional mode to test output yielding
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the block
|
# Run the block
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
mock_execution_processor = MagicMock()
|
||||||
|
|
||||||
async for output_name, output_value in block.run(
|
async for output_name, output_value in block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=llm.TEST_CREDENTIALS,
|
credentials=llm.TEST_CREDENTIALS,
|
||||||
@@ -349,6 +392,9 @@ async def test_output_yielding_with_dynamic_fields():
|
|||||||
graph_exec_id="test_exec",
|
graph_exec_id="test_exec",
|
||||||
node_exec_id="test_node_exec",
|
node_exec_id="test_node_exec",
|
||||||
user_id="test_user",
|
user_id="test_user",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_value
|
outputs[output_name] = output_value
|
||||||
|
|
||||||
@@ -511,6 +557,37 @@ async def test_validation_errors_dont_pollute_conversation():
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Mock the database manager to avoid HTTP calls during tool execution
|
||||||
|
with patch(
|
||||||
|
"backend.blocks.smart_decision_maker.get_database_manager_async_client"
|
||||||
|
) as mock_db_manager:
|
||||||
|
# Set up the mock database manager for agent mode
|
||||||
|
mock_db_client = AsyncMock()
|
||||||
|
mock_db_manager.return_value = mock_db_client
|
||||||
|
|
||||||
|
# Mock the node retrieval
|
||||||
|
mock_target_node = Mock()
|
||||||
|
mock_target_node.id = "test-sink-node-id"
|
||||||
|
mock_target_node.block_id = "TestBlock"
|
||||||
|
mock_target_node.block = Mock()
|
||||||
|
mock_target_node.block.name = "Test Block"
|
||||||
|
mock_db_client.get_node.return_value = mock_target_node
|
||||||
|
|
||||||
|
# Mock the execution result creation
|
||||||
|
mock_node_exec_result = Mock()
|
||||||
|
mock_node_exec_result.node_exec_id = "mock-node-exec-id"
|
||||||
|
mock_final_input_data = {"correct_param": "value"}
|
||||||
|
mock_db_client.upsert_execution_input.return_value = (
|
||||||
|
mock_node_exec_result,
|
||||||
|
mock_final_input_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the output retrieval
|
||||||
|
mock_outputs = {"correct_param": "value"}
|
||||||
|
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = (
|
||||||
|
mock_outputs
|
||||||
|
)
|
||||||
|
|
||||||
# Create input data
|
# Create input data
|
||||||
from backend.blocks import llm
|
from backend.blocks import llm
|
||||||
|
|
||||||
@@ -519,10 +596,41 @@ async def test_validation_errors_dont_pollute_conversation():
|
|||||||
credentials=llm.TEST_CREDENTIALS_INPUT,
|
credentials=llm.TEST_CREDENTIALS_INPUT,
|
||||||
model=llm.LlmModel.GPT4O,
|
model=llm.LlmModel.GPT4O,
|
||||||
retry=3, # Allow retries
|
retry=3, # Allow retries
|
||||||
|
agent_mode_max_iterations=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run the block
|
# Run the block
|
||||||
outputs = {}
|
outputs = {}
|
||||||
|
from backend.data.execution import ExecutionContext
|
||||||
|
|
||||||
|
mock_execution_context = ExecutionContext(safe_mode=False)
|
||||||
|
|
||||||
|
# Create a proper mock execution processor for agent mode
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
mock_execution_processor = AsyncMock()
|
||||||
|
mock_execution_processor.execution_stats = MagicMock()
|
||||||
|
mock_execution_processor.execution_stats_lock = MagicMock()
|
||||||
|
|
||||||
|
# Create a mock NodeExecutionProgress for the sink node
|
||||||
|
mock_node_exec_progress = MagicMock()
|
||||||
|
mock_node_exec_progress.add_task = MagicMock()
|
||||||
|
mock_node_exec_progress.pop_output = MagicMock(
|
||||||
|
return_value=None
|
||||||
|
) # No outputs to process
|
||||||
|
|
||||||
|
# Set up running_node_execution as a defaultdict that returns our mock for any key
|
||||||
|
mock_execution_processor.running_node_execution = defaultdict(
|
||||||
|
lambda: mock_node_exec_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the on_node_execution method that gets called during tool execution
|
||||||
|
mock_node_stats = MagicMock()
|
||||||
|
mock_node_stats.error = None
|
||||||
|
mock_execution_processor.on_node_execution.return_value = (
|
||||||
|
mock_node_stats
|
||||||
|
)
|
||||||
|
|
||||||
async for output_name, output_value in block.run(
|
async for output_name, output_value in block.run(
|
||||||
input_data,
|
input_data,
|
||||||
credentials=llm.TEST_CREDENTIALS,
|
credentials=llm.TEST_CREDENTIALS,
|
||||||
@@ -531,16 +639,20 @@ async def test_validation_errors_dont_pollute_conversation():
|
|||||||
graph_exec_id="test_exec",
|
graph_exec_id="test_exec",
|
||||||
node_exec_id="test_node_exec",
|
node_exec_id="test_node_exec",
|
||||||
user_id="test_user",
|
user_id="test_user",
|
||||||
|
graph_version=1,
|
||||||
|
execution_context=mock_execution_context,
|
||||||
|
execution_processor=mock_execution_processor,
|
||||||
):
|
):
|
||||||
outputs[output_name] = output_value
|
outputs[output_name] = output_value
|
||||||
|
|
||||||
# Verify we had 2 LLM calls (initial + retry)
|
# Verify we had at least 1 LLM call
|
||||||
assert call_count == 2
|
assert call_count >= 1
|
||||||
|
|
||||||
# Check the final conversation output
|
# Check the final conversation output
|
||||||
final_conversation = outputs.get("conversations", [])
|
final_conversation = outputs.get("conversations", [])
|
||||||
|
|
||||||
# The final conversation should NOT contain the validation error message
|
# The final conversation should NOT contain validation error messages
|
||||||
|
# Even if retries don't happen in agent mode, we should not leak errors
|
||||||
error_messages = [
|
error_messages = [
|
||||||
msg
|
msg
|
||||||
for msg in final_conversation
|
for msg in final_conversation
|
||||||
@@ -550,6 +662,3 @@ async def test_validation_errors_dont_pollute_conversation():
|
|||||||
assert (
|
assert (
|
||||||
len(error_messages) == 0
|
len(error_messages) == 0
|
||||||
), "Validation error leaked into final conversation"
|
), "Validation error leaked into final conversation"
|
||||||
|
|
||||||
# The final conversation should only have the successful response
|
|
||||||
assert final_conversation[-1]["content"] == "valid"
|
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from multiprocessing import Manager
|
|
||||||
from queue import Empty
|
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
@@ -65,6 +65,9 @@ from .includes import (
|
|||||||
)
|
)
|
||||||
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
from .model import CredentialsMetaInput, GraphExecutionStats, NodeExecutionStats
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -836,6 +839,30 @@ async def upsert_execution_output(
|
|||||||
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
await AgentNodeExecutionInputOutput.prisma().create(data=data)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_execution_outputs_by_node_exec_id(
|
||||||
|
node_exec_id: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get all execution outputs for a specific node execution ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_exec_id: The node execution ID to get outputs for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping output names to their data values
|
||||||
|
"""
|
||||||
|
outputs = await AgentNodeExecutionInputOutput.prisma().find_many(
|
||||||
|
where={"referencedByOutputExecId": node_exec_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for output in outputs:
|
||||||
|
if output.data is not None:
|
||||||
|
result[output.name] = type_utils.convert(output.data, JsonValue)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def update_graph_execution_start_time(
|
async def update_graph_execution_start_time(
|
||||||
graph_exec_id: str,
|
graph_exec_id: str,
|
||||||
) -> GraphExecution | None:
|
) -> GraphExecution | None:
|
||||||
@@ -1136,12 +1163,16 @@ class NodeExecutionEntry(BaseModel):
|
|||||||
|
|
||||||
class ExecutionQueue(Generic[T]):
|
class ExecutionQueue(Generic[T]):
|
||||||
"""
|
"""
|
||||||
Queue for managing the execution of agents.
|
Thread-safe queue for managing node execution within a single graph execution.
|
||||||
This will be shared between different processes
|
|
||||||
|
Note: Uses queue.Queue (not multiprocessing.Queue) since all access is from
|
||||||
|
threads within the same process. If migrating back to ProcessPoolExecutor,
|
||||||
|
replace with multiprocessing.Manager().Queue() for cross-process safety.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.queue = Manager().Queue()
|
# Thread-safe queue (not multiprocessing) — see class docstring
|
||||||
|
self.queue: queue.Queue[T] = queue.Queue()
|
||||||
|
|
||||||
def add(self, execution: T) -> T:
|
def add(self, execution: T) -> T:
|
||||||
self.queue.put(execution)
|
self.queue.put(execution)
|
||||||
@@ -1156,7 +1187,7 @@ class ExecutionQueue(Generic[T]):
|
|||||||
def get_or_none(self) -> T | None:
|
def get_or_none(self) -> T | None:
|
||||||
try:
|
try:
|
||||||
return self.queue.get_nowait()
|
return self.queue.get_nowait()
|
||||||
except Empty:
|
except queue.Empty:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,60 @@
|
|||||||
|
"""Tests for ExecutionQueue thread-safety."""
|
||||||
|
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.data.execution import ExecutionQueue
|
||||||
|
|
||||||
|
|
||||||
|
def test_execution_queue_uses_stdlib_queue():
|
||||||
|
"""Verify ExecutionQueue uses queue.Queue (not multiprocessing)."""
|
||||||
|
q = ExecutionQueue()
|
||||||
|
assert isinstance(q.queue, queue.Queue)
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_operations():
|
||||||
|
"""Test add, get, empty, and get_or_none."""
|
||||||
|
q = ExecutionQueue()
|
||||||
|
|
||||||
|
assert q.empty() is True
|
||||||
|
assert q.get_or_none() is None
|
||||||
|
|
||||||
|
result = q.add("item1")
|
||||||
|
assert result == "item1"
|
||||||
|
assert q.empty() is False
|
||||||
|
|
||||||
|
item = q.get()
|
||||||
|
assert item == "item1"
|
||||||
|
assert q.empty() is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_thread_safety():
|
||||||
|
"""Test concurrent access from multiple threads."""
|
||||||
|
q = ExecutionQueue()
|
||||||
|
results = []
|
||||||
|
num_items = 100
|
||||||
|
|
||||||
|
def producer():
|
||||||
|
for i in range(num_items):
|
||||||
|
q.add(f"item_{i}")
|
||||||
|
|
||||||
|
def consumer():
|
||||||
|
count = 0
|
||||||
|
while count < num_items:
|
||||||
|
item = q.get_or_none()
|
||||||
|
if item is not None:
|
||||||
|
results.append(item)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
producer_thread = threading.Thread(target=producer)
|
||||||
|
consumer_thread = threading.Thread(target=consumer)
|
||||||
|
|
||||||
|
producer_thread.start()
|
||||||
|
consumer_thread.start()
|
||||||
|
|
||||||
|
producer_thread.join(timeout=5)
|
||||||
|
consumer_thread.join(timeout=5)
|
||||||
|
|
||||||
|
assert len(results) == num_items
|
||||||
@@ -13,6 +13,7 @@ from backend.data.execution import (
|
|||||||
get_block_error_stats,
|
get_block_error_stats,
|
||||||
get_child_graph_executions,
|
get_child_graph_executions,
|
||||||
get_execution_kv_data,
|
get_execution_kv_data,
|
||||||
|
get_execution_outputs_by_node_exec_id,
|
||||||
get_frequently_executed_graphs,
|
get_frequently_executed_graphs,
|
||||||
get_graph_execution_meta,
|
get_graph_execution_meta,
|
||||||
get_graph_executions,
|
get_graph_executions,
|
||||||
@@ -147,6 +148,7 @@ class DatabaseManager(AppService):
|
|||||||
update_graph_execution_stats = _(update_graph_execution_stats)
|
update_graph_execution_stats = _(update_graph_execution_stats)
|
||||||
upsert_execution_input = _(upsert_execution_input)
|
upsert_execution_input = _(upsert_execution_input)
|
||||||
upsert_execution_output = _(upsert_execution_output)
|
upsert_execution_output = _(upsert_execution_output)
|
||||||
|
get_execution_outputs_by_node_exec_id = _(get_execution_outputs_by_node_exec_id)
|
||||||
get_execution_kv_data = _(get_execution_kv_data)
|
get_execution_kv_data = _(get_execution_kv_data)
|
||||||
set_execution_kv_data = _(set_execution_kv_data)
|
set_execution_kv_data = _(set_execution_kv_data)
|
||||||
get_block_error_stats = _(get_block_error_stats)
|
get_block_error_stats = _(get_block_error_stats)
|
||||||
@@ -277,6 +279,7 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
get_user_integrations = d.get_user_integrations
|
get_user_integrations = d.get_user_integrations
|
||||||
upsert_execution_input = d.upsert_execution_input
|
upsert_execution_input = d.upsert_execution_input
|
||||||
upsert_execution_output = d.upsert_execution_output
|
upsert_execution_output = d.upsert_execution_output
|
||||||
|
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
||||||
update_graph_execution_stats = d.update_graph_execution_stats
|
update_graph_execution_stats = d.update_graph_execution_stats
|
||||||
update_node_execution_status = d.update_node_execution_status
|
update_node_execution_status = d.update_node_execution_status
|
||||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||||
|
|||||||
@@ -133,9 +133,8 @@ def execute_graph(
|
|||||||
cluster_lock: ClusterLock,
|
cluster_lock: ClusterLock,
|
||||||
):
|
):
|
||||||
"""Execute graph using thread-local ExecutionProcessor instance"""
|
"""Execute graph using thread-local ExecutionProcessor instance"""
|
||||||
return _tls.processor.on_graph_execution(
|
processor: ExecutionProcessor = _tls.processor
|
||||||
graph_exec_entry, cancel_event, cluster_lock
|
return processor.on_graph_execution(graph_exec_entry, cancel_event, cluster_lock)
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@@ -143,8 +142,8 @@ T = TypeVar("T")
|
|||||||
|
|
||||||
async def execute_node(
|
async def execute_node(
|
||||||
node: Node,
|
node: Node,
|
||||||
creds_manager: IntegrationCredentialsManager,
|
|
||||||
data: NodeExecutionEntry,
|
data: NodeExecutionEntry,
|
||||||
|
execution_processor: "ExecutionProcessor",
|
||||||
execution_stats: NodeExecutionStats | None = None,
|
execution_stats: NodeExecutionStats | None = None,
|
||||||
nodes_input_masks: Optional[NodesInputMasks] = None,
|
nodes_input_masks: Optional[NodesInputMasks] = None,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
@@ -169,6 +168,7 @@ async def execute_node(
|
|||||||
node_id = data.node_id
|
node_id = data.node_id
|
||||||
node_block = node.block
|
node_block = node.block
|
||||||
execution_context = data.execution_context
|
execution_context = data.execution_context
|
||||||
|
creds_manager = execution_processor.creds_manager
|
||||||
|
|
||||||
log_metadata = LogMetadata(
|
log_metadata = LogMetadata(
|
||||||
logger=_logger,
|
logger=_logger,
|
||||||
@@ -212,6 +212,7 @@ async def execute_node(
|
|||||||
"node_exec_id": node_exec_id,
|
"node_exec_id": node_exec_id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"execution_context": execution_context,
|
"execution_context": execution_context,
|
||||||
|
"execution_processor": execution_processor,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
# Last-minute fetch credentials + acquire a system-wide read-write lock to prevent
|
||||||
@@ -608,8 +609,8 @@ class ExecutionProcessor:
|
|||||||
|
|
||||||
async for output_name, output_data in execute_node(
|
async for output_name, output_data in execute_node(
|
||||||
node=node,
|
node=node,
|
||||||
creds_manager=self.creds_manager,
|
|
||||||
data=node_exec,
|
data=node_exec,
|
||||||
|
execution_processor=self,
|
||||||
execution_stats=stats,
|
execution_stats=stats,
|
||||||
nodes_input_masks=nodes_input_masks,
|
nodes_input_masks=nodes_input_masks,
|
||||||
):
|
):
|
||||||
@@ -860,12 +861,17 @@ class ExecutionProcessor:
|
|||||||
execution_stats_lock = threading.Lock()
|
execution_stats_lock = threading.Lock()
|
||||||
|
|
||||||
# State holders ----------------------------------------------------
|
# State holders ----------------------------------------------------
|
||||||
running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
|
self.running_node_execution: dict[str, NodeExecutionProgress] = defaultdict(
|
||||||
NodeExecutionProgress
|
NodeExecutionProgress
|
||||||
)
|
)
|
||||||
running_node_evaluation: dict[str, Future] = {}
|
self.running_node_evaluation: dict[str, Future] = {}
|
||||||
|
self.execution_stats = execution_stats
|
||||||
|
self.execution_stats_lock = execution_stats_lock
|
||||||
execution_queue = ExecutionQueue[NodeExecutionEntry]()
|
execution_queue = ExecutionQueue[NodeExecutionEntry]()
|
||||||
|
|
||||||
|
running_node_execution = self.running_node_execution
|
||||||
|
running_node_evaluation = self.running_node_evaluation
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if db_client.get_credits(graph_exec.user_id) <= 0:
|
if db_client.get_credits(graph_exec.user_id) <= 0:
|
||||||
raise InsufficientBalanceError(
|
raise InsufficientBalanceError(
|
||||||
|
|||||||
@@ -5,6 +5,13 @@ from tiktoken import encoding_for_model
|
|||||||
|
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------#
|
||||||
|
# CONSTANTS #
|
||||||
|
# ---------------------------------------------------------------------------#
|
||||||
|
|
||||||
|
# Message prefixes for important system messages that should be protected during compression
|
||||||
|
MAIN_OBJECTIVE_PREFIX = "[Main Objective Prompt]: "
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------#
|
# ---------------------------------------------------------------------------#
|
||||||
# INTERNAL UTILITIES #
|
# INTERNAL UTILITIES #
|
||||||
# ---------------------------------------------------------------------------#
|
# ---------------------------------------------------------------------------#
|
||||||
@@ -63,6 +70,55 @@ def _msg_tokens(msg: dict, enc) -> int:
|
|||||||
return WRAPPER + content_tokens + tool_call_tokens
|
return WRAPPER + content_tokens + tool_call_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tool_message(msg: dict) -> bool:
|
||||||
|
"""Check if a message contains tool calls or results that should be protected."""
|
||||||
|
content = msg.get("content")
|
||||||
|
|
||||||
|
# Check for Anthropic-style tool messages
|
||||||
|
if isinstance(content, list) and any(
|
||||||
|
isinstance(item, dict) and item.get("type") in ("tool_use", "tool_result")
|
||||||
|
for item in content
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for OpenAI-style tool calls in the message
|
||||||
|
if "tool_calls" in msg or msg.get("role") == "tool":
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_objective_message(msg: dict) -> bool:
|
||||||
|
"""Check if a message contains objective/system prompts that should be absolutely protected."""
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
# Protect any message with the main objective prefix
|
||||||
|
return content.startswith(MAIN_OBJECTIVE_PREFIX)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
|
||||||
|
"""
|
||||||
|
Carefully truncate tool message content while preserving tool structure.
|
||||||
|
Only truncates tool_result content, leaves tool_use intact.
|
||||||
|
"""
|
||||||
|
content = msg.get("content")
|
||||||
|
if not isinstance(content, list):
|
||||||
|
return
|
||||||
|
|
||||||
|
for item in content:
|
||||||
|
# Only process tool_result items, leave tool_use blocks completely intact
|
||||||
|
if not (isinstance(item, dict) and item.get("type") == "tool_result"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
result_content = item.get("content", "")
|
||||||
|
if (
|
||||||
|
isinstance(result_content, str)
|
||||||
|
and _tok_len(result_content, enc) > max_tokens
|
||||||
|
):
|
||||||
|
item["content"] = _truncate_middle_tokens(result_content, enc, max_tokens)
|
||||||
|
|
||||||
|
|
||||||
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||||
"""
|
"""
|
||||||
Return *text* shortened to ≈max_tok tokens by keeping the head & tail
|
Return *text* shortened to ≈max_tok tokens by keeping the head & tail
|
||||||
@@ -140,13 +196,21 @@ def compress_prompt(
|
|||||||
return sum(_msg_tokens(m, enc) for m in msgs)
|
return sum(_msg_tokens(m, enc) for m in msgs)
|
||||||
|
|
||||||
original_token_count = total_tokens()
|
original_token_count = total_tokens()
|
||||||
|
|
||||||
if original_token_count + reserve <= target_tokens:
|
if original_token_count + reserve <= target_tokens:
|
||||||
return msgs
|
return msgs
|
||||||
|
|
||||||
# ---- STEP 0 : normalise content --------------------------------------
|
# ---- STEP 0 : normalise content --------------------------------------
|
||||||
# Convert non-string payloads to strings so token counting is coherent.
|
# Convert non-string payloads to strings so token counting is coherent.
|
||||||
for m in msgs[1:-1]: # keep the first & last intact
|
for i, m in enumerate(msgs):
|
||||||
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
||||||
|
if _is_tool_message(m):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Keep first and last messages intact (unless they're tool messages)
|
||||||
|
if i == 0 or i == len(msgs) - 1:
|
||||||
|
continue
|
||||||
|
|
||||||
# Reasonable 20k-char ceiling prevents pathological blobs
|
# Reasonable 20k-char ceiling prevents pathological blobs
|
||||||
content_str = json.dumps(m["content"], separators=(",", ":"))
|
content_str = json.dumps(m["content"], separators=(",", ":"))
|
||||||
if len(content_str) > 20_000:
|
if len(content_str) > 20_000:
|
||||||
@@ -157,34 +221,45 @@ def compress_prompt(
|
|||||||
cap = start_cap
|
cap = start_cap
|
||||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||||
for m in msgs[1:-1]: # keep first & last intact
|
for m in msgs[1:-1]: # keep first & last intact
|
||||||
if _tok_len(m.get("content") or "", enc) > cap:
|
if _is_tool_message(m):
|
||||||
m["content"] = _truncate_middle_tokens(m["content"], enc, cap)
|
# For tool messages, only truncate tool result content, preserve structure
|
||||||
|
_truncate_tool_message_content(m, enc, cap)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if _is_objective_message(m):
|
||||||
|
# Never truncate objective messages - they contain the core task
|
||||||
|
continue
|
||||||
|
|
||||||
|
content = m.get("content") or ""
|
||||||
|
if _tok_len(content, enc) > cap:
|
||||||
|
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
||||||
cap //= 2 # tighten the screw
|
cap //= 2 # tighten the screw
|
||||||
|
|
||||||
# ---- STEP 2 : middle-out deletion -----------------------------------
|
# ---- STEP 2 : middle-out deletion -----------------------------------
|
||||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||||
|
# Identify all deletable messages (not first/last, not tool messages, not objective messages)
|
||||||
|
deletable_indices = []
|
||||||
|
for i in range(1, len(msgs) - 1): # Skip first and last
|
||||||
|
if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]):
|
||||||
|
deletable_indices.append(i)
|
||||||
|
|
||||||
|
if not deletable_indices:
|
||||||
|
break # nothing more we can drop
|
||||||
|
|
||||||
|
# Delete from center outward - find the index closest to center
|
||||||
centre = len(msgs) // 2
|
centre = len(msgs) // 2
|
||||||
# Build a symmetrical centre-out index walk: centre, centre+1, centre-1, ...
|
to_delete = min(deletable_indices, key=lambda i: abs(i - centre))
|
||||||
order = [centre] + [
|
del msgs[to_delete]
|
||||||
i
|
|
||||||
for pair in zip(range(centre + 1, len(msgs) - 1), range(centre - 1, 0, -1))
|
|
||||||
for i in pair
|
|
||||||
]
|
|
||||||
removed = False
|
|
||||||
for i in order:
|
|
||||||
msg = msgs[i]
|
|
||||||
if "tool_calls" in msg or msg.get("role") == "tool":
|
|
||||||
continue # protect tool shells
|
|
||||||
del msgs[i]
|
|
||||||
removed = True
|
|
||||||
break
|
|
||||||
if not removed: # nothing more we can drop
|
|
||||||
break
|
|
||||||
|
|
||||||
# ---- STEP 3 : final safety-net trim on first & last ------------------
|
# ---- STEP 3 : final safety-net trim on first & last ------------------
|
||||||
cap = start_cap
|
cap = start_cap
|
||||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||||
for idx in (0, -1): # first and last
|
for idx in (0, -1): # first and last
|
||||||
|
if _is_tool_message(msgs[idx]):
|
||||||
|
# For tool messages at first/last position, truncate tool result content only
|
||||||
|
_truncate_tool_message_content(msgs[idx], enc, cap)
|
||||||
|
continue
|
||||||
|
|
||||||
text = msgs[idx].get("content") or ""
|
text = msgs[idx].get("content") or ""
|
||||||
if _tok_len(text, enc) > cap:
|
if _tok_len(text, enc) > cap:
|
||||||
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
|
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import { scrollbarStyles } from "@/components/styles/scrollbars";
|
|||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { X } from "@phosphor-icons/react";
|
import { X } from "@phosphor-icons/react";
|
||||||
import * as RXDialog from "@radix-ui/react-dialog";
|
import * as RXDialog from "@radix-ui/react-dialog";
|
||||||
import { debounce } from "lodash";
|
|
||||||
import {
|
import {
|
||||||
CSSProperties,
|
CSSProperties,
|
||||||
PropsWithChildren,
|
PropsWithChildren,
|
||||||
@@ -71,21 +70,13 @@ export function DialogWrap({
|
|||||||
if (!el) return;
|
if (!el) return;
|
||||||
setHasVerticalScrollbar(el.scrollHeight > el.clientHeight + 1);
|
setHasVerticalScrollbar(el.scrollHeight > el.clientHeight + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Debounce the update function to prevent rapid successive state updates
|
|
||||||
const debouncedUpdate = debounce(update, 100);
|
|
||||||
|
|
||||||
// Initial update without debounce for immediate UI feedback
|
|
||||||
update();
|
update();
|
||||||
|
const ro = new ResizeObserver(update);
|
||||||
const ro = new ResizeObserver(debouncedUpdate);
|
|
||||||
if (scrollRef.current) ro.observe(scrollRef.current);
|
if (scrollRef.current) ro.observe(scrollRef.current);
|
||||||
window.addEventListener("resize", debouncedUpdate);
|
window.addEventListener("resize", update);
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
debouncedUpdate.cancel();
|
|
||||||
ro.disconnect();
|
ro.disconnect();
|
||||||
window.removeEventListener("resize", debouncedUpdate);
|
window.removeEventListener("resize", update);
|
||||||
};
|
};
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ const mockFlags = {
|
|||||||
[Flag.AGENT_FAVORITING]: false,
|
[Flag.AGENT_FAVORITING]: false,
|
||||||
[Flag.MARKETPLACE_SEARCH_TERMS]: DEFAULT_SEARCH_TERMS,
|
[Flag.MARKETPLACE_SEARCH_TERMS]: DEFAULT_SEARCH_TERMS,
|
||||||
[Flag.ENABLE_PLATFORM_PAYMENT]: false,
|
[Flag.ENABLE_PLATFORM_PAYMENT]: false,
|
||||||
[Flag.CHAT]: false,
|
[Flag.CHAT]: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] | null {
|
export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] | null {
|
||||||
|
|||||||
Reference in New Issue
Block a user