Compare commits

...

4 Commits

Author SHA1 Message Date
Otto
a093d57ed2 fix: address CodeRabbit review bugs
- customize_agent.py: Strip whitespace from split parts of agent_id
- edit_agent.py: Use model_config instead of deprecated class Config
- edit_agent.py: Fix undefined agent_id/changes → params.agent_id/params.changes
- find_library_agent.py: Remove docstrings per coding guidelines
- get_doc_page.py: Fix undefined path → params.path
- run_block.py: Fix undefined block_id → params.block_id
- workspace_files.py: Fix undefined include_all_sessions → params.include_all_sessions
2026-02-04 09:17:44 +00:00
Otto
6692f39cbd refactor(copilot): add Pydantic input models to all tools
Convert all CoPilot tools from kwargs.get() pattern to Pydantic models:

Tools updated:
- find_agent.py: FindAgentInput
- find_library_agent.py: FindLibraryAgentInput
- find_block.py: FindBlockInput
- search_docs.py: SearchDocsInput
- get_doc_page.py: GetDocPageInput
- create_agent.py: CreateAgentInput
- edit_agent.py: EditAgentInput
- run_block.py: RunBlockInput
- workspace_files.py: 4 input models (List/Read/Write/Delete)

Benefits:
- Type safety with automatic validation
- Consistent string stripping via field_validators
- Better IDE support and error messages
- Cleaner _execute methods using params object

Addresses ntindle review feedback about kwargs pattern.
2026-02-04 09:05:18 +00:00
Otto
aeba28266c refactor(copilot): use Pydantic models and match/case in customize_agent
Addresses review feedback from ntindle on PR #11943:

1. Use typed parameters instead of kwargs.get():
   - Added CustomizeAgentInput Pydantic model with field_validator
   - Tool now uses params = CustomizeAgentInput(**kwargs) pattern

2. Use match/case for cleaner pattern matching:
   - Extracted response handling to _handle_customization_result method
   - Uses match result_type: case 'error' | 'clarifying_questions' | _

3. Improved code organization:
   - Split monolithic _execute into smaller focused methods
   - _handle_customization_result for response type handling
   - _save_or_preview_agent for final save/preview logic
2026-02-04 08:54:27 +00:00
Otto
6d8c83c039 refactor(backend): move local imports to module level in chat service
Addresses review feedback from PRs #11937, #11856:
- Move uuid import to top level (was duplicated in 3 functions)
- Move compress_context import to top level
- Remove redundant local imports for cast and ChatCompletionMessageParam
  (already imported at module level)

Refs:
- https://github.com/Significant-Gravitas/AutoGPT/pull/11937#discussion_r2761107861
- https://github.com/Significant-Gravitas/AutoGPT/pull/11856#discussion_r2761558008
- https://github.com/Significant-Gravitas/AutoGPT/pull/11856#discussion_r2761559661
2026-02-04 03:33:15 +00:00
11 changed files with 425 additions and 215 deletions

View File

