mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-05 12:25:04 -05:00
Compare commits
4 Commits
dev
...
fix/code-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a093d57ed2 | ||
|
|
6692f39cbd | ||
|
|
aeba28266c | ||
|
|
6d8c83c039 |
3
autogpt_platform/backend/.gitignore
vendored
3
autogpt_platform/backend/.gitignore
vendored
@@ -19,6 +19,3 @@ load-tests/*.json
|
|||||||
load-tests/*.log
|
load-tests/*.log
|
||||||
load-tests/node_modules/*
|
load-tests/node_modules/*
|
||||||
migrations/*/rollback*.sql
|
migrations/*/rollback*.sql
|
||||||
|
|
||||||
# Workspace files
|
|
||||||
workspaces/
|
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import uuid as uuid_module
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
from backend.util.prompt import compress_context
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.util.prompt import CompressResult
|
from backend.util.prompt import CompressResult
|
||||||
|
|
||||||
@@ -33,7 +36,7 @@ from backend.data.understanding import (
|
|||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
)
|
)
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.settings import AppEnvironment, Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from . import db as chat_db
|
from . import db as chat_db
|
||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
@@ -222,18 +225,8 @@ async def _get_system_prompt_template(context: str) -> str:
|
|||||||
try:
|
try:
|
||||||
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
# cache_ttl_seconds=0 disables SDK caching to always get the latest prompt
|
||||||
# Use asyncio.to_thread to avoid blocking the event loop
|
# Use asyncio.to_thread to avoid blocking the event loop
|
||||||
# In non-production environments, fetch the latest prompt version
|
|
||||||
# instead of the production-labeled version for easier testing
|
|
||||||
label = (
|
|
||||||
None
|
|
||||||
if settings.config.app_env == AppEnvironment.PRODUCTION
|
|
||||||
else "latest"
|
|
||||||
)
|
|
||||||
prompt = await asyncio.to_thread(
|
prompt = await asyncio.to_thread(
|
||||||
langfuse.get_prompt,
|
langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0
|
||||||
config.langfuse_prompt_name,
|
|
||||||
label=label,
|
|
||||||
cache_ttl_seconds=0,
|
|
||||||
)
|
)
|
||||||
return prompt.compile(users_information=context)
|
return prompt.compile(users_information=context)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -477,8 +470,6 @@ async def stream_chat_completion(
|
|||||||
should_retry = False
|
should_retry = False
|
||||||
|
|
||||||
# Generate unique IDs for AI SDK protocol
|
# Generate unique IDs for AI SDK protocol
|
||||||
import uuid as uuid_module
|
|
||||||
|
|
||||||
message_id = str(uuid_module.uuid4())
|
message_id = str(uuid_module.uuid4())
|
||||||
text_block_id = str(uuid_module.uuid4())
|
text_block_id = str(uuid_module.uuid4())
|
||||||
|
|
||||||
@@ -628,9 +619,6 @@ async def stream_chat_completion(
|
|||||||
total_tokens=chunk.totalTokens,
|
total_tokens=chunk.totalTokens,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif isinstance(chunk, StreamHeartbeat):
|
|
||||||
# Pass through heartbeat to keep SSE connection alive
|
|
||||||
yield chunk
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True)
|
||||||
|
|
||||||
@@ -839,10 +827,6 @@ async def _manage_context_window(
|
|||||||
Returns:
|
Returns:
|
||||||
CompressResult with compacted messages and metadata
|
CompressResult with compacted messages and metadata
|
||||||
"""
|
"""
|
||||||
import openai
|
|
||||||
|
|
||||||
from backend.util.prompt import compress_context
|
|
||||||
|
|
||||||
# Convert messages to dict format
|
# Convert messages to dict format
|
||||||
messages_dict = []
|
messages_dict = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
@@ -1153,8 +1137,6 @@ async def _yield_tool_call(
|
|||||||
KeyError: If expected tool call fields are missing
|
KeyError: If expected tool call fields are missing
|
||||||
TypeError: If tool call structure is invalid
|
TypeError: If tool call structure is invalid
|
||||||
"""
|
"""
|
||||||
import uuid as uuid_module
|
|
||||||
|
|
||||||
tool_name = tool_calls[yield_idx]["function"]["name"]
|
tool_name = tool_calls[yield_idx]["function"]["name"]
|
||||||
tool_call_id = tool_calls[yield_idx]["id"]
|
tool_call_id = tool_calls[yield_idx]["id"]
|
||||||
|
|
||||||
@@ -1775,8 +1757,6 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
after a tool result is saved. Chunks are published to the stream registry
|
after a tool result is saved. Chunks are published to the stream registry
|
||||||
so reconnecting clients can receive them.
|
so reconnecting clients can receive them.
|
||||||
"""
|
"""
|
||||||
import uuid as uuid_module
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load fresh session from DB (bypass cache to get the updated tool result)
|
# Load fresh session from DB (bypass cache to get the updated tool result)
|
||||||
await invalidate_session_cache(session_id)
|
await invalidate_session_cache(session_id)
|
||||||
@@ -1812,10 +1792,6 @@ async def _generate_llm_continuation_with_streaming(
|
|||||||
extra_body["session_id"] = session_id[:128]
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
# Make streaming LLM call (no tools - just text response)
|
# Make streaming LLM call (no tools - just text response)
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
|
||||||
|
|
||||||
# Generate unique IDs for AI SDK protocol
|
# Generate unique IDs for AI SDK protocol
|
||||||
message_id = str(uuid_module.uuid4())
|
message_id = str(uuid_module.uuid4())
|
||||||
text_block_id = str(uuid_module.uuid4())
|
text_block_id = str(uuid_module.uuid4())
|
||||||
|
|||||||
@@ -7,7 +7,15 @@ from typing import Any, NotRequired, TypedDict
|
|||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.api.features.library import db as library_db
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
from backend.data.graph import (
|
||||||
|
Graph,
|
||||||
|
Link,
|
||||||
|
Node,
|
||||||
|
create_graph,
|
||||||
|
get_graph,
|
||||||
|
get_graph_all_versions,
|
||||||
|
get_store_listed_graphs,
|
||||||
|
)
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
@@ -20,6 +28,8 @@ from .service import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||||
|
|
||||||
|
|
||||||
class ExecutionSummary(TypedDict):
|
class ExecutionSummary(TypedDict):
|
||||||
"""Summary of a single execution for quality assessment."""
|
"""Summary of a single execution for quality assessment."""
|
||||||
@@ -659,6 +669,45 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _reassign_node_ids(graph: Graph) -> None:
|
||||||
|
"""Reassign all node and link IDs to new UUIDs.
|
||||||
|
|
||||||
|
This is needed when creating a new version to avoid unique constraint violations.
|
||||||
|
"""
|
||||||
|
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||||
|
|
||||||
|
for node in graph.nodes:
|
||||||
|
node.id = id_map[node.id]
|
||||||
|
|
||||||
|
for link in graph.links:
|
||||||
|
link.id = str(uuid.uuid4())
|
||||||
|
if link.source_id in id_map:
|
||||||
|
link.source_id = id_map[link.source_id]
|
||||||
|
if link.sink_id in id_map:
|
||||||
|
link.sink_id = id_map[link.sink_id]
|
||||||
|
|
||||||
|
|
||||||
|
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
||||||
|
"""Populate user_id in AgentExecutorBlock nodes.
|
||||||
|
|
||||||
|
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
||||||
|
This function fills in the actual user_id so sub-agents run with correct permissions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_json: Agent JSON dict (modified in place)
|
||||||
|
user_id: User ID to set
|
||||||
|
"""
|
||||||
|
for node in agent_json.get("nodes", []):
|
||||||
|
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
||||||
|
input_default = node.get("input_default") or {}
|
||||||
|
if not input_default.get("user_id"):
|
||||||
|
input_default["user_id"] = user_id
|
||||||
|
node["input_default"] = input_default
|
||||||
|
logger.debug(
|
||||||
|
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def save_agent_to_library(
|
async def save_agent_to_library(
|
||||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||||
) -> tuple[Graph, Any]:
|
) -> tuple[Graph, Any]:
|
||||||
@@ -672,10 +721,35 @@ async def save_agent_to_library(
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
|
# Populate user_id in AgentExecutorBlock nodes before conversion
|
||||||
|
_populate_agent_executor_user_ids(agent_json, user_id)
|
||||||
|
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
|
|
||||||
if is_update:
|
if is_update:
|
||||||
return await library_db.update_graph_in_library(graph, user_id)
|
if graph.id:
|
||||||
return await library_db.create_graph_in_library(graph, user_id)
|
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||||
|
if existing_versions:
|
||||||
|
latest_version = max(v.version for v in existing_versions)
|
||||||
|
graph.version = latest_version + 1
|
||||||
|
_reassign_node_ids(graph)
|
||||||
|
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||||
|
else:
|
||||||
|
graph.id = str(uuid.uuid4())
|
||||||
|
graph.version = 1
|
||||||
|
_reassign_node_ids(graph)
|
||||||
|
logger.info(f"Creating new agent with ID {graph.id}")
|
||||||
|
|
||||||
|
created_graph = await create_graph(graph, user_id)
|
||||||
|
|
||||||
|
library_agents = await library_db.create_library_agent(
|
||||||
|
graph=created_graph,
|
||||||
|
user_id=user_id,
|
||||||
|
sensitive_action_safe_mode=True,
|
||||||
|
create_library_agents_for_sub_graphs=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return created_graph, library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -206,9 +206,9 @@ async def search_agents(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
no_results_msg = (
|
no_results_msg = (
|
||||||
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
|
f"No agents found matching '{query}'. Try different keywords or browse the marketplace."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
|
else f"No agents matching '{query}' found in your library."
|
||||||
)
|
)
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
message=no_results_msg, session_id=session_id, suggestions=suggestions
|
||||||
@@ -224,10 +224,10 @@ async def search_agents(
|
|||||||
message = (
|
message = (
|
||||||
"Now you have found some options for the user to choose from. "
|
"Now you have found some options for the user to choose from. "
|
||||||
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
|
||||||
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
|
"Please ask the user if they would like to use any of these agents."
|
||||||
if source == "marketplace"
|
if source == "marketplace"
|
||||||
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
else "Found agents in the user's library. You can provide a link to view an agent at: "
|
||||||
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
|
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute."
|
||||||
)
|
)
|
||||||
|
|
||||||
return AgentsFoundResponse(
|
return AgentsFoundResponse(
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
@@ -28,6 +30,26 @@ from .models import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateAgentInput(BaseModel):
|
||||||
|
"""Input parameters for the create_agent tool."""
|
||||||
|
|
||||||
|
description: str = ""
|
||||||
|
context: str = ""
|
||||||
|
save: bool = True
|
||||||
|
# Internal async processing params (passed by long-running tool handler)
|
||||||
|
_operation_id: str | None = None
|
||||||
|
_task_id: str | None = None
|
||||||
|
|
||||||
|
@field_validator("description", "context", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_strings(cls, v: Any) -> str:
|
||||||
|
"""Strip whitespace from string fields."""
|
||||||
|
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow" # Allow _operation_id, _task_id from kwargs
|
||||||
|
|
||||||
|
|
||||||
class CreateAgentTool(BaseTool):
|
class CreateAgentTool(BaseTool):
|
||||||
"""Tool for creating agents from natural language descriptions."""
|
"""Tool for creating agents from natural language descriptions."""
|
||||||
|
|
||||||
@@ -85,7 +107,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
"""Execute the create_agent tool.
|
"""Execute the create_agent tool.
|
||||||
|
|
||||||
@@ -94,16 +116,14 @@ class CreateAgentTool(BaseTool):
|
|||||||
2. Generate agent JSON (external service handles fixing and validation)
|
2. Generate agent JSON (external service handles fixing and validation)
|
||||||
3. Preview or save based on the save parameter
|
3. Preview or save based on the save parameter
|
||||||
"""
|
"""
|
||||||
description = kwargs.get("description", "").strip()
|
params = CreateAgentInput(**kwargs)
|
||||||
context = kwargs.get("context", "")
|
|
||||||
save = kwargs.get("save", True)
|
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
# Extract async processing params (passed by long-running tool handler)
|
# Extract async processing params
|
||||||
operation_id = kwargs.get("_operation_id")
|
operation_id = kwargs.get("_operation_id")
|
||||||
task_id = kwargs.get("_task_id")
|
task_id = kwargs.get("_task_id")
|
||||||
|
|
||||||
if not description:
|
if not params.description:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide a description of what the agent should do.",
|
message="Please provide a description of what the agent should do.",
|
||||||
error="Missing description parameter",
|
error="Missing description parameter",
|
||||||
@@ -115,7 +135,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
try:
|
try:
|
||||||
library_agents = await get_all_relevant_agents_for_generation(
|
library_agents = await get_all_relevant_agents_for_generation(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_query=description,
|
search_query=params.description,
|
||||||
include_marketplace=True,
|
include_marketplace=True,
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -126,7 +146,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
decomposition_result = await decompose_goal(
|
decomposition_result = await decompose_goal(
|
||||||
description, context, library_agents
|
params.description, params.context, library_agents
|
||||||
)
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
@@ -142,7 +162,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
|
message="Failed to analyze the goal. The agent generation service may be unavailable. Please try again.",
|
||||||
error="decomposition_failed",
|
error="decomposition_failed",
|
||||||
details={"description": description[:100]},
|
details={"description": params.description[:100]},
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -158,7 +178,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
message=user_message,
|
message=user_message,
|
||||||
error=f"decomposition_failed:{error_type}",
|
error=f"decomposition_failed:{error_type}",
|
||||||
details={
|
details={
|
||||||
"description": description[:100],
|
"description": params.description[:100],
|
||||||
"service_error": error_msg,
|
"service_error": error_msg,
|
||||||
"error_type": error_type,
|
"error_type": error_type,
|
||||||
},
|
},
|
||||||
@@ -244,7 +264,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
|
message="Failed to generate the agent. The agent generation service may be unavailable. Please try again.",
|
||||||
error="generation_failed",
|
error="generation_failed",
|
||||||
details={"description": description[:100]},
|
details={"description": params.description[:100]},
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -266,7 +286,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
message=user_message,
|
message=user_message,
|
||||||
error=f"generation_failed:{error_type}",
|
error=f"generation_failed:{error_type}",
|
||||||
details={
|
details={
|
||||||
"description": description[:100],
|
"description": params.description[:100],
|
||||||
"service_error": error_msg,
|
"service_error": error_msg,
|
||||||
"error_type": error_type,
|
"error_type": error_type,
|
||||||
},
|
},
|
||||||
@@ -291,7 +311,7 @@ class CreateAgentTool(BaseTool):
|
|||||||
node_count = len(agent_json.get("nodes", []))
|
node_count = len(agent_json.get("nodes", []))
|
||||||
link_count = len(agent_json.get("links", []))
|
link_count = len(agent_json.get("links", []))
|
||||||
|
|
||||||
if not save:
|
if not params.save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||||
@@ -27,6 +29,23 @@ from .models import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomizeAgentInput(BaseModel):
|
||||||
|
"""Input parameters for the customize_agent tool."""
|
||||||
|
|
||||||
|
agent_id: str = ""
|
||||||
|
modifications: str = ""
|
||||||
|
context: str = ""
|
||||||
|
save: bool = True
|
||||||
|
|
||||||
|
@field_validator("agent_id", "modifications", "context", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_strings(cls, v: Any) -> str:
|
||||||
|
"""Strip whitespace from string fields."""
|
||||||
|
if isinstance(v, str):
|
||||||
|
return v.strip()
|
||||||
|
return v if v is not None else ""
|
||||||
|
|
||||||
|
|
||||||
class CustomizeAgentTool(BaseTool):
|
class CustomizeAgentTool(BaseTool):
|
||||||
"""Tool for customizing marketplace/template agents using natural language."""
|
"""Tool for customizing marketplace/template agents using natural language."""
|
||||||
|
|
||||||
@@ -92,7 +111,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
"""Execute the customize_agent tool.
|
"""Execute the customize_agent tool.
|
||||||
|
|
||||||
@@ -102,20 +121,17 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
3. Call customize_template with the modification request
|
3. Call customize_template with the modification request
|
||||||
4. Preview or save based on the save parameter
|
4. Preview or save based on the save parameter
|
||||||
"""
|
"""
|
||||||
agent_id = kwargs.get("agent_id", "").strip()
|
params = CustomizeAgentInput(**kwargs)
|
||||||
modifications = kwargs.get("modifications", "").strip()
|
|
||||||
context = kwargs.get("context", "")
|
|
||||||
save = kwargs.get("save", True)
|
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
if not agent_id:
|
if not params.agent_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||||
error="missing_agent_id",
|
error="missing_agent_id",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not modifications:
|
if not params.modifications:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please describe how you want to customize this agent.",
|
message="Please describe how you want to customize this agent.",
|
||||||
error="missing_modifications",
|
error="missing_modifications",
|
||||||
@@ -123,11 +139,11 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Parse agent_id in format "creator/slug"
|
# Parse agent_id in format "creator/slug"
|
||||||
parts = [p.strip() for p in agent_id.split("/")]
|
parts = [p.strip() for p in params.agent_id.split("/")]
|
||||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
f"Invalid agent ID format: '{agent_id}'. "
|
f"Invalid agent ID format: '{params.agent_id}'. "
|
||||||
"Expected format is 'creator/agent-name' "
|
"Expected format is 'creator/agent-name' "
|
||||||
"(e.g., 'autogpt/newsletter-writer')."
|
"(e.g., 'autogpt/newsletter-writer')."
|
||||||
),
|
),
|
||||||
@@ -145,14 +161,14 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
except AgentNotFoundError:
|
except AgentNotFoundError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
f"Could not find marketplace agent '{agent_id}'. "
|
f"Could not find marketplace agent '{params.agent_id}'. "
|
||||||
"Please check the agent ID and try again."
|
"Please check the agent ID and try again."
|
||||||
),
|
),
|
||||||
error="agent_not_found",
|
error="agent_not_found",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
logger.error(f"Error fetching marketplace agent {params.agent_id}: {e}")
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to fetch the marketplace agent. Please try again.",
|
message="Failed to fetch the marketplace agent. Please try again.",
|
||||||
error="fetch_error",
|
error="fetch_error",
|
||||||
@@ -162,7 +178,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
if not agent_details.store_listing_version_id:
|
if not agent_details.store_listing_version_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
f"The agent '{agent_id}' does not have an available version. "
|
f"The agent '{params.agent_id}' does not have an available version. "
|
||||||
"Please try a different agent."
|
"Please try a different agent."
|
||||||
),
|
),
|
||||||
error="no_version_available",
|
error="no_version_available",
|
||||||
@@ -174,7 +190,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||||
template_agent = graph_to_json(graph)
|
template_agent = graph_to_json(graph)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
logger.error(f"Error fetching agent graph for {params.agent_id}: {e}")
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to fetch the agent configuration. Please try again.",
|
message="Failed to fetch the agent configuration. Please try again.",
|
||||||
error="graph_fetch_error",
|
error="graph_fetch_error",
|
||||||
@@ -185,8 +201,8 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
try:
|
try:
|
||||||
result = await customize_template(
|
result = await customize_template(
|
||||||
template_agent=template_agent,
|
template_agent=template_agent,
|
||||||
modification_request=modifications,
|
modification_request=params.modifications,
|
||||||
context=context,
|
context=params.context,
|
||||||
)
|
)
|
||||||
except AgentGeneratorNotConfiguredError:
|
except AgentGeneratorNotConfiguredError:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
@@ -198,7 +214,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
logger.error(f"Error calling customize_template for {params.agent_id}: {e}")
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=(
|
message=(
|
||||||
"Failed to customize the agent due to a service error. "
|
"Failed to customize the agent due to a service error. "
|
||||||
@@ -219,8 +235,37 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle error response
|
# Handle response using match/case for cleaner pattern matching
|
||||||
if isinstance(result, dict) and result.get("type") == "error":
|
return await self._handle_customization_result(
|
||||||
|
result=result,
|
||||||
|
params=params,
|
||||||
|
agent_details=agent_details,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_customization_result(
|
||||||
|
self,
|
||||||
|
result: dict[str, Any],
|
||||||
|
params: CustomizeAgentInput,
|
||||||
|
agent_details: Any,
|
||||||
|
user_id: str | None,
|
||||||
|
session_id: str | None,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Handle the result from customize_template using pattern matching."""
|
||||||
|
# Ensure result is a dict
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||||
|
return ErrorResponse(
|
||||||
|
message="Failed to customize the agent due to an unexpected response.",
|
||||||
|
error="unexpected_response_type",
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
result_type = result.get("type")
|
||||||
|
|
||||||
|
match result_type:
|
||||||
|
case "error":
|
||||||
error_msg = result.get("error", "Unknown error")
|
error_msg = result.get("error", "Unknown error")
|
||||||
error_type = result.get("error_type", "unknown")
|
error_type = result.get("error_type", "unknown")
|
||||||
user_message = get_user_message_for_error(
|
user_message = get_user_message_for_error(
|
||||||
@@ -242,42 +287,52 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle clarifying questions
|
case "clarifying_questions":
|
||||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
questions_data = result.get("questions") or []
|
||||||
questions = result.get("questions") or []
|
if not isinstance(questions_data, list):
|
||||||
if not isinstance(questions, list):
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Unexpected clarifying questions format: {type(questions)}"
|
f"Unexpected clarifying questions format: {type(questions_data)}"
|
||||||
)
|
)
|
||||||
questions = []
|
questions_data = []
|
||||||
|
|
||||||
|
questions = [
|
||||||
|
ClarifyingQuestion(
|
||||||
|
question=q.get("question", "") if isinstance(q, dict) else "",
|
||||||
|
keyword=q.get("keyword", "") if isinstance(q, dict) else "",
|
||||||
|
example=q.get("example") if isinstance(q, dict) else None,
|
||||||
|
)
|
||||||
|
for q in questions_data
|
||||||
|
if isinstance(q, dict)
|
||||||
|
]
|
||||||
|
|
||||||
return ClarificationNeededResponse(
|
return ClarificationNeededResponse(
|
||||||
message=(
|
message=(
|
||||||
"I need some more information to customize this agent. "
|
"I need some more information to customize this agent. "
|
||||||
"Please answer the following questions:"
|
"Please answer the following questions:"
|
||||||
),
|
),
|
||||||
questions=[
|
questions=questions,
|
||||||
ClarifyingQuestion(
|
|
||||||
question=q.get("question", ""),
|
|
||||||
keyword=q.get("keyword", ""),
|
|
||||||
example=q.get("example"),
|
|
||||||
)
|
|
||||||
for q in questions
|
|
||||||
if isinstance(q, dict)
|
|
||||||
],
|
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Result should be the customized agent JSON
|
case _:
|
||||||
if not isinstance(result, dict):
|
# Default case: result is the customized agent JSON
|
||||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
return await self._save_or_preview_agent(
|
||||||
return ErrorResponse(
|
customized_agent=result,
|
||||||
message="Failed to customize the agent due to an unexpected response.",
|
params=params,
|
||||||
error="unexpected_response_type",
|
agent_details=agent_details,
|
||||||
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
customized_agent = result
|
async def _save_or_preview_agent(
|
||||||
|
self,
|
||||||
|
customized_agent: dict[str, Any],
|
||||||
|
params: CustomizeAgentInput,
|
||||||
|
agent_details: Any,
|
||||||
|
user_id: str | None,
|
||||||
|
session_id: str | None,
|
||||||
|
) -> ToolResponseBase:
|
||||||
|
"""Save or preview the customized agent based on params.save."""
|
||||||
agent_name = customized_agent.get(
|
agent_name = customized_agent.get(
|
||||||
"name", f"Customized {agent_details.agent_name}"
|
"name", f"Customized {agent_details.agent_name}"
|
||||||
)
|
)
|
||||||
@@ -287,7 +342,7 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||||
link_count = len(links) if isinstance(links, list) else 0
|
link_count = len(links) if isinstance(links, list) else 0
|
||||||
|
|
||||||
if not save:
|
if not params.save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
@@ -27,6 +29,20 @@ from .models import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EditAgentInput(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
agent_id: str = ""
|
||||||
|
changes: str = ""
|
||||||
|
context: str = ""
|
||||||
|
save: bool = True
|
||||||
|
|
||||||
|
@field_validator("agent_id", "changes", "context", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_strings(cls, v: Any) -> str:
|
||||||
|
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||||
|
|
||||||
|
|
||||||
class EditAgentTool(BaseTool):
|
class EditAgentTool(BaseTool):
|
||||||
"""Tool for editing existing agents using natural language."""
|
"""Tool for editing existing agents using natural language."""
|
||||||
|
|
||||||
@@ -90,7 +106,7 @@ class EditAgentTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
"""Execute the edit_agent tool.
|
"""Execute the edit_agent tool.
|
||||||
|
|
||||||
@@ -99,35 +115,32 @@ class EditAgentTool(BaseTool):
|
|||||||
2. Generate updated agent (external service handles fixing and validation)
|
2. Generate updated agent (external service handles fixing and validation)
|
||||||
3. Preview or save based on the save parameter
|
3. Preview or save based on the save parameter
|
||||||
"""
|
"""
|
||||||
agent_id = kwargs.get("agent_id", "").strip()
|
params = EditAgentInput(**kwargs)
|
||||||
changes = kwargs.get("changes", "").strip()
|
|
||||||
context = kwargs.get("context", "")
|
|
||||||
save = kwargs.get("save", True)
|
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
# Extract async processing params (passed by long-running tool handler)
|
# Extract async processing params (passed by long-running tool handler)
|
||||||
operation_id = kwargs.get("_operation_id")
|
operation_id = kwargs.get("_operation_id")
|
||||||
task_id = kwargs.get("_task_id")
|
task_id = kwargs.get("_task_id")
|
||||||
|
|
||||||
if not agent_id:
|
if not params.agent_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide the agent ID to edit.",
|
message="Please provide the agent ID to edit.",
|
||||||
error="Missing agent_id parameter",
|
error="Missing agent_id parameter",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not changes:
|
if not params.changes:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please describe what changes you want to make.",
|
message="Please describe what changes you want to make.",
|
||||||
error="Missing changes parameter",
|
error="Missing changes parameter",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
current_agent = await get_agent_as_json(params.agent_id, user_id)
|
||||||
|
|
||||||
if current_agent is None:
|
if current_agent is None:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"Could not find agent with ID '{agent_id}' in your library.",
|
message=f"Could not find agent '{params.agent_id}' in your library.",
|
||||||
error="agent_not_found",
|
error="agent_not_found",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
@@ -138,7 +151,7 @@ class EditAgentTool(BaseTool):
|
|||||||
graph_id = current_agent.get("id")
|
graph_id = current_agent.get("id")
|
||||||
library_agents = await get_all_relevant_agents_for_generation(
|
library_agents = await get_all_relevant_agents_for_generation(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_query=changes,
|
search_query=params.changes,
|
||||||
exclude_graph_id=graph_id,
|
exclude_graph_id=graph_id,
|
||||||
include_marketplace=True,
|
include_marketplace=True,
|
||||||
)
|
)
|
||||||
@@ -148,9 +161,11 @@ class EditAgentTool(BaseTool):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to fetch library agents: {e}")
|
logger.warning(f"Failed to fetch library agents: {e}")
|
||||||
|
|
||||||
update_request = changes
|
update_request = params.changes
|
||||||
if context:
|
if params.context:
|
||||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
update_request = (
|
||||||
|
f"{params.changes}\n\nAdditional context:\n{params.context}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await generate_agent_patch(
|
result = await generate_agent_patch(
|
||||||
@@ -174,7 +189,7 @@ class EditAgentTool(BaseTool):
|
|||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
|
message="Failed to generate changes. The agent generation service may be unavailable or timed out. Please try again.",
|
||||||
error="update_generation_failed",
|
error="update_generation_failed",
|
||||||
details={"agent_id": agent_id, "changes": changes[:100]},
|
details={"agent_id": params.agent_id, "changes": params.changes[:100]},
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -206,8 +221,8 @@ class EditAgentTool(BaseTool):
|
|||||||
message=user_message,
|
message=user_message,
|
||||||
error=f"update_generation_failed:{error_type}",
|
error=f"update_generation_failed:{error_type}",
|
||||||
details={
|
details={
|
||||||
"agent_id": agent_id,
|
"agent_id": params.agent_id,
|
||||||
"changes": changes[:100],
|
"changes": params.changes[:100],
|
||||||
"service_error": error_msg,
|
"service_error": error_msg,
|
||||||
"error_type": error_type,
|
"error_type": error_type,
|
||||||
},
|
},
|
||||||
@@ -239,7 +254,7 @@ class EditAgentTool(BaseTool):
|
|||||||
node_count = len(updated_agent.get("nodes", []))
|
node_count = len(updated_agent.get("nodes", []))
|
||||||
link_count = len(updated_agent.get("links", []))
|
link_count = len(updated_agent.get("links", []))
|
||||||
|
|
||||||
if not save:
|
if not params.save:
|
||||||
return AgentPreviewResponse(
|
return AgentPreviewResponse(
|
||||||
message=(
|
message=(
|
||||||
f"I've updated the agent. "
|
f"I've updated the agent. "
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_search import search_agents
|
from .agent_search import search_agents
|
||||||
@@ -9,6 +11,18 @@ from .base import BaseTool
|
|||||||
from .models import ToolResponseBase
|
from .models import ToolResponseBase
|
||||||
|
|
||||||
|
|
||||||
|
class FindAgentInput(BaseModel):
|
||||||
|
"""Input parameters for the find_agent tool."""
|
||||||
|
|
||||||
|
query: str = ""
|
||||||
|
|
||||||
|
@field_validator("query", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_string(cls, v: Any) -> str:
|
||||||
|
"""Strip whitespace from query."""
|
||||||
|
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||||
|
|
||||||
|
|
||||||
class FindAgentTool(BaseTool):
|
class FindAgentTool(BaseTool):
|
||||||
"""Tool for discovering agents from the marketplace."""
|
"""Tool for discovering agents from the marketplace."""
|
||||||
|
|
||||||
@@ -36,10 +50,11 @@ class FindAgentTool(BaseTool):
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self, user_id: str | None, session: ChatSession, **kwargs
|
self, user_id: str | None, session: ChatSession, **kwargs: Any
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
|
params = FindAgentInput(**kwargs)
|
||||||
return await search_agents(
|
return await search_agents(
|
||||||
query=kwargs.get("query", "").strip(),
|
query=params.query,
|
||||||
source="marketplace",
|
source="marketplace",
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import logging
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from prisma.enums import ContentType
|
from prisma.enums import ContentType
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||||
@@ -18,6 +19,18 @@ from backend.data.block import get_block
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FindBlockInput(BaseModel):
|
||||||
|
"""Input parameters for the find_block tool."""
|
||||||
|
|
||||||
|
query: str = ""
|
||||||
|
|
||||||
|
@field_validator("query", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_string(cls, v: Any) -> str:
|
||||||
|
"""Strip whitespace from query."""
|
||||||
|
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||||
|
|
||||||
|
|
||||||
class FindBlockTool(BaseTool):
|
class FindBlockTool(BaseTool):
|
||||||
"""Tool for searching available blocks."""
|
"""Tool for searching available blocks."""
|
||||||
|
|
||||||
@@ -59,24 +72,24 @@ class FindBlockTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
"""Search for blocks matching the query.
|
"""Search for blocks matching the query.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User ID (required)
|
user_id: User ID (required)
|
||||||
session: Chat session
|
session: Chat session
|
||||||
query: Search query
|
**kwargs: Tool parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BlockListResponse: List of matching blocks
|
BlockListResponse: List of matching blocks
|
||||||
NoResultsResponse: No blocks found
|
NoResultsResponse: No blocks found
|
||||||
ErrorResponse: Error message
|
ErrorResponse: Error message
|
||||||
"""
|
"""
|
||||||
query = kwargs.get("query", "").strip()
|
params = FindBlockInput(**kwargs)
|
||||||
session_id = session.session_id
|
session_id = session.session_id
|
||||||
|
|
||||||
if not query:
|
if not params.query:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide a search query",
|
message="Please provide a search query",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -85,7 +98,7 @@ class FindBlockTool(BaseTool):
|
|||||||
try:
|
try:
|
||||||
# Search for blocks using hybrid search
|
# Search for blocks using hybrid search
|
||||||
results, total = await unified_hybrid_search(
|
results, total = await unified_hybrid_search(
|
||||||
query=query,
|
query=params.query,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=10,
|
page_size=10,
|
||||||
@@ -93,7 +106,7 @@ class FindBlockTool(BaseTool):
|
|||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=f"No blocks found for '{query}'",
|
message=f"No blocks found for '{params.query}'",
|
||||||
suggestions=[
|
suggestions=[
|
||||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||||
"Check spelling of technical terms",
|
"Check spelling of technical terms",
|
||||||
@@ -165,7 +178,7 @@ class FindBlockTool(BaseTool):
|
|||||||
|
|
||||||
if not blocks:
|
if not blocks:
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=f"No blocks found for '{query}'",
|
message=f"No blocks found for '{params.query}'",
|
||||||
suggestions=[
|
suggestions=[
|
||||||
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
"Try broader keywords like 'email', 'http', 'text', 'ai'",
|
||||||
],
|
],
|
||||||
@@ -174,13 +187,13 @@ class FindBlockTool(BaseTool):
|
|||||||
|
|
||||||
return BlockListResponse(
|
return BlockListResponse(
|
||||||
message=(
|
message=(
|
||||||
f"Found {len(blocks)} block(s) matching '{query}'. "
|
f"Found {len(blocks)} block(s) matching '{params.query}'. "
|
||||||
"To execute a block, use run_block with the block's 'id' field "
|
"To execute a block, use run_block with the block's 'id' field "
|
||||||
"and provide 'input_data' matching the block's input_schema."
|
"and provide 'input_data' matching the block's input_schema."
|
||||||
),
|
),
|
||||||
blocks=blocks,
|
blocks=blocks,
|
||||||
count=len(blocks),
|
count=len(blocks),
|
||||||
query=query,
|
query=params.query,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
|
|
||||||
from .agent_search import search_agents
|
from .agent_search import search_agents
|
||||||
@@ -9,6 +11,15 @@ from .base import BaseTool
|
|||||||
from .models import ToolResponseBase
|
from .models import ToolResponseBase
|
||||||
|
|
||||||
|
|
||||||
|
class FindLibraryAgentInput(BaseModel):
|
||||||
|
query: str = ""
|
||||||
|
|
||||||
|
@field_validator("query", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_string(cls, v: Any) -> str:
|
||||||
|
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||||
|
|
||||||
|
|
||||||
class FindLibraryAgentTool(BaseTool):
|
class FindLibraryAgentTool(BaseTool):
|
||||||
"""Tool for searching agents in the user's library."""
|
"""Tool for searching agents in the user's library."""
|
||||||
|
|
||||||
@@ -42,10 +53,11 @@ class FindLibraryAgentTool(BaseTool):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self, user_id: str | None, session: ChatSession, **kwargs
|
self, user_id: str | None, session: ChatSession, **kwargs: Any
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
|
params = FindLibraryAgentInput(**kwargs)
|
||||||
return await search_agents(
|
return await search_agents(
|
||||||
query=kwargs.get("query", "").strip(),
|
query=params.query,
|
||||||
source="library",
|
source="library",
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
from backend.api.features.chat.tools.models import (
|
from backend.api.features.chat.tools.models import (
|
||||||
@@ -18,6 +20,18 @@ logger = logging.getLogger(__name__)
|
|||||||
DOCS_BASE_URL = "https://docs.agpt.co"
|
DOCS_BASE_URL = "https://docs.agpt.co"
|
||||||
|
|
||||||
|
|
||||||
|
class GetDocPageInput(BaseModel):
|
||||||
|
"""Input parameters for the get_doc_page tool."""
|
||||||
|
|
||||||
|
path: str = ""
|
||||||
|
|
||||||
|
@field_validator("path", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_string(cls, v: Any) -> str:
|
||||||
|
"""Strip whitespace from path."""
|
||||||
|
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||||
|
|
||||||
|
|
||||||
class GetDocPageTool(BaseTool):
|
class GetDocPageTool(BaseTool):
|
||||||
"""Tool for fetching full content of a documentation page."""
|
"""Tool for fetching full content of a documentation page."""
|
||||||
|
|
||||||
@@ -75,23 +89,23 @@ class GetDocPageTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
"""Fetch full content of a documentation page.
|
"""Fetch full content of a documentation page.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User ID (not required for docs)
|
user_id: User ID (not required for docs)
|
||||||
session: Chat session
|
session: Chat session
|
||||||
path: Path to the documentation file
|
**kwargs: Tool parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DocPageResponse: Full document content
|
DocPageResponse: Full document content
|
||||||
ErrorResponse: Error message
|
ErrorResponse: Error message
|
||||||
"""
|
"""
|
||||||
path = kwargs.get("path", "").strip()
|
params = GetDocPageInput(**kwargs)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
if not path:
|
if not params.path:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide a documentation path.",
|
message="Please provide a documentation path.",
|
||||||
error="Missing path parameter",
|
error="Missing path parameter",
|
||||||
@@ -99,7 +113,7 @@ class GetDocPageTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Sanitize path to prevent directory traversal
|
# Sanitize path to prevent directory traversal
|
||||||
if ".." in path or path.startswith("/"):
|
if ".." in params.path or params.path.startswith("/"):
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Invalid documentation path.",
|
message="Invalid documentation path.",
|
||||||
error="invalid_path",
|
error="invalid_path",
|
||||||
@@ -107,11 +121,11 @@ class GetDocPageTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
docs_root = self._get_docs_root()
|
docs_root = self._get_docs_root()
|
||||||
full_path = docs_root / path
|
full_path = docs_root / params.path
|
||||||
|
|
||||||
if not full_path.exists():
|
if not full_path.exists():
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"Documentation page not found: {path}",
|
message=f"Documentation page not found: {params.path}",
|
||||||
error="not_found",
|
error="not_found",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
@@ -128,19 +142,19 @@ class GetDocPageTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
content = full_path.read_text(encoding="utf-8")
|
content = full_path.read_text(encoding="utf-8")
|
||||||
title = self._extract_title(content, path)
|
title = self._extract_title(content, params.path)
|
||||||
|
|
||||||
return DocPageResponse(
|
return DocPageResponse(
|
||||||
message=f"Retrieved documentation page: {title}",
|
message=f"Retrieved documentation page: {title}",
|
||||||
title=title,
|
title=title,
|
||||||
path=path,
|
path=params.path,
|
||||||
content=content,
|
content=content,
|
||||||
doc_url=self._make_doc_url(path),
|
doc_url=self._make_doc_url(params.path),
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to read documentation page {path}: {e}")
|
logger.error(f"Failed to read documentation page {params.path}: {e}")
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"Failed to read documentation page: {str(e)}",
|
message=f"Failed to read documentation page: {str(e)}",
|
||||||
error="read_failed",
|
error="read_failed",
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import uuid
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
@@ -29,6 +30,25 @@ from .utils import build_missing_credentials_from_field_info
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RunBlockInput(BaseModel):
|
||||||
|
"""Input parameters for the run_block tool."""
|
||||||
|
|
||||||
|
block_id: str = ""
|
||||||
|
input_data: dict[str, Any] = {}
|
||||||
|
|
||||||
|
@field_validator("block_id", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_block_id(cls, v: Any) -> str:
|
||||||
|
"""Strip whitespace from block_id."""
|
||||||
|
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||||
|
|
||||||
|
@field_validator("input_data", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def ensure_dict(cls, v: Any) -> dict[str, Any]:
|
||||||
|
"""Ensure input_data is a dict."""
|
||||||
|
return v if isinstance(v, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
class RunBlockTool(BaseTool):
|
class RunBlockTool(BaseTool):
|
||||||
"""Tool for executing a block and returning its outputs."""
|
"""Tool for executing a block and returning its outputs."""
|
||||||
|
|
||||||
@@ -162,37 +182,29 @@ class RunBlockTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
"""Execute a block with the given input data.
|
"""Execute a block with the given input data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User ID (required)
|
user_id: User ID (required)
|
||||||
session: Chat session
|
session: Chat session
|
||||||
block_id: Block UUID to execute
|
**kwargs: Tool parameters
|
||||||
input_data: Input values for the block
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
BlockOutputResponse: Block execution outputs
|
BlockOutputResponse: Block execution outputs
|
||||||
SetupRequirementsResponse: Missing credentials
|
SetupRequirementsResponse: Missing credentials
|
||||||
ErrorResponse: Error message
|
ErrorResponse: Error message
|
||||||
"""
|
"""
|
||||||
block_id = kwargs.get("block_id", "").strip()
|
params = RunBlockInput(**kwargs)
|
||||||
input_data = kwargs.get("input_data", {})
|
|
||||||
session_id = session.session_id
|
session_id = session.session_id
|
||||||
|
|
||||||
if not block_id:
|
if not params.block_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide a block_id",
|
message="Please provide a block_id",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(input_data, dict):
|
|
||||||
return ErrorResponse(
|
|
||||||
message="input_data must be an object",
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not user_id:
|
if not user_id:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Authentication required",
|
message="Authentication required",
|
||||||
@@ -200,23 +212,25 @@ class RunBlockTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get the block
|
# Get the block
|
||||||
block = get_block(block_id)
|
block = get_block(params.block_id)
|
||||||
if not block:
|
if not block:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"Block '{block_id}' not found",
|
message=f"Block '{params.block_id}' not found",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
if block.disabled:
|
if block.disabled:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"Block '{block_id}' is disabled",
|
message=f"Block '{params.block_id}' is disabled",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
logger.info(
|
||||||
|
f"Executing block {block.name} ({params.block_id}) for user {user_id}"
|
||||||
|
)
|
||||||
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
creds_manager = IntegrationCredentialsManager()
|
||||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||||
user_id, block, input_data
|
user_id, block, params.input_data
|
||||||
)
|
)
|
||||||
|
|
||||||
if missing_credentials:
|
if missing_credentials:
|
||||||
@@ -234,7 +248,7 @@ class RunBlockTool(BaseTool):
|
|||||||
),
|
),
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
setup_info=SetupInfo(
|
setup_info=SetupInfo(
|
||||||
agent_id=block_id,
|
agent_id=params.block_id,
|
||||||
agent_name=block.name,
|
agent_name=block.name,
|
||||||
user_readiness=UserReadiness(
|
user_readiness=UserReadiness(
|
||||||
has_all_credentials=False,
|
has_all_credentials=False,
|
||||||
@@ -263,7 +277,7 @@ class RunBlockTool(BaseTool):
|
|||||||
# - node_exec_id = unique per block execution
|
# - node_exec_id = unique per block execution
|
||||||
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
synthetic_graph_id = f"copilot-session-{session.session_id}"
|
||||||
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
synthetic_graph_exec_id = f"copilot-session-{session.session_id}"
|
||||||
synthetic_node_id = f"copilot-node-{block_id}"
|
synthetic_node_id = f"copilot-node-{params.block_id}"
|
||||||
synthetic_node_exec_id = (
|
synthetic_node_exec_id = (
|
||||||
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
f"copilot-{session.session_id}-{uuid.uuid4().hex[:8]}"
|
||||||
)
|
)
|
||||||
@@ -298,8 +312,8 @@ class RunBlockTool(BaseTool):
|
|||||||
|
|
||||||
for field_name, cred_meta in matched_credentials.items():
|
for field_name, cred_meta in matched_credentials.items():
|
||||||
# Inject metadata into input_data (for validation)
|
# Inject metadata into input_data (for validation)
|
||||||
if field_name not in input_data:
|
if field_name not in params.input_data:
|
||||||
input_data[field_name] = cred_meta.model_dump()
|
params.input_data[field_name] = cred_meta.model_dump()
|
||||||
|
|
||||||
# Fetch actual credentials and pass as kwargs (for execution)
|
# Fetch actual credentials and pass as kwargs (for execution)
|
||||||
actual_credentials = await creds_manager.get(
|
actual_credentials = await creds_manager.get(
|
||||||
@@ -316,14 +330,14 @@ class RunBlockTool(BaseTool):
|
|||||||
# Execute the block and collect outputs
|
# Execute the block and collect outputs
|
||||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||||
async for output_name, output_data in block.execute(
|
async for output_name, output_data in block.execute(
|
||||||
input_data,
|
params.input_data,
|
||||||
**exec_kwargs,
|
**exec_kwargs,
|
||||||
):
|
):
|
||||||
outputs[output_name].append(output_data)
|
outputs[output_name].append(output_data)
|
||||||
|
|
||||||
return BlockOutputResponse(
|
return BlockOutputResponse(
|
||||||
message=f"Block '{block.name}' executed successfully",
|
message=f"Block '{block.name}' executed successfully",
|
||||||
block_id=block_id,
|
block_id=params.block_id,
|
||||||
block_name=block.name,
|
block_name=block.name,
|
||||||
outputs=dict(outputs),
|
outputs=dict(outputs),
|
||||||
success=True,
|
success=True,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import logging
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from prisma.enums import ContentType
|
from prisma.enums import ContentType
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
from backend.api.features.chat.tools.base import BaseTool
|
||||||
@@ -28,6 +29,18 @@ MAX_RESULTS = 5
|
|||||||
SNIPPET_LENGTH = 200
|
SNIPPET_LENGTH = 200
|
||||||
|
|
||||||
|
|
||||||
|
class SearchDocsInput(BaseModel):
|
||||||
|
"""Input parameters for the search_docs tool."""
|
||||||
|
|
||||||
|
query: str = ""
|
||||||
|
|
||||||
|
@field_validator("query", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_string(cls, v: Any) -> str:
|
||||||
|
"""Strip whitespace from query."""
|
||||||
|
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||||
|
|
||||||
|
|
||||||
class SearchDocsTool(BaseTool):
|
class SearchDocsTool(BaseTool):
|
||||||
"""Tool for searching AutoGPT platform documentation."""
|
"""Tool for searching AutoGPT platform documentation."""
|
||||||
|
|
||||||
@@ -91,24 +104,24 @@ class SearchDocsTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
"""Search documentation and return relevant sections.
|
"""Search documentation and return relevant sections.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User ID (not required for docs)
|
user_id: User ID (not required for docs)
|
||||||
session: Chat session
|
session: Chat session
|
||||||
query: Search query
|
**kwargs: Tool parameters
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DocSearchResultsResponse: List of matching documentation sections
|
DocSearchResultsResponse: List of matching documentation sections
|
||||||
NoResultsResponse: No results found
|
NoResultsResponse: No results found
|
||||||
ErrorResponse: Error message
|
ErrorResponse: Error message
|
||||||
"""
|
"""
|
||||||
query = kwargs.get("query", "").strip()
|
params = SearchDocsInput(**kwargs)
|
||||||
session_id = session.session_id if session else None
|
session_id = session.session_id if session else None
|
||||||
|
|
||||||
if not query:
|
if not params.query:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide a search query.",
|
message="Please provide a search query.",
|
||||||
error="Missing query parameter",
|
error="Missing query parameter",
|
||||||
@@ -118,7 +131,7 @@ class SearchDocsTool(BaseTool):
|
|||||||
try:
|
try:
|
||||||
# Search using hybrid search for DOCUMENTATION content type only
|
# Search using hybrid search for DOCUMENTATION content type only
|
||||||
results, total = await unified_hybrid_search(
|
results, total = await unified_hybrid_search(
|
||||||
query=query,
|
query=params.query,
|
||||||
content_types=[ContentType.DOCUMENTATION],
|
content_types=[ContentType.DOCUMENTATION],
|
||||||
page=1,
|
page=1,
|
||||||
page_size=MAX_RESULTS * 2, # Fetch extra for deduplication
|
page_size=MAX_RESULTS * 2, # Fetch extra for deduplication
|
||||||
@@ -127,7 +140,7 @@ class SearchDocsTool(BaseTool):
|
|||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=f"No documentation found for '{query}'.",
|
message=f"No documentation found for '{params.query}'.",
|
||||||
suggestions=[
|
suggestions=[
|
||||||
"Try different keywords",
|
"Try different keywords",
|
||||||
"Use more general terms",
|
"Use more general terms",
|
||||||
@@ -162,7 +175,7 @@ class SearchDocsTool(BaseTool):
|
|||||||
|
|
||||||
if not deduplicated:
|
if not deduplicated:
|
||||||
return NoResultsResponse(
|
return NoResultsResponse(
|
||||||
message=f"No documentation found for '{query}'.",
|
message=f"No documentation found for '{params.query}'.",
|
||||||
suggestions=[
|
suggestions=[
|
||||||
"Try different keywords",
|
"Try different keywords",
|
||||||
"Use more general terms",
|
"Use more general terms",
|
||||||
@@ -195,7 +208,7 @@ class SearchDocsTool(BaseTool):
|
|||||||
message=f"Found {len(doc_results)} relevant documentation sections.",
|
message=f"Found {len(doc_results)} relevant documentation sections.",
|
||||||
results=doc_results,
|
results=doc_results,
|
||||||
count=len(doc_results),
|
count=len(doc_results),
|
||||||
query=query,
|
query=params.query,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -8,12 +8,7 @@ from backend.api.features.library import model as library_model
|
|||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.data import graph as graph_db
|
from backend.data import graph as graph_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||||
CredentialsFieldInfo,
|
|
||||||
CredentialsMetaInput,
|
|
||||||
HostScopedCredentials,
|
|
||||||
OAuth2Credentials,
|
|
||||||
)
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
@@ -278,14 +273,7 @@ async def match_user_credentials_to_graph(
|
|||||||
for cred in available_creds
|
for cred in available_creds
|
||||||
if cred.provider in credential_requirements.provider
|
if cred.provider in credential_requirements.provider
|
||||||
and cred.type in credential_requirements.supported_types
|
and cred.type in credential_requirements.supported_types
|
||||||
and (
|
and _credential_has_required_scopes(cred, credential_requirements)
|
||||||
cred.type != "oauth2"
|
|
||||||
or _credential_has_required_scopes(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
and (
|
|
||||||
cred.type != "host_scoped"
|
|
||||||
or _credential_is_for_host(cred, credential_requirements)
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -330,10 +318,19 @@ async def match_user_credentials_to_graph(
|
|||||||
|
|
||||||
|
|
||||||
def _credential_has_required_scopes(
|
def _credential_has_required_scopes(
|
||||||
credential: OAuth2Credentials,
|
credential: Credentials,
|
||||||
requirements: CredentialsFieldInfo,
|
requirements: CredentialsFieldInfo,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if an OAuth2 credential has all the scopes required by the input."""
|
"""
|
||||||
|
Check if a credential has all the scopes required by the block.
|
||||||
|
|
||||||
|
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||||
|
of the required scopes. For other credential types, returns True (no scope check).
|
||||||
|
"""
|
||||||
|
# Only OAuth2 credentials have scopes to check
|
||||||
|
if credential.type != "oauth2":
|
||||||
|
return True
|
||||||
|
|
||||||
# If no scopes are required, any credential matches
|
# If no scopes are required, any credential matches
|
||||||
if not requirements.required_scopes:
|
if not requirements.required_scopes:
|
||||||
return True
|
return True
|
||||||
@@ -342,22 +339,6 @@ def _credential_has_required_scopes(
|
|||||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||||
|
|
||||||
|
|
||||||
def _credential_is_for_host(
|
|
||||||
credential: HostScopedCredentials,
|
|
||||||
requirements: CredentialsFieldInfo,
|
|
||||||
) -> bool:
|
|
||||||
"""Check if a host-scoped credential matches the host required by the input."""
|
|
||||||
# We need to know the host to match host-scoped credentials to.
|
|
||||||
# Graph.aggregate_credentials_inputs() adds the node's set URL value (if any)
|
|
||||||
# to discriminator_values. No discriminator_values -> no host to match against.
|
|
||||||
if not requirements.discriminator_values:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check that credential host matches required host.
|
|
||||||
# Host-scoped credential inputs are grouped by host, so any item from the set works.
|
|
||||||
return credential.matches_url(list(requirements.discriminator_values)[0])
|
|
||||||
|
|
||||||
|
|
||||||
async def check_user_has_required_credentials(
|
async def check_user_has_required_credentials(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
required_credentials: list[CredentialsMetaInput],
|
required_credentials: list[CredentialsMetaInput],
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
@@ -78,6 +78,65 @@ class WorkspaceDeleteResponse(ToolResponseBase):
|
|||||||
success: bool
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
# Input models for workspace tools
|
||||||
|
class ListWorkspaceFilesInput(BaseModel):
|
||||||
|
"""Input parameters for list_workspace_files tool."""
|
||||||
|
|
||||||
|
path_prefix: str | None = None
|
||||||
|
limit: int = 50
|
||||||
|
include_all_sessions: bool = False
|
||||||
|
|
||||||
|
@field_validator("path_prefix", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_path(cls, v: Any) -> str | None:
|
||||||
|
return v.strip() if isinstance(v, str) else None
|
||||||
|
|
||||||
|
|
||||||
|
class ReadWorkspaceFileInput(BaseModel):
|
||||||
|
"""Input parameters for read_workspace_file tool."""
|
||||||
|
|
||||||
|
file_id: str | None = None
|
||||||
|
path: str | None = None
|
||||||
|
force_download_url: bool = False
|
||||||
|
|
||||||
|
@field_validator("file_id", "path", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_strings(cls, v: Any) -> str | None:
|
||||||
|
return v.strip() if isinstance(v, str) else None
|
||||||
|
|
||||||
|
|
||||||
|
class WriteWorkspaceFileInput(BaseModel):
|
||||||
|
"""Input parameters for write_workspace_file tool."""
|
||||||
|
|
||||||
|
filename: str = ""
|
||||||
|
content_base64: str = ""
|
||||||
|
path: str | None = None
|
||||||
|
mime_type: str | None = None
|
||||||
|
overwrite: bool = False
|
||||||
|
|
||||||
|
@field_validator("filename", "content_base64", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_required(cls, v: Any) -> str:
|
||||||
|
return v.strip() if isinstance(v, str) else (v if v is not None else "")
|
||||||
|
|
||||||
|
@field_validator("path", "mime_type", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_optional(cls, v: Any) -> str | None:
|
||||||
|
return v.strip() if isinstance(v, str) else None
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteWorkspaceFileInput(BaseModel):
|
||||||
|
"""Input parameters for delete_workspace_file tool."""
|
||||||
|
|
||||||
|
file_id: str | None = None
|
||||||
|
path: str | None = None
|
||||||
|
|
||||||
|
@field_validator("file_id", "path", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def strip_strings(cls, v: Any) -> str | None:
|
||||||
|
return v.strip() if isinstance(v, str) else None
|
||||||
|
|
||||||
|
|
||||||
class ListWorkspaceFilesTool(BaseTool):
|
class ListWorkspaceFilesTool(BaseTool):
|
||||||
"""Tool for listing files in user's workspace."""
|
"""Tool for listing files in user's workspace."""
|
||||||
|
|
||||||
@@ -131,8 +190,9 @@ class ListWorkspaceFilesTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
|
params = ListWorkspaceFilesInput(**kwargs)
|
||||||
session_id = session.session_id
|
session_id = session.session_id
|
||||||
|
|
||||||
if not user_id:
|
if not user_id:
|
||||||
@@ -141,9 +201,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
limit = min(params.limit, 100)
|
||||||
limit = min(kwargs.get("limit", 50), 100)
|
|
||||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workspace = await get_or_create_workspace(user_id)
|
workspace = await get_or_create_workspace(user_id)
|
||||||
@@ -151,13 +209,13 @@ class ListWorkspaceFilesTool(BaseTool):
|
|||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
files = await manager.list_files(
|
files = await manager.list_files(
|
||||||
path=path_prefix,
|
path=params.path_prefix,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
include_all_sessions=include_all_sessions,
|
include_all_sessions=params.include_all_sessions,
|
||||||
)
|
)
|
||||||
total = await manager.get_file_count(
|
total = await manager.get_file_count(
|
||||||
path=path_prefix,
|
path=params.path_prefix,
|
||||||
include_all_sessions=include_all_sessions,
|
include_all_sessions=params.include_all_sessions,
|
||||||
)
|
)
|
||||||
|
|
||||||
file_infos = [
|
file_infos = [
|
||||||
@@ -171,7 +229,9 @@ class ListWorkspaceFilesTool(BaseTool):
|
|||||||
for f in files
|
for f in files
|
||||||
]
|
]
|
||||||
|
|
||||||
scope_msg = "all sessions" if include_all_sessions else "current session"
|
scope_msg = (
|
||||||
|
"all sessions" if params.include_all_sessions else "current session"
|
||||||
|
)
|
||||||
return WorkspaceFileListResponse(
|
return WorkspaceFileListResponse(
|
||||||
files=file_infos,
|
files=file_infos,
|
||||||
total_count=total,
|
total_count=total,
|
||||||
@@ -259,8 +319,9 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
|
params = ReadWorkspaceFileInput(**kwargs)
|
||||||
session_id = session.session_id
|
session_id = session.session_id
|
||||||
|
|
||||||
if not user_id:
|
if not user_id:
|
||||||
@@ -269,11 +330,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
file_id: Optional[str] = kwargs.get("file_id")
|
if not params.file_id and not params.path:
|
||||||
path: Optional[str] = kwargs.get("path")
|
|
||||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
|
||||||
|
|
||||||
if not file_id and not path:
|
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide either file_id or path",
|
message="Please provide either file_id or path",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -285,21 +342,21 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
# Get file info
|
# Get file info
|
||||||
if file_id:
|
if params.file_id:
|
||||||
file_info = await manager.get_file_info(file_id)
|
file_info = await manager.get_file_info(params.file_id)
|
||||||
if file_info is None:
|
if file_info is None:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"File not found: {file_id}",
|
message=f"File not found: {params.file_id}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
target_file_id = file_id
|
target_file_id = params.file_id
|
||||||
else:
|
else:
|
||||||
# path is guaranteed to be non-None here due to the check above
|
# path is guaranteed to be non-None here due to the check above
|
||||||
assert path is not None
|
assert params.path is not None
|
||||||
file_info = await manager.get_file_info_by_path(path)
|
file_info = await manager.get_file_info_by_path(params.path)
|
||||||
if file_info is None:
|
if file_info is None:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"File not found at path: {path}",
|
message=f"File not found at path: {params.path}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
target_file_id = file_info.id
|
target_file_id = file_info.id
|
||||||
@@ -309,7 +366,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||||
|
|
||||||
# Return inline content for small text files (unless force_download_url)
|
# Return inline content for small text files (unless force_download_url)
|
||||||
if is_small_file and is_text_file and not force_download_url:
|
if is_small_file and is_text_file and not params.force_download_url:
|
||||||
content = await manager.read_file_by_id(target_file_id)
|
content = await manager.read_file_by_id(target_file_id)
|
||||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||||
|
|
||||||
@@ -429,8 +486,9 @@ class WriteWorkspaceFileTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
|
params = WriteWorkspaceFileInput(**kwargs)
|
||||||
session_id = session.session_id
|
session_id = session.session_id
|
||||||
|
|
||||||
if not user_id:
|
if not user_id:
|
||||||
@@ -439,19 +497,13 @@ class WriteWorkspaceFileTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
filename: str = kwargs.get("filename", "")
|
if not params.filename:
|
||||||
content_b64: str = kwargs.get("content_base64", "")
|
|
||||||
path: Optional[str] = kwargs.get("path")
|
|
||||||
mime_type: Optional[str] = kwargs.get("mime_type")
|
|
||||||
overwrite: bool = kwargs.get("overwrite", False)
|
|
||||||
|
|
||||||
if not filename:
|
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide a filename",
|
message="Please provide a filename",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not content_b64:
|
if not params.content_base64:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide content_base64",
|
message="Please provide content_base64",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -459,7 +511,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
|||||||
|
|
||||||
# Decode content
|
# Decode content
|
||||||
try:
|
try:
|
||||||
content = base64.b64decode(content_b64)
|
content = base64.b64decode(params.content_base64)
|
||||||
except Exception:
|
except Exception:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Invalid base64-encoded content",
|
message="Invalid base64-encoded content",
|
||||||
@@ -476,7 +528,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Virus scan
|
# Virus scan
|
||||||
await scan_content_safe(content, filename=filename)
|
await scan_content_safe(content, filename=params.filename)
|
||||||
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
workspace = await get_or_create_workspace(user_id)
|
||||||
# Pass session_id for session-scoped file access
|
# Pass session_id for session-scoped file access
|
||||||
@@ -484,10 +536,10 @@ class WriteWorkspaceFileTool(BaseTool):
|
|||||||
|
|
||||||
file_record = await manager.write_file(
|
file_record = await manager.write_file(
|
||||||
content=content,
|
content=content,
|
||||||
filename=filename,
|
filename=params.filename,
|
||||||
path=path,
|
path=params.path,
|
||||||
mime_type=mime_type,
|
mime_type=params.mime_type,
|
||||||
overwrite=overwrite,
|
overwrite=params.overwrite,
|
||||||
)
|
)
|
||||||
|
|
||||||
return WorkspaceWriteResponse(
|
return WorkspaceWriteResponse(
|
||||||
@@ -557,8 +609,9 @@ class DeleteWorkspaceFileTool(BaseTool):
|
|||||||
self,
|
self,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
|
params = DeleteWorkspaceFileInput(**kwargs)
|
||||||
session_id = session.session_id
|
session_id = session.session_id
|
||||||
|
|
||||||
if not user_id:
|
if not user_id:
|
||||||
@@ -567,10 +620,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
file_id: Optional[str] = kwargs.get("file_id")
|
if not params.file_id and not params.path:
|
||||||
path: Optional[str] = kwargs.get("path")
|
|
||||||
|
|
||||||
if not file_id and not path:
|
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message="Please provide either file_id or path",
|
message="Please provide either file_id or path",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -583,15 +633,15 @@ class DeleteWorkspaceFileTool(BaseTool):
|
|||||||
|
|
||||||
# Determine the file_id to delete
|
# Determine the file_id to delete
|
||||||
target_file_id: str
|
target_file_id: str
|
||||||
if file_id:
|
if params.file_id:
|
||||||
target_file_id = file_id
|
target_file_id = params.file_id
|
||||||
else:
|
else:
|
||||||
# path is guaranteed to be non-None here due to the check above
|
# path is guaranteed to be non-None here due to the check above
|
||||||
assert path is not None
|
assert params.path is not None
|
||||||
file_info = await manager.get_file_info_by_path(path)
|
file_info = await manager.get_file_info_by_path(params.path)
|
||||||
if file_info is None:
|
if file_info is None:
|
||||||
return ErrorResponse(
|
return ErrorResponse(
|
||||||
message=f"File not found at path: {path}",
|
message=f"File not found at path: {params.path}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
target_file_id = file_info.id
|
target_file_id = file_info.id
|
||||||
|
|||||||
@@ -19,10 +19,7 @@ from backend.data.graph import GraphSettings
|
|||||||
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.integrations.webhooks.graph_lifecycle_hooks import (
|
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
|
||||||
on_graph_activate,
|
|
||||||
on_graph_deactivate,
|
|
||||||
)
|
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
@@ -540,92 +537,6 @@ async def update_agent_version_in_library(
|
|||||||
return library_model.LibraryAgent.from_db(lib)
|
return library_model.LibraryAgent.from_db(lib)
|
||||||
|
|
||||||
|
|
||||||
async def create_graph_in_library(
|
|
||||||
graph: graph_db.Graph,
|
|
||||||
user_id: str,
|
|
||||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
|
||||||
"""Create a new graph and add it to the user's library."""
|
|
||||||
graph.version = 1
|
|
||||||
graph_model = graph_db.make_graph_model(graph, user_id)
|
|
||||||
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
|
||||||
|
|
||||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
|
||||||
|
|
||||||
library_agents = await create_library_agent(
|
|
||||||
graph=created_graph,
|
|
||||||
user_id=user_id,
|
|
||||||
sensitive_action_safe_mode=True,
|
|
||||||
create_library_agents_for_sub_graphs=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if created_graph.is_active:
|
|
||||||
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
|
||||||
|
|
||||||
return created_graph, library_agents[0]
|
|
||||||
|
|
||||||
|
|
||||||
async def update_graph_in_library(
|
|
||||||
graph: graph_db.Graph,
|
|
||||||
user_id: str,
|
|
||||||
) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]:
|
|
||||||
"""Create a new version of an existing graph and update the library entry."""
|
|
||||||
existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id)
|
|
||||||
current_active_version = (
|
|
||||||
next((v for v in existing_versions if v.is_active), None)
|
|
||||||
if existing_versions
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
graph.version = (
|
|
||||||
max(v.version for v in existing_versions) + 1 if existing_versions else 1
|
|
||||||
)
|
|
||||||
|
|
||||||
graph_model = graph_db.make_graph_model(graph, user_id)
|
|
||||||
graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
|
||||||
|
|
||||||
created_graph = await graph_db.create_graph(graph_model, user_id)
|
|
||||||
|
|
||||||
library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id)
|
|
||||||
if not library_agent:
|
|
||||||
raise NotFoundError(f"Library agent not found for graph {created_graph.id}")
|
|
||||||
|
|
||||||
library_agent = await update_library_agent_version_and_settings(
|
|
||||||
user_id, created_graph
|
|
||||||
)
|
|
||||||
|
|
||||||
if created_graph.is_active:
|
|
||||||
created_graph = await on_graph_activate(created_graph, user_id=user_id)
|
|
||||||
await graph_db.set_graph_active_version(
|
|
||||||
graph_id=created_graph.id,
|
|
||||||
version=created_graph.version,
|
|
||||||
user_id=user_id,
|
|
||||||
)
|
|
||||||
if current_active_version:
|
|
||||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
|
||||||
|
|
||||||
return created_graph, library_agent
|
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent_version_and_settings(
|
|
||||||
user_id: str, agent_graph: graph_db.GraphModel
|
|
||||||
) -> library_model.LibraryAgent:
|
|
||||||
"""Update library agent to point to new graph version and sync settings."""
|
|
||||||
library = await update_agent_version_in_library(
|
|
||||||
user_id, agent_graph.id, agent_graph.version
|
|
||||||
)
|
|
||||||
updated_settings = GraphSettings.from_graph(
|
|
||||||
graph=agent_graph,
|
|
||||||
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
|
||||||
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
|
||||||
)
|
|
||||||
if updated_settings != library.settings:
|
|
||||||
library = await update_library_agent(
|
|
||||||
library_agent_id=library.id,
|
|
||||||
user_id=user_id,
|
|
||||||
settings=updated_settings,
|
|
||||||
)
|
|
||||||
return library
|
|
||||||
|
|
||||||
|
|
||||||
async def update_library_agent(
|
async def update_library_agent(
|
||||||
library_agent_id: str,
|
library_agent_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
|||||||
@@ -101,6 +101,7 @@ from backend.util.timezone_utils import (
|
|||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
|
|
||||||
from .library import db as library_db
|
from .library import db as library_db
|
||||||
|
from .library import model as library_model
|
||||||
from .store.model import StoreAgentDetails
|
from .store.model import StoreAgentDetails
|
||||||
|
|
||||||
|
|
||||||
@@ -822,16 +823,18 @@ async def update_graph(
|
|||||||
graph: graph_db.Graph,
|
graph: graph_db.Graph,
|
||||||
user_id: Annotated[str, Security(get_user_id)],
|
user_id: Annotated[str, Security(get_user_id)],
|
||||||
) -> graph_db.GraphModel:
|
) -> graph_db.GraphModel:
|
||||||
|
# Sanity check
|
||||||
if graph.id and graph.id != graph_id:
|
if graph.id and graph.id != graph_id:
|
||||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||||
|
|
||||||
|
# Determine new version
|
||||||
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||||
if not existing_versions:
|
if not existing_versions:
|
||||||
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
raise HTTPException(404, detail=f"Graph #{graph_id} not found")
|
||||||
|
latest_version_number = max(g.version for g in existing_versions)
|
||||||
|
graph.version = latest_version_number + 1
|
||||||
|
|
||||||
graph.version = max(g.version for g in existing_versions) + 1
|
|
||||||
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
current_active_version = next((v for v in existing_versions if v.is_active), None)
|
||||||
|
|
||||||
graph = graph_db.make_graph_model(graph, user_id)
|
graph = graph_db.make_graph_model(graph, user_id)
|
||||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
graph.reassign_ids(user_id=user_id, reassign_graph_id=False)
|
||||||
graph.validate_graph(for_run=False)
|
graph.validate_graph(for_run=False)
|
||||||
@@ -839,23 +842,27 @@ async def update_graph(
|
|||||||
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
|
||||||
|
|
||||||
if new_graph_version.is_active:
|
if new_graph_version.is_active:
|
||||||
await library_db.update_library_agent_version_and_settings(
|
# Keep the library agent up to date with the new active version
|
||||||
user_id, new_graph_version
|
await _update_library_agent_version_and_settings(user_id, new_graph_version)
|
||||||
)
|
|
||||||
|
# Handle activation of the new graph first to ensure continuity
|
||||||
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id)
|
||||||
|
# Ensure new version is the only active version
|
||||||
await graph_db.set_graph_active_version(
|
await graph_db.set_graph_active_version(
|
||||||
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
graph_id=graph_id, version=new_graph_version.version, user_id=user_id
|
||||||
)
|
)
|
||||||
if current_active_version:
|
if current_active_version:
|
||||||
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(current_active_version, user_id=user_id)
|
await on_graph_deactivate(current_active_version, user_id=user_id)
|
||||||
|
|
||||||
|
# Fetch new graph version *with sub-graphs* (needed for credentials input schema)
|
||||||
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
new_graph_version_with_subgraphs = await graph_db.get_graph(
|
||||||
graph_id,
|
graph_id,
|
||||||
new_graph_version.version,
|
new_graph_version.version,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
include_subgraphs=True,
|
include_subgraphs=True,
|
||||||
)
|
)
|
||||||
assert new_graph_version_with_subgraphs
|
assert new_graph_version_with_subgraphs # make type checker happy
|
||||||
return new_graph_version_with_subgraphs
|
return new_graph_version_with_subgraphs
|
||||||
|
|
||||||
|
|
||||||
@@ -893,15 +900,33 @@ async def set_graph_active_version(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Keep the library agent up to date with the new active version
|
# Keep the library agent up to date with the new active version
|
||||||
await library_db.update_library_agent_version_and_settings(
|
await _update_library_agent_version_and_settings(user_id, new_active_graph)
|
||||||
user_id, new_active_graph
|
|
||||||
)
|
|
||||||
|
|
||||||
if current_active_graph and current_active_graph.version != new_active_version:
|
if current_active_graph and current_active_graph.version != new_active_version:
|
||||||
# Handle deactivation of the previously active version
|
# Handle deactivation of the previously active version
|
||||||
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
await on_graph_deactivate(current_active_graph, user_id=user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def _update_library_agent_version_and_settings(
|
||||||
|
user_id: str, agent_graph: graph_db.GraphModel
|
||||||
|
) -> library_model.LibraryAgent:
|
||||||
|
library = await library_db.update_agent_version_in_library(
|
||||||
|
user_id, agent_graph.id, agent_graph.version
|
||||||
|
)
|
||||||
|
updated_settings = GraphSettings.from_graph(
|
||||||
|
graph=agent_graph,
|
||||||
|
hitl_safe_mode=library.settings.human_in_the_loop_safe_mode,
|
||||||
|
sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode,
|
||||||
|
)
|
||||||
|
if updated_settings != library.settings:
|
||||||
|
library = await library_db.update_library_agent(
|
||||||
|
library_agent_id=library.id,
|
||||||
|
user_id=user_id,
|
||||||
|
settings=updated_settings,
|
||||||
|
)
|
||||||
|
return library
|
||||||
|
|
||||||
|
|
||||||
@v1_router.patch(
|
@v1_router.patch(
|
||||||
path="/graphs/{graph_id}/settings",
|
path="/graphs/{graph_id}/settings",
|
||||||
summary="Update graph settings",
|
summary="Update graph settings",
|
||||||
|
|||||||
@@ -162,16 +162,8 @@ class LinearClient:
|
|||||||
"searchTerm": team_name,
|
"searchTerm": team_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
result = await self.query(query, variables)
|
team_id = await self.query(query, variables)
|
||||||
nodes = result["teams"]["nodes"]
|
return team_id["teams"]["nodes"][0]["id"]
|
||||||
|
|
||||||
if not nodes:
|
|
||||||
raise LinearAPIException(
|
|
||||||
f"Team '{team_name}' not found. Check the team name or key and try again.",
|
|
||||||
status_code=404,
|
|
||||||
)
|
|
||||||
|
|
||||||
return nodes[0]["id"]
|
|
||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@@ -248,44 +240,17 @@ class LinearClient:
|
|||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def try_search_issues(
|
async def try_search_issues(self, term: str) -> list[Issue]:
|
||||||
self,
|
|
||||||
term: str,
|
|
||||||
max_results: int = 10,
|
|
||||||
team_id: str | None = None,
|
|
||||||
) -> list[Issue]:
|
|
||||||
try:
|
try:
|
||||||
query = """
|
query = """
|
||||||
query SearchIssues(
|
query SearchIssues($term: String!, $includeComments: Boolean!) {
|
||||||
$term: String!,
|
searchIssues(term: $term, includeComments: $includeComments) {
|
||||||
$first: Int,
|
|
||||||
$teamId: String
|
|
||||||
) {
|
|
||||||
searchIssues(
|
|
||||||
term: $term,
|
|
||||||
first: $first,
|
|
||||||
teamId: $teamId
|
|
||||||
) {
|
|
||||||
nodes {
|
nodes {
|
||||||
id
|
id
|
||||||
identifier
|
identifier
|
||||||
title
|
title
|
||||||
description
|
description
|
||||||
priority
|
priority
|
||||||
createdAt
|
|
||||||
state {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
type
|
|
||||||
}
|
|
||||||
project {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
}
|
|
||||||
assignee {
|
|
||||||
id
|
|
||||||
name
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -293,8 +258,7 @@ class LinearClient:
|
|||||||
|
|
||||||
variables: dict[str, Any] = {
|
variables: dict[str, Any] = {
|
||||||
"term": term,
|
"term": term,
|
||||||
"first": max_results,
|
"includeComments": True,
|
||||||
"teamId": team_id,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
issues = await self.query(query, variables)
|
issues = await self.query(query, variables)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from ._config import (
|
|||||||
LinearScope,
|
LinearScope,
|
||||||
linear,
|
linear,
|
||||||
)
|
)
|
||||||
from .models import CreateIssueResponse, Issue, State
|
from .models import CreateIssueResponse, Issue
|
||||||
|
|
||||||
|
|
||||||
class LinearCreateIssueBlock(Block):
|
class LinearCreateIssueBlock(Block):
|
||||||
@@ -135,20 +135,9 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
description="Linear credentials with read permissions",
|
description="Linear credentials with read permissions",
|
||||||
required_scopes={LinearScope.READ},
|
required_scopes={LinearScope.READ},
|
||||||
)
|
)
|
||||||
max_results: int = SchemaField(
|
|
||||||
description="Maximum number of results to return",
|
|
||||||
default=10,
|
|
||||||
ge=1,
|
|
||||||
le=100,
|
|
||||||
)
|
|
||||||
team_name: str | None = SchemaField(
|
|
||||||
description="Optional team name to filter results (e.g., 'Internal', 'Open Source')",
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
class Output(BlockSchemaOutput):
|
class Output(BlockSchemaOutput):
|
||||||
issues: list[Issue] = SchemaField(description="List of issues")
|
issues: list[Issue] = SchemaField(description="List of issues")
|
||||||
error: str = SchemaField(description="Error message if the search failed")
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -156,11 +145,8 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
description="Searches for issues on Linear",
|
description="Searches for issues on Linear",
|
||||||
input_schema=self.Input,
|
input_schema=self.Input,
|
||||||
output_schema=self.Output,
|
output_schema=self.Output,
|
||||||
categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING},
|
|
||||||
test_input={
|
test_input={
|
||||||
"term": "Test issue",
|
"term": "Test issue",
|
||||||
"max_results": 10,
|
|
||||||
"team_name": None,
|
|
||||||
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
"credentials": TEST_CREDENTIALS_INPUT_OAUTH,
|
||||||
},
|
},
|
||||||
test_credentials=TEST_CREDENTIALS_OAUTH,
|
test_credentials=TEST_CREDENTIALS_OAUTH,
|
||||||
@@ -170,14 +156,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
[
|
[
|
||||||
Issue(
|
Issue(
|
||||||
id="abc123",
|
id="abc123",
|
||||||
identifier="TST-123",
|
identifier="abc123",
|
||||||
title="Test issue",
|
title="Test issue",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
priority=1,
|
priority=1,
|
||||||
state=State(
|
|
||||||
id="state1", name="In Progress", type="started"
|
|
||||||
),
|
|
||||||
createdAt="2026-01-15T10:00:00.000Z",
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@@ -186,12 +168,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
"search_issues": lambda *args, **kwargs: [
|
"search_issues": lambda *args, **kwargs: [
|
||||||
Issue(
|
Issue(
|
||||||
id="abc123",
|
id="abc123",
|
||||||
identifier="TST-123",
|
identifier="abc123",
|
||||||
title="Test issue",
|
title="Test issue",
|
||||||
description="Test description",
|
description="Test description",
|
||||||
priority=1,
|
priority=1,
|
||||||
state=State(id="state1", name="In Progress", type="started"),
|
|
||||||
createdAt="2026-01-15T10:00:00.000Z",
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -201,22 +181,10 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
async def search_issues(
|
async def search_issues(
|
||||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||||
term: str,
|
term: str,
|
||||||
max_results: int = 10,
|
|
||||||
team_name: str | None = None,
|
|
||||||
) -> list[Issue]:
|
) -> list[Issue]:
|
||||||
client = LinearClient(credentials=credentials)
|
client = LinearClient(credentials=credentials)
|
||||||
|
response: list[Issue] = await client.try_search_issues(term=term)
|
||||||
# Resolve team name to ID if provided
|
return response
|
||||||
# Raises LinearAPIException with descriptive message if team not found
|
|
||||||
team_id: str | None = None
|
|
||||||
if team_name:
|
|
||||||
team_id = await client.try_get_team_by_name(team_name=team_name)
|
|
||||||
|
|
||||||
return await client.try_search_issues(
|
|
||||||
term=term,
|
|
||||||
max_results=max_results,
|
|
||||||
team_id=team_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
@@ -228,10 +196,7 @@ class LinearSearchIssuesBlock(Block):
|
|||||||
"""Execute the issue search"""
|
"""Execute the issue search"""
|
||||||
try:
|
try:
|
||||||
issues = await self.search_issues(
|
issues = await self.search_issues(
|
||||||
credentials=credentials,
|
credentials=credentials, term=input_data.term
|
||||||
term=input_data.term,
|
|
||||||
max_results=input_data.max_results,
|
|
||||||
team_name=input_data.team_name,
|
|
||||||
)
|
)
|
||||||
yield "issues", issues
|
yield "issues", issues
|
||||||
except LinearAPIException as e:
|
except LinearAPIException as e:
|
||||||
|
|||||||
@@ -36,21 +36,12 @@ class Project(BaseModel):
|
|||||||
content: str | None = None
|
content: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class State(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
type: str | None = (
|
|
||||||
None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Issue(BaseModel):
|
class Issue(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
identifier: str
|
identifier: str
|
||||||
title: str
|
title: str
|
||||||
description: str | None
|
description: str | None
|
||||||
priority: int
|
priority: int
|
||||||
state: State | None = None
|
|
||||||
project: Project | None = None
|
project: Project | None = None
|
||||||
createdAt: str | None = None
|
createdAt: str | None = None
|
||||||
comments: list[Comment] | None = None
|
comments: list[Comment] | None = None
|
||||||
|
|||||||
@@ -165,13 +165,10 @@ class TranscribeYoutubeVideoBlock(Block):
|
|||||||
credentials: WebshareProxyCredentials,
|
credentials: WebshareProxyCredentials,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> BlockOutput:
|
) -> BlockOutput:
|
||||||
try:
|
|
||||||
video_id = self.extract_video_id(input_data.youtube_url)
|
video_id = self.extract_video_id(input_data.youtube_url)
|
||||||
|
yield "video_id", video_id
|
||||||
|
|
||||||
transcript = self.get_transcript(video_id, credentials)
|
transcript = self.get_transcript(video_id, credentials)
|
||||||
transcript_text = self.format_transcript(transcript=transcript)
|
transcript_text = self.format_transcript(transcript=transcript)
|
||||||
|
|
||||||
# Only yield after all operations succeed
|
|
||||||
yield "video_id", video_id
|
|
||||||
yield "transcript", transcript_text
|
yield "transcript", transcript_text
|
||||||
except Exception as e:
|
|
||||||
yield "error", str(e)
|
|
||||||
|
|||||||
@@ -134,16 +134,6 @@ async def test_block_credit_reset(server: SpinTestServer):
|
|||||||
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
month1 = datetime.now(timezone.utc).replace(month=1, day=1)
|
||||||
user_credit.time_now = lambda: month1
|
user_credit.time_now = lambda: month1
|
||||||
|
|
||||||
# IMPORTANT: Set updatedAt to December of previous year to ensure it's
|
|
||||||
# in a different month than month1 (January). This fixes a timing bug
|
|
||||||
# where if the test runs in early February, 35 days ago would be January,
|
|
||||||
# matching the mocked month1 and preventing the refill from triggering.
|
|
||||||
dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15)
|
|
||||||
await UserBalance.prisma().update(
|
|
||||||
where={"userId": DEFAULT_USER_ID},
|
|
||||||
data={"updatedAt": dec_previous_year},
|
|
||||||
)
|
|
||||||
|
|
||||||
# First call in month 1 should trigger refill
|
# First call in month 1 should trigger refill
|
||||||
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
balance = await user_credit.get_credits(DEFAULT_USER_ID)
|
||||||
assert balance == REFILL_VALUE # Should get 1000 credits
|
assert balance == REFILL_VALUE # Should get 1000 credits
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
get_args,
|
get_args,
|
||||||
)
|
)
|
||||||
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from prisma.enums import CreditTransactionType, OnboardingStep
|
from prisma.enums import CreditTransactionType, OnboardingStep
|
||||||
@@ -41,7 +42,6 @@ from typing_extensions import TypedDict
|
|||||||
|
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util.json import loads as json_loads
|
from backend.util.json import loads as json_loads
|
||||||
from backend.util.request import parse_url
|
|
||||||
from backend.util.settings import Secrets
|
from backend.util.settings import Secrets
|
||||||
|
|
||||||
# Type alias for any provider name (including custom ones)
|
# Type alias for any provider name (including custom ones)
|
||||||
@@ -397,25 +397,19 @@ class HostScopedCredentials(_BaseCredentials):
|
|||||||
def matches_url(self, url: str) -> bool:
|
def matches_url(self, url: str) -> bool:
|
||||||
"""Check if this credential should be applied to the given URL."""
|
"""Check if this credential should be applied to the given URL."""
|
||||||
|
|
||||||
request_host, request_port = _extract_host_from_url(url)
|
parsed_url = urlparse(url)
|
||||||
cred_scope_host, cred_scope_port = _extract_host_from_url(self.host)
|
# Extract hostname without port
|
||||||
|
request_host = parsed_url.hostname
|
||||||
if not request_host:
|
if not request_host:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# If a port is specified in credential host, the request host port must match
|
# Simple host matching - exact match or wildcard subdomain match
|
||||||
if cred_scope_port is not None and request_port != cred_scope_port:
|
if self.host == request_host:
|
||||||
return False
|
|
||||||
# Non-standard ports are only allowed if explicitly specified in credential host
|
|
||||||
elif cred_scope_port is None and request_port not in (80, 443, None):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Simple host matching
|
|
||||||
if cred_scope_host == request_host:
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
# Support wildcard matching (e.g., "*.example.com" matches "api.example.com")
|
||||||
if cred_scope_host.startswith("*."):
|
if self.host.startswith("*."):
|
||||||
domain = cred_scope_host[2:] # Remove "*."
|
domain = self.host[2:] # Remove "*."
|
||||||
return request_host.endswith(f".{domain}") or request_host == domain
|
return request_host.endswith(f".{domain}") or request_host == domain
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -557,13 +551,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _extract_host_from_url(url: str) -> tuple[str, int | None]:
|
def _extract_host_from_url(url: str) -> str:
|
||||||
"""Extract host and port from URL for grouping host-scoped credentials."""
|
"""Extract host from URL for grouping host-scoped credentials."""
|
||||||
try:
|
try:
|
||||||
parsed = parse_url(url)
|
parsed = urlparse(url)
|
||||||
return parsed.hostname or url, parsed.port
|
return parsed.hostname or url
|
||||||
except Exception:
|
except Exception:
|
||||||
return "", None
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||||
@@ -612,7 +606,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
|||||||
providers = frozenset(
|
providers = frozenset(
|
||||||
[cast(CP, "http")]
|
[cast(CP, "http")]
|
||||||
+ [
|
+ [
|
||||||
cast(CP, parse_url(str(value)).netloc)
|
cast(CP, _extract_host_from_url(str(value)))
|
||||||
for value in field.discriminator_values
|
for value in field.discriminator_values
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -79,23 +79,10 @@ class TestHostScopedCredentials:
|
|||||||
headers={"Authorization": SecretStr("Bearer token")},
|
headers={"Authorization": SecretStr("Bearer token")},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Non-standard ports require explicit port in credential host
|
assert creds.matches_url("http://localhost:8080/api/v1")
|
||||||
assert not creds.matches_url("http://localhost:8080/api/v1")
|
|
||||||
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
assert creds.matches_url("https://localhost:443/secure/endpoint")
|
||||||
assert creds.matches_url("http://localhost/simple")
|
assert creds.matches_url("http://localhost/simple")
|
||||||
|
|
||||||
def test_matches_url_with_explicit_port(self):
|
|
||||||
"""Test URL matching with explicit port in credential host."""
|
|
||||||
creds = HostScopedCredentials(
|
|
||||||
provider="custom",
|
|
||||||
host="localhost:8080",
|
|
||||||
headers={"Authorization": SecretStr("Bearer token")},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert creds.matches_url("http://localhost:8080/api/v1")
|
|
||||||
assert not creds.matches_url("http://localhost:3000/api/v1")
|
|
||||||
assert not creds.matches_url("http://localhost/simple")
|
|
||||||
|
|
||||||
def test_empty_headers_dict(self):
|
def test_empty_headers_dict(self):
|
||||||
"""Test HostScopedCredentials with empty headers."""
|
"""Test HostScopedCredentials with empty headers."""
|
||||||
creds = HostScopedCredentials(
|
creds = HostScopedCredentials(
|
||||||
@@ -141,20 +128,8 @@ class TestHostScopedCredentials:
|
|||||||
("*.example.com", "https://sub.api.example.com/test", True),
|
("*.example.com", "https://sub.api.example.com/test", True),
|
||||||
("*.example.com", "https://example.com/test", True),
|
("*.example.com", "https://example.com/test", True),
|
||||||
("*.example.com", "https://example.org/test", False),
|
("*.example.com", "https://example.org/test", False),
|
||||||
# Non-standard ports require explicit port in credential host
|
("localhost", "http://localhost:3000/test", True),
|
||||||
("localhost", "http://localhost:3000/test", False),
|
|
||||||
("localhost:3000", "http://localhost:3000/test", True),
|
|
||||||
("localhost", "http://127.0.0.1:3000/test", False),
|
("localhost", "http://127.0.0.1:3000/test", False),
|
||||||
# IPv6 addresses (frontend stores with brackets via URL.hostname)
|
|
||||||
("[::1]", "http://[::1]/test", True),
|
|
||||||
("[::1]", "http://[::1]:80/test", True),
|
|
||||||
("[::1]", "https://[::1]:443/test", True),
|
|
||||||
("[::1]", "http://[::1]:8080/test", False), # Non-standard port
|
|
||||||
("[::1]:8080", "http://[::1]:8080/test", True),
|
|
||||||
("[::1]:8080", "http://[::1]:9090/test", False),
|
|
||||||
("[2001:db8::1]", "http://[2001:db8::1]/path", True),
|
|
||||||
("[2001:db8::1]", "https://[2001:db8::1]:443/path", True),
|
|
||||||
("[2001:db8::1]", "http://[2001:db8::ff]/path", False),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool):
|
||||||
|
|||||||
@@ -157,7 +157,12 @@ async def validate_url(
|
|||||||
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
is_trusted: Boolean indicating if the hostname is in trusted_origins
|
||||||
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
ip_addresses: List of IP addresses for the host; empty if the host is trusted
|
||||||
"""
|
"""
|
||||||
parsed = parse_url(url)
|
# Canonicalize URL
|
||||||
|
url = url.strip("/ ").replace("\\", "/")
|
||||||
|
parsed = urlparse(url)
|
||||||
|
if not parsed.scheme:
|
||||||
|
url = f"http://{url}"
|
||||||
|
parsed = urlparse(url)
|
||||||
|
|
||||||
# Check scheme
|
# Check scheme
|
||||||
if parsed.scheme not in ALLOWED_SCHEMES:
|
if parsed.scheme not in ALLOWED_SCHEMES:
|
||||||
@@ -215,17 +220,6 @@ async def validate_url(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_url(url: str) -> URL:
|
|
||||||
"""Canonicalizes and parses a URL string."""
|
|
||||||
url = url.strip("/ ").replace("\\", "/")
|
|
||||||
|
|
||||||
# Ensure scheme is present for proper parsing
|
|
||||||
if not re.match(r"[a-z0-9+.\-]+://", url):
|
|
||||||
url = f"http://{url}"
|
|
||||||
|
|
||||||
return urlparse(url)
|
|
||||||
|
|
||||||
|
|
||||||
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL:
|
||||||
"""
|
"""
|
||||||
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
Pins a URL to a specific IP address to prevent DNS rebinding attacks.
|
||||||
|
|||||||
@@ -1,17 +1,6 @@
|
|||||||
import { OAuthPopupResultMessage } from "./types";
|
import { OAuthPopupResultMessage } from "./types";
|
||||||
import { NextResponse } from "next/server";
|
import { NextResponse } from "next/server";
|
||||||
|
|
||||||
/**
|
|
||||||
* Safely encode a value as JSON for embedding in a script tag.
|
|
||||||
* Escapes characters that could break out of the script context to prevent XSS.
|
|
||||||
*/
|
|
||||||
function safeJsonStringify(value: unknown): string {
|
|
||||||
return JSON.stringify(value)
|
|
||||||
.replace(/</g, "\\u003c")
|
|
||||||
.replace(/>/g, "\\u003e")
|
|
||||||
.replace(/&/g, "\\u0026");
|
|
||||||
}
|
|
||||||
|
|
||||||
// This route is intended to be used as the callback for integration OAuth flows,
|
// This route is intended to be used as the callback for integration OAuth flows,
|
||||||
// controlled by the CredentialsInput component. The CredentialsInput opens the login
|
// controlled by the CredentialsInput component. The CredentialsInput opens the login
|
||||||
// page in a pop-up window, which then redirects to this route to close the loop.
|
// page in a pop-up window, which then redirects to this route to close the loop.
|
||||||
@@ -34,13 +23,12 @@ export async function GET(request: Request) {
|
|||||||
console.debug("Sending message to opener:", message);
|
console.debug("Sending message to opener:", message);
|
||||||
|
|
||||||
// Return a response with the message as JSON and a script to close the window
|
// Return a response with the message as JSON and a script to close the window
|
||||||
// Use safeJsonStringify to prevent XSS by escaping <, >, and & characters
|
|
||||||
return new NextResponse(
|
return new NextResponse(
|
||||||
`
|
`
|
||||||
<html>
|
<html>
|
||||||
<body>
|
<body>
|
||||||
<script>
|
<script>
|
||||||
window.opener.postMessage(${safeJsonStringify(message)});
|
window.opener.postMessage(${JSON.stringify(message)});
|
||||||
window.close();
|
window.close();
|
||||||
</script>
|
</script>
|
||||||
</body>
|
</body>
|
||||||
|
|||||||
@@ -26,20 +26,8 @@ export function buildCopilotChatUrl(prompt: string): string {
|
|||||||
|
|
||||||
export function getQuickActions(): string[] {
|
export function getQuickActions(): string[] {
|
||||||
return [
|
return [
|
||||||
"I don't know where to start, just ask me stuff",
|
"Show me what I can automate",
|
||||||
"I do the same thing every week and it's killing me",
|
"Design a custom workflow",
|
||||||
"Help me find where I'm wasting my time",
|
"Help me with content creation",
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getInputPlaceholder(width?: number) {
|
|
||||||
if (!width) return "What's your role and what eats up most of your day?";
|
|
||||||
|
|
||||||
if (width < 500) {
|
|
||||||
return "I'm a chef and I hate...";
|
|
||||||
}
|
|
||||||
if (width <= 1080) {
|
|
||||||
return "What's your role and what eats up most of your day?";
|
|
||||||
}
|
|
||||||
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ import { Text } from "@/components/atoms/Text/Text";
|
|||||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
import { useEffect, useState } from "react";
|
|
||||||
import { useCopilotStore } from "./copilot-page-store";
|
import { useCopilotStore } from "./copilot-page-store";
|
||||||
import { getInputPlaceholder } from "./helpers";
|
|
||||||
import { useCopilotPage } from "./useCopilotPage";
|
import { useCopilotPage } from "./useCopilotPage";
|
||||||
|
|
||||||
export default function CopilotPage() {
|
export default function CopilotPage() {
|
||||||
@@ -16,25 +14,8 @@ export default function CopilotPage() {
|
|||||||
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||||
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||||
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||||
|
|
||||||
const [inputPlaceholder, setInputPlaceholder] = useState(
|
|
||||||
getInputPlaceholder(),
|
|
||||||
);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const handleResize = () => {
|
|
||||||
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
|
||||||
};
|
|
||||||
|
|
||||||
handleResize();
|
|
||||||
|
|
||||||
window.addEventListener("resize", handleResize);
|
|
||||||
return () => window.removeEventListener("resize", handleResize);
|
|
||||||
}, []);
|
|
||||||
|
|
||||||
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
|
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
|
||||||
state;
|
state;
|
||||||
|
|
||||||
const {
|
const {
|
||||||
handleQuickAction,
|
handleQuickAction,
|
||||||
startChatWithPrompt,
|
startChatWithPrompt,
|
||||||
@@ -92,7 +73,7 @@ export default function CopilotPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-3 py-5 md:px-6 md:py-10">
|
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
||||||
<div className="w-full text-center">
|
<div className="w-full text-center">
|
||||||
{isLoading ? (
|
{isLoading ? (
|
||||||
<div className="mx-auto max-w-2xl">
|
<div className="mx-auto max-w-2xl">
|
||||||
@@ -109,25 +90,25 @@ export default function CopilotPage() {
|
|||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<div className="mx-auto max-w-3xl">
|
<div className="mx-auto max-w-2xl">
|
||||||
<Text
|
<Text
|
||||||
variant="h3"
|
variant="h3"
|
||||||
className="mb-1 !text-[1.375rem] text-zinc-700"
|
className="mb-3 !text-[1.375rem] text-zinc-700"
|
||||||
>
|
>
|
||||||
Hey, <span className="text-violet-600">{greetingName}</span>
|
Hey, <span className="text-violet-600">{greetingName}</span>
|
||||||
</Text>
|
</Text>
|
||||||
<Text variant="h3" className="mb-8 !font-normal">
|
<Text variant="h3" className="mb-8 !font-normal">
|
||||||
Tell me about your work — I'll find what to automate.
|
What do you want to automate?
|
||||||
</Text>
|
</Text>
|
||||||
|
|
||||||
<div className="mb-6">
|
<div className="mb-6">
|
||||||
<ChatInput
|
<ChatInput
|
||||||
onSend={startChatWithPrompt}
|
onSend={startChatWithPrompt}
|
||||||
placeholder={inputPlaceholder}
|
placeholder='You can search or just ask - e.g. "create a blog post outline"'
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
<div className="flex flex-nowrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||||
{quickActions.map((action) => (
|
{quickActions.map((action) => (
|
||||||
<Button
|
<Button
|
||||||
key={action}
|
key={action}
|
||||||
@@ -135,7 +116,7 @@ export default function CopilotPage() {
|
|||||||
variant="outline"
|
variant="outline"
|
||||||
size="small"
|
size="small"
|
||||||
onClick={() => handleQuickAction(action)}
|
onClick={() => handleQuickAction(action)}
|
||||||
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
className="h-auto shrink-0 border-zinc-600 !px-4 !py-2 text-[1rem] text-zinc-600"
|
||||||
>
|
>
|
||||||
{action}
|
{action}
|
||||||
</Button>
|
</Button>
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessi
|
|||||||
import { Button } from "@/components/atoms/Button/Button";
|
import { Button } from "@/components/atoms/Button/Button";
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
import { Text } from "@/components/atoms/Text/Text";
|
||||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||||
|
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { GlobeHemisphereEastIcon } from "@phosphor-icons/react";
|
import { GlobeHemisphereEastIcon } from "@phosphor-icons/react";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
@@ -55,6 +56,10 @@ export function ChatContainer({
|
|||||||
onStreamingChange?.(isStreaming);
|
onStreamingChange?.(isStreaming);
|
||||||
}, [isStreaming, onStreamingChange]);
|
}, [isStreaming, onStreamingChange]);
|
||||||
|
|
||||||
|
const breakpoint = useBreakpoint();
|
||||||
|
const isMobile =
|
||||||
|
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
@@ -122,7 +127,11 @@ export function ChatContainer({
|
|||||||
disabled={isStreaming || !sessionId}
|
disabled={isStreaming || !sessionId}
|
||||||
isStreaming={isStreaming}
|
isStreaming={isStreaming}
|
||||||
onStop={stopStreaming}
|
onStop={stopStreaming}
|
||||||
placeholder="What else can I help with?"
|
placeholder={
|
||||||
|
isMobile
|
||||||
|
? "You can search or just ask"
|
||||||
|
: 'You can search or just ask — e.g. "create a blog post outline"'
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -74,20 +74,19 @@ export function ChatInput({
|
|||||||
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
|
hasMultipleLines ? "rounded-xlarge" : "rounded-full",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{!value && !isRecording && (
|
|
||||||
<div
|
|
||||||
className="pointer-events-none absolute inset-0 top-0.5 flex items-center justify-start pl-14 text-[1rem] text-zinc-400"
|
|
||||||
aria-hidden="true"
|
|
||||||
>
|
|
||||||
{isTranscribing ? "Transcribing..." : placeholder}
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
<textarea
|
<textarea
|
||||||
id={inputId}
|
id={inputId}
|
||||||
aria-label="Chat message input"
|
aria-label="Chat message input"
|
||||||
value={value}
|
value={value}
|
||||||
onChange={handleChange}
|
onChange={handleChange}
|
||||||
onKeyDown={handleKeyDown}
|
onKeyDown={handleKeyDown}
|
||||||
|
placeholder={
|
||||||
|
isTranscribing
|
||||||
|
? "Transcribing..."
|
||||||
|
: isRecording
|
||||||
|
? ""
|
||||||
|
: placeholder
|
||||||
|
}
|
||||||
disabled={isInputDisabled}
|
disabled={isInputDisabled}
|
||||||
rows={1}
|
rows={1}
|
||||||
className={cn(
|
className={cn(
|
||||||
@@ -123,14 +122,13 @@ export function ChatInput({
|
|||||||
size="icon"
|
size="icon"
|
||||||
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
aria-label={isRecording ? "Stop recording" : "Start recording"}
|
||||||
onClick={toggleRecording}
|
onClick={toggleRecording}
|
||||||
disabled={disabled || isTranscribing || isStreaming}
|
disabled={disabled || isTranscribing}
|
||||||
className={cn(
|
className={cn(
|
||||||
isRecording
|
isRecording
|
||||||
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
|
? "animate-pulse border-red-500 bg-red-500 text-white hover:border-red-600 hover:bg-red-600"
|
||||||
: isTranscribing
|
: isTranscribing
|
||||||
? "border-zinc-300 bg-zinc-100 text-zinc-400"
|
? "border-zinc-300 bg-zinc-100 text-zinc-400"
|
||||||
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
: "border-zinc-300 bg-white text-zinc-500 hover:border-zinc-400 hover:bg-zinc-50 hover:text-zinc-700",
|
||||||
isStreaming && "opacity-40",
|
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{isTranscribing ? (
|
{isTranscribing ? (
|
||||||
|
|||||||
@@ -38,8 +38,8 @@ export function AudioWaveform({
|
|||||||
// Create audio context and analyser
|
// Create audio context and analyser
|
||||||
const audioContext = new AudioContext();
|
const audioContext = new AudioContext();
|
||||||
const analyser = audioContext.createAnalyser();
|
const analyser = audioContext.createAnalyser();
|
||||||
analyser.fftSize = 256;
|
analyser.fftSize = 512;
|
||||||
analyser.smoothingTimeConstant = 0.3;
|
analyser.smoothingTimeConstant = 0.8;
|
||||||
|
|
||||||
// Connect the stream to the analyser
|
// Connect the stream to the analyser
|
||||||
const source = audioContext.createMediaStreamSource(stream);
|
const source = audioContext.createMediaStreamSource(stream);
|
||||||
@@ -73,11 +73,10 @@ export function AudioWaveform({
|
|||||||
maxAmplitude = Math.max(maxAmplitude, amplitude);
|
maxAmplitude = Math.max(maxAmplitude, amplitude);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize amplitude (0-128 range) to 0-1
|
// Map amplitude (0-128) to bar height
|
||||||
const normalized = maxAmplitude / 128;
|
const normalized = (maxAmplitude / 128) * 255;
|
||||||
// Apply sensitivity boost (multiply by 4) and use sqrt curve to amplify quiet sounds
|
const height =
|
||||||
const boosted = Math.min(1, Math.sqrt(normalized) * 4);
|
minBarHeight + (normalized / 255) * (maxBarHeight - minBarHeight);
|
||||||
const height = minBarHeight + boosted * (maxBarHeight - minBarHeight);
|
|
||||||
newBars.push(height);
|
newBars.push(height);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -224,7 +224,7 @@ export function useVoiceRecording({
|
|||||||
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
|
[value, isTranscribing, toggleRecording, baseHandleKeyDown],
|
||||||
);
|
);
|
||||||
|
|
||||||
const showMicButton = isSupported;
|
const showMicButton = isSupported && !isStreaming;
|
||||||
const isInputDisabled = disabled || isStreaming || isTranscribing;
|
const isInputDisabled = disabled || isStreaming || isTranscribing;
|
||||||
|
|
||||||
// Cleanup on unmount
|
// Cleanup on unmount
|
||||||
|
|||||||
@@ -346,7 +346,6 @@ export function ChatMessage({
|
|||||||
toolId={message.toolId}
|
toolId={message.toolId}
|
||||||
toolName={message.toolName}
|
toolName={message.toolName}
|
||||||
result={message.result}
|
result={message.result}
|
||||||
onSendMessage={onSendMessage}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -73,7 +73,6 @@ export function MessageList({
|
|||||||
key={index}
|
key={index}
|
||||||
message={message}
|
message={message}
|
||||||
prevMessage={messages[index - 1]}
|
prevMessage={messages[index - 1]}
|
||||||
onSendMessage={onSendMessage}
|
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,13 +5,11 @@ import { shouldSkipAgentOutput } from "../../helpers";
|
|||||||
export interface LastToolResponseProps {
|
export interface LastToolResponseProps {
|
||||||
message: ChatMessageData;
|
message: ChatMessageData;
|
||||||
prevMessage: ChatMessageData | undefined;
|
prevMessage: ChatMessageData | undefined;
|
||||||
onSendMessage?: (content: string) => void;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function LastToolResponse({
|
export function LastToolResponse({
|
||||||
message,
|
message,
|
||||||
prevMessage,
|
prevMessage,
|
||||||
onSendMessage,
|
|
||||||
}: LastToolResponseProps) {
|
}: LastToolResponseProps) {
|
||||||
if (message.type !== "tool_response") return null;
|
if (message.type !== "tool_response") return null;
|
||||||
|
|
||||||
@@ -23,7 +21,6 @@ export function LastToolResponse({
|
|||||||
toolId={message.toolId}
|
toolId={message.toolId}
|
||||||
toolName={message.toolName}
|
toolName={message.toolName}
|
||||||
result={message.result}
|
result={message.result}
|
||||||
onSendMessage={onSendMessage}
|
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
import { Progress } from "@/components/atoms/Progress/Progress";
|
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { useEffect, useRef, useState } from "react";
|
import { useEffect, useRef, useState } from "react";
|
||||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||||
import { useAsymptoticProgress } from "../ToolCallMessage/useAsymptoticProgress";
|
|
||||||
|
|
||||||
export interface ThinkingMessageProps {
|
export interface ThinkingMessageProps {
|
||||||
className?: string;
|
className?: string;
|
||||||
@@ -13,19 +11,18 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
|||||||
const [showCoffeeMessage, setShowCoffeeMessage] = useState(false);
|
const [showCoffeeMessage, setShowCoffeeMessage] = useState(false);
|
||||||
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
const timerRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
const coffeeTimerRef = useRef<NodeJS.Timeout | null>(null);
|
const coffeeTimerRef = useRef<NodeJS.Timeout | null>(null);
|
||||||
const progress = useAsymptoticProgress(showCoffeeMessage);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (timerRef.current === null) {
|
if (timerRef.current === null) {
|
||||||
timerRef.current = setTimeout(() => {
|
timerRef.current = setTimeout(() => {
|
||||||
setShowSlowLoader(true);
|
setShowSlowLoader(true);
|
||||||
}, 3000);
|
}, 8000);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (coffeeTimerRef.current === null) {
|
if (coffeeTimerRef.current === null) {
|
||||||
coffeeTimerRef.current = setTimeout(() => {
|
coffeeTimerRef.current = setTimeout(() => {
|
||||||
setShowCoffeeMessage(true);
|
setShowCoffeeMessage(true);
|
||||||
}, 8000);
|
}, 10000);
|
||||||
}
|
}
|
||||||
|
|
||||||
return () => {
|
return () => {
|
||||||
@@ -52,18 +49,9 @@ export function ThinkingMessage({ className }: ThinkingMessageProps) {
|
|||||||
<AIChatBubble>
|
<AIChatBubble>
|
||||||
<div className="transition-all duration-500 ease-in-out">
|
<div className="transition-all duration-500 ease-in-out">
|
||||||
{showCoffeeMessage ? (
|
{showCoffeeMessage ? (
|
||||||
<div className="flex flex-col items-center gap-3">
|
|
||||||
<div className="flex w-full max-w-[280px] flex-col gap-1.5">
|
|
||||||
<div className="flex items-center justify-between text-xs text-neutral-500">
|
|
||||||
<span>Working on it...</span>
|
|
||||||
<span>{Math.round(progress)}%</span>
|
|
||||||
</div>
|
|
||||||
<Progress value={progress} className="h-2 w-full" />
|
|
||||||
</div>
|
|
||||||
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||||
This could take a few minutes, grab a coffee ☕️
|
This could take a few minutes, grab a coffee ☕️
|
||||||
</span>
|
</span>
|
||||||
</div>
|
|
||||||
) : showSlowLoader ? (
|
) : showSlowLoader ? (
|
||||||
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
<span className="inline-block animate-shimmer bg-gradient-to-r from-neutral-400 via-neutral-600 to-neutral-400 bg-[length:200%_100%] bg-clip-text text-transparent">
|
||||||
Taking a bit more time...
|
Taking a bit more time...
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
import { useEffect, useRef, useState } from "react";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Hook that returns a progress value that starts fast and slows down,
|
|
||||||
* asymptotically approaching but never reaching the max value.
|
|
||||||
*
|
|
||||||
* Uses a half-life formula: progress = max * (1 - 0.5^(time/halfLife))
|
|
||||||
* This creates the "game loading bar" effect where:
|
|
||||||
* - 50% is reached at halfLifeSeconds
|
|
||||||
* - 75% is reached at 2 * halfLifeSeconds
|
|
||||||
* - 87.5% is reached at 3 * halfLifeSeconds
|
|
||||||
* - and so on...
|
|
||||||
*
|
|
||||||
* @param isActive - Whether the progress should be animating
|
|
||||||
* @param halfLifeSeconds - Time in seconds to reach 50% progress (default: 30)
|
|
||||||
* @param maxProgress - Maximum progress value to approach (default: 100)
|
|
||||||
* @param intervalMs - Update interval in milliseconds (default: 100)
|
|
||||||
* @returns Current progress value (0-maxProgress)
|
|
||||||
*/
|
|
||||||
export function useAsymptoticProgress(
|
|
||||||
isActive: boolean,
|
|
||||||
halfLifeSeconds = 30,
|
|
||||||
maxProgress = 100,
|
|
||||||
intervalMs = 100,
|
|
||||||
) {
|
|
||||||
const [progress, setProgress] = useState(0);
|
|
||||||
const elapsedTimeRef = useRef(0);
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (!isActive) {
|
|
||||||
setProgress(0);
|
|
||||||
elapsedTimeRef.current = 0;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const interval = setInterval(() => {
|
|
||||||
elapsedTimeRef.current += intervalMs / 1000;
|
|
||||||
// Half-life approach: progress = max * (1 - 0.5^(time/halfLife))
|
|
||||||
// At t=halfLife: 50%, at t=2*halfLife: 75%, at t=3*halfLife: 87.5%, etc.
|
|
||||||
const newProgress =
|
|
||||||
maxProgress *
|
|
||||||
(1 - Math.pow(0.5, elapsedTimeRef.current / halfLifeSeconds));
|
|
||||||
setProgress(newProgress);
|
|
||||||
}, intervalMs);
|
|
||||||
|
|
||||||
return () => clearInterval(interval);
|
|
||||||
}, [isActive, halfLifeSeconds, maxProgress, intervalMs]);
|
|
||||||
|
|
||||||
return progress;
|
|
||||||
}
|
|
||||||
@@ -1,128 +0,0 @@
|
|||||||
"use client";
|
|
||||||
|
|
||||||
import { useGetV2GetLibraryAgent } from "@/app/api/__generated__/endpoints/library/library";
|
|
||||||
import { GraphExecutionJobInfo } from "@/app/api/__generated__/models/graphExecutionJobInfo";
|
|
||||||
import { GraphExecutionMeta } from "@/app/api/__generated__/models/graphExecutionMeta";
|
|
||||||
import { RunAgentModal } from "@/app/(platform)/library/agents/[id]/components/NewAgentLibraryView/components/modals/RunAgentModal/RunAgentModal";
|
|
||||||
import { Button } from "@/components/atoms/Button/Button";
|
|
||||||
import { Text } from "@/components/atoms/Text/Text";
|
|
||||||
import {
|
|
||||||
CheckCircleIcon,
|
|
||||||
PencilLineIcon,
|
|
||||||
PlayIcon,
|
|
||||||
} from "@phosphor-icons/react";
|
|
||||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
|
||||||
|
|
||||||
interface Props {
|
|
||||||
agentName: string;
|
|
||||||
libraryAgentId: string;
|
|
||||||
onSendMessage?: (content: string) => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function AgentCreatedPrompt({
|
|
||||||
agentName,
|
|
||||||
libraryAgentId,
|
|
||||||
onSendMessage,
|
|
||||||
}: Props) {
|
|
||||||
// Fetch library agent eagerly so modal is ready when user clicks
|
|
||||||
const { data: libraryAgentResponse, isLoading } = useGetV2GetLibraryAgent(
|
|
||||||
libraryAgentId,
|
|
||||||
{
|
|
||||||
query: {
|
|
||||||
enabled: !!libraryAgentId,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
const libraryAgent =
|
|
||||||
libraryAgentResponse?.status === 200 ? libraryAgentResponse.data : null;
|
|
||||||
|
|
||||||
function handleRunWithPlaceholders() {
|
|
||||||
onSendMessage?.(
|
|
||||||
`Run the agent "${agentName}" with placeholder/example values so I can test it.`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleRunCreated(execution: GraphExecutionMeta) {
|
|
||||||
onSendMessage?.(
|
|
||||||
`I've started the agent "${agentName}". The execution ID is ${execution.id}. Please monitor its progress and let me know when it completes.`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
function handleScheduleCreated(schedule: GraphExecutionJobInfo) {
|
|
||||||
const scheduleInfo = schedule.cron
|
|
||||||
? `with cron schedule "${schedule.cron}"`
|
|
||||||
: "to run on the specified schedule";
|
|
||||||
onSendMessage?.(
|
|
||||||
`I've scheduled the agent "${agentName}" ${scheduleInfo}. The schedule ID is ${schedule.id}.`,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<AIChatBubble>
|
|
||||||
<div className="flex flex-col gap-4">
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<div className="flex h-8 w-8 items-center justify-center rounded-full bg-green-100">
|
|
||||||
<CheckCircleIcon
|
|
||||||
size={18}
|
|
||||||
weight="fill"
|
|
||||||
className="text-green-600"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<Text variant="body-medium" className="text-neutral-900">
|
|
||||||
Agent Created Successfully
|
|
||||||
</Text>
|
|
||||||
<Text variant="small" className="text-neutral-500">
|
|
||||||
"{agentName}" is ready to test
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="flex flex-col gap-2">
|
|
||||||
<Text variant="small-medium" className="text-neutral-700">
|
|
||||||
Ready to test?
|
|
||||||
</Text>
|
|
||||||
<div className="flex flex-wrap gap-2">
|
|
||||||
<Button
|
|
||||||
variant="outline"
|
|
||||||
size="small"
|
|
||||||
onClick={handleRunWithPlaceholders}
|
|
||||||
className="gap-2"
|
|
||||||
>
|
|
||||||
<PlayIcon size={16} />
|
|
||||||
Run with example values
|
|
||||||
</Button>
|
|
||||||
{libraryAgent ? (
|
|
||||||
<RunAgentModal
|
|
||||||
triggerSlot={
|
|
||||||
<Button variant="outline" size="small" className="gap-2">
|
|
||||||
<PencilLineIcon size={16} />
|
|
||||||
Run with my inputs
|
|
||||||
</Button>
|
|
||||||
}
|
|
||||||
agent={libraryAgent}
|
|
||||||
onRunCreated={handleRunCreated}
|
|
||||||
onScheduleCreated={handleScheduleCreated}
|
|
||||||
/>
|
|
||||||
) : (
|
|
||||||
<Button
|
|
||||||
variant="outline"
|
|
||||||
size="small"
|
|
||||||
loading={isLoading}
|
|
||||||
disabled
|
|
||||||
className="gap-2"
|
|
||||||
>
|
|
||||||
<PencilLineIcon size={16} />
|
|
||||||
Run with my inputs
|
|
||||||
</Button>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
<Text variant="small" className="text-neutral-500">
|
|
||||||
or just ask me
|
|
||||||
</Text>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</AIChatBubble>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
@@ -2,13 +2,11 @@ import { Text } from "@/components/atoms/Text/Text";
|
|||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import type { ToolResult } from "@/types/chat";
|
import type { ToolResult } from "@/types/chat";
|
||||||
import { WarningCircleIcon } from "@phosphor-icons/react";
|
import { WarningCircleIcon } from "@phosphor-icons/react";
|
||||||
import { AgentCreatedPrompt } from "./AgentCreatedPrompt";
|
|
||||||
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
import { AIChatBubble } from "../AIChatBubble/AIChatBubble";
|
||||||
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
import { MarkdownContent } from "../MarkdownContent/MarkdownContent";
|
||||||
import {
|
import {
|
||||||
formatToolResponse,
|
formatToolResponse,
|
||||||
getErrorMessage,
|
getErrorMessage,
|
||||||
isAgentSavedResponse,
|
|
||||||
isErrorResponse,
|
isErrorResponse,
|
||||||
} from "./helpers";
|
} from "./helpers";
|
||||||
|
|
||||||
@@ -18,7 +16,6 @@ export interface ToolResponseMessageProps {
|
|||||||
result?: ToolResult;
|
result?: ToolResult;
|
||||||
success?: boolean;
|
success?: boolean;
|
||||||
className?: string;
|
className?: string;
|
||||||
onSendMessage?: (content: string) => void;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ToolResponseMessage({
|
export function ToolResponseMessage({
|
||||||
@@ -27,7 +24,6 @@ export function ToolResponseMessage({
|
|||||||
result,
|
result,
|
||||||
success: _success,
|
success: _success,
|
||||||
className,
|
className,
|
||||||
onSendMessage,
|
|
||||||
}: ToolResponseMessageProps) {
|
}: ToolResponseMessageProps) {
|
||||||
if (isErrorResponse(result)) {
|
if (isErrorResponse(result)) {
|
||||||
const errorMessage = getErrorMessage(result);
|
const errorMessage = getErrorMessage(result);
|
||||||
@@ -47,18 +43,6 @@ export function ToolResponseMessage({
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for agent_saved response - show special prompt
|
|
||||||
const agentSavedData = isAgentSavedResponse(result);
|
|
||||||
if (agentSavedData.isSaved) {
|
|
||||||
return (
|
|
||||||
<AgentCreatedPrompt
|
|
||||||
agentName={agentSavedData.agentName}
|
|
||||||
libraryAgentId={agentSavedData.libraryAgentId}
|
|
||||||
onSendMessage={onSendMessage}
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
const formattedText = formatToolResponse(result, toolName);
|
const formattedText = formatToolResponse(result, toolName);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -6,43 +6,6 @@ function stripInternalReasoning(content: string): string {
|
|||||||
.trim();
|
.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface AgentSavedData {
|
|
||||||
isSaved: boolean;
|
|
||||||
agentName: string;
|
|
||||||
agentId: string;
|
|
||||||
libraryAgentId: string;
|
|
||||||
libraryAgentLink: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
export function isAgentSavedResponse(result: unknown): AgentSavedData {
|
|
||||||
if (typeof result !== "object" || result === null) {
|
|
||||||
return {
|
|
||||||
isSaved: false,
|
|
||||||
agentName: "",
|
|
||||||
agentId: "",
|
|
||||||
libraryAgentId: "",
|
|
||||||
libraryAgentLink: "",
|
|
||||||
};
|
|
||||||
}
|
|
||||||
const response = result as Record<string, unknown>;
|
|
||||||
if (response.type === "agent_saved") {
|
|
||||||
return {
|
|
||||||
isSaved: true,
|
|
||||||
agentName: (response.agent_name as string) || "Agent",
|
|
||||||
agentId: (response.agent_id as string) || "",
|
|
||||||
libraryAgentId: (response.library_agent_id as string) || "",
|
|
||||||
libraryAgentLink: (response.library_agent_link as string) || "",
|
|
||||||
};
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
isSaved: false,
|
|
||||||
agentName: "",
|
|
||||||
agentId: "",
|
|
||||||
libraryAgentId: "",
|
|
||||||
libraryAgentLink: "",
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
export function isErrorResponse(result: unknown): boolean {
|
export function isErrorResponse(result: unknown): boolean {
|
||||||
if (typeof result === "string") {
|
if (typeof result === "string") {
|
||||||
const lower = result.toLowerCase();
|
const lower = result.toLowerCase();
|
||||||
|
|||||||
@@ -41,17 +41,7 @@ export function HostScopedCredentialsModal({
|
|||||||
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
|
const currentHost = currentUrl ? getHostFromUrl(currentUrl) : "";
|
||||||
|
|
||||||
const formSchema = z.object({
|
const formSchema = z.object({
|
||||||
host: z
|
host: z.string().min(1, "Host is required"),
|
||||||
.string()
|
|
||||||
.min(1, "Host is required")
|
|
||||||
.refine((val) => !/^[a-zA-Z][a-zA-Z\d+\-.]*:\/\//.test(val), {
|
|
||||||
message: "Enter only the host (e.g. api.example.com), not a full URL",
|
|
||||||
})
|
|
||||||
.refine((val) => !val.includes("/"), {
|
|
||||||
message:
|
|
||||||
"Enter only the host (e.g. api.example.com), without a trailing path. " +
|
|
||||||
"You may specify a port (e.g. api.example.com:8080) if needed.",
|
|
||||||
}),
|
|
||||||
title: z.string().optional(),
|
title: z.string().optional(),
|
||||||
headers: z.record(z.string()).optional(),
|
headers: z.record(z.string()).optional(),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import {
|
|||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||||
|
import { storage, Key as StorageKey } from "@/services/storage/local-storage";
|
||||||
import { WalletIcon } from "@phosphor-icons/react";
|
import { WalletIcon } from "@phosphor-icons/react";
|
||||||
import { PopoverClose } from "@radix-ui/react-popover";
|
import { PopoverClose } from "@radix-ui/react-popover";
|
||||||
import { X } from "lucide-react";
|
import { X } from "lucide-react";
|
||||||
@@ -174,6 +175,7 @@ export function Wallet() {
|
|||||||
const [prevCredits, setPrevCredits] = useState<number | null>(credits);
|
const [prevCredits, setPrevCredits] = useState<number | null>(credits);
|
||||||
const [flash, setFlash] = useState(false);
|
const [flash, setFlash] = useState(false);
|
||||||
const [walletOpen, setWalletOpen] = useState(false);
|
const [walletOpen, setWalletOpen] = useState(false);
|
||||||
|
const [lastSeenCredits, setLastSeenCredits] = useState<number | null>(null);
|
||||||
|
|
||||||
const totalCount = useMemo(() => {
|
const totalCount = useMemo(() => {
|
||||||
return groups.reduce((acc, group) => acc + group.tasks.length, 0);
|
return groups.reduce((acc, group) => acc + group.tasks.length, 0);
|
||||||
@@ -198,6 +200,38 @@ export function Wallet() {
|
|||||||
setCompletedCount(completed);
|
setCompletedCount(completed);
|
||||||
}, [groups, state?.completedSteps]);
|
}, [groups, state?.completedSteps]);
|
||||||
|
|
||||||
|
// Load last seen credits from localStorage once on mount
|
||||||
|
useEffect(() => {
|
||||||
|
const stored = storage.get(StorageKey.WALLET_LAST_SEEN_CREDITS);
|
||||||
|
if (stored !== undefined && stored !== null) {
|
||||||
|
const parsed = parseFloat(stored);
|
||||||
|
if (!Number.isNaN(parsed)) setLastSeenCredits(parsed);
|
||||||
|
else setLastSeenCredits(0);
|
||||||
|
} else {
|
||||||
|
setLastSeenCredits(0);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Auto-open once if never shown, otherwise open only when credits increase beyond last seen
|
||||||
|
useEffect(() => {
|
||||||
|
if (typeof credits !== "number") return;
|
||||||
|
// Open once for first-time users
|
||||||
|
if (state && state.walletShown === false) {
|
||||||
|
requestAnimationFrame(() => setWalletOpen(true));
|
||||||
|
// Mark as shown so it won't reopen on every reload
|
||||||
|
updateState({ walletShown: true });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Open if user gained more credits than last acknowledged
|
||||||
|
if (
|
||||||
|
lastSeenCredits !== null &&
|
||||||
|
credits > lastSeenCredits &&
|
||||||
|
walletOpen === false
|
||||||
|
) {
|
||||||
|
requestAnimationFrame(() => setWalletOpen(true));
|
||||||
|
}
|
||||||
|
}, [credits, lastSeenCredits, state?.walletShown, updateState, walletOpen]);
|
||||||
|
|
||||||
const onWalletOpen = useCallback(async () => {
|
const onWalletOpen = useCallback(async () => {
|
||||||
if (!state?.walletShown) {
|
if (!state?.walletShown) {
|
||||||
updateState({ walletShown: true });
|
updateState({ walletShown: true });
|
||||||
@@ -290,7 +324,19 @@ export function Wallet() {
|
|||||||
if (credits === null || !state) return null;
|
if (credits === null || !state) return null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Popover open={walletOpen} onOpenChange={(open) => setWalletOpen(open)}>
|
<Popover
|
||||||
|
open={walletOpen}
|
||||||
|
onOpenChange={(open) => {
|
||||||
|
setWalletOpen(open);
|
||||||
|
if (!open) {
|
||||||
|
// Persist the latest acknowledged credits so we only auto-open on future gains
|
||||||
|
if (typeof credits === "number") {
|
||||||
|
storage.set(StorageKey.WALLET_LAST_SEEN_CREDITS, String(credits));
|
||||||
|
setLastSeenCredits(credits);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
>
|
||||||
<PopoverTrigger asChild>
|
<PopoverTrigger asChild>
|
||||||
<div className="relative inline-block">
|
<div className="relative inline-block">
|
||||||
<button
|
<button
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
|||||||
| [Get Store Agent Details](block-integrations/system/store_operations.md#get-store-agent-details) | Get detailed information about an agent from the store |
|
| [Get Store Agent Details](block-integrations/system/store_operations.md#get-store-agent-details) | Get detailed information about an agent from the store |
|
||||||
| [Get Weather Information](block-integrations/basic.md#get-weather-information) | Retrieves weather information for a specified location using OpenWeatherMap API |
|
| [Get Weather Information](block-integrations/basic.md#get-weather-information) | Retrieves weather information for a specified location using OpenWeatherMap API |
|
||||||
| [Human In The Loop](block-integrations/basic.md#human-in-the-loop) | Pause execution and wait for human approval or modification of data |
|
| [Human In The Loop](block-integrations/basic.md#human-in-the-loop) | Pause execution and wait for human approval or modification of data |
|
||||||
|
| [Linear Search Issues](block-integrations/linear/issues.md#linear-search-issues) | Searches for issues on Linear |
|
||||||
| [List Is Empty](block-integrations/basic.md#list-is-empty) | Checks if a list is empty |
|
| [List Is Empty](block-integrations/basic.md#list-is-empty) | Checks if a list is empty |
|
||||||
| [List Library Agents](block-integrations/system/library_operations.md#list-library-agents) | List all agents in your personal library |
|
| [List Library Agents](block-integrations/system/library_operations.md#list-library-agents) | List all agents in your personal library |
|
||||||
| [Note](block-integrations/basic.md#note) | A visual annotation block that displays a sticky note in the workflow editor for documentation and organization purposes |
|
| [Note](block-integrations/basic.md#note) | A visual annotation block that displays a sticky note in the workflow editor for documentation and organization purposes |
|
||||||
@@ -570,7 +571,6 @@ Below is a comprehensive list of all available blocks, categorized by their prim
|
|||||||
| [Linear Create Comment](block-integrations/linear/comment.md#linear-create-comment) | Creates a new comment on a Linear issue |
|
| [Linear Create Comment](block-integrations/linear/comment.md#linear-create-comment) | Creates a new comment on a Linear issue |
|
||||||
| [Linear Create Issue](block-integrations/linear/issues.md#linear-create-issue) | Creates a new issue on Linear |
|
| [Linear Create Issue](block-integrations/linear/issues.md#linear-create-issue) | Creates a new issue on Linear |
|
||||||
| [Linear Get Project Issues](block-integrations/linear/issues.md#linear-get-project-issues) | Gets issues from a Linear project filtered by status and assignee |
|
| [Linear Get Project Issues](block-integrations/linear/issues.md#linear-get-project-issues) | Gets issues from a Linear project filtered by status and assignee |
|
||||||
| [Linear Search Issues](block-integrations/linear/issues.md#linear-search-issues) | Searches for issues on Linear |
|
|
||||||
| [Linear Search Projects](block-integrations/linear/projects.md#linear-search-projects) | Searches for projects on Linear |
|
| [Linear Search Projects](block-integrations/linear/projects.md#linear-search-projects) | Searches for projects on Linear |
|
||||||
|
|
||||||
## Hardware
|
## Hardware
|
||||||
|
|||||||
@@ -90,9 +90,9 @@ Searches for issues on Linear
|
|||||||
|
|
||||||
### How it works
|
### How it works
|
||||||
<!-- MANUAL: how_it_works -->
|
<!-- MANUAL: how_it_works -->
|
||||||
This block searches for issues in Linear using a text query. It searches across issue titles, descriptions, and other fields to find matching issues. You can limit the number of results returned using the `max_results` parameter (default: 10, max: 100) to control token consumption and response size.
|
This block searches for issues in Linear using a text query. It searches across issue titles, descriptions, and other fields to find matching issues.
|
||||||
|
|
||||||
Optionally filter results by team name to narrow searches to specific workspaces. If a team name is provided, the block resolves it to a team ID before searching. Returns matching issues with their state, creation date, project, and assignee information. If the search or team resolution fails, an error message is returned.
|
Returns a list of issues matching the search term.
|
||||||
<!-- END MANUAL -->
|
<!-- END MANUAL -->
|
||||||
|
|
||||||
### Inputs
|
### Inputs
|
||||||
@@ -100,14 +100,12 @@ Optionally filter results by team name to narrow searches to specific workspaces
|
|||||||
| Input | Description | Type | Required |
|
| Input | Description | Type | Required |
|
||||||
|-------|-------------|------|----------|
|
|-------|-------------|------|----------|
|
||||||
| term | Term to search for issues | str | Yes |
|
| term | Term to search for issues | str | Yes |
|
||||||
| max_results | Maximum number of results to return | int | No |
|
|
||||||
| team_name | Optional team name to filter results (e.g., 'Internal', 'Open Source') | str | No |
|
|
||||||
|
|
||||||
### Outputs
|
### Outputs
|
||||||
|
|
||||||
| Output | Description | Type |
|
| Output | Description | Type |
|
||||||
|--------|-------------|------|
|
|--------|-------------|------|
|
||||||
| error | Error message if the search failed | str |
|
| error | Error message if the operation failed | str |
|
||||||
| issues | List of issues | List[Issue] |
|
| issues | List of issues | List[Issue] |
|
||||||
|
|
||||||
### Possible use case
|
### Possible use case
|
||||||
|
|||||||
Reference in New Issue
Block a user