mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
36 Commits
dev
...
fix/orches
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f796390994 | ||
|
|
ec9e45df98 | ||
|
|
ebd476471e | ||
|
|
41578af2da | ||
|
|
d27b97cf74 | ||
|
|
0f7a47ccfa | ||
|
|
db691aff62 | ||
|
|
643b8a04b9 | ||
|
|
ebf9491dcb | ||
|
|
9ab49f7a4a | ||
|
|
091c5ebb67 | ||
|
|
f77cde318f | ||
|
|
2e69220d55 | ||
|
|
97b5fb89f6 | ||
|
|
7f435c3b02 | ||
|
|
82c1b06940 | ||
|
|
f67e05ef98 | ||
|
|
8cddf26ed6 | ||
|
|
3f5c2b93cd | ||
|
|
7e62fdae48 | ||
|
|
a866e8d709 | ||
|
|
df2cd316f8 | ||
|
|
f1b1c19612 | ||
|
|
e5d42fcb99 | ||
|
|
8b289cacdd | ||
|
|
1eeac09801 | ||
|
|
afc02697af | ||
|
|
dd0cca48b4 | ||
|
|
60d4dd8ff2 | ||
|
|
8548bfcc4e | ||
|
|
7e19a1aa68 | ||
|
|
5723c1e230 | ||
|
|
1968ecf355 | ||
|
|
f4d6bc1f5b | ||
|
|
823eb3d15a | ||
|
|
270c2f0f55 |
@@ -28,6 +28,7 @@ from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
|
||||
from backend.util.security import filter_sensitive_fields
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import Link, Node
|
||||
@@ -35,6 +36,9 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Anthropic imposes a 64-character limit on tool/function names.
|
||||
MAX_TOOL_NAME_LENGTH = 64
|
||||
|
||||
|
||||
class ToolInfo(BaseModel):
|
||||
"""Processed tool call information."""
|
||||
@@ -258,6 +262,71 @@ def get_pending_tool_calls(conversation_history: list[Any] | None) -> dict[str,
|
||||
return {call_id: count for call_id, count in pending_calls.items() if count > 0}
|
||||
|
||||
|
||||
def _disambiguate_tool_names(tools: list[dict[str, Any]]) -> None:
|
||||
"""Ensure all tool names are unique (Anthropic API requires this).
|
||||
|
||||
When multiple nodes use the same block type, they get the same tool name.
|
||||
This appends _1, _2, etc. and enriches descriptions with hardcoded defaults
|
||||
so the LLM can distinguish them. Mutates the list in place.
|
||||
|
||||
Malformed tools (missing ``function`` or ``function.name``) are silently
|
||||
skipped so the caller never crashes on unexpected input.
|
||||
"""
|
||||
# Collect tools that have the required structure, skipping malformed ones.
|
||||
valid_tools: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
func = tool.get("function") if isinstance(tool, dict) else None
|
||||
if not isinstance(func, dict) or not isinstance(func.get("name"), str):
|
||||
# Strip internal metadata even from malformed entries.
|
||||
if isinstance(func, dict):
|
||||
func.pop("_hardcoded_defaults", None)
|
||||
continue
|
||||
valid_tools.append(tool)
|
||||
|
||||
names = [t.get("function", {}).get("name", "") for t in valid_tools]
|
||||
name_counts = Counter(names)
|
||||
duplicates = {n for n, c in name_counts.items() if c > 1}
|
||||
|
||||
if not duplicates:
|
||||
for t in valid_tools:
|
||||
t.get("function", {}).pop("_hardcoded_defaults", None)
|
||||
return
|
||||
|
||||
taken: set[str] = set(names)
|
||||
counters: dict[str, int] = {}
|
||||
|
||||
for tool in valid_tools:
|
||||
func = tool.get("function", {})
|
||||
name = func.get("name", "")
|
||||
defaults = func.pop("_hardcoded_defaults", {})
|
||||
|
||||
if name not in duplicates:
|
||||
continue
|
||||
|
||||
counters[name] = counters.get(name, 0) + 1
|
||||
# Skip suffixes that collide with existing (e.g. user-named) tools
|
||||
while True:
|
||||
suffix = f"_{counters[name]}"
|
||||
candidate = f"{name[: MAX_TOOL_NAME_LENGTH - len(suffix)]}{suffix}"
|
||||
if candidate not in taken:
|
||||
break
|
||||
counters[name] += 1
|
||||
|
||||
func["name"] = candidate
|
||||
taken.add(candidate)
|
||||
|
||||
if defaults and isinstance(defaults, dict):
|
||||
parts: list[str] = []
|
||||
for k, v in defaults.items():
|
||||
rendered = json.dumps(v)
|
||||
if len(rendered) > 100:
|
||||
rendered = rendered[:80] + "...<truncated>"
|
||||
parts.append(f"{k}={rendered}")
|
||||
summary = ", ".join(parts)
|
||||
original_desc = func.get("description", "") or ""
|
||||
func["description"] = f"{original_desc} [Pre-configured: {summary}]"
|
||||
|
||||
|
||||
class OrchestratorBlock(Block):
|
||||
"""
|
||||
A block that uses a language model to orchestrate tool calls, supporting both
|
||||
@@ -507,6 +576,19 @@ class OrchestratorBlock(Block):
|
||||
tool_function["_field_mapping"] = field_mapping
|
||||
tool_function["_sink_node_id"] = sink_node.id
|
||||
|
||||
# Store hardcoded defaults (non-linked inputs) for disambiguation.
|
||||
# Exclude linked fields, private fields, and credential/auth fields
|
||||
# to avoid leaking sensitive data into tool descriptions.
|
||||
defaults = sink_node.input_default
|
||||
tool_function["_hardcoded_defaults"] = (
|
||||
filter_sensitive_fields(
|
||||
defaults,
|
||||
linked_fields={link.sink_name for link in links},
|
||||
)
|
||||
if isinstance(defaults, dict)
|
||||
else {}
|
||||
)
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
@@ -581,6 +663,21 @@ class OrchestratorBlock(Block):
|
||||
tool_function["_field_mapping"] = field_mapping
|
||||
tool_function["_sink_node_id"] = sink_node.id
|
||||
|
||||
# Store hardcoded defaults (non-linked inputs) for disambiguation.
|
||||
# Exclude linked fields, private fields, agent meta fields, and
|
||||
# credential/auth fields to avoid leaking sensitive data.
|
||||
_AGENT_META_FIELDS = frozenset({"graph_id", "graph_version", "input_schema"})
|
||||
defaults = sink_node.input_default
|
||||
tool_function["_hardcoded_defaults"] = (
|
||||
filter_sensitive_fields(
|
||||
defaults,
|
||||
linked_fields={link.sink_name for link in links},
|
||||
extra_excludes=_AGENT_META_FIELDS,
|
||||
)
|
||||
if isinstance(defaults, dict)
|
||||
else {}
|
||||
)
|
||||
|
||||
return {"type": "function", "function": tool_function}
|
||||
|
||||
@staticmethod
|
||||
@@ -629,6 +726,7 @@ class OrchestratorBlock(Block):
|
||||
)
|
||||
return_tool_functions.append(tool_func)
|
||||
|
||||
_disambiguate_tool_names(return_tool_functions)
|
||||
return return_tool_functions
|
||||
|
||||
async def _attempt_llm_call_with_validation(
|
||||
@@ -996,7 +1094,10 @@ class OrchestratorBlock(Block):
|
||||
credentials, input_data, iteration_prompt, tool_functions
|
||||
)
|
||||
except Exception as e:
|
||||
yield "error", f"LLM call failed in agent mode iteration {iteration}: {str(e)}"
|
||||
yield (
|
||||
"error",
|
||||
f"LLM call failed in agent mode iteration {iteration}: {str(e)}",
|
||||
)
|
||||
return
|
||||
|
||||
# Process tool calls
|
||||
@@ -1041,7 +1142,10 @@ class OrchestratorBlock(Block):
|
||||
if max_iterations < 0:
|
||||
yield "finished", f"Agent mode completed after {iteration} iterations"
|
||||
else:
|
||||
yield "finished", f"Agent mode completed after {max_iterations} iterations (limit reached)"
|
||||
yield (
|
||||
"finished",
|
||||
f"Agent mode completed after {max_iterations} iterations (limit reached)",
|
||||
)
|
||||
yield "conversations", current_prompt
|
||||
|
||||
async def run(
|
||||
|
||||
@@ -1074,6 +1074,7 @@ async def test_orchestrator_uses_customized_name_for_blocks():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {"customized_name": "My Custom Tool Name"}
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
@@ -1105,6 +1106,7 @@ async def test_orchestrator_falls_back_to_block_name():
|
||||
mock_node.block_id = StoreValueBlock().id
|
||||
mock_node.metadata = {} # No customized_name
|
||||
mock_node.block = StoreValueBlock()
|
||||
mock_node.input_default = {}
|
||||
|
||||
# Create a mock link
|
||||
mock_link = MagicMock(spec=Link)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
80
autogpt_platform/backend/backend/util/security.py
Normal file
80
autogpt_platform/backend/backend/util/security.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Shared security constants for field-level filtering.
|
||||
|
||||
Other modules (e.g. orchestrator, future blocks) import from here so the
|
||||
sensitive-field list stays in one place.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
# Substrings used for case-insensitive matching against field names.
|
||||
# A field is considered sensitive if any of these appear anywhere in the
|
||||
# lowercased field name (substring match, not exact match).
|
||||
SENSITIVE_FIELD_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
"credentials",
|
||||
"api_key",
|
||||
"password",
|
||||
"secret",
|
||||
"secret_key",
|
||||
"private_key",
|
||||
"client_secret",
|
||||
"token",
|
||||
"auth",
|
||||
"authorization",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"bearer_token",
|
||||
"passphrase",
|
||||
"webhook_secret",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_sensitive_field(field_name: str) -> bool:
|
||||
"""Check if a field name is sensitive using substring matching.
|
||||
|
||||
Returns True if the lowercased field_name contains any of the
|
||||
SENSITIVE_FIELD_NAMES as a substring.
|
||||
"""
|
||||
lower = field_name.lower()
|
||||
return any(s in lower for s in SENSITIVE_FIELD_NAMES)
|
||||
|
||||
|
||||
def filter_sensitive_fields(
|
||||
data: dict[str, Any],
|
||||
*,
|
||||
extra_excludes: frozenset[str] | None = None,
|
||||
linked_fields: set[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return a copy of *data* with sensitive and private fields removed.
|
||||
|
||||
This also recursively scans one level of nested dicts to remove keys
|
||||
that match sensitive field names, preventing secrets from leaking
|
||||
through benign top-level key names (e.g. ``{"config": {"api_key": "..."}}``)
|
||||
|
||||
Args:
|
||||
data: The dict to filter.
|
||||
extra_excludes: Additional exact field names to exclude (e.g.
|
||||
``{"graph_id", "graph_version", "input_schema"}``).
|
||||
linked_fields: Field names to exclude because they are linked.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
excludes = extra_excludes or frozenset()
|
||||
linked = linked_fields or set()
|
||||
|
||||
for k, v in data.items():
|
||||
if k in linked:
|
||||
continue
|
||||
if k.startswith("_"):
|
||||
continue
|
||||
if k in excludes:
|
||||
continue
|
||||
if is_sensitive_field(k):
|
||||
continue
|
||||
# Recursively filter nested dicts one level deep
|
||||
if isinstance(v, dict):
|
||||
v = {nk: nv for nk, nv in v.items() if not is_sensitive_field(nk)}
|
||||
if not v:
|
||||
continue
|
||||
result[k] = v
|
||||
return result
|
||||
Reference in New Issue
Block a user