Compare commits

...

4 Commits

Author SHA1 Message Date
Otto
a093d57ed2 fix: address CodeRabbit review bugs
- customize_agent.py: Strip whitespace from split parts of agent_id
- edit_agent.py: Use model_config instead of deprecated class Config
- edit_agent.py: Fix undefined agent_id/changes → params.agent_id/params.changes
- find_library_agent.py: Remove docstrings per coding guidelines
- get_doc_page.py: Fix undefined path → params.path
- run_block.py: Fix undefined block_id → params.block_id
- workspace_files.py: Fix undefined include_all_sessions → params.include_all_sessions
2026-02-04 09:17:44 +00:00
Otto
6692f39cbd refactor(copilot): add Pydantic input models to all tools
Convert all CoPilot tools from kwargs.get() pattern to Pydantic models:

Tools updated:
- find_agent.py: FindAgentInput
- find_library_agent.py: FindLibraryAgentInput
- find_block.py: FindBlockInput
- search_docs.py: SearchDocsInput
- get_doc_page.py: GetDocPageInput
- create_agent.py: CreateAgentInput
- edit_agent.py: EditAgentInput
- run_block.py: RunBlockInput
- workspace_files.py: 4 input models (List/Read/Write/Delete)

Benefits:
- Type safety with automatic validation
- Consistent string stripping via field_validators
- Better IDE support and error messages
- Cleaner _execute methods using params object

Addresses ntindle review feedback about kwargs pattern.
2026-02-04 09:05:18 +00:00
Otto
aeba28266c refactor(copilot): use Pydantic models and match/case in customize_agent
Addresses review feedback from ntindle on PR #11943:

1. Use typed parameters instead of kwargs.get():
   - Added CustomizeAgentInput Pydantic model with field_validator
   - Tool now uses params = CustomizeAgentInput(**kwargs) pattern

2. Use match/case for cleaner pattern matching:
   - Extracted response handling to _handle_customization_result method
   - Uses match result_type: case 'error' | 'clarifying_questions' | _

3. Improved code organization:
   - Split monolithic _execute into smaller focused methods
   - _handle_customization_result for response type handling
   - _save_or_preview_agent for final save/preview logic
2026-02-04 08:54:27 +00:00
Otto
6d8c83c039 refactor(backend): move local imports to module level in chat service
Addresses review feedback from PRs #11937, #11856:
- Move uuid import to top level (was duplicated in 3 functions)
- Move compress_context import to top level
- Remove redundant local imports for cast and ChatCompletionMessageParam
  (already imported at module level)

Refs:
- https://github.com/Significant-Gravitas/AutoGPT/pull/11937#discussion_r2761107861
- https://github.com/Significant-Gravitas/AutoGPT/pull/11856#discussion_r2761558008
- https://github.com/Significant-Gravitas/AutoGPT/pull/11856#discussion_r2761559661
2026-02-04 03:33:15 +00:00
11 changed files with 425 additions and 215 deletions

View File

