mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-30 17:38:17 -05:00
fix: add defensive null checks and validation to agent generator
- Add isinstance(str) checks to prevent AttributeError when agent name is None - Add AgentJsonValidationError for better error messages on invalid agent JSON - Add validation in json_to_graph for missing required fields (block_id, source_id, etc) - Remove try-except that was swallowing DatabaseError in get_all_relevant_agents_for_generation - Add DatabaseError handler in enrich_library_agents_from_steps to properly re-raise - Export AgentJsonValidationError from __init__.py
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
from .core import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
AgentJsonValidationError,
|
||||
AgentSummary,
|
||||
DecompositionResult,
|
||||
DecompositionStep,
|
||||
@@ -28,6 +29,7 @@ from .service import is_external_service_configured
|
||||
|
||||
__all__ = [
|
||||
"AgentGeneratorNotConfiguredError",
|
||||
"AgentJsonValidationError",
|
||||
"AgentSummary",
|
||||
"DecompositionResult",
|
||||
"DecompositionStep",
|
||||
|
||||
@@ -335,11 +335,12 @@ async def get_all_relevant_agents_for_generation(
|
||||
if graph_id == exclude_graph_id:
|
||||
continue
|
||||
agent = await get_library_agent_by_graph_id(user_id, graph_id)
|
||||
if agent and agent.get("graph_id") not in seen_graph_ids:
|
||||
agent_graph_id = agent.get("graph_id") if agent else None
|
||||
if agent and agent_graph_id and agent_graph_id not in seen_graph_ids:
|
||||
agents.append(agent)
|
||||
seen_graph_ids.add(agent.get("graph_id", ""))
|
||||
seen_graph_ids.add(agent_graph_id)
|
||||
logger.debug(
|
||||
f"Found explicitly mentioned agent: {agent.get('name', 'Unknown')}"
|
||||
f"Found explicitly mentioned agent: {agent.get('name') or 'Unknown'}"
|
||||
)
|
||||
|
||||
if include_library:
|
||||
@@ -360,11 +361,16 @@ async def get_all_relevant_agents_for_generation(
|
||||
search_query=search_query,
|
||||
max_results=max_marketplace_results,
|
||||
)
|
||||
library_names = {name.lower() for a in agents if (name := a.get("name"))}
|
||||
library_names: set[str] = set()
|
||||
for a in agents:
|
||||
name = a.get("name")
|
||||
if name and isinstance(name, str):
|
||||
library_names.add(name.lower())
|
||||
for agent in marketplace_agents:
|
||||
agent_name = agent.get("name")
|
||||
if agent_name and agent_name.lower() not in library_names:
|
||||
agents.append(agent)
|
||||
if agent_name and isinstance(agent_name, str):
|
||||
if agent_name.lower() not in library_names:
|
||||
agents.append(agent)
|
||||
|
||||
return agents
|
||||
|
||||
@@ -444,11 +450,11 @@ async def enrich_library_agents_from_steps(
|
||||
existing_names: set[str] = set()
|
||||
|
||||
for agent in existing_agents:
|
||||
agent_name = agent.get("name", "")
|
||||
if agent_name:
|
||||
agent_name = agent.get("name")
|
||||
if agent_name and isinstance(agent_name, str):
|
||||
existing_names.add(agent_name.lower())
|
||||
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||
if graph_id:
|
||||
if graph_id and isinstance(graph_id, str):
|
||||
existing_ids.add(graph_id)
|
||||
|
||||
all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents)
|
||||
@@ -465,8 +471,8 @@ async def enrich_library_agents_from_steps(
|
||||
)
|
||||
|
||||
for agent in additional_agents:
|
||||
agent_name = agent.get("name", "")
|
||||
if not agent_name:
|
||||
agent_name = agent.get("name")
|
||||
if not agent_name or not isinstance(agent_name, str):
|
||||
continue
|
||||
agent_name_lower = agent_name.lower()
|
||||
|
||||
@@ -479,9 +485,12 @@ async def enrich_library_agents_from_steps(
|
||||
|
||||
all_agents.append(agent)
|
||||
existing_names.add(agent_name_lower)
|
||||
if graph_id:
|
||||
if graph_id and isinstance(graph_id, str):
|
||||
existing_ids.add(graph_id)
|
||||
|
||||
except DatabaseError:
|
||||
logger.error(f"Database error searching for agents with term '{term}'")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to search for additional agents with term '{term}': {e}"
|
||||
@@ -557,6 +566,12 @@ async def generate_agent(
|
||||
return result
|
||||
|
||||
|
||||
class AgentJsonValidationError(Exception):
|
||||
"""Raised when agent JSON is invalid or missing required fields."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
"""Convert agent JSON dict to Graph model.
|
||||
|
||||
@@ -565,25 +580,55 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
|
||||
Returns:
|
||||
Graph ready for saving
|
||||
|
||||
Raises:
|
||||
AgentJsonValidationError: If required fields are missing from nodes or links
|
||||
"""
|
||||
nodes = []
|
||||
for n in agent_json.get("nodes", []):
|
||||
for idx, n in enumerate(agent_json.get("nodes", [])):
|
||||
block_id = n.get("block_id")
|
||||
if not block_id:
|
||||
node_id = n.get("id", f"index_{idx}")
|
||||
raise AgentJsonValidationError(
|
||||
f"Node '{node_id}' is missing required field 'block_id'"
|
||||
)
|
||||
node = Node(
|
||||
id=n.get("id", str(uuid.uuid4())),
|
||||
block_id=n["block_id"],
|
||||
block_id=block_id,
|
||||
input_default=n.get("input_default", {}),
|
||||
metadata=n.get("metadata", {}),
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
links = []
|
||||
for link_data in agent_json.get("links", []):
|
||||
for idx, link_data in enumerate(agent_json.get("links", [])):
|
||||
source_id = link_data.get("source_id")
|
||||
sink_id = link_data.get("sink_id")
|
||||
source_name = link_data.get("source_name")
|
||||
sink_name = link_data.get("sink_name")
|
||||
|
||||
missing_fields = []
|
||||
if not source_id:
|
||||
missing_fields.append("source_id")
|
||||
if not sink_id:
|
||||
missing_fields.append("sink_id")
|
||||
if not source_name:
|
||||
missing_fields.append("source_name")
|
||||
if not sink_name:
|
||||
missing_fields.append("sink_name")
|
||||
|
||||
if missing_fields:
|
||||
link_id = link_data.get("id", f"index_{idx}")
|
||||
raise AgentJsonValidationError(
|
||||
f"Link '{link_id}' is missing required fields: {', '.join(missing_fields)}"
|
||||
)
|
||||
|
||||
link = Link(
|
||||
id=link_data.get("id", str(uuid.uuid4())),
|
||||
source_id=link_data["source_id"],
|
||||
sink_id=link_data["sink_id"],
|
||||
source_name=link_data["source_name"],
|
||||
sink_name=link_data["sink_name"],
|
||||
source_id=source_id,
|
||||
sink_id=sink_id,
|
||||
source_name=source_name,
|
||||
sink_name=sink_name,
|
||||
is_static=link_data.get("is_static", False),
|
||||
)
|
||||
links.append(link)
|
||||
|
||||
Reference in New Issue
Block a user