mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-31 09:58:19 -05:00
fix(backend/chat): Include input schema in discovery and validate unknown fields (#11916)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .models import (
|
||||
@@ -14,6 +17,7 @@ from .models import (
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from .utils import fetch_graph_from_store_slug
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -54,7 +58,28 @@ async def search_agents(
|
||||
if source == "marketplace":
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
|
||||
# Fetch all graphs in parallel for better performance
|
||||
async def fetch_marketplace_graph(
|
||||
creator: str, slug: str
|
||||
) -> GraphModel | None:
|
||||
try:
|
||||
graph, _ = await fetch_graph_from_store_slug(creator, slug)
|
||||
return graph
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch input schema for {creator}/{slug}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
graphs = await asyncio.gather(
|
||||
*(
|
||||
fetch_marketplace_graph(agent.creator, agent.slug)
|
||||
for agent in results.agents
|
||||
)
|
||||
)
|
||||
|
||||
for agent, graph in zip(results.agents, graphs):
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=f"{agent.creator}/{agent.slug}",
|
||||
@@ -67,6 +92,7 @@ async def search_agents(
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False,
|
||||
inputs=graph.input_schema if graph else None,
|
||||
)
|
||||
)
|
||||
else: # library
|
||||
@@ -76,7 +102,32 @@ async def search_agents(
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
|
||||
# Fetch all graphs in parallel for better performance
|
||||
# (list_library_agents doesn't include nodes for performance)
|
||||
async def fetch_library_graph(
|
||||
graph_id: str, graph_version: int
|
||||
) -> GraphModel | None:
|
||||
try:
|
||||
return await graph_db.get_graph(
|
||||
graph_id=graph_id,
|
||||
version=graph_version,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch input schema for graph {graph_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
graphs = await asyncio.gather(
|
||||
*(
|
||||
fetch_library_graph(agent.graph_id, agent.graph_version)
|
||||
for agent in results.agents
|
||||
)
|
||||
)
|
||||
|
||||
for agent, graph in zip(results.agents, graphs):
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
@@ -90,6 +141,7 @@ async def search_agents(
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
inputs=graph.input_schema if graph else None,
|
||||
)
|
||||
)
|
||||
logger.info(f"Found {len(agents)} agents in {source}")
|
||||
|
||||
@@ -32,6 +32,8 @@ class ResponseType(str, Enum):
|
||||
OPERATION_STARTED = "operation_started"
|
||||
OPERATION_PENDING = "operation_pending"
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
# Input validation
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -62,6 +64,10 @@ class AgentInfo(BaseModel):
|
||||
has_external_trigger: bool | None = None
|
||||
new_output: bool | None = None
|
||||
graph_id: str | None = None
|
||||
inputs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Input schema for the agent, including field names, types, and defaults",
|
||||
)
|
||||
|
||||
|
||||
class AgentsFoundResponse(ToolResponseBase):
|
||||
@@ -188,6 +194,20 @@ class ErrorResponse(ToolResponseBase):
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class InputValidationErrorResponse(ToolResponseBase):
|
||||
"""Response when run_agent receives unknown input fields."""
|
||||
|
||||
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
|
||||
unrecognized_fields: list[str] = Field(
|
||||
description="List of input field names that were not recognized"
|
||||
)
|
||||
inputs: dict[str, Any] = Field(
|
||||
description="The agent's valid input schema for reference"
|
||||
)
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
# Agent output models
|
||||
class ExecutionOutputInfo(BaseModel):
|
||||
"""Summary of a single execution's outputs."""
|
||||
|
||||
@@ -30,6 +30,7 @@ from .models import (
|
||||
ErrorResponse,
|
||||
ExecutionOptions,
|
||||
ExecutionStartedResponse,
|
||||
InputValidationErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
@@ -273,6 +274,22 @@ class RunAgentTool(BaseTool):
|
||||
input_properties = graph.input_schema.get("properties", {})
|
||||
required_fields = set(graph.input_schema.get("required", []))
|
||||
provided_inputs = set(params.inputs.keys())
|
||||
valid_fields = set(input_properties.keys())
|
||||
|
||||
# Check for unknown input fields
|
||||
unrecognized_fields = provided_inputs - valid_fields
|
||||
if unrecognized_fields:
|
||||
return InputValidationErrorResponse(
|
||||
message=(
|
||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||
f"Agent was not executed. Please use the correct field names from the schema."
|
||||
),
|
||||
session_id=session_id,
|
||||
unrecognized_fields=sorted(unrecognized_fields),
|
||||
inputs=graph.input_schema,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# If agent has inputs but none were provided AND use_defaults is not set,
|
||||
# always show what's available first so user can decide
|
||||
|
||||
@@ -402,3 +402,42 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
||||
# Should return error about missing schedule_name
|
||||
assert result_data.get("type") == "error"
|
||||
assert "schedule_name" in result_data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
|
||||
"""Test that run_agent returns input_validation_error for unknown input fields."""
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
tool = RunAgentTool()
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute with unknown input field names
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={
|
||||
"unknown_field": "some value",
|
||||
"another_unknown": "another value",
|
||||
},
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return input_validation_error type with unrecognized fields
|
||||
assert result_data.get("type") == "input_validation_error"
|
||||
assert "unrecognized_fields" in result_data
|
||||
assert set(result_data["unrecognized_fields"]) == {
|
||||
"another_unknown",
|
||||
"unknown_field",
|
||||
}
|
||||
assert "inputs" in result_data # Contains the valid schema
|
||||
assert "Agent was not executed" in result_data["message"]
|
||||
|
||||
Reference in New Issue
Block a user