From aeba28266cb5a9754bd57e08e60da3dba2c29cbd Mon Sep 17 00:00:00 2001 From: Otto Date: Wed, 4 Feb 2026 08:54:27 +0000 Subject: [PATCH] refactor(copilot): use Pydantic models and match/case in customize_agent Addresses review feedback from ntindle on PR #11943: 1. Use typed parameters instead of kwargs.get(): - Added CustomizeAgentInput Pydantic model with field_validator - Tool now uses params = CustomizeAgentInput(**kwargs) pattern 2. Use match/case for cleaner pattern matching: - Extracted response handling to _handle_customization_result method - Uses match result_type: case 'error' | 'clarifying_questions' | _ 3. Improved code organization: - Split monolithic _execute into smaller focused methods - _handle_customization_result for response type handling - _save_or_preview_agent for final save/preview logic --- .../features/chat/tools/customize_agent.py | 187 +++++++++++------- 1 file changed, 121 insertions(+), 66 deletions(-) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py index c0568bd936..04984abeba 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py @@ -3,6 +3,8 @@ import logging from typing import Any +from pydantic import BaseModel, field_validator + from backend.api.features.chat.model import ChatSession from backend.api.features.store import db as store_db from backend.api.features.store.exceptions import AgentNotFoundError @@ -27,6 +29,23 @@ from .models import ( logger = logging.getLogger(__name__) +class CustomizeAgentInput(BaseModel): + """Input parameters for the customize_agent tool.""" + + agent_id: str = "" + modifications: str = "" + context: str = "" + save: bool = True + + @field_validator("agent_id", "modifications", "context", mode="before") + @classmethod + def strip_strings(cls, v: Any) -> str: + """Strip whitespace from string fields.""" + if isinstance(v, str): + return v.strip() + return v if v is not None else "" + + class CustomizeAgentTool(BaseTool): """Tool for customizing marketplace/template agents using natural language.""" @@ -92,7 +111,7 @@ class CustomizeAgentTool(BaseTool): self, user_id: str | None, session: ChatSession, - **kwargs, + **kwargs: Any, ) -> ToolResponseBase: """Execute the customize_agent tool. @@ -102,20 +121,17 @@ class CustomizeAgentTool(BaseTool): 3. Call customize_template with the modification request 4. Preview or save based on the save parameter """ - agent_id = kwargs.get("agent_id", "").strip() - modifications = kwargs.get("modifications", "").strip() - context = kwargs.get("context", "") - save = kwargs.get("save", True) + params = CustomizeAgentInput(**kwargs) session_id = session.session_id if session else None - if not agent_id: + if not params.agent_id: return ErrorResponse( message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').", error="missing_agent_id", session_id=session_id, ) - if not modifications: + if not params.modifications: return ErrorResponse( message="Please describe how you want to customize this agent.", error="missing_modifications", @@ -123,11 +139,11 @@ class CustomizeAgentTool(BaseTool): ) # Parse agent_id in format "creator/slug" - parts = [p.strip() for p in agent_id.split("/")] + parts = params.agent_id.split("/") if len(parts) != 2 or not parts[0] or not parts[1]: return ErrorResponse( message=( - f"Invalid agent ID format: '{agent_id}'. " + f"Invalid agent ID format: '{params.agent_id}'. " "Expected format is 'creator/agent-name' " "(e.g., 'autogpt/newsletter-writer')." ), @@ -145,14 +161,14 @@ class CustomizeAgentTool(BaseTool): except AgentNotFoundError: return ErrorResponse( message=( - f"Could not find marketplace agent '{agent_id}'. " + f"Could not find marketplace agent '{params.agent_id}'. " "Please check the agent ID and try again." ), error="agent_not_found", session_id=session_id, ) except Exception as e: - logger.error(f"Error fetching marketplace agent {agent_id}: {e}") + logger.error(f"Error fetching marketplace agent {params.agent_id}: {e}") return ErrorResponse( message="Failed to fetch the marketplace agent. Please try again.", error="fetch_error", @@ -162,7 +178,7 @@ class CustomizeAgentTool(BaseTool): if not agent_details.store_listing_version_id: return ErrorResponse( message=( - f"The agent '{agent_id}' does not have an available version. " + f"The agent '{params.agent_id}' does not have an available version. " "Please try a different agent." ), error="no_version_available", @@ -174,7 +190,7 @@ class CustomizeAgentTool(BaseTool): graph = await store_db.get_agent(agent_details.store_listing_version_id) template_agent = graph_to_json(graph) except Exception as e: - logger.error(f"Error fetching agent graph for {agent_id}: {e}") + logger.error(f"Error fetching agent graph for {params.agent_id}: {e}") return ErrorResponse( message="Failed to fetch the agent configuration. Please try again.", error="graph_fetch_error", @@ -185,8 +201,8 @@ class CustomizeAgentTool(BaseTool): try: result = await customize_template( template_agent=template_agent, - modification_request=modifications, - context=context, + modification_request=params.modifications, + context=params.context, ) except AgentGeneratorNotConfiguredError: return ErrorResponse( @@ -198,7 +214,7 @@ class CustomizeAgentTool(BaseTool): session_id=session_id, ) except Exception as e: - logger.error(f"Error calling customize_template for {agent_id}: {e}") + logger.error(f"Error calling customize_template for {params.agent_id}: {e}") return ErrorResponse( message=( "Failed to customize the agent due to a service error. " @@ -219,55 +235,25 @@ class CustomizeAgentTool(BaseTool): session_id=session_id, ) - # Handle error response - if isinstance(result, dict) and result.get("type") == "error": - error_msg = result.get("error", "Unknown error") - error_type = result.get("error_type", "unknown") - user_message = get_user_message_for_error( - error_type, - operation="customize the agent", - llm_parse_message=( - "The AI had trouble customizing the agent. " - "Please try again or simplify your request." - ), - validation_message=( - "The customized agent failed validation. " - "Please try rephrasing your request." - ), - error_details=error_msg, - ) - return ErrorResponse( - message=user_message, - error=f"customization_failed:{error_type}", - session_id=session_id, - ) + # Handle response using match/case for cleaner pattern matching + return await self._handle_customization_result( + result=result, + params=params, + agent_details=agent_details, + user_id=user_id, + session_id=session_id, + ) - # Handle clarifying questions - if isinstance(result, dict) and result.get("type") == "clarifying_questions": - questions = result.get("questions") or [] - if not isinstance(questions, list): - logger.error( - f"Unexpected clarifying questions format: {type(questions)}" - ) - questions = [] - return ClarificationNeededResponse( - message=( - "I need some more information to customize this agent. " - "Please answer the following questions:" - ), - questions=[ - ClarifyingQuestion( - question=q.get("question", ""), - keyword=q.get("keyword", ""), - example=q.get("example"), - ) - for q in questions - if isinstance(q, dict) - ], - session_id=session_id, - ) - - # Result should be the customized agent JSON + async def _handle_customization_result( + self, + result: dict[str, Any], + params: CustomizeAgentInput, + agent_details: Any, + user_id: str | None, + session_id: str | None, + ) -> ToolResponseBase: + """Handle the result from customize_template using pattern matching.""" + # Ensure result is a dict if not isinstance(result, dict): logger.error(f"Unexpected customize_template response type: {type(result)}") return ErrorResponse( @@ -276,8 +262,77 @@ class CustomizeAgentTool(BaseTool): session_id=session_id, ) - customized_agent = result + result_type = result.get("type") + match result_type: + case "error": + error_msg = result.get("error", "Unknown error") + error_type = result.get("error_type", "unknown") + user_message = get_user_message_for_error( + error_type, + operation="customize the agent", + llm_parse_message=( + "The AI had trouble customizing the agent. " + "Please try again or simplify your request." + ), + validation_message=( + "The customized agent failed validation. " + "Please try rephrasing your request." + ), + error_details=error_msg, + ) + return ErrorResponse( + message=user_message, + error=f"customization_failed:{error_type}", + session_id=session_id, + ) + + case "clarifying_questions": + questions_data = result.get("questions") or [] + if not isinstance(questions_data, list): + logger.error( + f"Unexpected clarifying questions format: {type(questions_data)}" + ) + questions_data = [] + + questions = [ + ClarifyingQuestion( + question=q.get("question", "") if isinstance(q, dict) else "", + keyword=q.get("keyword", "") if isinstance(q, dict) else "", + example=q.get("example") if isinstance(q, dict) else None, + ) + for q in questions_data + if isinstance(q, dict) + ] + + return ClarificationNeededResponse( + message=( + "I need some more information to customize this agent. " + "Please answer the following questions:" + ), + questions=questions, + session_id=session_id, + ) + + case _: + # Default case: result is the customized agent JSON + return await self._save_or_preview_agent( + customized_agent=result, + params=params, + agent_details=agent_details, + user_id=user_id, + session_id=session_id, + ) + + async def _save_or_preview_agent( + self, + customized_agent: dict[str, Any], + params: CustomizeAgentInput, + agent_details: Any, + user_id: str | None, + session_id: str | None, + ) -> ToolResponseBase: + """Save or preview the customized agent based on params.save.""" agent_name = customized_agent.get( "name", f"Customized {agent_details.agent_name}" ) @@ -287,7 +342,7 @@ class CustomizeAgentTool(BaseTool): node_count = len(nodes) if isinstance(nodes, list) else 0 link_count = len(links) if isinstance(links, list) else 0 - if not save: + if not params.save: return AgentPreviewResponse( message=( f"I've customized the agent '{agent_details.agent_name}'. "