Compare commits

..

2 Commits

Author SHA1 Message Date
Otto
7f7a7067ec refactor(copilot): use Pydantic models and match/case in customize_agent
Addresses review feedback from ntindle:

1. Use typed parameters instead of kwargs.get():
   - Added CustomizeAgentInput Pydantic model with field_validator for stripping strings
   - Tool now uses params = CustomizeAgentInput(**kwargs) pattern

2. Use match/case for cleaner pattern matching:
   - Extracted response handling to _handle_customization_result method
   - Uses match result_type: case 'error' | 'clarifying_questions' | _

3. Improved code organization:
   - Split monolithic _execute into smaller focused methods
   - _handle_customization_result for response type handling
   - _save_or_preview_agent for final save/preview logic
2026-02-04 08:53:02 +00:00
Krzysztof Czerwinski
c026485023 feat(frontend): Disable auto-opening wallet (#11961)
<!-- Clearly explain the need for these changes: -->

### Changes 🏗️

- Disable auto-opening Wallet for first time user and on credit increase
- Remove no longer needed `lastSeenCredits` state and storage

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Wallet doesn't open automatically
2026-02-04 06:11:41 +00:00
12 changed files with 151 additions and 352 deletions

View File

@@ -1,15 +1,12 @@
import asyncio
import logging
import time
import uuid as uuid_module
from asyncio import CancelledError
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, cast
import openai
from backend.util.prompt import compress_context
if TYPE_CHECKING:
from backend.util.prompt import CompressResult
@@ -470,6 +467,8 @@ async def stream_chat_completion(
should_retry = False
# Generate unique IDs for AI SDK protocol
import uuid as uuid_module
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())
@@ -827,6 +826,10 @@ async def _manage_context_window(
Returns:
CompressResult with compacted messages and metadata
"""
import openai
from backend.util.prompt import compress_context
# Convert messages to dict format
messages_dict = []
for msg in messages:
@@ -1137,6 +1140,8 @@ async def _yield_tool_call(
KeyError: If expected tool call fields are missing
TypeError: If tool call structure is invalid
"""
import uuid as uuid_module
tool_name = tool_calls[yield_idx]["function"]["name"]
tool_call_id = tool_calls[yield_idx]["id"]
@@ -1757,6 +1762,8 @@ async def _generate_llm_continuation_with_streaming(
after a tool result is saved. Chunks are published to the stream registry
so reconnecting clients can receive them.
"""
import uuid as uuid_module
try:
# Load fresh session from DB (bypass cache to get the updated tool result)
await invalidate_session_cache(session_id)
@@ -1792,6 +1799,10 @@ async def _generate_llm_continuation_with_streaming(
extra_body["session_id"] = session_id[:128]
# Make streaming LLM call (no tools - just text response)
from typing import cast
from openai.types.chat import ChatCompletionMessageParam
# Generate unique IDs for AI SDK protocol
message_id = str(uuid_module.uuid4())
text_block_id = str(uuid_module.uuid4())

View File

@@ -3,8 +3,6 @@
import logging
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
@@ -30,26 +28,6 @@ from .models import (
logger = logging.getLogger(__name__)
class CreateAgentInput(BaseModel):
"""Input parameters for the create_agent tool."""
description: str = ""
context: str = ""
save: bool = True
# Internal async processing params (passed by long-running tool handler)
_operation_id: str | None = None
_task_id: str | None = None
@field_validator("description", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
"""Strip whitespace from string fields."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class Config:
extra = "allow" # Allow _operation_id, _task_id from kwargs
class CreateAgentTool(BaseTool):
"""Tool for creating agents from natural language descriptions."""
@@ -107,7 +85,7 @@ class CreateAgentTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
**kwargs,
) -> ToolResponseBase:
"""Execute the create_agent tool.
@@ -116,14 +94,16 @@ class CreateAgentTool(BaseTool):
2. Generate agent JSON (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
params = CreateAgentInput(**kwargs)
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not params.description:
if not description:
return ErrorResponse(
message="Please provide a description of what the agent should do.",
error="Missing description parameter",
@@ -135,7 +115,7 @@ class CreateAgentTool(BaseTool):
try:
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=params.description,
search_query=description,
include_marketplace=True,
)
logger.debug(
@@ -146,7 +126,7 @@ class CreateAgentTool(BaseTool):
try:
decomposition_result = await decompose_goal(
params.description, params.context, library_agents
description, context, library_agents
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
@@ -162,7 +142,7 @@ class CreateAgentTool(BaseTool):
return ErrorResponse(
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
error="decomposition_failed",
details={"description": params.description[:100]},
details={"description": description[:100]},
session_id=session_id,
)
@@ -178,7 +158,7 @@ class CreateAgentTool(BaseTool):
message=user_message,
error=f"decomposition_failed:{error_type}",
details={
"description": params.description[:100],
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
@@ -264,7 +244,7 @@ class CreateAgentTool(BaseTool):
return ErrorResponse(
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
error="generation_failed",
details={"description": params.description[:100]},
details={"description": description[:100]},
session_id=session_id,
)
@@ -286,7 +266,7 @@ class CreateAgentTool(BaseTool):
message=user_message,
error=f"generation_failed:{error_type}",
details={
"description": params.description[:100],
"description": description[:100],
"service_error": error_msg,
"error_type": error_type,
},
@@ -311,7 +291,7 @@ class CreateAgentTool(BaseTool):
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
if not params.save:
if not save:
return AgentPreviewResponse(
message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "

View File

@@ -139,7 +139,7 @@ class CustomizeAgentTool(BaseTool):
)
# Parse agent_id in format "creator/slug"
parts = [p.strip() for p in params.agent_id.split("/")]
parts = params.agent_id.split("/")
if len(parts) != 2 or not parts[0] or not parts[1]:
return ErrorResponse(
message=(

View File

@@ -3,8 +3,6 @@
import logging
from typing import Any
from pydantic import BaseModel, ConfigDict, field_validator
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
@@ -29,20 +27,6 @@ from .models import (
logger = logging.getLogger(__name__)
class EditAgentInput(BaseModel):
model_config = ConfigDict(extra="allow")
agent_id: str = ""
changes: str = ""
context: str = ""
save: bool = True
@field_validator("agent_id", "changes", "context", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str:
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class EditAgentTool(BaseTool):
"""Tool for editing existing agents using natural language."""
@@ -106,7 +90,7 @@ class EditAgentTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
**kwargs,
) -> ToolResponseBase:
"""Execute the edit_agent tool.
@@ -115,32 +99,35 @@ class EditAgentTool(BaseTool):
2. Generate updated agent (external service handles fixing and validation)
3. Preview or save based on the save parameter
"""
params = EditAgentInput(**kwargs)
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not params.agent_id:
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
error="Missing agent_id parameter",
session_id=session_id,
)
if not params.changes:
if not changes:
return ErrorResponse(
message="Please describe what changes you want to make.",
error="Missing changes parameter",
session_id=session_id,
)
current_agent = await get_agent_as_json(params.agent_id, user_id)
current_agent = await get_agent_as_json(agent_id, user_id)
if current_agent is None:
return ErrorResponse(
message=f"Could not find agent '{params.agent_id}' in your library.",
message=f"Could not find agent with ID '{agent_id}' in your library.",
error="agent_not_found",
session_id=session_id,
)
@@ -151,7 +138,7 @@ class EditAgentTool(BaseTool):
graph_id = current_agent.get("id")
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
search_query=params.changes,
search_query=changes,
exclude_graph_id=graph_id,
include_marketplace=True,
)
@@ -161,11 +148,9 @@ class EditAgentTool(BaseTool):
except Exception as e:
logger.warning(f"Failed to fetch library agents: {e}")
update_request = params.changes
if params.context:
update_request = (
f"{params.changes}\n\nAdditional context:\n{params.context}"
)
update_request = changes
if context:
update_request = f"{changes}\n\nAdditional context:\n{context}"
try:
result = await generate_agent_patch(
@@ -189,7 +174,7 @@ class EditAgentTool(BaseTool):
return ErrorResponse(
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
error="update_generation_failed",
details={"agent_id": params.agent_id, "changes": params.changes[:100]},
details={"agent_id": agent_id, "changes": changes[:100]},
session_id=session_id,
)
@@ -221,8 +206,8 @@ class EditAgentTool(BaseTool):
message=user_message,
error=f"update_generation_failed:{error_type}",
details={
"agent_id": params.agent_id,
"changes": params.changes[:100],
"agent_id": agent_id,
"changes": changes[:100],
"service_error": error_msg,
"error_type": error_type,
},
@@ -254,7 +239,7 @@ class EditAgentTool(BaseTool):
node_count = len(updated_agent.get("nodes", []))
link_count = len(updated_agent.get("links", []))
if not params.save:
if not save:
return AgentPreviewResponse(
message=(
f"I've updated the agent. "

View File

@@ -2,8 +2,6 @@
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
@@ -11,18 +9,6 @@ from .base import BaseTool
from .models import ToolResponseBase
class FindAgentInput(BaseModel):
"""Input parameters for the find_agent tool."""
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from query."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class FindAgentTool(BaseTool):
"""Tool for discovering agents from the marketplace."""
@@ -50,11 +36,10 @@ class FindAgentTool(BaseTool):
}
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs: Any
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
params = FindAgentInput(**kwargs)
return await search_agents(
query=params.query,
query=kwargs.get("query", "").strip(),
source="marketplace",
session_id=session.session_id,
user_id=user_id,

View File

@@ -2,7 +2,6 @@ import logging
from typing import Any
from prisma.enums import ContentType
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
@@ -19,18 +18,6 @@ from backend.data.block import get_block
logger = logging.getLogger(__name__)
class FindBlockInput(BaseModel):
"""Input parameters for the find_block tool."""
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from query."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class FindBlockTool(BaseTool):
"""Tool for searching available blocks."""
@@ -72,24 +59,24 @@ class FindBlockTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
**kwargs,
) -> ToolResponseBase:
"""Search for blocks matching the query.
Args:
user_id: User ID (required)
session: Chat session
**kwargs: Tool parameters
query: Search query
Returns:
BlockListResponse: List of matching blocks
NoResultsResponse: No blocks found
ErrorResponse: Error message
"""
params = FindBlockInput(**kwargs)
query = kwargs.get("query", "").strip()
session_id = session.session_id
if not params.query:
if not query:
return ErrorResponse(
message="Please provide a search query",
session_id=session_id,
@@ -98,7 +85,7 @@ class FindBlockTool(BaseTool):
try:
# Search for blocks using hybrid search
results, total = await unified_hybrid_search(
query=params.query,
query=query,
content_types=[ContentType.BLOCK],
page=1,
page_size=10,
@@ -106,7 +93,7 @@ class FindBlockTool(BaseTool):
if not results:
return NoResultsResponse(
message=f"No blocks found for '{params.query}'",
message=f"No blocks found for '{query}'",
suggestions=[
"Try broader keywords like 'email', 'http', 'text', 'ai'",
"Check spelling of technical terms",
@@ -178,7 +165,7 @@ class FindBlockTool(BaseTool):
if not blocks:
return NoResultsResponse(
message=f"No blocks found for '{params.query}'",
message=f"No blocks found for '{query}'",
suggestions=[
"Try broader keywords like 'email', 'http', 'text', 'ai'",
],
@@ -187,13 +174,13 @@ class FindBlockTool(BaseTool):
return BlockListResponse(
message=(
f"Found {len(blocks)} block(s) matching '{params.query}'. "
f"Found {len(blocks)} block(s) matching '{query}'. "
"To execute a block, use run_block with the block's 'id' field "
"and provide 'input_data' matching the block's input_schema."
),
blocks=blocks,
count=len(blocks),
query=params.query,
query=query,
session_id=session_id,
)

View File

@@ -2,8 +2,6 @@
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
@@ -11,15 +9,6 @@ from .base import BaseTool
from .models import ToolResponseBase
class FindLibraryAgentInput(BaseModel):
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class FindLibraryAgentTool(BaseTool):
"""Tool for searching agents in the user's library."""
@@ -53,11 +42,10 @@ class FindLibraryAgentTool(BaseTool):
return True
async def _execute(
self, user_id: str | None, session: ChatSession, **kwargs: Any
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
params = FindLibraryAgentInput(**kwargs)
return await search_agents(
query=params.query,
query=kwargs.get("query", "").strip(),
source="library",
session_id=session.session_id,
user_id=user_id,

View File

@@ -4,8 +4,6 @@ import logging
from pathlib import Path
from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
@@ -20,18 +18,6 @@ logger = logging.getLogger(__name__)
DOCS_BASE_URL = "https://docs.agpt.co"
class GetDocPageInput(BaseModel):
"""Input parameters for the get_doc_page tool."""
path: str = ""
@field_validator("path", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from path."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class GetDocPageTool(BaseTool):
"""Tool for fetching full content of a documentation page."""
@@ -89,23 +75,23 @@ class GetDocPageTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
**kwargs,
) -> ToolResponseBase:
"""Fetch full content of a documentation page.
Args:
user_id: User ID (not required for docs)
session: Chat session
**kwargs: Tool parameters
path: Path to the documentation file
Returns:
DocPageResponse: Full document content
ErrorResponse: Error message
"""
params = GetDocPageInput(**kwargs)
path = kwargs.get("path", "").strip()
session_id = session.session_id if session else None
if not params.path:
if not path:
return ErrorResponse(
message="Please provide a documentation path.",
error="Missing path parameter",
@@ -113,7 +99,7 @@ class GetDocPageTool(BaseTool):
)
# Sanitize path to prevent directory traversal
if ".." in params.path or params.path.startswith("/"):
if ".." in path or path.startswith("/"):
return ErrorResponse(
message="Invalid documentation path.",
error="invalid_path",
@@ -121,11 +107,11 @@ class GetDocPageTool(BaseTool):
)
docs_root = self._get_docs_root()
full_path = docs_root / params.path
full_path = docs_root / path
if not full_path.exists():
return ErrorResponse(
message=f"Documentation page not found: {params.path}",
message=f"Documentation page not found: {path}",
error="not_found",
session_id=session_id,
)
@@ -142,19 +128,19 @@ class GetDocPageTool(BaseTool):
try:
content = full_path.read_text(encoding="utf-8")
title = self._extract_title(content, params.path)
title = self._extract_title(content, path)
return DocPageResponse(
message=f"Retrieved documentation page: {title}",
title=title,
path=params.path,
path=path,
content=content,
doc_url=self._make_doc_url(params.path),
doc_url=self._make_doc_url(path),
session_id=session_id,
)
except Exception as e:
logger.error(f"Failed to read documentation page {params.path}: {e}")
logger.error(f"Failed to read documentation page {path}: {e}")
return ErrorResponse(
message=f"Failed to read documentation page: {str(e)}",
error="read_failed",

View File

@@ -5,7 +5,6 @@ import uuid
from collections import defaultdict
from typing import Any
from pydantic import BaseModel, field_validator
from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession
@@ -30,25 +29,6 @@ from .utils import build_missing_credentials_from_field_info
logger = logging.getLogger(__name__)
class RunBlockInput(BaseModel):
"""Input parameters for the run_block tool."""
block_id: str = ""
input_data: dict[str, Any] = {}
@field_validator("block_id", mode="before")
@classmethod
def strip_block_id(cls, v: Any) -> str:
"""Strip whitespace from block_id."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
@field_validator("input_data", mode="before")
@classmethod
def ensure_dict(cls, v: Any) -> dict[str, Any]:
"""Ensure input_data is a dict."""
return v if isinstance(v, dict) else {}
class RunBlockTool(BaseTool):
"""Tool for executing a block and returning its outputs."""
@@ -182,29 +162,37 @@ class RunBlockTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
**kwargs,
) -> ToolResponseBase:
"""Execute a block with the given input data.
Args:
user_id: User ID (required)
session: Chat session
**kwargs: Tool parameters
block_id: Block UUID to execute
input_data: Input values for the block
Returns:
BlockOutputResponse: Block execution outputs
SetupRequirementsResponse: Missing credentials
ErrorResponse: Error message
"""
params = RunBlockInput(**kwargs)
block_id = kwargs.get("block_id", "").strip()
input_data = kwargs.get("input_data", {})
session_id = session.session_id
if not params.block_id:
if not block_id:
return ErrorResponse(
message="Please provide a block_id",
session_id=session_id,
)
if not isinstance(input_data, dict):
return ErrorResponse(
message="input_data must be an object",
session_id=session_id,
)
if not user_id:
return ErrorResponse(
message="Authentication required",
@@ -212,25 +200,23 @@ class RunBlockTool(BaseTool):
)
# Get the block
block = get_block(params.block_id)
block = get_block(block_id)
if not block:
return ErrorResponse(
message=f"Block '{params.block_id}' not found",
message=f"Block '{block_id}' not found",
session_id=session_id,
)
if block.disabled:
return ErrorResponse(
message=f"Block '{params.block_id}' is disabled",
message=f"Block '{block_id}' is disabled",
session_id=session_id,
)
logger.info(
f"Executing block {block.name} ({params.block_id}) for user {user_id}"
)
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager()
matched_credentials, missing_credentials = await self._check_block_credentials(
user_id, block, params.input_data
user_id, block, input_data
)
if missing_credentials:
@@ -248,7 +234,7 @@ class RunBlockTool(BaseTool):
),
session_id=session_id,
setup_info=SetupInfo(
agent_id=params.block_id,
agent_id=block_id,
agent_name=block.name,
user_readiness=UserReadiness(
has_all_credentials=False,
@@ -277,7 +263,7 @@ class RunBlockTool(BaseTool):
# - node_exec_id = unique per block execution
synthetic_graph_id = f"copilot-session-{session.session_id}"
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
synthetic_node_id = f"copilot-node-{params.block_id}"
synthetic_node_id = f"copilot-node-{block_id}"
synthetic_node_exec_id = (
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
)
@@ -312,8 +298,8 @@ class RunBlockTool(BaseTool):
for field_name, cred_meta in matched_credentials.items():
# Inject metadata into input_data (for validation)
if field_name not in params.input_data:
params.input_data[field_name] = cred_meta.model_dump()
if field_name not in input_data:
input_data[field_name] = cred_meta.model_dump()
# Fetch actual credentials and pass as kwargs (for execution)
actual_credentials = await creds_manager.get(
@@ -330,14 +316,14 @@ class RunBlockTool(BaseTool):
# Execute the block and collect outputs
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in block.execute(
params.input_data,
input_data,
**exec_kwargs,
):
outputs[output_name].append(output_data)
return BlockOutputResponse(
message=f"Block '{block.name}' executed successfully",
block_id=params.block_id,
block_id=block_id,
block_name=block.name,
outputs=dict(outputs),
success=True,

View File

@@ -4,7 +4,6 @@ import logging
from typing import Any
from prisma.enums import ContentType
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
@@ -29,18 +28,6 @@ MAX_RESULTS = 5
SNIPPET_LENGTH = 200
class SearchDocsInput(BaseModel):
"""Input parameters for the search_docs tool."""
query: str = ""
@field_validator("query", mode="before")
@classmethod
def strip_string(cls, v: Any) -> str:
"""Strip whitespace from query."""
return v.strip() if isinstance(v, str) else (v if v is not None else "")
class SearchDocsTool(BaseTool):
"""Tool for searching AutoGPT platform documentation."""
@@ -104,24 +91,24 @@ class SearchDocsTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
**kwargs,
) -> ToolResponseBase:
"""Search documentation and return relevant sections.
Args:
user_id: User ID (not required for docs)
session: Chat session
**kwargs: Tool parameters
query: Search query
Returns:
DocSearchResultsResponse: List of matching documentation sections
NoResultsResponse: No results found
ErrorResponse: Error message
"""
params = SearchDocsInput(**kwargs)
query = kwargs.get("query", "").strip()
session_id = session.session_id if session else None
if not params.query:
if not query:
return ErrorResponse(
message="Please provide a search query.",
error="Missing query parameter",
@@ -131,7 +118,7 @@ class SearchDocsTool(BaseTool):
try:
# Search using hybrid search for DOCUMENTATION content type only
results, total = await unified_hybrid_search(
query=params.query,
query=query,
content_types=[ContentType.DOCUMENTATION],
page=1,
page_size=MAX_RESULTS * 2, # Fetch extra for deduplication
@@ -140,7 +127,7 @@ class SearchDocsTool(BaseTool):
if not results:
return NoResultsResponse(
message=f"No documentation found for '{params.query}'.",
message=f"No documentation found for '{query}'.",
suggestions=[
"Try different keywords",
"Use more general terms",
@@ -175,7 +162,7 @@ class SearchDocsTool(BaseTool):
if not deduplicated:
return NoResultsResponse(
message=f"No documentation found for '{params.query}'.",
message=f"No documentation found for '{query}'.",
suggestions=[
"Try different keywords",
"Use more general terms",
@@ -208,7 +195,7 @@ class SearchDocsTool(BaseTool):
message=f"Found {len(doc_results)} relevant documentation sections.",
results=doc_results,
count=len(doc_results),
query=params.query,
query=query,
session_id=session_id,
)

View File

@@ -2,9 +2,9 @@
import base64
import logging
from typing import Any
from typing import Any, Optional
from pydantic import BaseModel, field_validator
from pydantic import BaseModel
from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace
@@ -78,65 +78,6 @@ class WorkspaceDeleteResponse(ToolResponseBase):
success: bool
# Input models for workspace tools
class ListWorkspaceFilesInput(BaseModel):
"""Input parameters for list_workspace_files tool."""
path_prefix: str | None = None
limit: int = 50
include_all_sessions: bool = False
@field_validator("path_prefix", mode="before")
@classmethod
def strip_path(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class ReadWorkspaceFileInput(BaseModel):
"""Input parameters for read_workspace_file tool."""
file_id: str | None = None
path: str | None = None
force_download_url: bool = False
@field_validator("file_id", "path", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class WriteWorkspaceFileInput(BaseModel):
"""Input parameters for write_workspace_file tool."""
filename: str = ""
content_base64: str = ""
path: str | None = None
mime_type: str | None = None
overwrite: bool = False
@field_validator("filename", "content_base64", mode="before")
@classmethod
def strip_required(cls, v: Any) -> str:
return v.strip() if isinstance(v, str) else (v if v is not None else "")
@field_validator("path", "mime_type", mode="before")
@classmethod
def strip_optional(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class DeleteWorkspaceFileInput(BaseModel):
"""Input parameters for delete_workspace_file tool."""
file_id: str | None = None
path: str | None = None
@field_validator("file_id", "path", mode="before")
@classmethod
def strip_strings(cls, v: Any) -> str | None:
return v.strip() if isinstance(v, str) else None
class ListWorkspaceFilesTool(BaseTool):
"""Tool for listing files in user's workspace."""
@@ -190,9 +131,8 @@ class ListWorkspaceFilesTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
**kwargs,
) -> ToolResponseBase:
params = ListWorkspaceFilesInput(**kwargs)
session_id = session.session_id
if not user_id:
@@ -201,7 +141,9 @@ class ListWorkspaceFilesTool(BaseTool):
session_id=session_id,
)
limit = min(params.limit, 100)
path_prefix: Optional[str] = kwargs.get("path_prefix")
limit = min(kwargs.get("limit", 50), 100)
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
workspace = await get_or_create_workspace(user_id)
@@ -209,13 +151,13 @@ class ListWorkspaceFilesTool(BaseTool):
manager = WorkspaceManager(user_id, workspace.id, session_id)
files = await manager.list_files(
path=params.path_prefix,
path=path_prefix,
limit=limit,
include_all_sessions=params.include_all_sessions,
include_all_sessions=include_all_sessions,
)
total = await manager.get_file_count(
path=params.path_prefix,
include_all_sessions=params.include_all_sessions,
path=path_prefix,
include_all_sessions=include_all_sessions,
)
file_infos = [
@@ -229,9 +171,7 @@ class ListWorkspaceFilesTool(BaseTool):
for f in files
]
scope_msg = (
"all sessions" if params.include_all_sessions else "current session"
)
scope_msg = "all sessions" if include_all_sessions else "current session"
return WorkspaceFileListResponse(
files=file_infos,
total_count=total,
@@ -319,9 +259,8 @@ class ReadWorkspaceFileTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
**kwargs,
) -> ToolResponseBase:
params = ReadWorkspaceFileInput(**kwargs)
session_id = session.session_id
if not user_id:
@@ -330,7 +269,11 @@ class ReadWorkspaceFileTool(BaseTool):
session_id=session_id,
)
if not params.file_id and not params.path:
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
force_download_url: bool = kwargs.get("force_download_url", False)
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
@@ -342,21 +285,21 @@ class ReadWorkspaceFileTool(BaseTool):
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Get file info
if params.file_id:
file_info = await manager.get_file_info(params.file_id)
if file_id:
file_info = await manager.get_file_info(file_id)
if file_info is None:
return ErrorResponse(
message=f"File not found: {params.file_id}",
message=f"File not found: {file_id}",
session_id=session_id,
)
target_file_id = params.file_id
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert params.path is not None
file_info = await manager.get_file_info_by_path(params.path)
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {params.path}",
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
@@ -366,7 +309,7 @@ class ReadWorkspaceFileTool(BaseTool):
is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not params.force_download_url:
if is_small_file and is_text_file and not force_download_url:
content = await manager.read_file_by_id(target_file_id)
content_b64 = base64.b64encode(content).decode("utf-8")
@@ -486,9 +429,8 @@ class WriteWorkspaceFileTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
**kwargs,
) -> ToolResponseBase:
params = WriteWorkspaceFileInput(**kwargs)
session_id = session.session_id
if not user_id:
@@ -497,13 +439,19 @@ class WriteWorkspaceFileTool(BaseTool):
session_id=session_id,
)
if not params.filename:
filename: str = kwargs.get("filename", "")
content_b64: str = kwargs.get("content_base64", "")
path: Optional[str] = kwargs.get("path")
mime_type: Optional[str] = kwargs.get("mime_type")
overwrite: bool = kwargs.get("overwrite", False)
if not filename:
return ErrorResponse(
message="Please provide a filename",
session_id=session_id,
)
if not params.content_base64:
if not content_b64:
return ErrorResponse(
message="Please provide content_base64",
session_id=session_id,
@@ -511,7 +459,7 @@ class WriteWorkspaceFileTool(BaseTool):
# Decode content
try:
content = base64.b64decode(params.content_base64)
content = base64.b64decode(content_b64)
except Exception:
return ErrorResponse(
message="Invalid base64-encoded content",
@@ -528,7 +476,7 @@ class WriteWorkspaceFileTool(BaseTool):
try:
# Virus scan
await scan_content_safe(content, filename=params.filename)
await scan_content_safe(content, filename=filename)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
@@ -536,10 +484,10 @@ class WriteWorkspaceFileTool(BaseTool):
file_record = await manager.write_file(
content=content,
filename=params.filename,
path=params.path,
mime_type=params.mime_type,
overwrite=params.overwrite,
filename=filename,
path=path,
mime_type=mime_type,
overwrite=overwrite,
)
return WorkspaceWriteResponse(
@@ -609,9 +557,8 @@ class DeleteWorkspaceFileTool(BaseTool):
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
**kwargs,
) -> ToolResponseBase:
params = DeleteWorkspaceFileInput(**kwargs)
session_id = session.session_id
if not user_id:
@@ -620,7 +567,10 @@ class DeleteWorkspaceFileTool(BaseTool):
session_id=session_id,
)
if not params.file_id and not params.path:
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
@@ -633,15 +583,15 @@ class DeleteWorkspaceFileTool(BaseTool):
# Determine the file_id to delete
target_file_id: str
if params.file_id:
target_file_id = params.file_id
if file_id:
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert params.path is not None
file_info = await manager.get_file_info_by_path(params.path)
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {params.path}",
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id

View File

@@ -15,7 +15,6 @@ import {
import { cn } from "@/lib/utils";
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { storage, Key as StorageKey } from "@/services/storage/local-storage";
import { WalletIcon } from "@phosphor-icons/react";
import { PopoverClose } from "@radix-ui/react-popover";
import { X } from "lucide-react";
@@ -175,7 +174,6 @@ export function Wallet() {
const [prevCredits, setPrevCredits] = useState<number | null>(credits);
const [flash, setFlash] = useState(false);
const [walletOpen, setWalletOpen] = useState(false);
const [lastSeenCredits, setLastSeenCredits] = useState<number | null>(null);
const totalCount = useMemo(() => {
return groups.reduce((acc, group) => acc + group.tasks.length, 0);
@@ -200,38 +198,6 @@ export function Wallet() {
setCompletedCount(completed);
}, [groups, state?.completedSteps]);
// Load last seen credits from localStorage once on mount
useEffect(() => {
const stored = storage.get(StorageKey.WALLET_LAST_SEEN_CREDITS);
if (stored !== undefined && stored !== null) {
const parsed = parseFloat(stored);
if (!Number.isNaN(parsed)) setLastSeenCredits(parsed);
else setLastSeenCredits(0);
} else {
setLastSeenCredits(0);
}
}, []);
// Auto-open once if never shown, otherwise open only when credits increase beyond last seen
useEffect(() => {
if (typeof credits !== "number") return;
// Open once for first-time users
if (state && state.walletShown === false) {
requestAnimationFrame(() => setWalletOpen(true));
// Mark as shown so it won't reopen on every reload
updateState({ walletShown: true });
return;
}
// Open if user gained more credits than last acknowledged
if (
lastSeenCredits !== null &&
credits > lastSeenCredits &&
walletOpen === false
) {
requestAnimationFrame(() => setWalletOpen(true));
}
}, [credits, lastSeenCredits, state?.walletShown, updateState, walletOpen]);
const onWalletOpen = useCallback(async () => {
if (!state?.walletShown) {
updateState({ walletShown: true });
@@ -324,19 +290,7 @@ export function Wallet() {
if (credits === null || !state) return null;
return (
<Popover
open={walletOpen}
onOpenChange={(open) => {
setWalletOpen(open);
if (!open) {
// Persist the latest acknowledged credits so we only auto-open on future gains
if (typeof credits === "number") {
storage.set(StorageKey.WALLET_LAST_SEEN_CREDITS, String(credits));
setLastSeenCredits(credits);
}
}
}}
>
<Popover open={walletOpen} onOpenChange={(open) => setWalletOpen(open)}>
<PopoverTrigger asChild>
<div className="relative inline-block">
<button