mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-04 20:05:11 -05:00
Compare commits
1 Commits
feat/text-
...
refactor/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f7a7067ec |
@@ -3,6 +3,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||||
@@ -27,6 +29,23 @@ from .models import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomizeAgentInput(BaseModel):
|
||||||
|
"""Input parameters for the customize_agent tool."""
|
||||||
|
|
||||||
|
agent_id: str = ""
|
||||||
|
modifications: str = ""
|
||||||
|
context: str = ""
|
||||||
|
save: bool = True
|
||||||
|
|
||||||
|
@field_validator("agent_id", "modifications", "context", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_strings(cls, v: Any) -> str:
|
||||||
|
"""Strip whitespace from string fields."""
|
||||||
|
if isinstance(v, str):
|
||||||
|
return v.strip()
|
||||||
|
return v if v is not None else ""
|
||||||
|
|
||||||
|
|
||||||
class CustomizeAgentTool(BaseTool):
|
class CustomizeAgentTool(BaseTool):
|
||||||
"""Tool for customizing marketplace/template agents using natural language."""
|
"""Tool for customizing marketplace/template agents using natural language."""
|
||||||
|
|
||||||
@@ -92,7 +111,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
"""Execute the customize_agent tool.
|
"""Execute the customize_agent tool.
|
||||||
|
|
||||||
@@ -102,20 +121,17 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
3. Call customize_template with the modification request
|
3. Call customize_template with the modification request
|
||||||
4. Preview or save based on the save parameter
|
4. Preview or save based on the save parameter
|
||||||
"""
|
"""
|
||||||
agent_id = kwargs.get("agent_id", "").strip()
|
params = CustomizeAgentInput(**kwargs)
|
||||||
modifications = kwargs.get("modifications", "").strip()
|
|
||||||
context = kwargs.get("context", "")
|
|
||||||
save = kwargs.get("save", True)
|
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
if not agent_id:
|
if not params.agent_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||||
error="missing_agent_id",
|
error="missing_agent_id",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not modifications:
|
if not params.modifications:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please describe how you want to customize this agent.",
|
message="Please describe how you want to customize this agent.",
|
||||||
error="missing_modifications",
|
error="missing_modifications",
|
||||||
@@ -123,11 +139,11 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Parse agent_id in format "creator/slug"
|
# Parse agent_id in format "creator/slug"
|
||||||
parts = [p.strip() for p in agent_id.split("/")]
|
parts = params.agent_id.split("/")
|
||||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
f"Invalid agent ID format: '{agent_id}'. "
|
f"Invalid agent ID format: '{params.agent_id}'. "
|
||||||
"Expected format is 'creator/agent-name' "
|
"Expected format is 'creator/agent-name' "
|
||||||
"(e.g., 'autogpt/newsletter-writer')."
|
"(e.g., 'autogpt/newsletter-writer')."
|
||||||
),
|
),
|
||||||
@@ -145,14 +161,14 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
except AgentNotFoundError:
|
except AgentNotFoundError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
f"Could not find marketplace agent '{agent_id}'. "
|
f"Could not find marketplace agent '{params.agent_id}'. "
|
||||||
"Please check the agent ID and try again."
|
"Please check the agent ID and try again."
|
||||||
),
|
),
|
||||||
error="agent_not_found",
|
error="agent_not_found",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
logger.error(f"Error fetching marketplace agent {params.agent_id}: {e}")
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to fetch the marketplace agent. Please try again.",
|
message="Failed to fetch the marketplace agent. Please try again.",
|
||||||
error="fetch_error",
|
error="fetch_error",
|
||||||
@@ -162,7 +178,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
if not agent_details.store_listing_version_id:
|
if not agent_details.store_listing_version_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
f"The agent '{agent_id}' does not have an available version. "
|
f"The agent '{params.agent_id}' does not have an available version. "
|
||||||
"Please try a different agent."
|
"Please try a different agent."
|
||||||
),
|
),
|
||||||
error="no_version_available",
|
error="no_version_available",
|
||||||
@@ -174,7 +190,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||||
template_agent = graph_to_json(graph)
|
template_agent = graph_to_json(graph)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
logger.error(f"Error fetching agent graph for {params.agent_id}: {e}")
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to fetch the agent configuration. Please try again.",
|
message="Failed to fetch the agent configuration. Please try again.",
|
||||||
error="graph_fetch_error",
|
error="graph_fetch_error",
|
||||||
@@ -185,8 +201,8 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
try:
|
try:
|
||||||
result = await customize_template(
|
result = await customize_template(
|
||||||
template_agent=template_agent,
|
template_agent=template_agent,
|
||||||
modification_request=modifications,
|
modification_request=params.modifications,
|
||||||
context=context,
|
context=params.context,
|
||||||
)
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
@@ -198,7 +214,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
logger.error(f"Error calling customize_template for {params.agent_id}: {e}")
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
"Failed to customize the agent due to a service error. "
|
"Failed to customize the agent due to a service error. "
|
||||||
@@ -219,8 +235,37 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle error response
|
# Handle response using match/case for cleaner pattern matching
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
return await self._handle_customization_result(
|
||||||
|
result=result,
|
||||||
|
params=params,
|
||||||
|
agent_details=agent_details,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_customization_result(
|
||||||
|
self,
|
||||||
|
result: dict[str, Any],
|
||||||
|
params: CustomizeAgentInput,
|
||||||
|
agent_details: Any,
|
||||||
|
user_id: str | None,
|
||||||
|
session_id: str | None,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Handle the result from customize_template using pattern matching."""
|
||||||
|
# Ensure result is a dict
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to customize the agent due to an unexpected response.",
|
||||||
|
error="unexpected_response_type",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_type = result.get("type")
|
||||||
|
|
||||||
|
match result_type:
|
||||||
|
case "error":
|
||||||
error_msg = result.get("error", "Unknown error")
|
error_msg = result.get("error", "Unknown error")
|
||||||
error_type = result.get("error_type", "unknown")
|
error_type = result.get("error_type", "unknown")
|
||||||
user_message = get_user_message_for_error(
|
user_message = get_user_message_for_error(
|
||||||
@@ -242,42 +287,52 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle clarifying questions
|
case "clarifying_questions":
|
||||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
questions_data = result.get("questions") or []
|
||||||
questions = result.get("questions") or []
|
if not isinstance(questions_data, list):
|
||||||
if not isinstance(questions, list):
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Unexpected clarifying questions format: {type(questions)}"
|
f"Unexpected clarifying questions format: {type(questions_data)}"
|
||||||
)
|
)
|
||||||
questions = []
|
questions_data = []
|
||||||
|
|
||||||
|
questions = [
|
||||||
|
ClarifyingQuestion(
|
||||||
|
question=q.get("question", "") if isinstance(q, dict) else "",
|
||||||
|
keyword=q.get("keyword", "") if isinstance(q, dict) else "",
|
||||||
|
example=q.get("example") if isinstance(q, dict) else None,
|
||||||
|
)
|
||||||
|
for q in questions_data
|
||||||
|
if isinstance(q, dict)
|
||||||
|
]
|
||||||
|
|
||||||
return ClarificationNeededResponse(
|
return ClarificationNeededResponse(
|
||||||
message=(
|
message=(
|
||||||
"I need some more information to customize this agent. "
|
"I need some more information to customize this agent. "
|
||||||
"Please answer the following questions:"
|
"Please answer the following questions:"
|
||||||
),
|
),
|
||||||
questions=[
|
questions=questions,
|
||||||
ClarifyingQuestion(
|
|
||||||
question=q.get("question", ""),
|
|
||||||
keyword=q.get("keyword", ""),
|
|
||||||
example=q.get("example"),
|
|
||||||
)
|
|
||||||
for q in questions
|
|
||||||
if isinstance(q, dict)
|
|
||||||
],
|
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Result should be the customized agent JSON
|
case _:
|
||||||
if not isinstance(result, dict):
|
# Default case: result is the customized agent JSON
|
||||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
return await self._save_or_preview_agent(
|
||||||
return ErrorResponse(
|
customized_agent=result,
|
||||||
message="Failed to customize the agent due to an unexpected response.",
|
params=params,
|
||||||
error="unexpected_response_type",
|
agent_details=agent_details,
|
||||||
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
customized_agent = result
|
async def _save_or_preview_agent(
|
||||||
|
self,
|
||||||
|
customized_agent: dict[str, Any],
|
||||||
|
params: CustomizeAgentInput,
|
||||||
|
agent_details: Any,
|
||||||
|
user_id: str | None,
|
||||||
|
session_id: str | None,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Save or preview the customized agent based on params.save."""
|
||||||
agent_name = customized_agent.get(
|
agent_name = customized_agent.get(
|
||||||
"name", f"Customized {agent_details.agent_name}"
|
"name", f"Customized {agent_details.agent_name}"
|
||||||
)
|
)
|
||||||
@@ -287,7 +342,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||||
link_count = len(links) if isinstance(links, list) else 0
|
link_count = len(links) if isinstance(links, list) else 0
|
||||||
|
|
||||||
if not save:
|
if not params.save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||||
|
|||||||
Reference in New Issue
Block a user