mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-10 07:38:04 -05:00
fix(blocks): Make Smart Decision Maker tool pin handling consistent and reliable (#11363)
- Resolves #11345 ### Changes 🏗️ - Move tool use routing logic from frontend to backend: routing info was being baked into graph links by the frontend, inconsistently, causing issues - Rework tool use routing to use target node ID instead of target block name - Add a bit of magic to `NodeOutputs` component to show tool node title instead of ID DX: - Removed `build` from `.prettierignore` -> re-enable formatting for builder components ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Use SDM block in a graph; verify it works - [x] Use SDM block with agent executor block as tool; verify it works - Tests for `parse_execution_output` pass (checked by CI)
This commit is contained in:
committed by
GitHub
parent
a3e5f7fce2
commit
536e2a5ec8
@@ -18,6 +18,7 @@ from backend.data.dynamic_fields import (
|
||||
extract_base_field_name,
|
||||
get_dynamic_field_description,
|
||||
is_dynamic_field,
|
||||
is_tool_pin,
|
||||
)
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
@@ -367,8 +368,9 @@ class SmartDecisionMakerBlock(Block):
|
||||
"required": sorted(required_fields),
|
||||
}
|
||||
|
||||
# Store field mapping for later use in output processing
|
||||
# Store field mapping and node info for later use in output processing
|
||||
tool_function["_field_mapping"] = field_mapping
|
||||
tool_function["_sink_node_id"] = sink_node.id
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@@ -431,10 +433,13 @@ class SmartDecisionMakerBlock(Block):
|
||||
"strict": True,
|
||||
}
|
||||
|
||||
# Store node info for later use in output processing
|
||||
tool_function["_sink_node_id"] = sink_node.id
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
async def _create_function_signature(
|
||||
async def _create_tool_node_signatures(
|
||||
node_id: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
@@ -450,7 +455,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
tools = [
|
||||
(link, node)
|
||||
for link, node in await db_client.get_connected_output_nodes(node_id)
|
||||
if link.source_name.startswith("tools_^_") and link.source_id == node_id
|
||||
if is_tool_pin(link.source_name) and link.source_id == node_id
|
||||
]
|
||||
if not tools:
|
||||
raise ValueError("There is no next node to execute.")
|
||||
@@ -538,8 +543,14 @@ class SmartDecisionMakerBlock(Block):
|
||||
),
|
||||
None,
|
||||
)
|
||||
if tool_def is None and len(tool_functions) == 1:
|
||||
tool_def = tool_functions[0]
|
||||
if tool_def is None:
|
||||
if len(tool_functions) == 1:
|
||||
tool_def = tool_functions[0]
|
||||
else:
|
||||
validation_errors_list.append(
|
||||
f"Tool call for '{tool_name}' does not match any known "
|
||||
"tool definition."
|
||||
)
|
||||
|
||||
# Get parameters schema from tool definition
|
||||
if (
|
||||
@@ -591,7 +602,7 @@ class SmartDecisionMakerBlock(Block):
|
||||
user_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
tool_functions = await self._create_function_signature(node_id)
|
||||
tool_functions = await self._create_tool_node_signatures(node_id)
|
||||
yield "tool_functions", json.dumps(tool_functions)
|
||||
|
||||
input_data.conversation_history = input_data.conversation_history or []
|
||||
@@ -661,9 +672,9 @@ class SmartDecisionMakerBlock(Block):
|
||||
except ValueError as e:
|
||||
last_error = e
|
||||
error_feedback = (
|
||||
"Your tool call had parameter errors. Please fix the following issues and try again:\n"
|
||||
"Your tool call had errors. Please fix the following issues and try again:\n"
|
||||
+ f"- {str(e)}\n"
|
||||
+ "\nPlease make sure to use the exact parameter names as specified in the function schema."
|
||||
+ "\nPlease make sure to use the exact tool and parameter names as specified in the function schema."
|
||||
)
|
||||
current_prompt = list(current_prompt) + [
|
||||
{"role": "user", "content": error_feedback}
|
||||
@@ -690,21 +701,23 @@ class SmartDecisionMakerBlock(Block):
|
||||
),
|
||||
None,
|
||||
)
|
||||
if (
|
||||
tool_def
|
||||
and "function" in tool_def
|
||||
and "parameters" in tool_def["function"]
|
||||
):
|
||||
if not tool_def:
|
||||
# NOTE: This matches the logic in _attempt_llm_call_with_validation and
|
||||
# relies on its validation for the assumption that this is valid to use.
|
||||
if len(tool_functions) == 1:
|
||||
tool_def = tool_functions[0]
|
||||
else:
|
||||
# This should not happen due to prior validation
|
||||
continue
|
||||
|
||||
if "function" in tool_def and "parameters" in tool_def["function"]:
|
||||
expected_args = tool_def["function"]["parameters"].get("properties", {})
|
||||
else:
|
||||
expected_args = {arg: {} for arg in tool_args.keys()}
|
||||
|
||||
# Get field mapping from tool definition
|
||||
field_mapping = (
|
||||
tool_def.get("function", {}).get("_field_mapping", {})
|
||||
if tool_def
|
||||
else {}
|
||||
)
|
||||
# Get the sink node ID and field mapping from tool definition
|
||||
field_mapping = tool_def["function"].get("_field_mapping", {})
|
||||
sink_node_id = tool_def["function"]["_sink_node_id"]
|
||||
|
||||
for clean_arg_name in expected_args:
|
||||
# arg_name is now always the cleaned field name (for Anthropic API compliance)
|
||||
@@ -712,9 +725,8 @@ class SmartDecisionMakerBlock(Block):
|
||||
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
|
||||
arg_value = tool_args.get(clean_arg_name)
|
||||
|
||||
sanitized_tool_name = self.cleanup(tool_name)
|
||||
sanitized_arg_name = self.cleanup(original_field_name)
|
||||
emit_key = f"tools_^_{sanitized_tool_name}_~_{sanitized_arg_name}"
|
||||
emit_key = f"tools_^_{sink_node_id}_~_{sanitized_arg_name}"
|
||||
|
||||
logger.debug(
|
||||
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",
|
||||
|
||||
@@ -165,7 +165,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
|
||||
)
|
||||
test_graph = await create_graph(server, test_graph, test_user)
|
||||
|
||||
tool_functions = await SmartDecisionMakerBlock._create_function_signature(
|
||||
tool_functions = await SmartDecisionMakerBlock._create_tool_node_signatures(
|
||||
test_graph.nodes[0].id
|
||||
)
|
||||
assert tool_functions is not None, "Tool functions should not be None"
|
||||
@@ -215,7 +215,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
"content": "I need to think about this.",
|
||||
}
|
||||
|
||||
# Mock the _create_function_signature method to avoid database calls
|
||||
# Mock the _create_tool_node_signatures method to avoid database calls
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
with patch(
|
||||
@@ -224,7 +224,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
|
||||
return_value=mock_response,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
@@ -293,6 +293,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
},
|
||||
"required": ["query", "max_keyword_difficulty"],
|
||||
},
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
},
|
||||
}
|
||||
]
|
||||
@@ -318,7 +319,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
return_value=mock_response_with_typo,
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
@@ -375,7 +376,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
return_value=mock_response_missing_required,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
@@ -425,7 +426,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
return_value=mock_response_valid,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
@@ -450,13 +451,13 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify tool outputs were generated correctly
|
||||
assert "tools_^_search_keywords_~_query" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert "tools_^_search_keywords_~_max_keyword_difficulty" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
assert "tools_^_test-sink-node-id_~_query" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_query"] == "test"
|
||||
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_max_keyword_difficulty"] == 50
|
||||
# Optional parameter should be None when not provided
|
||||
assert "tools_^_search_keywords_~_optional_param" in outputs
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] is None
|
||||
assert "tools_^_test-sink-node-id_~_optional_param" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_optional_param"] is None
|
||||
|
||||
# Test case 4: Valid tool call with ALL parameters (should succeed)
|
||||
mock_tool_call_all_params = MagicMock()
|
||||
@@ -479,7 +480,7 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
return_value=mock_response_all_params,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
@@ -504,9 +505,9 @@ async def test_smart_decision_maker_parameter_validation():
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify all tool outputs were generated correctly
|
||||
assert outputs["tools_^_search_keywords_~_query"] == "test"
|
||||
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
|
||||
assert outputs["tools_^_search_keywords_~_optional_param"] == "custom_value"
|
||||
assert outputs["tools_^_test-sink-node-id_~_query"] == "test"
|
||||
assert outputs["tools_^_test-sink-node-id_~_max_keyword_difficulty"] == 50
|
||||
assert outputs["tools_^_test-sink-node-id_~_optional_param"] == "custom_value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -530,6 +531,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
"properties": {"param": {"type": "string"}},
|
||||
"required": ["param"],
|
||||
},
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
},
|
||||
}
|
||||
]
|
||||
@@ -588,7 +590,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
"backend.blocks.llm.llm_call", new_callable=AsyncMock
|
||||
) as mock_llm_call, patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_tool_functions,
|
||||
):
|
||||
@@ -617,8 +619,8 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
outputs[output_name] = output_data
|
||||
|
||||
# Verify the tool output was generated successfully
|
||||
assert "tools_^_test_tool_~_param" in outputs
|
||||
assert outputs["tools_^_test_tool_~_param"] == "test_value"
|
||||
assert "tools_^_test-sink-node-id_~_param" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_param"] == "test_value"
|
||||
|
||||
# Verify conversation history was properly maintained
|
||||
assert "conversations" in outputs
|
||||
@@ -656,7 +658,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
return_value=mock_response_ollama,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[], # No tools for this test
|
||||
):
|
||||
@@ -702,7 +704,7 @@ async def test_smart_decision_maker_raw_response_conversion():
|
||||
return_value=mock_response_dict,
|
||||
), patch.object(
|
||||
SmartDecisionMakerBlock,
|
||||
"_create_function_signature",
|
||||
"_create_tool_node_signatures",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
):
|
||||
|
||||
@@ -192,7 +192,7 @@ async def test_create_block_function_signature_with_object_fields():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_function_signature():
|
||||
async def test_create_tool_node_signatures():
|
||||
"""Test that the mapping between sanitized and original field names is built correctly."""
|
||||
block = SmartDecisionMakerBlock()
|
||||
|
||||
@@ -241,7 +241,7 @@ async def test_create_function_signature():
|
||||
]
|
||||
|
||||
# Call the method that builds signatures
|
||||
tool_functions = await block._create_function_signature("test_node_id")
|
||||
tool_functions = await block._create_tool_node_signatures("test_node_id")
|
||||
|
||||
# Verify we got 2 tool functions (one for dict, one for list)
|
||||
assert len(tool_functions) == 2
|
||||
@@ -310,7 +310,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
|
||||
# Mock the function signature creation
|
||||
with patch.object(
|
||||
block, "_create_function_signature", new_callable=AsyncMock
|
||||
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
mock_sig.return_value = [
|
||||
{
|
||||
@@ -325,6 +325,7 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
"values___email": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
},
|
||||
}
|
||||
]
|
||||
@@ -351,16 +352,16 @@ async def test_output_yielding_with_dynamic_fields():
|
||||
):
|
||||
outputs[output_name] = output_value
|
||||
|
||||
# Verify the outputs use sanitized field names (matching frontend normalizeToolName)
|
||||
assert "tools_^_createdictionaryblock_~_values___name" in outputs
|
||||
assert outputs["tools_^_createdictionaryblock_~_values___name"] == "Alice"
|
||||
# Verify the outputs use sink node ID in output keys
|
||||
assert "tools_^_test-sink-node-id_~_values___name" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_values___name"] == "Alice"
|
||||
|
||||
assert "tools_^_createdictionaryblock_~_values___age" in outputs
|
||||
assert outputs["tools_^_createdictionaryblock_~_values___age"] == 30
|
||||
assert "tools_^_test-sink-node-id_~_values___age" in outputs
|
||||
assert outputs["tools_^_test-sink-node-id_~_values___age"] == 30
|
||||
|
||||
assert "tools_^_createdictionaryblock_~_values___email" in outputs
|
||||
assert "tools_^_test-sink-node-id_~_values___email" in outputs
|
||||
assert (
|
||||
outputs["tools_^_createdictionaryblock_~_values___email"]
|
||||
outputs["tools_^_test-sink-node-id_~_values___email"]
|
||||
== "alice@example.com"
|
||||
)
|
||||
|
||||
@@ -488,7 +489,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
|
||||
# Mock the function signature creation
|
||||
with patch.object(
|
||||
block, "_create_function_signature", new_callable=AsyncMock
|
||||
block, "_create_tool_node_signatures", new_callable=AsyncMock
|
||||
) as mock_sig:
|
||||
mock_sig.return_value = [
|
||||
{
|
||||
@@ -505,6 +506,7 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
},
|
||||
"required": ["correct_param"],
|
||||
},
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
@@ -92,6 +92,18 @@ def get_dynamic_field_description(field_name: str) -> str:
|
||||
return f"Value for {field_name}"
|
||||
|
||||
|
||||
def is_tool_pin(name: str) -> bool:
|
||||
"""Check if a pin name represents a tool connection."""
|
||||
return name.startswith("tools_^_") or name == "tools"
|
||||
|
||||
|
||||
def sanitize_pin_name(name: str) -> str:
|
||||
sanitized_name = extract_base_field_name(name)
|
||||
if is_tool_pin(sanitized_name):
|
||||
return "tools"
|
||||
return sanitized_name
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Dynamic field parsing and merging utilities
|
||||
# --------------------------------------------------------------------------- #
|
||||
@@ -137,30 +149,64 @@ def _tokenise(path: str) -> list[tuple[str, str]] | None:
|
||||
return tokens
|
||||
|
||||
|
||||
def parse_execution_output(output: tuple[str, Any], name: str) -> Any:
|
||||
def parse_execution_output(
|
||||
output_item: tuple[str, Any],
|
||||
link_output_selector: str,
|
||||
sink_node_id: str | None = None,
|
||||
sink_pin_name: str | None = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieve a nested value out of `output` using the flattened *name*.
|
||||
Retrieve a nested value out of `output` using the flattened `link_output_selector`.
|
||||
|
||||
On any failure (wrong name, wrong type, out-of-range, bad path)
|
||||
returns **None**.
|
||||
On any failure (wrong name, wrong type, out-of-range, bad path) returns **None**.
|
||||
|
||||
### Special Case: Tool pins
|
||||
For regular output pins, the `output_item`'s name will simply be the field name, and
|
||||
`link_output_selector` (= the `source_name` of the link) may provide a "selector"
|
||||
used to extract part of the output value and route it through the link
|
||||
to the next node.
|
||||
|
||||
However, for tool pins, it is the other way around: the `output_item`'s name
|
||||
provides the routing information (`tools_^_{sink_node_id}_~_{field_name}`),
|
||||
and the `link_output_selector` is simply `"tools"`
|
||||
(or `"tools_^_{tool_name}_~_{field_name}"` for backward compatibility).
|
||||
|
||||
Args:
|
||||
output: Tuple of (base_name, data) representing a block output entry
|
||||
name: The flattened field name to extract from the output data
|
||||
output_item: Tuple of (base_name, data) representing a block output entry.
|
||||
link_output_selector: The flattened field name to extract from the output data.
|
||||
sink_node_id: Sink node ID, used for tool use routing.
|
||||
sink_pin_name: Sink pin name, used for tool use routing.
|
||||
|
||||
Returns:
|
||||
The value at the specified path, or None if not found/invalid
|
||||
The value at the specified path, or `None` if not found/invalid.
|
||||
"""
|
||||
base_name, data = output
|
||||
output_pin_name, data = output_item
|
||||
|
||||
# Special handling for tool pins
|
||||
if is_tool_pin(link_output_selector) and ( # "tools" or "tools_^_…"
|
||||
output_pin_name.startswith("tools_^_") and "_~_" in output_pin_name
|
||||
):
|
||||
if not (sink_node_id and sink_pin_name):
|
||||
raise ValueError(
|
||||
"sink_node_id and sink_pin_name must be provided for tool pin routing"
|
||||
)
|
||||
|
||||
# Extract routing information from emit key: tools_^_{node_id}_~_{field}
|
||||
selector = output_pin_name[8:] # Remove "tools_^_" prefix
|
||||
target_node_id, target_input_pin = selector.split("_~_", 1)
|
||||
if target_node_id == sink_node_id and target_input_pin == sink_pin_name:
|
||||
return data
|
||||
else:
|
||||
return None
|
||||
|
||||
# Exact match → whole object
|
||||
if name == base_name:
|
||||
if link_output_selector == output_pin_name:
|
||||
return data
|
||||
|
||||
# Must start with the expected name
|
||||
if not name.startswith(base_name):
|
||||
if not link_output_selector.startswith(output_pin_name):
|
||||
return None
|
||||
path = name[len(base_name) :]
|
||||
path = link_output_selector[len(output_pin_name) :]
|
||||
if not path:
|
||||
return None # nothing left to parse
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import prisma as db
|
||||
from backend.data.dynamic_fields import extract_base_field_name
|
||||
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
|
||||
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
@@ -578,9 +578,9 @@ class GraphModel(Graph):
|
||||
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
|
||||
)
|
||||
provided_inputs = set(
|
||||
[_sanitize_pin_name(name) for name in node.input_default]
|
||||
[sanitize_pin_name(name) for name in node.input_default]
|
||||
+ [
|
||||
_sanitize_pin_name(link.sink_name)
|
||||
sanitize_pin_name(link.sink_name)
|
||||
for link in input_links.get(node.id, [])
|
||||
]
|
||||
+ ([name for name in node_input_mask] if node_input_mask else [])
|
||||
@@ -696,7 +696,7 @@ class GraphModel(Graph):
|
||||
f"{prefix}, {node.block_id} is invalid block id, available blocks: {blocks}"
|
||||
)
|
||||
|
||||
sanitized_name = _sanitize_pin_name(name)
|
||||
sanitized_name = sanitize_pin_name(name)
|
||||
vals = node.input_default
|
||||
if i == 0:
|
||||
fields = (
|
||||
@@ -710,7 +710,7 @@ class GraphModel(Graph):
|
||||
if block.block_type not in [BlockType.AGENT]
|
||||
else vals.get("input_schema", {}).get("properties", {}).keys()
|
||||
)
|
||||
if sanitized_name not in fields and not _is_tool_pin(name):
|
||||
if sanitized_name not in fields and not is_tool_pin(name):
|
||||
fields_msg = f"Allowed fields: {fields}"
|
||||
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
|
||||
|
||||
@@ -750,17 +750,6 @@ class GraphModel(Graph):
|
||||
)
|
||||
|
||||
|
||||
def _is_tool_pin(name: str) -> bool:
|
||||
return name.startswith("tools_^_")
|
||||
|
||||
|
||||
def _sanitize_pin_name(name: str) -> str:
|
||||
sanitized_name = extract_base_field_name(name)
|
||||
if _is_tool_pin(sanitized_name):
|
||||
return "tools"
|
||||
return sanitized_name
|
||||
|
||||
|
||||
class GraphMeta(Graph):
|
||||
user_id: str
|
||||
|
||||
|
||||
@@ -322,7 +322,9 @@ async def _enqueue_next_nodes(
|
||||
next_node_id = node_link.sink_id
|
||||
|
||||
output_name, _ = output
|
||||
next_data = parse_execution_output(output, next_output_name)
|
||||
next_data = parse_execution_output(
|
||||
output, next_output_name, next_node_id, next_input_name
|
||||
)
|
||||
if next_data is None and output_name != next_output_name:
|
||||
return enqueued_executions
|
||||
next_node = await db_client.get_node(next_node_id)
|
||||
|
||||
@@ -111,6 +111,35 @@ def test_parse_execution_output():
|
||||
parse_execution_output(output, "result_@_attr_$_0_#_key") is None
|
||||
) # Should fail at @_attr
|
||||
|
||||
# Test case 7: Tool pin routing with matching node ID and pin name
|
||||
output = ("tools_^_node123_~_query", "search term")
|
||||
assert parse_execution_output(output, "tools", "node123", "query") == "search term"
|
||||
|
||||
# Test case 8: Tool pin routing with node ID mismatch
|
||||
output = ("tools_^_node123_~_query", "search term")
|
||||
assert parse_execution_output(output, "tools", "node456", "query") is None
|
||||
|
||||
# Test case 9: Tool pin routing with pin name mismatch
|
||||
output = ("tools_^_node123_~_query", "search term")
|
||||
assert parse_execution_output(output, "tools", "node123", "different_pin") is None
|
||||
|
||||
# Test case 10: Tool pin routing with complex field names
|
||||
output = ("tools_^_node789_~_nested_field", {"key": "value"})
|
||||
result = parse_execution_output(output, "tools", "node789", "nested_field")
|
||||
assert result == {"key": "value"}
|
||||
|
||||
# Test case 11: Tool pin routing missing required parameters should raise error
|
||||
output = ("tools_^_node123_~_query", "search term")
|
||||
try:
|
||||
parse_execution_output(output, "tools", "node123") # Missing sink_pin_name
|
||||
assert False, "Should have raised ValueError"
|
||||
except ValueError as e:
|
||||
assert "must be provided for tool pin routing" in str(e)
|
||||
|
||||
# Test case 12: Non-tool pin with similar pattern should use normal logic
|
||||
output = ("tools_^_node123_~_query", "search term")
|
||||
assert parse_execution_output(output, "different_name", "node123", "query") is None
|
||||
|
||||
|
||||
def test_merge_execution_input():
|
||||
# Test case for basic list extraction
|
||||
|
||||
@@ -2,7 +2,6 @@ node_modules
|
||||
pnpm-lock.yaml
|
||||
.next
|
||||
.auth
|
||||
build
|
||||
public
|
||||
Dockerfile
|
||||
.prettierignore
|
||||
|
||||
@@ -36,18 +36,22 @@ export const useGraphContent = ({
|
||||
if (!node || !node.data) {
|
||||
return "";
|
||||
}
|
||||
|
||||
|
||||
const inputs = Object.keys(node.data?.inputSchema?.properties || {});
|
||||
const outputs = Object.keys(node.data?.outputSchema?.properties || {});
|
||||
const parts = [];
|
||||
|
||||
|
||||
if (inputs.length > 0) {
|
||||
parts.push(`Inputs: ${inputs.slice(0, 3).join(", ")}${inputs.length > 3 ? "..." : ""}`);
|
||||
parts.push(
|
||||
`Inputs: ${inputs.slice(0, 3).join(", ")}${inputs.length > 3 ? "..." : ""}`,
|
||||
);
|
||||
}
|
||||
if (outputs.length > 0) {
|
||||
parts.push(`Outputs: ${outputs.slice(0, 3).join(", ")}${outputs.length > 3 ? "..." : ""}`);
|
||||
parts.push(
|
||||
`Outputs: ${outputs.slice(0, 3).join(", ")}${outputs.length > 3 ? "..." : ""}`,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
return parts.join(" | ");
|
||||
};
|
||||
|
||||
@@ -57,4 +61,4 @@ export const useGraphContent = ({
|
||||
handleKeyDown,
|
||||
getNodeInputOutputSummary,
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
@@ -18,4 +18,4 @@ export const useGraphMenuSearchBarComponent = ({
|
||||
inputRef,
|
||||
handleClear,
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
@@ -98,7 +98,7 @@ export type CustomNodeData = {
|
||||
errors?: { [key: string]: string };
|
||||
isOutputStatic?: boolean;
|
||||
uiType: BlockUIType;
|
||||
metadata?: { [key: string]: any };
|
||||
metadata?: { customized_name?: string; [key: string]: any };
|
||||
};
|
||||
|
||||
export type CustomNode = XYNode<CustomNodeData, "custom">;
|
||||
@@ -201,7 +201,7 @@ export const CustomNode = React.memo(
|
||||
|
||||
useEffect(() => {
|
||||
isInitialSetup.current = false;
|
||||
if (data.backend_id) return; // don't auto-modify existing nodes
|
||||
if (data.backend_id) return; // don't auto-modify existing nodes
|
||||
|
||||
if (data.uiType === BlockUIType.AGENT) {
|
||||
setHardcodedValues({
|
||||
@@ -822,7 +822,7 @@ export const CustomNode = React.memo(
|
||||
{isTitleHovered && !isEditingTitle && (
|
||||
<button
|
||||
onClick={handleTitleEdit}
|
||||
className="cursor-pointer rounded p-1 opacity-0 transition-opacity hover:bg-gray-100 group-hover:opacity-100"
|
||||
className="cursor-pointer rounded p-1 opacity-0 transition-opacity group-hover:opacity-100 hover:bg-gray-100"
|
||||
aria-label="Edit title"
|
||||
>
|
||||
<Pencil1Icon className="h-4 w-4" />
|
||||
|
||||
@@ -36,11 +36,7 @@ import {
|
||||
LibraryAgent,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import {
|
||||
getTypeColor,
|
||||
findNewlyAddedBlockCoordinates,
|
||||
beautifyString,
|
||||
} from "@/lib/utils";
|
||||
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
|
||||
import { history } from "../history";
|
||||
import { CustomEdge } from "../CustomEdge/CustomEdge";
|
||||
import ConnectionLine from "../ConnectionLine";
|
||||
@@ -78,6 +74,7 @@ type BuilderContextType = {
|
||||
visualizeBeads: "no" | "static" | "animate";
|
||||
setIsAnyModalOpen: (isOpen: boolean) => void;
|
||||
getNextNodeId: () => string;
|
||||
getNodeTitle: (nodeID: string) => string | null;
|
||||
};
|
||||
|
||||
export type NodeDimension = {
|
||||
@@ -148,6 +145,9 @@ const FlowEditor: React.FC<{
|
||||
flowExecutionID,
|
||||
visualizeBeads !== "no",
|
||||
);
|
||||
const [immediateNodePositions, setImmediateNodePositions] = useState<
|
||||
Record<string, { x: number; y: number }>
|
||||
>(Object.fromEntries(nodes.map((node) => [node.id, node.position])));
|
||||
|
||||
const router = useRouter();
|
||||
const pathname = usePathname();
|
||||
@@ -241,6 +241,13 @@ const FlowEditor: React.FC<{
|
||||
const oldPosition = initialPositionRef.current[node.id];
|
||||
const newPosition = node.position;
|
||||
|
||||
// Clear immediate position, because on drag end it is no longer needed
|
||||
setImmediateNodePositions((prevPositions) => {
|
||||
const updatedPositions = { ...prevPositions };
|
||||
delete updatedPositions[node.id];
|
||||
return updatedPositions;
|
||||
});
|
||||
|
||||
// Calculate the movement distance
|
||||
if (!oldPosition || !newPosition) return;
|
||||
|
||||
@@ -278,6 +285,26 @@ const FlowEditor: React.FC<{
|
||||
|
||||
const onNodesChange = useCallback(
|
||||
(nodeChanges: NodeChange<CustomNode>[]) => {
|
||||
// Intercept position changes to update immediate positions & prevent excessive node re-renders
|
||||
const draggingPosChanges = nodeChanges
|
||||
.filter((c) => c.type === "position")
|
||||
.filter((c) => c.dragging === true);
|
||||
if (draggingPosChanges.length > 0) {
|
||||
setImmediateNodePositions((prevPositions) => {
|
||||
const newPositions = { ...prevPositions };
|
||||
draggingPosChanges.forEach((change) => {
|
||||
if (change.position) newPositions[change.id] = change.position;
|
||||
});
|
||||
return newPositions;
|
||||
});
|
||||
|
||||
// Don't further process ongoing position changes
|
||||
nodeChanges = nodeChanges.filter(
|
||||
(change) => change.type !== "position" || change.dragging !== true,
|
||||
);
|
||||
if (nodeChanges.length === 0) return;
|
||||
}
|
||||
|
||||
// Persist the changes
|
||||
setNodes((prev) => applyNodeChanges(nodeChanges, prev));
|
||||
|
||||
@@ -675,6 +702,21 @@ const FlowEditor: React.FC<{
|
||||
findNodeDimensions();
|
||||
}, [nodes, findNodeDimensions]);
|
||||
|
||||
const getNodeTitle = useCallback(
|
||||
(nodeID: string) => {
|
||||
const node = nodes.find((n) => n.data.backend_id === nodeID);
|
||||
if (!node) return null;
|
||||
|
||||
return (
|
||||
node.data.metadata?.customized_name ||
|
||||
(node.data.uiType == BlockUIType.AGENT &&
|
||||
node.data.hardcodedValues.agent_name) ||
|
||||
node.data.blockType.replace(/Block$/, "")
|
||||
);
|
||||
},
|
||||
[nodes],
|
||||
);
|
||||
|
||||
const handleCopyPaste = useCopyPaste(getNextNodeId);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
@@ -783,7 +825,7 @@ const FlowEditor: React.FC<{
|
||||
data: {
|
||||
blockType: blockName,
|
||||
blockCosts: nodeSchema.costs || [],
|
||||
title: `${beautifyString(blockName)} ${nodeId}`,
|
||||
title: `${blockName} ${nodeId}`,
|
||||
description: nodeSchema.description,
|
||||
categories: nodeSchema.categories,
|
||||
inputSchema: nodeSchema.inputSchema,
|
||||
@@ -826,13 +868,29 @@ const FlowEditor: React.FC<{
|
||||
],
|
||||
);
|
||||
|
||||
const buildContextValue: BuilderContextType = useMemo(
|
||||
() => ({
|
||||
libraryAgent,
|
||||
visualizeBeads,
|
||||
setIsAnyModalOpen,
|
||||
getNextNodeId,
|
||||
getNodeTitle,
|
||||
}),
|
||||
[libraryAgent, visualizeBeads, getNextNodeId, getNodeTitle],
|
||||
);
|
||||
|
||||
return (
|
||||
<BuilderContext.Provider
|
||||
value={{ libraryAgent, visualizeBeads, setIsAnyModalOpen, getNextNodeId }}
|
||||
>
|
||||
<BuilderContext.Provider value={buildContextValue}>
|
||||
<div className={className}>
|
||||
<ReactFlow
|
||||
nodes={nodes}
|
||||
nodes={nodes.map((node) =>
|
||||
node.id in immediateNodePositions
|
||||
? {
|
||||
...node,
|
||||
position: immediateNodePositions[node.id] || node.position,
|
||||
}
|
||||
: node,
|
||||
)}
|
||||
edges={edges}
|
||||
nodeTypes={{ custom: CustomNode }}
|
||||
edgeTypes={{ custom: CustomEdge }}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import React, { useState } from "react";
|
||||
import { ContentRenderer } from "../../../../../components/__legacy__/ui/render";
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
import React, { useContext, useState } from "react";
|
||||
import { Button } from "@/components/__legacy__/ui/button";
|
||||
import { Maximize2 } from "lucide-react";
|
||||
import { Button } from "../../../../../components/__legacy__/ui/button";
|
||||
import * as Separator from "@radix-ui/react-separator";
|
||||
import { ContentRenderer } from "@/components/__legacy__/ui/render";
|
||||
|
||||
import { beautifyString } from "@/lib/utils";
|
||||
|
||||
import { BuilderContext } from "./Flow/Flow";
|
||||
import ExpandableOutputDialog from "./ExpandableOutputDialog";
|
||||
|
||||
type NodeOutputsProps = {
|
||||
@@ -17,6 +20,8 @@ export default function NodeOutputs({
|
||||
truncateLongData,
|
||||
data,
|
||||
}: NodeOutputsProps) {
|
||||
const builderContext = useContext(BuilderContext);
|
||||
|
||||
const [expandedDialog, setExpandedDialog] = useState<{
|
||||
isOpen: boolean;
|
||||
execId: string;
|
||||
@@ -24,6 +29,26 @@ export default function NodeOutputs({
|
||||
data: any[];
|
||||
} | null>(null);
|
||||
|
||||
if (!builderContext) {
|
||||
throw new Error(
|
||||
"BuilderContext consumer must be inside FlowEditor component",
|
||||
);
|
||||
}
|
||||
|
||||
const { getNodeTitle } = builderContext;
|
||||
|
||||
const getBeautifiedPinName = (pin: string) => {
|
||||
if (!pin.startsWith("tools_^_")) {
|
||||
return beautifyString(pin);
|
||||
}
|
||||
// Special handling for tool pins: replace node ID with node title
|
||||
const toolNodeID = pin.slice(8).split("_~_")[0]; // tools_^_{node_id}_~_{field}
|
||||
const toolNodeTitle = getNodeTitle(toolNodeID);
|
||||
return toolNodeTitle
|
||||
? beautifyString(pin.replace(toolNodeID, toolNodeTitle))
|
||||
: beautifyString(pin);
|
||||
};
|
||||
|
||||
const openExpandedView = (pinName: string, pinData: any[]) => {
|
||||
setExpandedDialog({
|
||||
isOpen: true,
|
||||
@@ -44,7 +69,7 @@ export default function NodeOutputs({
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center">
|
||||
<strong className="mr-2">Pin:</strong>
|
||||
<span>{beautifyString(pin)}</span>
|
||||
<span>{getBeautifiedPinName(pin)}</span>
|
||||
</div>
|
||||
{(truncateLongData || dataArray.length > 10) && (
|
||||
<Button
|
||||
|
||||
@@ -17,7 +17,6 @@ import {
|
||||
LinkCreatable,
|
||||
NodeCreatable,
|
||||
NodeExecutionResult,
|
||||
SpecialBlockID,
|
||||
Node,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
@@ -276,28 +275,6 @@ export default function useAgentGraph(
|
||||
const cleanupSourceName = (sourceName: string) =>
|
||||
isToolSourceName(sourceName) ? "tools" : sourceName;
|
||||
|
||||
const getToolFuncName = useCallback(
|
||||
(nodeID: string) => {
|
||||
const sinkNode = xyNodes.find((node) => node.id === nodeID);
|
||||
if (!sinkNode) return "";
|
||||
|
||||
const sinkNodeName =
|
||||
sinkNode.data.block_id === SpecialBlockID.AGENT
|
||||
? sinkNode.data.hardcodedValues?.agent_name ||
|
||||
availableFlows.find(
|
||||
(flow) => flow.id === sinkNode.data.hardcodedValues.graph_id,
|
||||
)?.name ||
|
||||
"agentexecutorblock"
|
||||
: sinkNode.data.title.split(" ")[0];
|
||||
|
||||
return sinkNodeName;
|
||||
},
|
||||
[xyNodes, availableFlows],
|
||||
);
|
||||
|
||||
const normalizeToolName = (str: string) =>
|
||||
str.replace(/[^a-zA-Z0-9_-]/g, "_").toLowerCase(); // This normalization rule has to match with the one on smart_decision_maker.py
|
||||
|
||||
/** ------------------------------ */
|
||||
|
||||
const updateEdgeBeads = useCallback(
|
||||
@@ -607,17 +584,10 @@ export default function useAgentGraph(
|
||||
|
||||
const prepareSaveableGraph = useCallback((): GraphCreatable => {
|
||||
const links = xyEdges.map((edge): LinkCreatable => {
|
||||
let sourceName = edge.sourceHandle || "";
|
||||
const sourceName = edge.sourceHandle || "";
|
||||
const sourceNode = xyNodes.find((node) => node.id === edge.source);
|
||||
const sinkNode = xyNodes.find((node) => node.id === edge.target);
|
||||
|
||||
// Special case for SmartDecisionMakerBlock
|
||||
if (
|
||||
sourceNode?.data.block_id === SpecialBlockID.SMART_DECISION &&
|
||||
sourceName.toLowerCase() === "tools"
|
||||
) {
|
||||
sourceName = `tools_^_${normalizeToolName(getToolFuncName(edge.target))}_~_${normalizeToolName(edge.targetHandle || "")}`;
|
||||
}
|
||||
return {
|
||||
source_id: sourceNode?.data.backend_id ?? edge.source,
|
||||
sink_id: sinkNode?.data.backend_id ?? edge.target,
|
||||
@@ -650,7 +620,6 @@ export default function useAgentGraph(
|
||||
agentDescription,
|
||||
agentRecommendedScheduleCron,
|
||||
prepareNodeInputData,
|
||||
getToolFuncName,
|
||||
]);
|
||||
|
||||
const resetEdgeBeads = useCallback(() => {
|
||||
|
||||
Reference in New Issue
Block a user