@@ -1,12 +1,15 @@
import asyncio
import logging
import time
import uuid as uuid_module
from asyncio import CancelledError
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, cast
import openai
from backend.util.prompt import compress_context
if TYPE_CHECKING:
from backend.util.prompt import CompressResult
@@ -467,8 +470,6 @@ async def stream_chat_completion(
should_retry = False
# Generate unique IDs for AI SDK protocol
import uuid as uuid_module
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())
@@ -826,10 +827,6 @@ async def _manage_context_window(
Returns:
CompressResult with compacted messages and metadata
"""
import openai
from backend.util.prompt import compress_context
# Convert messages to dict format
messages_dict = []
for msg in messages:
@@ -1140,8 +1137,6 @@ async def _yield_tool_call(
KeyError: If expected tool call fields are missing
TypeError: If tool call structure is invalid
"""
import uuid as uuid_module
tool_name = tool_calls[yield_idx]["function"]["name"]
tool_call_id = tool_calls[yield_idx]["id"]
@@ -1762,8 +1757,6 @@ async def _generate_llm_continuation_with_streaming(
after a tool result is saved. Chunks are published to the stream registry
so reconnecting clients can receive them.
"""
import uuid as uuid_module
try:
# Load fresh session from DB (bypass cache to get the updated tool result)
await invalidate_session_cache(session_id)
@@ -1799,10 +1792,6 @@ async def _generate_llm_continuation_with_streaming(
extra_body["session_id"] = session_id[:128]
# Make streaming LLM call (no tools - just text response)
from typing import cast
from openai.types.chat import ChatCompletionMessageParam
# Generate unique IDs for AI SDK protocol
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())

View File

@@ -3,6 +3,8 @@
import logging
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
@@ -28,6 +30,26 @@ from .models import (
logger = logging.getLogger(__name__)
class CreateAgentInput(BaseModel):
"""Input parameters for the create_agent tool."""
description: str = ""
context: str = ""
save: bool = True
# Internal async processing params (passed by long-running tool handler)
_operation_id: str | None = None
_task_id: str | None = None
@field_validator("description", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
"""Strip whitespace from string fields."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class Config:
extra = "allow" # Allow _operation_id, _task_id from kwargs
class CreateAgentTool(BaseTool):
"""Tool for creating agents from natural language descriptions."""
@@ -85,7 +107,7 @@ class CreateAgentTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs,
**kwargs: Any,
) -> ToolResponseBase:
"""Execute the create_agent tool.
@@ -94,16 +116,14 @@ class CreateAgentTool(BaseTool):
2. Generate agent JSON (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
params = CreateAgentInput(**kwargs)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
# Extract async processing params
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not description:
if not params.description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
error="Missing description parameter",
@@ -115,7 +135,7 @@ class CreateAgentTool(BaseTool):
try:
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=description,
search_query=params.description,
include_marketplace=True,
)
logger.debug(
@@ -126,7 +146,7 @@ class CreateAgentTool(BaseTool):
try:
decomposition_result = await decompose_goal(
description, context, library_agents
params.description, params.context, library_agents
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
@@ -142,7 +162,7 @@ class CreateAgentTool(BaseTool):
return ErrorResponse(
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
error="decomposition_failed",
details={"description": description[:100]},
details={"description": params.description[:100]},
session_id=session_id,
)
@@ -158,7 +178,7 @@ class CreateAgentTool(BaseTool):
message=user_message,
error=f"decomposition_failed:{error_type}",
details={
"description": description[:100],
"description": params.description[:100],
"service_error": error_msg,
"error_type": error_type,
},
@@ -244,7 +264,7 @@ class CreateAgentTool(BaseTool):
return ErrorResponse(
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
error="generation_failed",
details={"description": description[:100]},
details={"description": params.description[:100]},
session_id=session_id,
)
@@ -266,7 +286,7 @@ class CreateAgentTool(BaseTool):
message=user_message,
error=f"generation_failed:{error_type}",
details={
"description": description[:100],
"description": params.description[:100],
"service_error": error_msg,
"error_type": error_type,
},
@@ -291,7 +311,7 @@ class CreateAgentTool(BaseTool):
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
if not save:
if not params.save:
return AgentPreviewResponse(
message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "

View File

@@ -3,6 +3,8 @@
import logging
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError
@@ -27,6 +29,23 @@ from .models import (
logger = logging.getLogger(__name__)
class CustomizeAgentInput(BaseModel):
"""Input parameters for the customize_agent tool."""
agent_id: str = ""
modifications: str = ""
context: str = ""
save: bool = True
@field_validator("agent_id", "modifications", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
"""Strip whitespace from string fields."""
if isinstance(v, str):
return v.strip()
return v if v is not None else ""
class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language."""
@@ -92,7 +111,7 @@ class CustomizeAgentTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs,
**kwargs: Any,
) -> ToolResponseBase:
"""Execute the customize_agent tool.
@@ -102,20 +121,17 @@ class CustomizeAgentTool(BaseTool):
3. Call customize_template with the modification request
4. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
params = CustomizeAgentInput(**kwargs)
session_id = session.session_id if session else None
if not agent_id:
if not params.agent_id:
return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id",
session_id=session_id,
)
if not modifications:
if not params.modifications:
return ErrorResponse(
message="Please describe how you want to customize this agent.",
error="missing_modifications",
@@ -123,11 +139,11 @@ class CustomizeAgentTool(BaseTool):
)
# Parse agent_id in format "creator/slug"
parts = [p.strip() for p in agent_id.split("/")]
parts = [p.strip() for p in params.agent_id.split("/")]
if len(parts) != 2 or not parts[0] or not parts[1]:
return ErrorResponse(
message=(
f"Invalid agent ID format: '{agent_id}'. "
f"Invalid agent ID format: '{params.agent_id}'. "
"Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')."
),
@@ -145,14 +161,14 @@ class CustomizeAgentTool(BaseTool):
except AgentNotFoundError:
return ErrorResponse(
message=(
f"Could not find marketplace agent '{agent_id}'. "
f"Could not find marketplace agent '{params.agent_id}'. "
"Please check the agent ID and try again."
),
error="agent_not_found",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
logger.error(f"Error fetching marketplace agent {params.agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error",
@@ -162,7 +178,7 @@ class CustomizeAgentTool(BaseTool):
if not agent_details.store_listing_version_id:
return ErrorResponse(
message=(
f"The agent '{agent_id}' does not have an available version. "
f"The agent '{params.agent_id}' does not have an available version. "
"Please try a different agent."
),
error="no_version_available",
@@ -174,7 +190,7 @@ class CustomizeAgentTool(BaseTool):
graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph)
except Exception as e:
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
logger.error(f"Error fetching agent graph for {params.agent_id}: {e}")
return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error",
@@ -185,8 +201,8 @@ class CustomizeAgentTool(BaseTool):
try:
result = await customize_template(
template_agent=template_agent,
modification_request=modifications,
context=context,
modification_request=params.modifications,
context=params.context,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
@@ -198,7 +214,7 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id,
)
except Exception as e:
logger.error(f"Error calling customize_template for {agent_id}: {e}")
logger.error(f"Error calling customize_template for {params.agent_id}: {e}")
return ErrorResponse(
message=(
"Failed to customize the agent due to a service error. "
@@ -219,55 +235,25 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id,
)
# Handle error response
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
# Handle response using match/case for cleaner pattern matching
return await self._handle_customization_result(
result=result,
params=params,
agent_details=agent_details,
user_id=user_id,
session_id=session_id,
)
# Handle clarifying questions
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
questions = result.get("questions") or []
if not isinstance(questions, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions)}"
)
questions = []
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
async def _handle_customization_result(
self,
result: dict[str, Any],
params: CustomizeAgentInput,
agent_details: Any,
user_id: str | None,
session_id: str | None,
) -> ToolResponseBase:
"""Handle the result from customize_template using pattern matching."""
# Ensure result is a dict
if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse(
@@ -276,8 +262,77 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id,
)
customized_agent = result
result_type = result.get("type")
match result_type:
case "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
case "clarifying_questions":
questions_data = result.get("questions") or []
if not isinstance(questions_data, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions_data)}"
)
questions_data = []
questions = [
ClarifyingQuestion(
question=q.get("question", "") if isinstance(q, dict) else "",
keyword=q.get("keyword", "") if isinstance(q, dict) else "",
example=q.get("example") if isinstance(q, dict) else None,
)
for q in questions_data
if isinstance(q, dict)
]
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=questions,
session_id=session_id,
)
case _:
# Default case: result is the customized agent JSON
return await self._save_or_preview_agent(
customized_agent=result,
params=params,
agent_details=agent_details,
user_id=user_id,
session_id=session_id,
)
async def _save_or_preview_agent(
self,
customized_agent: dict[str, Any],
params: CustomizeAgentInput,
agent_details: Any,
user_id: str | None,
session_id: str | None,
) -> ToolResponseBase:
"""Save or preview the customized agent based on params.save."""
agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}"
)
@@ -287,7 +342,7 @@ class CustomizeAgentTool(BaseTool):
node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0
if not save:
if not params.save:
return AgentPreviewResponse(
message=(
f"I've customized the agent '{agent_details.agent_name}'. "

View File

@@ -3,6 +3,8 @@
import logging
from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
@@ -27,6 +29,20 @@ from .models import (
logger = logging.getLogger(__name__)
class EditAgentInput(BaseModel):
model_config = ConfigDict(extra="allow")
agent_id: str = ""
changes: str = ""
context: str = ""
save: bool = True
@field_validator("agent_id", "changes", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class EditAgentTool(BaseTool):
"""Tool for editing existing agents using natural language."""
@@ -90,7 +106,7 @@ class EditAgentTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs,
**kwargs: Any,
) -> ToolResponseBase:
"""Execute the edit_agent tool.
@@ -99,35 +115,32 @@ class EditAgentTool(BaseTool):
2. Generate updated agent (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
params = EditAgentInput(**kwargs)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
if not params.agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
error="Missing agent_id parameter",
session_id=session_id,
)
if not changes:
if not params.changes:
return ErrorResponse(
message="Please describe what changes you want to make.",
error="Missing changes parameter",
session_id=session_id,
)
current_agent = await get_agent_as_json(agent_id, user_id)
current_agent = await get_agent_as_json(params.agent_id, user_id)
if current_agent is None:
return ErrorResponse(
message=f"Could not find agent with ID '{agent_id}' in your library.",
message=f"Could not find agent '{params.agent_id}' in your library.",
error="agent_not_found",
session_id=session_id,
)
@@ -138,7 +151,7 @@ class EditAgentTool(BaseTool):
graph_id = current_agent.get("id")
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=changes,
search_query=params.changes,
exclude_graph_id=graph_id,
include_marketplace=True,
)
@@ -148,9 +161,11 @@ class EditAgentTool(BaseTool):
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
update_request = changes
if context:
update_request = f"{changes}\n\nAdditional context:\n{context}"
update_request = params.changes
if params.context:
update_request = (
f"{params.changes}\n\nAdditional context:\n{params.context}"
)
try:
result = await generate_agent_patch(
@@ -174,7 +189,7 @@ class EditAgentTool(BaseTool):
return ErrorResponse(
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
error="update_generation_failed",
details={"agent_id": agent_id, "changes": changes[:100]},
details={"agent_id": params.agent_id, "changes": params.changes[:100]},
session_id=session_id,
)
@@ -206,8 +221,8 @@ class EditAgentTool(BaseTool):
message=user_message,
error=f"update_generation_failed:{error_type}",
details={
"agent_id": agent_id,
"changes": changes[:100],
"agent_id": params.agent_id,
"changes": params.changes[:100],
"service_error": error_msg,
"error_type": error_type,
},
@@ -239,7 +254,7 @@ class EditAgentTool(BaseTool):
node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", []))
if not save:
if not params.save:
return AgentPreviewResponse(
message=(
f"I've updated the agent. "

View File

@@ -2,6 +2,8 @@
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
@@ -9,6 +11,18 @@ from .base import BaseTool
from .models import ToolResponseBase
class FindAgentInput(BaseModel):
"""Input parameters for the find_agent tool."""
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from query."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class FindAgentTool(BaseTool):
"""Tool for discovering agents from the marketplace."""
@@ -36,10 +50,11 @@ class FindAgentTool(BaseTool):
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
self, user_id: str | None, session: ChatSession, **kwargs: Any
) -> ToolResponseBase:
params = FindAgentInput(**kwargs)
return await search_agents(
query=kwargs.get("query", "").strip(),
query=params.query,
source="marketplace",
session_id=session.session_id,
user_id=user_id,

View File

@@ -2,6 +2,7 @@ import logging
from typing import Any
from prisma.enums import ContentType
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
@@ -18,6 +19,18 @@ from backend.data.block import get_block
logger = logging.getLogger(__name__)
class FindBlockInput(BaseModel):
"""Input parameters for the find_block tool."""
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from query."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class FindBlockTool(BaseTool):
"""Tool for searching available blocks."""
@@ -59,24 +72,24 @@ class FindBlockTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs,
**kwargs: Any,
) -> ToolResponseBase:
"""Search for blocks matching the query.
Args:
user_id: User ID (required)
session: Chat session
query: Search query
**kwargs: Tool parameters
Returns:
BlockListResponse: List of matching blocks
NoResultsResponse: No blocks found
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
params = FindBlockInput(**kwargs)
session_id = session.session_id
if not query:
if not params.query:
return ErrorResponse(
message="Please provide a search query",
session_id=session_id,
@@ -85,7 +98,7 @@ class FindBlockTool(BaseTool):
try:
# Search for blocks using hybrid search
results, total = await unified_hybrid_search(
query=query,
query=params.query,
content_types=[ContentType.BLOCK],
page=1,
page_size=10,
@@ -93,7 +106,7 @@ class FindBlockTool(BaseTool):
if not results:
return NoResultsResponse(
message=f"No blocks found for '{query}'",
message=f"No blocks found for '{params.query}'",
suggestions=[
"Try broader keywords like 'email', 'http', 'text', 'ai'",
"Check spelling of technical terms",
@@ -165,7 +178,7 @@ class FindBlockTool(BaseTool):
if not blocks:
return NoResultsResponse(
message=f"No blocks found for '{query}'",
message=f"No blocks found for '{params.query}'",
suggestions=[
"Try broader keywords like 'email', 'http', 'text', 'ai'",
],
@@ -174,13 +187,13 @@ class FindBlockTool(BaseTool):
return BlockListResponse(
message=(
f"Found {len(blocks)} block(s) matching '{query}'. "
f"Found {len(blocks)} block(s) matching '{params.query}'. "
"To execute a block, use run_block with the block's 'id' field "
"and provide 'input_data' matching the block's input_schema."
),
blocks=blocks,
count=len(blocks),
query=query,
query=params.query,
session_id=session_id,
)

View File

@@ -2,6 +2,8 @@
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
@@ -9,6 +11,15 @@ from .base import BaseTool
from .models import ToolResponseBase
class FindLibraryAgentInput(BaseModel):
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class FindLibraryAgentTool(BaseTool):
"""Tool for searching agents in the user's library."""
@@ -42,10 +53,11 @@ class FindLibraryAgentTool(BaseTool):
return True
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs
self, user_id: str | None, session: ChatSession, **kwargs: Any
) -> ToolResponseBase:
params = FindLibraryAgentInput(**kwargs)
return await search_agents(
query=kwargs.get("query", "").strip(),
query=params.query,
source="library",
session_id=session.session_id,
user_id=user_id,

View File

@@ -4,6 +4,8 @@ import logging
from pathlib import Path
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
@@ -18,6 +20,18 @@ logger = logging.getLogger(__name__)
DOCS_BASE_URL = "https://docs.agpt.co"
class GetDocPageInput(BaseModel):
"""Input parameters for the get_doc_page tool."""
path: str = ""
@field_validator("path", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from path."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class GetDocPageTool(BaseTool):
"""Tool for fetching full content of a documentation page."""
@@ -75,23 +89,23 @@ class GetDocPageTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs,
**kwargs: Any,
) -> ToolResponseBase:
"""Fetch full content of a documentation page.
Args:
user_id: User ID (not required for docs)
session: Chat session
path: Path to the documentation file
**kwargs: Tool parameters
Returns:
DocPageResponse: Full document content
ErrorResponse: Error message
"""
path = kwargs.get("path", "").strip()
params = GetDocPageInput(**kwargs)
session_id = session.session_id if session else None
if not path:
if not params.path:
return ErrorResponse(
message="Please provide a documentation path.",
error="Missing path parameter",
@@ -99,7 +113,7 @@ class GetDocPageTool(BaseTool):
)
# Sanitize path to prevent directory traversal
if ".." in path or path.startswith("/"):
if ".." in params.path or params.path.startswith("/"):
return ErrorResponse(
message="Invalid documentation path.",
error="invalid_path",
@@ -107,11 +121,11 @@ class GetDocPageTool(BaseTool):
)
docs_root = self._get_docs_root()
full_path = docs_root / path
full_path = docs_root / params.path
if not full_path.exists():
return ErrorResponse(
message=f"Documentation page not found: {path}",
message=f"Documentation page not found: {params.path}",
error="not_found",
session_id=session_id,
)
@@ -128,19 +142,19 @@ class GetDocPageTool(BaseTool):
try:
content = full_path.read_text(encoding="utf-8")
title = self._extract_title(content, path)
title = self._extract_title(content, params.path)
return DocPageResponse(
message=f"Retrieved documentation page: {title}",
title=title,
path=path,
path=params.path,
content=content,
doc_url=self._make_doc_url(path),
doc_url=self._make_doc_url(params.path),
session_id=session_id,
)
except Exception as e:
logger.error(f"Failed to read documentation page {path}: {e}")
logger.error(f"Failed to read documentation page {params.path}: {e}")
return ErrorResponse(
message=f"Failed to read documentation page: {str(e)}",
error="read_failed",

View File

@@ -5,6 +5,7 @@ import uuid
from collections import defaultdict
from typing import Any
from pydantic import BaseModel, field_validator
from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession
@@ -29,6 +30,25 @@ from .utils import build_missing_credentials_from_field_info
logger = logging.getLogger(__name__)
class RunBlockInput(BaseModel):
"""Input parameters for the run_block tool."""
block_id: str = ""
input_data: dict[str, Any] = {}
@field_validator("block_id", mode="before")
@classmethod
def strip_block_id(cls, v: Any) -> str:
"""Strip whitespace from block_id."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
@field_validator("input_data", mode="before")
@classmethod
def ensure_dict(cls, v: Any) -> dict[str, Any]:
"""Ensure input_data is a dict."""
return v if isinstance(v, dict) else {}
class RunBlockTool(BaseTool):
"""Tool for executing a block and returning its outputs."""
@@ -162,37 +182,29 @@ class RunBlockTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs,
**kwargs: Any,
) -> ToolResponseBase:
"""Execute a block with the given input data.
Args:
user_id: User ID (required)
session: Chat session
block_id: Block UUID to execute
input_data: Input values for the block
**kwargs: Tool parameters
Returns:
BlockOutputResponse: Block execution outputs
SetupRequirementsResponse: Missing credentials
ErrorResponse: Error message
"""
block_id = kwargs.get("block_id", "").strip()
input_data = kwargs.get("input_data", {})
params = RunBlockInput(**kwargs)
session_id = session.session_id
if not block_id:
if not params.block_id:
return ErrorResponse(
message="Please provide a block_id",
session_id=session_id,
)
if not isinstance(input_data, dict):
return ErrorResponse(
message="input_data must be an object",
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="Authentication required",
@@ -200,23 +212,25 @@ class RunBlockTool(BaseTool):
)
# Get the block
block = get_block(block_id)
block = get_block(params.block_id)
if not block:
return ErrorResponse(
message=f"Block '{block_id}' not found",
message=f"Block '{params.block_id}' not found",
session_id=session_id,
)
if block.disabled:
return ErrorResponse(
message=f"Block '{block_id}' is disabled",
message=f"Block '{params.block_id}' is disabled",
session_id=session_id,
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
logger.info(
f"Executing block {block.name} ({params.block_id}) for user {user_id}"
)
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block, input_data
user_id, block, params.input_data
)
if missing_credentials:
@@ -234,7 +248,7 @@ class RunBlockTool(BaseTool):
),
session_id=session_id,
setup_info=SetupInfo(
agent_id=block_id,
agent_id=params.block_id,
agent_name=block.name,
user_readiness=UserReadiness(
has_all_credentials=False,
@@ -263,7 +277,7 @@ class RunBlockTool(BaseTool):
# - node_exec_id = unique per block execution
synthetic_graph_id = f"copilot-session-{session.session_id}"
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
synthetic_node_id = f"copilot-node-{block_id}"
synthetic_node_id = f"copilot-node-{params.block_id}"
synthetic_node_exec_id = (
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
)
@@ -298,8 +312,8 @@ class RunBlockTool(BaseTool):
for field_name, cred_meta in matched_credentials.items():
# Inject metadata into input_data (for validation)
if field_name not in input_data:
input_data[field_name] = cred_meta.model_dump()
if field_name not in params.input_data:
params.input_data[field_name] = cred_meta.model_dump()
# Fetch actual credentials and pass as kwargs (for execution)
actual_credentials = await creds_manager.get(
@@ -316,14 +330,14 @@ class RunBlockTool(BaseTool):
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
input_data,
params.input_data,
**exec_kwargs,
):
outputs[output_name].append(output_data)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=block_id,
block_id=params.block_id,
block_name=block.name,
outputs=dict(outputs),
success=True,

View File

@@ -4,6 +4,7 @@ import logging
from typing import Any
from prisma.enums import ContentType
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
@@ -28,6 +29,18 @@ MAX_RESULTS = 5
SNIPPET_LENGTH = 200
class SearchDocsInput(BaseModel):
"""Input parameters for the search_docs tool."""
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from query."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class SearchDocsTool(BaseTool):
"""Tool for searching AutoGPT platform documentation."""
@@ -91,24 +104,24 @@ class SearchDocsTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs,
**kwargs: Any,
) -> ToolResponseBase:
"""Search documentation and return relevant sections.
Args:
user_id: User ID (not required for docs)
session: Chat session
query: Search query
**kwargs: Tool parameters
Returns:
DocSearchResultsResponse: List of matching documentation sections
NoResultsResponse: No results found
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
params = SearchDocsInput(**kwargs)
session_id = session.session_id if session else None
if not query:
if not params.query:
return ErrorResponse(
message="Please provide a search query.",
error="Missing query parameter",
@@ -118,7 +131,7 @@ class SearchDocsTool(BaseTool):
try:
# Search using hybrid search for DOCUMENTATION content type only
results, total = await unified_hybrid_search(
query=query,
query=params.query,
content_types=[ContentType.DOCUMENTATION],
page=1,
page_size=MAX_RESULTS * 2, # Fetch extra for deduplication
@@ -127,7 +140,7 @@ class SearchDocsTool(BaseTool):
if not results:
return NoResultsResponse(
message=f"No documentation found for '{query}'.",
message=f"No documentation found for '{params.query}'.",
suggestions=[
"Try different keywords",
"Use more general terms",
@@ -162,7 +175,7 @@ class SearchDocsTool(BaseTool):
if not deduplicated:
return NoResultsResponse(
message=f"No documentation found for '{query}'.",
message=f"No documentation found for '{params.query}'.",
suggestions=[
"Try different keywords",
"Use more general terms",
@@ -195,7 +208,7 @@ class SearchDocsTool(BaseTool):
message=f"Found {len(doc_results)} relevant documentation sections.",
results=doc_results,
count=len(doc_results),
query=query,
query=params.query,
session_id=session_id,
)

View File

@@ -2,9 +2,9 @@
import base64
import logging
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace
@@ -78,6 +78,65 @@ class WorkspaceDeleteResponse(ToolResponseBase):
success: bool
# Input models for workspace tools
class ListWorkspaceFilesInput(BaseModel):
"""Input parameters for list_workspace_files tool."""
path_prefix: str | None = None
limit: int = 50
include_all_sessions: bool = False
@field_validator("path_prefix", mode="before")
@classmethod
def strip_path(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class ReadWorkspaceFileInput(BaseModel):
"""Input parameters for read_workspace_file tool."""
file_id: str | None = None
path: str | None = None
force_download_url: bool = False
@field_validator("file_id", "path", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class WriteWorkspaceFileInput(BaseModel):
"""Input parameters for write_workspace_file tool."""
filename: str = ""
content_base64: str = ""
path: str | None = None
mime_type: str | None = None
overwrite: bool = False
@field_validator("filename", "content_base64", mode="before")
@classmethod
def strip_required(cls, v: Any) -> str:
return v.strip() if isinstance(v, str) else (v if v is not None else "")
@field_validator("path", "mime_type", mode="before")
@classmethod
def strip_optional(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class DeleteWorkspaceFileInput(BaseModel):
"""Input parameters for delete_workspace_file tool."""
file_id: str | None = None
path: str | None = None
@field_validator("file_id", "path", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class ListWorkspaceFilesTool(BaseTool):
"""Tool for listing files in user's workspace."""
@@ -131,8 +190,9 @@ class ListWorkspaceFilesTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs,
**kwargs: Any,
) -> ToolResponseBase:
params = ListWorkspaceFilesInput(**kwargs)
session_id = session.session_id
if not user_id:
@@ -141,9 +201,7 @@ class ListWorkspaceFilesTool(BaseTool):
session_id=session_id,
)
path_prefix: Optional[str] = kwargs.get("path_prefix")
limit = min(kwargs.get("limit", 50), 100)
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
limit = min(params.limit, 100)
try:
workspace = await get_or_create_workspace(user_id)
@@ -151,13 +209,13 @@ class ListWorkspaceFilesTool(BaseTool):
manager = WorkspaceManager(user_id, workspace.id, session_id)
files = await manager.list_files(
path=path_prefix,
path=params.path_prefix,
limit=limit,
include_all_sessions=include_all_sessions,
include_all_sessions=params.include_all_sessions,
)
total = await manager.get_file_count(
path=path_prefix,
include_all_sessions=include_all_sessions,
path=params.path_prefix,
include_all_sessions=params.include_all_sessions,
)
file_infos = [
@@ -171,7 +229,9 @@ class ListWorkspaceFilesTool(BaseTool):
for f in files
]
scope_msg = "all sessions" if include_all_sessions else "current session"
scope_msg = (
"all sessions" if params.include_all_sessions else "current session"
)
return WorkspaceFileListResponse(
files=file_infos,
total_count=total,
@@ -259,8 +319,9 @@ class ReadWorkspaceFileTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs,
**kwargs: Any,
) -> ToolResponseBase:
params = ReadWorkspaceFileInput(**kwargs)
session_id = session.session_id
if not user_id:
@@ -269,11 +330,7 @@ class ReadWorkspaceFileTool(BaseTool):
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
force_download_url: bool = kwargs.get("force_download_url", False)
if not file_id and not path:
if not params.file_id and not params.path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
@@ -285,21 +342,21 @@ class ReadWorkspaceFileTool(BaseTool):
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Get file info
if file_id:
file_info = await manager.get_file_info(file_id)
if params.file_id:
file_info = await manager.get_file_info(params.file_id)
if file_info is None:
return ErrorResponse(
message=f"File not found: {file_id}",
message=f"File not found: {params.file_id}",
session_id=session_id,
)
target_file_id = file_id
target_file_id = params.file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
assert params.path is not None
file_info = await manager.get_file_info_by_path(params.path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
message=f"File not found at path: {params.path}",
session_id=session_id,
)
target_file_id = file_info.id
@@ -309,7 +366,7 @@ class ReadWorkspaceFileTool(BaseTool):
is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url:
if is_small_file and is_text_file and not params.force_download_url:
content = await manager.read_file_by_id(target_file_id)
content_b64 = base64.b64encode(content).decode("utf-8")
@@ -429,8 +486,9 @@ class WriteWorkspaceFileTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs,
**kwargs: Any,
) -> ToolResponseBase:
params = WriteWorkspaceFileInput(**kwargs)
session_id = session.session_id
if not user_id:
@@ -439,19 +497,13 @@ class WriteWorkspaceFileTool(BaseTool):
session_id=session_id,
)
filename: str = kwargs.get("filename", "")
content_b64: str = kwargs.get("content_base64", "")
path: Optional[str] = kwargs.get("path")
mime_type: Optional[str] = kwargs.get("mime_type")
overwrite: bool = kwargs.get("overwrite", False)
if not filename:
if not params.filename:
return ErrorResponse(
message="Please provide a filename",
session_id=session_id,
)
if not content_b64:
if not params.content_base64:
return ErrorResponse(
message="Please provide content_base64",
session_id=session_id,
@@ -459,7 +511,7 @@ class WriteWorkspaceFileTool(BaseTool):
# Decode content
try:
content = base64.b64decode(content_b64)
content = base64.b64decode(params.content_base64)
except Exception:
return ErrorResponse(
message="Invalid base64-encoded content",
@@ -476,7 +528,7 @@ class WriteWorkspaceFileTool(BaseTool):
try:
# Virus scan
await scan_content_safe(content, filename=filename)
await scan_content_safe(content, filename=params.filename)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
@@ -484,10 +536,10 @@ class WriteWorkspaceFileTool(BaseTool):
file_record = await manager.write_file(
content=content,
filename=filename,
path=path,
mime_type=mime_type,
overwrite=overwrite,
filename=params.filename,
path=params.path,
mime_type=params.mime_type,
overwrite=params.overwrite,
)
return WorkspaceWriteResponse(
@@ -557,8 +609,9 @@ class DeleteWorkspaceFileTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs,
**kwargs: Any,
) -> ToolResponseBase:
params = DeleteWorkspaceFileInput(**kwargs)
session_id = session.session_id
if not user_id:
@@ -567,10 +620,7 @@ class DeleteWorkspaceFileTool(BaseTool):
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
if not file_id and not path:
if not params.file_id and not params.path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
@@ -583,15 +633,15 @@ class DeleteWorkspaceFileTool(BaseTool):
# Determine the file_id to delete
target_file_id: str
if file_id:
target_file_id = file_id
if params.file_id:
target_file_id = params.file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
assert params.path is not None
file_info = await manager.get_file_info_by_path(params.path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
message=f"File not found at path: {params.path}",
session_id=session_id,
)
target_file_id = file_info.id