diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py index 1a636c41f7..b7650b3cbd 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py @@ -2,6 +2,7 @@ from .core import ( AgentGeneratorNotConfiguredError, + AgentJsonValidationError, AgentSummary, DecompositionResult, DecompositionStep, @@ -28,6 +29,7 @@ from .service import is_external_service_configured __all__ = [ "AgentGeneratorNotConfiguredError", + "AgentJsonValidationError", "AgentSummary", "DecompositionResult", "DecompositionStep", diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py index 02ffa7a38b..71c5247d09 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py @@ -335,11 +335,12 @@ async def get_all_relevant_agents_for_generation( if graph_id == exclude_graph_id: continue agent = await get_library_agent_by_graph_id(user_id, graph_id) - if agent and agent.get("graph_id") not in seen_graph_ids: + agent_graph_id = agent.get("graph_id") if agent else None + if agent and agent_graph_id and agent_graph_id not in seen_graph_ids: agents.append(agent) - seen_graph_ids.add(agent.get("graph_id", "")) + seen_graph_ids.add(agent_graph_id) logger.debug( - f"Found explicitly mentioned agent: {agent.get('name', 'Unknown')}" + f"Found explicitly mentioned agent: {agent.get('name') or 'Unknown'}" ) if include_library: @@ -360,11 +361,16 @@ async def get_all_relevant_agents_for_generation( search_query=search_query, max_results=max_marketplace_results, ) - library_names = {name.lower() for a in agents if (name := a.get("name"))} + library_names: set[str] = set() + for a in agents: + name = a.get("name") + if name and isinstance(name, str): + library_names.add(name.lower()) for agent in marketplace_agents: agent_name = agent.get("name") - if agent_name and agent_name.lower() not in library_names: - agents.append(agent) + if agent_name and isinstance(agent_name, str): + if agent_name.lower() not in library_names: + agents.append(agent) return agents @@ -444,11 +450,11 @@ async def enrich_library_agents_from_steps( existing_names: set[str] = set() for agent in existing_agents: - agent_name = agent.get("name", "") - if agent_name: + agent_name = agent.get("name") + if agent_name and isinstance(agent_name, str): existing_names.add(agent_name.lower()) graph_id = agent.get("graph_id") # type: ignore[call-overload] - if graph_id: + if graph_id and isinstance(graph_id, str): existing_ids.add(graph_id) all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents) @@ -465,8 +471,8 @@ async def enrich_library_agents_from_steps( ) for agent in additional_agents: - agent_name = agent.get("name", "") - if not agent_name: + agent_name = agent.get("name") + if not agent_name or not isinstance(agent_name, str): continue agent_name_lower = agent_name.lower() @@ -479,9 +485,12 @@ async def enrich_library_agents_from_steps( all_agents.append(agent) existing_names.add(agent_name_lower) - if graph_id: + if graph_id and isinstance(graph_id, str): existing_ids.add(graph_id) + except DatabaseError: + logger.error(f"Database error searching for agents with term '{term}'") + raise except Exception as e: logger.warning( f"Failed to search for additional agents with term '{term}': {e}" @@ -557,6 +566,12 @@ async def generate_agent( return result +class AgentJsonValidationError(Exception): + """Raised when agent JSON is invalid or missing required fields.""" + + pass + + def json_to_graph(agent_json: dict[str, Any]) -> Graph: """Convert agent JSON dict to Graph model. @@ -565,25 +580,55 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph: Returns: Graph ready for saving + + Raises: + AgentJsonValidationError: If required fields are missing from nodes or links """ nodes = [] - for n in agent_json.get("nodes", []): + for idx, n in enumerate(agent_json.get("nodes", [])): + block_id = n.get("block_id") + if not block_id: + node_id = n.get("id", f"index_{idx}") + raise AgentJsonValidationError( + f"Node '{node_id}' is missing required field 'block_id'" + ) node = Node( id=n.get("id", str(uuid.uuid4())), - block_id=n["block_id"], + block_id=block_id, input_default=n.get("input_default", {}), metadata=n.get("metadata", {}), ) nodes.append(node) links = [] - for link_data in agent_json.get("links", []): + for idx, link_data in enumerate(agent_json.get("links", [])): + source_id = link_data.get("source_id") + sink_id = link_data.get("sink_id") + source_name = link_data.get("source_name") + sink_name = link_data.get("sink_name") + + missing_fields = [] + if not source_id: + missing_fields.append("source_id") + if not sink_id: + missing_fields.append("sink_id") + if not source_name: + missing_fields.append("source_name") + if not sink_name: + missing_fields.append("sink_name") + + if missing_fields: + link_id = link_data.get("id", f"index_{idx}") + raise AgentJsonValidationError( + f"Link '{link_id}' is missing required fields: {', '.join(missing_fields)}" + ) + link = Link( id=link_data.get("id", str(uuid.uuid4())), - source_id=link_data["source_id"], - sink_id=link_data["sink_id"], - source_name=link_data["source_name"], - sink_name=link_data["sink_name"], + source_id=source_id, + sink_id=sink_id, + source_name=source_name, + sink_name=sink_name, is_static=link_data.get("is_static", False), ) links.append(link)