diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py index 7333851a5b..8b075c8beb 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py @@ -3,6 +3,8 @@ import logging from typing import Any +from pydantic import BaseModel, field_validator + from backend.api.features.chat.model import ChatSession from .agent_generator import ( @@ -28,6 +30,26 @@ from .models import ( logger = logging.getLogger(__name__) +class CreateAgentInput(BaseModel): + """Input parameters for the create_agent tool.""" + + description: str = "" + context: str = "" + save: bool = True + # Internal async processing params (passed by long-running tool handler) + _operation_id: str | None = None + _task_id: str | None = None + + @field_validator("description", "context", mode="before") + @classmethod + def strip_strings(cls, v: Any) -> str: + """Strip whitespace from string fields.""" + return v.strip() if isinstance(v, str) else (v if v is not None else "") + + class Config: + extra = "allow" # Allow _operation_id, _task_id from kwargs + + class CreateAgentTool(BaseTool): """Tool for creating agents from natural language descriptions.""" @@ -85,7 +107,7 @@ class CreateAgentTool(BaseTool): self, user_id: str | None, session: ChatSession, - **kwargs, + **kwargs: Any, ) -> ToolResponseBase: """Execute the create_agent tool. @@ -94,16 +116,14 @@ class CreateAgentTool(BaseTool): 2. Generate agent JSON (external service handles fixing and validation) 3. Preview or save based on the save parameter """ - description = kwargs.get("description", "").strip() - context = kwargs.get("context", "") - save = kwargs.get("save", True) + params = CreateAgentInput(**kwargs) session_id = session.session_id if session else None - # Extract async processing params (passed by long-running tool handler) + # Extract async processing params operation_id = kwargs.get("_operation_id") task_id = kwargs.get("_task_id") - if not description: + if not params.description: return ErrorResponse( message="Please provide a description of what the agent should do.", error="Missing description parameter", @@ -115,7 +135,7 @@ class CreateAgentTool(BaseTool): try: library_agents = await get_all_relevant_agents_for_generation( user_id=user_id, - search_query=description, + search_query=params.description, include_marketplace=True, ) logger.debug( @@ -126,7 +146,7 @@ class CreateAgentTool(BaseTool): try: decomposition_result = await decompose_goal( - description, context, library_agents + params.description, params.context, library_agents ) except AgentGeneratorNotConfiguredError: return ErrorResponse( @@ -142,7 +162,7 @@ class CreateAgentTool(BaseTool): return ErrorResponse( message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.", error="decomposition_failed", - details={"description": description[:100]}, + details={"description": params.description[:100]}, session_id=session_id, ) @@ -158,7 +178,7 @@ class CreateAgentTool(BaseTool): message=user_message, error=f"decomposition_failed:{error_type}", details={ - "description": description[:100], + "description": params.description[:100], "service_error": error_msg, "error_type": error_type, }, @@ -244,7 +264,7 @@ class CreateAgentTool(BaseTool): return ErrorResponse( message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.", error="generation_failed", - details={"description": description[:100]}, + details={"description": params.description[:100]}, session_id=session_id, ) @@ -266,7 +286,7 @@ class CreateAgentTool(BaseTool): message=user_message, error=f"generation_failed:{error_type}", details={ - "description": description[:100], + "description": params.description[:100], "service_error": error_msg, "error_type": error_type, }, @@ -291,7 +311,7 @@ class CreateAgentTool(BaseTool): node_count = len(agent_json.get("nodes", [])) link_count = len(agent_json.get("links", [])) - if not save: + if not params.save: return AgentPreviewResponse( message=( f"I've generated an agent called '{agent_name}' with {node_count} blocks. " diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py index 3ae56407a7..018b94bd41 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py @@ -3,6 +3,8 @@ import logging from typing import Any +from pydantic import BaseModel, field_validator + from backend.api.features.chat.model import ChatSession from .agent_generator import ( @@ -27,6 +29,24 @@ from .models import ( logger = logging.getLogger(__name__) +class EditAgentInput(BaseModel): + """Input parameters for the edit_agent tool.""" + + agent_id: str = "" + changes: str = "" + context: str = "" + save: bool = True + + @field_validator("agent_id", "changes", "context", mode="before") + @classmethod + def strip_strings(cls, v: Any) -> str: + """Strip whitespace from string fields.""" + return v.strip() if isinstance(v, str) else (v if v is not None else "") + + class Config: + extra = "allow" # Allow _operation_id, _task_id + + class EditAgentTool(BaseTool): """Tool for editing existing agents using natural language.""" @@ -90,7 +110,7 @@ class EditAgentTool(BaseTool): self, user_id: str | None, session: ChatSession, - **kwargs, + **kwargs: Any, ) -> ToolResponseBase: """Execute the edit_agent tool. @@ -99,35 +119,32 @@ class EditAgentTool(BaseTool): 2. Generate updated agent (external service handles fixing and validation) 3. Preview or save based on the save parameter """ - agent_id = kwargs.get("agent_id", "").strip() - changes = kwargs.get("changes", "").strip() - context = kwargs.get("context", "") - save = kwargs.get("save", True) + params = EditAgentInput(**kwargs) session_id = session.session_id if session else None # Extract async processing params (passed by long-running tool handler) operation_id = kwargs.get("_operation_id") task_id = kwargs.get("_task_id") - if not agent_id: + if not params.agent_id: return ErrorResponse( message="Please provide the agent ID to edit.", error="Missing agent_id parameter", session_id=session_id, ) - if not changes: + if not params.changes: return ErrorResponse( message="Please describe what changes you want to make.", error="Missing changes parameter", session_id=session_id, ) - current_agent = await get_agent_as_json(agent_id, user_id) + current_agent = await get_agent_as_json(params.agent_id, user_id) if current_agent is None: return ErrorResponse( - message=f"Could not find agent with ID '{agent_id}' in your library.", + message=f"Could not find agent '{params.agent_id}' in your library.", error="agent_not_found", session_id=session_id, ) @@ -138,7 +155,7 @@ class EditAgentTool(BaseTool): graph_id = current_agent.get("id") library_agents = await get_all_relevant_agents_for_generation( user_id=user_id, - search_query=changes, + search_query=params.changes, exclude_graph_id=graph_id, include_marketplace=True, ) @@ -148,9 +165,11 @@ class EditAgentTool(BaseTool): except Exception as e: logger.warning(f"Failed to fetch library agents: {e}") - update_request = changes - if context: - update_request = f"{changes}\n\nAdditional context:\n{context}" + update_request = params.changes + if params.context: + update_request = ( + f"{params.changes}\n\nAdditional context:\n{params.context}" + ) try: result = await generate_agent_patch( @@ -239,7 +258,7 @@ class EditAgentTool(BaseTool): node_count = len(updated_agent.get("nodes", [])) link_count = len(updated_agent.get("links", [])) - if not save: + if not params.save: return AgentPreviewResponse( message=( f"I've updated the agent. " diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/find_agent.py index 477522757d..0ec3b42083 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/find_agent.py @@ -2,6 +2,8 @@ from typing import Any +from pydantic import BaseModel, field_validator + from backend.api.features.chat.model import ChatSession from .agent_search import search_agents @@ -9,6 +11,18 @@ from .base import BaseTool from .models import ToolResponseBase +class FindAgentInput(BaseModel): + """Input parameters for the find_agent tool.""" + + query: str = "" + + @field_validator("query", mode="before") + @classmethod + def strip_string(cls, v: Any) -> str: + """Strip whitespace from query.""" + return v.strip() if isinstance(v, str) else (v if v is not None else "") + + class FindAgentTool(BaseTool): """Tool for discovering agents from the marketplace.""" @@ -36,10 +50,11 @@ class FindAgentTool(BaseTool): } async def _execute( - self, user_id: str | None, session: ChatSession, **kwargs + self, user_id: str | None, session: ChatSession, **kwargs: Any ) -> ToolResponseBase: + params = FindAgentInput(**kwargs) return await search_agents( - query=kwargs.get("query", "").strip(), + query=params.query, source="marketplace", session_id=session.session_id, user_id=user_id, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py index 7ca85961f9..66be667d32 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py @@ -2,6 +2,7 @@ import logging from typing import Any from prisma.enums import ContentType +from pydantic import BaseModel, field_validator from backend.api.features.chat.model import ChatSession from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase @@ -18,6 +19,18 @@ from backend.data.block import get_block logger = logging.getLogger(__name__) +class FindBlockInput(BaseModel): + """Input parameters for the find_block tool.""" + + query: str = "" + + @field_validator("query", mode="before") + @classmethod + def strip_string(cls, v: Any) -> str: + """Strip whitespace from query.""" + return v.strip() if isinstance(v, str) else (v if v is not None else "") + + class FindBlockTool(BaseTool): """Tool for searching available blocks.""" @@ -59,24 +72,24 @@ class FindBlockTool(BaseTool): self, user_id: str | None, session: ChatSession, - **kwargs, + **kwargs: Any, ) -> ToolResponseBase: """Search for blocks matching the query. Args: user_id: User ID (required) session: Chat session - query: Search query + **kwargs: Tool parameters Returns: BlockListResponse: List of matching blocks NoResultsResponse: No blocks found ErrorResponse: Error message """ - query = kwargs.get("query", "").strip() + params = FindBlockInput(**kwargs) session_id = session.session_id - if not query: + if not params.query: return ErrorResponse( message="Please provide a search query", session_id=session_id, @@ -85,7 +98,7 @@ class FindBlockTool(BaseTool): try: # Search for blocks using hybrid search results, total = await unified_hybrid_search( - query=query, + query=params.query, content_types=[ContentType.BLOCK], page=1, page_size=10, @@ -93,7 +106,7 @@ class FindBlockTool(BaseTool): if not results: return NoResultsResponse( - message=f"No blocks found for '{query}'", + message=f"No blocks found for '{params.query}'", suggestions=[ "Try broader keywords like 'email', 'http', 'text', 'ai'", "Check spelling of technical terms", @@ -165,7 +178,7 @@ class FindBlockTool(BaseTool): if not blocks: return NoResultsResponse( - message=f"No blocks found for '{query}'", + message=f"No blocks found for '{params.query}'", suggestions=[ "Try broader keywords like 'email', 'http', 'text', 'ai'", ], @@ -174,13 +187,13 @@ class FindBlockTool(BaseTool): return BlockListResponse( message=( - f"Found {len(blocks)} block(s) matching '{query}'. " + f"Found {len(blocks)} block(s) matching '{params.query}'. " "To execute a block, use run_block with the block's 'id' field " "and provide 'input_data' matching the block's input_schema." ), blocks=blocks, count=len(blocks), - query=query, + query=params.query, session_id=session_id, ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_library_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/find_library_agent.py index 108fba75ae..f36a977b54 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_library_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/find_library_agent.py @@ -2,6 +2,8 @@ from typing import Any +from pydantic import BaseModel, field_validator + from backend.api.features.chat.model import ChatSession from .agent_search import search_agents @@ -9,6 +11,18 @@ from .base import BaseTool from .models import ToolResponseBase +class FindLibraryAgentInput(BaseModel): + """Input parameters for the find_library_agent tool.""" + + query: str = "" + + @field_validator("query", mode="before") + @classmethod + def strip_string(cls, v: Any) -> str: + """Strip whitespace from query.""" + return v.strip() if isinstance(v, str) else (v if v is not None else "") + + class FindLibraryAgentTool(BaseTool): """Tool for searching agents in the user's library.""" @@ -42,10 +56,11 @@ class FindLibraryAgentTool(BaseTool): return True async def _execute( - self, user_id: str | None, session: ChatSession, **kwargs + self, user_id: str | None, session: ChatSession, **kwargs: Any ) -> ToolResponseBase: + params = FindLibraryAgentInput(**kwargs) return await search_agents( - query=kwargs.get("query", "").strip(), + query=params.query, source="library", session_id=session.session_id, user_id=user_id, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/get_doc_page.py b/autogpt_platform/backend/backend/api/features/chat/tools/get_doc_page.py index 7040cd7db5..80f6150548 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/get_doc_page.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/get_doc_page.py @@ -4,6 +4,8 @@ import logging from pathlib import Path from typing import Any +from pydantic import BaseModel, field_validator + from backend.api.features.chat.model import ChatSession from backend.api.features.chat.tools.base import BaseTool from backend.api.features.chat.tools.models import ( @@ -18,6 +20,18 @@ logger = logging.getLogger(__name__) DOCS_BASE_URL = "https://docs.agpt.co" +class GetDocPageInput(BaseModel): + """Input parameters for the get_doc_page tool.""" + + path: str = "" + + @field_validator("path", mode="before") + @classmethod + def strip_string(cls, v: Any) -> str: + """Strip whitespace from path.""" + return v.strip() if isinstance(v, str) else (v if v is not None else "") + + class GetDocPageTool(BaseTool): """Tool for fetching full content of a documentation page.""" @@ -75,23 +89,23 @@ class GetDocPageTool(BaseTool): self, user_id: str | None, session: ChatSession, - **kwargs, + **kwargs: Any, ) -> ToolResponseBase: """Fetch full content of a documentation page. Args: user_id: User ID (not required for docs) session: Chat session - path: Path to the documentation file + **kwargs: Tool parameters Returns: DocPageResponse: Full document content ErrorResponse: Error message """ - path = kwargs.get("path", "").strip() + params = GetDocPageInput(**kwargs) session_id = session.session_id if session else None - if not path: + if not params.path: return ErrorResponse( message="Please provide a documentation path.", error="Missing path parameter", @@ -99,7 +113,7 @@ class GetDocPageTool(BaseTool): ) # Sanitize path to prevent directory traversal - if ".." in path or path.startswith("/"): + if ".." in params.path or params.path.startswith("/"): return ErrorResponse( message="Invalid documentation path.", error="invalid_path", @@ -107,11 +121,11 @@ class GetDocPageTool(BaseTool): ) docs_root = self._get_docs_root() - full_path = docs_root / path + full_path = docs_root / params.path if not full_path.exists(): return ErrorResponse( - message=f"Documentation page not found: {path}", + message=f"Documentation page not found: {params.path}", error="not_found", session_id=session_id, ) @@ -128,19 +142,19 @@ class GetDocPageTool(BaseTool): try: content = full_path.read_text(encoding="utf-8") - title = self._extract_title(content, path) + title = self._extract_title(content, params.path) return DocPageResponse( message=f"Retrieved documentation page: {title}", title=title, - path=path, + path=params.path, content=content, doc_url=self._make_doc_url(path), session_id=session_id, ) except Exception as e: - logger.error(f"Failed to read documentation page {path}: {e}") + logger.error(f"Failed to read documentation page {params.path}: {e}") return ErrorResponse( message=f"Failed to read documentation page: {str(e)}", error="read_failed", diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py index 51bb2c0575..08b132d4a6 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py @@ -5,6 +5,7 @@ import uuid from collections import defaultdict from typing import Any +from pydantic import BaseModel, field_validator from pydantic_core import PydanticUndefined from backend.api.features.chat.model import ChatSession @@ -29,6 +30,25 @@ from .utils import build_missing_credentials_from_field_info logger = logging.getLogger(__name__) +class RunBlockInput(BaseModel): + """Input parameters for the run_block tool.""" + + block_id: str = "" + input_data: dict[str, Any] = {} + + @field_validator("block_id", mode="before") + @classmethod + def strip_block_id(cls, v: Any) -> str: + """Strip whitespace from block_id.""" + return v.strip() if isinstance(v, str) else (v if v is not None else "") + + @field_validator("input_data", mode="before") + @classmethod + def ensure_dict(cls, v: Any) -> dict[str, Any]: + """Ensure input_data is a dict.""" + return v if isinstance(v, dict) else {} + + class RunBlockTool(BaseTool): """Tool for executing a block and returning its outputs.""" @@ -162,37 +182,29 @@ class RunBlockTool(BaseTool): self, user_id: str | None, session: ChatSession, - **kwargs, + **kwargs: Any, ) -> ToolResponseBase: """Execute a block with the given input data. Args: user_id: User ID (required) session: Chat session - block_id: Block UUID to execute - input_data: Input values for the block + **kwargs: Tool parameters Returns: BlockOutputResponse: Block execution outputs SetupRequirementsResponse: Missing credentials ErrorResponse: Error message """ - block_id = kwargs.get("block_id", "").strip() - input_data = kwargs.get("input_data", {}) + params = RunBlockInput(**kwargs) session_id = session.session_id - if not block_id: + if not params.block_id: return ErrorResponse( message="Please provide a block_id", session_id=session_id, ) - if not isinstance(input_data, dict): - return ErrorResponse( - message="input_data must be an object", - session_id=session_id, - ) - if not user_id: return ErrorResponse( message="Authentication required", @@ -200,23 +212,25 @@ class RunBlockTool(BaseTool): ) # Get the block - block = get_block(block_id) + block = get_block(params.block_id) if not block: return ErrorResponse( - message=f"Block '{block_id}' not found", + message=f"Block '{params.block_id}' not found", session_id=session_id, ) if block.disabled: return ErrorResponse( - message=f"Block '{block_id}' is disabled", + message=f"Block '{params.block_id}' is disabled", session_id=session_id, ) - logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}") + logger.info( + f"Executing block {block.name} ({params.block_id}) for user {user_id}" + ) creds_manager = IntegrationCredentialsManager() matched_credentials, missing_credentials = await self._check_block_credentials( - user_id, block, input_data + user_id, block, params.input_data ) if missing_credentials: @@ -298,8 +312,8 @@ class RunBlockTool(BaseTool): for field_name, cred_meta in matched_credentials.items(): # Inject metadata into input_data (for validation) - if field_name not in input_data: - input_data[field_name] = cred_meta.model_dump() + if field_name not in params.input_data: + params.input_data[field_name] = cred_meta.model_dump() # Fetch actual credentials and pass as kwargs (for execution) actual_credentials = await creds_manager.get( @@ -316,14 +330,14 @@ class RunBlockTool(BaseTool): # Execute the block and collect outputs outputs: dict[str, list[Any]] = defaultdict(list) async for output_name, output_data in block.execute( - input_data, + params.input_data, **exec_kwargs, ): outputs[output_name].append(output_data) return BlockOutputResponse( message=f"Block '{block.name}' executed successfully", - block_id=block_id, + block_id=params.block_id, block_name=block.name, outputs=dict(outputs), success=True, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/search_docs.py b/autogpt_platform/backend/backend/api/features/chat/tools/search_docs.py index edb0c0de1e..c581abc35e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/search_docs.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/search_docs.py @@ -4,6 +4,7 @@ import logging from typing import Any from prisma.enums import ContentType +from pydantic import BaseModel, field_validator from backend.api.features.chat.model import ChatSession from backend.api.features.chat.tools.base import BaseTool @@ -28,6 +29,18 @@ MAX_RESULTS = 5 SNIPPET_LENGTH = 200 +class SearchDocsInput(BaseModel): + """Input parameters for the search_docs tool.""" + + query: str = "" + + @field_validator("query", mode="before") + @classmethod + def strip_string(cls, v: Any) -> str: + """Strip whitespace from query.""" + return v.strip() if isinstance(v, str) else (v if v is not None else "") + + class SearchDocsTool(BaseTool): """Tool for searching AutoGPT platform documentation.""" @@ -91,24 +104,24 @@ class SearchDocsTool(BaseTool): self, user_id: str | None, session: ChatSession, - **kwargs, + **kwargs: Any, ) -> ToolResponseBase: """Search documentation and return relevant sections. Args: user_id: User ID (not required for docs) session: Chat session - query: Search query + **kwargs: Tool parameters Returns: DocSearchResultsResponse: List of matching documentation sections NoResultsResponse: No results found ErrorResponse: Error message """ - query = kwargs.get("query", "").strip() + params = SearchDocsInput(**kwargs) session_id = session.session_id if session else None - if not query: + if not params.query: return ErrorResponse( message="Please provide a search query.", error="Missing query parameter", @@ -118,7 +131,7 @@ class SearchDocsTool(BaseTool): try: # Search using hybrid search for DOCUMENTATION content type only results, total = await unified_hybrid_search( - query=query, + query=params.query, content_types=[ContentType.DOCUMENTATION], page=1, page_size=MAX_RESULTS * 2, # Fetch extra for deduplication @@ -127,7 +140,7 @@ class SearchDocsTool(BaseTool): if not results: return NoResultsResponse( - message=f"No documentation found for '{query}'.", + message=f"No documentation found for '{params.query}'.", suggestions=[ "Try different keywords", "Use more general terms", @@ -162,7 +175,7 @@ class SearchDocsTool(BaseTool): if not deduplicated: return NoResultsResponse( - message=f"No documentation found for '{query}'.", + message=f"No documentation found for '{params.query}'.", suggestions=[ "Try different keywords", "Use more general terms", @@ -195,7 +208,7 @@ class SearchDocsTool(BaseTool): message=f"Found {len(doc_results)} relevant documentation sections.", results=doc_results, count=len(doc_results), - query=query, + query=params.query, session_id=session_id, ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py b/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py index 03532c8fee..7012b68bf6 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py @@ -4,7 +4,7 @@ import base64 import logging from typing import Any, Optional -from pydantic import BaseModel +from pydantic import BaseModel, field_validator from backend.api.features.chat.model import ChatSession from backend.data.workspace import get_or_create_workspace @@ -78,6 +78,65 @@ class WorkspaceDeleteResponse(ToolResponseBase): success: bool +# Input models for workspace tools +class ListWorkspaceFilesInput(BaseModel): + """Input parameters for list_workspace_files tool.""" + + path_prefix: str | None = None + limit: int = 50 + include_all_sessions: bool = False + + @field_validator("path_prefix", mode="before") + @classmethod + def strip_path(cls, v: Any) -> str | None: + return v.strip() if isinstance(v, str) else None + + +class ReadWorkspaceFileInput(BaseModel): + """Input parameters for read_workspace_file tool.""" + + file_id: str | None = None + path: str | None = None + force_download_url: bool = False + + @field_validator("file_id", "path", mode="before") + @classmethod + def strip_strings(cls, v: Any) -> str | None: + return v.strip() if isinstance(v, str) else None + + +class WriteWorkspaceFileInput(BaseModel): + """Input parameters for write_workspace_file tool.""" + + filename: str = "" + content_base64: str = "" + path: str | None = None + mime_type: str | None = None + overwrite: bool = False + + @field_validator("filename", "content_base64", mode="before") + @classmethod + def strip_required(cls, v: Any) -> str: + return v.strip() if isinstance(v, str) else (v if v is not None else "") + + @field_validator("path", "mime_type", mode="before") + @classmethod + def strip_optional(cls, v: Any) -> str | None: + return v.strip() if isinstance(v, str) else None + + +class DeleteWorkspaceFileInput(BaseModel): + """Input parameters for delete_workspace_file tool.""" + + file_id: str | None = None + path: str | None = None + + @field_validator("file_id", "path", mode="before") + @classmethod + def strip_strings(cls, v: Any) -> str | None: + return v.strip() if isinstance(v, str) else None + + class ListWorkspaceFilesTool(BaseTool): """Tool for listing files in user's workspace.""" @@ -131,8 +190,9 @@ class ListWorkspaceFilesTool(BaseTool): self, user_id: str | None, session: ChatSession, - **kwargs, + **kwargs: Any, ) -> ToolResponseBase: + params = ListWorkspaceFilesInput(**kwargs) session_id = session.session_id if not user_id: @@ -141,9 +201,7 @@ class ListWorkspaceFilesTool(BaseTool): session_id=session_id, ) - path_prefix: Optional[str] = kwargs.get("path_prefix") - limit = min(kwargs.get("limit", 50), 100) - include_all_sessions: bool = kwargs.get("include_all_sessions", False) + limit = min(params.limit, 100) try: workspace = await get_or_create_workspace(user_id) @@ -151,13 +209,13 @@ class ListWorkspaceFilesTool(BaseTool): manager = WorkspaceManager(user_id, workspace.id, session_id) files = await manager.list_files( - path=path_prefix, + path=params.path_prefix, limit=limit, - include_all_sessions=include_all_sessions, + include_all_sessions=params.include_all_sessions, ) total = await manager.get_file_count( - path=path_prefix, - include_all_sessions=include_all_sessions, + path=params.path_prefix, + include_all_sessions=params.include_all_sessions, ) file_infos = [ @@ -259,8 +317,9 @@ class ReadWorkspaceFileTool(BaseTool): self, user_id: str | None, session: ChatSession, - **kwargs, + **kwargs: Any, ) -> ToolResponseBase: + params = ReadWorkspaceFileInput(**kwargs) session_id = session.session_id if not user_id: @@ -269,11 +328,7 @@ class ReadWorkspaceFileTool(BaseTool): session_id=session_id, ) - file_id: Optional[str] = kwargs.get("file_id") - path: Optional[str] = kwargs.get("path") - force_download_url: bool = kwargs.get("force_download_url", False) - - if not file_id and not path: + if not params.file_id and not params.path: return ErrorResponse( message="Please provide either file_id or path", session_id=session_id, @@ -285,21 +340,21 @@ class ReadWorkspaceFileTool(BaseTool): manager = WorkspaceManager(user_id, workspace.id, session_id) # Get file info - if file_id: - file_info = await manager.get_file_info(file_id) + if params.file_id: + file_info = await manager.get_file_info(params.file_id) if file_info is None: return ErrorResponse( - message=f"File not found: {file_id}", + message=f"File not found: {params.file_id}", session_id=session_id, ) - target_file_id = file_id + target_file_id = params.file_id else: # path is guaranteed to be non-None here due to the check above - assert path is not None - file_info = await manager.get_file_info_by_path(path) + assert params.path is not None + file_info = await manager.get_file_info_by_path(params.path) if file_info is None: return ErrorResponse( - message=f"File not found at path: {path}", + message=f"File not found at path: {params.path}", session_id=session_id, ) target_file_id = file_info.id @@ -309,7 +364,7 @@ class ReadWorkspaceFileTool(BaseTool): is_text_file = self._is_text_mime_type(file_info.mimeType) # Return inline content for small text files (unless force_download_url) - if is_small_file and is_text_file and not force_download_url: + if is_small_file and is_text_file and not params.force_download_url: content = await manager.read_file_by_id(target_file_id) content_b64 = base64.b64encode(content).decode("utf-8") @@ -429,8 +484,9 @@ class WriteWorkspaceFileTool(BaseTool): self, user_id: str | None, session: ChatSession, - **kwargs, + **kwargs: Any, ) -> ToolResponseBase: + params = WriteWorkspaceFileInput(**kwargs) session_id = session.session_id if not user_id: @@ -439,19 +495,13 @@ class WriteWorkspaceFileTool(BaseTool): session_id=session_id, ) - filename: str = kwargs.get("filename", "") - content_b64: str = kwargs.get("content_base64", "") - path: Optional[str] = kwargs.get("path") - mime_type: Optional[str] = kwargs.get("mime_type") - overwrite: bool = kwargs.get("overwrite", False) - - if not filename: + if not params.filename: return ErrorResponse( message="Please provide a filename", session_id=session_id, ) - if not content_b64: + if not params.content_base64: return ErrorResponse( message="Please provide content_base64", session_id=session_id, @@ -459,7 +509,7 @@ class WriteWorkspaceFileTool(BaseTool): # Decode content try: - content = base64.b64decode(content_b64) + content = base64.b64decode(params.content_base64) except Exception: return ErrorResponse( message="Invalid base64-encoded content", @@ -476,7 +526,7 @@ class WriteWorkspaceFileTool(BaseTool): try: # Virus scan - await scan_content_safe(content, filename=filename) + await scan_content_safe(content, filename=params.filename) workspace = await get_or_create_workspace(user_id) # Pass session_id for session-scoped file access @@ -484,10 +534,10 @@ class WriteWorkspaceFileTool(BaseTool): file_record = await manager.write_file( content=content, - filename=filename, - path=path, - mime_type=mime_type, - overwrite=overwrite, + filename=params.filename, + path=params.path, + mime_type=params.mime_type, + overwrite=params.overwrite, ) return WorkspaceWriteResponse( @@ -557,8 +607,9 @@ class DeleteWorkspaceFileTool(BaseTool): self, user_id: str | None, session: ChatSession, - **kwargs, + **kwargs: Any, ) -> ToolResponseBase: + params = DeleteWorkspaceFileInput(**kwargs) session_id = session.session_id if not user_id: @@ -567,10 +618,7 @@ class DeleteWorkspaceFileTool(BaseTool): session_id=session_id, ) - file_id: Optional[str] = kwargs.get("file_id") - path: Optional[str] = kwargs.get("path") - - if not file_id and not path: + if not params.file_id and not params.path: return ErrorResponse( message="Please provide either file_id or path", session_id=session_id, @@ -583,15 +631,15 @@ class DeleteWorkspaceFileTool(BaseTool): # Determine the file_id to delete target_file_id: str - if file_id: - target_file_id = file_id + if params.file_id: + target_file_id = params.file_id else: # path is guaranteed to be non-None here due to the check above - assert path is not None - file_info = await manager.get_file_info_by_path(path) + assert params.path is not None + file_info = await manager.get_file_info_by_path(params.path) if file_info is None: return ErrorResponse( - message=f"File not found at path: {path}", + message=f"File not found at path: {params.path}", session_id=session_id, ) target_file_id = file_info.id