mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-04 11:55:11 -05:00
Compare commits
4 Commits
release-v0
...
fix/code-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a093d57ed2 | ||
|
|
6692f39cbd | ||
|
|
aeba28266c | ||
|
|
6d8c83c039 |
@@ -1,12 +1,15 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid as uuid_module
|
||||
from asyncio import CancelledError
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import openai
|
||||
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
@@ -467,8 +470,6 @@ async def stream_chat_completion(
|
||||
should_retry = False
|
||||
|
||||
# Generate unique IDs for AI SDK protocol
|
||||
import uuid as uuid_module
|
||||
|
||||
message_id = str(uuid_module.uuid4())
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
@@ -826,10 +827,6 @@ async def _manage_context_window(
|
||||
Returns:
|
||||
CompressResult with compacted messages and metadata
|
||||
"""
|
||||
import openai
|
||||
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
# Convert messages to dict format
|
||||
messages_dict = []
|
||||
for msg in messages:
|
||||
@@ -1140,8 +1137,6 @@ async def _yield_tool_call(
|
||||
KeyError: If expected tool call fields are missing
|
||||
TypeError: If tool call structure is invalid
|
||||
"""
|
||||
import uuid as uuid_module
|
||||
|
||||
tool_name = tool_calls[yield_idx]["function"]["name"]
|
||||
tool_call_id = tool_calls[yield_idx]["id"]
|
||||
|
||||
@@ -1762,8 +1757,6 @@ async def _generate_llm_continuation_with_streaming(
|
||||
after a tool result is saved. Chunks are published to the stream registry
|
||||
so reconnecting clients can receive them.
|
||||
"""
|
||||
import uuid as uuid_module
|
||||
|
||||
try:
|
||||
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||
await invalidate_session_cache(session_id)
|
||||
@@ -1799,10 +1792,6 @@ async def _generate_llm_continuation_with_streaming(
|
||||
extra_body["session_id"] = session_id[:128]
|
||||
|
||||
# Make streaming LLM call (no tools - just text response)
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
# Generate unique IDs for AI SDK protocol
|
||||
message_id = str(uuid_module.uuid4())
|
||||
text_block_id = str(uuid_module.uuid4())
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
@@ -28,6 +30,26 @@ from .models import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreateAgentInput(BaseModel):
|
||||
"""Input parameters for the create_agent tool."""
|
||||
|
||||
description: str = ""
|
||||
context: str = ""
|
||||
save: bool = True
|
||||
# Internal async processing params (passed by long-running tool handler)
|
||||
_operation_id: str | None = None
|
||||
_task_id: str | None = None
|
||||
|
||||
@field_validator("description", "context", mode="before")
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> str:
|
||||
"""Strip whitespace from string fields."""
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
class Config:
|
||||
extra = "allow" # Allow _operation_id, _task_id from kwargs
|
||||
|
||||
|
||||
class CreateAgentTool(BaseTool):
|
||||
"""Tool for creating agents from natural language descriptions."""
|
||||
|
||||
@@ -85,7 +107,7 @@ class CreateAgentTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the create_agent tool.
|
||||
|
||||
@@ -94,16 +116,14 @@ class CreateAgentTool(BaseTool):
|
||||
2. Generate agent JSON (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
description = kwargs.get("description", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
params = CreateAgentInput(**kwargs)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
# Extract async processing params
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not description:
|
||||
if not params.description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
error="Missing description parameter",
|
||||
@@ -115,7 +135,7 @@ class CreateAgentTool(BaseTool):
|
||||
try:
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=description,
|
||||
search_query=params.description,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
@@ -126,7 +146,7 @@ class CreateAgentTool(BaseTool):
|
||||
|
||||
try:
|
||||
decomposition_result = await decompose_goal(
|
||||
description, context, library_agents
|
||||
params.description, params.context, library_agents
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
@@ -142,7 +162,7 @@ class CreateAgentTool(BaseTool):
|
||||
return ErrorResponse(
|
||||
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
|
||||
error="decomposition_failed",
|
||||
details={"description": description[:100]},
|
||||
details={"description": params.description[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -158,7 +178,7 @@ class CreateAgentTool(BaseTool):
|
||||
message=user_message,
|
||||
error=f"decomposition_failed:{error_type}",
|
||||
details={
|
||||
"description": description[:100],
|
||||
"description": params.description[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
@@ -244,7 +264,7 @@ class CreateAgentTool(BaseTool):
|
||||
return ErrorResponse(
|
||||
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
|
||||
error="generation_failed",
|
||||
details={"description": description[:100]},
|
||||
details={"description": params.description[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -266,7 +286,7 @@ class CreateAgentTool(BaseTool):
|
||||
message=user_message,
|
||||
error=f"generation_failed:{error_type}",
|
||||
details={
|
||||
"description": description[:100],
|
||||
"description": params.description[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
@@ -291,7 +311,7 @@ class CreateAgentTool(BaseTool):
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
if not save:
|
||||
if not params.save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
@@ -27,6 +29,23 @@ from .models import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomizeAgentInput(BaseModel):
|
||||
"""Input parameters for the customize_agent tool."""
|
||||
|
||||
agent_id: str = ""
|
||||
modifications: str = ""
|
||||
context: str = ""
|
||||
save: bool = True
|
||||
|
||||
@field_validator("agent_id", "modifications", "context", mode="before")
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> str:
|
||||
"""Strip whitespace from string fields."""
|
||||
if isinstance(v, str):
|
||||
return v.strip()
|
||||
return v if v is not None else ""
|
||||
|
||||
|
||||
class CustomizeAgentTool(BaseTool):
|
||||
"""Tool for customizing marketplace/template agents using natural language."""
|
||||
|
||||
@@ -92,7 +111,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the customize_agent tool.
|
||||
|
||||
@@ -102,20 +121,17 @@ class CustomizeAgentTool(BaseTool):
|
||||
3. Call customize_template with the modification request
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
modifications = kwargs.get("modifications", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
params = CustomizeAgentInput(**kwargs)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
if not params.agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||
error="missing_agent_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not modifications:
|
||||
if not params.modifications:
|
||||
return ErrorResponse(
|
||||
message="Please describe how you want to customize this agent.",
|
||||
error="missing_modifications",
|
||||
@@ -123,11 +139,11 @@ class CustomizeAgentTool(BaseTool):
|
||||
)
|
||||
|
||||
# Parse agent_id in format "creator/slug"
|
||||
parts = [p.strip() for p in agent_id.split("/")]
|
||||
parts = [p.strip() for p in params.agent_id.split("/")]
|
||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Invalid agent ID format: '{agent_id}'. "
|
||||
f"Invalid agent ID format: '{params.agent_id}'. "
|
||||
"Expected format is 'creator/agent-name' "
|
||||
"(e.g., 'autogpt/newsletter-writer')."
|
||||
),
|
||||
@@ -145,14 +161,14 @@ class CustomizeAgentTool(BaseTool):
|
||||
except AgentNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Could not find marketplace agent '{agent_id}'. "
|
||||
f"Could not find marketplace agent '{params.agent_id}'. "
|
||||
"Please check the agent ID and try again."
|
||||
),
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
||||
logger.error(f"Error fetching marketplace agent {params.agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the marketplace agent. Please try again.",
|
||||
error="fetch_error",
|
||||
@@ -162,7 +178,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
if not agent_details.store_listing_version_id:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The agent '{agent_id}' does not have an available version. "
|
||||
f"The agent '{params.agent_id}' does not have an available version. "
|
||||
"Please try a different agent."
|
||||
),
|
||||
error="no_version_available",
|
||||
@@ -174,7 +190,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||
template_agent = graph_to_json(graph)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
||||
logger.error(f"Error fetching agent graph for {params.agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the agent configuration. Please try again.",
|
||||
error="graph_fetch_error",
|
||||
@@ -185,8 +201,8 @@ class CustomizeAgentTool(BaseTool):
|
||||
try:
|
||||
result = await customize_template(
|
||||
template_agent=template_agent,
|
||||
modification_request=modifications,
|
||||
context=context,
|
||||
modification_request=params.modifications,
|
||||
context=params.context,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
@@ -198,7 +214,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
||||
logger.error(f"Error calling customize_template for {params.agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Failed to customize the agent due to a service error. "
|
||||
@@ -219,55 +235,25 @@ class CustomizeAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle error response
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="customize the agent",
|
||||
llm_parse_message=(
|
||||
"The AI had trouble customizing the agent. "
|
||||
"Please try again or simplify your request."
|
||||
),
|
||||
validation_message=(
|
||||
"The customized agent failed validation. "
|
||||
"Please try rephrasing your request."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"customization_failed:{error_type}",
|
||||
session_id=session_id,
|
||||
)
|
||||
# Handle response using match/case for cleaner pattern matching
|
||||
return await self._handle_customization_result(
|
||||
result=result,
|
||||
params=params,
|
||||
agent_details=agent_details,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle clarifying questions
|
||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions") or []
|
||||
if not isinstance(questions, list):
|
||||
logger.error(
|
||||
f"Unexpected clarifying questions format: {type(questions)}"
|
||||
)
|
||||
questions = []
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to customize this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
if isinstance(q, dict)
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Result should be the customized agent JSON
|
||||
async def _handle_customization_result(
|
||||
self,
|
||||
result: dict[str, Any],
|
||||
params: CustomizeAgentInput,
|
||||
agent_details: Any,
|
||||
user_id: str | None,
|
||||
session_id: str | None,
|
||||
) -> ToolResponseBase:
|
||||
"""Handle the result from customize_template using pattern matching."""
|
||||
# Ensure result is a dict
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||
return ErrorResponse(
|
||||
@@ -276,8 +262,77 @@ class CustomizeAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
customized_agent = result
|
||||
result_type = result.get("type")
|
||||
|
||||
match result_type:
|
||||
case "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="customize the agent",
|
||||
llm_parse_message=(
|
||||
"The AI had trouble customizing the agent. "
|
||||
"Please try again or simplify your request."
|
||||
),
|
||||
validation_message=(
|
||||
"The customized agent failed validation. "
|
||||
"Please try rephrasing your request."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"customization_failed:{error_type}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
case "clarifying_questions":
|
||||
questions_data = result.get("questions") or []
|
||||
if not isinstance(questions_data, list):
|
||||
logger.error(
|
||||
f"Unexpected clarifying questions format: {type(questions_data)}"
|
||||
)
|
||||
questions_data = []
|
||||
|
||||
questions = [
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", "") if isinstance(q, dict) else "",
|
||||
keyword=q.get("keyword", "") if isinstance(q, dict) else "",
|
||||
example=q.get("example") if isinstance(q, dict) else None,
|
||||
)
|
||||
for q in questions_data
|
||||
if isinstance(q, dict)
|
||||
]
|
||||
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to customize this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=questions,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
case _:
|
||||
# Default case: result is the customized agent JSON
|
||||
return await self._save_or_preview_agent(
|
||||
customized_agent=result,
|
||||
params=params,
|
||||
agent_details=agent_details,
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _save_or_preview_agent(
|
||||
self,
|
||||
customized_agent: dict[str, Any],
|
||||
params: CustomizeAgentInput,
|
||||
agent_details: Any,
|
||||
user_id: str | None,
|
||||
session_id: str | None,
|
||||
) -> ToolResponseBase:
|
||||
"""Save or preview the customized agent based on params.save."""
|
||||
agent_name = customized_agent.get(
|
||||
"name", f"Customized {agent_details.agent_name}"
|
||||
)
|
||||
@@ -287,7 +342,7 @@ class CustomizeAgentTool(BaseTool):
|
||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||
link_count = len(links) if isinstance(links, list) else 0
|
||||
|
||||
if not save:
|
||||
if not params.save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
@@ -27,6 +29,20 @@ from .models import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EditAgentInput(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
agent_id: str = ""
|
||||
changes: str = ""
|
||||
context: str = ""
|
||||
save: bool = True
|
||||
|
||||
@field_validator("agent_id", "changes", "context", mode="before")
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> str:
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
|
||||
class EditAgentTool(BaseTool):
|
||||
"""Tool for editing existing agents using natural language."""
|
||||
|
||||
@@ -90,7 +106,7 @@ class EditAgentTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the edit_agent tool.
|
||||
|
||||
@@ -99,35 +115,32 @@ class EditAgentTool(BaseTool):
|
||||
2. Generate updated agent (external service handles fixing and validation)
|
||||
3. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
changes = kwargs.get("changes", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
params = EditAgentInput(**kwargs)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not agent_id:
|
||||
if not params.agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
error="Missing agent_id parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not changes:
|
||||
if not params.changes:
|
||||
return ErrorResponse(
|
||||
message="Please describe what changes you want to make.",
|
||||
error="Missing changes parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||
current_agent = await get_agent_as_json(params.agent_id, user_id)
|
||||
|
||||
if current_agent is None:
|
||||
return ErrorResponse(
|
||||
message=f"Could not find agent with ID '{agent_id}' in your library.",
|
||||
message=f"Could not find agent '{params.agent_id}' in your library.",
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -138,7 +151,7 @@ class EditAgentTool(BaseTool):
|
||||
graph_id = current_agent.get("id")
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=changes,
|
||||
search_query=params.changes,
|
||||
exclude_graph_id=graph_id,
|
||||
include_marketplace=True,
|
||||
)
|
||||
@@ -148,9 +161,11 @@ class EditAgentTool(BaseTool):
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
|
||||
update_request = changes
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
update_request = params.changes
|
||||
if params.context:
|
||||
update_request = (
|
||||
f"{params.changes}\n\nAdditional context:\n{params.context}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await generate_agent_patch(
|
||||
@@ -174,7 +189,7 @@ class EditAgentTool(BaseTool):
|
||||
return ErrorResponse(
|
||||
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
|
||||
error="update_generation_failed",
|
||||
details={"agent_id": agent_id, "changes": changes[:100]},
|
||||
details={"agent_id": params.agent_id, "changes": params.changes[:100]},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -206,8 +221,8 @@ class EditAgentTool(BaseTool):
|
||||
message=user_message,
|
||||
error=f"update_generation_failed:{error_type}",
|
||||
details={
|
||||
"agent_id": agent_id,
|
||||
"changes": changes[:100],
|
||||
"agent_id": params.agent_id,
|
||||
"changes": params.changes[:100],
|
||||
"service_error": error_msg,
|
||||
"error_type": error_type,
|
||||
},
|
||||
@@ -239,7 +254,7 @@ class EditAgentTool(BaseTool):
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
if not save:
|
||||
if not params.save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've updated the agent. "
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
@@ -9,6 +11,18 @@ from .base import BaseTool
|
||||
from .models import ToolResponseBase
|
||||
|
||||
|
||||
class FindAgentInput(BaseModel):
|
||||
"""Input parameters for the find_agent tool."""
|
||||
|
||||
query: str = ""
|
||||
|
||||
@field_validator("query", mode="before")
|
||||
@classmethod
|
||||
def strip_string(cls, v: Any) -> str:
|
||||
"""Strip whitespace from query."""
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
|
||||
class FindAgentTool(BaseTool):
|
||||
"""Tool for discovering agents from the marketplace."""
|
||||
|
||||
@@ -36,10 +50,11 @@ class FindAgentTool(BaseTool):
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
self, user_id: str | None, session: ChatSession, **kwargs: Any
|
||||
) -> ToolResponseBase:
|
||||
params = FindAgentInput(**kwargs)
|
||||
return await search_agents(
|
||||
query=kwargs.get("query", "").strip(),
|
||||
query=params.query,
|
||||
source="marketplace",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||
@@ -18,6 +19,18 @@ from backend.data.block import get_block
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindBlockInput(BaseModel):
|
||||
"""Input parameters for the find_block tool."""
|
||||
|
||||
query: str = ""
|
||||
|
||||
@field_validator("query", mode="before")
|
||||
@classmethod
|
||||
def strip_string(cls, v: Any) -> str:
|
||||
"""Strip whitespace from query."""
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
|
||||
class FindBlockTool(BaseTool):
|
||||
"""Tool for searching available blocks."""
|
||||
|
||||
@@ -59,24 +72,24 @@ class FindBlockTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for blocks matching the query.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
**kwargs: Tool parameters
|
||||
|
||||
Returns:
|
||||
BlockListResponse: List of matching blocks
|
||||
NoResultsResponse: No blocks found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
params = FindBlockInput(**kwargs)
|
||||
session_id = session.session_id
|
||||
|
||||
if not query:
|
||||
if not params.query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
@@ -85,7 +98,7 @@ class FindBlockTool(BaseTool):
|
||||
try:
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
query=params.query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=10,
|
||||
@@ -93,7 +106,7 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
message=f"No blocks found for '{params.query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
"Check spelling of technical terms",
|
||||
@@ -165,7 +178,7 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
if not blocks:
|
||||
return NoResultsResponse(
|
||||
message=f"No blocks found for '{query}'",
|
||||
message=f"No blocks found for '{params.query}'",
|
||||
suggestions=[
|
||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||
],
|
||||
@@ -174,13 +187,13 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
return BlockListResponse(
|
||||
message=(
|
||||
f"Found {len(blocks)} block(s) matching '{query}'. "
|
||||
f"Found {len(blocks)} block(s) matching '{params.query}'. "
|
||||
"To execute a block, use run_block with the block's 'id' field "
|
||||
"and provide 'input_data' matching the block's input_schema."
|
||||
),
|
||||
blocks=blocks,
|
||||
count=len(blocks),
|
||||
query=query,
|
||||
query=params.query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
@@ -9,6 +11,15 @@ from .base import BaseTool
|
||||
from .models import ToolResponseBase
|
||||
|
||||
|
||||
class FindLibraryAgentInput(BaseModel):
|
||||
query: str = ""
|
||||
|
||||
@field_validator("query", mode="before")
|
||||
@classmethod
|
||||
def strip_string(cls, v: Any) -> str:
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
|
||||
class FindLibraryAgentTool(BaseTool):
|
||||
"""Tool for searching agents in the user's library."""
|
||||
|
||||
@@ -42,10 +53,11 @@ class FindLibraryAgentTool(BaseTool):
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self, user_id: str | None, session: ChatSession, **kwargs
|
||||
self, user_id: str | None, session: ChatSession, **kwargs: Any
|
||||
) -> ToolResponseBase:
|
||||
params = FindLibraryAgentInput(**kwargs)
|
||||
return await search_agents(
|
||||
query=kwargs.get("query", "").strip(),
|
||||
query=params.query,
|
||||
source="library",
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
|
||||
@@ -4,6 +4,8 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
@@ -18,6 +20,18 @@ logger = logging.getLogger(__name__)
|
||||
DOCS_BASE_URL = "https://docs.agpt.co"
|
||||
|
||||
|
||||
class GetDocPageInput(BaseModel):
|
||||
"""Input parameters for the get_doc_page tool."""
|
||||
|
||||
path: str = ""
|
||||
|
||||
@field_validator("path", mode="before")
|
||||
@classmethod
|
||||
def strip_string(cls, v: Any) -> str:
|
||||
"""Strip whitespace from path."""
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
|
||||
class GetDocPageTool(BaseTool):
|
||||
"""Tool for fetching full content of a documentation page."""
|
||||
|
||||
@@ -75,23 +89,23 @@ class GetDocPageTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Fetch full content of a documentation page.
|
||||
|
||||
Args:
|
||||
user_id: User ID (not required for docs)
|
||||
session: Chat session
|
||||
path: Path to the documentation file
|
||||
**kwargs: Tool parameters
|
||||
|
||||
Returns:
|
||||
DocPageResponse: Full document content
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
path = kwargs.get("path", "").strip()
|
||||
params = GetDocPageInput(**kwargs)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not path:
|
||||
if not params.path:
|
||||
return ErrorResponse(
|
||||
message="Please provide a documentation path.",
|
||||
error="Missing path parameter",
|
||||
@@ -99,7 +113,7 @@ class GetDocPageTool(BaseTool):
|
||||
)
|
||||
|
||||
# Sanitize path to prevent directory traversal
|
||||
if ".." in path or path.startswith("/"):
|
||||
if ".." in params.path or params.path.startswith("/"):
|
||||
return ErrorResponse(
|
||||
message="Invalid documentation path.",
|
||||
error="invalid_path",
|
||||
@@ -107,11 +121,11 @@ class GetDocPageTool(BaseTool):
|
||||
)
|
||||
|
||||
docs_root = self._get_docs_root()
|
||||
full_path = docs_root / path
|
||||
full_path = docs_root / params.path
|
||||
|
||||
if not full_path.exists():
|
||||
return ErrorResponse(
|
||||
message=f"Documentation page not found: {path}",
|
||||
message=f"Documentation page not found: {params.path}",
|
||||
error="not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -128,19 +142,19 @@ class GetDocPageTool(BaseTool):
|
||||
|
||||
try:
|
||||
content = full_path.read_text(encoding="utf-8")
|
||||
title = self._extract_title(content, path)
|
||||
title = self._extract_title(content, params.path)
|
||||
|
||||
return DocPageResponse(
|
||||
message=f"Retrieved documentation page: {title}",
|
||||
title=title,
|
||||
path=path,
|
||||
path=params.path,
|
||||
content=content,
|
||||
doc_url=self._make_doc_url(path),
|
||||
doc_url=self._make_doc_url(params.path),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read documentation page {path}: {e}")
|
||||
logger.error(f"Failed to read documentation page {params.path}: {e}")
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read documentation page: {str(e)}",
|
||||
error="read_failed",
|
||||
|
||||
@@ -5,6 +5,7 @@ import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
@@ -29,6 +30,25 @@ from .utils import build_missing_credentials_from_field_info
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunBlockInput(BaseModel):
|
||||
"""Input parameters for the run_block tool."""
|
||||
|
||||
block_id: str = ""
|
||||
input_data: dict[str, Any] = {}
|
||||
|
||||
@field_validator("block_id", mode="before")
|
||||
@classmethod
|
||||
def strip_block_id(cls, v: Any) -> str:
|
||||
"""Strip whitespace from block_id."""
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
@field_validator("input_data", mode="before")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v: Any) -> dict[str, Any]:
|
||||
"""Ensure input_data is a dict."""
|
||||
return v if isinstance(v, dict) else {}
|
||||
|
||||
|
||||
class RunBlockTool(BaseTool):
|
||||
"""Tool for executing a block and returning its outputs."""
|
||||
|
||||
@@ -162,37 +182,29 @@ class RunBlockTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute a block with the given input data.
|
||||
|
||||
Args:
|
||||
user_id: User ID (required)
|
||||
session: Chat session
|
||||
block_id: Block UUID to execute
|
||||
input_data: Input values for the block
|
||||
**kwargs: Tool parameters
|
||||
|
||||
Returns:
|
||||
BlockOutputResponse: Block execution outputs
|
||||
SetupRequirementsResponse: Missing credentials
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
block_id = kwargs.get("block_id", "").strip()
|
||||
input_data = kwargs.get("input_data", {})
|
||||
params = RunBlockInput(**kwargs)
|
||||
session_id = session.session_id
|
||||
|
||||
if not block_id:
|
||||
if not params.block_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide a block_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not isinstance(input_data, dict):
|
||||
return ErrorResponse(
|
||||
message="input_data must be an object",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
@@ -200,23 +212,25 @@ class RunBlockTool(BaseTool):
|
||||
)
|
||||
|
||||
# Get the block
|
||||
block = get_block(block_id)
|
||||
block = get_block(params.block_id)
|
||||
if not block:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' not found",
|
||||
message=f"Block '{params.block_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
if block.disabled:
|
||||
return ErrorResponse(
|
||||
message=f"Block '{block_id}' is disabled",
|
||||
message=f"Block '{params.block_id}' is disabled",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
logger.info(
|
||||
f"Executing block {block.name} ({params.block_id}) for user {user_id}"
|
||||
)
|
||||
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||
user_id, block, input_data
|
||||
user_id, block, params.input_data
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
@@ -234,7 +248,7 @@ class RunBlockTool(BaseTool):
|
||||
),
|
||||
session_id=session_id,
|
||||
setup_info=SetupInfo(
|
||||
agent_id=block_id,
|
||||
agent_id=params.block_id,
|
||||
agent_name=block.name,
|
||||
user_readiness=UserReadiness(
|
||||
has_all_credentials=False,
|
||||
@@ -263,7 +277,7 @@ class RunBlockTool(BaseTool):
|
||||
# - node_exec_id = unique per block execution
|
||||
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||
synthetic_node_id = f"copilot-node-{block_id}"
|
||||
synthetic_node_id = f"copilot-node-{params.block_id}"
|
||||
synthetic_node_exec_id = (
|
||||
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||
)
|
||||
@@ -298,8 +312,8 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
for field_name, cred_meta in matched_credentials.items():
|
||||
# Inject metadata into input_data (for validation)
|
||||
if field_name not in input_data:
|
||||
input_data[field_name] = cred_meta.model_dump()
|
||||
if field_name not in params.input_data:
|
||||
params.input_data[field_name] = cred_meta.model_dump()
|
||||
|
||||
# Fetch actual credentials and pass as kwargs (for execution)
|
||||
actual_credentials = await creds_manager.get(
|
||||
@@ -316,14 +330,14 @@ class RunBlockTool(BaseTool):
|
||||
# Execute the block and collect outputs
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in block.execute(
|
||||
input_data,
|
||||
params.input_data,
|
||||
**exec_kwargs,
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
|
||||
return BlockOutputResponse(
|
||||
message=f"Block '{block.name}' executed successfully",
|
||||
block_id=block_id,
|
||||
block_id=params.block_id,
|
||||
block_name=block.name,
|
||||
outputs=dict(outputs),
|
||||
success=True,
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
@@ -28,6 +29,18 @@ MAX_RESULTS = 5
|
||||
SNIPPET_LENGTH = 200
|
||||
|
||||
|
||||
class SearchDocsInput(BaseModel):
|
||||
"""Input parameters for the search_docs tool."""
|
||||
|
||||
query: str = ""
|
||||
|
||||
@field_validator("query", mode="before")
|
||||
@classmethod
|
||||
def strip_string(cls, v: Any) -> str:
|
||||
"""Strip whitespace from query."""
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
|
||||
class SearchDocsTool(BaseTool):
|
||||
"""Tool for searching AutoGPT platform documentation."""
|
||||
|
||||
@@ -91,24 +104,24 @@ class SearchDocsTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
"""Search documentation and return relevant sections.
|
||||
|
||||
Args:
|
||||
user_id: User ID (not required for docs)
|
||||
session: Chat session
|
||||
query: Search query
|
||||
**kwargs: Tool parameters
|
||||
|
||||
Returns:
|
||||
DocSearchResultsResponse: List of matching documentation sections
|
||||
NoResultsResponse: No results found
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
params = SearchDocsInput(**kwargs)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not query:
|
||||
if not params.query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query.",
|
||||
error="Missing query parameter",
|
||||
@@ -118,7 +131,7 @@ class SearchDocsTool(BaseTool):
|
||||
try:
|
||||
# Search using hybrid search for DOCUMENTATION content type only
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
query=params.query,
|
||||
content_types=[ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
page_size=MAX_RESULTS * 2, # Fetch extra for deduplication
|
||||
@@ -127,7 +140,7 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
if not results:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{query}'.",
|
||||
message=f"No documentation found for '{params.query}'.",
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use more general terms",
|
||||
@@ -162,7 +175,7 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
if not deduplicated:
|
||||
return NoResultsResponse(
|
||||
message=f"No documentation found for '{query}'.",
|
||||
message=f"No documentation found for '{params.query}'.",
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use more general terms",
|
||||
@@ -195,7 +208,7 @@ class SearchDocsTool(BaseTool):
|
||||
message=f"Found {len(doc_results)} relevant documentation sections.",
|
||||
results=doc_results,
|
||||
count=len(doc_results),
|
||||
query=query,
|
||||
query=params.query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
@@ -78,6 +78,65 @@ class WorkspaceDeleteResponse(ToolResponseBase):
|
||||
success: bool
|
||||
|
||||
|
||||
# Input models for workspace tools
|
||||
class ListWorkspaceFilesInput(BaseModel):
|
||||
"""Input parameters for list_workspace_files tool."""
|
||||
|
||||
path_prefix: str | None = None
|
||||
limit: int = 50
|
||||
include_all_sessions: bool = False
|
||||
|
||||
@field_validator("path_prefix", mode="before")
|
||||
@classmethod
|
||||
def strip_path(cls, v: Any) -> str | None:
|
||||
return v.strip() if isinstance(v, str) else None
|
||||
|
||||
|
||||
class ReadWorkspaceFileInput(BaseModel):
|
||||
"""Input parameters for read_workspace_file tool."""
|
||||
|
||||
file_id: str | None = None
|
||||
path: str | None = None
|
||||
force_download_url: bool = False
|
||||
|
||||
@field_validator("file_id", "path", mode="before")
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> str | None:
|
||||
return v.strip() if isinstance(v, str) else None
|
||||
|
||||
|
||||
class WriteWorkspaceFileInput(BaseModel):
|
||||
"""Input parameters for write_workspace_file tool."""
|
||||
|
||||
filename: str = ""
|
||||
content_base64: str = ""
|
||||
path: str | None = None
|
||||
mime_type: str | None = None
|
||||
overwrite: bool = False
|
||||
|
||||
@field_validator("filename", "content_base64", mode="before")
|
||||
@classmethod
|
||||
def strip_required(cls, v: Any) -> str:
|
||||
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||
|
||||
@field_validator("path", "mime_type", mode="before")
|
||||
@classmethod
|
||||
def strip_optional(cls, v: Any) -> str | None:
|
||||
return v.strip() if isinstance(v, str) else None
|
||||
|
||||
|
||||
class DeleteWorkspaceFileInput(BaseModel):
|
||||
"""Input parameters for delete_workspace_file tool."""
|
||||
|
||||
file_id: str | None = None
|
||||
path: str | None = None
|
||||
|
||||
@field_validator("file_id", "path", mode="before")
|
||||
@classmethod
|
||||
def strip_strings(cls, v: Any) -> str | None:
|
||||
return v.strip() if isinstance(v, str) else None
|
||||
|
||||
|
||||
class ListWorkspaceFilesTool(BaseTool):
|
||||
"""Tool for listing files in user's workspace."""
|
||||
|
||||
@@ -131,8 +190,9 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
params = ListWorkspaceFilesInput(**kwargs)
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
@@ -141,9 +201,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||
limit = min(kwargs.get("limit", 50), 100)
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
limit = min(params.limit, 100)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
@@ -151,13 +209,13 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
files = await manager.list_files(
|
||||
path=path_prefix,
|
||||
path=params.path_prefix,
|
||||
limit=limit,
|
||||
include_all_sessions=include_all_sessions,
|
||||
include_all_sessions=params.include_all_sessions,
|
||||
)
|
||||
total = await manager.get_file_count(
|
||||
path=path_prefix,
|
||||
include_all_sessions=include_all_sessions,
|
||||
path=params.path_prefix,
|
||||
include_all_sessions=params.include_all_sessions,
|
||||
)
|
||||
|
||||
file_infos = [
|
||||
@@ -171,7 +229,9 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
for f in files
|
||||
]
|
||||
|
||||
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||
scope_msg = (
|
||||
"all sessions" if params.include_all_sessions else "current session"
|
||||
)
|
||||
return WorkspaceFileListResponse(
|
||||
files=file_infos,
|
||||
total_count=total,
|
||||
@@ -259,8 +319,9 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
params = ReadWorkspaceFileInput(**kwargs)
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
@@ -269,11 +330,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
|
||||
if not file_id and not path:
|
||||
if not params.file_id and not params.path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
@@ -285,21 +342,21 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
# Get file info
|
||||
if file_id:
|
||||
file_info = await manager.get_file_info(file_id)
|
||||
if params.file_id:
|
||||
file_info = await manager.get_file_info(params.file_id)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {file_id}",
|
||||
message=f"File not found: {params.file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_id
|
||||
target_file_id = params.file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
assert params.path is not None
|
||||
file_info = await manager.get_file_info_by_path(params.path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
message=f"File not found at path: {params.path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
@@ -309,7 +366,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||
|
||||
# Return inline content for small text files (unless force_download_url)
|
||||
if is_small_file and is_text_file and not force_download_url:
|
||||
if is_small_file and is_text_file and not params.force_download_url:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
@@ -429,8 +486,9 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
params = WriteWorkspaceFileInput(**kwargs)
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
@@ -439,19 +497,13 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
filename: str = kwargs.get("filename", "")
|
||||
content_b64: str = kwargs.get("content_base64", "")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||
overwrite: bool = kwargs.get("overwrite", False)
|
||||
|
||||
if not filename:
|
||||
if not params.filename:
|
||||
return ErrorResponse(
|
||||
message="Please provide a filename",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not content_b64:
|
||||
if not params.content_base64:
|
||||
return ErrorResponse(
|
||||
message="Please provide content_base64",
|
||||
session_id=session_id,
|
||||
@@ -459,7 +511,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
# Decode content
|
||||
try:
|
||||
content = base64.b64decode(content_b64)
|
||||
content = base64.b64decode(params.content_base64)
|
||||
except Exception:
|
||||
return ErrorResponse(
|
||||
message="Invalid base64-encoded content",
|
||||
@@ -476,7 +528,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
await scan_content_safe(content, filename=params.filename)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
@@ -484,10 +536,10 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
|
||||
file_record = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
overwrite=overwrite,
|
||||
filename=params.filename,
|
||||
path=params.path,
|
||||
mime_type=params.mime_type,
|
||||
overwrite=params.overwrite,
|
||||
)
|
||||
|
||||
return WorkspaceWriteResponse(
|
||||
@@ -557,8 +609,9 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
params = DeleteWorkspaceFileInput(**kwargs)
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
@@ -567,10 +620,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
|
||||
if not file_id and not path:
|
||||
if not params.file_id and not params.path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
@@ -583,15 +633,15 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
|
||||
# Determine the file_id to delete
|
||||
target_file_id: str
|
||||
if file_id:
|
||||
target_file_id = file_id
|
||||
if params.file_id:
|
||||
target_file_id = params.file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
assert params.path is not None
|
||||
file_info = await manager.get_file_info_by_path(params.path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
message=f"File not found at path: {params.path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
Reference in New Issue
Block a user