fix(blocks): Make Smart Decision Maker tool pin handling consistent and reliable (#11363)

- Resolves #11345

### Changes 🏗️

- Move tool use routing logic from frontend to backend: routing info was
being baked into graph links by the frontend, inconsistently, causing
issues
- Rework tool use routing to use target node ID instead of target block
name
- Add a bit of magic to `NodeOutputs` component to show tool node title
instead of ID

DX:
- Removed `build` from `.prettierignore` -> re-enable formatting for
builder components

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Use SDM block in a graph; verify it works
  - [x] Use SDM block with agent executor block as tool; verify it works
  - Tests for `parse_execution_output` pass (checked by CI)
This commit is contained in:
Reinier van der Leer
2025-11-12 18:55:38 +01:00
committed by GitHub
parent a3e5f7fce2
commit 536e2a5ec8
14 changed files with 276 additions and 139 deletions

View File

@@ -18,6 +18,7 @@ from backend.data.dynamic_fields import (
extract_base_field_name,
get_dynamic_field_description,
is_dynamic_field,
is_tool_pin,
)
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util import json
@@ -367,8 +368,9 @@ class SmartDecisionMakerBlock(Block):
"required": sorted(required_fields),
}
# Store field mapping for later use in output processing
# Store field mapping and node info for later use in output processing
tool_function["_field_mapping"] = field_mapping
tool_function["_sink_node_id"] = sink_node.id
return {"type": "function", "function": tool_function}
@@ -431,10 +433,13 @@ class SmartDecisionMakerBlock(Block):
"strict": True,
}
# Store node info for later use in output processing
tool_function["_sink_node_id"] = sink_node.id
return {"type": "function", "function": tool_function}
@staticmethod
async def _create_function_signature(
async def _create_tool_node_signatures(
node_id: str,
) -> list[dict[str, Any]]:
"""
@@ -450,7 +455,7 @@ class SmartDecisionMakerBlock(Block):
tools = [
(link, node)
for link, node in await db_client.get_connected_output_nodes(node_id)
if link.source_name.startswith("tools_^_") and link.source_id == node_id
if is_tool_pin(link.source_name) and link.source_id == node_id
]
if not tools:
raise ValueError("There is no next node to execute.")
@@ -538,8 +543,14 @@ class SmartDecisionMakerBlock(Block):
),
None,
)
if tool_def is None and len(tool_functions) == 1:
tool_def = tool_functions[0]
if tool_def is None:
if len(tool_functions) == 1:
tool_def = tool_functions[0]
else:
validation_errors_list.append(
f"Tool call for '{tool_name}' does not match any known "
"tool definition."
)
# Get parameters schema from tool definition
if (
@@ -591,7 +602,7 @@ class SmartDecisionMakerBlock(Block):
user_id: str,
**kwargs,
) -> BlockOutput:
tool_functions = await self._create_function_signature(node_id)
tool_functions = await self._create_tool_node_signatures(node_id)
yield "tool_functions", json.dumps(tool_functions)
input_data.conversation_history = input_data.conversation_history or []
@@ -661,9 +672,9 @@ class SmartDecisionMakerBlock(Block):
except ValueError as e:
last_error = e
error_feedback = (
"Your tool call had parameter errors. Please fix the following issues and try again:\n"
"Your tool call had errors. Please fix the following issues and try again:\n"
+ f"- {str(e)}\n"
+ "\nPlease make sure to use the exact parameter names as specified in the function schema."
+ "\nPlease make sure to use the exact tool and parameter names as specified in the function schema."
)
current_prompt = list(current_prompt) + [
{"role": "user", "content": error_feedback}
@@ -690,21 +701,23 @@ class SmartDecisionMakerBlock(Block):
),
None,
)
if (
tool_def
and "function" in tool_def
and "parameters" in tool_def["function"]
):
if not tool_def:
# NOTE: This matches the logic in _attempt_llm_call_with_validation and
# relies on its validation for the assumption that this is valid to use.
if len(tool_functions) == 1:
tool_def = tool_functions[0]
else:
# This should not happen due to prior validation
continue
if "function" in tool_def and "parameters" in tool_def["function"]:
expected_args = tool_def["function"]["parameters"].get("properties", {})
else:
expected_args = {arg: {} for arg in tool_args.keys()}
# Get field mapping from tool definition
field_mapping = (
tool_def.get("function", {}).get("_field_mapping", {})
if tool_def
else {}
)
# Get the sink node ID and field mapping from tool definition
field_mapping = tool_def["function"].get("_field_mapping", {})
sink_node_id = tool_def["function"]["_sink_node_id"]
for clean_arg_name in expected_args:
# arg_name is now always the cleaned field name (for Anthropic API compliance)
@@ -712,9 +725,8 @@ class SmartDecisionMakerBlock(Block):
original_field_name = field_mapping.get(clean_arg_name, clean_arg_name)
arg_value = tool_args.get(clean_arg_name)
sanitized_tool_name = self.cleanup(tool_name)
sanitized_arg_name = self.cleanup(original_field_name)
emit_key = f"tools_^_{sanitized_tool_name}_~_{sanitized_arg_name}"
emit_key = f"tools_^_{sink_node_id}_~_{sanitized_arg_name}"
logger.debug(
"[SmartDecisionMakerBlock|geid:%s|neid:%s] emit %s",

View File

@@ -165,7 +165,7 @@ async def test_smart_decision_maker_function_signature(server: SpinTestServer):
)
test_graph = await create_graph(server, test_graph, test_user)
tool_functions = await SmartDecisionMakerBlock._create_function_signature(
tool_functions = await SmartDecisionMakerBlock._create_tool_node_signatures(
test_graph.nodes[0].id
)
assert tool_functions is not None, "Tool functions should not be None"
@@ -215,7 +215,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
"content": "I need to think about this.",
}
# Mock the _create_function_signature method to avoid database calls
# Mock the _create_tool_node_signatures method to avoid database calls
from unittest.mock import AsyncMock
with patch(
@@ -224,7 +224,7 @@ async def test_smart_decision_maker_tracks_llm_stats():
return_value=mock_response,
), patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
):
@@ -293,6 +293,7 @@ async def test_smart_decision_maker_parameter_validation():
},
"required": ["query", "max_keyword_difficulty"],
},
"_sink_node_id": "test-sink-node-id",
},
}
]
@@ -318,7 +319,7 @@ async def test_smart_decision_maker_parameter_validation():
return_value=mock_response_with_typo,
) as mock_llm_call, patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
@@ -375,7 +376,7 @@ async def test_smart_decision_maker_parameter_validation():
return_value=mock_response_missing_required,
), patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
@@ -425,7 +426,7 @@ async def test_smart_decision_maker_parameter_validation():
return_value=mock_response_valid,
), patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
@@ -450,13 +451,13 @@ async def test_smart_decision_maker_parameter_validation():
outputs[output_name] = output_data
# Verify tool outputs were generated correctly
assert "tools_^_search_keywords_~_query" in outputs
assert outputs["tools_^_search_keywords_~_query"] == "test"
assert "tools_^_search_keywords_~_max_keyword_difficulty" in outputs
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
assert "tools_^_test-sink-node-id_~_query" in outputs
assert outputs["tools_^_test-sink-node-id_~_query"] == "test"
assert "tools_^_test-sink-node-id_~_max_keyword_difficulty" in outputs
assert outputs["tools_^_test-sink-node-id_~_max_keyword_difficulty"] == 50
# Optional parameter should be None when not provided
assert "tools_^_search_keywords_~_optional_param" in outputs
assert outputs["tools_^_search_keywords_~_optional_param"] is None
assert "tools_^_test-sink-node-id_~_optional_param" in outputs
assert outputs["tools_^_test-sink-node-id_~_optional_param"] is None
# Test case 4: Valid tool call with ALL parameters (should succeed)
mock_tool_call_all_params = MagicMock()
@@ -479,7 +480,7 @@ async def test_smart_decision_maker_parameter_validation():
return_value=mock_response_all_params,
), patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
@@ -504,9 +505,9 @@ async def test_smart_decision_maker_parameter_validation():
outputs[output_name] = output_data
# Verify all tool outputs were generated correctly
assert outputs["tools_^_search_keywords_~_query"] == "test"
assert outputs["tools_^_search_keywords_~_max_keyword_difficulty"] == 50
assert outputs["tools_^_search_keywords_~_optional_param"] == "custom_value"
assert outputs["tools_^_test-sink-node-id_~_query"] == "test"
assert outputs["tools_^_test-sink-node-id_~_max_keyword_difficulty"] == 50
assert outputs["tools_^_test-sink-node-id_~_optional_param"] == "custom_value"
@pytest.mark.asyncio
@@ -530,6 +531,7 @@ async def test_smart_decision_maker_raw_response_conversion():
"properties": {"param": {"type": "string"}},
"required": ["param"],
},
"_sink_node_id": "test-sink-node-id",
},
}
]
@@ -588,7 +590,7 @@ async def test_smart_decision_maker_raw_response_conversion():
"backend.blocks.llm.llm_call", new_callable=AsyncMock
) as mock_llm_call, patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=mock_tool_functions,
):
@@ -617,8 +619,8 @@ async def test_smart_decision_maker_raw_response_conversion():
outputs[output_name] = output_data
# Verify the tool output was generated successfully
assert "tools_^_test_tool_~_param" in outputs
assert outputs["tools_^_test_tool_~_param"] == "test_value"
assert "tools_^_test-sink-node-id_~_param" in outputs
assert outputs["tools_^_test-sink-node-id_~_param"] == "test_value"
# Verify conversation history was properly maintained
assert "conversations" in outputs
@@ -656,7 +658,7 @@ async def test_smart_decision_maker_raw_response_conversion():
return_value=mock_response_ollama,
), patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[], # No tools for this test
):
@@ -702,7 +704,7 @@ async def test_smart_decision_maker_raw_response_conversion():
return_value=mock_response_dict,
), patch.object(
SmartDecisionMakerBlock,
"_create_function_signature",
"_create_tool_node_signatures",
new_callable=AsyncMock,
return_value=[],
):

View File

@@ -192,7 +192,7 @@ async def test_create_block_function_signature_with_object_fields():
@pytest.mark.asyncio
async def test_create_function_signature():
async def test_create_tool_node_signatures():
"""Test that the mapping between sanitized and original field names is built correctly."""
block = SmartDecisionMakerBlock()
@@ -241,7 +241,7 @@ async def test_create_function_signature():
]
# Call the method that builds signatures
tool_functions = await block._create_function_signature("test_node_id")
tool_functions = await block._create_tool_node_signatures("test_node_id")
# Verify we got 2 tool functions (one for dict, one for list)
assert len(tool_functions) == 2
@@ -310,7 +310,7 @@ async def test_output_yielding_with_dynamic_fields():
# Mock the function signature creation
with patch.object(
block, "_create_function_signature", new_callable=AsyncMock
block, "_create_tool_node_signatures", new_callable=AsyncMock
) as mock_sig:
mock_sig.return_value = [
{
@@ -325,6 +325,7 @@ async def test_output_yielding_with_dynamic_fields():
"values___email": {"type": "string"},
},
},
"_sink_node_id": "test-sink-node-id",
},
}
]
@@ -351,16 +352,16 @@ async def test_output_yielding_with_dynamic_fields():
):
outputs[output_name] = output_value
# Verify the outputs use sanitized field names (matching frontend normalizeToolName)
assert "tools_^_createdictionaryblock_~_values___name" in outputs
assert outputs["tools_^_createdictionaryblock_~_values___name"] == "Alice"
# Verify the outputs use sink node ID in output keys
assert "tools_^_test-sink-node-id_~_values___name" in outputs
assert outputs["tools_^_test-sink-node-id_~_values___name"] == "Alice"
assert "tools_^_createdictionaryblock_~_values___age" in outputs
assert outputs["tools_^_createdictionaryblock_~_values___age"] == 30
assert "tools_^_test-sink-node-id_~_values___age" in outputs
assert outputs["tools_^_test-sink-node-id_~_values___age"] == 30
assert "tools_^_createdictionaryblock_~_values___email" in outputs
assert "tools_^_test-sink-node-id_~_values___email" in outputs
assert (
outputs["tools_^_createdictionaryblock_~_values___email"]
outputs["tools_^_test-sink-node-id_~_values___email"]
== "alice@example.com"
)
@@ -488,7 +489,7 @@ async def test_validation_errors_dont_pollute_conversation():
# Mock the function signature creation
with patch.object(
block, "_create_function_signature", new_callable=AsyncMock
block, "_create_tool_node_signatures", new_callable=AsyncMock
) as mock_sig:
mock_sig.return_value = [
{
@@ -505,6 +506,7 @@ async def test_validation_errors_dont_pollute_conversation():
},
"required": ["correct_param"],
},
"_sink_node_id": "test-sink-node-id",
},
}
]

View File

@@ -92,6 +92,18 @@ def get_dynamic_field_description(field_name: str) -> str:
return f"Value for {field_name}"
def is_tool_pin(name: str) -> bool:
"""Check if a pin name represents a tool connection."""
return name.startswith("tools_^_") or name == "tools"
def sanitize_pin_name(name: str) -> str:
sanitized_name = extract_base_field_name(name)
if is_tool_pin(sanitized_name):
return "tools"
return sanitized_name
# --------------------------------------------------------------------------- #
# Dynamic field parsing and merging utilities
# --------------------------------------------------------------------------- #
@@ -137,30 +149,64 @@ def _tokenise(path: str) -> list[tuple[str, str]] | None:
return tokens
def parse_execution_output(output: tuple[str, Any], name: str) -> Any:
def parse_execution_output(
output_item: tuple[str, Any],
link_output_selector: str,
sink_node_id: str | None = None,
sink_pin_name: str | None = None,
) -> Any:
"""
Retrieve a nested value out of `output` using the flattened *name*.
Retrieve a nested value out of `output` using the flattened `link_output_selector`.
On any failure (wrong name, wrong type, out-of-range, bad path)
returns **None**.
On any failure (wrong name, wrong type, out-of-range, bad path) returns **None**.
### Special Case: Tool pins
For regular output pins, the `output_item`'s name will simply be the field name, and
`link_output_selector` (= the `source_name` of the link) may provide a "selector"
used to extract part of the output value and route it through the link
to the next node.
However, for tool pins, it is the other way around: the `output_item`'s name
provides the routing information (`tools_^_{sink_node_id}_~_{field_name}`),
and the `link_output_selector` is simply `"tools"`
(or `"tools_^_{tool_name}_~_{field_name}"` for backward compatibility).
Args:
output: Tuple of (base_name, data) representing a block output entry
name: The flattened field name to extract from the output data
output_item: Tuple of (base_name, data) representing a block output entry.
link_output_selector: The flattened field name to extract from the output data.
sink_node_id: Sink node ID, used for tool use routing.
sink_pin_name: Sink pin name, used for tool use routing.
Returns:
The value at the specified path, or None if not found/invalid
The value at the specified path, or `None` if not found/invalid.
"""
base_name, data = output
output_pin_name, data = output_item
# Special handling for tool pins
if is_tool_pin(link_output_selector) and ( # "tools" or "tools_^_…"
output_pin_name.startswith("tools_^_") and "_~_" in output_pin_name
):
if not (sink_node_id and sink_pin_name):
raise ValueError(
"sink_node_id and sink_pin_name must be provided for tool pin routing"
)
# Extract routing information from emit key: tools_^_{node_id}_~_{field}
selector = output_pin_name[8:] # Remove "tools_^_" prefix
target_node_id, target_input_pin = selector.split("_~_", 1)
if target_node_id == sink_node_id and target_input_pin == sink_pin_name:
return data
else:
return None
# Exact match → whole object
if name == base_name:
if link_output_selector == output_pin_name:
return data
# Must start with the expected name
if not name.startswith(base_name):
if not link_output_selector.startswith(output_pin_name):
return None
path = name[len(base_name) :]
path = link_output_selector[len(output_pin_name) :]
if not path:
return None # nothing left to parse

View File

@@ -26,7 +26,7 @@ from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import LlmModel
from backend.data.db import prisma as db
from backend.data.dynamic_fields import extract_base_field_name
from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name
from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH
from backend.data.model import (
CredentialsField,
@@ -578,9 +578,9 @@ class GraphModel(Graph):
nodes_input_masks.get(node.id, {}) if nodes_input_masks else {}
)
provided_inputs = set(
[_sanitize_pin_name(name) for name in node.input_default]
[sanitize_pin_name(name) for name in node.input_default]
+ [
_sanitize_pin_name(link.sink_name)
sanitize_pin_name(link.sink_name)
for link in input_links.get(node.id, [])
]
+ ([name for name in node_input_mask] if node_input_mask else [])
@@ -696,7 +696,7 @@ class GraphModel(Graph):
f"{prefix}, {node.block_id} is invalid block id, available blocks: {blocks}"
)
sanitized_name = _sanitize_pin_name(name)
sanitized_name = sanitize_pin_name(name)
vals = node.input_default
if i == 0:
fields = (
@@ -710,7 +710,7 @@ class GraphModel(Graph):
if block.block_type not in [BlockType.AGENT]
else vals.get("input_schema", {}).get("properties", {}).keys()
)
if sanitized_name not in fields and not _is_tool_pin(name):
if sanitized_name not in fields and not is_tool_pin(name):
fields_msg = f"Allowed fields: {fields}"
raise ValueError(f"{prefix}, `{name}` invalid, {fields_msg}")
@@ -750,17 +750,6 @@ class GraphModel(Graph):
)
def _is_tool_pin(name: str) -> bool:
return name.startswith("tools_^_")
def _sanitize_pin_name(name: str) -> str:
sanitized_name = extract_base_field_name(name)
if _is_tool_pin(sanitized_name):
return "tools"
return sanitized_name
class GraphMeta(Graph):
user_id: str

View File

@@ -322,7 +322,9 @@ async def _enqueue_next_nodes(
next_node_id = node_link.sink_id
output_name, _ = output
next_data = parse_execution_output(output, next_output_name)
next_data = parse_execution_output(
output, next_output_name, next_node_id, next_input_name
)
if next_data is None and output_name != next_output_name:
return enqueued_executions
next_node = await db_client.get_node(next_node_id)

View File

@@ -111,6 +111,35 @@ def test_parse_execution_output():
parse_execution_output(output, "result_@_attr_$_0_#_key") is None
) # Should fail at @_attr
# Test case 7: Tool pin routing with matching node ID and pin name
output = ("tools_^_node123_~_query", "search term")
assert parse_execution_output(output, "tools", "node123", "query") == "search term"
# Test case 8: Tool pin routing with node ID mismatch
output = ("tools_^_node123_~_query", "search term")
assert parse_execution_output(output, "tools", "node456", "query") is None
# Test case 9: Tool pin routing with pin name mismatch
output = ("tools_^_node123_~_query", "search term")
assert parse_execution_output(output, "tools", "node123", "different_pin") is None
# Test case 10: Tool pin routing with complex field names
output = ("tools_^_node789_~_nested_field", {"key": "value"})
result = parse_execution_output(output, "tools", "node789", "nested_field")
assert result == {"key": "value"}
# Test case 11: Tool pin routing missing required parameters should raise error
output = ("tools_^_node123_~_query", "search term")
try:
parse_execution_output(output, "tools", "node123") # Missing sink_pin_name
assert False, "Should have raised ValueError"
except ValueError as e:
assert "must be provided for tool pin routing" in str(e)
# Test case 12: Non-tool pin with similar pattern should use normal logic
output = ("tools_^_node123_~_query", "search term")
assert parse_execution_output(output, "different_name", "node123", "query") is None
def test_merge_execution_input():
# Test case for basic list extraction

View File

@@ -2,7 +2,6 @@ node_modules
pnpm-lock.yaml
.next
.auth
build
public
Dockerfile
.prettierignore

View File

@@ -36,18 +36,22 @@ export const useGraphContent = ({
if (!node || !node.data) {
return "";
}
const inputs = Object.keys(node.data?.inputSchema?.properties || {});
const outputs = Object.keys(node.data?.outputSchema?.properties || {});
const parts = [];
if (inputs.length > 0) {
parts.push(`Inputs: ${inputs.slice(0, 3).join(", ")}${inputs.length > 3 ? "..." : ""}`);
parts.push(
`Inputs: ${inputs.slice(0, 3).join(", ")}${inputs.length > 3 ? "..." : ""}`,
);
}
if (outputs.length > 0) {
parts.push(`Outputs: ${outputs.slice(0, 3).join(", ")}${outputs.length > 3 ? "..." : ""}`);
parts.push(
`Outputs: ${outputs.slice(0, 3).join(", ")}${outputs.length > 3 ? "..." : ""}`,
);
}
return parts.join(" | ");
};
@@ -57,4 +61,4 @@ export const useGraphContent = ({
handleKeyDown,
getNodeInputOutputSummary,
};
};
};

View File

@@ -18,4 +18,4 @@ export const useGraphMenuSearchBarComponent = ({
inputRef,
handleClear,
};
};
};

View File

@@ -98,7 +98,7 @@ export type CustomNodeData = {
errors?: { [key: string]: string };
isOutputStatic?: boolean;
uiType: BlockUIType;
metadata?: { [key: string]: any };
metadata?: { customized_name?: string; [key: string]: any };
};
export type CustomNode = XYNode<CustomNodeData, "custom">;
@@ -201,7 +201,7 @@ export const CustomNode = React.memo(
useEffect(() => {
isInitialSetup.current = false;
if (data.backend_id) return; // don't auto-modify existing nodes
if (data.backend_id) return; // don't auto-modify existing nodes
if (data.uiType === BlockUIType.AGENT) {
setHardcodedValues({
@@ -822,7 +822,7 @@ export const CustomNode = React.memo(
{isTitleHovered && !isEditingTitle && (
<button
onClick={handleTitleEdit}
className="cursor-pointer rounded p-1 opacity-0 transition-opacity hover:bg-gray-100 group-hover:opacity-100"
className="cursor-pointer rounded p-1 opacity-0 transition-opacity group-hover:opacity-100 hover:bg-gray-100"
aria-label="Edit title"
>
<Pencil1Icon className="h-4 w-4" />

View File

@@ -36,11 +36,7 @@ import {
LibraryAgent,
} from "@/lib/autogpt-server-api";
import { Key, storage } from "@/services/storage/local-storage";
import {
getTypeColor,
findNewlyAddedBlockCoordinates,
beautifyString,
} from "@/lib/utils";
import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils";
import { history } from "../history";
import { CustomEdge } from "../CustomEdge/CustomEdge";
import ConnectionLine from "../ConnectionLine";
@@ -78,6 +74,7 @@ type BuilderContextType = {
visualizeBeads: "no" | "static" | "animate";
setIsAnyModalOpen: (isOpen: boolean) => void;
getNextNodeId: () => string;
getNodeTitle: (nodeID: string) => string | null;
};
export type NodeDimension = {
@@ -148,6 +145,9 @@ const FlowEditor: React.FC<{
flowExecutionID,
visualizeBeads !== "no",
);
const [immediateNodePositions, setImmediateNodePositions] = useState<
Record<string, { x: number; y: number }>
>(Object.fromEntries(nodes.map((node) => [node.id, node.position])));
const router = useRouter();
const pathname = usePathname();
@@ -241,6 +241,13 @@ const FlowEditor: React.FC<{
const oldPosition = initialPositionRef.current[node.id];
const newPosition = node.position;
// Clear immediate position, because on drag end it is no longer needed
setImmediateNodePositions((prevPositions) => {
const updatedPositions = { ...prevPositions };
delete updatedPositions[node.id];
return updatedPositions;
});
// Calculate the movement distance
if (!oldPosition || !newPosition) return;
@@ -278,6 +285,26 @@ const FlowEditor: React.FC<{
const onNodesChange = useCallback(
(nodeChanges: NodeChange<CustomNode>[]) => {
// Intercept position changes to update immediate positions & prevent excessive node re-renders
const draggingPosChanges = nodeChanges
.filter((c) => c.type === "position")
.filter((c) => c.dragging === true);
if (draggingPosChanges.length > 0) {
setImmediateNodePositions((prevPositions) => {
const newPositions = { ...prevPositions };
draggingPosChanges.forEach((change) => {
if (change.position) newPositions[change.id] = change.position;
});
return newPositions;
});
// Don't further process ongoing position changes
nodeChanges = nodeChanges.filter(
(change) => change.type !== "position" || change.dragging !== true,
);
if (nodeChanges.length === 0) return;
}
// Persist the changes
setNodes((prev) => applyNodeChanges(nodeChanges, prev));
@@ -675,6 +702,21 @@ const FlowEditor: React.FC<{
findNodeDimensions();
}, [nodes, findNodeDimensions]);
const getNodeTitle = useCallback(
(nodeID: string) => {
const node = nodes.find((n) => n.data.backend_id === nodeID);
if (!node) return null;
return (
node.data.metadata?.customized_name ||
(node.data.uiType == BlockUIType.AGENT &&
node.data.hardcodedValues.agent_name) ||
node.data.blockType.replace(/Block$/, "")
);
},
[nodes],
);
const handleCopyPaste = useCopyPaste(getNextNodeId);
const handleKeyDown = useCallback(
@@ -783,7 +825,7 @@ const FlowEditor: React.FC<{
data: {
blockType: blockName,
blockCosts: nodeSchema.costs || [],
title: `${beautifyString(blockName)} ${nodeId}`,
title: `${blockName} ${nodeId}`,
description: nodeSchema.description,
categories: nodeSchema.categories,
inputSchema: nodeSchema.inputSchema,
@@ -826,13 +868,29 @@ const FlowEditor: React.FC<{
],
);
const buildContextValue: BuilderContextType = useMemo(
() => ({
libraryAgent,
visualizeBeads,
setIsAnyModalOpen,
getNextNodeId,
getNodeTitle,
}),
[libraryAgent, visualizeBeads, getNextNodeId, getNodeTitle],
);
return (
<BuilderContext.Provider
value={{ libraryAgent, visualizeBeads, setIsAnyModalOpen, getNextNodeId }}
>
<BuilderContext.Provider value={buildContextValue}>
<div className={className}>
<ReactFlow
nodes={nodes}
nodes={nodes.map((node) =>
node.id in immediateNodePositions
? {
...node,
position: immediateNodePositions[node.id] || node.position,
}
: node,
)}
edges={edges}
nodeTypes={{ custom: CustomNode }}
edgeTypes={{ custom: CustomEdge }}

View File

@@ -1,9 +1,12 @@
import React, { useState } from "react";
import { ContentRenderer } from "../../../../../components/__legacy__/ui/render";
import { beautifyString } from "@/lib/utils";
import React, { useContext, useState } from "react";
import { Button } from "@/components/__legacy__/ui/button";
import { Maximize2 } from "lucide-react";
import { Button } from "../../../../../components/__legacy__/ui/button";
import * as Separator from "@radix-ui/react-separator";
import { ContentRenderer } from "@/components/__legacy__/ui/render";
import { beautifyString } from "@/lib/utils";
import { BuilderContext } from "./Flow/Flow";
import ExpandableOutputDialog from "./ExpandableOutputDialog";
type NodeOutputsProps = {
@@ -17,6 +20,8 @@ export default function NodeOutputs({
truncateLongData,
data,
}: NodeOutputsProps) {
const builderContext = useContext(BuilderContext);
const [expandedDialog, setExpandedDialog] = useState<{
isOpen: boolean;
execId: string;
@@ -24,6 +29,26 @@ export default function NodeOutputs({
data: any[];
} | null>(null);
if (!builderContext) {
throw new Error(
"BuilderContext consumer must be inside FlowEditor component",
);
}
const { getNodeTitle } = builderContext;
const getBeautifiedPinName = (pin: string) => {
if (!pin.startsWith("tools_^_")) {
return beautifyString(pin);
}
// Special handling for tool pins: replace node ID with node title
const toolNodeID = pin.slice(8).split("_~_")[0]; // tools_^_{node_id}_~_{field}
const toolNodeTitle = getNodeTitle(toolNodeID);
return toolNodeTitle
? beautifyString(pin.replace(toolNodeID, toolNodeTitle))
: beautifyString(pin);
};
const openExpandedView = (pinName: string, pinData: any[]) => {
setExpandedDialog({
isOpen: true,
@@ -44,7 +69,7 @@ export default function NodeOutputs({
<div className="flex items-center justify-between">
<div className="flex items-center">
<strong className="mr-2">Pin:</strong>
<span>{beautifyString(pin)}</span>
<span>{getBeautifiedPinName(pin)}</span>
</div>
{(truncateLongData || dataArray.length > 10) && (
<Button

View File

@@ -17,7 +17,6 @@ import {
LinkCreatable,
NodeCreatable,
NodeExecutionResult,
SpecialBlockID,
Node,
} from "@/lib/autogpt-server-api";
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
@@ -276,28 +275,6 @@ export default function useAgentGraph(
const cleanupSourceName = (sourceName: string) =>
isToolSourceName(sourceName) ? "tools" : sourceName;
const getToolFuncName = useCallback(
(nodeID: string) => {
const sinkNode = xyNodes.find((node) => node.id === nodeID);
if (!sinkNode) return "";
const sinkNodeName =
sinkNode.data.block_id === SpecialBlockID.AGENT
? sinkNode.data.hardcodedValues?.agent_name ||
availableFlows.find(
(flow) => flow.id === sinkNode.data.hardcodedValues.graph_id,
)?.name ||
"agentexecutorblock"
: sinkNode.data.title.split(" ")[0];
return sinkNodeName;
},
[xyNodes, availableFlows],
);
const normalizeToolName = (str: string) =>
str.replace(/[^a-zA-Z0-9_-]/g, "_").toLowerCase(); // This normalization rule has to match with the one on smart_decision_maker.py
/** ------------------------------ */
const updateEdgeBeads = useCallback(
@@ -607,17 +584,10 @@ export default function useAgentGraph(
const prepareSaveableGraph = useCallback((): GraphCreatable => {
const links = xyEdges.map((edge): LinkCreatable => {
let sourceName = edge.sourceHandle || "";
const sourceName = edge.sourceHandle || "";
const sourceNode = xyNodes.find((node) => node.id === edge.source);
const sinkNode = xyNodes.find((node) => node.id === edge.target);
// Special case for SmartDecisionMakerBlock
if (
sourceNode?.data.block_id === SpecialBlockID.SMART_DECISION &&
sourceName.toLowerCase() === "tools"
) {
sourceName = `tools_^_${normalizeToolName(getToolFuncName(edge.target))}_~_${normalizeToolName(edge.targetHandle || "")}`;
}
return {
source_id: sourceNode?.data.backend_id ?? edge.source,
sink_id: sinkNode?.data.backend_id ?? edge.target,
@@ -650,7 +620,6 @@ export default function useAgentGraph(
agentDescription,
agentRecommendedScheduleCron,
prepareNodeInputData,
getToolFuncName,
]);
const resetEdgeBeads = useCallback(() => {