Compare commits

...

1 Commits

Author SHA1 Message Date
Otto
7f7a7067ec refactor(copilot): use Pydantic models and match/case in customize_agent
Addresses review feedback from ntindle:

1. Use typed parameters instead of kwargs.get():
   - Added CustomizeAgentInput Pydantic model with field_validator for stripping strings
   - Tool now uses params = CustomizeAgentInput(**kwargs) pattern

2. Use match/case for cleaner pattern matching:
   - Extracted response handling to _handle_customization_result method
   - Uses match result_type: case 'error' | 'clarifying_questions' | _

3. Improved code organization:
   - Split monolithic _execute into smaller focused methods
   - _handle_customization_result for response type handling
   - _save_or_preview_agent for final save/preview logic
2026-02-04 08:53:02 +00:00

View File

@@ -3,6 +3,8 @@
import logging
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError
@@ -27,6 +29,23 @@ from .models import (
logger = logging.getLogger(__name__)
class CustomizeAgentInput(BaseModel):
"""Input parameters for the customize_agent tool."""
agent_id: str = ""
modifications: str = ""
context: str = ""
save: bool = True
@field_validator("agent_id", "modifications", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
"""Strip whitespace from string fields."""
if isinstance(v, str):
return v.strip()
return v if v is not None else ""
class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language."""
@@ -92,7 +111,7 @@ class CustomizeAgentTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs,
**kwargs: Any,
) -> ToolResponseBase:
"""Execute the customize_agent tool.
@@ -102,20 +121,17 @@ class CustomizeAgentTool(BaseTool):
3. Call customize_template with the modification request
4. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
params = CustomizeAgentInput(**kwargs)
session_id = session.session_id if session else None
if not agent_id:
if not params.agent_id:
return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id",
session_id=session_id,
)
if not modifications:
if not params.modifications:
return ErrorResponse(
message="Please describe how you want to customize this agent.",
error="missing_modifications",
@@ -123,11 +139,11 @@ class CustomizeAgentTool(BaseTool):
)
# Parse agent_id in format "creator/slug"
parts = [p.strip() for p in agent_id.split("/")]
parts = params.agent_id.split("/")
if len(parts) != 2 or not parts[0] or not parts[1]:
return ErrorResponse(
message=(
f"Invalid agent ID format: '{agent_id}'. "
f"Invalid agent ID format: '{params.agent_id}'. "
"Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')."
),
@@ -145,14 +161,14 @@ class CustomizeAgentTool(BaseTool):
except AgentNotFoundError:
return ErrorResponse(
message=(
f"Could not find marketplace agent '{agent_id}'. "
f"Could not find marketplace agent '{params.agent_id}'. "
"Please check the agent ID and try again."
),
error="agent_not_found",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
logger.error(f"Error fetching marketplace agent {params.agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error",
@@ -162,7 +178,7 @@ class CustomizeAgentTool(BaseTool):
if not agent_details.store_listing_version_id:
return ErrorResponse(
message=(
f"The agent '{agent_id}' does not have an available version. "
f"The agent '{params.agent_id}' does not have an available version. "
"Please try a different agent."
),
error="no_version_available",
@@ -174,7 +190,7 @@ class CustomizeAgentTool(BaseTool):
graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph)
except Exception as e:
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
logger.error(f"Error fetching agent graph for {params.agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error",
@@ -185,8 +201,8 @@ class CustomizeAgentTool(BaseTool):
try:
result = await customize_template(
template_agent=template_agent,
modification_request=modifications,
context=context,
modification_request=params.modifications,
context=params.context,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
@@ -198,7 +214,7 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id,
)
except Exception as e:
logger.error(f"Error calling customize_template for {agent_id}: {e}")
logger.error(f"Error calling customize_template for {params.agent_id}: {e}")
return ErrorResponse(
message=(
"Failed to customize the agent due to a service error. "
@@ -219,55 +235,25 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id,
)
# Handle error response
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
# Handle response using match/case for cleaner pattern matching
return await self._handle_customization_result(
result=result,
params=params,
agent_details=agent_details,
user_id=user_id,
session_id=session_id,
)
# Handle clarifying questions
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
questions = result.get("questions") or []
if not isinstance(questions, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions)}"
)
questions = []
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
async def _handle_customization_result(
self,
result: dict[str, Any],
params: CustomizeAgentInput,
agent_details: Any,
user_id: str | None,
session_id: str | None,
) -> ToolResponseBase:
"""Handle the result from customize_template using pattern matching."""
# Ensure result is a dict
if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse(
@@ -276,8 +262,77 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id,
)
customized_agent = result
result_type = result.get("type")
match result_type:
case "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
case "clarifying_questions":
questions_data = result.get("questions") or []
if not isinstance(questions_data, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions_data)}"
)
questions_data = []
questions = [
ClarifyingQuestion(
question=q.get("question", "") if isinstance(q, dict) else "",
keyword=q.get("keyword", "") if isinstance(q, dict) else "",
example=q.get("example") if isinstance(q, dict) else None,
)
for q in questions_data
if isinstance(q, dict)
]
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=questions,
session_id=session_id,
)
case _:
# Default case: result is the customized agent JSON
return await self._save_or_preview_agent(
customized_agent=result,
params=params,
agent_details=agent_details,
user_id=user_id,
session_id=session_id,
)
async def _save_or_preview_agent(
self,
customized_agent: dict[str, Any],
params: CustomizeAgentInput,
agent_details: Any,
user_id: str | None,
session_id: str | None,
) -> ToolResponseBase:
"""Save or preview the customized agent based on params.save."""
agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}"
)
@@ -287,7 +342,7 @@ class CustomizeAgentTool(BaseTool):
node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0
if not save:
if not params.save:
return AgentPreviewResponse(
message=(
f"I've customized the agent '{agent_details.agent_name}'. "