@@ -1,12 +1,15 @@
import asyncio import asyncio
import logging import logging
import time import time
import uuid as uuid_module
from asyncio import CancelledError from asyncio import CancelledError
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
import openai import openai
from backend.util.prompt import compress_context
if TYPE_CHECKING: if TYPE_CHECKING:
from backend.util.prompt import CompressResult from backend.util.prompt import CompressResult
@@ -467,8 +470,6 @@ async def stream_chat_completion(
should_retry = False should_retry = False
# Generate unique IDs for AI SDK protocol # Generate unique IDs for AI SDK protocol
import uuid as uuid_module
message_id = str(uuid_module.uuid4()) message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4()) text_block_id = str(uuid_module.uuid4())
@@ -826,10 +827,6 @@ async def _manage_context_window(
Returns: Returns:
CompressResult with compacted messages and metadata CompressResult with compacted messages and metadata
""" """
import openai
from backend.util.prompt import compress_context
# Convert messages to dict format # Convert messages to dict format
messages_dict = [] messages_dict = []
for msg in messages: for msg in messages:
@@ -1140,8 +1137,6 @@ async def _yield_tool_call(
KeyError: If expected tool call fields are missing KeyError: If expected tool call fields are missing
TypeError: If tool call structure is invalid TypeError: If tool call structure is invalid
""" """
import uuid as uuid_module
tool_name = tool_calls[yield_idx]["function"]["name"] tool_name = tool_calls[yield_idx]["function"]["name"]
tool_call_id = tool_calls[yield_idx]["id"] tool_call_id = tool_calls[yield_idx]["id"]
@@ -1762,8 +1757,6 @@ async def _generate_llm_continuation_with_streaming(
after a tool result is saved. Chunks are published to the stream registry after a tool result is saved. Chunks are published to the stream registry
so reconnecting clients can receive them. so reconnecting clients can receive them.
""" """
import uuid as uuid_module
try: try:
# Load fresh session from DB (bypass cache to get the updated tool result) # Load fresh session from DB (bypass cache to get the updated tool result)
await invalidate_session_cache(session_id) await invalidate_session_cache(session_id)
@@ -1799,10 +1792,6 @@ async def _generate_llm_continuation_with_streaming(
extra_body["session_id"] = session_id[:128] extra_body["session_id"] = session_id[:128]
# Make streaming LLM call (no tools - just text response) # Make streaming LLM call (no tools - just text response)
from typing import cast
from openai.types.chat import ChatCompletionMessageParam
# Generate unique IDs for AI SDK protocol # Generate unique IDs for AI SDK protocol
message_id = str(uuid_module.uuid4()) message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4()) text_block_id = str(uuid_module.uuid4())

View File

@@ -3,6 +3,8 @@
import logging import logging
from typing import Any from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_generator import ( from .agent_generator import (
@@ -28,6 +30,26 @@ from .models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CreateAgentInput(BaseModel):
"""Input parameters for the create_agent tool."""
description: str = ""
context: str = ""
save: bool = True
# Internal async processing params (passed by long-running tool handler)
_operation_id: str | None = None
_task_id: str | None = None
@field_validator("description", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
"""Strip whitespace from string fields."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class Config:
extra = "allow" # Allow _operation_id, _task_id from kwargs
class CreateAgentTool(BaseTool): class CreateAgentTool(BaseTool):
"""Tool for creating agents from natural language descriptions.""" """Tool for creating agents from natural language descriptions."""
@@ -85,7 +107,7 @@ class CreateAgentTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs, **kwargs: Any,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Execute the create_agent tool. """Execute the create_agent tool.
@@ -94,16 +116,14 @@ class CreateAgentTool(BaseTool):
2. Generate agent JSON (external service handles fixing and validation) 2. Generate agent JSON (external service handles fixing and validation)
3. Preview or save based on the save parameter 3. Preview or save based on the save parameter
""" """
description = kwargs.get("description", "").strip() params = CreateAgentInput(**kwargs)
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler) # Extract async processing params
operation_id = kwargs.get("_operation_id") operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id") task_id = kwargs.get("_task_id")
if not description: if not params.description:
return ErrorResponse( return ErrorResponse(
message="Please provide a description of what the agent should do.", message="Please provide a description of what the agent should do.",
error="Missing description parameter", error="Missing description parameter",
@@ -115,7 +135,7 @@ class CreateAgentTool(BaseTool):
try: try:
library_agents = await get_all_relevant_agents_for_generation( library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id, user_id=user_id,
search_query=description, search_query=params.description,
include_marketplace=True, include_marketplace=True,
) )
logger.debug( logger.debug(
@@ -126,7 +146,7 @@ class CreateAgentTool(BaseTool):
try: try:
decomposition_result = await decompose_goal( decomposition_result = await decompose_goal(
description, context, library_agents params.description, params.context, library_agents
) )
except AgentGeneratorNotConfiguredError: except AgentGeneratorNotConfiguredError:
return ErrorResponse( return ErrorResponse(
@@ -142,7 +162,7 @@ class CreateAgentTool(BaseTool):
return ErrorResponse( return ErrorResponse(
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.", message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
error="decomposition_failed", error="decomposition_failed",
details={"description": description[:100]}, details={"description": params.description[:100]},
session_id=session_id, session_id=session_id,
) )
@@ -158,7 +178,7 @@ class CreateAgentTool(BaseTool):
message=user_message, message=user_message,
error=f"decomposition_failed:{error_type}", error=f"decomposition_failed:{error_type}",
details={ details={
"description": description[:100], "description": params.description[:100],
"service_error": error_msg, "service_error": error_msg,
"error_type": error_type, "error_type": error_type,
}, },
@@ -244,7 +264,7 @@ class CreateAgentTool(BaseTool):
return ErrorResponse( return ErrorResponse(
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.", message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
error="generation_failed", error="generation_failed",
details={"description": description[:100]}, details={"description": params.description[:100]},
session_id=session_id, session_id=session_id,
) )
@@ -266,7 +286,7 @@ class CreateAgentTool(BaseTool):
message=user_message, message=user_message,
error=f"generation_failed:{error_type}", error=f"generation_failed:{error_type}",
details={ details={
"description": description[:100], "description": params.description[:100],
"service_error": error_msg, "service_error": error_msg,
"error_type": error_type, "error_type": error_type,
}, },
@@ -291,7 +311,7 @@ class CreateAgentTool(BaseTool):
node_count = len(agent_json.get("nodes", [])) node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", [])) link_count = len(agent_json.get("links", []))
if not save: if not params.save:
return AgentPreviewResponse( return AgentPreviewResponse(
message=( message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. " f"I've generated an agent called '{agent_name}' with {node_count} blocks. "

View File

@@ -3,6 +3,8 @@
import logging import logging
from typing import Any from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError from backend.api.features.store.exceptions import AgentNotFoundError
@@ -27,6 +29,23 @@ from .models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CustomizeAgentInput(BaseModel):
"""Input parameters for the customize_agent tool."""
agent_id: str = ""
modifications: str = ""
context: str = ""
save: bool = True
@field_validator("agent_id", "modifications", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
"""Strip whitespace from string fields."""
if isinstance(v, str):
return v.strip()
return v if v is not None else ""
class CustomizeAgentTool(BaseTool): class CustomizeAgentTool(BaseTool):
"""Tool for customizing marketplace/template agents using natural language.""" """Tool for customizing marketplace/template agents using natural language."""
@@ -92,7 +111,7 @@ class CustomizeAgentTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs, **kwargs: Any,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Execute the customize_agent tool. """Execute the customize_agent tool.
@@ -102,20 +121,17 @@ class CustomizeAgentTool(BaseTool):
3. Call customize_template with the modification request 3. Call customize_template with the modification request
4. Preview or save based on the save parameter 4. Preview or save based on the save parameter
""" """
agent_id = kwargs.get("agent_id", "").strip() params = CustomizeAgentInput(**kwargs)
modifications = kwargs.get("modifications", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None session_id = session.session_id if session else None
if not agent_id: if not params.agent_id:
return ErrorResponse( return ErrorResponse(
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').", message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
error="missing_agent_id", error="missing_agent_id",
session_id=session_id, session_id=session_id,
) )
if not modifications: if not params.modifications:
return ErrorResponse( return ErrorResponse(
message="Please describe how you want to customize this agent.", message="Please describe how you want to customize this agent.",
error="missing_modifications", error="missing_modifications",
@@ -123,11 +139,11 @@ class CustomizeAgentTool(BaseTool):
) )
# Parse agent_id in format "creator/slug" # Parse agent_id in format "creator/slug"
parts = [p.strip() for p in agent_id.split("/")] parts = [p.strip() for p in params.agent_id.split("/")]
if len(parts) != 2 or not parts[0] or not parts[1]: if len(parts) != 2 or not parts[0] or not parts[1]:
return ErrorResponse( return ErrorResponse(
message=( message=(
f"Invalid agent ID format: '{agent_id}'. " f"Invalid agent ID format: '{params.agent_id}'. "
"Expected format is 'creator/agent-name' " "Expected format is 'creator/agent-name' "
"(e.g., 'autogpt/newsletter-writer')." "(e.g., 'autogpt/newsletter-writer')."
), ),
@@ -145,14 +161,14 @@ class CustomizeAgentTool(BaseTool):
except AgentNotFoundError: except AgentNotFoundError:
return ErrorResponse( return ErrorResponse(
message=( message=(
f"Could not find marketplace agent '{agent_id}'. " f"Could not find marketplace agent '{params.agent_id}'. "
"Please check the agent ID and try again." "Please check the agent ID and try again."
), ),
error="agent_not_found", error="agent_not_found",
session_id=session_id, session_id=session_id,
) )
except Exception as e: except Exception as e:
logger.error(f"Error fetching marketplace agent {agent_id}: {e}") logger.error(f"Error fetching marketplace agent {params.agent_id}: {e}")
return ErrorResponse( return ErrorResponse(
message="Failed to fetch the marketplace agent. Please try again.", message="Failed to fetch the marketplace agent. Please try again.",
error="fetch_error", error="fetch_error",
@@ -162,7 +178,7 @@ class CustomizeAgentTool(BaseTool):
if not agent_details.store_listing_version_id: if not agent_details.store_listing_version_id:
return ErrorResponse( return ErrorResponse(
message=( message=(
f"The agent '{agent_id}' does not have an available version. " f"The agent '{params.agent_id}' does not have an available version. "
"Please try a different agent." "Please try a different agent."
), ),
error="no_version_available", error="no_version_available",
@@ -174,7 +190,7 @@ class CustomizeAgentTool(BaseTool):
graph = await store_db.get_agent(agent_details.store_listing_version_id) graph = await store_db.get_agent(agent_details.store_listing_version_id)
template_agent = graph_to_json(graph) template_agent = graph_to_json(graph)
except Exception as e: except Exception as e:
logger.error(f"Error fetching agent graph for {agent_id}: {e}") logger.error(f"Error fetching agent graph for {params.agent_id}: {e}")
return ErrorResponse( return ErrorResponse(
message="Failed to fetch the agent configuration. Please try again.", message="Failed to fetch the agent configuration. Please try again.",
error="graph_fetch_error", error="graph_fetch_error",
@@ -185,8 +201,8 @@ class CustomizeAgentTool(BaseTool):
try: try:
result = await customize_template( result = await customize_template(
template_agent=template_agent, template_agent=template_agent,
modification_request=modifications, modification_request=params.modifications,
context=context, context=params.context,
) )
except AgentGeneratorNotConfiguredError: except AgentGeneratorNotConfiguredError:
return ErrorResponse( return ErrorResponse(
@@ -198,7 +214,7 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
except Exception as e: except Exception as e:
logger.error(f"Error calling customize_template for {agent_id}: {e}") logger.error(f"Error calling customize_template for {params.agent_id}: {e}")
return ErrorResponse( return ErrorResponse(
message=( message=(
"Failed to customize the agent due to a service error. " "Failed to customize the agent due to a service error. "
@@ -219,55 +235,25 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
# Handle error response # Handle response using match/case for cleaner pattern matching
if isinstance(result, dict) and result.get("type") == "error": return await self._handle_customization_result(
error_msg = result.get("error", "Unknown error") result=result,
error_type = result.get("error_type", "unknown") params=params,
user_message = get_user_message_for_error( agent_details=agent_details,
error_type, user_id=user_id,
operation="customize the agent", session_id=session_id,
llm_parse_message=( )
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
# Handle clarifying questions async def _handle_customization_result(
if isinstance(result, dict) and result.get("type") == "clarifying_questions": self,
questions = result.get("questions") or [] result: dict[str, Any],
if not isinstance(questions, list): params: CustomizeAgentInput,
logger.error( agent_details: Any,
f"Unexpected clarifying questions format: {type(questions)}" user_id: str | None,
) session_id: str | None,
questions = [] ) -> ToolResponseBase:
return ClarificationNeededResponse( """Handle the result from customize_template using pattern matching."""
message=( # Ensure result is a dict
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=[
ClarifyingQuestion(
question=q.get("question", ""),
keyword=q.get("keyword", ""),
example=q.get("example"),
)
for q in questions
if isinstance(q, dict)
],
session_id=session_id,
)
# Result should be the customized agent JSON
if not isinstance(result, dict): if not isinstance(result, dict):
logger.error(f"Unexpected customize_template response type: {type(result)}") logger.error(f"Unexpected customize_template response type: {type(result)}")
return ErrorResponse( return ErrorResponse(
@@ -276,8 +262,77 @@ class CustomizeAgentTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
customized_agent = result result_type = result.get("type")
match result_type:
case "error":
error_msg = result.get("error", "Unknown error")
error_type = result.get("error_type", "unknown")
user_message = get_user_message_for_error(
error_type,
operation="customize the agent",
llm_parse_message=(
"The AI had trouble customizing the agent. "
"Please try again or simplify your request."
),
validation_message=(
"The customized agent failed validation. "
"Please try rephrasing your request."
),
error_details=error_msg,
)
return ErrorResponse(
message=user_message,
error=f"customization_failed:{error_type}",
session_id=session_id,
)
case "clarifying_questions":
questions_data = result.get("questions") or []
if not isinstance(questions_data, list):
logger.error(
f"Unexpected clarifying questions format: {type(questions_data)}"
)
questions_data = []
questions = [
ClarifyingQuestion(
question=q.get("question", "") if isinstance(q, dict) else "",
keyword=q.get("keyword", "") if isinstance(q, dict) else "",
example=q.get("example") if isinstance(q, dict) else None,
)
for q in questions_data
if isinstance(q, dict)
]
return ClarificationNeededResponse(
message=(
"I need some more information to customize this agent. "
"Please answer the following questions:"
),
questions=questions,
session_id=session_id,
)
case _:
# Default case: result is the customized agent JSON
return await self._save_or_preview_agent(
customized_agent=result,
params=params,
agent_details=agent_details,
user_id=user_id,
session_id=session_id,
)
async def _save_or_preview_agent(
self,
customized_agent: dict[str, Any],
params: CustomizeAgentInput,
agent_details: Any,
user_id: str | None,
session_id: str | None,
) -> ToolResponseBase:
"""Save or preview the customized agent based on params.save."""
agent_name = customized_agent.get( agent_name = customized_agent.get(
"name", f"Customized {agent_details.agent_name}" "name", f"Customized {agent_details.agent_name}"
) )
@@ -287,7 +342,7 @@ class CustomizeAgentTool(BaseTool):
node_count = len(nodes) if isinstance(nodes, list) else 0 node_count = len(nodes) if isinstance(nodes, list) else 0
link_count = len(links) if isinstance(links, list) else 0 link_count = len(links) if isinstance(links, list) else 0
if not save: if not params.save:
return AgentPreviewResponse( return AgentPreviewResponse(
message=( message=(
f"I've customized the agent '{agent_details.agent_name}'. " f"I've customized the agent '{agent_details.agent_name}'. "

View File

@@ -3,6 +3,8 @@
import logging import logging
from typing import Any from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_generator import ( from .agent_generator import (
@@ -27,6 +29,20 @@ from .models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EditAgentInput(BaseModel):
model_config = ConfigDict(extra="allow")
agent_id: str = ""
changes: str = ""
context: str = ""
save: bool = True
@field_validator("agent_id", "changes", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class EditAgentTool(BaseTool): class EditAgentTool(BaseTool):
"""Tool for editing existing agents using natural language.""" """Tool for editing existing agents using natural language."""
@@ -90,7 +106,7 @@ class EditAgentTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs, **kwargs: Any,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Execute the edit_agent tool. """Execute the edit_agent tool.
@@ -99,35 +115,32 @@ class EditAgentTool(BaseTool):
2. Generate updated agent (external service handles fixing and validation) 2. Generate updated agent (external service handles fixing and validation)
3. Preview or save based on the save parameter 3. Preview or save based on the save parameter
""" """
agent_id = kwargs.get("agent_id", "").strip() params = EditAgentInput(**kwargs)
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler) # Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id") operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id") task_id = kwargs.get("_task_id")
if not agent_id: if not params.agent_id:
return ErrorResponse( return ErrorResponse(
message="Please provide the agent ID to edit.", message="Please provide the agent ID to edit.",
error="Missing agent_id parameter", error="Missing agent_id parameter",
session_id=session_id, session_id=session_id,
) )
if not changes: if not params.changes:
return ErrorResponse( return ErrorResponse(
message="Please describe what changes you want to make.", message="Please describe what changes you want to make.",
error="Missing changes parameter", error="Missing changes parameter",
session_id=session_id, session_id=session_id,
) )
current_agent = await get_agent_as_json(agent_id, user_id) current_agent = await get_agent_as_json(params.agent_id, user_id)
if current_agent is None: if current_agent is None:
return ErrorResponse( return ErrorResponse(
message=f"Could not find agent with ID '{agent_id}' in your library.", message=f"Could not find agent '{params.agent_id}' in your library.",
error="agent_not_found", error="agent_not_found",
session_id=session_id, session_id=session_id,
) )
@@ -138,7 +151,7 @@ class EditAgentTool(BaseTool):
graph_id = current_agent.get("id") graph_id = current_agent.get("id")
library_agents = await get_all_relevant_agents_for_generation( library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id, user_id=user_id,
search_query=changes, search_query=params.changes,
exclude_graph_id=graph_id, exclude_graph_id=graph_id,
include_marketplace=True, include_marketplace=True,
) )
@@ -148,9 +161,11 @@ class EditAgentTool(BaseTool):
except Exception as e: except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}") logger.warning(f"Failed to fetch library agents: {e}")
update_request = changes update_request = params.changes
if context: if params.context:
update_request = f"{changes}\n\nAdditional context:\n{context}" update_request = (
f"{params.changes}\n\nAdditional context:\n{params.context}"
)
try: try:
result = await generate_agent_patch( result = await generate_agent_patch(
@@ -174,7 +189,7 @@ class EditAgentTool(BaseTool):
return ErrorResponse( return ErrorResponse(
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.", message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
error="update_generation_failed", error="update_generation_failed",
details={"agent_id": agent_id, "changes": changes[:100]}, details={"agent_id": params.agent_id, "changes": params.changes[:100]},
session_id=session_id, session_id=session_id,
) )
@@ -206,8 +221,8 @@ class EditAgentTool(BaseTool):
message=user_message, message=user_message,
error=f"update_generation_failed:{error_type}", error=f"update_generation_failed:{error_type}",
details={ details={
"agent_id": agent_id, "agent_id": params.agent_id,
"changes": changes[:100], "changes": params.changes[:100],
"service_error": error_msg, "service_error": error_msg,
"error_type": error_type, "error_type": error_type,
}, },
@@ -239,7 +254,7 @@ class EditAgentTool(BaseTool):
node_count = len(updated_agent.get("nodes", [])) node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", [])) link_count = len(updated_agent.get("links", []))
if not save: if not params.save:
return AgentPreviewResponse( return AgentPreviewResponse(
message=( message=(
f"I've updated the agent. " f"I've updated the agent. "

View File

@@ -2,6 +2,8 @@
from typing import Any from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents from .agent_search import search_agents
@@ -9,6 +11,18 @@ from .base import BaseTool
from .models import ToolResponseBase from .models import ToolResponseBase
class FindAgentInput(BaseModel):
"""Input parameters for the find_agent tool."""
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from query."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class FindAgentTool(BaseTool): class FindAgentTool(BaseTool):
"""Tool for discovering agents from the marketplace.""" """Tool for discovering agents from the marketplace."""
@@ -36,10 +50,11 @@ class FindAgentTool(BaseTool):
} }
async def _execute( async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs self, user_id: str | None, session: ChatSession, **kwargs: Any
) -> ToolResponseBase: ) -> ToolResponseBase:
params = FindAgentInput(**kwargs)
return await search_agents( return await search_agents(
query=kwargs.get("query", "").strip(), query=params.query,
source="marketplace", source="marketplace",
session_id=session.session_id, session_id=session.session_id,
user_id=user_id, user_id=user_id,

View File

@@ -2,6 +2,7 @@ import logging
from typing import Any from typing import Any
from prisma.enums import ContentType from prisma.enums import ContentType
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
@@ -18,6 +19,18 @@ from backend.data.block import get_block
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class FindBlockInput(BaseModel):
"""Input parameters for the find_block tool."""
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from query."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class FindBlockTool(BaseTool): class FindBlockTool(BaseTool):
"""Tool for searching available blocks.""" """Tool for searching available blocks."""
@@ -59,24 +72,24 @@ class FindBlockTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs, **kwargs: Any,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Search for blocks matching the query. """Search for blocks matching the query.
Args: Args:
user_id: User ID (required) user_id: User ID (required)
session: Chat session session: Chat session
query: Search query **kwargs: Tool parameters
Returns: Returns:
BlockListResponse: List of matching blocks BlockListResponse: List of matching blocks
NoResultsResponse: No blocks found NoResultsResponse: No blocks found
ErrorResponse: Error message ErrorResponse: Error message
""" """
query = kwargs.get("query", "").strip() params = FindBlockInput(**kwargs)
session_id = session.session_id session_id = session.session_id
if not query: if not params.query:
return ErrorResponse( return ErrorResponse(
message="Please provide a search query", message="Please provide a search query",
session_id=session_id, session_id=session_id,
@@ -85,7 +98,7 @@ class FindBlockTool(BaseTool):
try: try:
# Search for blocks using hybrid search # Search for blocks using hybrid search
results, total = await unified_hybrid_search( results, total = await unified_hybrid_search(
query=query, query=params.query,
content_types=[ContentType.BLOCK], content_types=[ContentType.BLOCK],
page=1, page=1,
page_size=10, page_size=10,
@@ -93,7 +106,7 @@ class FindBlockTool(BaseTool):
if not results: if not results:
return NoResultsResponse( return NoResultsResponse(
message=f"No blocks found for '{query}'", message=f"No blocks found for '{params.query}'",
suggestions=[ suggestions=[
"Try broader keywords like 'email', 'http', 'text', 'ai'", "Try broader keywords like 'email', 'http', 'text', 'ai'",
"Check spelling of technical terms", "Check spelling of technical terms",
@@ -165,7 +178,7 @@ class FindBlockTool(BaseTool):
if not blocks: if not blocks:
return NoResultsResponse( return NoResultsResponse(
message=f"No blocks found for '{query}'", message=f"No blocks found for '{params.query}'",
suggestions=[ suggestions=[
"Try broader keywords like 'email', 'http', 'text', 'ai'", "Try broader keywords like 'email', 'http', 'text', 'ai'",
], ],
@@ -174,13 +187,13 @@ class FindBlockTool(BaseTool):
return BlockListResponse( return BlockListResponse(
message=( message=(
f"Found {len(blocks)} block(s) matching '{query}'. " f"Found {len(blocks)} block(s) matching '{params.query}'. "
"To execute a block, use run_block with the block's 'id' field " "To execute a block, use run_block with the block's 'id' field "
"and provide 'input_data' matching the block's input_schema." "and provide 'input_data' matching the block's input_schema."
), ),
blocks=blocks, blocks=blocks,
count=len(blocks), count=len(blocks),
query=query, query=params.query,
session_id=session_id, session_id=session_id,
) )

View File

@@ -2,6 +2,8 @@
from typing import Any from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents from .agent_search import search_agents
@@ -9,6 +11,15 @@ from .base import BaseTool
from .models import ToolResponseBase from .models import ToolResponseBase
class FindLibraryAgentInput(BaseModel):
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class FindLibraryAgentTool(BaseTool): class FindLibraryAgentTool(BaseTool):
"""Tool for searching agents in the user's library.""" """Tool for searching agents in the user's library."""
@@ -42,10 +53,11 @@ class FindLibraryAgentTool(BaseTool):
return True return True
async def _execute( async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs self, user_id: str | None, session: ChatSession, **kwargs: Any
) -> ToolResponseBase: ) -> ToolResponseBase:
params = FindLibraryAgentInput(**kwargs)
return await search_agents( return await search_agents(
query=kwargs.get("query", "").strip(), query=params.query,
source="library", source="library",
session_id=session.session_id, session_id=session.session_id,
user_id=user_id, user_id=user_id,

View File

@@ -4,6 +4,8 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import ( from backend.api.features.chat.tools.models import (
@@ -18,6 +20,18 @@ logger = logging.getLogger(__name__)
DOCS_BASE_URL = "https://docs.agpt.co" DOCS_BASE_URL = "https://docs.agpt.co"
class GetDocPageInput(BaseModel):
"""Input parameters for the get_doc_page tool."""
path: str = ""
@field_validator("path", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from path."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class GetDocPageTool(BaseTool): class GetDocPageTool(BaseTool):
"""Tool for fetching full content of a documentation page.""" """Tool for fetching full content of a documentation page."""
@@ -75,23 +89,23 @@ class GetDocPageTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs, **kwargs: Any,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Fetch full content of a documentation page. """Fetch full content of a documentation page.
Args: Args:
user_id: User ID (not required for docs) user_id: User ID (not required for docs)
session: Chat session session: Chat session
path: Path to the documentation file **kwargs: Tool parameters
Returns: Returns:
DocPageResponse: Full document content DocPageResponse: Full document content
ErrorResponse: Error message ErrorResponse: Error message
""" """
path = kwargs.get("path", "").strip() params = GetDocPageInput(**kwargs)
session_id = session.session_id if session else None session_id = session.session_id if session else None
if not path: if not params.path:
return ErrorResponse( return ErrorResponse(
message="Please provide a documentation path.", message="Please provide a documentation path.",
error="Missing path parameter", error="Missing path parameter",
@@ -99,7 +113,7 @@ class GetDocPageTool(BaseTool):
) )
# Sanitize path to prevent directory traversal # Sanitize path to prevent directory traversal
if ".." in path or path.startswith("/"): if ".." in params.path or params.path.startswith("/"):
return ErrorResponse( return ErrorResponse(
message="Invalid documentation path.", message="Invalid documentation path.",
error="invalid_path", error="invalid_path",
@@ -107,11 +121,11 @@ class GetDocPageTool(BaseTool):
) )
docs_root = self._get_docs_root() docs_root = self._get_docs_root()
full_path = docs_root / path full_path = docs_root / params.path
if not full_path.exists(): if not full_path.exists():
return ErrorResponse( return ErrorResponse(
message=f"Documentation page not found: {path}", message=f"Documentation page not found: {params.path}",
error="not_found", error="not_found",
session_id=session_id, session_id=session_id,
) )
@@ -128,19 +142,19 @@ class GetDocPageTool(BaseTool):
try: try:
content = full_path.read_text(encoding="utf-8") content = full_path.read_text(encoding="utf-8")
title = self._extract_title(content, path) title = self._extract_title(content, params.path)
return DocPageResponse( return DocPageResponse(
message=f"Retrieved documentation page: {title}", message=f"Retrieved documentation page: {title}",
title=title, title=title,
path=path, path=params.path,
content=content, content=content,
doc_url=self._make_doc_url(path), doc_url=self._make_doc_url(params.path),
session_id=session_id, session_id=session_id,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to read documentation page {path}: {e}") logger.error(f"Failed to read documentation page {params.path}: {e}")
return ErrorResponse( return ErrorResponse(
message=f"Failed to read documentation page: {str(e)}", message=f"Failed to read documentation page: {str(e)}",
error="read_failed", error="read_failed",

View File

@@ -5,6 +5,7 @@ import uuid
from collections import defaultdict from collections import defaultdict
from typing import Any from typing import Any
from pydantic import BaseModel, field_validator
from pydantic_core import PydanticUndefined from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
@@ -29,6 +30,25 @@ from .utils import build_missing_credentials_from_field_info
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RunBlockInput(BaseModel):
"""Input parameters for the run_block tool."""
block_id: str = ""
input_data: dict[str, Any] = {}
@field_validator("block_id", mode="before")
@classmethod
def strip_block_id(cls, v: Any) -> str:
"""Strip whitespace from block_id."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
@field_validator("input_data", mode="before")
@classmethod
def ensure_dict(cls, v: Any) -> dict[str, Any]:
"""Ensure input_data is a dict."""
return v if isinstance(v, dict) else {}
class RunBlockTool(BaseTool): class RunBlockTool(BaseTool):
"""Tool for executing a block and returning its outputs.""" """Tool for executing a block and returning its outputs."""
@@ -162,37 +182,29 @@ class RunBlockTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs, **kwargs: Any,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Execute a block with the given input data. """Execute a block with the given input data.
Args: Args:
user_id: User ID (required) user_id: User ID (required)
session: Chat session session: Chat session
block_id: Block UUID to execute **kwargs: Tool parameters
input_data: Input values for the block
Returns: Returns:
BlockOutputResponse: Block execution outputs BlockOutputResponse: Block execution outputs
SetupRequirementsResponse: Missing credentials SetupRequirementsResponse: Missing credentials
ErrorResponse: Error message ErrorResponse: Error message
""" """
block_id = kwargs.get("block_id", "").strip() params = RunBlockInput(**kwargs)
input_data = kwargs.get("input_data", {})
session_id = session.session_id session_id = session.session_id
if not block_id: if not params.block_id:
return ErrorResponse( return ErrorResponse(
message="Please provide a block_id", message="Please provide a block_id",
session_id=session_id, session_id=session_id,
) )
if not isinstance(input_data, dict):
return ErrorResponse(
message="input_data must be an object",
session_id=session_id,
)
if not user_id: if not user_id:
return ErrorResponse( return ErrorResponse(
message="Authentication required", message="Authentication required",
@@ -200,23 +212,25 @@ class RunBlockTool(BaseTool):
) )
# Get the block # Get the block
block = get_block(block_id) block = get_block(params.block_id)
if not block: if not block:
return ErrorResponse( return ErrorResponse(
message=f"Block '{block_id}' not found", message=f"Block '{params.block_id}' not found",
session_id=session_id, session_id=session_id,
) )
if block.disabled: if block.disabled:
return ErrorResponse( return ErrorResponse(
message=f"Block '{block_id}' is disabled", message=f"Block '{params.block_id}' is disabled",
session_id=session_id, session_id=session_id,
) )
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}") logger.info(
f"Executing block {block.name} ({params.block_id}) for user {user_id}"
)
creds_manager = IntegrationCredentialsManager() creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials( matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block, input_data user_id, block, params.input_data
) )
if missing_credentials: if missing_credentials:
@@ -234,7 +248,7 @@ class RunBlockTool(BaseTool):
), ),
session_id=session_id, session_id=session_id,
setup_info=SetupInfo( setup_info=SetupInfo(
agent_id=block_id, agent_id=params.block_id,
agent_name=block.name, agent_name=block.name,
user_readiness=UserReadiness( user_readiness=UserReadiness(
has_all_credentials=False, has_all_credentials=False,
@@ -263,7 +277,7 @@ class RunBlockTool(BaseTool):
# - node_exec_id = unique per block execution # - node_exec_id = unique per block execution
synthetic_graph_id = f"copilot-session-{session.session_id}" synthetic_graph_id = f"copilot-session-{session.session_id}"
synthetic_graph_exec_id = f"copilot-session-{session.session_id}" synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
synthetic_node_id = f"copilot-node-{block_id}" synthetic_node_id = f"copilot-node-{params.block_id}"
synthetic_node_exec_id = ( synthetic_node_exec_id = (
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}" f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
) )
@@ -298,8 +312,8 @@ class RunBlockTool(BaseTool):
for field_name, cred_meta in matched_credentials.items(): for field_name, cred_meta in matched_credentials.items():
# Inject metadata into input_data (for validation) # Inject metadata into input_data (for validation)
if field_name not in input_data: if field_name not in params.input_data:
input_data[field_name] = cred_meta.model_dump() params.input_data[field_name] = cred_meta.model_dump()
# Fetch actual credentials and pass as kwargs (for execution) # Fetch actual credentials and pass as kwargs (for execution)
actual_credentials = await creds_manager.get( actual_credentials = await creds_manager.get(
@@ -316,14 +330,14 @@ class RunBlockTool(BaseTool):
# Execute the block and collect outputs # Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list) outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute( async for output_name, output_data in block.execute(
input_data, params.input_data,
**exec_kwargs, **exec_kwargs,
): ):
outputs[output_name].append(output_data) outputs[output_name].append(output_data)
return BlockOutputResponse( return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully", message=f"Block '{block.name}' executed successfully",
block_id=block_id, block_id=params.block_id,
block_name=block.name, block_name=block.name,
outputs=dict(outputs), outputs=dict(outputs),
success=True, success=True,

View File

@@ -4,6 +4,7 @@ import logging
from typing import Any from typing import Any
from prisma.enums import ContentType from prisma.enums import ContentType
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool from backend.api.features.chat.tools.base import BaseTool
@@ -28,6 +29,18 @@ MAX_RESULTS = 5
SNIPPET_LENGTH = 200 SNIPPET_LENGTH = 200
class SearchDocsInput(BaseModel):
"""Input parameters for the search_docs tool."""
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from query."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class SearchDocsTool(BaseTool): class SearchDocsTool(BaseTool):
"""Tool for searching AutoGPT platform documentation.""" """Tool for searching AutoGPT platform documentation."""
@@ -91,24 +104,24 @@ class SearchDocsTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs, **kwargs: Any,
) -> ToolResponseBase: ) -> ToolResponseBase:
"""Search documentation and return relevant sections. """Search documentation and return relevant sections.
Args: Args:
user_id: User ID (not required for docs) user_id: User ID (not required for docs)
session: Chat session session: Chat session
query: Search query **kwargs: Tool parameters
Returns: Returns:
DocSearchResultsResponse: List of matching documentation sections DocSearchResultsResponse: List of matching documentation sections
NoResultsResponse: No results found NoResultsResponse: No results found
ErrorResponse: Error message ErrorResponse: Error message
""" """
query = kwargs.get("query", "").strip() params = SearchDocsInput(**kwargs)
session_id = session.session_id if session else None session_id = session.session_id if session else None
if not query: if not params.query:
return ErrorResponse( return ErrorResponse(
message="Please provide a search query.", message="Please provide a search query.",
error="Missing query parameter", error="Missing query parameter",
@@ -118,7 +131,7 @@ class SearchDocsTool(BaseTool):
try: try:
# Search using hybrid search for DOCUMENTATION content type only # Search using hybrid search for DOCUMENTATION content type only
results, total = await unified_hybrid_search( results, total = await unified_hybrid_search(
query=query, query=params.query,
content_types=[ContentType.DOCUMENTATION], content_types=[ContentType.DOCUMENTATION],
page=1, page=1,
page_size=MAX_RESULTS * 2, # Fetch extra for deduplication page_size=MAX_RESULTS * 2, # Fetch extra for deduplication
@@ -127,7 +140,7 @@ class SearchDocsTool(BaseTool):
if not results: if not results:
return NoResultsResponse( return NoResultsResponse(
message=f"No documentation found for '{query}'.", message=f"No documentation found for '{params.query}'.",
suggestions=[ suggestions=[
"Try different keywords", "Try different keywords",
"Use more general terms", "Use more general terms",
@@ -162,7 +175,7 @@ class SearchDocsTool(BaseTool):
if not deduplicated: if not deduplicated:
return NoResultsResponse( return NoResultsResponse(
message=f"No documentation found for '{query}'.", message=f"No documentation found for '{params.query}'.",
suggestions=[ suggestions=[
"Try different keywords", "Try different keywords",
"Use more general terms", "Use more general terms",
@@ -195,7 +208,7 @@ class SearchDocsTool(BaseTool):
message=f"Found {len(doc_results)} relevant documentation sections.", message=f"Found {len(doc_results)} relevant documentation sections.",
results=doc_results, results=doc_results,
count=len(doc_results), count=len(doc_results),
query=query, query=params.query,
session_id=session_id, session_id=session_id,
) )

View File

@@ -2,9 +2,9 @@
import base64 import base64
import logging import logging
from typing import Any, Optional from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace from backend.data.workspace import get_or_create_workspace
@@ -78,6 +78,65 @@ class WorkspaceDeleteResponse(ToolResponseBase):
success: bool success: bool
# Input models for workspace tools
class ListWorkspaceFilesInput(BaseModel):
"""Input parameters for list_workspace_files tool."""
path_prefix: str | None = None
limit: int = 50
include_all_sessions: bool = False
@field_validator("path_prefix", mode="before")
@classmethod
def strip_path(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class ReadWorkspaceFileInput(BaseModel):
"""Input parameters for read_workspace_file tool."""
file_id: str | None = None
path: str | None = None
force_download_url: bool = False
@field_validator("file_id", "path", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class WriteWorkspaceFileInput(BaseModel):
"""Input parameters for write_workspace_file tool."""
filename: str = ""
content_base64: str = ""
path: str | None = None
mime_type: str | None = None
overwrite: bool = False
@field_validator("filename", "content_base64", mode="before")
@classmethod
def strip_required(cls, v: Any) -> str:
return v.strip() if isinstance(v, str) else (v if v is not None else "")
@field_validator("path", "mime_type", mode="before")
@classmethod
def strip_optional(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class DeleteWorkspaceFileInput(BaseModel):
"""Input parameters for delete_workspace_file tool."""
file_id: str | None = None
path: str | None = None
@field_validator("file_id", "path", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class ListWorkspaceFilesTool(BaseTool): class ListWorkspaceFilesTool(BaseTool):
"""Tool for listing files in user's workspace.""" """Tool for listing files in user's workspace."""
@@ -131,8 +190,9 @@ class ListWorkspaceFilesTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs, **kwargs: Any,
) -> ToolResponseBase: ) -> ToolResponseBase:
params = ListWorkspaceFilesInput(**kwargs)
session_id = session.session_id session_id = session.session_id
if not user_id: if not user_id:
@@ -141,9 +201,7 @@ class ListWorkspaceFilesTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
path_prefix: Optional[str] = kwargs.get("path_prefix") limit = min(params.limit, 100)
limit = min(kwargs.get("limit", 50), 100)
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try: try:
workspace = await get_or_create_workspace(user_id) workspace = await get_or_create_workspace(user_id)
@@ -151,13 +209,13 @@ class ListWorkspaceFilesTool(BaseTool):
manager = WorkspaceManager(user_id, workspace.id, session_id) manager = WorkspaceManager(user_id, workspace.id, session_id)
files = await manager.list_files( files = await manager.list_files(
path=path_prefix, path=params.path_prefix,
limit=limit, limit=limit,
include_all_sessions=include_all_sessions, include_all_sessions=params.include_all_sessions,
) )
total = await manager.get_file_count( total = await manager.get_file_count(
path=path_prefix, path=params.path_prefix,
include_all_sessions=include_all_sessions, include_all_sessions=params.include_all_sessions,
) )
file_infos = [ file_infos = [
@@ -171,7 +229,9 @@ class ListWorkspaceFilesTool(BaseTool):
for f in files for f in files
] ]
scope_msg = "all sessions" if include_all_sessions else "current session" scope_msg = (
"all sessions" if params.include_all_sessions else "current session"
)
return WorkspaceFileListResponse( return WorkspaceFileListResponse(
files=file_infos, files=file_infos,
total_count=total, total_count=total,
@@ -259,8 +319,9 @@ class ReadWorkspaceFileTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs, **kwargs: Any,
) -> ToolResponseBase: ) -> ToolResponseBase:
params = ReadWorkspaceFileInput(**kwargs)
session_id = session.session_id session_id = session.session_id
if not user_id: if not user_id:
@@ -269,11 +330,7 @@ class ReadWorkspaceFileTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
file_id: Optional[str] = kwargs.get("file_id") if not params.file_id and not params.path:
path: Optional[str] = kwargs.get("path")
force_download_url: bool = kwargs.get("force_download_url", False)
if not file_id and not path:
return ErrorResponse( return ErrorResponse(
message="Please provide either file_id or path", message="Please provide either file_id or path",
session_id=session_id, session_id=session_id,
@@ -285,21 +342,21 @@ class ReadWorkspaceFileTool(BaseTool):
manager = WorkspaceManager(user_id, workspace.id, session_id) manager = WorkspaceManager(user_id, workspace.id, session_id)
# Get file info # Get file info
if file_id: if params.file_id:
file_info = await manager.get_file_info(file_id) file_info = await manager.get_file_info(params.file_id)
if file_info is None: if file_info is None:
return ErrorResponse( return ErrorResponse(
message=f"File not found: {file_id}", message=f"File not found: {params.file_id}",
session_id=session_id, session_id=session_id,
) )
target_file_id = file_id target_file_id = params.file_id
else: else:
# path is guaranteed to be non-None here due to the check above # path is guaranteed to be non-None here due to the check above
assert path is not None assert params.path is not None
file_info = await manager.get_file_info_by_path(path) file_info = await manager.get_file_info_by_path(params.path)
if file_info is None: if file_info is None:
return ErrorResponse( return ErrorResponse(
message=f"File not found at path: {path}", message=f"File not found at path: {params.path}",
session_id=session_id, session_id=session_id,
) )
target_file_id = file_info.id target_file_id = file_info.id
@@ -309,7 +366,7 @@ class ReadWorkspaceFileTool(BaseTool):
is_text_file = self._is_text_mime_type(file_info.mimeType) is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url) # Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url: if is_small_file and is_text_file and not params.force_download_url:
content = await manager.read_file_by_id(target_file_id) content = await manager.read_file_by_id(target_file_id)
content_b64 = base64.b64encode(content).decode("utf-8") content_b64 = base64.b64encode(content).decode("utf-8")
@@ -429,8 +486,9 @@ class WriteWorkspaceFileTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs, **kwargs: Any,
) -> ToolResponseBase: ) -> ToolResponseBase:
params = WriteWorkspaceFileInput(**kwargs)
session_id = session.session_id session_id = session.session_id
if not user_id: if not user_id:
@@ -439,19 +497,13 @@ class WriteWorkspaceFileTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
filename: str = kwargs.get("filename", "") if not params.filename:
content_b64: str = kwargs.get("content_base64", "")
path: Optional[str] = kwargs.get("path")
mime_type: Optional[str] = kwargs.get("mime_type")
overwrite: bool = kwargs.get("overwrite", False)
if not filename:
return ErrorResponse( return ErrorResponse(
message="Please provide a filename", message="Please provide a filename",
session_id=session_id, session_id=session_id,
) )
if not content_b64: if not params.content_base64:
return ErrorResponse( return ErrorResponse(
message="Please provide content_base64", message="Please provide content_base64",
session_id=session_id, session_id=session_id,
@@ -459,7 +511,7 @@ class WriteWorkspaceFileTool(BaseTool):
# Decode content # Decode content
try: try:
content = base64.b64decode(content_b64) content = base64.b64decode(params.content_base64)
except Exception: except Exception:
return ErrorResponse( return ErrorResponse(
message="Invalid base64-encoded content", message="Invalid base64-encoded content",
@@ -476,7 +528,7 @@ class WriteWorkspaceFileTool(BaseTool):
try: try:
# Virus scan # Virus scan
await scan_content_safe(content, filename=filename) await scan_content_safe(content, filename=params.filename)
workspace = await get_or_create_workspace(user_id) workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access # Pass session_id for session-scoped file access
@@ -484,10 +536,10 @@ class WriteWorkspaceFileTool(BaseTool):
file_record = await manager.write_file( file_record = await manager.write_file(
content=content, content=content,
filename=filename, filename=params.filename,
path=path, path=params.path,
mime_type=mime_type, mime_type=params.mime_type,
overwrite=overwrite, overwrite=params.overwrite,
) )
return WorkspaceWriteResponse( return WorkspaceWriteResponse(
@@ -557,8 +609,9 @@ class DeleteWorkspaceFileTool(BaseTool):
self, self,
user_id: str | None, user_id: str | None,
session: ChatSession, session: ChatSession,
**kwargs, **kwargs: Any,
) -> ToolResponseBase: ) -> ToolResponseBase:
params = DeleteWorkspaceFileInput(**kwargs)
session_id = session.session_id session_id = session.session_id
if not user_id: if not user_id:
@@ -567,10 +620,7 @@ class DeleteWorkspaceFileTool(BaseTool):
session_id=session_id, session_id=session_id,
) )
file_id: Optional[str] = kwargs.get("file_id") if not params.file_id and not params.path:
path: Optional[str] = kwargs.get("path")
if not file_id and not path:
return ErrorResponse( return ErrorResponse(
message="Please provide either file_id or path", message="Please provide either file_id or path",
session_id=session_id, session_id=session_id,
@@ -583,15 +633,15 @@ class DeleteWorkspaceFileTool(BaseTool):
# Determine the file_id to delete # Determine the file_id to delete
target_file_id: str target_file_id: str
if file_id: if params.file_id:
target_file_id = file_id target_file_id = params.file_id
else: else:
# path is guaranteed to be non-None here due to the check above # path is guaranteed to be non-None here due to the check above
assert path is not None assert params.path is not None
file_info = await manager.get_file_info_by_path(path) file_info = await manager.get_file_info_by_path(params.path)
if file_info is None: if file_info is None:
return ErrorResponse( return ErrorResponse(
message=f"File not found at path: {path}", message=f"File not found at path: {params.path}",
session_id=session_id, session_id=session_id,
) )
target_file_id = file_info.id target_file_id = file_info.id