fix: add defensive null checks and validation to agent generator

- Add isinstance(str) checks to prevent AttributeError when agent name is None
- Add AgentJsonValidationError for better error messages on invalid agent JSON
- Add validation in json_to_graph for missing required fields (block_id, source_id, etc)
- Remove try-except that was swallowing DatabaseError in get_all_relevant_agents_for_generation
- Add DatabaseError handler in enrich_library_agents_from_steps to properly re-raise
- Export AgentJsonValidationError from __init__.py
This commit is contained in:
Zamil Majdy
2026-01-30 13:55:29 -06:00
parent 82a9612192
commit a042c3285d
2 changed files with 66 additions and 19 deletions

View File

@@ -2,6 +2,7 @@
from .core import (
AgentGeneratorNotConfiguredError,
AgentJsonValidationError,
AgentSummary,
DecompositionResult,
DecompositionStep,
@@ -28,6 +29,7 @@ from .service import is_external_service_configured
__all__ = [
"AgentGeneratorNotConfiguredError",
"AgentJsonValidationError",
"AgentSummary",
"DecompositionResult",
"DecompositionStep",

View File

@@ -335,11 +335,12 @@ async def get_all_relevant_agents_for_generation(
if graph_id == exclude_graph_id:
continue
agent = await get_library_agent_by_graph_id(user_id, graph_id)
if agent and agent.get("graph_id") not in seen_graph_ids:
agent_graph_id = agent.get("graph_id") if agent else None
if agent and agent_graph_id and agent_graph_id not in seen_graph_ids:
agents.append(agent)
seen_graph_ids.add(agent.get("graph_id", ""))
seen_graph_ids.add(agent_graph_id)
logger.debug(
f"Found explicitly mentioned agent: {agent.get('name', 'Unknown')}"
f"Found explicitly mentioned agent: {agent.get('name') or 'Unknown'}"
)
if include_library:
@@ -360,11 +361,16 @@ async def get_all_relevant_agents_for_generation(
search_query=search_query,
max_results=max_marketplace_results,
)
library_names = {name.lower() for a in agents if (name := a.get("name"))}
library_names: set[str] = set()
for a in agents:
name = a.get("name")
if name and isinstance(name, str):
library_names.add(name.lower())
for agent in marketplace_agents:
agent_name = agent.get("name")
if agent_name and agent_name.lower() not in library_names:
agents.append(agent)
if agent_name and isinstance(agent_name, str):
if agent_name.lower() not in library_names:
agents.append(agent)
return agents
@@ -444,11 +450,11 @@ async def enrich_library_agents_from_steps(
existing_names: set[str] = set()
for agent in existing_agents:
agent_name = agent.get("name", "")
if agent_name:
agent_name = agent.get("name")
if agent_name and isinstance(agent_name, str):
existing_names.add(agent_name.lower())
graph_id = agent.get("graph_id") # type: ignore[call-overload]
if graph_id:
if graph_id and isinstance(graph_id, str):
existing_ids.add(graph_id)
all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents)
@@ -465,8 +471,8 @@ async def enrich_library_agents_from_steps(
)
for agent in additional_agents:
agent_name = agent.get("name", "")
if not agent_name:
agent_name = agent.get("name")
if not agent_name or not isinstance(agent_name, str):
continue
agent_name_lower = agent_name.lower()
@@ -479,9 +485,12 @@ async def enrich_library_agents_from_steps(
all_agents.append(agent)
existing_names.add(agent_name_lower)
if graph_id:
if graph_id and isinstance(graph_id, str):
existing_ids.add(graph_id)
except DatabaseError:
logger.error(f"Database error searching for agents with term '{term}'")
raise
except Exception as e:
logger.warning(
f"Failed to search for additional agents with term '{term}': {e}"
@@ -557,6 +566,12 @@ async def generate_agent(
return result
class AgentJsonValidationError(Exception):
"""Raised when agent JSON is invalid or missing required fields."""
pass
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
"""Convert agent JSON dict to Graph model.
@@ -565,25 +580,55 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
Returns:
Graph ready for saving
Raises:
AgentJsonValidationError: If required fields are missing from nodes or links
"""
nodes = []
for n in agent_json.get("nodes", []):
for idx, n in enumerate(agent_json.get("nodes", [])):
block_id = n.get("block_id")
if not block_id:
node_id = n.get("id", f"index_{idx}")
raise AgentJsonValidationError(
f"Node '{node_id}' is missing required field 'block_id'"
)
node = Node(
id=n.get("id", str(uuid.uuid4())),
block_id=n["block_id"],
block_id=block_id,
input_default=n.get("input_default", {}),
metadata=n.get("metadata", {}),
)
nodes.append(node)
links = []
for link_data in agent_json.get("links", []):
for idx, link_data in enumerate(agent_json.get("links", [])):
source_id = link_data.get("source_id")
sink_id = link_data.get("sink_id")
source_name = link_data.get("source_name")
sink_name = link_data.get("sink_name")
missing_fields = []
if not source_id:
missing_fields.append("source_id")
if not sink_id:
missing_fields.append("sink_id")
if not source_name:
missing_fields.append("source_name")
if not sink_name:
missing_fields.append("sink_name")
if missing_fields:
link_id = link_data.get("id", f"index_{idx}")
raise AgentJsonValidationError(
f"Link '{link_id}' is missing required fields: {', '.join(missing_fields)}"
)
link = Link(
id=link_data.get("id", str(uuid.uuid4())),
source_id=link_data["source_id"],
sink_id=link_data["sink_id"],
source_name=link_data["source_name"],
sink_name=link_data["sink_name"],
source_id=source_id,
sink_id=sink_id,
source_name=source_name,
sink_name=sink_name,
is_static=link_data.get("is_static", False),
)
links.append(link)