mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-19 20:18:22 -05:00
Compare commits
25 Commits
fix/undefi
...
swiftyos/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
821532d068 | ||
|
|
8257b62100 | ||
|
|
4c67ed89af | ||
|
|
9a209c54ea | ||
|
|
8160e3a6a1 | ||
|
|
3549b99b23 | ||
|
|
bc54ccdb8e | ||
|
|
a53545abbc | ||
|
|
e404544db2 | ||
|
|
951a8ac068 | ||
|
|
4ef4b3e04d | ||
|
|
833971b861 | ||
|
|
6addea33f7 | ||
|
|
6475557351 | ||
|
|
2ad4a2ac09 | ||
|
|
9d72b444c6 | ||
|
|
7b79954003 | ||
|
|
c4fe4f2233 | ||
|
|
87a45d75cf | ||
|
|
04460e43f3 | ||
|
|
0d93315767 | ||
|
|
7528920b5d | ||
|
|
911b119296 | ||
|
|
df50c7352d | ||
|
|
7db3db6b2e |
@@ -1,12 +1,14 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import Future
|
||||
from typing import Mapping, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, JsonValue, ValidationError
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
|
||||
from pydantic import BaseModel, JsonValue, SecretStr, ValidationError
|
||||
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data import graph as graph_db
|
||||
@@ -34,6 +36,7 @@ from backend.data.graph import GraphModel, Node
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_async_execution_queue,
|
||||
@@ -46,6 +49,66 @@ from backend.util.settings import Config
|
||||
from backend.util.type import convert
|
||||
|
||||
|
||||
def get_system_credentials():
|
||||
"""Get system-provided credentials from environment and AutoRegistry."""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
system_credentials = {}
|
||||
|
||||
try:
|
||||
# Get SDK-registered credentials
|
||||
system_creds_list = AutoRegistry.get_all_credentials()
|
||||
for cred in system_creds_list:
|
||||
system_credentials[f"system-{cred.provider}"] = cred
|
||||
|
||||
# WORKAROUND: Check for common LLM providers that don't use SDK pattern
|
||||
# System credentials never expire - set to far future (Unix timestamp)
|
||||
expires_at = int(
|
||||
(datetime.utcnow() + timedelta(days=36500)).timestamp()
|
||||
) # 100 years
|
||||
|
||||
# Check for OpenAI
|
||||
if "system-openai" not in system_credentials:
|
||||
openai_key = os.getenv("OPENAI_API_KEY")
|
||||
if openai_key:
|
||||
system_credentials["system-openai"] = APIKeyCredentials(
|
||||
id="system-openai",
|
||||
provider="openai",
|
||||
api_key=SecretStr(openai_key),
|
||||
title="System OpenAI API Key",
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
# Check for Anthropic
|
||||
if "system-anthropic" not in system_credentials:
|
||||
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if anthropic_key:
|
||||
system_credentials["system-anthropic"] = APIKeyCredentials(
|
||||
id="system-anthropic",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr(anthropic_key),
|
||||
title="System Anthropic API Key",
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
# Check for Replicate
|
||||
if "system-replicate" not in system_credentials:
|
||||
replicate_key = os.getenv("REPLICATE_API_KEY")
|
||||
if replicate_key:
|
||||
system_credentials["system-replicate"] = APIKeyCredentials(
|
||||
id="system-replicate",
|
||||
provider="replicate",
|
||||
api_key=SecretStr(replicate_key),
|
||||
title="System Replicate API Key",
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting system credentials: {e}")
|
||||
|
||||
return system_credentials
|
||||
|
||||
|
||||
async def get_user_context(user_id: str) -> UserContext:
|
||||
"""
|
||||
Get UserContext for a user, always returns a valid context with timezone.
|
||||
@@ -291,17 +354,27 @@ async def _validate_node_input_credentials(
|
||||
credential_errors[node.id][field_name] = f"Invalid credentials: {e}"
|
||||
continue
|
||||
|
||||
try:
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = await get_integration_credentials_store().get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any errors fetching credentials
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = f"Credentials not available: {e}"
|
||||
continue
|
||||
# Check if this is a system credential first
|
||||
credentials = None
|
||||
if credentials_meta.id.startswith("system-"):
|
||||
system_creds = get_system_credentials()
|
||||
credentials = system_creds.get(credentials_meta.id)
|
||||
|
||||
# If not a system credential or not found, try user credentials
|
||||
if not credentials:
|
||||
try:
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = (
|
||||
await get_integration_credentials_store().get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle any errors fetching credentials
|
||||
credential_errors[node.id][
|
||||
field_name
|
||||
] = f"Credentials not available: {e}"
|
||||
continue
|
||||
|
||||
if not credentials:
|
||||
credential_errors[node.id][
|
||||
@@ -346,9 +419,13 @@ def make_node_credentials_input_map(
|
||||
# Get aggregated credentials fields for the graph
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
|
||||
logging.debug(f"Graph expects credentials: {list(graph_cred_inputs.keys())}")
|
||||
logging.debug(f"Provided credentials: {list(graph_credentials_input.keys())}")
|
||||
|
||||
for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items():
|
||||
# Best-effort map: skip missing items
|
||||
if graph_input_name not in graph_credentials_input:
|
||||
logging.warning(f"Missing credential for graph input: {graph_input_name}")
|
||||
continue
|
||||
|
||||
# Use passed-in credentials for all compatible node input fields
|
||||
|
||||
@@ -27,6 +27,7 @@ import backend.server.v2.admin.credit_admin_routes
|
||||
import backend.server.v2.admin.store_admin_routes
|
||||
import backend.server.v2.builder
|
||||
import backend.server.v2.builder.routes
|
||||
import backend.server.v2.chat.routes
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.library.routes
|
||||
@@ -272,6 +273,11 @@ app.include_router(
|
||||
tags=["v2", "turnstile"],
|
||||
prefix="/api/turnstile",
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.chat.routes.router,
|
||||
tags=["v2", "chat"],
|
||||
prefix="/api/v2/chat",
|
||||
)
|
||||
|
||||
app.include_router(
|
||||
backend.server.routers.postmark.postmark.router,
|
||||
|
||||
426
autogpt_platform/backend/backend/server/v2/chat/chat.py
Normal file
426
autogpt_platform/backend/backend/server/v2/chat/chat.py
Normal file
@@ -0,0 +1,426 @@
|
||||
"""Chat streaming functions for handling OpenAI chat completions with tool calling."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import (
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
)
|
||||
from prisma.enums import ChatMessageRole
|
||||
|
||||
from backend.server.v2.chat import db
|
||||
from backend.server.v2.chat.config import get_config
|
||||
from backend.server.v2.chat.models import (
|
||||
Error,
|
||||
LoginNeeded,
|
||||
StreamEnd,
|
||||
TextChunk,
|
||||
ToolCall,
|
||||
ToolResponse,
|
||||
)
|
||||
from backend.server.v2.chat.tool_exports import execute_tool, tools
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Global client cache
|
||||
_client_cache: AsyncOpenAI | None = None
|
||||
|
||||
|
||||
def get_openai_client(force_new: bool = False) -> AsyncOpenAI:
|
||||
"""Get or create an OpenAI client instance.
|
||||
|
||||
Args:
|
||||
force_new: Force creation of a new client instance
|
||||
|
||||
Returns:
|
||||
AsyncOpenAI client instance
|
||||
|
||||
"""
|
||||
global _client_cache
|
||||
config = get_config()
|
||||
|
||||
if not force_new and config.cache_client and _client_cache is not None:
|
||||
return _client_cache
|
||||
|
||||
# Create new client with configuration
|
||||
client_kwargs = {}
|
||||
if config.api_key:
|
||||
client_kwargs["api_key"] = config.api_key
|
||||
if config.base_url:
|
||||
client_kwargs["base_url"] = config.base_url
|
||||
|
||||
client = AsyncOpenAI(**client_kwargs)
|
||||
|
||||
# Cache if configured
|
||||
if config.cache_client:
|
||||
_client_cache = client
|
||||
|
||||
return client
|
||||
|
||||
|
||||
async def stream_chat_response(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
user_id: str,
|
||||
model: str | None = None,
|
||||
on_assistant_message: (
|
||||
Callable[[str, list[dict[str, Any]] | None], Awaitable[None]] | None
|
||||
) = None,
|
||||
on_tool_response: Callable[[str, str, str], Awaitable[None]] | None = None,
|
||||
session_id: str | None = None, # Optional for login needed responses
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Pure streaming function for OpenAI chat completions with tool calling.
|
||||
|
||||
This function is database-agnostic and focuses only on streaming logic.
|
||||
|
||||
Args:
|
||||
messages: Conversation context as ChatCompletionMessageParam list
|
||||
user_id: User ID for tool execution
|
||||
model: OpenAI model to use (overrides config)
|
||||
on_assistant_message: Callback for assistant messages (content, tool_calls)
|
||||
on_tool_response: Callback for tool responses (tool_call_id, content, role)
|
||||
session_id: Optional session ID for login responses
|
||||
|
||||
Yields:
|
||||
SSE formatted JSON response objects
|
||||
|
||||
"""
|
||||
config = get_config()
|
||||
model = model or config.model
|
||||
|
||||
try:
|
||||
logger.info("Starting pure chat stream")
|
||||
|
||||
# Get OpenAI client
|
||||
client = get_openai_client()
|
||||
|
||||
# Loop to handle tool calls and continue conversation
|
||||
while True:
|
||||
try:
|
||||
logger.info("Creating OpenAI chat completion stream...")
|
||||
|
||||
# Create the stream with proper types
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Variables to accumulate the response
|
||||
assistant_message: str = ""
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Process the stream
|
||||
chunk: ChatCompletionChunk
|
||||
async for chunk in stream:
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
# Capture finish reason
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
logger.info(f"Finish reason: {finish_reason}")
|
||||
|
||||
# Handle content streaming
|
||||
if delta.content:
|
||||
assistant_message += delta.content
|
||||
# Stream the text chunk
|
||||
text_response = TextChunk(
|
||||
content=delta.content,
|
||||
timestamp=datetime.utcnow().isoformat(),
|
||||
)
|
||||
yield text_response.to_sse()
|
||||
|
||||
# Handle tool calls
|
||||
if delta.tool_calls:
|
||||
for tc_chunk in delta.tool_calls:
|
||||
idx = tc_chunk.index
|
||||
|
||||
# Ensure we have a tool call object at this index
|
||||
while len(tool_calls) <= idx:
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": "",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Accumulate the tool call data
|
||||
if tc_chunk.id:
|
||||
tool_calls[idx]["id"] = tc_chunk.id
|
||||
if tc_chunk.function:
|
||||
if tc_chunk.function.name:
|
||||
tool_calls[idx]["function"][
|
||||
"name"
|
||||
] = tc_chunk.function.name
|
||||
if tc_chunk.function.arguments:
|
||||
tool_calls[idx]["function"][
|
||||
"arguments"
|
||||
] += tc_chunk.function.arguments
|
||||
|
||||
logger.info(f"Stream complete. Finish reason: {finish_reason}")
|
||||
|
||||
# Notify about assistant message if callback provided
|
||||
if on_assistant_message and (assistant_message or tool_calls):
|
||||
await on_assistant_message(
|
||||
assistant_message if assistant_message else "",
|
||||
tool_calls if tool_calls else None,
|
||||
)
|
||||
|
||||
# Check if we need to execute tools
|
||||
if finish_reason == "tool_calls" and tool_calls:
|
||||
logger.info(f"Processing {len(tool_calls)} tool call(s)")
|
||||
|
||||
# Add assistant message with tool calls to context
|
||||
assistant_msg = {
|
||||
"role": "assistant",
|
||||
"content": assistant_message if assistant_message else None,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
messages.append(assistant_msg) # type: ignore
|
||||
|
||||
# Process each tool call
|
||||
for tool_call in tool_calls:
|
||||
tool_name: str = tool_call.get("function", {}).get(
|
||||
"name",
|
||||
"",
|
||||
)
|
||||
tool_id: str = tool_call.get("id", "")
|
||||
|
||||
# Parse arguments
|
||||
try:
|
||||
tool_args: dict[str, Any] = json.loads(
|
||||
tool_call.get("function", {}).get("arguments", "{}"),
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
tool_args = {}
|
||||
|
||||
logger.info(
|
||||
f"Executing tool: {tool_name} with args: {tool_args}",
|
||||
)
|
||||
|
||||
# Stream tool call notification
|
||||
tool_call_response = ToolCall(
|
||||
tool_id=tool_id,
|
||||
tool_name=tool_name,
|
||||
arguments=tool_args,
|
||||
timestamp=datetime.utcnow().isoformat(),
|
||||
)
|
||||
yield tool_call_response.to_sse()
|
||||
|
||||
# Small delay for UI responsiveness
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Execute the tool (returns JSON string)
|
||||
tool_result_str = await execute_tool(
|
||||
tool_name,
|
||||
tool_args,
|
||||
user_id=user_id,
|
||||
session_id=session_id or "",
|
||||
)
|
||||
|
||||
# Parse the JSON result
|
||||
try:
|
||||
tool_result = json.loads(tool_result_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If not JSON, use as string
|
||||
tool_result = tool_result_str
|
||||
|
||||
# Check for special responses (login needed, etc.)
|
||||
if isinstance(tool_result, dict):
|
||||
result_type = tool_result.get("type")
|
||||
if result_type == "need_login":
|
||||
login_response = LoginNeeded(
|
||||
message=tool_result.get(
|
||||
"message", "Authentication required"
|
||||
),
|
||||
session_id=session_id or "",
|
||||
agent_info=tool_result.get("agent_info"),
|
||||
timestamp=datetime.utcnow().isoformat(),
|
||||
)
|
||||
yield login_response.to_sse()
|
||||
else:
|
||||
# Stream tool response
|
||||
tool_response = ToolResponse(
|
||||
tool_id=tool_id,
|
||||
tool_name=tool_name,
|
||||
result=tool_result,
|
||||
success=True,
|
||||
timestamp=datetime.utcnow().isoformat(),
|
||||
)
|
||||
yield tool_response.to_sse()
|
||||
else:
|
||||
# Stream tool response
|
||||
tool_response = ToolResponse(
|
||||
tool_id=tool_id,
|
||||
tool_name=tool_name,
|
||||
result=tool_result_str, # Use original string
|
||||
success=True,
|
||||
timestamp=datetime.utcnow().isoformat(),
|
||||
)
|
||||
yield tool_response.to_sse()
|
||||
|
||||
logger.info(
|
||||
f"Tool result: {tool_result_str[:200] if len(tool_result_str) > 200 else tool_result_str}"
|
||||
)
|
||||
|
||||
# Notify about tool response if callback provided
|
||||
if on_tool_response:
|
||||
await on_tool_response(
|
||||
tool_id,
|
||||
tool_result_str, # Already a string
|
||||
"tool",
|
||||
)
|
||||
|
||||
# Add tool result to context
|
||||
tool_msg: ChatCompletionToolMessageParam = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"content": tool_result_str, # Already JSON string
|
||||
}
|
||||
messages.append(tool_msg)
|
||||
|
||||
# Continue the loop to get final response
|
||||
logger.info("Making follow-up call with tool results...")
|
||||
continue
|
||||
else:
|
||||
# No tool calls, conversation complete
|
||||
logger.info("Conversation complete")
|
||||
|
||||
# Send stream end marker
|
||||
end_response = StreamEnd(
|
||||
timestamp=datetime.utcnow().isoformat(),
|
||||
summary={
|
||||
"message_count": len(messages),
|
||||
"had_tool_calls": len(tool_calls) > 0,
|
||||
},
|
||||
)
|
||||
yield end_response.to_sse()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream: {e!s}", exc_info=True)
|
||||
error_response = Error(
|
||||
message=str(e),
|
||||
timestamp=datetime.utcnow().isoformat(),
|
||||
)
|
||||
yield error_response.to_sse()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream_chat_response: {e!s}", exc_info=True)
|
||||
error_response = Error(
|
||||
message=str(e),
|
||||
timestamp=datetime.utcnow().isoformat(),
|
||||
)
|
||||
yield error_response.to_sse()
|
||||
|
||||
|
||||
# Wrapper function that handles database operations
|
||||
async def stream_chat_completion(
|
||||
session_id: str,
|
||||
user_message: str,
|
||||
user_id: str,
|
||||
model: str = "gpt-4o",
|
||||
max_messages: int = 50,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Main entry point for streaming chat completions with database handling.
|
||||
|
||||
This function handles all database operations and delegates streaming
|
||||
to the pure stream_chat_response function.
|
||||
|
||||
Args:
|
||||
session_id: Chat session ID
|
||||
user_message: User's input message
|
||||
user_id: User ID for authentication
|
||||
model: OpenAI model to use
|
||||
max_messages: Maximum context messages to include
|
||||
|
||||
Yields:
|
||||
SSE formatted JSON strings with response data
|
||||
|
||||
"""
|
||||
config = get_config()
|
||||
logger.warn(
|
||||
f"Streaming chat completion for session {session_id} with user {user_id} and message {user_message}"
|
||||
)
|
||||
# Store user message in database
|
||||
await db.create_chat_message(
|
||||
session_id=session_id,
|
||||
content=user_message,
|
||||
role=ChatMessageRole.USER,
|
||||
)
|
||||
|
||||
# Get conversation context (already typed as List[ChatCompletionMessageParam])
|
||||
context = await db.get_conversation_context(
|
||||
session_id=session_id,
|
||||
max_messages=max_messages,
|
||||
include_system=True,
|
||||
)
|
||||
|
||||
# Add system prompt if this is the first message
|
||||
if not any(msg.get("role") == "system" for msg in context):
|
||||
system_prompt = config.get_system_prompt()
|
||||
system_message: ChatCompletionMessageParam = {
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
}
|
||||
context.insert(0, system_message)
|
||||
|
||||
# Add current user message to context
|
||||
user_msg: ChatCompletionMessageParam = {
|
||||
"role": "user",
|
||||
"content": user_message,
|
||||
}
|
||||
context.append(user_msg)
|
||||
|
||||
# Define database callbacks
|
||||
async def save_assistant_message(
|
||||
content: str, tool_calls: list[dict[str, Any]] | None
|
||||
) -> None:
|
||||
"""Save assistant message to database."""
|
||||
await db.create_chat_message(
|
||||
session_id=session_id,
|
||||
content=content,
|
||||
role=ChatMessageRole.ASSISTANT,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
async def save_tool_response(tool_call_id: str, content: str, role: str) -> None:
|
||||
"""Save tool response to database."""
|
||||
await db.create_chat_message(
|
||||
session_id=session_id,
|
||||
content=content,
|
||||
role=ChatMessageRole.TOOL,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
|
||||
# Stream the response using the pure function
|
||||
async for chunk in stream_chat_response(
|
||||
messages=context,
|
||||
user_id=user_id,
|
||||
model=model,
|
||||
on_assistant_message=save_assistant_message,
|
||||
on_tool_response=save_tool_response,
|
||||
session_id=session_id,
|
||||
):
|
||||
yield chunk
|
||||
207
autogpt_platform/backend/backend/server/v2/chat/config.py
Normal file
207
autogpt_platform/backend/backend/server/v2/chat/config.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Configuration management for chat system."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="qwen/qwen3-235b-a22b-2507", description="Default model to use"
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
default="https://openrouter.ai/api/v1",
|
||||
description="Base URL for API (e.g., for OpenRouter)",
|
||||
)
|
||||
|
||||
# System Prompt Configuration
|
||||
system_prompt_path: str = Field(
|
||||
default="prompts/chat_system.md",
|
||||
description="Path to system prompt file relative to chat module",
|
||||
)
|
||||
|
||||
# Streaming Configuration
|
||||
max_context_messages: int = Field(
|
||||
default=50, ge=1, le=200, description="Maximum context messages"
|
||||
)
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
|
||||
# Client Configuration
|
||||
cache_client: bool = Field(
|
||||
default=True, description="Whether to cache the OpenAI client"
|
||||
)
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
def get_api_key(cls, v):
|
||||
"""Get API key from environment if not provided."""
|
||||
if v is None:
|
||||
# Try to get from environment variables
|
||||
# First check for CHAT_API_KEY (Pydantic prefix)
|
||||
v = os.getenv("CHAT_API_KEY")
|
||||
if not v:
|
||||
# Fall back to OPEN_ROUTER_API_KEY
|
||||
v = os.getenv("OPEN_ROUTER_API_KEY")
|
||||
if not v:
|
||||
# Fall back to OPENAI_API_KEY
|
||||
v = os.getenv("OPENAI_API_KEY")
|
||||
return v
|
||||
|
||||
@field_validator("base_url", mode="before")
|
||||
@classmethod
|
||||
def get_base_url(cls, v):
|
||||
"""Get base URL from environment if not provided."""
|
||||
if v is None:
|
||||
# Check for OpenRouter or custom base URL
|
||||
v = os.getenv("OPENAI_BASE_URL") or os.getenv("OPENROUTER_BASE_URL")
|
||||
if os.getenv("USE_OPENROUTER") == "true" and not v:
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
def get_system_prompt(self, **template_vars) -> str:
|
||||
"""Load and render the system prompt from file.
|
||||
|
||||
Args:
|
||||
**template_vars: Variables to substitute in the template
|
||||
|
||||
Returns:
|
||||
Rendered system prompt string
|
||||
|
||||
"""
|
||||
# Get the path relative to this module
|
||||
module_dir = Path(__file__).parent
|
||||
prompt_path = module_dir / self.system_prompt_path
|
||||
|
||||
# Check for .j2 extension first (Jinja2 template)
|
||||
j2_path = Path(str(prompt_path) + ".j2")
|
||||
if j2_path.exists():
|
||||
try:
|
||||
from jinja2 import Template
|
||||
|
||||
template = Template(j2_path.read_text())
|
||||
return template.render(**template_vars)
|
||||
except ImportError:
|
||||
# Jinja2 not installed, fall back to reading as plain text
|
||||
return j2_path.read_text()
|
||||
|
||||
# Check for markdown file
|
||||
if prompt_path.exists():
|
||||
content = prompt_path.read_text()
|
||||
|
||||
# Simple variable substitution if Jinja2 is not available
|
||||
for key, value in template_vars.items():
|
||||
placeholder = f"{{{key}}}"
|
||||
content = content.replace(placeholder, str(value))
|
||||
|
||||
return content
|
||||
|
||||
# Fallback to default system prompt if file not found
|
||||
return self._get_default_system_prompt()
|
||||
|
||||
def _get_default_system_prompt(self) -> str:
|
||||
"""Get the default system prompt if file is not found."""
|
||||
return """# AutoGPT Agent Setup Assistant
|
||||
|
||||
You help users find and set up AutoGPT agents to solve their business problems. **Bias toward action** - move quickly to get agents running.
|
||||
|
||||
## THE FLOW (Always Follow This Order)
|
||||
|
||||
1. **find_agent** → Search for agents that solve their problem
|
||||
2. **get_agent_details** → Get comprehensive info about chosen agent
|
||||
3. **get_required_setup_info** → Verify user has required credentials (MANDATORY before next step)
|
||||
4. **setup_agent** or **run_agent** → Execute the agent
|
||||
|
||||
## YOUR APPROACH
|
||||
|
||||
### STEP 1: UNDERSTAND THE PROBLEM (Quick)
|
||||
- One or two targeted questions max
|
||||
- What business problem are they trying to solve?
|
||||
- Move quickly to searching for solutions
|
||||
|
||||
### STEP 2: FIND AGENTS
|
||||
- Use `find_agent` immediately with relevant keywords
|
||||
- Suggest the best option based on what you know
|
||||
- Explain briefly how it solves their problem
|
||||
- Ask them if they would like to use it, if they do move to step 3
|
||||
|
||||
### STEP 3: GET DETAILS
|
||||
- Use `get_agent_details` on their chosen agent
|
||||
- Explain what the agent does and its requirements
|
||||
- Keep explanations brief and outcome-focused
|
||||
|
||||
### STEP 4: VERIFY SETUP (CRITICAL)
|
||||
- **ALWAYS** use `get_required_setup_info` before proceeding
|
||||
- This checks if user has all required credentials
|
||||
- Tell user what credentials they need (if any)
|
||||
- Explain credentials are added via the frontend interface
|
||||
|
||||
### STEP 5: EXECUTE
|
||||
- Once credentials verified, use `setup_agent` for scheduled runs OR `run_agent` for immediate execution
|
||||
- Confirm successful setup/run
|
||||
- Provide clear next steps
|
||||
|
||||
## KEY RULES
|
||||
|
||||
### What You DON'T Do:
|
||||
- Don't help with login (frontend handles this)
|
||||
- Don't help add credentials (frontend handles this)
|
||||
- Don't skip `get_required_setup_info` (it's mandatory)
|
||||
- Don't over-explain technical details
|
||||
- Don't use ** to highlight text
|
||||
|
||||
### What You DO:
|
||||
- Act fast - get to agent discovery quickly
|
||||
- Use tools proactively without asking permission
|
||||
- Keep explanations short and business-focused
|
||||
- Always verify credentials before setup/run
|
||||
- Focus on outcomes and value
|
||||
|
||||
### Error Handling:
|
||||
- If authentication needed → Tell user to sign in via the interface
|
||||
- If credentials missing → Tell user what's needed and where to add them in the frontend
|
||||
- If setup fails → Identify issue, provide clear fix
|
||||
|
||||
## SUCCESS LOOKS LIKE:
|
||||
- User has an agent running within minutes
|
||||
- User understands what their agent does
|
||||
- User knows how to use their agent going forward
|
||||
- Minimal back-and-forth, maximum action
|
||||
|
||||
**Remember: Speed to value. Find agent → Get details → Verify credentials → Run. Keep it simple, keep it moving.**"""
|
||||
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore" # Ignore extra environment variables
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
_config: ChatConfig | None = None
|
||||
|
||||
|
||||
def get_config() -> ChatConfig:
|
||||
"""Get or create the chat configuration."""
|
||||
global _config
|
||||
if _config is None:
|
||||
_config = ChatConfig()
|
||||
return _config
|
||||
|
||||
|
||||
def set_config(config: ChatConfig) -> None:
|
||||
"""Set the chat configuration."""
|
||||
global _config
|
||||
_config = config
|
||||
|
||||
|
||||
def reset_config() -> None:
|
||||
"""Reset the configuration to defaults."""
|
||||
global _config
|
||||
_config = None
|
||||
311
autogpt_platform/backend/backend/server/v2/chat/db.py
Normal file
311
autogpt_platform/backend/backend/server/v2/chat/db.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""Database operations for chat functionality."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from prisma import Json
|
||||
from prisma.enums import ChatMessageRole
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ========== ChatSession Functions ==========
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
user_id: str,
|
||||
) -> prisma.models.ChatSession:
|
||||
"""Create a new chat session for a user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user creating the session
|
||||
|
||||
Returns:
|
||||
The created ChatSession object
|
||||
|
||||
"""
|
||||
# For anonymous users, create a temporary user record
|
||||
if user_id.startswith("anon_"):
|
||||
# Check if anonymous user already exists
|
||||
existing_user = await prisma.models.User.prisma().find_unique(
|
||||
where={"id": user_id},
|
||||
)
|
||||
|
||||
if not existing_user:
|
||||
# Create anonymous user with minimal data
|
||||
await prisma.models.User.prisma().create(
|
||||
data={
|
||||
"id": user_id,
|
||||
"email": f"{user_id}@anonymous.local",
|
||||
"name": "Anonymous User",
|
||||
},
|
||||
)
|
||||
logger.info(f"Created anonymous user: {user_id}")
|
||||
|
||||
return await prisma.models.ChatSession.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
include_messages: bool = False,
|
||||
) -> prisma.models.ChatSession:
|
||||
"""Get a chat session by ID.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session
|
||||
user_id: Optional user ID to verify ownership
|
||||
include_messages: Whether to include messages in the response
|
||||
|
||||
Returns:
|
||||
The ChatSession object
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the session doesn't exist or user doesn't have access
|
||||
|
||||
"""
|
||||
where_clause: dict[str, Any] = {"id": session_id}
|
||||
if user_id:
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
session = await prisma.models.ChatSession.prisma().find_first(
|
||||
where=prisma.types.ChatSessionWhereInput(**where_clause), # type: ignore
|
||||
include={"messages": include_messages} if include_messages else None,
|
||||
)
|
||||
|
||||
if not session:
|
||||
msg = f"Chat session {session_id} not found"
|
||||
raise NotFoundError(msg)
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def list_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
include_last_message: bool = False,
|
||||
) -> list[prisma.models.ChatSession]:
|
||||
"""List chat sessions for a user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user
|
||||
limit: Maximum number of sessions to return
|
||||
offset: Number of sessions to skip
|
||||
include_last_message: Whether to include the last message for each session
|
||||
|
||||
Returns:
|
||||
List of ChatSession objects
|
||||
|
||||
"""
|
||||
where_clause: dict[str, Any] = {"userId": user_id}
|
||||
|
||||
include_clause = None
|
||||
if include_last_message:
|
||||
include_clause = {"messages": {"take": 1, "order": [{"sequence": "desc"}]}}
|
||||
|
||||
return await prisma.models.ChatSession.prisma().find_many(
|
||||
where=prisma.types.ChatSessionWhereInput(**where_clause), # type: ignore
|
||||
include=include_clause, # type: ignore
|
||||
order=[{"updatedAt": "desc"}],
|
||||
skip=offset,
|
||||
take=limit,
|
||||
)
|
||||
|
||||
|
||||
# ========== ChatMessage Functions ==========
|
||||
|
||||
|
||||
async def create_chat_message(
|
||||
session_id: str,
|
||||
content: str,
|
||||
role: ChatMessageRole,
|
||||
sequence: int | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
parent_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
prompt_tokens: int | None = None,
|
||||
completion_tokens: int | None = None,
|
||||
error: str | None = None,
|
||||
) -> prisma.models.ChatMessage:
|
||||
"""Create a new chat message.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the chat session
|
||||
content: The message content
|
||||
role: The role of the message sender
|
||||
sequence: Optional sequence number (auto-incremented if not provided)
|
||||
tool_call_id: For tool responses
|
||||
tool_calls: List of tool calls made by assistant
|
||||
parent_id: Parent message ID for threading
|
||||
metadata: Additional metadata
|
||||
prompt_tokens: Number of prompt tokens used
|
||||
completion_tokens: Number of completion tokens used
|
||||
error: Error message if any
|
||||
|
||||
Returns:
|
||||
The created ChatMessage object
|
||||
|
||||
"""
|
||||
# Auto-increment sequence if not provided
|
||||
if sequence is None:
|
||||
last_message = await prisma.models.ChatMessage.prisma().find_first(
|
||||
where={"sessionId": session_id},
|
||||
order=[{"sequence": "desc"}],
|
||||
)
|
||||
sequence = (last_message.sequence + 1) if last_message else 0
|
||||
|
||||
total_tokens = None
|
||||
if prompt_tokens is not None and completion_tokens is not None:
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
# Build the data dict dynamically to avoid setting None values
|
||||
data: dict[str, Any] = {
|
||||
"sessionId": session_id,
|
||||
"content": content,
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
# Only add optional fields if they have values
|
||||
if tool_call_id:
|
||||
data["toolCallId"] = tool_call_id
|
||||
if tool_calls:
|
||||
data["toolCalls"] = Json(tool_calls) # type: ignore
|
||||
if parent_id:
|
||||
data["parentId"] = parent_id
|
||||
if metadata:
|
||||
data["metadata"] = Json(metadata)
|
||||
if prompt_tokens is not None:
|
||||
data["promptTokens"] = prompt_tokens
|
||||
if completion_tokens is not None:
|
||||
data["completionTokens"] = completion_tokens
|
||||
if total_tokens is not None:
|
||||
data["totalTokens"] = total_tokens
|
||||
if error:
|
||||
data["error"] = error
|
||||
|
||||
message = await prisma.models.ChatMessage.prisma().create(data=prisma.types.ChatMessageCreateInput(**data)) # type: ignore
|
||||
|
||||
# Update session's updatedAt timestamp
|
||||
await prisma.models.ChatSession.prisma().update(where={"id": session_id}, data={})
|
||||
|
||||
return message
|
||||
|
||||
|
||||
async def get_chat_messages(
|
||||
session_id: str,
|
||||
limit: int | None = None,
|
||||
offset: int = 0,
|
||||
parent_id: str | None = None,
|
||||
include_children: bool = False,
|
||||
) -> list[prisma.models.ChatMessage]:
|
||||
"""Get messages for a chat session.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the chat session
|
||||
limit: Maximum number of messages to return
|
||||
offset: Number of messages to skip
|
||||
parent_id: Filter by parent message (for threaded conversations)
|
||||
include_children: Whether to include child messages
|
||||
|
||||
Returns:
|
||||
List of ChatMessage objects ordered by sequence
|
||||
|
||||
"""
|
||||
where_clause: dict[str, Any] = {"sessionId": session_id}
|
||||
|
||||
if parent_id is not None:
|
||||
where_clause["parentId"] = parent_id
|
||||
|
||||
include_clause = {"children": True} if include_children else None
|
||||
|
||||
return await prisma.models.ChatMessage.prisma().find_many(
|
||||
where=prisma.types.ChatMessageWhereInput(**where_clause), # type: ignore
|
||||
include=include_clause, # type: ignore
|
||||
order=[{"sequence": "asc"}],
|
||||
skip=offset,
|
||||
take=limit,
|
||||
)
|
||||
|
||||
|
||||
# ========== Helper Functions ==========
|
||||
|
||||
|
||||
async def get_conversation_context(
|
||||
session_id: str,
|
||||
max_messages: int = 50,
|
||||
include_system: bool = True,
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
"""Get the conversation context formatted for OpenAI API.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the chat session
|
||||
max_messages: Maximum number of messages to include
|
||||
include_system: Whether to include system messages
|
||||
|
||||
Returns:
|
||||
List of ChatCompletionMessageParam for OpenAI API
|
||||
|
||||
"""
|
||||
messages = await get_chat_messages(session_id, limit=max_messages)
|
||||
|
||||
context: list[ChatCompletionMessageParam] = []
|
||||
for msg in messages:
|
||||
if not include_system and msg.role == ChatMessageRole.SYSTEM:
|
||||
continue
|
||||
|
||||
# Handle role - it might be a string or an enum
|
||||
role_value = msg.role.value if hasattr(msg.role, "value") else msg.role
|
||||
role = role_value.lower()
|
||||
|
||||
message: dict[str, Any]
|
||||
|
||||
# Build the message based on role
|
||||
if role == "assistant" and msg.toolCalls:
|
||||
# Assistant message with tool calls
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": msg.content if msg.content else None,
|
||||
"tool_calls": msg.toolCalls,
|
||||
}
|
||||
elif role == "tool":
|
||||
# Tool response message
|
||||
message = {
|
||||
"role": "tool",
|
||||
"content": msg.content,
|
||||
"tool_call_id": msg.toolCallId or "",
|
||||
}
|
||||
elif role == "system":
|
||||
# System message
|
||||
message = {
|
||||
"role": "system",
|
||||
"content": msg.content,
|
||||
}
|
||||
elif role == "user":
|
||||
# User message
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": msg.content,
|
||||
}
|
||||
else:
|
||||
# Default assistant message
|
||||
message = {
|
||||
"role": "assistant",
|
||||
"content": msg.content,
|
||||
}
|
||||
|
||||
context.append(message) # type: ignore
|
||||
|
||||
return context
|
||||
122
autogpt_platform/backend/backend/server/v2/chat/models.py
Normal file
122
autogpt_platform/backend/backend/server/v2/chat/models.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Response models for chat streaming."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of streaming responses."""
|
||||
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TOOL_CALL = "tool_call"
|
||||
TOOL_RESPONSE = "tool_response"
|
||||
LOGIN_NEEDED = "login_needed"
|
||||
ERROR = "error"
|
||||
STREAM_END = "stream_end"
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
"""Base response model for all streaming responses."""
|
||||
|
||||
type: ResponseType
|
||||
timestamp: str | None = None
|
||||
|
||||
|
||||
class TextChunk(BaseResponse):
|
||||
"""Streaming text content from the assistant."""
|
||||
|
||||
type: ResponseType = ResponseType.TEXT_CHUNK
|
||||
content: str = Field(..., description="Text content chunk")
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
class ToolCall(BaseResponse):
|
||||
"""Tool invocation notification."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_CALL
|
||||
tool_id: str = Field(..., description="Unique tool call ID")
|
||||
tool_name: str = Field(..., description="Name of the tool being called")
|
||||
arguments: dict[str, Any] = Field(
|
||||
default_factory=dict, description="Tool arguments"
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
class ToolResponse(BaseResponse):
|
||||
"""Tool execution result."""
|
||||
|
||||
type: ResponseType = ResponseType.TOOL_RESPONSE
|
||||
tool_id: str = Field(..., description="Tool call ID this responds to")
|
||||
tool_name: str = Field(..., description="Name of the tool that was executed")
|
||||
result: str | dict[str, Any] = Field(..., description="Tool execution result")
|
||||
success: bool = Field(
|
||||
default=True, description="Whether the tool execution succeeded"
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
class LoginNeeded(BaseResponse):
|
||||
"""Authentication required notification."""
|
||||
|
||||
type: ResponseType = ResponseType.LOGIN_NEEDED
|
||||
message: str = Field(..., description="Message explaining why login is needed")
|
||||
session_id: str = Field(..., description="Current session ID to preserve")
|
||||
agent_info: dict[str, Any] | None = Field(
|
||||
default=None, description="Agent context if applicable"
|
||||
)
|
||||
required_action: str = Field(
|
||||
default="login", description="Required action (login/signup)"
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
class Error(BaseResponse):
|
||||
"""Error response."""
|
||||
|
||||
type: ResponseType = ResponseType.ERROR
|
||||
message: str = Field(..., description="Error message")
|
||||
code: str | None = Field(default=None, description="Error code")
|
||||
details: dict[str, Any] | None = Field(
|
||||
default=None, description="Additional error details"
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
class StreamEnd(BaseResponse):
|
||||
"""End of stream marker."""
|
||||
|
||||
type: ResponseType = ResponseType.STREAM_END
|
||||
summary: dict[str, Any] | None = Field(
|
||||
default=None, description="Stream summary statistics"
|
||||
)
|
||||
|
||||
def to_sse(self) -> str:
|
||||
"""Convert to SSE format."""
|
||||
return f"data: {self.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
# Additional model for agent carousel data
|
||||
class AgentCarouselData(BaseModel):
|
||||
"""Data structure for agent carousel display."""
|
||||
|
||||
type: str = "agent_carousel"
|
||||
query: str
|
||||
count: int
|
||||
agents: list[dict[str, Any]]
|
||||
@@ -0,0 +1,67 @@
|
||||
# AutoGPT Agent Setup Assistant
|
||||
|
||||
You help users find and set up AutoGPT agents to solve their business problems. **Bias toward action** - move quickly to get agents running.
|
||||
|
||||
## THE FLOW (Always Follow This Order)
|
||||
|
||||
1. **find_agent** → Search for agents that solve their problem
|
||||
2. **get_agent_details** → Get comprehensive info about chosen agent
|
||||
3. **get_required_setup_info** → Verify user has required credentials (MANDATORY before next step)
|
||||
4. **setup_agent** or **run_agent** → Execute the agent
|
||||
|
||||
## YOUR APPROACH
|
||||
|
||||
### STEP 1: UNDERSTAND THE PROBLEM (Quick)
|
||||
- One or two targeted questions max
|
||||
- What business problem are they trying to solve?
|
||||
- Move quickly to searching for solutions
|
||||
|
||||
### STEP 2: FIND AGENTS
|
||||
- Use `find_agent` immediately with relevant keywords
|
||||
- Suggest the best option based on what you know
|
||||
- Explain briefly how it solves their problem
|
||||
- Ask them if they would like to use it, if they do move to step 3
|
||||
|
||||
### STEP 3: GET DETAILS
|
||||
- Use `get_agent_details` on their chosen agent
|
||||
- Explain what the agent does and its requirements
|
||||
- Keep explanations brief and outcome-focused
|
||||
|
||||
### STEP 4: VERIFY SETUP (CRITICAL)
|
||||
- **ALWAYS** use `get_required_setup_info` before proceeding
|
||||
- This checks if user has all required credentials
|
||||
- Tell user what credentials they need (if any)
|
||||
- Explain credentials are added via the frontend interface
|
||||
|
||||
### STEP 5: EXECUTE
|
||||
- Once credentials verified, use `setup_agent` for scheduled runs OR `run_agent` for immediate execution
|
||||
- Confirm successful setup/run
|
||||
- Provide clear next steps
|
||||
|
||||
## KEY RULES
|
||||
|
||||
### What You DON'T Do:
|
||||
- Don't help with login (frontend handles this)
|
||||
- Don't help add credentials (frontend handles this)
|
||||
- Don't skip `get_required_setup_info` (it's mandatory)
|
||||
- Don't over-explain technical details
|
||||
|
||||
### What You DO:
|
||||
- Act fast - get to agent discovery quickly
|
||||
- Use tools proactively without asking permission
|
||||
- Keep explanations short and business-focused
|
||||
- Always verify credentials before setup/run
|
||||
- Focus on outcomes and value
|
||||
|
||||
### Error Handling:
|
||||
- If authentication needed → Tell user to sign in via the interface
|
||||
- If credentials missing → Tell user what's needed and where to add them in the frontend
|
||||
- If setup fails → Identify issue, provide clear fix
|
||||
|
||||
## SUCCESS LOOKS LIKE:
|
||||
- User has an agent running within minutes
|
||||
- User understands what their agent does
|
||||
- User knows how to use their agent going forward
|
||||
- Minimal back-and-forth, maximum action
|
||||
|
||||
**Remember: Speed to value. Find agent → Get details → Verify credentials → Run. Keep it simple, keep it moving.**
|
||||
613
autogpt_platform/backend/backend/server/v2/chat/routes.py
Normal file
613
autogpt_platform/backend/backend/server/v2/chat/routes.py
Normal file
@@ -0,0 +1,613 @@
|
||||
"""Chat API routes for SSE streaming and session management."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
import prisma.models
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from prisma.enums import ChatMessageRole
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.status import HTTP_404_NOT_FOUND
|
||||
|
||||
from backend.server.v2.chat import chat, db
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Optional bearer token authentication
|
||||
optional_bearer = HTTPBearer(auto_error=False)
|
||||
|
||||
router = APIRouter(
|
||||
tags=["chat"],
|
||||
responses={
|
||||
404: {"description": "Resource not found"},
|
||||
401: {"description": "Unauthorized"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_optional_user_id(
|
||||
credentials: HTTPAuthorizationCredentials | None = Security(optional_bearer),
|
||||
) -> str | None:
|
||||
"""Get user ID from auth token if present, otherwise None for anonymous."""
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Parse JWT token to get user ID
|
||||
from autogpt_libs.auth.jwt_utils import parse_jwt_token
|
||||
|
||||
payload = parse_jwt_token(credentials.credentials)
|
||||
return payload.get("sub")
|
||||
except Exception as e:
|
||||
logger.debug(f"Auth token validation failed (anonymous access): {e}")
|
||||
return None
|
||||
|
||||
|
||||
# ========== Request/Response Models ==========
|
||||
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
"""Request model for creating a new chat session."""
|
||||
|
||||
metadata: dict | None = Field(
|
||||
default_factory=dict,
|
||||
description="Optional metadata",
|
||||
)
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
"""Response model for created chat session."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
user_id: str
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
"""Request model for sending a chat message."""
|
||||
|
||||
message: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=10000,
|
||||
description="Message content",
|
||||
)
|
||||
model: str = Field(default="gpt-4o", description="AI model to use")
|
||||
max_context_messages: int = Field(
|
||||
default=50,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Max context messages",
|
||||
)
|
||||
|
||||
|
||||
class SendMessageResponse(BaseModel):
|
||||
"""Response model for non-streaming message."""
|
||||
|
||||
message_id: str
|
||||
content: str
|
||||
role: str
|
||||
tokens_used: dict | None = None
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
"""Response model for session list."""
|
||||
|
||||
sessions: list[dict]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
"""Response model for session details."""
|
||||
|
||||
id: str
|
||||
created_at: str
|
||||
updated_at: str
|
||||
user_id: str
|
||||
messages: list[dict]
|
||||
metadata: dict
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
request: CreateSessionRequest,
|
||||
user_id: Annotated[str | None, Depends(get_optional_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""Create a new chat session for the authenticated or anonymous user.
|
||||
|
||||
Args:
|
||||
request: Session creation parameters
|
||||
user_id: Optional authenticated user ID
|
||||
|
||||
Returns:
|
||||
Created session details
|
||||
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Creating session with user_id: {user_id}")
|
||||
|
||||
# Create the session (anonymous if no user_id)
|
||||
# Use a special anonymous user ID if not authenticated
|
||||
import uuid
|
||||
|
||||
session_user_id = user_id if user_id else f"anon_{uuid.uuid4().hex[:12]}"
|
||||
logger.info(f"Using session_user_id: {session_user_id}")
|
||||
|
||||
session = await db.create_chat_session(user_id=session_user_id)
|
||||
|
||||
logger.info(f"Created chat session {session.id} for user {user_id}")
|
||||
|
||||
return CreateSessionResponse(
|
||||
id=session.id,
|
||||
created_at=session.createdAt.isoformat(),
|
||||
user_id=session.userId,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create session: {e!s}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create session: {e!s}")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def list_sessions(
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
include_last_message: Annotated[bool, Query()] = True,
|
||||
) -> SessionListResponse:
|
||||
"""List chat sessions for the authenticated user.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of sessions to return
|
||||
offset: Number of sessions to skip
|
||||
include_last_message: Whether to include the last message
|
||||
user_id: Authenticated user ID
|
||||
|
||||
Returns:
|
||||
List of user's chat sessions
|
||||
|
||||
"""
|
||||
try:
|
||||
sessions = await db.list_chat_sessions(
|
||||
user_id=user_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_last_message=include_last_message,
|
||||
)
|
||||
|
||||
# Format sessions for response
|
||||
session_list = []
|
||||
for session in sessions:
|
||||
session_dict: dict[str, str | dict[str, str]] = {
|
||||
"id": session.id,
|
||||
"created_at": session.createdAt.isoformat(),
|
||||
"updated_at": session.updatedAt.isoformat(),
|
||||
}
|
||||
|
||||
# Add last message if included
|
||||
if include_last_message and session.messages:
|
||||
last_msg = session.messages[0]
|
||||
session_dict["last_message"] = {
|
||||
"content": (
|
||||
last_msg.content[:100] if last_msg.content else ""
|
||||
), # Preview
|
||||
"role": (
|
||||
last_msg.role.value
|
||||
if hasattr(last_msg.role, "value")
|
||||
else str(last_msg.role)
|
||||
),
|
||||
"created_at": last_msg.createdAt.isoformat(),
|
||||
}
|
||||
|
||||
session_list.append(session_dict)
|
||||
|
||||
return SessionListResponse(
|
||||
sessions=session_list,
|
||||
total=len(session_list), # TODO: Add total count query
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to list sessions: {e!s}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list sessions")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(get_optional_user_id)],
|
||||
include_messages: Annotated[bool, Query()] = True,
|
||||
) -> SessionDetailResponse:
|
||||
"""Get details of a specific chat session.
|
||||
|
||||
Args:
|
||||
session_id: ID of the session to retrieve
|
||||
include_messages: Whether to include all messages
|
||||
user_id: Authenticated user ID
|
||||
|
||||
Returns:
|
||||
Session details with optional messages
|
||||
|
||||
"""
|
||||
try:
|
||||
# For anonymous sessions, we don't check ownership
|
||||
if user_id:
|
||||
# Authenticated user - verify ownership
|
||||
session = await db.get_chat_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
include_messages=include_messages,
|
||||
)
|
||||
else:
|
||||
# Anonymous user - just get the session by ID
|
||||
from backend.data.db import prisma
|
||||
|
||||
if include_messages:
|
||||
session = await prisma.chatsession.find_unique(
|
||||
where={"id": session_id},
|
||||
include={"messages": True},
|
||||
)
|
||||
# Sort messages if they were included
|
||||
if session and session.messages:
|
||||
session.messages.sort(key=lambda m: m.sequence)
|
||||
else:
|
||||
session = await prisma.chatsession.find_unique(
|
||||
where={"id": session_id},
|
||||
)
|
||||
if not session:
|
||||
msg = f"Session {session_id} not found"
|
||||
raise NotFoundError(msg)
|
||||
|
||||
# Format messages if included
|
||||
messages = []
|
||||
if include_messages and session.messages:
|
||||
for msg in session.messages:
|
||||
messages.append(
|
||||
{
|
||||
"id": msg.id,
|
||||
"content": msg.content,
|
||||
"role": msg.role,
|
||||
"created_at": msg.createdAt.isoformat(),
|
||||
"tool_calls": msg.toolCalls,
|
||||
"tool_call_id": msg.toolCallId,
|
||||
"tokens": (
|
||||
{
|
||||
"prompt": msg.promptTokens,
|
||||
"completion": msg.completionTokens,
|
||||
"total": msg.totalTokens,
|
||||
}
|
||||
if msg.totalTokens
|
||||
else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.id,
|
||||
created_at=session.createdAt.isoformat(),
|
||||
updated_at=session.updatedAt.isoformat(),
|
||||
user_id=session.userId,
|
||||
messages=messages,
|
||||
metadata={}, # TODO: Add session metadata support
|
||||
)
|
||||
except NotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail=f"Session {session_id} not found",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to get session: {e!s}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get session")
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}", dependencies=[Security(auth.requires_user)])
|
||||
async def delete_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""Delete a chat session and all its messages.
|
||||
|
||||
Args:
|
||||
session_id: ID of the session to delete
|
||||
user_id: Authenticated user ID
|
||||
|
||||
Returns:
|
||||
Deletion confirmation
|
||||
|
||||
"""
|
||||
try:
|
||||
# Verify ownership first
|
||||
await db.get_chat_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Delete the session (cascade deletes messages)
|
||||
await prisma.models.ChatSession.prisma().delete(where={"id": session_id})
|
||||
|
||||
logger.info(f"Deleted session {session_id} for user {user_id}")
|
||||
|
||||
return {"status": "success", "message": f"Session {session_id} deleted"}
|
||||
except NotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail=f"Session {session_id} not found",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete session: {e!s}")
|
||||
raise HTTPException(status_code=500, detail="Failed to delete session")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/messages",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
)
|
||||
async def send_message(
|
||||
session_id: str,
|
||||
request: SendMessageRequest,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> SendMessageResponse:
|
||||
"""Send a message to a chat session (non-streaming).
|
||||
|
||||
This endpoint processes the message and returns the complete response.
|
||||
For streaming responses, use the /stream endpoint.
|
||||
|
||||
Args:
|
||||
session_id: ID of the session
|
||||
request: Message parameters
|
||||
user_id: Authenticated user ID
|
||||
|
||||
Returns:
|
||||
Complete assistant response
|
||||
|
||||
"""
|
||||
try:
|
||||
# Verify session ownership
|
||||
await db.get_chat_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Store user message
|
||||
await db.create_chat_message(
|
||||
session_id=session_id,
|
||||
content=request.message,
|
||||
role=ChatMessageRole.USER,
|
||||
)
|
||||
|
||||
# Collect the complete response using the refactored function
|
||||
full_response = ""
|
||||
async for chunk in chat.stream_chat_completion(
|
||||
session_id=session_id,
|
||||
user_message=request.message,
|
||||
user_id=user_id,
|
||||
model=request.model,
|
||||
max_messages=request.max_context_messages,
|
||||
):
|
||||
# Parse SSE data
|
||||
if chunk.startswith("data: "):
|
||||
import json
|
||||
|
||||
try:
|
||||
data = json.loads(chunk[6:].strip())
|
||||
if data.get("type") == "text_chunk":
|
||||
full_response += data.get("content", "")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Get the last assistant message for token counts
|
||||
messages = await db.get_chat_messages(session_id=session_id, limit=1)
|
||||
last_msg = messages[0] if messages else None
|
||||
|
||||
tokens_used = None
|
||||
if last_msg and last_msg.totalTokens:
|
||||
tokens_used = {
|
||||
"prompt": last_msg.promptTokens,
|
||||
"completion": last_msg.completionTokens,
|
||||
"total": last_msg.totalTokens,
|
||||
}
|
||||
|
||||
return SendMessageResponse(
|
||||
message_id=last_msg.id if last_msg else "",
|
||||
content=full_response.strip(),
|
||||
role="ASSISTANT",
|
||||
tokens_used=tokens_used,
|
||||
)
|
||||
except NotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail=f"Session {session_id} not found",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to send message: {e!s}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to send message: {e!s}")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
async def stream_chat(
|
||||
session_id: str,
|
||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||
model: Annotated[str, Query()] = "gpt-4o",
|
||||
max_context: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
user_id: str | None = Depends(get_optional_user_id),
|
||||
):
|
||||
"""Stream chat responses using Server-Sent Events (SSE).
|
||||
|
||||
This endpoint streams the AI response in real-time, including:
|
||||
- Text chunks as they're generated
|
||||
- Tool call UI elements
|
||||
- Tool execution results
|
||||
|
||||
Args:
|
||||
session_id: ID of the session
|
||||
message: User's message
|
||||
model: AI model to use
|
||||
max_context: Maximum context messages
|
||||
user_id: Optional authenticated user ID
|
||||
|
||||
Returns:
|
||||
SSE stream of response chunks
|
||||
|
||||
"""
|
||||
try:
|
||||
|
||||
# Get session - allow anonymous access by session ID
|
||||
# For anonymous users, we just verify the session exists
|
||||
if user_id:
|
||||
session = await db.get_chat_session(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
else:
|
||||
# For anonymous, just verify session exists (no ownership check)
|
||||
from backend.data.db import prisma
|
||||
|
||||
session = await prisma.chatsession.find_unique(
|
||||
where={"id": session_id},
|
||||
)
|
||||
if not session:
|
||||
msg = f"Session {session_id} not found"
|
||||
raise NotFoundError(msg)
|
||||
|
||||
# Use the session's user_id for tool execution
|
||||
effective_user_id = user_id if user_id else session.userId
|
||||
|
||||
logger.info(f"Starting SSE stream for session {session_id}")
|
||||
|
||||
# Get the streaming generator using the refactored function
|
||||
stream_generator = chat.stream_chat_completion(
|
||||
session_id=session_id,
|
||||
user_message=message,
|
||||
user_id=effective_user_id,
|
||||
model=model,
|
||||
max_messages=max_context,
|
||||
)
|
||||
|
||||
# Return as SSE stream
|
||||
return StreamingResponse(
|
||||
stream_generator,
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
"Access-Control-Allow-Origin": "*", # TODO: Configure proper CORS
|
||||
},
|
||||
)
|
||||
except NotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail=f"Session {session_id} not found",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to stream chat: {e!s}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to stream: {e!s}")
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sessions/{session_id}/assign-user", dependencies=[Security(auth.requires_user)]
|
||||
)
|
||||
async def assign_user_to_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> dict:
|
||||
"""Assign an authenticated user to an anonymous session.
|
||||
|
||||
This is called after a user logs in to claim their anonymous session.
|
||||
|
||||
Args:
|
||||
session_id: ID of the anonymous session
|
||||
user_id: Authenticated user ID
|
||||
|
||||
Returns:
|
||||
Success status
|
||||
|
||||
"""
|
||||
try:
|
||||
# Get the session (should be anonymous)
|
||||
from backend.data.db import prisma
|
||||
|
||||
session = await prisma.chatsession.find_unique(
|
||||
where={"id": session_id},
|
||||
)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_404_NOT_FOUND,
|
||||
detail=f"Session {session_id} not found",
|
||||
)
|
||||
|
||||
# Check if session is anonymous (starts with anon_)
|
||||
if not session.userId.startswith("anon_"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Session already has an assigned user",
|
||||
)
|
||||
|
||||
# Update the session with the real user ID
|
||||
await prisma.chatsession.update(
|
||||
where={"id": session_id},
|
||||
data={"userId": user_id}, # type: ignore
|
||||
)
|
||||
|
||||
logger.info(f"Assigned user {user_id} to session {session_id}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Session {session_id} assigned to user",
|
||||
"user_id": user_id,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to assign user to session: {e!s}")
|
||||
raise HTTPException(status_code=500, detail="Failed to assign user")
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check() -> dict:
|
||||
"""Check if the chat service is healthy.
|
||||
|
||||
Returns:
|
||||
Health status
|
||||
|
||||
"""
|
||||
try:
|
||||
# Try to get the OpenAI client to verify connectivity
|
||||
from backend.server.v2.chat.config import get_config
|
||||
|
||||
config = get_config()
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "chat",
|
||||
"version": "2.0",
|
||||
"model": config.model,
|
||||
"has_api_key": config.api_key is not None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Health check failed: {e!s}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"error": str(e),
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Chat tools for OpenAI function calling - main exports."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.server.v2.chat.tools import (
|
||||
FindAgentTool,
|
||||
GetAgentDetailsTool,
|
||||
GetRequiredSetupInfoTool,
|
||||
RunAgentTool,
|
||||
SetupAgentTool,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize tool instances
|
||||
find_agent_tool = FindAgentTool()
|
||||
get_agent_details_tool = GetAgentDetailsTool()
|
||||
get_required_setup_info_tool = GetRequiredSetupInfoTool()
|
||||
setup_agent_tool = SetupAgentTool()
|
||||
run_agent_tool = RunAgentTool()
|
||||
|
||||
# Export tools as OpenAI format
|
||||
tools: list[ChatCompletionToolParam] = [
|
||||
find_agent_tool.as_openai_tool(),
|
||||
get_agent_details_tool.as_openai_tool(),
|
||||
get_required_setup_info_tool.as_openai_tool(),
|
||||
setup_agent_tool.as_openai_tool(),
|
||||
run_agent_tool.as_openai_tool(),
|
||||
]
|
||||
|
||||
|
||||
# Tool execution dispatcher
|
||||
async def execute_tool(
|
||||
tool_name: str,
|
||||
parameters: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
) -> str:
|
||||
"""Execute a tool by name with the given parameters.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to execute
|
||||
parameters: Tool parameters
|
||||
user_id: User ID (may be anonymous)
|
||||
session_id: Chat session ID
|
||||
|
||||
Returns:
|
||||
JSON string result from the tool
|
||||
|
||||
"""
|
||||
# Map tool names to instances
|
||||
tool_map = {
|
||||
"find_agent": find_agent_tool,
|
||||
"get_agent_details": get_agent_details_tool,
|
||||
"get_required_setup_info": get_required_setup_info_tool,
|
||||
"setup_agent": setup_agent_tool,
|
||||
"run_agent": run_agent_tool,
|
||||
}
|
||||
|
||||
tool = tool_map.get(tool_name)
|
||||
if not tool:
|
||||
return json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": f"Unknown tool: {tool_name}",
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute tool - returns Pydantic model
|
||||
result = await tool.execute(user_id, session_id, **parameters)
|
||||
|
||||
# Convert Pydantic model to JSON string
|
||||
if isinstance(result, BaseModel):
|
||||
return result.model_dump_json(indent=2)
|
||||
# Fallback for non-Pydantic responses
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {tool_name}: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": f"Tool execution failed: {e!s}",
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Chat tools for agent discovery, setup, and execution."""
|
||||
|
||||
from .base import BaseTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .get_agent_details import GetAgentDetailsTool
|
||||
from .get_required_setup_info import GetRequiredSetupInfoTool
|
||||
from .run_agent import RunAgentTool
|
||||
from .setup_agent import SetupAgentTool
|
||||
|
||||
__all__ = [
|
||||
"CHAT_TOOLS",
|
||||
"BaseTool",
|
||||
"FindAgentTool",
|
||||
"GetAgentDetailsTool",
|
||||
"GetRequiredSetupInfoTool",
|
||||
"RunAgentTool",
|
||||
"SetupAgentTool",
|
||||
"find_agent_tool",
|
||||
"get_agent_details_tool",
|
||||
"get_required_setup_info_tool",
|
||||
"run_agent_tool",
|
||||
"setup_agent_tool",
|
||||
]
|
||||
|
||||
# Initialize all tools
|
||||
find_agent_tool = FindAgentTool()
|
||||
get_agent_details_tool = GetAgentDetailsTool()
|
||||
get_required_setup_info_tool = GetRequiredSetupInfoTool()
|
||||
setup_agent_tool = SetupAgentTool()
|
||||
run_agent_tool = RunAgentTool()
|
||||
|
||||
# Export tool instances
|
||||
CHAT_TOOLS = [
|
||||
find_agent_tool,
|
||||
get_agent_details_tool,
|
||||
get_required_setup_info_tool,
|
||||
setup_agent_tool,
|
||||
run_agent_tool,
|
||||
]
|
||||
100
autogpt_platform/backend/backend/server/v2/chat/tools/base.py
Normal file
100
autogpt_platform/backend/backend/server/v2/chat/tools/base.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Base classes and shared utilities for chat tools."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
|
||||
class BaseTool:
|
||||
"""Base class for all chat tools."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Tool name for OpenAI function calling."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Tool description for OpenAI."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
"""Tool parameters schema for OpenAI."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""Whether this tool requires authentication."""
|
||||
return False
|
||||
|
||||
def as_openai_tool(self) -> ChatCompletionToolParam:
|
||||
"""Convert to OpenAI tool format."""
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
},
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the tool with authentication check.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous like "anon_123")
|
||||
session_id: Chat session ID
|
||||
**kwargs: Tool-specific parameters
|
||||
|
||||
Returns:
|
||||
Pydantic response object
|
||||
|
||||
"""
|
||||
# Check authentication if required
|
||||
if self.requires_auth and (not user_id or user_id.startswith("anon_")):
|
||||
return NeedLoginResponse(
|
||||
message=f"Please sign in to use {self.name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._execute(user_id, session_id, **kwargs)
|
||||
except Exception as e:
|
||||
# Log the error internally but return a safe message
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Error in {self.name}: {e}", exc_info=True)
|
||||
|
||||
return ErrorResponse(
|
||||
message=f"An error occurred while executing {self.name}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Internal execution logic to be implemented by subclasses.
|
||||
|
||||
Args:
|
||||
user_id: User ID (authenticated or anonymous)
|
||||
session_id: Chat session ID
|
||||
**kwargs: Tool-specific parameters
|
||||
|
||||
Returns:
|
||||
Pydantic response object
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,152 @@
|
||||
"""Tool for discovering agents from marketplace and user library."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
AgentCarouselResponse,
|
||||
AgentInfo,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.server.v2.store import db as store_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FindAgentTool(BaseTool):
|
||||
"""Tool for discovering agents based on user needs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Discover agents from the marketplace based on capabilities and user needs."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query describing what the user wants to accomplish",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for agents in the marketplace.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous)
|
||||
session_id: Chat session ID
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
Pydantic response model
|
||||
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Search marketplace agents
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
|
||||
# Use search_store_agents for vector search - limit to 5 agents
|
||||
store_results = await store_db.search_store_agents(
|
||||
search_query=query,
|
||||
limit=5,
|
||||
)
|
||||
logger.info(f"Find agents tool found {len(store_results.agents)} agents")
|
||||
# Format marketplace agents
|
||||
agents = []
|
||||
for agent in store_results.agents:
|
||||
# Build the full agent ID as username/slug for marketplace lookup
|
||||
# Ensure we're using the slug from the agent, not any other ID field
|
||||
agent_slug = agent.slug
|
||||
agent_creator = agent.creator
|
||||
agent_id = f"{agent_creator}/{agent_slug}"
|
||||
logger.info(
|
||||
f"Building agent ID: creator={agent_creator}, slug={agent_slug}, full_id={agent_id}"
|
||||
)
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent_id, # Use username/slug format for marketplace agents
|
||||
name=agent.agent_name,
|
||||
description=agent.description or "",
|
||||
source="marketplace",
|
||||
in_library=False,
|
||||
creator=agent.creator,
|
||||
category="general", # StoreAgent doesn't have categories
|
||||
rating=agent.rating,
|
||||
runs=agent.runs,
|
||||
is_featured=False, # StoreAgent doesn't have is_featured
|
||||
),
|
||||
)
|
||||
|
||||
if not agents:
|
||||
return NoResultsResponse(
|
||||
message=f"No agents found matching '{query}'. Try different keywords or browse the marketplace.",
|
||||
session_id=session_id,
|
||||
suggestions=[
|
||||
"Try more general terms",
|
||||
"Browse categories in the marketplace",
|
||||
"Check spelling",
|
||||
],
|
||||
)
|
||||
|
||||
# Return formatted carousel
|
||||
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
|
||||
return AgentCarouselResponse(
|
||||
message=title,
|
||||
title=title,
|
||||
agents=agents,
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching agents: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message="Failed to search for agents. Please try again.",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
import prisma
|
||||
|
||||
find_agent_tool = FindAgentTool()
|
||||
print(find_agent_tool.parameters)
|
||||
|
||||
async def main():
|
||||
await prisma.Prisma().connect()
|
||||
agents = await find_agent_tool.execute(
|
||||
query="Linkedin", user_id="user", session_id="session"
|
||||
)
|
||||
print(agents)
|
||||
await prisma.Prisma().disconnect()
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,511 @@
|
||||
"""Tool for getting detailed information about a specific agent."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data import graph as graph_db
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
AgentDetails,
|
||||
AgentDetailsNeedCredentialsResponse,
|
||||
AgentDetailsNeedLoginResponse,
|
||||
AgentDetailsResponse,
|
||||
ErrorResponse,
|
||||
ExecutionOptions,
|
||||
InputField,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.server.v2.store import db as store_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GetAgentDetailsTool(BaseTool):
|
||||
"""Tool for getting detailed information about an agent."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_agent_details"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Get detailed information about a specific agent including inputs, credentials required, and execution options."
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "The marketplace agent slug (e.g., 'username/agent-name' or just 'agent-name' to search)",
|
||||
},
|
||||
"agent_version": {
|
||||
"type": "integer",
|
||||
"description": "Optional specific version of the agent (defaults to latest)",
|
||||
},
|
||||
},
|
||||
"required": ["agent_id"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Get detailed information about an agent.
|
||||
|
||||
Args:
|
||||
user_id: User ID (may be anonymous)
|
||||
session_id: Chat session ID
|
||||
agent_id: Agent ID or slug
|
||||
agent_version: Optional version number
|
||||
|
||||
Returns:
|
||||
Pydantic response model
|
||||
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
agent_version = kwargs.get("agent_version")
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide an agent ID",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Always try to get from marketplace first
|
||||
graph = None
|
||||
store_agent = None
|
||||
in_library = False
|
||||
is_marketplace = False
|
||||
|
||||
# Check if it's a slug format (username/agent_name)
|
||||
if "/" in agent_id:
|
||||
try:
|
||||
# Parse username/agent_name from slug
|
||||
username, agent_name = agent_id.split("/", 1)
|
||||
store_agent = await store_db.get_store_agent_details(
|
||||
username, agent_name
|
||||
)
|
||||
logger.info(f"Found agent {agent_id} in marketplace")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get from marketplace: {e}")
|
||||
else:
|
||||
# Try to find by agent slug alone (search all agents)
|
||||
try:
|
||||
# Search for the agent in the store
|
||||
search_results = await store_db.search_store_agents(
|
||||
search_query=agent_id, limit=1
|
||||
)
|
||||
if search_results.agents:
|
||||
first_agent = search_results.agents[0]
|
||||
# Now get the full details using the slug
|
||||
if "/" in first_agent.slug:
|
||||
username, agent_name = first_agent.slug.split("/", 1)
|
||||
store_agent = await store_db.get_store_agent_details(
|
||||
username, agent_name
|
||||
)
|
||||
logger.info("Found agent by search in marketplace")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to search marketplace: {e}")
|
||||
|
||||
# If we found a store agent, get its graph
|
||||
if store_agent:
|
||||
try:
|
||||
# Use get_available_graph to get the graph from store listing version
|
||||
graph_meta = await store_db.get_available_graph(
|
||||
store_agent.store_listing_version_id
|
||||
)
|
||||
# Now get the full graph with that ID
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=graph_meta.id,
|
||||
version=graph_meta.version,
|
||||
user_id=None, # Public access
|
||||
include_subgraphs=True,
|
||||
)
|
||||
is_marketplace = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get graph for store agent: {e}")
|
||||
|
||||
if not graph:
|
||||
return ErrorResponse(
|
||||
message=f"Agent '{agent_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Parse input schema
|
||||
input_fields = {}
|
||||
if hasattr(graph, "input_schema") and graph.input_schema:
|
||||
if isinstance(graph.input_schema, dict):
|
||||
properties = graph.input_schema.get("properties", {})
|
||||
required = graph.input_schema.get("required", [])
|
||||
|
||||
input_required = []
|
||||
input_optional = []
|
||||
|
||||
for key, schema in properties.items():
|
||||
field = InputField(
|
||||
name=key,
|
||||
type=schema.get("type", "string"),
|
||||
description=schema.get("description", ""),
|
||||
required=key in required,
|
||||
default=schema.get("default"),
|
||||
options=schema.get("enum"),
|
||||
format=schema.get("format"),
|
||||
)
|
||||
|
||||
if key in required:
|
||||
input_required.append(field)
|
||||
else:
|
||||
input_optional.append(field)
|
||||
|
||||
input_fields = {
|
||||
"schema": graph.input_schema,
|
||||
"required": input_required,
|
||||
"optional": input_optional,
|
||||
}
|
||||
|
||||
# Parse credential requirements and check availability
|
||||
credentials = []
|
||||
needs_auth = False
|
||||
missing_credentials = []
|
||||
if (
|
||||
hasattr(graph, "credentials_input_schema")
|
||||
and graph.credentials_input_schema
|
||||
):
|
||||
# Get system-provided credentials
|
||||
system_credentials = {}
|
||||
try:
|
||||
system_creds_list = AutoRegistry.get_all_credentials()
|
||||
system_credentials = {c.provider: c for c in system_creds_list}
|
||||
|
||||
# WORKAROUND: Check for common LLM providers that don't use SDK pattern
|
||||
import os
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
# Check for OpenAI
|
||||
if "openai" not in system_credentials:
|
||||
openai_key = os.getenv("OPENAI_API_KEY")
|
||||
if openai_key:
|
||||
system_credentials["openai"] = APIKeyCredentials(
|
||||
id="openai-system",
|
||||
provider="openai",
|
||||
api_key=SecretStr(openai_key),
|
||||
title="System OpenAI API Key",
|
||||
)
|
||||
|
||||
# Check for Anthropic
|
||||
if "anthropic" not in system_credentials:
|
||||
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if anthropic_key:
|
||||
system_credentials["anthropic"] = APIKeyCredentials(
|
||||
id="anthropic-system",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr(anthropic_key),
|
||||
title="System Anthropic API Key",
|
||||
)
|
||||
|
||||
# Check for other common providers
|
||||
for provider, env_var in [
|
||||
("groq", "GROQ_API_KEY"),
|
||||
("ollama", "OLLAMA_API_KEY"),
|
||||
("open_router", "OPEN_ROUTER_API_KEY"),
|
||||
]:
|
||||
if provider not in system_credentials:
|
||||
api_key = os.getenv(env_var)
|
||||
if api_key:
|
||||
system_credentials[provider] = APIKeyCredentials(
|
||||
id=f"{provider}-system",
|
||||
provider=provider,
|
||||
api_key=SecretStr(api_key),
|
||||
title=f"System {provider} API Key",
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"System provides credentials for: {list(system_credentials.keys())}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get system credentials: {e}")
|
||||
|
||||
# Get user's credentials if authenticated
|
||||
user_credentials = {}
|
||||
if user_id and not user_id.startswith("anon_"):
|
||||
try:
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
)
|
||||
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
user_creds_list = await creds_manager.store.get_all_creds(
|
||||
user_id
|
||||
)
|
||||
user_credentials = {c.provider: c for c in user_creds_list}
|
||||
logger.debug(
|
||||
f"User has credentials for: {list(user_credentials.keys())}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get user credentials: {e}")
|
||||
|
||||
# Handle nested schema structure
|
||||
credentials_to_check = {}
|
||||
if isinstance(graph.credentials_input_schema, dict):
|
||||
if "properties" in graph.credentials_input_schema:
|
||||
credentials_to_check = graph.credentials_input_schema[
|
||||
"properties"
|
||||
]
|
||||
else:
|
||||
credentials_to_check = graph.credentials_input_schema
|
||||
|
||||
# Process credentials from the schema dict into a list
|
||||
credentials = []
|
||||
for cred_key, cred_schema in credentials_to_check.items():
|
||||
# Extract the actual provider name
|
||||
actual_provider = None
|
||||
if isinstance(cred_schema, dict):
|
||||
# Try to extract provider from credentials_provider field
|
||||
if "credentials_provider" in cred_schema:
|
||||
providers = cred_schema["credentials_provider"]
|
||||
if isinstance(providers, list) and len(providers) > 0:
|
||||
actual_provider = str(providers[0])
|
||||
if "ProviderName." in actual_provider:
|
||||
actual_provider = (
|
||||
actual_provider.split("'")[1]
|
||||
if "'" in actual_provider
|
||||
else actual_provider.split(".")[-1].lower()
|
||||
)
|
||||
|
||||
cred_meta = {
|
||||
"id": cred_key,
|
||||
"provider": actual_provider
|
||||
or cred_schema.get("credentials_provider", cred_key),
|
||||
"type": cred_schema.get("credentials_type", "api_key"),
|
||||
"title": cred_schema.get("title")
|
||||
or cred_schema.get("description"),
|
||||
}
|
||||
credentials.append(cred_meta)
|
||||
|
||||
# Check if this credential is available
|
||||
provider_name = actual_provider or cred_key
|
||||
if (
|
||||
provider_name not in user_credentials
|
||||
and provider_name not in system_credentials
|
||||
):
|
||||
missing_credentials.append(provider_name)
|
||||
logger.debug(
|
||||
f"Missing credential for provider: {provider_name}"
|
||||
)
|
||||
|
||||
# Only needs auth if there are missing credentials
|
||||
needs_auth = bool(missing_credentials)
|
||||
|
||||
# Determine execution options
|
||||
execution_options = ExecutionOptions(
|
||||
manual=True, # Always support manual execution
|
||||
scheduled=True, # Most agents support scheduling
|
||||
webhook=False, # Check for webhook support
|
||||
)
|
||||
|
||||
# Check for webhook/trigger support
|
||||
if hasattr(graph, "has_external_trigger"):
|
||||
execution_options.webhook = graph.has_external_trigger
|
||||
elif hasattr(graph, "webhook_input_node") and graph.webhook_input_node:
|
||||
execution_options.webhook = True
|
||||
|
||||
# Build trigger info if available
|
||||
trigger_info = None
|
||||
if hasattr(graph, "trigger_setup_info") and graph.trigger_setup_info:
|
||||
trigger_info = {
|
||||
"supported": True,
|
||||
"config": (
|
||||
graph.trigger_setup_info.dict()
|
||||
if hasattr(graph.trigger_setup_info, "dict")
|
||||
else graph.trigger_setup_info
|
||||
),
|
||||
}
|
||||
|
||||
# Build stats if available
|
||||
stats = None
|
||||
if hasattr(graph, "executions_count"):
|
||||
stats = {
|
||||
"total_runs": graph.executions_count, # type: ignore
|
||||
"last_run": (
|
||||
graph.last_execution.isoformat() # type: ignore
|
||||
if hasattr(graph, "last_execution") and graph.last_execution # type: ignore
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
# Create agent details
|
||||
details = AgentDetails(
|
||||
id=graph.id,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
version=graph.version,
|
||||
is_latest=graph.is_active if hasattr(graph, "is_active") else True,
|
||||
in_library=in_library,
|
||||
is_marketplace=is_marketplace,
|
||||
inputs=input_fields,
|
||||
credentials=credentials, # type: ignore
|
||||
execution_options=execution_options,
|
||||
trigger_info=trigger_info,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
# Check if user needs to log in or set up credentials
|
||||
if needs_auth:
|
||||
if not user_id or user_id.startswith("anon_"):
|
||||
# Anonymous user needs to log in first
|
||||
# Build a descriptive message about what credentials are needed
|
||||
cred_list = []
|
||||
for cred in credentials:
|
||||
cred_desc = f"{cred.get('provider', 'Unknown')}"
|
||||
if cred.get("type"):
|
||||
cred_desc += f" ({cred.get('type')})"
|
||||
cred_list.append(cred_desc)
|
||||
|
||||
cred_message = f"This agent requires the following credentials: {', '.join(cred_list)}. Please sign in to set up and run this agent."
|
||||
|
||||
return AgentDetailsNeedLoginResponse(
|
||||
message=cred_message,
|
||||
session_id=session_id,
|
||||
agent=details,
|
||||
agent_info={
|
||||
"agent_id": agent_id,
|
||||
"agent_version": agent_version,
|
||||
"name": details.name,
|
||||
"graph_id": graph.id,
|
||||
},
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
else:
|
||||
# Authenticated user needs to set up credentials
|
||||
# Return the credentials schema so the frontend can show the setup UI
|
||||
cred_message = f"The agent '{details.name}' requires credentials to be configured. Please provide the required credentials to continue."
|
||||
|
||||
return AgentDetailsNeedCredentialsResponse(
|
||||
message=cred_message,
|
||||
session_id=session_id,
|
||||
agent=details,
|
||||
credentials_schema=graph.credentials_input_schema,
|
||||
agent_info={
|
||||
"agent_id": agent_id,
|
||||
"agent_version": agent_version,
|
||||
"name": details.name,
|
||||
"graph_id": graph.id,
|
||||
},
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# Build a descriptive message about the agent
|
||||
message_parts = [f"Agent '{graph.name}' details loaded successfully."]
|
||||
|
||||
if credentials:
|
||||
cred_list = []
|
||||
for cred in credentials:
|
||||
cred_desc = f"{cred.get('provider', 'Unknown')}"
|
||||
if cred.get("type"):
|
||||
cred_desc += f" ({cred.get('type')})"
|
||||
cred_list.append(cred_desc)
|
||||
message_parts.append(f"Required credentials: {', '.join(cred_list)}")
|
||||
|
||||
# Be very explicit about required inputs
|
||||
if input_fields:
|
||||
if input_fields.get("required"):
|
||||
message_parts.append("\n**REQUIRED INPUTS:**")
|
||||
for field in input_fields["required"]:
|
||||
desc = f" - {field.name} ({field.type})"
|
||||
if field.description:
|
||||
desc += f": {field.description}"
|
||||
message_parts.append(desc)
|
||||
|
||||
# Build example dict format
|
||||
example_dict = {}
|
||||
for field in input_fields["required"]:
|
||||
if field.type == "string":
|
||||
example_dict[field.name] = f"<{field.name}_value>"
|
||||
elif field.type == "number" or field.type == "integer":
|
||||
example_dict[field.name] = 123
|
||||
elif field.type == "boolean":
|
||||
example_dict[field.name] = True
|
||||
else:
|
||||
example_dict[field.name] = f"<{field.type}_value>"
|
||||
|
||||
message_parts.append(
|
||||
"\n**IMPORTANT:** To run this agent, you MUST pass these inputs as a dictionary to run_agent, setup_agent, and get_required_setup_info tools."
|
||||
)
|
||||
message_parts.append(f"Example format: inputs={example_dict}")
|
||||
|
||||
if input_fields.get("optional"):
|
||||
message_parts.append("\n**OPTIONAL INPUTS:**")
|
||||
for field in input_fields["optional"]:
|
||||
desc = f" - {field.name} ({field.type})"
|
||||
if field.description:
|
||||
desc += f": {field.description}"
|
||||
if field.default is not None:
|
||||
desc += f" [default: {field.default}]"
|
||||
message_parts.append(desc)
|
||||
|
||||
return AgentDetailsResponse(
|
||||
message=" ".join(message_parts),
|
||||
session_id=session_id,
|
||||
agent=details,
|
||||
user_authenticated=not (not user_id or user_id.startswith("anon_")),
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent details: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to get agent details: {e!s}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
from backend.data.db import prisma
|
||||
|
||||
find_agent_tool = GetAgentDetailsTool()
|
||||
print(find_agent_tool.parameters)
|
||||
|
||||
async def main():
|
||||
await prisma.connect()
|
||||
|
||||
# Test with a logged-in user
|
||||
print("\n=== Testing agent with logged-in user ===")
|
||||
result1 = await find_agent_tool._execute(
|
||||
agent_id="autogpt-store/slug-a",
|
||||
user_id="3e53486c-cf57-477e-ba2a-cb02dc828e1a",
|
||||
session_id="session1",
|
||||
)
|
||||
print(f"Result type: {result1.type}")
|
||||
print(f"Has credentials schema: {'credentials_schema' in result1.__dict__}")
|
||||
if hasattr(result1, "message"):
|
||||
print(f"Result type: {result1.type}")
|
||||
print(f"Message: {result1.message}...")
|
||||
|
||||
# Test with an anonymous user
|
||||
print("\n=== Testing with anonymous user ===")
|
||||
result2 = await find_agent_tool._execute(
|
||||
agent_id="autogpt-store/slug-a",
|
||||
user_id="anon_user123",
|
||||
session_id="session2",
|
||||
)
|
||||
print(f"Result type: {result2.type}")
|
||||
print(f"Message: {result2.message}...")
|
||||
|
||||
await prisma.disconnect()
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,423 @@
|
||||
"""Tool for getting required setup information for an agent."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data import graph as graph_db
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
ErrorResponse,
|
||||
ExecutionModeInfo,
|
||||
InputField,
|
||||
SetupInfo,
|
||||
SetupRequirementInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.server.v2.store import db as store_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GetRequiredSetupInfoTool(BaseTool):
|
||||
"""Tool for getting required setup information including credentials and inputs."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_required_setup_info"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Check if an agent can be set up with the provided input data and credentials.
|
||||
Call this AFTER get_agent_details to validate that you have all required inputs.
|
||||
Pass the input dictionary you plan to use with run_agent or setup_agent to verify it's complete."""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "The marketplace agent slug (e.g., 'username/agent-name' or just 'agent-name' to search)",
|
||||
},
|
||||
"agent_version": {
|
||||
"type": "integer",
|
||||
"description": "Optional specific version of the agent (defaults to latest)",
|
||||
},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": "The input dictionary you plan to provide. Should contain ALL required inputs from get_agent_details",
|
||||
"additionalProperties": True,
|
||||
},
|
||||
},
|
||||
"required": ["agent_id"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""This tool requires authentication."""
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Get required setup information for an agent.
|
||||
|
||||
Args:
|
||||
user_id: Authenticated user ID
|
||||
session_id: Chat session ID
|
||||
agent_id: Agent/Graph ID
|
||||
agent_version: Optional version
|
||||
|
||||
Returns:
|
||||
JSON formatted setup requirements
|
||||
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
agent_version = kwargs.get("agent_version")
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide an agent ID",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
graph = None
|
||||
|
||||
# Check if it's a marketplace slug format (username/agent_name)
|
||||
if "/" in agent_id:
|
||||
try:
|
||||
# Parse username/agent_name from slug
|
||||
username, agent_name = agent_id.split("/", 1)
|
||||
store_agent = await store_db.get_store_agent_details(
|
||||
username, agent_name
|
||||
)
|
||||
if store_agent:
|
||||
# Get graph from store listing
|
||||
graph_meta = await store_db.get_available_graph(
|
||||
store_agent.store_listing_version_id
|
||||
)
|
||||
# Now get the full graph with that ID
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=graph_meta.id,
|
||||
version=graph_meta.version,
|
||||
user_id=None, # Public access
|
||||
include_subgraphs=True,
|
||||
)
|
||||
logger.info(f"Found agent {agent_id} in marketplace")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get from marketplace: {e}")
|
||||
else:
|
||||
# Try direct graph ID lookup
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=agent_id,
|
||||
version=agent_version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
|
||||
if not graph:
|
||||
# Try to get from marketplace/public
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=agent_id,
|
||||
version=agent_version,
|
||||
user_id=None,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
|
||||
if not graph:
|
||||
return ErrorResponse(
|
||||
message=f"Agent '{agent_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
setup_info = SetupInfo(
|
||||
agent_id=graph.id,
|
||||
agent_name=graph.name,
|
||||
version=graph.version,
|
||||
)
|
||||
|
||||
# Get credential manager
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
|
||||
# Analyze credential requirements
|
||||
if (
|
||||
hasattr(graph, "credentials_input_schema")
|
||||
and graph.credentials_input_schema
|
||||
):
|
||||
user_credentials = {}
|
||||
system_credentials = {}
|
||||
|
||||
try:
|
||||
# Get user's existing credentials
|
||||
if user_id:
|
||||
user_creds_list = await creds_manager.store.get_all_creds(
|
||||
user_id
|
||||
)
|
||||
else:
|
||||
user_creds_list = []
|
||||
user_credentials = {c.provider: c for c in user_creds_list}
|
||||
logger.info(
|
||||
f"User has credentials for providers: {list(user_credentials.keys())}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get user credentials: {e}")
|
||||
|
||||
# Get system-provided default credentials
|
||||
try:
|
||||
system_creds_list = AutoRegistry.get_all_credentials()
|
||||
system_credentials = {c.provider: c for c in system_creds_list}
|
||||
|
||||
# WORKAROUND: Check for common LLM providers that don't use SDK pattern
|
||||
import os
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
# Check for OpenAI
|
||||
if "openai" not in system_credentials:
|
||||
openai_key = os.getenv("OPENAI_API_KEY")
|
||||
if openai_key:
|
||||
system_credentials["openai"] = APIKeyCredentials(
|
||||
id="openai-system",
|
||||
provider="openai",
|
||||
api_key=SecretStr(openai_key),
|
||||
title="System OpenAI API Key",
|
||||
)
|
||||
|
||||
# Check for Anthropic
|
||||
if "anthropic" not in system_credentials:
|
||||
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if anthropic_key:
|
||||
system_credentials["anthropic"] = APIKeyCredentials(
|
||||
id="anthropic-system",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr(anthropic_key),
|
||||
title="System Anthropic API Key",
|
||||
)
|
||||
|
||||
# Check for other common providers
|
||||
provider_env_map = {
|
||||
"groq": "GROQ_API_KEY",
|
||||
"ollama": "OLLAMA_API_KEY",
|
||||
"open_router": "OPEN_ROUTER_API_KEY",
|
||||
}
|
||||
|
||||
for provider, env_var in provider_env_map.items():
|
||||
if provider not in system_credentials:
|
||||
api_key = os.getenv(env_var)
|
||||
if api_key:
|
||||
system_credentials[provider] = APIKeyCredentials(
|
||||
id=f"{provider}-system",
|
||||
provider=provider,
|
||||
api_key=SecretStr(api_key),
|
||||
title=f"System {provider} API Key",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"System provides credentials for: {list(system_credentials.keys())}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get system credentials: {e}")
|
||||
|
||||
# Handle the nested schema structure
|
||||
credentials_to_check = {}
|
||||
if isinstance(graph.credentials_input_schema, dict):
|
||||
# Check if it's a JSON schema with properties
|
||||
if "properties" in graph.credentials_input_schema:
|
||||
credentials_to_check = graph.credentials_input_schema[
|
||||
"properties"
|
||||
]
|
||||
else:
|
||||
# Fallback to treating the whole schema as credentials
|
||||
credentials_to_check = graph.credentials_input_schema
|
||||
|
||||
cred_in_schema = {}
|
||||
for cred_key, cred_schema in credentials_to_check.items():
|
||||
cred_req = SetupRequirementInfo(
|
||||
key=cred_key,
|
||||
provider=cred_key,
|
||||
required=True,
|
||||
user_has=False,
|
||||
)
|
||||
|
||||
# Parse credential schema to extract the actual provider
|
||||
actual_provider = None
|
||||
if isinstance(cred_schema, dict):
|
||||
# Try to extract provider from credentials_provider field
|
||||
if "credentials_provider" in cred_schema:
|
||||
providers = cred_schema["credentials_provider"]
|
||||
if isinstance(providers, list) and len(providers) > 0:
|
||||
# Extract the actual provider name from the enum
|
||||
actual_provider = str(providers[0])
|
||||
# Handle ProviderName enum format
|
||||
if "ProviderName." in actual_provider:
|
||||
actual_provider = (
|
||||
actual_provider.split("'")[1]
|
||||
if "'" in actual_provider
|
||||
else actual_provider.split(".")[-1].lower()
|
||||
)
|
||||
cred_req.provider = actual_provider
|
||||
elif "provider" in cred_schema:
|
||||
cred_req.provider = cred_schema["provider"]
|
||||
actual_provider = cred_schema["provider"]
|
||||
|
||||
if "type" in cred_schema:
|
||||
cred_req.type = cred_schema["type"] # oauth, api_key
|
||||
if "scopes" in cred_schema:
|
||||
cred_req.scopes = cred_schema["scopes"]
|
||||
if "description" in cred_schema:
|
||||
cred_req.description = cred_schema["description"]
|
||||
|
||||
# Check if user has this credential using the actual provider name
|
||||
provider_name = actual_provider or cred_req.provider
|
||||
logger.debug(
|
||||
f"Checking credential {cred_key}: provider={provider_name}, available={list(user_credentials.keys())}"
|
||||
)
|
||||
|
||||
# Check user credentials first, then system credentials
|
||||
if provider_name in user_credentials:
|
||||
cred_req.user_has = True
|
||||
cred_req.credential_id = user_credentials[provider_name].id
|
||||
logger.info(f"User has credential for {provider_name}")
|
||||
elif provider_name in system_credentials:
|
||||
cred_req.user_has = True
|
||||
cred_req.credential_id = f"system-{provider_name}"
|
||||
logger.info(f"System provides credential for {provider_name}")
|
||||
else:
|
||||
cred_in_schema[cred_key] = cred_schema
|
||||
logger.info(
|
||||
f"User missing credential for {provider_name} (not provided by system either)"
|
||||
)
|
||||
|
||||
setup_info.requirements["credentials"].append(cred_req)
|
||||
setup_info.user_readiness.missing_credentials = cred_in_schema
|
||||
|
||||
# Analyze input requirements
|
||||
if hasattr(graph, "input_schema") and graph.input_schema:
|
||||
if isinstance(graph.input_schema, dict):
|
||||
properties = graph.input_schema.get("properties", {})
|
||||
required = graph.input_schema.get("required", [])
|
||||
|
||||
for key, schema in properties.items():
|
||||
input_req = InputField(
|
||||
name=key,
|
||||
type=schema.get("type", "string"),
|
||||
required=key in required,
|
||||
description=schema.get("description", ""),
|
||||
)
|
||||
|
||||
# Add default value if present
|
||||
if "default" in schema:
|
||||
input_req.default = schema["default"]
|
||||
|
||||
# Add enum values if present
|
||||
if "enum" in schema:
|
||||
input_req.options = schema["enum"]
|
||||
|
||||
# Add format hints
|
||||
if "format" in schema:
|
||||
input_req.format = schema["format"]
|
||||
|
||||
setup_info.requirements["inputs"].append(input_req)
|
||||
|
||||
# Determine supported execution modes
|
||||
execution_modes = []
|
||||
|
||||
# Manual execution is always supported
|
||||
execution_modes.append(
|
||||
ExecutionModeInfo(
|
||||
type="manual",
|
||||
description="Run the agent immediately with provided inputs",
|
||||
supported=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Check for scheduled execution support
|
||||
execution_modes.append(
|
||||
ExecutionModeInfo(
|
||||
type="scheduled",
|
||||
description="Run the agent on a recurring schedule (cron)",
|
||||
supported=True,
|
||||
config_required={
|
||||
"cron": "Cron expression (e.g., '0 9 * * 1' for Mondays at 9 AM)",
|
||||
"timezone": "User timezone (converted to UTC)",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Check for webhook support
|
||||
webhook_supported = False
|
||||
if hasattr(graph, "has_external_trigger"):
|
||||
webhook_supported = graph.has_external_trigger
|
||||
elif hasattr(graph, "webhook_input_node") and graph.webhook_input_node:
|
||||
webhook_supported = True
|
||||
|
||||
if webhook_supported:
|
||||
webhook_mode = ExecutionModeInfo(
|
||||
type="webhook",
|
||||
description="Trigger the agent via external webhook",
|
||||
supported=True,
|
||||
config_required={},
|
||||
)
|
||||
|
||||
# Add trigger setup info if available
|
||||
if hasattr(graph, "trigger_setup_info") and graph.trigger_setup_info:
|
||||
webhook_mode.trigger_info = (
|
||||
graph.trigger_setup_info.dict() # type: ignore
|
||||
if hasattr(graph.trigger_setup_info, "dict")
|
||||
else graph.trigger_setup_info # type: ignore
|
||||
)
|
||||
|
||||
execution_modes.append(webhook_mode)
|
||||
else:
|
||||
execution_modes.append(
|
||||
ExecutionModeInfo(
|
||||
type="webhook",
|
||||
description="Webhook triggers not supported for this agent",
|
||||
supported=False,
|
||||
)
|
||||
)
|
||||
|
||||
setup_info.requirements["execution_modes"] = execution_modes
|
||||
|
||||
# Check overall readiness
|
||||
has_all_creds = len(setup_info.user_readiness.missing_credentials) == 0
|
||||
setup_info.user_readiness.has_all_credentials = has_all_creds
|
||||
|
||||
# Agent is ready if all required credentials are present
|
||||
setup_info.user_readiness.ready_to_run = has_all_creds
|
||||
|
||||
# Add setup instructions
|
||||
if not setup_info.user_readiness.ready_to_run:
|
||||
instructions = []
|
||||
if setup_info.user_readiness.missing_credentials:
|
||||
instructions.append(
|
||||
f"Add credentials for: {', '.join(setup_info.user_readiness.missing_credentials)}",
|
||||
)
|
||||
setup_info.setup_instructions = instructions
|
||||
else:
|
||||
setup_info.setup_instructions = ["Agent is ready to set up and run!"]
|
||||
|
||||
return SetupRequirementsResponse(
|
||||
message=f"Setup requirements for '{graph.name}' retrieved successfully",
|
||||
setup_info=setup_info,
|
||||
session_id=session_id,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting setup requirements: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to get setup requirements: {e!s}",
|
||||
session_id=session_id,
|
||||
)
|
||||
284
autogpt_platform/backend/backend/server/v2/chat/tools/models.py
Normal file
284
autogpt_platform/backend/backend/server/v2/chat/tools/models.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""Pydantic models for tool responses."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
|
||||
|
||||
class ResponseType(str, Enum):
|
||||
"""Types of tool responses."""
|
||||
|
||||
AGENT_CAROUSEL = "agent_carousel"
|
||||
AGENT_DETAILS = "agent_details"
|
||||
AGENT_DETAILS_NEED_LOGIN = "agent_details_need_login"
|
||||
AGENT_DETAILS_NEED_CREDENTIALS = "agent_details_need_credentials"
|
||||
SETUP_REQUIREMENTS = "setup_requirements"
|
||||
SCHEDULE_CREATED = "schedule_created"
|
||||
WEBHOOK_CREATED = "webhook_created"
|
||||
PRESET_CREATED = "preset_created"
|
||||
EXECUTION_STARTED = "execution_started"
|
||||
NEED_LOGIN = "need_login"
|
||||
NEED_CREDENTIALS = "need_credentials"
|
||||
INSUFFICIENT_CREDITS = "insufficient_credits"
|
||||
VALIDATION_ERROR = "validation_error"
|
||||
ERROR = "error"
|
||||
NO_RESULTS = "no_results"
|
||||
SUCCESS = "success"
|
||||
|
||||
|
||||
# Base response model
|
||||
class ToolResponseBase(BaseModel):
|
||||
"""Base model for all tool responses."""
|
||||
|
||||
type: ResponseType
|
||||
message: str
|
||||
session_id: str | None = None
|
||||
|
||||
|
||||
# Agent discovery models
|
||||
class AgentInfo(BaseModel):
|
||||
"""Information about an agent."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
source: str = Field(description="marketplace or library")
|
||||
in_library: bool = False
|
||||
creator: str | None = None
|
||||
category: str | None = None
|
||||
rating: float | None = None
|
||||
runs: int | None = None
|
||||
is_featured: bool | None = None
|
||||
status: str | None = None
|
||||
can_access_graph: bool | None = None
|
||||
has_external_trigger: bool | None = None
|
||||
new_output: bool | None = None
|
||||
graph_id: str | None = None
|
||||
|
||||
|
||||
class AgentCarouselResponse(ToolResponseBase):
|
||||
"""Response for find_agent tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_CAROUSEL
|
||||
title: str = "Available Agents"
|
||||
agents: list[AgentInfo]
|
||||
count: int
|
||||
|
||||
|
||||
class NoResultsResponse(ToolResponseBase):
|
||||
"""Response when no agents found."""
|
||||
|
||||
type: ResponseType = ResponseType.NO_RESULTS
|
||||
suggestions: list[str] = []
|
||||
|
||||
|
||||
# Agent details models
|
||||
class InputField(BaseModel):
|
||||
"""Input field specification."""
|
||||
|
||||
name: str
|
||||
type: str = "string"
|
||||
description: str = ""
|
||||
required: bool = False
|
||||
default: Any | None = None
|
||||
options: list[Any] | None = None
|
||||
format: str | None = None
|
||||
|
||||
|
||||
class ExecutionOptions(BaseModel):
|
||||
"""Available execution options for an agent."""
|
||||
|
||||
manual: bool = True
|
||||
scheduled: bool = True
|
||||
webhook: bool = False
|
||||
|
||||
|
||||
class AgentDetails(BaseModel):
|
||||
"""Detailed agent information."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
version: int
|
||||
is_latest: bool = True
|
||||
in_library: bool = False
|
||||
is_marketplace: bool = False
|
||||
inputs: dict[str, Any] = {}
|
||||
credentials: list[CredentialsMetaInput] = []
|
||||
execution_options: ExecutionOptions = Field(default_factory=ExecutionOptions)
|
||||
trigger_info: dict[str, Any] | None = None
|
||||
stats: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class AgentDetailsResponse(ToolResponseBase):
|
||||
"""Response for get_agent_details tool."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_DETAILS
|
||||
agent: AgentDetails
|
||||
user_authenticated: bool = False
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
class AgentDetailsNeedLoginResponse(ToolResponseBase):
|
||||
"""Response when agent details need login."""
|
||||
|
||||
type: ResponseType = ResponseType.AGENT_DETAILS_NEED_LOGIN
|
||||
agent: AgentDetails
|
||||
agent_info: dict[str, Any] | None = None
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
class AgentDetailsNeedCredentialsResponse(ToolResponseBase):
|
||||
"""Response when agent needs credentials to be configured."""
|
||||
|
||||
type: ResponseType = ResponseType.NEED_CREDENTIALS
|
||||
agent: AgentDetails
|
||||
credentials_schema: dict[str, Any]
|
||||
agent_info: dict[str, Any] | None = None
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
# Setup info models
|
||||
class SetupRequirementInfo(BaseModel):
|
||||
"""Setup requirement information."""
|
||||
|
||||
key: str
|
||||
provider: str
|
||||
required: bool = True
|
||||
user_has: bool = False
|
||||
credential_id: str | None = None
|
||||
type: str | None = None
|
||||
scopes: list[str] | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class ExecutionModeInfo(BaseModel):
|
||||
"""Execution mode information."""
|
||||
|
||||
type: str # manual, scheduled, webhook
|
||||
description: str
|
||||
supported: bool
|
||||
config_required: dict[str, str] | None = None
|
||||
trigger_info: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class UserReadiness(BaseModel):
|
||||
"""User readiness status."""
|
||||
|
||||
has_all_credentials: bool = False
|
||||
missing_credentials: dict[str, Any] = {}
|
||||
ready_to_run: bool = False
|
||||
|
||||
|
||||
class SetupInfo(BaseModel):
|
||||
"""Complete setup information."""
|
||||
|
||||
agent_id: str
|
||||
agent_name: str
|
||||
version: int
|
||||
requirements: dict[str, list[Any]] = Field(
|
||||
default_factory=lambda: {
|
||||
"credentials": [],
|
||||
"inputs": [],
|
||||
"execution_modes": [],
|
||||
},
|
||||
)
|
||||
user_readiness: UserReadiness = Field(default_factory=UserReadiness)
|
||||
setup_instructions: list[str] = []
|
||||
|
||||
|
||||
class SetupRequirementsResponse(ToolResponseBase):
|
||||
"""Response for get_required_setup_info tool."""
|
||||
|
||||
type: ResponseType = ResponseType.SETUP_REQUIREMENTS
|
||||
setup_info: SetupInfo
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
# Setup agent models
|
||||
class ScheduleCreatedResponse(ToolResponseBase):
|
||||
"""Response for scheduled agent setup."""
|
||||
|
||||
type: ResponseType = ResponseType.SCHEDULE_CREATED
|
||||
schedule_id: str
|
||||
name: str
|
||||
cron: str
|
||||
timezone: str = "UTC"
|
||||
next_run: str | None = None
|
||||
graph_id: str
|
||||
graph_name: str
|
||||
|
||||
|
||||
class WebhookCreatedResponse(ToolResponseBase):
|
||||
"""Response for webhook agent setup."""
|
||||
|
||||
type: ResponseType = ResponseType.WEBHOOK_CREATED
|
||||
webhook_id: str
|
||||
webhook_url: str
|
||||
preset_id: str | None = None
|
||||
name: str
|
||||
graph_id: str
|
||||
graph_name: str
|
||||
|
||||
|
||||
class PresetCreatedResponse(ToolResponseBase):
|
||||
"""Response for preset agent setup."""
|
||||
|
||||
type: ResponseType = ResponseType.PRESET_CREATED
|
||||
preset_id: str
|
||||
name: str
|
||||
graph_id: str
|
||||
graph_name: str
|
||||
|
||||
|
||||
# Run agent models
|
||||
class ExecutionStartedResponse(ToolResponseBase):
|
||||
"""Response for agent execution started."""
|
||||
|
||||
type: ResponseType = ResponseType.EXECUTION_STARTED
|
||||
execution_id: str
|
||||
graph_id: str
|
||||
graph_name: str
|
||||
status: str = "QUEUED"
|
||||
ended_at: str | None = None
|
||||
outputs: dict[str, Any] | None = None
|
||||
error: str | None = None
|
||||
timeout_reached: bool | None = None
|
||||
|
||||
|
||||
class InsufficientCreditsResponse(ToolResponseBase):
|
||||
"""Response for insufficient credits."""
|
||||
|
||||
type: ResponseType = ResponseType.INSUFFICIENT_CREDITS
|
||||
balance: float
|
||||
|
||||
|
||||
class ValidationErrorResponse(ToolResponseBase):
|
||||
"""Response for validation errors."""
|
||||
|
||||
type: ResponseType = ResponseType.VALIDATION_ERROR
|
||||
error: str
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# Auth/error models
|
||||
class NeedLoginResponse(ToolResponseBase):
|
||||
"""Response when login is needed."""
|
||||
|
||||
type: ResponseType = ResponseType.NEED_LOGIN
|
||||
agent_info: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ErrorResponse(ToolResponseBase):
|
||||
"""Response for errors."""
|
||||
|
||||
type: ResponseType = ResponseType.ERROR
|
||||
error: str | None = None
|
||||
details: dict[str, Any] | None = None
|
||||
@@ -0,0 +1,567 @@
|
||||
"""Tool for running an agent manually (one-off execution)."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.execution import get_graph_execution, get_graph_execution_meta
|
||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
ErrorResponse,
|
||||
ExecutionStartedResponse,
|
||||
InsufficientCreditsResponse,
|
||||
ToolResponseBase,
|
||||
ValidationErrorResponse,
|
||||
)
|
||||
from backend.server.v2.library import db as library_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunAgentTool(BaseTool):
|
||||
"""Tool for executing an agent manually with immediate results."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "run_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Run an agent immediately (one-off manual execution).
|
||||
IMPORTANT: Before calling this tool, you MUST first call get_agent_details to determine what inputs are required.
|
||||
The 'inputs' parameter must be a dictionary containing ALL required input values identified by get_agent_details.
|
||||
Example: If get_agent_details shows required inputs 'search_query' and 'max_results', you must pass:
|
||||
inputs={"search_query": "user's query", "max_results": 10}"""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "The ID of the agent to run (graph ID or marketplace slug)",
|
||||
},
|
||||
"agent_version": {
|
||||
"type": "integer",
|
||||
"description": "Optional version number of the agent",
|
||||
},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": 'REQUIRED: Dictionary of input values. Must include ALL required inputs from get_agent_details. Format: {"input_name": value}',
|
||||
"additionalProperties": True,
|
||||
},
|
||||
"credentials": {
|
||||
"type": "object",
|
||||
"description": "Credentials for the agent (if needed)",
|
||||
"additionalProperties": True,
|
||||
},
|
||||
"wait_for_result": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to wait for execution to complete (max 30s)",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["agent_id"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""This tool requires authentication."""
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute an agent manually.
|
||||
|
||||
Args:
|
||||
user_id: Authenticated user ID
|
||||
session_id: Chat session ID
|
||||
**kwargs: Execution parameters
|
||||
|
||||
Returns:
|
||||
JSON formatted execution result
|
||||
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
agent_version = kwargs.get("agent_version")
|
||||
inputs = kwargs.get("inputs", {})
|
||||
credentials = kwargs.get("credentials", {})
|
||||
wait_for_result = kwargs.get("wait_for_result", False)
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide an agent ID",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if user is authenticated (required for running agents)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required to run agents",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check credit balance
|
||||
credit_model = get_user_credit_model()
|
||||
balance = await credit_model.get_credits(user_id)
|
||||
|
||||
if balance <= 0:
|
||||
return InsufficientCreditsResponse(
|
||||
message="Insufficient credits. Please top up your account.",
|
||||
balance=balance,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if agent_id looks like a marketplace slug
|
||||
graph = None
|
||||
marketplace_graph = None
|
||||
|
||||
if "/" in agent_id:
|
||||
# Looks like a marketplace slug, try to get from store first
|
||||
from backend.server.v2.store import db as store_db
|
||||
|
||||
try:
|
||||
username, agent_name = agent_id.split("/", 1)
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
username, agent_name
|
||||
)
|
||||
if agent_details:
|
||||
# Get the graph from the store listing version
|
||||
graph_meta = await store_db.get_available_graph(
|
||||
agent_details.store_listing_version_id
|
||||
)
|
||||
marketplace_graph = await graph_db.get_graph(
|
||||
graph_id=graph_meta.id,
|
||||
version=graph_meta.version,
|
||||
user_id=None, # Public access
|
||||
include_subgraphs=True,
|
||||
)
|
||||
logger.info(f"Found marketplace agent by slug: {agent_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get agent by slug: {e}")
|
||||
|
||||
# If we have a marketplace graph from the slug lookup, handle it
|
||||
if marketplace_graph:
|
||||
# Check if already in user's library
|
||||
library_agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id=user_id,
|
||||
graph_id=marketplace_graph.id,
|
||||
graph_version=marketplace_graph.version,
|
||||
)
|
||||
|
||||
if library_agent:
|
||||
logger.info(
|
||||
f"Agent {agent_id} already in user library, using existing entry"
|
||||
)
|
||||
# Get the graph from the library agent
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=library_agent.graph_id,
|
||||
version=library_agent.graph_version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Adding marketplace agent {agent_id} to user library")
|
||||
await library_db.create_library_agent(
|
||||
graph=marketplace_graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=True,
|
||||
)
|
||||
graph = marketplace_graph
|
||||
else:
|
||||
# Not found via slug, try as direct graph ID
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=agent_id,
|
||||
version=agent_version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
|
||||
if not graph:
|
||||
# Try as marketplace agent by ID
|
||||
marketplace_graph = await graph_db.get_graph(
|
||||
graph_id=agent_id,
|
||||
version=agent_version,
|
||||
user_id=None, # Public access
|
||||
include_subgraphs=True,
|
||||
)
|
||||
|
||||
if marketplace_graph:
|
||||
# Check if already in user's library
|
||||
library_agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id=user_id,
|
||||
graph_id=marketplace_graph.id,
|
||||
graph_version=marketplace_graph.version,
|
||||
)
|
||||
|
||||
if library_agent:
|
||||
logger.info(
|
||||
f"Agent {agent_id} already in user library, using existing entry"
|
||||
)
|
||||
# Get the graph from the library agent
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=library_agent.graph_id,
|
||||
version=library_agent.graph_version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Adding marketplace agent {agent_id} to user library"
|
||||
)
|
||||
await library_db.create_library_agent(
|
||||
graph=marketplace_graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=True,
|
||||
)
|
||||
graph = marketplace_graph
|
||||
|
||||
if not graph:
|
||||
return ErrorResponse(
|
||||
message=f"Agent '{agent_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get system-provided credentials
|
||||
system_credentials = {}
|
||||
try:
|
||||
system_creds_list = AutoRegistry.get_all_credentials()
|
||||
system_credentials = {c.provider: c for c in system_creds_list}
|
||||
|
||||
# WORKAROUND: Check for common LLM providers that don't use SDK pattern
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
# System credentials never expire - set to far future (Unix timestamp)
|
||||
expires_at = int(
|
||||
(datetime.utcnow() + timedelta(days=36500)).timestamp()
|
||||
) # 100 years
|
||||
|
||||
# Check for OpenAI
|
||||
if "openai" not in system_credentials:
|
||||
openai_key = os.getenv("OPENAI_API_KEY")
|
||||
if openai_key:
|
||||
system_credentials["openai"] = APIKeyCredentials(
|
||||
id="system-openai",
|
||||
provider="openai",
|
||||
api_key=SecretStr(openai_key),
|
||||
title="System OpenAI API Key",
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
# Check for Anthropic
|
||||
if "anthropic" not in system_credentials:
|
||||
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if anthropic_key:
|
||||
system_credentials["anthropic"] = APIKeyCredentials(
|
||||
id="system-anthropic",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr(anthropic_key),
|
||||
title="System Anthropic API Key",
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
# Check for other common providers
|
||||
for provider, env_var in [
|
||||
("groq", "GROQ_API_KEY"),
|
||||
("ollama", "OLLAMA_API_KEY"),
|
||||
("open_router", "OPEN_ROUTER_API_KEY"),
|
||||
]:
|
||||
if provider not in system_credentials:
|
||||
api_key = os.getenv(env_var)
|
||||
if api_key:
|
||||
system_credentials[provider] = APIKeyCredentials(
|
||||
id=f"system-{provider}",
|
||||
provider=provider,
|
||||
api_key=SecretStr(api_key),
|
||||
title=f"System {provider} API Key",
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"System provides credentials for: {list(system_credentials.keys())}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get system credentials: {e}")
|
||||
|
||||
# Convert credentials to CredentialsMetaInput format
|
||||
# Fill in missing credentials with system-provided ones
|
||||
input_credentials = {}
|
||||
|
||||
# First, process user-provided credentials
|
||||
for key, value in credentials.items():
|
||||
if isinstance(value, dict):
|
||||
input_credentials[key] = CredentialsMetaInput(**value)
|
||||
else:
|
||||
# Assume it's a credential ID
|
||||
input_credentials[key] = CredentialsMetaInput(
|
||||
id=value,
|
||||
provider=key, # Use the key as provider name
|
||||
type="api_key",
|
||||
)
|
||||
|
||||
# Get user credentials if authenticated
|
||||
user_credentials = {}
|
||||
if user_id and not user_id.startswith("anon_"):
|
||||
try:
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
)
|
||||
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
user_creds_list = await creds_manager.store.get_all_creds(user_id)
|
||||
for cred in user_creds_list:
|
||||
user_credentials[cred.provider] = cred
|
||||
logger.info(
|
||||
f"User has credentials for: {list(user_credentials.keys())}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get user credentials: {e}")
|
||||
|
||||
# Use the graph's aggregated credentials to properly map credentials
|
||||
# This ensures we use the same keys that the graph expects
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
logger.info(
|
||||
f"Graph aggregate credentials: {list(graph_cred_inputs.keys())}"
|
||||
)
|
||||
logger.info(f"User provided credentials: {list(input_credentials.keys())}")
|
||||
logger.info(
|
||||
f"Available system credentials: {list(system_credentials.keys())}"
|
||||
)
|
||||
logger.info(f"Available user credentials: {list(user_credentials.keys())}")
|
||||
|
||||
# Process each aggregated credential field
|
||||
for agg_key, (field_info, node_fields) in graph_cred_inputs.items():
|
||||
if agg_key not in input_credentials:
|
||||
# Extract provider from field_info (it's a frozenset, get the first element)
|
||||
provider_set = field_info.provider
|
||||
if (
|
||||
isinstance(provider_set, (set, frozenset))
|
||||
and len(provider_set) > 0
|
||||
):
|
||||
# Get the first provider from the set
|
||||
provider_enum = next(iter(provider_set))
|
||||
# Get the string value from the enum
|
||||
provider_name = (
|
||||
provider_enum.value
|
||||
if hasattr(provider_enum, "value")
|
||||
else str(provider_enum)
|
||||
)
|
||||
else:
|
||||
provider_name = str(provider_set) if provider_set else None
|
||||
|
||||
logger.info(
|
||||
f"Checking credential {agg_key} for provider {provider_name}"
|
||||
)
|
||||
|
||||
# Try to find credential from user or system
|
||||
credential_found = False
|
||||
|
||||
# First check user credentials
|
||||
if provider_name and provider_name in user_credentials:
|
||||
logger.info(
|
||||
f"Using user credential for {provider_name} (key: {agg_key})"
|
||||
)
|
||||
user_cred = user_credentials[provider_name]
|
||||
# Use the provider_enum we already extracted from the frozenset
|
||||
if (
|
||||
isinstance(provider_set, (set, frozenset))
|
||||
and len(provider_set) > 0
|
||||
):
|
||||
provider_enum = next(iter(provider_set))
|
||||
input_credentials[agg_key] = CredentialsMetaInput(
|
||||
id=user_cred.id,
|
||||
provider=provider_enum,
|
||||
type=(
|
||||
user_cred.type
|
||||
if hasattr(user_cred, "type")
|
||||
else "api_key"
|
||||
),
|
||||
)
|
||||
credential_found = True
|
||||
logger.info(
|
||||
f"Added user credential to input_credentials[{agg_key}]"
|
||||
)
|
||||
|
||||
# If not found in user creds, check system credentials
|
||||
if (
|
||||
not credential_found
|
||||
and provider_name
|
||||
and provider_name in system_credentials
|
||||
):
|
||||
logger.info(
|
||||
f"Using system credential for {provider_name} (key: {agg_key})"
|
||||
)
|
||||
# Use the provider_enum we already extracted from the frozenset
|
||||
if (
|
||||
isinstance(provider_set, (set, frozenset))
|
||||
and len(provider_set) > 0
|
||||
):
|
||||
provider_enum = next(iter(provider_set))
|
||||
input_credentials[agg_key] = CredentialsMetaInput(
|
||||
id=f"system-{provider_name}",
|
||||
provider=provider_enum,
|
||||
type="api_key",
|
||||
)
|
||||
credential_found = True
|
||||
logger.info(
|
||||
f"Added system credential to input_credentials[{agg_key}]"
|
||||
)
|
||||
|
||||
if not credential_found:
|
||||
logger.warning(
|
||||
f"Could not find credential for {agg_key} (provider: {provider_name}) in user or system stores"
|
||||
)
|
||||
|
||||
# Check if the graph needs inputs that weren't provided
|
||||
if hasattr(graph, "input_schema") and graph.input_schema:
|
||||
required_inputs = []
|
||||
optional_inputs = []
|
||||
|
||||
# Parse the input schema
|
||||
input_schema = graph.input_schema
|
||||
if isinstance(input_schema, dict):
|
||||
properties = input_schema.get("properties", {})
|
||||
required = input_schema.get("required", [])
|
||||
|
||||
for key, schema in properties.items():
|
||||
if key not in inputs:
|
||||
input_info = {
|
||||
"name": key,
|
||||
"type": schema.get("type", "string"),
|
||||
"description": schema.get("description", ""),
|
||||
}
|
||||
|
||||
if key in required:
|
||||
required_inputs.append(input_info)
|
||||
else:
|
||||
optional_inputs.append(input_info)
|
||||
|
||||
# If there are required inputs missing, return an error
|
||||
if required_inputs:
|
||||
return ValidationErrorResponse(
|
||||
message="Missing required inputs for agent execution",
|
||||
session_id=session_id,
|
||||
error="Missing required inputs",
|
||||
details={
|
||||
"missing_inputs": required_inputs,
|
||||
"optional_inputs": optional_inputs,
|
||||
},
|
||||
)
|
||||
|
||||
# Execute the graph
|
||||
logger.info(
|
||||
f"Executing agent {graph.name} (ID: {graph.id}) for user {user_id}"
|
||||
)
|
||||
logger.info(
|
||||
f"Final credentials being passed: {list(input_credentials.keys())}"
|
||||
)
|
||||
for key, cred in input_credentials.items():
|
||||
logger.debug(
|
||||
f" {key}: id={cred.id}, provider={cred.provider}, type={cred.type}"
|
||||
)
|
||||
|
||||
graph_exec = await execution_utils.add_graph_execution(
|
||||
graph_id=graph.id,
|
||||
user_id=user_id,
|
||||
inputs=inputs,
|
||||
graph_version=graph.version,
|
||||
graph_credentials_inputs=input_credentials,
|
||||
)
|
||||
|
||||
result = ExecutionStartedResponse(
|
||||
message=f"Agent '{graph.name}' execution started",
|
||||
execution_id=graph_exec.id,
|
||||
graph_id=graph.id,
|
||||
graph_name=graph.name,
|
||||
status="QUEUED",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Optionally wait for completion (with timeout)
|
||||
if wait_for_result:
|
||||
logger.info(f"Waiting for execution {graph_exec.id} to complete...")
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
timeout = 30 # 30 seconds max wait
|
||||
|
||||
while asyncio.get_event_loop().time() - start_time < timeout:
|
||||
# Get execution status
|
||||
exec_status = await get_graph_execution_meta(user_id, graph_exec.id)
|
||||
|
||||
if exec_status and exec_status.status in [
|
||||
"COMPLETED",
|
||||
"FAILED",
|
||||
]:
|
||||
result.status = exec_status.status
|
||||
result.ended_at = (
|
||||
exec_status.ended_at.isoformat()
|
||||
if exec_status.ended_at
|
||||
else None
|
||||
)
|
||||
|
||||
if exec_status.status == "COMPLETED":
|
||||
result.message = "Agent completed successfully"
|
||||
|
||||
# Try to get outputs
|
||||
try:
|
||||
full_exec = await get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=graph_exec.id,
|
||||
include_node_executions=True,
|
||||
)
|
||||
if (
|
||||
full_exec
|
||||
and hasattr(full_exec, "output_data")
|
||||
and full_exec.output_data # type: ignore
|
||||
):
|
||||
result.outputs = full_exec.output_data # type: ignore
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get execution outputs: {e}")
|
||||
else:
|
||||
result.message = "Agent execution failed"
|
||||
if (
|
||||
hasattr(exec_status, "stats")
|
||||
and exec_status.stats
|
||||
and hasattr(exec_status.stats, "error")
|
||||
):
|
||||
result.error = exec_status.stats.error
|
||||
break
|
||||
|
||||
# Wait before checking again
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
# Timeout reached
|
||||
result.status = "RUNNING"
|
||||
result.message = "Execution still running. Check status later."
|
||||
result.timeout_reached = True
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing agent: {e}", exc_info=True)
|
||||
|
||||
# Check for specific error types
|
||||
if "validation" in str(e).lower():
|
||||
return ValidationErrorResponse(
|
||||
message="Input validation failed",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return ErrorResponse(
|
||||
message=f"Failed to execute agent: {e!s}",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -0,0 +1,584 @@
|
||||
"""Tool for setting up an agent with credentials and configuration."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import pytz
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||
from backend.integrations.webhooks.utils import setup_webhook_for_block
|
||||
from backend.sdk.registry import AutoRegistry
|
||||
from backend.server.v2.library import db as library_db
|
||||
from backend.util.clients import get_scheduler_client
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
ErrorResponse,
|
||||
PresetCreatedResponse,
|
||||
ScheduleCreatedResponse,
|
||||
ToolResponseBase,
|
||||
WebhookCreatedResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SetupAgentTool(BaseTool):
|
||||
"""Tool for setting up an agent with scheduled execution or webhook triggers."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "setup_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """Set up an agent with credentials and configure it for scheduled execution or webhook triggers.
|
||||
IMPORTANT: Before calling this tool, you MUST first call get_agent_details to determine what inputs are required.
|
||||
|
||||
For SCHEDULED execution:
|
||||
- Cron format: "minute hour day month weekday" (e.g., "0 9 * * 1-5" = 9am weekdays)
|
||||
- Common patterns: "0 * * * *" (hourly), "0 0 * * *" (daily at midnight), "0 9 * * 1" (Mondays at 9am)
|
||||
- Timezone: Use IANA timezone names like "America/New_York", "Europe/London", "Asia/Tokyo"
|
||||
- The 'inputs' parameter must contain ALL required inputs from get_agent_details as a dictionary
|
||||
|
||||
For WEBHOOK triggers:
|
||||
- The agent will be triggered by external events
|
||||
- Still requires all input values from get_agent_details"""
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": "The agent ID (graph ID) to set up",
|
||||
},
|
||||
"setup_type": {
|
||||
"type": "string",
|
||||
"enum": ["schedule", "webhook", "preset"],
|
||||
"description": "Type of setup: 'schedule' for cron, 'webhook' for triggers, 'preset' for saved configuration",
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name for this setup/schedule (e.g., 'Daily Report', 'Weekly Summary')",
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Description of this setup",
|
||||
},
|
||||
"cron": {
|
||||
"type": "string",
|
||||
"description": "Cron expression (5 fields: minute hour day month weekday). Examples: '0 9 * * 1-5' (9am weekdays), '*/30 * * * *' (every 30 min)",
|
||||
},
|
||||
"timezone": {
|
||||
"type": "string",
|
||||
"description": "IANA timezone (e.g., 'America/New_York', 'Europe/London', 'UTC'). Defaults to UTC if not specified.",
|
||||
},
|
||||
"inputs": {
|
||||
"type": "object",
|
||||
"description": 'REQUIRED: Dictionary with ALL required inputs from get_agent_details. Format: {"input_name": value}',
|
||||
"additionalProperties": True,
|
||||
},
|
||||
"credentials": {
|
||||
"type": "object",
|
||||
"description": "Credentials configuration",
|
||||
"additionalProperties": True,
|
||||
},
|
||||
"webhook_config": {
|
||||
"type": "object",
|
||||
"description": "Webhook configuration (required if setup_type is 'webhook')",
|
||||
"additionalProperties": True,
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "setup_type"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
"""This tool requires authentication."""
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Set up an agent with configuration.
|
||||
|
||||
Args:
|
||||
user_id: Authenticated user ID
|
||||
session_id: Chat session ID
|
||||
**kwargs: Setup parameters
|
||||
|
||||
Returns:
|
||||
JSON formatted setup result
|
||||
|
||||
"""
|
||||
# Check if user is authenticated (required for setting up agents)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required to set up agents",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
setup_type = kwargs.get("setup_type", "").strip()
|
||||
name = kwargs.get("name", f"Setup for {agent_id}")
|
||||
description = kwargs.get("description", "")
|
||||
inputs = kwargs.get("inputs", {})
|
||||
credentials = kwargs.get("credentials", {})
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide an agent ID",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not setup_type:
|
||||
return ErrorResponse(
|
||||
message="Please specify setup type: 'schedule', 'webhook', or 'preset'",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if agent_id looks like a marketplace slug
|
||||
graph = None
|
||||
marketplace_graph = None
|
||||
|
||||
if "/" in agent_id:
|
||||
# Looks like a marketplace slug, try to get from store first
|
||||
from backend.server.v2.store import db as store_db
|
||||
|
||||
try:
|
||||
username, agent_name = agent_id.split("/", 1)
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
username, agent_name
|
||||
)
|
||||
if agent_details:
|
||||
# Get the graph from the store listing version
|
||||
graph_meta = await store_db.get_available_graph(
|
||||
agent_details.store_listing_version_id
|
||||
)
|
||||
marketplace_graph = await graph_db.get_graph(
|
||||
graph_id=graph_meta.id,
|
||||
version=graph_meta.version,
|
||||
user_id=None, # Public access
|
||||
include_subgraphs=True,
|
||||
)
|
||||
logger.info(f"Found marketplace agent by slug: {agent_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to get agent by slug: {e}")
|
||||
|
||||
# If we have a marketplace graph from the slug lookup, handle it
|
||||
if marketplace_graph:
|
||||
# Check if already in user's library
|
||||
library_agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id=user_id,
|
||||
graph_id=marketplace_graph.id,
|
||||
graph_version=marketplace_graph.version,
|
||||
)
|
||||
|
||||
if library_agent:
|
||||
logger.info(
|
||||
f"Agent {agent_id} already in user library, using existing entry"
|
||||
)
|
||||
# Get the graph from the library agent
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=library_agent.graph_id,
|
||||
version=library_agent.graph_version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Adding marketplace agent {agent_id} to user library")
|
||||
await library_db.create_library_agent(
|
||||
graph=marketplace_graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=True,
|
||||
)
|
||||
graph = marketplace_graph
|
||||
else:
|
||||
# Not found via slug, try as direct graph ID
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=agent_id,
|
||||
version=None, # Use latest
|
||||
user_id=user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
|
||||
if not graph:
|
||||
# Try as marketplace agent by ID
|
||||
marketplace_graph = await graph_db.get_graph(
|
||||
graph_id=agent_id,
|
||||
version=None,
|
||||
user_id=None, # Public access
|
||||
include_subgraphs=True,
|
||||
)
|
||||
|
||||
if marketplace_graph:
|
||||
# Check if already in user's library
|
||||
library_agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id=user_id,
|
||||
graph_id=marketplace_graph.id,
|
||||
graph_version=marketplace_graph.version,
|
||||
)
|
||||
|
||||
if library_agent:
|
||||
logger.info(
|
||||
f"Agent {agent_id} already in user library, using existing entry"
|
||||
)
|
||||
# Get the graph from the library agent
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id=library_agent.graph_id,
|
||||
version=library_agent.graph_version,
|
||||
user_id=user_id,
|
||||
include_subgraphs=True,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Adding marketplace agent {agent_id} to user library"
|
||||
)
|
||||
await library_db.create_library_agent(
|
||||
graph=marketplace_graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=True,
|
||||
)
|
||||
graph = marketplace_graph
|
||||
|
||||
if not graph:
|
||||
return ErrorResponse(
|
||||
message=f"Agent '{agent_id}' not found",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get system-provided credentials
|
||||
system_credentials = {}
|
||||
try:
|
||||
# Get SDK-registered credentials
|
||||
system_creds_list = AutoRegistry.get_all_credentials()
|
||||
for cred in system_creds_list:
|
||||
system_credentials[cred.provider] = cred
|
||||
|
||||
# System credentials never expire - set to far future (Unix timestamp)
|
||||
expires_at = int(
|
||||
(datetime.utcnow() + timedelta(days=36500)).timestamp()
|
||||
)
|
||||
|
||||
# Check for OpenAI
|
||||
if "openai" not in system_credentials:
|
||||
openai_key = os.getenv("OPENAI_API_KEY")
|
||||
if openai_key:
|
||||
system_credentials["openai"] = APIKeyCredentials(
|
||||
id="system-openai",
|
||||
provider="openai",
|
||||
api_key=SecretStr(openai_key),
|
||||
title="System OpenAI API Key",
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
# Check for Anthropic
|
||||
if "anthropic" not in system_credentials:
|
||||
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if anthropic_key:
|
||||
system_credentials["anthropic"] = APIKeyCredentials(
|
||||
id="system-anthropic",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr(anthropic_key),
|
||||
title="System Anthropic API Key",
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
# Check for other common providers
|
||||
for provider, env_var in [
|
||||
("groq", "GROQ_API_KEY"),
|
||||
("ollama", "OLLAMA_API_KEY"),
|
||||
("open_router", "OPEN_ROUTER_API_KEY"),
|
||||
]:
|
||||
if provider not in system_credentials:
|
||||
api_key = os.getenv(env_var)
|
||||
if api_key:
|
||||
system_credentials[provider] = APIKeyCredentials(
|
||||
id=f"system-{provider}",
|
||||
provider=provider,
|
||||
api_key=SecretStr(api_key),
|
||||
title=f"System {provider} API Key",
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"System provides credentials for: {list(system_credentials.keys())}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get system credentials: {e}")
|
||||
|
||||
# Get user credentials if authenticated
|
||||
user_credentials = {}
|
||||
try:
|
||||
from backend.integrations.creds_manager import (
|
||||
IntegrationCredentialsManager,
|
||||
)
|
||||
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
user_creds_list = await creds_manager.store.get_all_creds(user_id)
|
||||
for cred in user_creds_list:
|
||||
user_credentials[cred.provider] = cred
|
||||
logger.info(
|
||||
f"User has credentials for: {list(user_credentials.keys())}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get user credentials: {e}")
|
||||
|
||||
# Convert provided credentials to CredentialsMetaInput format
|
||||
input_credentials = {}
|
||||
for key, value in credentials.items():
|
||||
if isinstance(value, dict):
|
||||
input_credentials[key] = CredentialsMetaInput(**value)
|
||||
elif isinstance(value, str):
|
||||
# Assume it's a credential ID
|
||||
input_credentials[key] = CredentialsMetaInput(
|
||||
id=value,
|
||||
provider=key, # Use the key as provider name
|
||||
type="api_key", # Default type
|
||||
)
|
||||
|
||||
# Use the graph's aggregated credentials to properly map credentials
|
||||
# This ensures we use the same keys that the graph expects
|
||||
graph_cred_inputs = graph.aggregate_credentials_inputs()
|
||||
logger.info(
|
||||
f"Graph aggregate credentials: {list(graph_cred_inputs.keys())}"
|
||||
)
|
||||
logger.info(f"User provided credentials: {list(input_credentials.keys())}")
|
||||
|
||||
# Process each aggregated credential field
|
||||
for agg_key, (field_info, node_fields) in graph_cred_inputs.items():
|
||||
if agg_key not in input_credentials:
|
||||
# Extract provider from field_info (it's a frozenset, get the first element)
|
||||
provider_set = field_info.provider
|
||||
if (
|
||||
isinstance(provider_set, (set, frozenset))
|
||||
and len(provider_set) > 0
|
||||
):
|
||||
# Get the first provider from the set
|
||||
provider_enum = next(iter(provider_set))
|
||||
# Get the string value from the enum
|
||||
provider_name = (
|
||||
provider_enum.value
|
||||
if hasattr(provider_enum, "value")
|
||||
else str(provider_enum)
|
||||
)
|
||||
else:
|
||||
provider_name = str(provider_set) if provider_set else None
|
||||
|
||||
logger.info(
|
||||
f"Checking credential {agg_key} for provider {provider_name}"
|
||||
)
|
||||
|
||||
# Try to find credential from user or system
|
||||
credential_found = False
|
||||
|
||||
# First check user credentials
|
||||
if provider_name and provider_name in user_credentials:
|
||||
logger.info(
|
||||
f"Using user credential for {provider_name} (key: {agg_key})"
|
||||
)
|
||||
user_cred = user_credentials[provider_name]
|
||||
# Use the provider_enum we already extracted from the frozenset
|
||||
if (
|
||||
isinstance(provider_set, (set, frozenset))
|
||||
and len(provider_set) > 0
|
||||
):
|
||||
provider_enum = next(iter(provider_set))
|
||||
input_credentials[agg_key] = CredentialsMetaInput(
|
||||
id=user_cred.id,
|
||||
provider=provider_enum,
|
||||
type=(
|
||||
user_cred.type
|
||||
if hasattr(user_cred, "type")
|
||||
else "api_key"
|
||||
),
|
||||
)
|
||||
credential_found = True
|
||||
logger.info(
|
||||
f"Added user credential to input_credentials[{agg_key}]"
|
||||
)
|
||||
|
||||
# If not found in user creds, check system credentials
|
||||
if (
|
||||
not credential_found
|
||||
and provider_name
|
||||
and provider_name in system_credentials
|
||||
):
|
||||
logger.info(
|
||||
f"Using system credential for {provider_name} (key: {agg_key})"
|
||||
)
|
||||
# Use the provider_enum we already extracted from the frozenset
|
||||
if (
|
||||
isinstance(provider_set, (set, frozenset))
|
||||
and len(provider_set) > 0
|
||||
):
|
||||
provider_enum = next(iter(provider_set))
|
||||
input_credentials[agg_key] = CredentialsMetaInput(
|
||||
id=f"system-{provider_name}",
|
||||
provider=provider_enum,
|
||||
type="api_key",
|
||||
)
|
||||
credential_found = True
|
||||
logger.info(
|
||||
f"Added system credential to input_credentials[{agg_key}]"
|
||||
)
|
||||
|
||||
if not credential_found:
|
||||
logger.warning(
|
||||
f"Could not find credential for {agg_key} (provider: {provider_name}) in user or system stores"
|
||||
)
|
||||
|
||||
result = {}
|
||||
|
||||
if setup_type == "schedule":
|
||||
# Set up scheduled execution
|
||||
cron = kwargs.get("cron")
|
||||
if not cron:
|
||||
return ErrorResponse(
|
||||
message="Cron expression is required for scheduled execution",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Validate cron expression
|
||||
try:
|
||||
CronTrigger.from_crontab(cron)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Invalid cron expression '{cron}': {e!s}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Convert timezone if provided
|
||||
timezone = kwargs.get("timezone", "UTC")
|
||||
try:
|
||||
pytz.timezone(timezone)
|
||||
except Exception:
|
||||
return ErrorResponse(
|
||||
message=f"Invalid timezone '{timezone}'",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Create schedule via scheduler client
|
||||
schedule_info = await get_scheduler_client().add_execution_schedule(
|
||||
user_id=user_id,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
cron=cron,
|
||||
input_data=inputs,
|
||||
input_credentials=input_credentials,
|
||||
name=name,
|
||||
)
|
||||
|
||||
result = ScheduleCreatedResponse(
|
||||
message=f"Schedule '{name}' created successfully",
|
||||
schedule_id=schedule_info.id,
|
||||
name=name,
|
||||
cron=cron,
|
||||
timezone=timezone,
|
||||
next_run=schedule_info.next_run_time,
|
||||
graph_id=graph.id,
|
||||
graph_name=graph.name,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
elif setup_type == "webhook":
|
||||
# Set up webhook trigger
|
||||
if not graph.webhook_input_node:
|
||||
return ErrorResponse(
|
||||
message=f"Agent '{graph.name}' does not support webhook triggers",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
webhook_config = kwargs.get("webhook_config", {})
|
||||
|
||||
# Combine webhook config with credentials
|
||||
trigger_config = {**webhook_config, **input_credentials}
|
||||
|
||||
# Set up webhook
|
||||
new_webhook, feedback = await setup_webhook_for_block(
|
||||
user_id=user_id,
|
||||
trigger_block=graph.webhook_input_node.block,
|
||||
trigger_config=trigger_config,
|
||||
)
|
||||
|
||||
if not new_webhook:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to create webhook: {feedback}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Create preset with webhook
|
||||
preset = await library_db.create_preset(
|
||||
user_id=user_id,
|
||||
preset={ # type: ignore
|
||||
"graph_id": graph.id,
|
||||
"graph_version": graph.version,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"inputs": inputs,
|
||||
"credentials": input_credentials,
|
||||
"webhook_id": new_webhook.id,
|
||||
"is_active": True,
|
||||
},
|
||||
)
|
||||
|
||||
result = WebhookCreatedResponse(
|
||||
message=f"Webhook trigger '{name}' created successfully",
|
||||
webhook_id=new_webhook.id,
|
||||
webhook_url=new_webhook.webhook_url, # type: ignore
|
||||
preset_id=preset.id,
|
||||
name=name,
|
||||
graph_id=graph.id,
|
||||
graph_name=graph.name,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
elif setup_type == "preset":
|
||||
# Create a preset configuration for manual execution
|
||||
preset = await library_db.create_preset(
|
||||
user_id=user_id,
|
||||
preset={ # type: ignore
|
||||
"graph_id": graph.id,
|
||||
"graph_version": graph.version,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"inputs": inputs,
|
||||
"credentials": input_credentials,
|
||||
"is_active": True,
|
||||
},
|
||||
)
|
||||
|
||||
result = PresetCreatedResponse(
|
||||
message=f"Preset configuration '{name}' created successfully",
|
||||
preset_id=preset.id,
|
||||
name=name,
|
||||
graph_id=graph.id,
|
||||
graph_name=graph.name,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
else:
|
||||
return ErrorResponse(
|
||||
message=f"Unknown setup type: {setup_type}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting up agent: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to set up agent: {e!s}",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import fastapi
|
||||
@@ -25,11 +26,29 @@ from backend.data.notifications import (
|
||||
NotificationEventModel,
|
||||
)
|
||||
from backend.notifications.notifications import queue_notification_async
|
||||
from backend.server.v2.store.embeddings import SearchFieldType, StoreAgentSearchService
|
||||
from backend.server.v2.store.embeddings import (
|
||||
SubmissionStatus as SearchSubmissionStatus,
|
||||
)
|
||||
from backend.server.v2.store.embeddings import create_embedding
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# Initialize the search service with database URL
|
||||
search_service: StoreAgentSearchService | None = None
|
||||
|
||||
|
||||
def get_search_service() -> StoreAgentSearchService:
|
||||
"""Get or create the search service instance"""
|
||||
global search_service
|
||||
if search_service is None:
|
||||
# Get database URL from environment variable (same as Prisma uses)
|
||||
db_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
search_service = StoreAgentSearchService(db_url)
|
||||
return search_service
|
||||
|
||||
|
||||
# Constants for default admin values
|
||||
DEFAULT_ADMIN_NAME = "AutoGPT Admin"
|
||||
@@ -55,6 +74,114 @@ def sanitize_query(query: str | None) -> str | None:
|
||||
)
|
||||
|
||||
|
||||
async def search_store_agents(
|
||||
search_query: str,
|
||||
limit: int = 30,
|
||||
) -> backend.server.v2.store.model.StoreAgentsResponse:
|
||||
"""
|
||||
Search for store agents using embeddings with SQLAlchemy.
|
||||
Falls back to text search if embedding creation fails.
|
||||
"""
|
||||
try:
|
||||
# Try to create embedding for semantic search
|
||||
query_embedding = await create_embedding(search_query)
|
||||
|
||||
if query_embedding is None:
|
||||
# Fallback to text-based search if embedding fails
|
||||
logger.warning(
|
||||
f"Failed to create embedding for query: {search_query}. "
|
||||
"Falling back to text search."
|
||||
)
|
||||
return await get_store_agents(
|
||||
search_query=search_query,
|
||||
page=1,
|
||||
page_size=limit,
|
||||
)
|
||||
|
||||
# Use SQLAlchemy service for vector search
|
||||
service = get_search_service()
|
||||
results = await service.search_by_embedding(query_embedding, limit=limit)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during vector search: {e}. Falling back to text search.")
|
||||
# Fallback to regular text search on any error
|
||||
return await get_store_agents(
|
||||
search_query=search_query,
|
||||
page=1,
|
||||
page_size=30,
|
||||
)
|
||||
|
||||
# Convert raw results to StoreAgent models
|
||||
agents = []
|
||||
for row in results:
|
||||
try:
|
||||
# Handle agent_image - it could be a list or a single string
|
||||
agent_image = row.get("agent_image", "")
|
||||
if isinstance(agent_image, list) and agent_image:
|
||||
agent_image = str(agent_image[0])
|
||||
elif not agent_image:
|
||||
agent_image = ""
|
||||
else:
|
||||
agent_image = str(agent_image)
|
||||
|
||||
agent = backend.server.v2.store.model.StoreAgent(
|
||||
slug=row.get("slug", ""),
|
||||
agent_name=row.get("agent_name", ""),
|
||||
agent_image=agent_image,
|
||||
creator=row.get("creator_username") or "Needs Profile",
|
||||
creator_avatar=row.get("creator_avatar") or "",
|
||||
sub_heading=row.get("sub_heading", ""),
|
||||
description=row.get("description", ""),
|
||||
runs=row.get("runs", 0),
|
||||
rating=(
|
||||
float(row.get("rating", 0.0))
|
||||
if row.get("rating") is not None
|
||||
else 0.0
|
||||
),
|
||||
)
|
||||
agents.append(agent)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating StoreAgent from search result: {e}")
|
||||
continue
|
||||
|
||||
return backend.server.v2.store.model.StoreAgentsResponse(
|
||||
agents=agents,
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
current_page=1,
|
||||
total_items=len(agents),
|
||||
total_pages=1,
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def search_agents(
|
||||
search_query: str,
|
||||
featured: bool | None = None,
|
||||
creators: list[str] | None = None,
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> backend.server.v2.store.model.StoreAgentsResponse:
|
||||
"""
|
||||
Search for store agents using embeddings with optional filters.
|
||||
Falls back to text search if embedding service is unavailable.
|
||||
"""
|
||||
try:
|
||||
# Try vector search first
|
||||
return await search_store_agents(search_query)
|
||||
except Exception as e:
|
||||
logger.error(f"Vector search failed: {e}. Using text search fallback.")
|
||||
# Fallback to text search with filters
|
||||
return await get_store_agents(
|
||||
featured=featured or False,
|
||||
creators=creators,
|
||||
search_query=search_query,
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
async def get_store_agents(
|
||||
featured: bool = False,
|
||||
creators: list[str] | None = None,
|
||||
@@ -123,7 +250,6 @@ async def get_store_agents(
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
# Skip this agent if there was an error
|
||||
# You could log the error here if needed
|
||||
logger.error(
|
||||
f"Error parsing Store agent when getting store agents from db: {e}"
|
||||
)
|
||||
@@ -1529,6 +1655,64 @@ async def review_store_submission(
|
||||
f"Failed to update store listing version {store_listing_version_id}"
|
||||
)
|
||||
|
||||
# Create embeddings if approved
|
||||
if is_approved and submission.StoreListing:
|
||||
try:
|
||||
service = get_search_service()
|
||||
|
||||
# Create embeddings for the approved listing
|
||||
fields_to_embed = [
|
||||
("name", submission.name, SearchFieldType.NAME),
|
||||
(
|
||||
"description",
|
||||
submission.description,
|
||||
SearchFieldType.DESCRIPTION,
|
||||
),
|
||||
]
|
||||
|
||||
if submission.subHeading:
|
||||
fields_to_embed.append(
|
||||
(
|
||||
"subHeading",
|
||||
submission.subHeading,
|
||||
SearchFieldType.SUBHEADING,
|
||||
)
|
||||
)
|
||||
|
||||
if submission.categories:
|
||||
categories_text = ", ".join(submission.categories)
|
||||
fields_to_embed.append(
|
||||
("categories", categories_text, SearchFieldType.CATEGORIES)
|
||||
)
|
||||
|
||||
for field_name, field_value, field_type in fields_to_embed:
|
||||
# Create embedding asynchronously
|
||||
embedding = await create_embedding(field_value)
|
||||
|
||||
# Only store if embedding was created successfully
|
||||
if embedding:
|
||||
await service.upsert_search_record(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
store_listing_id=submission.StoreListing.id,
|
||||
field_name=field_name,
|
||||
field_value=field_value,
|
||||
embedding=embedding,
|
||||
field_type=field_type,
|
||||
submission_status=SearchSubmissionStatus.APPROVED,
|
||||
is_available=submission.isAvailable,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to create embedding for {field_name} in listing {store_listing_version_id}. "
|
||||
"Search indexing skipped for this field."
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't fail the approval process
|
||||
logger.error(
|
||||
f"Error creating search embeddings for listing {store_listing_version_id}: {e}. "
|
||||
"Approval will continue without search indexing."
|
||||
)
|
||||
|
||||
# Send email notification to the agent creator
|
||||
if store_listing_version.AgentGraph and store_listing_version.AgentGraph.User:
|
||||
agent_creator = store_listing_version.AgentGraph.User
|
||||
|
||||
469
autogpt_platform/backend/backend/server/v2/store/embeddings.py
Normal file
469
autogpt_platform/backend/backend/server/v2/store/embeddings.py
Normal file
@@ -0,0 +1,469 @@
|
||||
"""Store search functionality with embeddings and pgvector using SQLAlchemy"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from uuid import uuid4
|
||||
|
||||
from openai import AsyncOpenAI, OpenAIError
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import Boolean, Column, DateTime
|
||||
from sqlalchemy import Enum as SQLEnum
|
||||
from sqlalchemy import Index, String, UniqueConstraint, and_, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
settings = Settings()
|
||||
|
||||
# Initialize async OpenAI client with proper configuration
|
||||
_openai_client: Optional[AsyncOpenAI] = None
|
||||
|
||||
|
||||
def get_openai_client() -> AsyncOpenAI:
|
||||
"""Get or create the async OpenAI client with proper configuration."""
|
||||
global _openai_client
|
||||
if _openai_client is None:
|
||||
api_key = settings.secrets.openai_api_key
|
||||
if not api_key:
|
||||
logger.warning(
|
||||
"OpenAI API key not configured. Vector search will use fallback text search."
|
||||
)
|
||||
raise ValueError(
|
||||
"OpenAI API key is not configured. Please set OPENAI_API_KEY in environment."
|
||||
)
|
||||
_openai_client = AsyncOpenAI(api_key=api_key)
|
||||
return _openai_client
|
||||
|
||||
|
||||
async def create_embedding(text: str) -> Optional[list[float]]:
|
||||
"""Create an embedding for the given text using OpenAI's API.
|
||||
|
||||
Args:
|
||||
text: The text to create an embedding for
|
||||
|
||||
Returns:
|
||||
A list of floats representing the embedding, or None if creation fails
|
||||
"""
|
||||
try:
|
||||
client = get_openai_client()
|
||||
response = await client.embeddings.create(
|
||||
input=text,
|
||||
model="text-embedding-3-small",
|
||||
)
|
||||
return response.data[0].embedding
|
||||
except ValueError as e:
|
||||
# API key not configured
|
||||
logger.error(f"OpenAI configuration error: {e}")
|
||||
return None
|
||||
except OpenAIError as e:
|
||||
# Handle specific OpenAI errors
|
||||
logger.error(f"OpenAI API error creating embedding: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error creating embedding: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# SQLAlchemy models
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class SubmissionStatus(str, Enum):
|
||||
PENDING = "PENDING"
|
||||
APPROVED = "APPROVED"
|
||||
REJECTED = "REJECTED"
|
||||
|
||||
|
||||
class SearchFieldType(str, Enum):
|
||||
NAME = "NAME"
|
||||
DESCRIPTION = "DESCRIPTION"
|
||||
CATEGORIES = "CATEGORIES"
|
||||
SUBHEADING = "SUBHEADING"
|
||||
|
||||
|
||||
class StoreAgentSearch(Base):
|
||||
__tablename__ = "StoreAgentSearch"
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid4()))
|
||||
createdAt = Column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
|
||||
)
|
||||
updatedAt = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Relations (foreign keys exist in DB but not modeled here)
|
||||
storeListingVersionId = Column(String, nullable=False)
|
||||
storeListingId = Column(String, nullable=False)
|
||||
|
||||
# Searchable fields
|
||||
fieldName = Column(String, nullable=False)
|
||||
fieldValue = Column(String, nullable=False)
|
||||
|
||||
# Vector embedding for similarity search
|
||||
# text-embedding-3-small produces 1536-dimensional embeddings
|
||||
embedding = Column(Vector(1536), nullable=False)
|
||||
|
||||
# Metadata
|
||||
fieldType = Column(
|
||||
SQLEnum(SearchFieldType, name="SearchFieldType", schema="platform"),
|
||||
nullable=False,
|
||||
)
|
||||
submissionStatus = Column(
|
||||
SQLEnum(SubmissionStatus, name="SubmissionStatus", schema="platform"),
|
||||
nullable=False,
|
||||
)
|
||||
isAvailable = Column(Boolean, nullable=False)
|
||||
|
||||
# Constraints and schema
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"storeListingVersionId",
|
||||
"fieldName",
|
||||
name="_store_listing_version_field_unique",
|
||||
),
|
||||
Index("ix_store_agent_search_listing_version", "storeListingVersionId"),
|
||||
Index("ix_store_agent_search_listing", "storeListingId"),
|
||||
Index("ix_store_agent_search_field_name", "fieldName"),
|
||||
Index("ix_store_agent_search_field_type", "fieldType"),
|
||||
Index(
|
||||
"ix_store_agent_search_status_available", "submissionStatus", "isAvailable"
|
||||
),
|
||||
{"schema": "platform"}, # Specify the schema
|
||||
)
|
||||
|
||||
|
||||
class StoreAgentSearchService:
|
||||
"""Service class for Store Agent Search operations using SQLAlchemy"""
|
||||
|
||||
def __init__(self, database_url: str):
|
||||
"""Initialize the search service with async database connection"""
|
||||
# Parse the URL to handle schema and other params separately
|
||||
from urllib.parse import parse_qs, urlparse, urlunparse
|
||||
|
||||
parsed = urlparse(database_url)
|
||||
query_params = parse_qs(parsed.query)
|
||||
|
||||
# Extract schema if present
|
||||
schema = query_params.pop("schema", ["public"])[0]
|
||||
|
||||
# Remove connect_timeout from query params (will be handled in connect_args)
|
||||
connect_timeout = query_params.pop("connect_timeout", [None])[0]
|
||||
|
||||
# Rebuild query string without schema and connect_timeout
|
||||
new_query = "&".join([f"{k}={v[0]}" for k, v in query_params.items()])
|
||||
|
||||
# Rebuild URL without schema and connect_timeout parameters
|
||||
clean_url = urlunparse(
|
||||
(
|
||||
parsed.scheme,
|
||||
parsed.netloc,
|
||||
parsed.path,
|
||||
parsed.params,
|
||||
new_query,
|
||||
parsed.fragment,
|
||||
)
|
||||
)
|
||||
|
||||
# Convert to async URL for asyncpg
|
||||
if clean_url.startswith("postgresql://"):
|
||||
clean_url = clean_url.replace("postgresql://", "postgresql+asyncpg://")
|
||||
elif clean_url.startswith("postgres://"):
|
||||
clean_url = clean_url.replace("postgres://", "postgresql+asyncpg://")
|
||||
|
||||
# Build connect_args
|
||||
connect_args: dict[str, Any] = {"server_settings": {"search_path": schema}}
|
||||
|
||||
# Add timeout if present (asyncpg uses 'timeout' not 'connect_timeout')
|
||||
if connect_timeout:
|
||||
connect_args["timeout"] = float(connect_timeout)
|
||||
|
||||
# Create engine with schema in connect_args
|
||||
self.engine = create_async_engine(
|
||||
clean_url,
|
||||
echo=False, # Set to True for debugging SQL queries
|
||||
future=True,
|
||||
connect_args=connect_args,
|
||||
)
|
||||
|
||||
self.async_session = async_sessionmaker(
|
||||
self.engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async def create_search_record(
|
||||
self,
|
||||
store_listing_version_id: str,
|
||||
store_listing_id: str,
|
||||
field_name: str,
|
||||
field_value: str,
|
||||
embedding: list[float],
|
||||
field_type: SearchFieldType,
|
||||
submission_status: SubmissionStatus,
|
||||
is_available: bool,
|
||||
) -> Optional[StoreAgentSearch]:
|
||||
"""Create a new search record with embedding.
|
||||
|
||||
Returns:
|
||||
The created search record or None if creation fails.
|
||||
"""
|
||||
try:
|
||||
async with self.async_session() as session:
|
||||
search_record = StoreAgentSearch(
|
||||
storeListingVersionId=store_listing_version_id,
|
||||
storeListingId=store_listing_id,
|
||||
fieldName=field_name,
|
||||
fieldValue=field_value,
|
||||
embedding=embedding,
|
||||
fieldType=field_type,
|
||||
submissionStatus=submission_status,
|
||||
isAvailable=is_available,
|
||||
)
|
||||
session.add(search_record)
|
||||
await session.commit()
|
||||
await session.refresh(search_record)
|
||||
return search_record
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create search record: {e}")
|
||||
return None
|
||||
|
||||
async def batch_create_search_records(
|
||||
self, records: list[dict]
|
||||
) -> list[StoreAgentSearch]:
|
||||
"""Batch create multiple search records"""
|
||||
async with self.async_session() as session:
|
||||
search_records = []
|
||||
for record in records:
|
||||
search_record = StoreAgentSearch(
|
||||
storeListingVersionId=record["storeListingVersionId"],
|
||||
storeListingId=record["storeListingId"],
|
||||
fieldName=record["fieldName"],
|
||||
fieldValue=record["fieldValue"],
|
||||
embedding=record["embedding"],
|
||||
fieldType=record.get("fieldType", SearchFieldType.NAME),
|
||||
submissionStatus=record["submissionStatus"],
|
||||
isAvailable=record["isAvailable"],
|
||||
)
|
||||
session.add(search_record)
|
||||
search_records.append(search_record)
|
||||
|
||||
await session.commit()
|
||||
return search_records
|
||||
|
||||
async def search_by_embedding(
|
||||
self, query_embedding: List[float], limit: int = 30
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for store agents using vector similarity.
|
||||
Returns the best matching store listings based on embedding similarity.
|
||||
|
||||
Args:
|
||||
query_embedding: The embedding vector to search with
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of matching store agents with similarity scores
|
||||
|
||||
Raises:
|
||||
Exception: If the database query fails
|
||||
"""
|
||||
if not query_embedding:
|
||||
logger.warning("Empty embedding provided for search")
|
||||
return []
|
||||
|
||||
try:
|
||||
async with self.async_session() as session:
|
||||
# Use parameterized query to prevent SQL injection
|
||||
query = text(
|
||||
"""
|
||||
WITH similarity_scores AS (
|
||||
SELECT
|
||||
sas."storeListingId",
|
||||
MIN(sas.embedding <=> CAST(:embedding AS vector)) AS similarity_score
|
||||
FROM platform."StoreAgentSearch" sas
|
||||
WHERE
|
||||
sas."submissionStatus" = 'APPROVED'
|
||||
AND sas."isAvailable" = true
|
||||
GROUP BY sas."storeListingId"
|
||||
ORDER BY similarity_score
|
||||
LIMIT :limit
|
||||
)
|
||||
SELECT
|
||||
sa.listing_id,
|
||||
sa.slug,
|
||||
sa.agent_name,
|
||||
sa.agent_image,
|
||||
sa.description,
|
||||
sa.sub_heading,
|
||||
sa.featured,
|
||||
sa.runs,
|
||||
sa.rating,
|
||||
sa.creator_username,
|
||||
sa.creator_avatar,
|
||||
ss.similarity_score
|
||||
FROM similarity_scores ss
|
||||
INNER JOIN platform."StoreAgent" sa
|
||||
ON sa.listing_id = ss."storeListingId"
|
||||
ORDER BY ss.similarity_score;
|
||||
"""
|
||||
)
|
||||
|
||||
# Format embedding as PostgreSQL array safely
|
||||
embedding_str = "[" + ",".join(map(str, query_embedding)) + "]"
|
||||
|
||||
result = await session.execute(
|
||||
query, {"embedding": embedding_str, "limit": limit}
|
||||
)
|
||||
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert rows to dictionaries
|
||||
return [dict(row._mapping) for row in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"Vector search query failed: {e}")
|
||||
# Return empty results instead of propagating error
|
||||
# This allows fallback to text search
|
||||
return []
|
||||
|
||||
async def get_search_records(
|
||||
self,
|
||||
store_listing_version_id: Optional[str] = None,
|
||||
field_name: Optional[str] = None,
|
||||
is_available: Optional[bool] = None,
|
||||
) -> Sequence[StoreAgentSearch]:
|
||||
"""
|
||||
Get search records using SQLAlchemy ORM
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
stmt = select(StoreAgentSearch)
|
||||
|
||||
# Build filters
|
||||
filters = []
|
||||
if store_listing_version_id:
|
||||
filters.append(
|
||||
StoreAgentSearch.storeListingVersionId == store_listing_version_id
|
||||
)
|
||||
if field_name:
|
||||
filters.append(StoreAgentSearch.fieldName == field_name)
|
||||
if is_available is not None:
|
||||
filters.append(StoreAgentSearch.isAvailable == is_available)
|
||||
|
||||
if filters:
|
||||
stmt = stmt.where(and_(*filters))
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_search_embeddings(
|
||||
self, store_listing_version_id: str, updates: Dict[str, List[float]]
|
||||
) -> None:
|
||||
"""Update embeddings for existing search records"""
|
||||
async with self.async_session() as session:
|
||||
for field_name, embedding in updates.items():
|
||||
# For vector updates, we still need raw SQL due to pgvector
|
||||
# Use $ parameters for asyncpg
|
||||
query = text(
|
||||
"""
|
||||
UPDATE platform."StoreAgentSearch"
|
||||
SET embedding = CAST(:embedding AS vector),
|
||||
"updatedAt" = CAST(:updated_at AS TIMESTAMPTZ)
|
||||
WHERE "storeListingVersionId" = :version_id
|
||||
AND "fieldName" = :field_name
|
||||
"""
|
||||
)
|
||||
|
||||
embedding_str = "[" + ",".join(map(str, embedding)) + "]"
|
||||
|
||||
await session.execute(
|
||||
query,
|
||||
{
|
||||
"embedding": embedding_str,
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
"version_id": store_listing_version_id,
|
||||
"field_name": field_name,
|
||||
},
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def delete_search_records(self, store_listing_version_id: str) -> None:
|
||||
"""Delete all search records for a store listing version using SQLAlchemy ORM"""
|
||||
async with self.async_session() as session:
|
||||
# Use SQLAlchemy ORM for deletion
|
||||
stmt = select(StoreAgentSearch).where(
|
||||
StoreAgentSearch.storeListingVersionId == store_listing_version_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
records = result.scalars().all()
|
||||
|
||||
for record in records:
|
||||
await session.delete(record)
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def upsert_search_record(
|
||||
self,
|
||||
store_listing_version_id: str,
|
||||
store_listing_id: str,
|
||||
field_name: str,
|
||||
field_value: str,
|
||||
embedding: List[float],
|
||||
field_type: SearchFieldType,
|
||||
submission_status: SubmissionStatus,
|
||||
is_available: bool,
|
||||
) -> Optional[StoreAgentSearch]:
|
||||
"""Upsert a search record (update if exists, create if not).
|
||||
|
||||
Returns:
|
||||
The upserted search record or None if operation fails.
|
||||
"""
|
||||
try:
|
||||
async with self.async_session() as session:
|
||||
# Check if record exists
|
||||
stmt = select(StoreAgentSearch).where(
|
||||
and_(
|
||||
StoreAgentSearch.storeListingVersionId
|
||||
== store_listing_version_id,
|
||||
StoreAgentSearch.fieldName == field_name,
|
||||
)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
existing_record = result.scalar_one_or_none()
|
||||
|
||||
if existing_record:
|
||||
# Update existing record
|
||||
existing_record.fieldValue = field_value # type: ignore[attr-defined]
|
||||
existing_record.embedding = embedding # type: ignore[attr-defined]
|
||||
existing_record.fieldType = field_type # type: ignore[attr-defined]
|
||||
existing_record.submissionStatus = submission_status # type: ignore[attr-defined]
|
||||
existing_record.isAvailable = is_available # type: ignore[attr-defined]
|
||||
existing_record.updatedAt = datetime.now(timezone.utc) # type: ignore[attr-defined]
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(existing_record)
|
||||
return existing_record
|
||||
else:
|
||||
# Create new record
|
||||
return await self.create_search_record(
|
||||
store_listing_version_id=store_listing_version_id,
|
||||
store_listing_id=store_listing_id,
|
||||
field_name=field_name,
|
||||
field_value=field_value,
|
||||
embedding=embedding,
|
||||
field_type=field_type,
|
||||
submission_status=submission_status,
|
||||
is_available=is_available,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upsert search record: {e}")
|
||||
return None
|
||||
|
||||
async def close(self):
|
||||
"""Close the database connection"""
|
||||
await self.engine.dispose()
|
||||
@@ -234,7 +234,7 @@ def test_get_agents_search(
|
||||
page_size=20,
|
||||
),
|
||||
)
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.get_store_agents")
|
||||
mock_db_call = mocker.patch("backend.server.v2.store.db.search_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
response = client.get("/agents?search_query=specific")
|
||||
assert response.status_code == 200
|
||||
@@ -246,10 +246,9 @@ def test_get_agents_search(
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(json.dumps(response.json(), indent=2), "agts_search")
|
||||
mock_db_call.assert_called_once_with(
|
||||
search_query="specific",
|
||||
featured=False,
|
||||
creators=None,
|
||||
sorted_by=None,
|
||||
search_query="specific",
|
||||
category=None,
|
||||
page=1,
|
||||
page_size=20,
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
-- CreateEnum
|
||||
CREATE TYPE "ChatMessageRole" AS ENUM ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "ChatSession" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"userId" TEXT NOT NULL,
|
||||
|
||||
CONSTRAINT "ChatSession_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "ChatMessage" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"content" TEXT NOT NULL,
|
||||
"role" "ChatMessageRole" NOT NULL,
|
||||
"toolCallId" TEXT,
|
||||
"toolCalls" JSONB,
|
||||
"promptTokens" INTEGER,
|
||||
"completionTokens" INTEGER,
|
||||
"totalTokens" INTEGER,
|
||||
"metadata" JSONB DEFAULT '{}',
|
||||
"sequence" INTEGER NOT NULL,
|
||||
"error" TEXT,
|
||||
"sessionId" TEXT NOT NULL,
|
||||
"parentId" TEXT,
|
||||
|
||||
CONSTRAINT "ChatMessage_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ChatSession_userId_createdAt_idx" ON "ChatSession"("userId", "createdAt");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ChatMessage_sessionId_sequence_idx" ON "ChatMessage"("sessionId", "sequence");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ChatMessage_role_idx" ON "ChatMessage"("role");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "ChatMessage_createdAt_idx" ON "ChatMessage"("createdAt");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ChatSession" ADD CONSTRAINT "ChatSession_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ChatMessage" ADD CONSTRAINT "ChatMessage_sessionId_fkey" FOREIGN KEY ("sessionId") REFERENCES "ChatSession"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "ChatMessage" ADD CONSTRAINT "ChatMessage_parentId_fkey" FOREIGN KEY ("parentId") REFERENCES "ChatMessage"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
@@ -0,0 +1,49 @@
|
||||
-- Enable pgvector extension
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "SearchFieldType" AS ENUM ('NAME', 'DESCRIPTION', 'CATEGORIES', 'SUBHEADING');
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "StoreAgentSearch" (
|
||||
"id" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updatedAt" TIMESTAMP(3) NOT NULL,
|
||||
"storeListingVersionId" TEXT NOT NULL,
|
||||
"storeListingId" TEXT NOT NULL,
|
||||
"fieldName" TEXT NOT NULL,
|
||||
"fieldValue" TEXT NOT NULL,
|
||||
"embedding" vector(1536),
|
||||
"fieldType" "SearchFieldType" NOT NULL,
|
||||
"submissionStatus" "SubmissionStatus" NOT NULL,
|
||||
"isAvailable" BOOLEAN NOT NULL,
|
||||
|
||||
CONSTRAINT "StoreAgentSearch_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreAgentSearch_storeListingVersionId_idx" ON "StoreAgentSearch"("storeListingVersionId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreAgentSearch_storeListingId_idx" ON "StoreAgentSearch"("storeListingId");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreAgentSearch_fieldName_idx" ON "StoreAgentSearch"("fieldName");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreAgentSearch_fieldType_idx" ON "StoreAgentSearch"("fieldType");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "StoreAgentSearch_submissionStatus_isAvailable_idx" ON "StoreAgentSearch"("submissionStatus", "isAvailable");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "StoreAgentSearch_storeListingVersionId_fieldName_key" ON "StoreAgentSearch"("storeListingVersionId", "fieldName");
|
||||
|
||||
-- Create HNSW index for vector similarity search (pgvector)
|
||||
CREATE INDEX "StoreAgentSearch_embedding_idx" ON "StoreAgentSearch" USING hnsw (embedding vector_cosine_ops);
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "StoreAgentSearch" ADD CONSTRAINT "StoreAgentSearch_storeListingVersionId_fkey" FOREIGN KEY ("storeListingVersionId") REFERENCES "StoreListingVersion"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "StoreAgentSearch" ADD CONSTRAINT "StoreAgentSearch_storeListingId_fkey" FOREIGN KEY ("storeListingId") REFERENCES "StoreListing"("id") ON DELETE CASCADE ON UPDATE CASCADE;
|
||||
84
autogpt_platform/backend/poetry.lock
generated
84
autogpt_platform/backend/poetry.lock
generated
@@ -311,6 +311,73 @@ files = [
|
||||
{file = "async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asyncpg"
|
||||
version = "0.30.0"
|
||||
description = "An asyncio PostgreSQL driver"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bfb4dd5ae0699bad2b233672c8fc5ccbd9ad24b89afded02341786887e37927e"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dc1f62c792752a49f88b7e6f774c26077091b44caceb1983509edc18a2222ec0"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3152fef2e265c9c24eec4ee3d22b4f4d2703d30614b0b6753e9ed4115c8a146f"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7255812ac85099a0e1ffb81b10dc477b9973345793776b128a23e60148dd1af"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:578445f09f45d1ad7abddbff2a3c7f7c291738fdae0abffbeb737d3fc3ab8b75"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c42f6bb65a277ce4d93f3fba46b91a265631c8df7250592dd4f11f8b0152150f"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-win32.whl", hash = "sha256:aa403147d3e07a267ada2ae34dfc9324e67ccc4cdca35261c8c22792ba2b10cf"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb622c94db4e13137c4c7f98834185049cc50ee01d8f657ef898b6407c7b9c50"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5e0511ad3dec5f6b4f7a9e063591d407eee66b88c14e2ea636f187da1dcfff6a"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:915aeb9f79316b43c3207363af12d0e6fd10776641a7de8a01212afd95bdf0ed"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c198a00cce9506fcd0bf219a799f38ac7a237745e1d27f0e1f66d3707c84a5a"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3326e6d7381799e9735ca2ec9fd7be4d5fef5dcbc3cb555d8a463d8460607956"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:51da377487e249e35bd0859661f6ee2b81db11ad1f4fc036194bc9cb2ead5056"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bc6d84136f9c4d24d358f3b02be4b6ba358abd09f80737d1ac7c444f36108454"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-win32.whl", hash = "sha256:574156480df14f64c2d76450a3f3aaaf26105869cad3865041156b38459e935d"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-win_amd64.whl", hash = "sha256:3356637f0bd830407b5597317b3cb3571387ae52ddc3bca6233682be88bbbc1f"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c902a60b52e506d38d7e80e0dd5399f657220f24635fee368117b8b5fce1142e"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aca1548e43bbb9f0f627a04666fedaca23db0a31a84136ad1f868cb15deb6e3a"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c2a2ef565400234a633da0eafdce27e843836256d40705d83ab7ec42074efb3"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1292b84ee06ac8a2ad8e51c7475aa309245874b61333d97411aab835c4a2f737"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f5712350388d0cd0615caec629ad53c81e506b1abaaf8d14c93f54b35e3595a"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:db9891e2d76e6f425746c5d2da01921e9a16b5a71a1c905b13f30e12a257c4af"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-win32.whl", hash = "sha256:68d71a1be3d83d0570049cd1654a9bdfe506e794ecc98ad0873304a9f35e411e"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a0292c6af5c500523949155ec17b7fe01a00ace33b68a476d6b5059f9630305"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05b185ebb8083c8568ea8a40e896d5f7af4b8554b64d7719c0eaa1eb5a5c3a70"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c47806b1a8cbb0a0db896f4cd34d89942effe353a5035c62734ab13b9f938da3"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b6fde867a74e8c76c71e2f64f80c64c0f3163e687f1763cfaf21633ec24ec33"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46973045b567972128a27d40001124fbc821c87a6cade040cfcd4fa8a30bcdc4"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9110df111cabc2ed81aad2f35394a00cadf4f2e0635603db6ebbd0fc896f46a4"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04ff0785ae7eed6cc138e73fc67b8e51d54ee7a3ce9b63666ce55a0bf095f7ba"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-win32.whl", hash = "sha256:ae374585f51c2b444510cdf3595b97ece4f233fde739aa14b50e0d64e8a7a590"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:29ff1fc8b5bf724273782ff8b4f57b0f8220a1b2324184846b39d1ab4122031d"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:64e899bce0600871b55368b8483e5e3e7f1860c9482e7f12e0a771e747988168"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b290f4726a887f75dcd1b3006f484252db37602313f806e9ffc4e5996cfe5cb"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f86b0e2cd3f1249d6fe6fd6cfe0cd4538ba994e2d8249c0491925629b9104d0f"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:393af4e3214c8fa4c7b86da6364384c0d1b3298d45803375572f415b6f673f38"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:fd4406d09208d5b4a14db9a9dbb311b6d7aeeab57bded7ed2f8ea41aeef39b34"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-win32.whl", hash = "sha256:0b448f0150e1c3b96cb0438a0d0aa4871f1472e58de14a3ec320dbb2798fb0d4"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-win_amd64.whl", hash = "sha256:f23b836dd90bea21104f69547923a02b167d999ce053f3d502081acea2fba15b"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f4e83f067b35ab5e6371f8a4c93296e0439857b4569850b178a01385e82e9ad"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5df69d55add4efcd25ea2a3b02025b669a285b767bfbf06e356d68dbce4234ff"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3479a0d9a852c7c84e822c073622baca862d1217b10a02dd57ee4a7a081f708"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26683d3b9a62836fad771a18ecf4659a30f348a561279d6227dab96182f46144"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1b982daf2441a0ed314bd10817f1606f1c28b1136abd9e4f11335358c2c631cb"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1c06a3a50d014b303e5f6fc1e5f95eb28d2cee89cf58384b700da621e5d5e547"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-win32.whl", hash = "sha256:1b11a555a198b08f5c4baa8f8231c74a366d190755aa4f99aacec5970afe929a"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-win_amd64.whl", hash = "sha256:8b684a3c858a83cd876f05958823b68e8d14ec01bb0c0d14a6704c5bf9711773"},
|
||||
{file = "asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
async-timeout = {version = ">=4.0.3", markers = "python_version < \"3.11.0\""}
|
||||
|
||||
[package.extras]
|
||||
docs = ["Sphinx (>=8.1.3,<8.2.0)", "sphinx-rtd-theme (>=1.2.2)"]
|
||||
gssauth = ["gssapi ; platform_system != \"Windows\"", "sspilib ; platform_system == \"Windows\""]
|
||||
test = ["distro (>=1.9.0,<1.10.0)", "flake8 (>=6.1,<7.0)", "flake8-pyi (>=24.1.0,<24.2.0)", "gssapi ; platform_system == \"Linux\"", "k5test ; platform_system == \"Linux\"", "mypy (>=1.8.0,<1.9.0)", "sspilib ; platform_system == \"Windows\"", "uvloop (>=0.15.3) ; platform_system != \"Windows\" and python_version < \"3.14.0\""]
|
||||
|
||||
[[package]]
|
||||
name = "attrs"
|
||||
version = "25.3.0"
|
||||
@@ -3705,6 +3772,21 @@ all = ["pbs-installer[download,install]"]
|
||||
download = ["httpx (>=0.27.0,<1)"]
|
||||
install = ["zstandard (>=0.21.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pgvector"
|
||||
version = "0.4.1"
|
||||
description = "pgvector support for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pgvector-0.4.1-py3-none-any.whl", hash = "sha256:34bb4e99e1b13d08a2fe82dda9f860f15ddcd0166fbb25bffe15821cbfeb7362"},
|
||||
{file = "pgvector-0.4.1.tar.gz", hash = "sha256:83d3a1c044ff0c2f1e95d13dfb625beb0b65506cfec0941bfe81fd0ad44f4003"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = "*"
|
||||
|
||||
[[package]]
|
||||
name = "pika"
|
||||
version = "1.3.2"
|
||||
@@ -7274,4 +7356,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "ff0f6f8d90793ea95f1f7008f7c845432ff46fca0937d5068b4f7cfec0ee7674"
|
||||
content-hash = "1dbb6515ca50d1ef55aa24d78499aafa84b79546091def48763213c845942645"
|
||||
|
||||
@@ -82,6 +82,8 @@ firecrawl-py = "^4.3.6"
|
||||
exa-py = "^1.14.20"
|
||||
croniter = "^6.0.0"
|
||||
stagehand = "^0.5.1"
|
||||
pgvector = "^0.4.1"
|
||||
asyncpg = "^0.30.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
aiohappyeyeballs = "^2.6.1"
|
||||
|
||||
@@ -57,6 +57,7 @@ model User {
|
||||
APIKeys APIKey[]
|
||||
IntegrationWebhooks IntegrationWebhook[]
|
||||
NotificationBatches UserNotificationBatch[]
|
||||
ChatSessions ChatSession[]
|
||||
}
|
||||
|
||||
enum OnboardingStep {
|
||||
@@ -762,6 +763,7 @@ model StoreListing {
|
||||
|
||||
// Relations
|
||||
Versions StoreListingVersion[] @relation("ListingVersions")
|
||||
SearchIndexes StoreAgentSearch[] @relation("SearchIndexes")
|
||||
|
||||
// Unique index on agentId to ensure only one listing per agent, regardless of number of versions the agent has.
|
||||
@@unique([agentGraphId])
|
||||
@@ -822,6 +824,9 @@ model StoreListingVersion {
|
||||
|
||||
// Reviews for this specific version
|
||||
Reviews StoreListingReview[]
|
||||
|
||||
// Search index entries for vector search
|
||||
SearchIndexes StoreAgentSearch[] @relation("SearchIndexes")
|
||||
|
||||
@@unique([storeListingId, version])
|
||||
@@index([storeListingId, submissionStatus, isAvailable])
|
||||
@@ -830,6 +835,55 @@ model StoreListingVersion {
|
||||
@@index([agentGraphId, agentGraphVersion]) // Non-unique index for efficient lookups
|
||||
}
|
||||
|
||||
model StoreAgentSearch {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Relation to StoreListingVersion
|
||||
storeListingVersionId String
|
||||
StoreListingVersion StoreListingVersion @relation("SearchIndexes", fields: [storeListingVersionId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Direct relation to StoreListing for easier joins
|
||||
storeListingId String
|
||||
StoreListing StoreListing @relation("SearchIndexes", fields: [storeListingId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Searchable fields that will be embedded
|
||||
fieldName String // The field being indexed (name, description, categories, subheading)
|
||||
fieldValue String // The actual text content for this field
|
||||
|
||||
// Vector embedding for similarity search using pgvector
|
||||
// text-embedding-3-small produces 1536-dimensional embeddings
|
||||
embedding Float[] // pgvector field for storing 1536-dimensional embeddings
|
||||
|
||||
// Metadata for search optimization
|
||||
fieldType SearchFieldType // Type of field for filtering
|
||||
|
||||
// Denormalized fields for filtering without joins
|
||||
submissionStatus SubmissionStatus // Copy from StoreListingVersion for filtering approved content
|
||||
isAvailable Boolean // Copy from StoreListingVersion for filtering available content
|
||||
|
||||
// Unique constraint to prevent duplicate indexing of the same field
|
||||
@@unique([storeListingVersionId, fieldName])
|
||||
|
||||
// Indexes for performance
|
||||
@@index([storeListingVersionId])
|
||||
@@index([storeListingId])
|
||||
@@index([fieldName])
|
||||
@@index([fieldType])
|
||||
@@index([submissionStatus, isAvailable]) // For filtering approved and available content
|
||||
|
||||
// Note: pgvector index will be created manually in migration for vector similarity search
|
||||
}
|
||||
|
||||
// Enum for searchable field types
|
||||
enum SearchFieldType {
|
||||
NAME // Agent name
|
||||
DESCRIPTION // Agent description
|
||||
CATEGORIES // Agent categories
|
||||
SUBHEADING // Agent subheading
|
||||
}
|
||||
|
||||
model StoreListingReview {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
@@ -892,3 +946,67 @@ enum APIKeyStatus {
|
||||
REVOKED
|
||||
SUSPENDED
|
||||
}
|
||||
|
||||
// Chat Session models for tracking conversations
|
||||
model ChatSession {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Relations
|
||||
userId String
|
||||
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
|
||||
messages ChatMessage[]
|
||||
|
||||
@@index([userId, createdAt])
|
||||
}
|
||||
|
||||
model ChatMessage {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime @updatedAt
|
||||
|
||||
// Message content
|
||||
content String @db.Text
|
||||
|
||||
// Role: user, assistant, system, tool
|
||||
role ChatMessageRole
|
||||
|
||||
// For tool messages
|
||||
toolCallId String?
|
||||
toolCalls Json? // Array of tool calls if role is assistant
|
||||
|
||||
// Token usage for this message
|
||||
promptTokens Int?
|
||||
completionTokens Int?
|
||||
totalTokens Int?
|
||||
|
||||
// Response metadata (timing, model used, etc)
|
||||
metadata Json? @default("{}")
|
||||
|
||||
// Order in conversation
|
||||
sequence Int
|
||||
|
||||
// Error tracking
|
||||
error String?
|
||||
|
||||
// Relations
|
||||
sessionId String
|
||||
ChatSession ChatSession @relation(fields: [sessionId], references: [id], onDelete: Cascade)
|
||||
|
||||
// Parent message for threading/branching conversations
|
||||
parentId String?
|
||||
parent ChatMessage? @relation("MessageThread", fields: [parentId], references: [id])
|
||||
children ChatMessage[] @relation("MessageThread")
|
||||
|
||||
@@index([sessionId, sequence])
|
||||
@@index([role])
|
||||
@@index([createdAt])
|
||||
}
|
||||
|
||||
enum ChatMessageRole {
|
||||
USER
|
||||
ASSISTANT
|
||||
SYSTEM
|
||||
TOOL
|
||||
}
|
||||
|
||||
@@ -25,7 +25,8 @@ export async function GET(request: Request) {
|
||||
const api = new BackendAPI();
|
||||
await api.createUser();
|
||||
|
||||
if (await shouldShowOnboarding()) {
|
||||
// Only show onboarding if no explicit next URL is provided
|
||||
if (next === "/" && (await shouldShowOnboarding())) {
|
||||
next = "/onboarding";
|
||||
revalidatePath("/onboarding", "layout");
|
||||
} else {
|
||||
|
||||
19
autogpt_platform/frontend/src/app/(platform)/chat/page.tsx
Normal file
19
autogpt_platform/frontend/src/app/(platform)/chat/page.tsx
Normal file
@@ -0,0 +1,19 @@
|
||||
"use client";
|
||||
|
||||
import { ChatInterface } from "@/components/chat/ChatInterface";
|
||||
|
||||
export default function ChatPage() {
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
<div className="border-b px-4 py-3">
|
||||
<h1 className="text-xl font-semibold">AI Agent Discovery Chat</h1>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Discover and interact with AI agents through natural conversation
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex-1 overflow-hidden">
|
||||
<ChatInterface className="h-full" />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -3,7 +3,7 @@ import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { BehaveAs, getBehaveAs } from "@/lib/utils";
|
||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { useForm } from "react-hook-form";
|
||||
import z from "zod";
|
||||
@@ -14,6 +14,7 @@ export function useLoginPage() {
|
||||
const [feedback, setFeedback] = useState<string | null>(null);
|
||||
const [captchaKey, setCaptchaKey] = useState(0);
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const { toast } = useToast();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||
@@ -41,8 +42,22 @@ export function useLoginPage() {
|
||||
}, [turnstile]);
|
||||
|
||||
useEffect(() => {
|
||||
if (user) router.push("/");
|
||||
}, [user]);
|
||||
if (user) {
|
||||
// Check for return URL from query params
|
||||
const returnUrl = searchParams.get("returnUrl");
|
||||
if (returnUrl) {
|
||||
router.push(decodeURIComponent(returnUrl));
|
||||
} else {
|
||||
// Check for pending chat session
|
||||
const pendingSession = localStorage.getItem("pending_chat_session");
|
||||
if (pendingSession) {
|
||||
router.push("/marketplace/discover");
|
||||
} else {
|
||||
router.push("/");
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [user, searchParams, router]);
|
||||
|
||||
async function handleProviderLogin(provider: LoginProvider) {
|
||||
setIsGoogleLoading(true);
|
||||
|
||||
@@ -3,9 +3,16 @@
|
||||
import { FilterChips } from "../FilterChips/FilterChips";
|
||||
import { SearchBar } from "../SearchBar/SearchBar";
|
||||
import { useHeroSection } from "./useHeroSection";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { MessageCircle, Sparkles } from "lucide-react";
|
||||
|
||||
export const HeroSection = () => {
|
||||
const router = useRouter();
|
||||
const { onFilterChange, searchTerms } = useHeroSection();
|
||||
|
||||
const handleDiscoverClick = () => {
|
||||
router.push("/chat");
|
||||
};
|
||||
return (
|
||||
<div className="mb-2 mt-8 flex flex-col items-center justify-center px-4 sm:mb-4 sm:mt-12 sm:px-6 md:mb-6 md:mt-16 lg:my-24 lg:px-8 xl:my-16">
|
||||
<div className="w-full max-w-3xl lg:max-w-4xl xl:max-w-5xl">
|
||||
@@ -29,6 +36,26 @@ export const HeroSection = () => {
|
||||
<h3 className="mb:text-2xl mb-6 text-center font-sans text-xl font-normal leading-loose text-neutral-700 dark:text-neutral-300 md:mb-12">
|
||||
Bringing you AI agents designed by thinkers from around the world
|
||||
</h3>
|
||||
|
||||
{/* New AI Discovery CTA */}
|
||||
<div className="mb-6 flex justify-center">
|
||||
<button
|
||||
onClick={handleDiscoverClick}
|
||||
className="group relative flex items-center gap-3 rounded-full bg-gradient-to-r from-violet-600 to-purple-600 px-8 py-4 text-white shadow-lg transition-all duration-300 hover:scale-105 hover:shadow-xl"
|
||||
>
|
||||
<MessageCircle className="h-5 w-5" />
|
||||
<span className="text-lg font-medium">
|
||||
Start AI-Powered Discovery
|
||||
</span>
|
||||
<Sparkles className="h-5 w-5 animate-pulse" />
|
||||
<div className="absolute inset-0 rounded-full bg-white opacity-0 transition-opacity duration-300 group-hover:opacity-10" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="mb-3 text-center text-sm text-neutral-600 dark:text-neutral-400">
|
||||
or search directly below
|
||||
</div>
|
||||
|
||||
<div className="mb-4 flex justify-center sm:mb-5">
|
||||
<SearchBar height="h-[74px]" />
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,270 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState, useEffect, useRef } from "react";
|
||||
import { ChatMessage } from "@/components/chat/ChatMessage";
|
||||
import { ChatInput } from "@/components/chat/ChatInput";
|
||||
import { ToolCallWidget } from "@/components/chat/ToolCallWidget";
|
||||
import { AgentDiscoveryCard } from "@/components/chat/AgentDiscoveryCard";
|
||||
import { ChatMessage as ChatMessageType } from "@/lib/autogpt-server-api/chat";
|
||||
// import { Loader2 } from "lucide-react";
|
||||
|
||||
// Demo page that simulates the chat interface without requiring authentication
|
||||
export default function DiscoverDemoPage() {
|
||||
const [messages, setMessages] = useState<ChatMessageType[]>([]);
|
||||
const [isStreaming, setIsStreaming] = useState(false);
|
||||
const [streamingContent, setStreamingContent] = useState("");
|
||||
const [toolCalls, setToolCalls] = useState<any[]>([]);
|
||||
const [discoveredAgents, setDiscoveredAgents] = useState<any[]>([]);
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Auto-scroll to bottom
|
||||
useEffect(() => {
|
||||
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||
}, [messages, streamingContent, toolCalls]);
|
||||
|
||||
// Add welcome message on mount
|
||||
useEffect(() => {
|
||||
setMessages([
|
||||
{
|
||||
content:
|
||||
"Hello! I'm your AI agent discovery assistant. I can help you find and set up the perfect AI agents for your needs. What would you like to automate today?",
|
||||
role: "ASSISTANT",
|
||||
created_at: new Date().toISOString(),
|
||||
},
|
||||
]);
|
||||
}, []);
|
||||
|
||||
const simulateResponse = async (message: string) => {
|
||||
setIsStreaming(true);
|
||||
setStreamingContent("");
|
||||
setToolCalls([]);
|
||||
|
||||
// Add user message
|
||||
const userMessage: ChatMessageType = {
|
||||
content: message,
|
||||
role: "USER",
|
||||
created_at: new Date().toISOString(),
|
||||
};
|
||||
setMessages((prev) => [...prev, userMessage]);
|
||||
|
||||
// Simulate thinking
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
|
||||
// Check for keywords and simulate appropriate response
|
||||
const lowerMessage = message.toLowerCase();
|
||||
|
||||
if (
|
||||
lowerMessage.includes("content") ||
|
||||
lowerMessage.includes("write") ||
|
||||
lowerMessage.includes("blog")
|
||||
) {
|
||||
// Simulate tool call
|
||||
setToolCalls([
|
||||
{
|
||||
id: "tool-1",
|
||||
name: "find_agent",
|
||||
parameters: { search_query: "content creation" },
|
||||
status: "calling",
|
||||
},
|
||||
]);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 1000));
|
||||
|
||||
setToolCalls([
|
||||
{
|
||||
id: "tool-1",
|
||||
name: "find_agent",
|
||||
parameters: { search_query: "content creation" },
|
||||
status: "executing",
|
||||
},
|
||||
]);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 1500));
|
||||
|
||||
setToolCalls([
|
||||
{
|
||||
id: "tool-1",
|
||||
name: "find_agent",
|
||||
parameters: { search_query: "content creation" },
|
||||
status: "completed",
|
||||
result: "Found 3 agents for content creation",
|
||||
},
|
||||
]);
|
||||
|
||||
// Simulate discovered agents
|
||||
setDiscoveredAgents([
|
||||
{
|
||||
id: "agent-001",
|
||||
version: "1.0.0",
|
||||
name: "Blog Writer Pro",
|
||||
description:
|
||||
"Generates high-quality blog posts with SEO optimization",
|
||||
creator: "AutoGPT Team",
|
||||
rating: 4.8,
|
||||
runs: 5420,
|
||||
categories: ["Content", "SEO", "Marketing"],
|
||||
},
|
||||
{
|
||||
id: "agent-002",
|
||||
version: "2.1.0",
|
||||
name: "Social Media Content Creator",
|
||||
description:
|
||||
"Creates engaging social media posts for multiple platforms",
|
||||
creator: "Community",
|
||||
rating: 4.6,
|
||||
runs: 3200,
|
||||
categories: ["Social Media", "Marketing"],
|
||||
},
|
||||
{
|
||||
id: "agent-003",
|
||||
version: "1.5.0",
|
||||
name: "Technical Documentation Writer",
|
||||
description:
|
||||
"Generates comprehensive technical documentation from code",
|
||||
creator: "DevTools Inc",
|
||||
rating: 4.9,
|
||||
runs: 2100,
|
||||
categories: ["Documentation", "Development"],
|
||||
},
|
||||
]);
|
||||
|
||||
// Simulate streaming response
|
||||
const response =
|
||||
"I found some excellent content creation agents for you! These agents can help with blog writing, social media content, and technical documentation. Each one has been highly rated by the community.";
|
||||
|
||||
for (let i = 0; i < response.length; i += 5) {
|
||||
setStreamingContent(response.substring(0, i + 5));
|
||||
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||
}
|
||||
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
content: response,
|
||||
role: "ASSISTANT",
|
||||
created_at: new Date().toISOString(),
|
||||
},
|
||||
]);
|
||||
} else if (
|
||||
lowerMessage.includes("automat") ||
|
||||
lowerMessage.includes("task")
|
||||
) {
|
||||
// Different response for automation
|
||||
const response =
|
||||
"I can help you find automation agents! What specific tasks would you like to automate? For example:\n\n- Data processing and analysis\n- Email management\n- File organization\n- Web scraping\n- Report generation\n- API integrations\n\nJust describe what you need and I'll find the perfect agent for you!";
|
||||
|
||||
for (let i = 0; i < response.length; i += 5) {
|
||||
setStreamingContent(response.substring(0, i + 5));
|
||||
await new Promise((resolve) => setTimeout(resolve, 30));
|
||||
}
|
||||
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
content: response,
|
||||
role: "ASSISTANT",
|
||||
created_at: new Date().toISOString(),
|
||||
},
|
||||
]);
|
||||
} else {
|
||||
// Generic response
|
||||
const response = `I understand you're interested in "${message}". Let me search for relevant agents that can help you with that.`;
|
||||
|
||||
for (let i = 0; i < response.length; i += 5) {
|
||||
setStreamingContent(response.substring(0, i + 5));
|
||||
await new Promise((resolve) => setTimeout(resolve, 40));
|
||||
}
|
||||
|
||||
setMessages((prev) => [
|
||||
...prev,
|
||||
{
|
||||
content: response,
|
||||
role: "ASSISTANT",
|
||||
created_at: new Date().toISOString(),
|
||||
},
|
||||
]);
|
||||
}
|
||||
|
||||
setStreamingContent("");
|
||||
setIsStreaming(false);
|
||||
};
|
||||
|
||||
const handleSendMessage = (message: string) => {
|
||||
if (!isStreaming) {
|
||||
simulateResponse(message);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSelectAgent = (agent: any) => {
|
||||
handleSendMessage(`I want to set up the agent "${agent.name}"`);
|
||||
};
|
||||
|
||||
const handleGetAgentDetails = (agent: any) => {
|
||||
handleSendMessage(`Tell me more about "${agent.name}"`);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex h-screen flex-col bg-neutral-50 dark:bg-neutral-950">
|
||||
{/* Header */}
|
||||
<div className="border-b border-neutral-200 bg-white px-4 py-3 dark:border-neutral-700 dark:bg-neutral-900">
|
||||
<div className="mx-auto max-w-4xl">
|
||||
<h1 className="text-lg font-semibold text-neutral-900 dark:text-neutral-100">
|
||||
AI Agent Discovery Assistant (Demo)
|
||||
</h1>
|
||||
<p className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||
This is a demo of the chat-based agent discovery interface
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Messages Area */}
|
||||
<div className="flex-1 overflow-y-auto">
|
||||
<div className="mx-auto max-w-4xl py-4">
|
||||
{messages.map((message, index) => (
|
||||
<ChatMessage key={index} message={message} />
|
||||
))}
|
||||
|
||||
{streamingContent && (
|
||||
<ChatMessage
|
||||
message={{
|
||||
content: streamingContent,
|
||||
role: "ASSISTANT",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{toolCalls.map((toolCall) => (
|
||||
<div key={toolCall.id} className="px-4">
|
||||
<ToolCallWidget
|
||||
toolName={toolCall.name}
|
||||
parameters={toolCall.parameters}
|
||||
result={toolCall.result}
|
||||
status={toolCall.status}
|
||||
error={toolCall.error}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
|
||||
{discoveredAgents.length > 0 && (
|
||||
<div className="px-4">
|
||||
<AgentDiscoveryCard
|
||||
agents={discoveredAgents}
|
||||
onSelectAgent={handleSelectAgent}
|
||||
onGetDetails={handleGetAgentDetails}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Input Area */}
|
||||
<ChatInput
|
||||
onSendMessage={handleSendMessage}
|
||||
isStreaming={isStreaming}
|
||||
placeholder="Try: 'I need help with content creation' or 'Show me automation agents'"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
"use client";
|
||||
|
||||
import { ChatInterface } from "@/components/chat/ChatInterface";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useEffect } from "react";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
|
||||
export default function DiscoverPage() {
|
||||
const { user } = useSupabase();
|
||||
const searchParams = useSearchParams();
|
||||
const sessionId = searchParams.get("sessionId");
|
||||
|
||||
useEffect(() => {
|
||||
// Check if we need to assign user to anonymous session after login
|
||||
const assignSessionToUser = async () => {
|
||||
// Priority 1: Session from URL (user returning from auth)
|
||||
const urlSession = sessionId;
|
||||
// Priority 2: Pending session from localStorage
|
||||
const pendingSession = localStorage.getItem("pending_chat_session");
|
||||
|
||||
const sessionToAssign = urlSession || pendingSession;
|
||||
|
||||
if (sessionToAssign && user) {
|
||||
try {
|
||||
const api = new BackendAPI();
|
||||
// Call the assign-user endpoint
|
||||
await (api as any)._request(
|
||||
"PATCH",
|
||||
`/v2/chat/sessions/${sessionToAssign}/assign-user`,
|
||||
{},
|
||||
);
|
||||
|
||||
// Clear the pending session flag
|
||||
localStorage.removeItem("pending_chat_session");
|
||||
|
||||
// The session is now owned by the user
|
||||
console.log(
|
||||
`Session ${sessionToAssign} assigned to user successfully`,
|
||||
);
|
||||
} catch (e: any) {
|
||||
// Check if error is because session already has a user
|
||||
if (e.message?.includes("already has an assigned user")) {
|
||||
console.log("Session already assigned to user, continuing...");
|
||||
} else {
|
||||
console.error("Failed to assign session to user:", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (user) {
|
||||
assignSessionToUser();
|
||||
}
|
||||
}, [user, sessionId]);
|
||||
|
||||
return (
|
||||
<div className="h-screen">
|
||||
<ChatInterface
|
||||
sessionId={sessionId || undefined}
|
||||
systemPrompt="You are a helpful assistant that helps users discover and set up AI agents from the AutoGPT marketplace. Be conversational, friendly, and guide users through finding the right agent for their needs. When users describe what they want to accomplish, search for relevant agents and present them in an engaging way. Help them understand what each agent does and guide them through the setup process."
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState } from "react";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
|
||||
export default function TestChatPage() {
|
||||
const [result, setResult] = useState<string>("");
|
||||
const [error, setError] = useState<string>("");
|
||||
const [loading, setLoading] = useState(false);
|
||||
const { supabase, user } = useSupabase();
|
||||
|
||||
const testAuth = async () => {
|
||||
setLoading(true);
|
||||
setError("");
|
||||
setResult("");
|
||||
|
||||
try {
|
||||
// Test Supabase session
|
||||
if (!supabase) {
|
||||
setError("Supabase client not initialized");
|
||||
return;
|
||||
}
|
||||
|
||||
const {
|
||||
data: { session },
|
||||
} = await supabase.auth.getSession();
|
||||
|
||||
if (!session) {
|
||||
setError("No Supabase session found. Please log in.");
|
||||
return;
|
||||
}
|
||||
|
||||
setResult(
|
||||
`Session found!\nUser ID: ${session.user.id}\nToken: ${session.access_token.substring(0, 20)}...`,
|
||||
);
|
||||
|
||||
// Test BackendAPI authentication
|
||||
const api = new BackendAPI();
|
||||
const isAuth = await api.isAuthenticated();
|
||||
setResult((prev) => prev + `\n\nBackendAPI authenticated: ${isAuth}`);
|
||||
|
||||
// Test chat API
|
||||
try {
|
||||
const chatSession = await api.chat.createSession();
|
||||
setResult(
|
||||
(prev) =>
|
||||
prev + `\n\nChat session created!\nSession ID: ${chatSession.id}`,
|
||||
);
|
||||
} catch (chatError: any) {
|
||||
setResult((prev) => prev + `\n\nChat API error: ${chatError.message}`);
|
||||
}
|
||||
} catch (err: any) {
|
||||
setError(err.message || "Unknown error");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="container mx-auto p-8">
|
||||
<h1 className="mb-4 text-2xl font-bold">Chat Authentication Test</h1>
|
||||
|
||||
<div className="mb-4">
|
||||
<p className="text-sm text-gray-600">
|
||||
User: {user?.email || "Not logged in"}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<button
|
||||
onClick={testAuth}
|
||||
disabled={loading}
|
||||
className="rounded bg-blue-500 px-4 py-2 text-white hover:bg-blue-600 disabled:bg-gray-400"
|
||||
>
|
||||
{loading ? "Testing..." : "Test Authentication"}
|
||||
</button>
|
||||
|
||||
{error && (
|
||||
<div className="mt-4 rounded border border-red-400 bg-red-100 p-4 text-red-700">
|
||||
<h3 className="font-bold">Error:</h3>
|
||||
<pre className="mt-2 text-sm">{error}</pre>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result && (
|
||||
<div className="mt-4 rounded border border-green-400 bg-green-100 p-4 text-green-700">
|
||||
<h3 className="font-bold">Result:</h3>
|
||||
<pre className="mt-2 whitespace-pre-wrap text-sm">{result}</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -3,7 +3,7 @@ import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { BehaveAs, getBehaveAs } from "@/lib/utils";
|
||||
import { LoginProvider, signupFormSchema } from "@/types/auth";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useCallback, useEffect, useState } from "react";
|
||||
import { useForm } from "react-hook-form";
|
||||
import z from "zod";
|
||||
@@ -15,6 +15,7 @@ export function useSignupPage() {
|
||||
const [captchaKey, setCaptchaKey] = useState(0);
|
||||
const { toast } = useToast();
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||
@@ -44,8 +45,22 @@ export function useSignupPage() {
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (user) router.push("/");
|
||||
}, [user]);
|
||||
if (user) {
|
||||
// Check for return URL from query params
|
||||
const returnUrl = searchParams.get("returnUrl");
|
||||
if (returnUrl) {
|
||||
router.push(decodeURIComponent(returnUrl));
|
||||
} else {
|
||||
// Check for pending chat session
|
||||
const pendingSession = localStorage.getItem("pending_chat_session");
|
||||
if (pendingSession) {
|
||||
router.push("/marketplace/discover");
|
||||
} else {
|
||||
router.push("/");
|
||||
}
|
||||
}
|
||||
}
|
||||
}, [user, searchParams, router]);
|
||||
|
||||
async function handleProviderSignup(provider: LoginProvider) {
|
||||
setIsGoogleLoading(true);
|
||||
|
||||
@@ -14,6 +14,10 @@ const API_PROXY_BASE_URL = `${FRONTEND_BASE_URL}/api/proxy`; // Sending request
|
||||
|
||||
const getBaseUrl = (): string => {
|
||||
if (!isServerSide()) {
|
||||
// In the browser, use the current origin to handle dynamic ports
|
||||
if (typeof window !== "undefined") {
|
||||
return `${window.location.origin}/api/proxy`;
|
||||
}
|
||||
return API_PROXY_BASE_URL;
|
||||
} else {
|
||||
return getAgptServerBaseUrl();
|
||||
|
||||
@@ -4682,6 +4682,379 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/v2/chat/sessions": {
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Create Session",
|
||||
"description": "Create a new chat session for the authenticated or anonymous user.\n\nArgs:\n request: Session creation parameters\n user_id: Optional authenticated user ID\n\nReturns:\n Created session details",
|
||||
"operationId": "postV2CreateSession",
|
||||
"security": [{ "HTTPBearer": [] }],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/CreateSessionRequest" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/CreateSessionResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": { "description": "Resource not found" },
|
||||
"401": { "description": "Unauthorized" },
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "List Sessions",
|
||||
"description": "List chat sessions for the authenticated user.\n\nArgs:\n limit: Maximum number of sessions to return\n offset: Number of sessions to skip\n include_last_message: Whether to include the last message\n user_id: Authenticated user ID\n\nReturns:\n List of user's chat sessions",
|
||||
"operationId": "getV2ListSessions",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "limit",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"maximum": 100,
|
||||
"minimum": 1,
|
||||
"default": 50,
|
||||
"title": "Limit"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "offset",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"default": 0,
|
||||
"title": "Offset"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "include_last_message",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"title": "Include Last Message"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/SessionListResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": { "description": "Resource not found" },
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/v2/chat/sessions/{session_id}": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Session",
|
||||
"description": "Get details of a specific chat session.\n\nArgs:\n session_id: ID of the session to retrieve\n include_messages: Whether to include all messages\n user_id: Authenticated user ID\n\nReturns:\n Session details with optional messages",
|
||||
"operationId": "getV2GetSession",
|
||||
"security": [{ "HTTPBearer": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Session Id" }
|
||||
},
|
||||
{
|
||||
"name": "include_messages",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "boolean",
|
||||
"default": true,
|
||||
"title": "Include Messages"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/SessionDetailResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": { "description": "Resource not found" },
|
||||
"401": { "description": "Unauthorized" },
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"delete": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Delete Session",
|
||||
"description": "Delete a chat session and all its messages.\n\nArgs:\n session_id: ID of the session to delete\n user_id: Authenticated user ID\n\nReturns:\n Deletion confirmation",
|
||||
"operationId": "deleteV2DeleteSession",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Session Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Response Deletev2Deletesession"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": { "description": "Resource not found" },
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/v2/chat/sessions/{session_id}/messages": {
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Send Message",
|
||||
"description": "Send a message to a chat session (non-streaming).\n\nThis endpoint processes the message and returns the complete response.\nFor streaming responses, use the /stream endpoint.\n\nArgs:\n session_id: ID of the session\n request: Message parameters\n user_id: Authenticated user ID\n\nReturns:\n Complete assistant response",
|
||||
"operationId": "postV2SendMessage",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Session Id" }
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/SendMessageRequest" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/SendMessageResponse" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": { "description": "Resource not found" },
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/v2/chat/sessions/{session_id}/stream": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Stream Chat",
|
||||
"description": "Stream chat responses using Server-Sent Events (SSE).\n\nThis endpoint streams the AI response in real-time, including:\n- Text chunks as they're generated\n- Tool call UI elements\n- Tool execution results\n\nArgs:\n session_id: ID of the session\n message: User's message\n model: AI model to use\n max_context: Maximum context messages\n user_id: Optional authenticated user ID\n\nReturns:\n SSE stream of response chunks",
|
||||
"operationId": "getV2StreamChat",
|
||||
"security": [{ "HTTPBearer": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Session Id" }
|
||||
},
|
||||
{
|
||||
"name": "message",
|
||||
"in": "query",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 10000,
|
||||
"title": "Message"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "model",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"default": "gpt-4o",
|
||||
"title": "Model"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "max_context",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "integer",
|
||||
"maximum": 100,
|
||||
"minimum": 1,
|
||||
"default": 50,
|
||||
"title": "Max Context"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": { "application/json": { "schema": {} } }
|
||||
},
|
||||
"404": { "description": "Resource not found" },
|
||||
"401": { "description": "Unauthorized" },
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/v2/chat/sessions/{session_id}/assign-user": {
|
||||
"patch": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Assign User To Session",
|
||||
"description": "Assign an authenticated user to an anonymous session.\n\nThis is called after a user logs in to claim their anonymous session.\n\nArgs:\n session_id: ID of the anonymous session\n user_id: Authenticated user ID\n\nReturns:\n Success status",
|
||||
"operationId": "patchV2AssignUserToSession",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Session Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Response Patchv2Assignusertosession"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": { "description": "Resource not found" },
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/v2/chat/health": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Health Check",
|
||||
"description": "Check if the chat service is healthy.\n\nReturns:\n Health status",
|
||||
"operationId": "getV2HealthCheck",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Response Getv2Healthcheck"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": { "description": "Resource not found" },
|
||||
"401": { "description": "Unauthorized" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/email/unsubscribe": {
|
||||
"post": {
|
||||
"tags": ["v1", "email"],
|
||||
@@ -5342,6 +5715,32 @@
|
||||
"required": ["graph"],
|
||||
"title": "CreateGraph"
|
||||
},
|
||||
"CreateSessionRequest": {
|
||||
"properties": {
|
||||
"metadata": {
|
||||
"anyOf": [
|
||||
{ "additionalProperties": true, "type": "object" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Metadata",
|
||||
"description": "Optional metadata"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"title": "CreateSessionRequest",
|
||||
"description": "Request model for creating a new chat session."
|
||||
},
|
||||
"CreateSessionResponse": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"created_at": { "type": "string", "title": "Created At" },
|
||||
"user_id": { "type": "string", "title": "User Id" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["id", "created_at", "user_id"],
|
||||
"title": "CreateSessionResponse",
|
||||
"description": "Response model for created chat session."
|
||||
},
|
||||
"Creator": {
|
||||
"properties": {
|
||||
"name": { "type": "string", "title": "Name" },
|
||||
@@ -7471,6 +7870,98 @@
|
||||
"required": ["items", "total_items", "page", "more_pages"],
|
||||
"title": "SearchResponse"
|
||||
},
|
||||
"SendMessageRequest": {
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"maxLength": 10000,
|
||||
"minLength": 1,
|
||||
"title": "Message",
|
||||
"description": "Message content"
|
||||
},
|
||||
"model": {
|
||||
"type": "string",
|
||||
"title": "Model",
|
||||
"description": "AI model to use",
|
||||
"default": "gpt-4o"
|
||||
},
|
||||
"max_context_messages": {
|
||||
"type": "integer",
|
||||
"maximum": 100.0,
|
||||
"minimum": 1.0,
|
||||
"title": "Max Context Messages",
|
||||
"description": "Max context messages",
|
||||
"default": 50
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["message"],
|
||||
"title": "SendMessageRequest",
|
||||
"description": "Request model for sending a chat message."
|
||||
},
|
||||
"SendMessageResponse": {
|
||||
"properties": {
|
||||
"message_id": { "type": "string", "title": "Message Id" },
|
||||
"content": { "type": "string", "title": "Content" },
|
||||
"role": { "type": "string", "title": "Role" },
|
||||
"tokens_used": {
|
||||
"anyOf": [
|
||||
{ "additionalProperties": true, "type": "object" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Tokens Used"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["message_id", "content", "role"],
|
||||
"title": "SendMessageResponse",
|
||||
"description": "Response model for non-streaming message."
|
||||
},
|
||||
"SessionDetailResponse": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
"created_at": { "type": "string", "title": "Created At" },
|
||||
"updated_at": { "type": "string", "title": "Updated At" },
|
||||
"user_id": { "type": "string", "title": "User Id" },
|
||||
"messages": {
|
||||
"items": { "additionalProperties": true, "type": "object" },
|
||||
"type": "array",
|
||||
"title": "Messages"
|
||||
},
|
||||
"metadata": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Metadata"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
"id",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"user_id",
|
||||
"messages",
|
||||
"metadata"
|
||||
],
|
||||
"title": "SessionDetailResponse",
|
||||
"description": "Response model for session details."
|
||||
},
|
||||
"SessionListResponse": {
|
||||
"properties": {
|
||||
"sessions": {
|
||||
"items": { "additionalProperties": true, "type": "object" },
|
||||
"type": "array",
|
||||
"title": "Sessions"
|
||||
},
|
||||
"total": { "type": "integer", "title": "Total" },
|
||||
"limit": { "type": "integer", "title": "Limit" },
|
||||
"offset": { "type": "integer", "title": "Offset" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["sessions", "total", "limit", "offset"],
|
||||
"title": "SessionListResponse",
|
||||
"description": "Response model for session list."
|
||||
},
|
||||
"SetGraphActiveVersion": {
|
||||
"properties": {
|
||||
"active_graph_version": {
|
||||
@@ -9621,6 +10112,7 @@
|
||||
"scheme": "bearer",
|
||||
"bearerFormat": "jwt"
|
||||
},
|
||||
"HTTPBearer": { "type": "http", "scheme": "bearer" },
|
||||
"APIKeyAuthenticator-X-Postmark-Webhook-Token": {
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
|
||||
@@ -2,6 +2,7 @@ import {
|
||||
ApiError,
|
||||
makeAuthenticatedFileUpload,
|
||||
makeAuthenticatedRequest,
|
||||
getServerAuthToken,
|
||||
} from "@/lib/autogpt-server-api/helpers";
|
||||
import { getAgptServerBaseUrl } from "@/lib/env-config";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
@@ -161,6 +162,58 @@ async function handler(
|
||||
const method = req.method;
|
||||
const contentType = req.headers.get("Content-Type");
|
||||
|
||||
// Special handling for SSE streaming endpoints
|
||||
const isStreamingEndpoint = path.some((segment) => segment === "stream");
|
||||
if (isStreamingEndpoint && method === "GET") {
|
||||
try {
|
||||
const token = await getServerAuthToken();
|
||||
const headers: HeadersInit = {};
|
||||
|
||||
if (token && token !== "no-token-found") {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
// Forward the request to the backend
|
||||
const response = await fetch(backendUrl, {
|
||||
method: "GET",
|
||||
headers,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to stream: ${error}` },
|
||||
{ status: response.status },
|
||||
);
|
||||
}
|
||||
|
||||
// Stream the SSE response directly
|
||||
const stream = response.body;
|
||||
if (!stream) {
|
||||
return NextResponse.json(
|
||||
{ error: "No stream available" },
|
||||
{ status: 502 },
|
||||
);
|
||||
}
|
||||
|
||||
// Return the streaming response with proper SSE headers
|
||||
return new NextResponse(stream, {
|
||||
status: 200,
|
||||
headers: {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
Connection: "keep-alive",
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
return createErrorResponse(
|
||||
error,
|
||||
path.map((s) => `/${s}`).join(""),
|
||||
method,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let responseBody: any;
|
||||
const responseHeaders: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
|
||||
272
autogpt_platform/frontend/src/components/chat/AgentCarousel.tsx
Normal file
272
autogpt_platform/frontend/src/components/chat/AgentCarousel.tsx
Normal file
@@ -0,0 +1,272 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState, useRef, useEffect } from "react";
|
||||
import Image from "next/image";
|
||||
import {
|
||||
ChevronLeft,
|
||||
ChevronRight,
|
||||
Star,
|
||||
Play,
|
||||
Info,
|
||||
Sparkles,
|
||||
} from "lucide-react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface Agent {
|
||||
id: string;
|
||||
name: string;
|
||||
sub_heading: string;
|
||||
description: string;
|
||||
creator: string;
|
||||
creator_avatar?: string;
|
||||
agent_image?: string;
|
||||
rating?: number;
|
||||
runs?: number;
|
||||
}
|
||||
|
||||
interface AgentCarouselProps {
|
||||
agents: Agent[];
|
||||
query: string;
|
||||
onSelectAgent: (agent: Agent) => void;
|
||||
onGetDetails: (agent: Agent) => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function AgentCarousel({
|
||||
agents,
|
||||
query,
|
||||
onSelectAgent,
|
||||
onGetDetails,
|
||||
className,
|
||||
}: AgentCarouselProps) {
|
||||
const [currentIndex, setCurrentIndex] = useState(0);
|
||||
const [isAutoScrolling, setIsAutoScrolling] = useState(true);
|
||||
const carouselRef = useRef<HTMLDivElement>(null);
|
||||
const scrollContainerRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Deduplicate agents by ID
|
||||
const uniqueAgents = React.useMemo(() => {
|
||||
const seen = new Set<string>();
|
||||
return agents.filter((agent) => {
|
||||
if (seen.has(agent.id)) {
|
||||
return false;
|
||||
}
|
||||
seen.add(agent.id);
|
||||
return true;
|
||||
});
|
||||
}, [agents]);
|
||||
|
||||
// Auto-scroll effect
|
||||
useEffect(() => {
|
||||
if (!isAutoScrolling || uniqueAgents.length <= 3) return;
|
||||
|
||||
const timer = setInterval(() => {
|
||||
setCurrentIndex(
|
||||
(prev) => (prev + 1) % Math.max(1, uniqueAgents.length - 2),
|
||||
);
|
||||
}, 5000);
|
||||
|
||||
return () => clearInterval(timer);
|
||||
}, [isAutoScrolling, uniqueAgents.length]);
|
||||
|
||||
// Scroll to current index
|
||||
useEffect(() => {
|
||||
if (scrollContainerRef.current) {
|
||||
const cardWidth = 320; // Approximate card width including gap
|
||||
scrollContainerRef.current.scrollTo({
|
||||
left: currentIndex * cardWidth,
|
||||
behavior: "smooth",
|
||||
});
|
||||
}
|
||||
}, [currentIndex]);
|
||||
|
||||
const handlePrevious = () => {
|
||||
setIsAutoScrolling(false);
|
||||
setCurrentIndex((prev) => Math.max(0, prev - 1));
|
||||
};
|
||||
|
||||
const handleNext = () => {
|
||||
setIsAutoScrolling(false);
|
||||
setCurrentIndex((prev) =>
|
||||
Math.min(Math.max(0, uniqueAgents.length - 3), prev + 1),
|
||||
);
|
||||
};
|
||||
|
||||
const handleDotClick = (index: number) => {
|
||||
setIsAutoScrolling(false);
|
||||
setCurrentIndex(index);
|
||||
};
|
||||
|
||||
if (!uniqueAgents || uniqueAgents.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const maxVisibleIndex = Math.max(0, uniqueAgents.length - 3);
|
||||
|
||||
return (
|
||||
<div className={cn("my-6 space-y-4", className)} ref={carouselRef}>
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between px-4">
|
||||
<div className="flex items-center gap-2">
|
||||
<Sparkles className="h-5 w-5 text-violet-600" />
|
||||
<h3 className="text-base font-semibold text-neutral-900 dark:text-neutral-100">
|
||||
Found {uniqueAgents.length} agents for “{query}”
|
||||
</h3>
|
||||
</div>
|
||||
{uniqueAgents.length > 3 && (
|
||||
<div className="flex items-center gap-2">
|
||||
<Button
|
||||
onClick={handlePrevious}
|
||||
variant="secondary"
|
||||
size="small"
|
||||
disabled={currentIndex === 0}
|
||||
className="p-1"
|
||||
>
|
||||
<ChevronLeft className="h-4 w-4" />
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleNext}
|
||||
variant="secondary"
|
||||
size="small"
|
||||
disabled={currentIndex >= maxVisibleIndex}
|
||||
className="p-1"
|
||||
>
|
||||
<ChevronRight className="h-4 w-4" />
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Carousel Container */}
|
||||
<div className="relative overflow-hidden px-4">
|
||||
<div
|
||||
ref={scrollContainerRef}
|
||||
className="scrollbar-hide flex gap-4 overflow-x-auto scroll-smooth"
|
||||
style={{ scrollbarWidth: "none", msOverflowStyle: "none" }}
|
||||
>
|
||||
{uniqueAgents.map((agent) => (
|
||||
<div
|
||||
key={agent.id}
|
||||
className={cn(
|
||||
"w-[300px] flex-shrink-0",
|
||||
"group relative overflow-hidden rounded-xl",
|
||||
"border border-neutral-200 dark:border-neutral-700",
|
||||
"bg-white dark:bg-neutral-900",
|
||||
"transition-all duration-300",
|
||||
"hover:scale-[1.02] hover:shadow-xl",
|
||||
"animate-in fade-in-50 slide-in-from-bottom-2",
|
||||
)}
|
||||
>
|
||||
{/* Agent Image Header */}
|
||||
<div className="relative h-32 bg-gradient-to-br from-violet-500/20 via-purple-500/20 to-indigo-500/20">
|
||||
{agent.agent_image ? (
|
||||
<Image
|
||||
src={agent.agent_image}
|
||||
alt={agent.name}
|
||||
fill
|
||||
className="object-cover opacity-90"
|
||||
/>
|
||||
) : (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<div className="text-4xl">🤖</div>
|
||||
</div>
|
||||
)}
|
||||
{agent.rating && (
|
||||
<div className="absolute right-2 top-2 flex items-center gap-1 rounded-full bg-black/50 px-2 py-1 backdrop-blur">
|
||||
<Star className="h-3 w-3 fill-yellow-400 text-yellow-400" />
|
||||
<span className="text-xs font-medium text-white">
|
||||
{agent.rating.toFixed(1)}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Agent Content */}
|
||||
<div className="space-y-3 p-4">
|
||||
<div>
|
||||
<h4 className="line-clamp-1 font-semibold text-neutral-900 dark:text-neutral-100">
|
||||
{agent.name}
|
||||
</h4>
|
||||
{agent.sub_heading && (
|
||||
<p className="mt-0.5 line-clamp-1 text-xs text-violet-600 dark:text-violet-400">
|
||||
{agent.sub_heading}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<p className="line-clamp-2 text-sm text-neutral-600 dark:text-neutral-400">
|
||||
{agent.description}
|
||||
</p>
|
||||
|
||||
{/* Creator Info */}
|
||||
<div className="flex items-center gap-2 text-xs text-neutral-500">
|
||||
{agent.creator_avatar ? (
|
||||
<div className="relative h-4 w-4">
|
||||
<Image
|
||||
src={agent.creator_avatar}
|
||||
alt={agent.creator}
|
||||
fill
|
||||
className="rounded-full object-cover"
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div className="h-4 w-4 rounded-full bg-neutral-300 dark:bg-neutral-600" />
|
||||
)}
|
||||
<span>by {agent.creator}</span>
|
||||
{agent.runs && (
|
||||
<>
|
||||
<span className="text-neutral-400">•</span>
|
||||
<span>{agent.runs.toLocaleString()} runs</span>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Action Buttons */}
|
||||
<div className="flex gap-2 pt-2">
|
||||
<Button
|
||||
onClick={() => onGetDetails(agent)}
|
||||
variant="secondary"
|
||||
size="small"
|
||||
className="flex-1"
|
||||
>
|
||||
<Info className="mr-1 h-3 w-3" />
|
||||
Details
|
||||
</Button>
|
||||
<Button
|
||||
onClick={() => onSelectAgent(agent)}
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="flex-1"
|
||||
>
|
||||
<Play className="mr-1 h-3 w-3" />
|
||||
Set Up
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Pagination Dots */}
|
||||
{agents.length > 3 && (
|
||||
<div className="flex justify-center gap-1.5 px-4">
|
||||
{Array.from({ length: maxVisibleIndex + 1 }).map((_, index) => (
|
||||
<button
|
||||
key={index}
|
||||
onClick={() => handleDotClick(index)}
|
||||
className={cn(
|
||||
"h-1.5 rounded-full transition-all duration-300",
|
||||
index === currentIndex
|
||||
? "w-6 bg-violet-600"
|
||||
: "w-1.5 bg-neutral-300 hover:bg-neutral-400 dark:bg-neutral-600",
|
||||
)}
|
||||
aria-label={`Go to slide ${index + 1}`}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { Star, Download, ArrowRight, Info } from "lucide-react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface Agent {
|
||||
id: string;
|
||||
version: string;
|
||||
name: string;
|
||||
description: string;
|
||||
creator?: string;
|
||||
rating?: number;
|
||||
runs?: number;
|
||||
downloads?: number;
|
||||
categories?: string[];
|
||||
}
|
||||
|
||||
interface AgentDiscoveryCardProps {
|
||||
agents: Agent[];
|
||||
onSelectAgent: (agent: Agent) => void;
|
||||
onGetDetails: (agent: Agent) => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function AgentDiscoveryCard({
|
||||
agents,
|
||||
onSelectAgent,
|
||||
onGetDetails,
|
||||
className,
|
||||
}: AgentDiscoveryCardProps) {
|
||||
if (!agents || agents.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn("my-4 space-y-3", className)}>
|
||||
<div className="text-sm font-medium text-neutral-700 dark:text-neutral-300">
|
||||
🎯 Recommended Agents for You:
|
||||
</div>
|
||||
|
||||
<div className="grid gap-3 md:grid-cols-2 lg:grid-cols-3">
|
||||
{agents.slice(0, 3).map((agent) => (
|
||||
<div
|
||||
key={`${agent.id}-${agent.version}`}
|
||||
className={cn(
|
||||
"group relative overflow-hidden rounded-lg border",
|
||||
"border-neutral-200 dark:border-neutral-700",
|
||||
"bg-white dark:bg-neutral-900",
|
||||
"transition-all duration-300 hover:shadow-lg",
|
||||
"animate-in fade-in-50 slide-in-from-bottom-2",
|
||||
)}
|
||||
>
|
||||
<div className="bg-gradient-to-br from-violet-500/10 to-purple-500/10 p-4">
|
||||
<div className="mb-2 flex items-start justify-between">
|
||||
<h3 className="font-semibold text-neutral-900 dark:text-neutral-100">
|
||||
{agent.name}
|
||||
</h3>
|
||||
{agent.rating && (
|
||||
<div className="flex items-center gap-1 text-sm">
|
||||
<Star className="h-3.5 w-3.5 fill-yellow-400 text-yellow-400" />
|
||||
<span className="text-neutral-600 dark:text-neutral-400">
|
||||
{agent.rating.toFixed(1)}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<p className="mb-3 line-clamp-2 text-sm text-neutral-600 dark:text-neutral-400">
|
||||
{agent.description}
|
||||
</p>
|
||||
|
||||
{agent.creator && (
|
||||
<p className="mb-2 text-xs text-neutral-500 dark:text-neutral-500">
|
||||
by {agent.creator}
|
||||
</p>
|
||||
)}
|
||||
|
||||
<div className="mb-3 flex items-center gap-3 text-xs text-neutral-500 dark:text-neutral-500">
|
||||
{agent.runs && (
|
||||
<div className="flex items-center gap-1">
|
||||
<Download className="h-3 w-3" />
|
||||
{agent.runs.toLocaleString()} runs
|
||||
</div>
|
||||
)}
|
||||
{agent.downloads && (
|
||||
<div className="flex items-center gap-1">
|
||||
<Download className="h-3 w-3" />
|
||||
{agent.downloads.toLocaleString()} downloads
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{agent.categories && agent.categories.length > 0 && (
|
||||
<div className="mb-3 flex flex-wrap gap-1">
|
||||
{agent.categories.slice(0, 3).map((category) => (
|
||||
<span
|
||||
key={category}
|
||||
className="rounded-full bg-neutral-100 px-2 py-0.5 text-xs text-neutral-600 dark:bg-neutral-800 dark:text-neutral-400"
|
||||
>
|
||||
{category}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex gap-2">
|
||||
<Button
|
||||
onClick={() => onGetDetails(agent)}
|
||||
variant="secondary"
|
||||
size="small"
|
||||
className="flex-1"
|
||||
>
|
||||
<Info className="mr-1 h-3 w-3" />
|
||||
Details
|
||||
</Button>
|
||||
<Button
|
||||
onClick={() => onSelectAgent(agent)}
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="flex-1"
|
||||
>
|
||||
Set Up
|
||||
<ArrowRight className="ml-1 h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
200
autogpt_platform/frontend/src/components/chat/AgentRunCard.tsx
Normal file
200
autogpt_platform/frontend/src/components/chat/AgentRunCard.tsx
Normal file
@@ -0,0 +1,200 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
Play,
|
||||
ExternalLink,
|
||||
Clock,
|
||||
CheckCircle,
|
||||
XCircle,
|
||||
Loader2,
|
||||
} from "lucide-react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Link from "next/link";
|
||||
|
||||
interface AgentRunCardProps {
|
||||
executionId: string;
|
||||
graphId: string;
|
||||
graphName: string;
|
||||
status: string;
|
||||
inputs?: Record<string, any>;
|
||||
outputs?: Record<string, any>;
|
||||
error?: string;
|
||||
timeoutReached?: boolean;
|
||||
endedAt?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function AgentRunCard({
|
||||
executionId,
|
||||
graphId,
|
||||
graphName,
|
||||
status,
|
||||
inputs,
|
||||
outputs,
|
||||
error,
|
||||
timeoutReached,
|
||||
endedAt,
|
||||
className,
|
||||
}: AgentRunCardProps) {
|
||||
const getStatusIcon = () => {
|
||||
switch (status) {
|
||||
case "COMPLETED":
|
||||
return <CheckCircle className="h-5 w-5 text-green-600" />;
|
||||
case "FAILED":
|
||||
case "ERROR":
|
||||
return <XCircle className="h-5 w-5 text-red-600" />;
|
||||
case "RUNNING":
|
||||
case "EXECUTING":
|
||||
return <Loader2 className="h-5 w-5 animate-spin text-blue-600" />;
|
||||
case "QUEUED":
|
||||
return <Clock className="h-5 w-5 text-amber-600" />;
|
||||
default:
|
||||
return <Play className="h-5 w-5 text-neutral-600" />;
|
||||
}
|
||||
};
|
||||
|
||||
const getStatusColor = () => {
|
||||
switch (status) {
|
||||
case "COMPLETED":
|
||||
return "border-green-200 bg-green-50 dark:border-green-800 dark:bg-green-950/30";
|
||||
case "FAILED":
|
||||
case "ERROR":
|
||||
return "border-red-200 bg-red-50 dark:border-red-800 dark:bg-red-950/30";
|
||||
case "RUNNING":
|
||||
case "EXECUTING":
|
||||
return "border-blue-200 bg-blue-50 dark:border-blue-800 dark:bg-blue-950/30";
|
||||
case "QUEUED":
|
||||
return "border-amber-200 bg-amber-50 dark:border-amber-800 dark:bg-amber-950/30";
|
||||
default:
|
||||
return "border-neutral-200 bg-neutral-50 dark:border-neutral-800 dark:bg-neutral-950/30";
|
||||
}
|
||||
};
|
||||
|
||||
const formatValue = (value: any): string => {
|
||||
if (value === null || value === undefined) return "null";
|
||||
if (typeof value === "object") {
|
||||
try {
|
||||
return JSON.stringify(value, null, 2);
|
||||
} catch {
|
||||
return String(value);
|
||||
}
|
||||
}
|
||||
return String(value);
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"my-4 overflow-hidden rounded-lg border transition-all duration-300",
|
||||
getStatusColor(),
|
||||
"animate-in fade-in-50 slide-in-from-bottom-2",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="px-6 py-5">
|
||||
{/* Header */}
|
||||
<div className="mb-4 flex items-start justify-between">
|
||||
<div className="flex items-center gap-3">
|
||||
{getStatusIcon()}
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-neutral-900 dark:text-neutral-100">
|
||||
{graphName}
|
||||
</h3>
|
||||
<p className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||
Execution ID: {executionId.slice(0, 8)}...
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<Link
|
||||
href={`/library/agents/${graphId}`}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
<Button variant="outline" size="small">
|
||||
<ExternalLink className="mr-2 h-4 w-4" />
|
||||
Go to Run
|
||||
</Button>
|
||||
</Link>
|
||||
</div>
|
||||
|
||||
{/* Input Data */}
|
||||
{inputs && Object.keys(inputs).length > 0 && (
|
||||
<div className="mb-4">
|
||||
<h4 className="mb-2 text-sm font-medium text-neutral-700 dark:text-neutral-300">
|
||||
Input Data:
|
||||
</h4>
|
||||
<div className="rounded-md bg-white/50 p-3 dark:bg-neutral-900/50">
|
||||
{Object.entries(inputs).map(([key, value]) => (
|
||||
<div key={key} className="mb-2 last:mb-0">
|
||||
<span className="font-mono text-xs text-neutral-600 dark:text-neutral-400">
|
||||
{key}:
|
||||
</span>
|
||||
<pre className="mt-1 overflow-x-auto rounded bg-neutral-100 p-2 text-xs text-neutral-800 dark:bg-neutral-800 dark:text-neutral-200">
|
||||
{formatValue(value)}
|
||||
</pre>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Output Data */}
|
||||
{outputs && Object.keys(outputs).length > 0 && (
|
||||
<div className="mb-4">
|
||||
<h4 className="mb-2 text-sm font-medium text-neutral-700 dark:text-neutral-300">
|
||||
Output Data:
|
||||
</h4>
|
||||
<div className="rounded-md bg-white/50 p-3 dark:bg-neutral-900/50">
|
||||
{Object.entries(outputs).map(([key, value]) => (
|
||||
<div key={key} className="mb-2 last:mb-0">
|
||||
<span className="font-mono text-xs text-neutral-600 dark:text-neutral-400">
|
||||
{key}:
|
||||
</span>
|
||||
<pre className="mt-1 overflow-x-auto rounded bg-neutral-100 p-2 text-xs text-neutral-800 dark:bg-neutral-800 dark:text-neutral-200">
|
||||
{formatValue(value)}
|
||||
</pre>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error Message */}
|
||||
{error && (
|
||||
<div className="mb-4 rounded-md border border-red-200 bg-red-50 p-3 dark:border-red-800 dark:bg-red-950/50">
|
||||
<p className="text-sm text-red-700 dark:text-red-300">
|
||||
<strong>Error:</strong> {error}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Timeout Message */}
|
||||
{timeoutReached && (
|
||||
<div className="mb-4 rounded-md border border-amber-200 bg-amber-50 p-3 dark:border-amber-800 dark:bg-amber-950/50">
|
||||
<p className="text-sm text-amber-700 dark:text-amber-300">
|
||||
<strong>Note:</strong> Execution timed out after 30 seconds. The
|
||||
agent may still be running in the background.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Status Bar */}
|
||||
<div className="flex items-center justify-between text-xs">
|
||||
<span className="font-medium text-neutral-600 dark:text-neutral-400">
|
||||
Status:{" "}
|
||||
<span className="text-neutral-900 dark:text-neutral-100">
|
||||
{status}
|
||||
</span>
|
||||
</span>
|
||||
{endedAt && (
|
||||
<span className="text-neutral-500 dark:text-neutral-500">
|
||||
Ended: {new Date(endedAt).toLocaleString()}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
253
autogpt_platform/frontend/src/components/chat/AgentSetupCard.tsx
Normal file
253
autogpt_platform/frontend/src/components/chat/AgentSetupCard.tsx
Normal file
@@ -0,0 +1,253 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import {
|
||||
CheckCircle,
|
||||
Calendar,
|
||||
Webhook,
|
||||
ExternalLink,
|
||||
Clock,
|
||||
PlayCircle,
|
||||
Library,
|
||||
} from "lucide-react";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface AgentSetupCardProps {
|
||||
status: string;
|
||||
triggerType: "schedule" | "webhook";
|
||||
name: string;
|
||||
graphId: string;
|
||||
graphVersion: number;
|
||||
scheduleId?: string;
|
||||
webhookUrl?: string;
|
||||
cron?: string;
|
||||
_cronUtc?: string;
|
||||
timezone?: string;
|
||||
nextRun?: string;
|
||||
addedToLibrary?: boolean;
|
||||
libraryId?: string;
|
||||
message: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function AgentSetupCard({
|
||||
status,
|
||||
triggerType,
|
||||
name,
|
||||
graphId,
|
||||
graphVersion,
|
||||
scheduleId,
|
||||
webhookUrl,
|
||||
cron,
|
||||
_cronUtc,
|
||||
timezone,
|
||||
nextRun,
|
||||
addedToLibrary,
|
||||
libraryId,
|
||||
message,
|
||||
className,
|
||||
}: AgentSetupCardProps) {
|
||||
const isSuccess = status === "success";
|
||||
|
||||
const formatNextRun = (isoString: string) => {
|
||||
try {
|
||||
const date = new Date(isoString);
|
||||
return date.toLocaleString();
|
||||
} catch {
|
||||
return isoString;
|
||||
}
|
||||
};
|
||||
|
||||
const handleViewInLibrary = () => {
|
||||
if (libraryId) {
|
||||
window.open(`/library/agents/${libraryId}`, "_blank");
|
||||
} else {
|
||||
window.open(`/library`, "_blank");
|
||||
}
|
||||
};
|
||||
|
||||
const handleViewRuns = () => {
|
||||
if (scheduleId) {
|
||||
window.open(`/library/runs?scheduleId=${scheduleId}`, "_blank");
|
||||
} else {
|
||||
window.open(`/library/runs`, "_blank");
|
||||
}
|
||||
};
|
||||
|
||||
const copyToClipboard = (text: string) => {
|
||||
navigator.clipboard.writeText(text).then(() => {
|
||||
// Could add a toast notification here
|
||||
console.log("Copied to clipboard:", text);
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"my-4 overflow-hidden rounded-lg border",
|
||||
isSuccess
|
||||
? "border-green-200 bg-gradient-to-br from-green-50 to-emerald-50 dark:border-green-800 dark:from-green-950/30 dark:to-emerald-950/30"
|
||||
: "border-red-200 bg-gradient-to-br from-red-50 to-rose-50 dark:border-red-800 dark:from-red-950/30 dark:to-rose-950/30",
|
||||
"duration-500 animate-in fade-in-50 slide-in-from-bottom-2",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="px-6 py-5">
|
||||
<div className="mb-4 flex items-center gap-3">
|
||||
<div
|
||||
className={cn(
|
||||
"flex h-10 w-10 items-center justify-center rounded-full",
|
||||
isSuccess ? "bg-green-600" : "bg-red-600",
|
||||
)}
|
||||
>
|
||||
{isSuccess ? (
|
||||
<CheckCircle className="h-5 w-5 text-white" />
|
||||
) : (
|
||||
<ExternalLink className="h-5 w-5 text-white" />
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-neutral-900 dark:text-neutral-100">
|
||||
{isSuccess ? "Agent Setup Complete" : "Setup Failed"}
|
||||
</h3>
|
||||
<p className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||
{name}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Success message */}
|
||||
<div className="mb-5 rounded-md bg-white/50 p-4 dark:bg-neutral-900/50">
|
||||
<p className="text-sm text-neutral-700 dark:text-neutral-300">
|
||||
{message}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{/* Setup details */}
|
||||
{isSuccess && (
|
||||
<div className="mb-5 space-y-3">
|
||||
{/* Trigger type badge */}
|
||||
<div className="flex items-center gap-2">
|
||||
{triggerType === "schedule" ? (
|
||||
<>
|
||||
<Calendar className="h-4 w-4 text-blue-600" />
|
||||
<span className="text-sm font-medium text-blue-700 dark:text-blue-400">
|
||||
Scheduled Execution
|
||||
</span>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Webhook className="h-4 w-4 text-purple-600" />
|
||||
<span className="text-sm font-medium text-purple-700 dark:text-purple-400">
|
||||
Webhook Trigger
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Schedule details */}
|
||||
{triggerType === "schedule" && cron && (
|
||||
<div className="space-y-2 rounded-md bg-blue-50 p-3 text-sm dark:bg-blue-950/30">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-neutral-600 dark:text-neutral-400">
|
||||
Schedule:
|
||||
</span>
|
||||
<code className="rounded bg-neutral-200 px-2 py-0.5 font-mono text-xs dark:bg-neutral-800">
|
||||
{cron}
|
||||
</code>
|
||||
</div>
|
||||
{timezone && (
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-neutral-600 dark:text-neutral-400">
|
||||
Timezone:
|
||||
</span>
|
||||
<span className="font-medium">{timezone}</span>
|
||||
</div>
|
||||
)}
|
||||
{nextRun && (
|
||||
<div className="flex items-center gap-2">
|
||||
<Clock className="h-4 w-4 text-blue-600" />
|
||||
<span className="text-neutral-600 dark:text-neutral-400">
|
||||
Next run:
|
||||
</span>
|
||||
<span className="font-medium">
|
||||
{formatNextRun(nextRun)}
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Webhook details */}
|
||||
{triggerType === "webhook" && webhookUrl && (
|
||||
<div className="space-y-2 rounded-md bg-purple-50 p-3 dark:bg-purple-950/30">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||
Webhook URL:
|
||||
</span>
|
||||
<button
|
||||
onClick={() => copyToClipboard(webhookUrl)}
|
||||
className="text-xs text-purple-600 hover:text-purple-700 hover:underline dark:text-purple-400"
|
||||
>
|
||||
Copy
|
||||
</button>
|
||||
</div>
|
||||
<code className="block break-all rounded bg-neutral-200 p-2 font-mono text-xs dark:bg-neutral-800">
|
||||
{webhookUrl}
|
||||
</code>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Library status */}
|
||||
{addedToLibrary && (
|
||||
<div className="flex items-center gap-2 text-sm text-green-700 dark:text-green-400">
|
||||
<Library className="h-4 w-4" />
|
||||
<span>Added to your library</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Action buttons */}
|
||||
{isSuccess && (
|
||||
<div className="flex gap-3">
|
||||
<Button
|
||||
onClick={handleViewInLibrary}
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="flex-1"
|
||||
>
|
||||
<Library className="mr-2 h-4 w-4" />
|
||||
View in Library
|
||||
</Button>
|
||||
{triggerType === "schedule" && (
|
||||
<Button
|
||||
onClick={handleViewRuns}
|
||||
variant="secondary"
|
||||
size="small"
|
||||
className="flex-1"
|
||||
>
|
||||
<PlayCircle className="mr-2 h-4 w-4" />
|
||||
View Runs
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Additional info */}
|
||||
<div className="mt-4 space-y-1 text-xs text-neutral-500 dark:text-neutral-500">
|
||||
<p>
|
||||
Agent ID: <span className="font-mono">{graphId}</span>
|
||||
</p>
|
||||
<p>Version: {graphVersion}</p>
|
||||
{scheduleId && (
|
||||
<p>
|
||||
Schedule ID: <span className="font-mono">{scheduleId}</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { LogIn, UserPlus, Shield } from "lucide-react";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface AuthPromptWidgetProps {
|
||||
message: string;
|
||||
sessionId: string;
|
||||
agentInfo?: {
|
||||
graph_id: string;
|
||||
name: string;
|
||||
trigger_type: string;
|
||||
};
|
||||
returnUrl?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function AuthPromptWidget({
|
||||
message,
|
||||
sessionId,
|
||||
agentInfo,
|
||||
returnUrl = "/marketplace/discover",
|
||||
className,
|
||||
}: AuthPromptWidgetProps) {
|
||||
const router = useRouter();
|
||||
|
||||
const handleSignIn = () => {
|
||||
// Store session info to return after auth
|
||||
if (typeof window !== "undefined") {
|
||||
localStorage.setItem("pending_chat_session", sessionId);
|
||||
if (agentInfo) {
|
||||
localStorage.setItem("pending_agent_setup", JSON.stringify(agentInfo));
|
||||
}
|
||||
}
|
||||
|
||||
// Build return URL with session ID
|
||||
const returnUrlWithSession = `${returnUrl}?sessionId=${sessionId}`;
|
||||
const encodedReturnUrl = encodeURIComponent(returnUrlWithSession);
|
||||
router.push(`/login?returnUrl=${encodedReturnUrl}`);
|
||||
};
|
||||
|
||||
const handleSignUp = () => {
|
||||
// Store session info to return after auth
|
||||
if (typeof window !== "undefined") {
|
||||
localStorage.setItem("pending_chat_session", sessionId);
|
||||
if (agentInfo) {
|
||||
localStorage.setItem("pending_agent_setup", JSON.stringify(agentInfo));
|
||||
}
|
||||
}
|
||||
|
||||
// Build return URL with session ID
|
||||
const returnUrlWithSession = `${returnUrl}?sessionId=${sessionId}`;
|
||||
const encodedReturnUrl = encodeURIComponent(returnUrlWithSession);
|
||||
router.push(`/signup?returnUrl=${encodedReturnUrl}`);
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"my-4 overflow-hidden rounded-lg border border-violet-200 dark:border-violet-800",
|
||||
"bg-gradient-to-br from-violet-50 to-purple-50 dark:from-violet-950/30 dark:to-purple-950/30",
|
||||
"duration-500 animate-in fade-in-50 slide-in-from-bottom-2",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="px-6 py-5">
|
||||
<div className="mb-4 flex items-center gap-3">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-full bg-violet-600">
|
||||
<Shield className="h-5 w-5 text-white" />
|
||||
</div>
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-neutral-900 dark:text-neutral-100">
|
||||
Authentication Required
|
||||
</h3>
|
||||
<p className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||
Sign in to set up and manage agents
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mb-5 rounded-md bg-white/50 p-4 dark:bg-neutral-900/50">
|
||||
<p className="text-sm text-neutral-700 dark:text-neutral-300">
|
||||
{message}
|
||||
</p>
|
||||
{agentInfo && (
|
||||
<div className="mt-3 text-xs text-neutral-600 dark:text-neutral-400">
|
||||
<p>
|
||||
Ready to set up:{" "}
|
||||
<span className="font-medium">{agentInfo.name}</span>
|
||||
</p>
|
||||
<p>
|
||||
Type:{" "}
|
||||
<span className="font-medium">{agentInfo.trigger_type}</span>
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex gap-3">
|
||||
<Button
|
||||
onClick={handleSignIn}
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="flex-1"
|
||||
>
|
||||
<LogIn className="mr-2 h-4 w-4" />
|
||||
Sign In
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleSignUp}
|
||||
variant="secondary"
|
||||
size="small"
|
||||
className="flex-1"
|
||||
>
|
||||
<UserPlus className="mr-2 h-4 w-4" />
|
||||
Create Account
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
<div className="mt-4 text-center text-xs text-neutral-500 dark:text-neutral-500">
|
||||
Your chat session will be preserved after signing in
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
129
autogpt_platform/frontend/src/components/chat/ChatInput.tsx
Normal file
129
autogpt_platform/frontend/src/components/chat/ChatInput.tsx
Normal file
@@ -0,0 +1,129 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState, useRef, useEffect, KeyboardEvent } from "react";
|
||||
import { Send, X } from "lucide-react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface ChatInputProps {
|
||||
onSendMessage: (message: string) => void;
|
||||
onStopStreaming?: () => void;
|
||||
isStreaming?: boolean;
|
||||
disabled?: boolean;
|
||||
placeholder?: string;
|
||||
maxLength?: number;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatInput({
|
||||
onSendMessage,
|
||||
onStopStreaming,
|
||||
isStreaming = false,
|
||||
disabled = false,
|
||||
placeholder = "Ask about AI agents or describe what you want to automate...",
|
||||
maxLength = 10000,
|
||||
className,
|
||||
}: ChatInputProps) {
|
||||
const [message, setMessage] = useState("");
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
// Auto-resize textarea
|
||||
useEffect(() => {
|
||||
if (textareaRef.current) {
|
||||
textareaRef.current.style.height = "auto";
|
||||
textareaRef.current.style.height = `${Math.min(textareaRef.current.scrollHeight, 200)}px`;
|
||||
}
|
||||
}, [message]);
|
||||
|
||||
const handleSubmit = () => {
|
||||
if (message.trim() && !disabled && !isStreaming) {
|
||||
onSendMessage(message.trim());
|
||||
setMessage("");
|
||||
}
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleSubmit();
|
||||
}
|
||||
};
|
||||
|
||||
const isDisabled = disabled || (isStreaming && !onStopStreaming);
|
||||
const charactersRemaining = maxLength - message.length;
|
||||
const showCharacterCount = message.length > maxLength * 0.8;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"border-t border-neutral-200 bg-white dark:border-neutral-700 dark:bg-neutral-900",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="mx-auto max-w-4xl px-4 py-4">
|
||||
<div className="flex items-end gap-3">
|
||||
<div className="flex-1">
|
||||
<textarea
|
||||
ref={textareaRef}
|
||||
value={message}
|
||||
onChange={(e) => setMessage(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={placeholder}
|
||||
disabled={isDisabled}
|
||||
maxLength={maxLength}
|
||||
className={cn(
|
||||
"w-full resize-none rounded-lg border border-neutral-300 dark:border-neutral-600",
|
||||
"bg-white px-4 py-3 dark:bg-neutral-800",
|
||||
"text-neutral-900 dark:text-neutral-100",
|
||||
"placeholder:text-neutral-500 dark:placeholder:text-neutral-400",
|
||||
"focus:border-violet-500 focus:outline-none focus:ring-2 focus:ring-violet-500/20",
|
||||
"disabled:cursor-not-allowed disabled:opacity-50",
|
||||
"max-h-[200px] min-h-[52px]",
|
||||
)}
|
||||
rows={1}
|
||||
/>
|
||||
{showCharacterCount && (
|
||||
<div
|
||||
className={cn(
|
||||
"mt-1 text-xs",
|
||||
charactersRemaining < 100
|
||||
? "text-red-500"
|
||||
: "text-neutral-500 dark:text-neutral-400",
|
||||
)}
|
||||
>
|
||||
{charactersRemaining} characters remaining
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{isStreaming && onStopStreaming ? (
|
||||
<Button
|
||||
onClick={onStopStreaming}
|
||||
variant="secondary"
|
||||
size="small"
|
||||
className="mb-[2px]"
|
||||
>
|
||||
<X className="mr-1 h-4 w-4" />
|
||||
Stop
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
onClick={handleSubmit}
|
||||
disabled={!message.trim() || isDisabled}
|
||||
variant="primary"
|
||||
size="small"
|
||||
className="mb-[2px]"
|
||||
>
|
||||
<Send className="mr-1 h-4 w-4" />
|
||||
Send
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="mt-2 text-xs text-neutral-500 dark:text-neutral-400">
|
||||
Press Enter to send, Shift+Enter for new line
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
1343
autogpt_platform/frontend/src/components/chat/ChatInterface.tsx
Normal file
1343
autogpt_platform/frontend/src/components/chat/ChatInterface.tsx
Normal file
File diff suppressed because it is too large
Load Diff
133
autogpt_platform/frontend/src/components/chat/ChatMessage.tsx
Normal file
133
autogpt_platform/frontend/src/components/chat/ChatMessage.tsx
Normal file
@@ -0,0 +1,133 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { ChatMessage as ChatMessageType } from "@/lib/autogpt-server-api/chat";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { User, Bot } from "lucide-react";
|
||||
|
||||
interface ChatMessageProps {
|
||||
message: ChatMessageType;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ChatMessage({ message, className }: ChatMessageProps) {
|
||||
const isUser = message.role === "USER";
|
||||
const isAssistant = message.role === "ASSISTANT";
|
||||
const isSystem = message.role === "SYSTEM";
|
||||
const isTool = message.role === "TOOL";
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex gap-4 px-4 py-6",
|
||||
isUser && "justify-end",
|
||||
!isUser && "justify-start",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{!isUser && (
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-8 w-8 items-center justify-center rounded-full bg-violet-600">
|
||||
<Bot className="h-5 w-5 text-white" />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
className={cn(
|
||||
"max-w-[70%] rounded-lg px-4 py-3",
|
||||
isUser && "bg-neutral-100 dark:bg-neutral-800",
|
||||
isAssistant &&
|
||||
"border border-neutral-200 bg-white dark:border-neutral-700 dark:bg-neutral-900",
|
||||
isSystem &&
|
||||
"border border-blue-200 bg-blue-50 dark:border-blue-800 dark:bg-blue-900/20",
|
||||
isTool &&
|
||||
"border border-green-200 bg-green-50 dark:border-green-800 dark:bg-green-900/20",
|
||||
)}
|
||||
>
|
||||
{isSystem && (
|
||||
<div className="mb-2 text-xs font-medium text-blue-600 dark:text-blue-400">
|
||||
System
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isTool && (
|
||||
<div className="mb-2 text-xs font-medium text-green-600 dark:text-green-400">
|
||||
Tool Response
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="prose prose-sm dark:prose-invert max-w-none">
|
||||
{/* Simple markdown-like rendering without external dependencies */}
|
||||
<div className="whitespace-pre-wrap">
|
||||
{message.content.split("\n").map((line, index) => {
|
||||
// Basic markdown parsing
|
||||
if (line.startsWith("# ")) {
|
||||
return (
|
||||
<h1 key={index} className="mb-2 text-xl font-bold">
|
||||
{line.substring(2)}
|
||||
</h1>
|
||||
);
|
||||
} else if (line.startsWith("## ")) {
|
||||
return (
|
||||
<h2 key={index} className="mb-2 text-lg font-bold">
|
||||
{line.substring(3)}
|
||||
</h2>
|
||||
);
|
||||
} else if (line.startsWith("### ")) {
|
||||
return (
|
||||
<h3 key={index} className="mb-2 text-base font-bold">
|
||||
{line.substring(4)}
|
||||
</h3>
|
||||
);
|
||||
} else if (line.startsWith("- ")) {
|
||||
return (
|
||||
<li key={index} className="ml-4 list-disc">
|
||||
{line.substring(2)}
|
||||
</li>
|
||||
);
|
||||
} else if (line.startsWith("```")) {
|
||||
return (
|
||||
<pre
|
||||
key={index}
|
||||
className="my-2 overflow-x-auto rounded bg-neutral-100 p-2 dark:bg-neutral-800"
|
||||
>
|
||||
<code>{line.substring(3)}</code>
|
||||
</pre>
|
||||
);
|
||||
} else if (line.trim() === "") {
|
||||
return <br key={index} />;
|
||||
} else {
|
||||
return (
|
||||
<p key={index} className="mb-1">
|
||||
{line}
|
||||
</p>
|
||||
);
|
||||
}
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{message.tokens && (
|
||||
<div className="mt-3 flex gap-3 text-xs text-neutral-500 dark:text-neutral-400">
|
||||
{message.tokens.prompt && (
|
||||
<span>Prompt: {message.tokens.prompt}</span>
|
||||
)}
|
||||
{message.tokens.completion && (
|
||||
<span>Completion: {message.tokens.completion}</span>
|
||||
)}
|
||||
{message.tokens.total && <span>Total: {message.tokens.total}</span>}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{isUser && (
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-8 w-8 items-center justify-center rounded-full bg-neutral-600">
|
||||
<User className="h-5 w-5 text-white" />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState } from "react";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import { Key, CheckCircle, AlertCircle, Loader2 } from "lucide-react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CredentialsInput } from "@/app/(platform)/library/agents/[id]/components/AgentRunsView/components/CredentialsInputs/CredentialsInputs";
|
||||
import {
|
||||
BlockIOCredentialsSubSchema,
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
|
||||
interface CredentialsSetupWidgetProps {
|
||||
agentInfo: {
|
||||
id: string;
|
||||
name: string;
|
||||
graph_id?: string;
|
||||
};
|
||||
credentialsSchema: any; // This will be the credentials_input_schema from the agent
|
||||
onCredentialsSubmit?: (
|
||||
credentials: Record<string, CredentialsMetaInput>,
|
||||
) => void;
|
||||
onSkip?: () => void;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function CredentialsSetupWidget({
|
||||
agentInfo,
|
||||
credentialsSchema,
|
||||
onCredentialsSubmit,
|
||||
onSkip,
|
||||
className,
|
||||
}: CredentialsSetupWidgetProps) {
|
||||
const [selectedCredentials, setSelectedCredentials] = useState<
|
||||
Record<string, CredentialsMetaInput>
|
||||
>({});
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
// Parse the credentials schema to extract individual credential requirements
|
||||
// Handle both nested (with properties) and flat schema structures
|
||||
const schemaProperties =
|
||||
credentialsSchema?.properties || credentialsSchema || {};
|
||||
const credentialKeys = Object.keys(schemaProperties);
|
||||
const allCredentialsSelected = credentialKeys.every(
|
||||
(key) => selectedCredentials[key],
|
||||
);
|
||||
|
||||
const handleCredentialSelect = (
|
||||
key: string,
|
||||
credential?: CredentialsMetaInput,
|
||||
) => {
|
||||
if (credential) {
|
||||
setSelectedCredentials((prev) => ({
|
||||
...prev,
|
||||
[key]: credential,
|
||||
}));
|
||||
setError(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (!allCredentialsSelected) {
|
||||
setError("Please provide all required credentials");
|
||||
return;
|
||||
}
|
||||
|
||||
setIsSubmitting(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
if (onCredentialsSubmit) {
|
||||
await onCredentialsSubmit(selectedCredentials);
|
||||
}
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to set up credentials",
|
||||
);
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"my-4 overflow-hidden rounded-lg border border-amber-200 dark:border-amber-800",
|
||||
"bg-gradient-to-br from-amber-50 to-orange-50 dark:from-amber-950/30 dark:to-orange-950/30",
|
||||
"duration-500 animate-in fade-in-50 slide-in-from-bottom-2",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div className="px-6 py-5">
|
||||
<div className="mb-4 flex items-center gap-3">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-full bg-amber-600">
|
||||
<AlertCircle className="h-5 w-5 text-white" />
|
||||
</div>
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold text-neutral-900 dark:text-neutral-100">
|
||||
Credentials Required
|
||||
</h3>
|
||||
<p className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||
The agent "{agentInfo.name}" requires credentials to
|
||||
run. Please provide the following:
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Credentials inputs */}
|
||||
<div className="space-y-4">
|
||||
{credentialKeys.map((key) => {
|
||||
const schema = schemaProperties[key] as BlockIOCredentialsSubSchema;
|
||||
const isSelected = !!selectedCredentials[key];
|
||||
|
||||
return (
|
||||
<div
|
||||
key={key}
|
||||
className={cn(
|
||||
"relative rounded-lg bg-white/50 p-4 dark:bg-neutral-900/50",
|
||||
isSelected && "bg-green-50 dark:bg-green-950/30",
|
||||
)}
|
||||
>
|
||||
<CredentialsInput
|
||||
schema={schema}
|
||||
selectedCredentials={selectedCredentials[key]}
|
||||
onSelectCredentials={(cred) =>
|
||||
handleCredentialSelect(key, cred)
|
||||
}
|
||||
hideIfSingleCredentialAvailable={false}
|
||||
/>
|
||||
{isSelected && (
|
||||
<div className="absolute right-4 top-4">
|
||||
<CheckCircle className="h-5 w-5 text-green-600" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="mt-4 rounded-lg border border-red-200 bg-red-50 p-3 text-sm text-red-700 dark:border-red-800 dark:bg-red-950/30 dark:text-red-300">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="mt-6 flex gap-2">
|
||||
<Button
|
||||
onClick={handleSubmit}
|
||||
disabled={!allCredentialsSelected || isSubmitting}
|
||||
className="flex-1"
|
||||
>
|
||||
{isSubmitting ? (
|
||||
<>
|
||||
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||
Setting up...
|
||||
</>
|
||||
) : (
|
||||
"Continue with Setup"
|
||||
)}
|
||||
</Button>
|
||||
{onSkip && (
|
||||
<Button variant="outline" onClick={onSkip} disabled={isSubmitting}>
|
||||
Skip for now
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{!allCredentialsSelected && (
|
||||
<div className="mt-4 flex items-center gap-2 rounded-md bg-amber-100 p-3 text-xs text-amber-700 dark:bg-amber-900/30 dark:text-amber-300">
|
||||
<Key className="h-4 w-4 flex-shrink-0" />
|
||||
<span>
|
||||
You need to configure all required credentials before this agent
|
||||
can be set up.
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,347 @@
|
||||
"use client";
|
||||
|
||||
import React, { useMemo } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { User, Bot } from "lucide-react";
|
||||
import { ToolCallWidget } from "./ToolCallWidget";
|
||||
import { AgentCarousel } from "./AgentCarousel";
|
||||
import { CredentialsSetupWidget } from "./CredentialsSetupWidget";
|
||||
import { AgentSetupCard } from "./AgentSetupCard";
|
||||
import { AgentRunCard } from "./AgentRunCard";
|
||||
|
||||
interface ContentSegment {
|
||||
type:
|
||||
| "text"
|
||||
| "tool"
|
||||
| "carousel"
|
||||
| "credentials_setup"
|
||||
| "agent_setup"
|
||||
| "agent_run"
|
||||
| "auth_required";
|
||||
content: any;
|
||||
id?: string;
|
||||
}
|
||||
|
||||
interface StreamingMessageProps {
|
||||
role: "USER" | "ASSISTANT" | "SYSTEM" | "TOOL";
|
||||
segments: ContentSegment[];
|
||||
className?: string;
|
||||
onSelectAgent?: (agent: any) => void;
|
||||
onGetAgentDetails?: (agent: any) => void;
|
||||
}
|
||||
|
||||
export function StreamingMessage({
|
||||
role,
|
||||
segments,
|
||||
className,
|
||||
onSelectAgent,
|
||||
onGetAgentDetails,
|
||||
}: StreamingMessageProps) {
|
||||
const isUser = role === "USER";
|
||||
const isAssistant = role === "ASSISTANT";
|
||||
const isSystem = role === "SYSTEM";
|
||||
const isTool = role === "TOOL";
|
||||
|
||||
// Helper function to parse text with ** markers for bold
|
||||
const parseTextWithBold = (text: string): React.ReactNode[] => {
|
||||
const parts: React.ReactNode[] = [];
|
||||
let currentIndex = 0;
|
||||
let isInBold = false;
|
||||
let boldStartIndex = -1;
|
||||
|
||||
while (currentIndex < text.length) {
|
||||
const nextMarkerIndex = text.indexOf("**", currentIndex);
|
||||
|
||||
if (nextMarkerIndex === -1) {
|
||||
// No more markers, add the rest of the text
|
||||
if (isInBold && boldStartIndex !== -1) {
|
||||
// We're in bold but no closing marker found, treat the opening ** as regular text
|
||||
parts.push(text.substring(boldStartIndex));
|
||||
} else {
|
||||
parts.push(text.substring(currentIndex));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (isInBold) {
|
||||
// This is a closing marker
|
||||
const boldText = text.substring(currentIndex, nextMarkerIndex);
|
||||
parts.push(<strong key={`bold-${parts.length}`}>{boldText}</strong>);
|
||||
isInBold = false;
|
||||
currentIndex = nextMarkerIndex + 2;
|
||||
boldStartIndex = -1;
|
||||
} else {
|
||||
// This is an opening marker
|
||||
if (nextMarkerIndex > currentIndex) {
|
||||
// Add text before the marker
|
||||
parts.push(text.substring(currentIndex, nextMarkerIndex));
|
||||
}
|
||||
isInBold = true;
|
||||
boldStartIndex = nextMarkerIndex;
|
||||
currentIndex = nextMarkerIndex + 2;
|
||||
}
|
||||
}
|
||||
|
||||
return parts;
|
||||
};
|
||||
|
||||
// Process segments to combine consecutive text segments
|
||||
const processedSegments = useMemo(() => {
|
||||
const result: ContentSegment[] = [];
|
||||
let currentText = "";
|
||||
|
||||
segments.forEach((segment) => {
|
||||
if (segment.type === "text") {
|
||||
currentText += segment.content;
|
||||
} else {
|
||||
// Flush any accumulated text
|
||||
if (currentText) {
|
||||
result.push({ type: "text", content: currentText });
|
||||
currentText = "";
|
||||
}
|
||||
// Add the non-text segment
|
||||
result.push(segment);
|
||||
}
|
||||
});
|
||||
|
||||
// Flush remaining text
|
||||
if (currentText) {
|
||||
result.push({ type: "text", content: currentText });
|
||||
}
|
||||
|
||||
return result;
|
||||
}, [segments]);
|
||||
|
||||
const renderSegment = (segment: ContentSegment, index: number) => {
|
||||
// Generate a unique key based on segment type, content hash, and index
|
||||
const segmentKey = `${segment.type}-${index}-${segment.id || segment.content?.id || Date.now()}`;
|
||||
|
||||
switch (segment.type) {
|
||||
case "text":
|
||||
return (
|
||||
<div key={segmentKey} className="inline">
|
||||
{/* Simple markdown-like rendering */}
|
||||
{segment.content
|
||||
.split("\n")
|
||||
.map((line: string, lineIndex: number) => {
|
||||
const lineKey = `${segmentKey}-line-${lineIndex}`;
|
||||
if (line.startsWith("# ")) {
|
||||
return (
|
||||
<h1 key={lineKey} className="mb-2 text-xl font-bold">
|
||||
{parseTextWithBold(line.substring(2))}
|
||||
</h1>
|
||||
);
|
||||
} else if (line.startsWith("## ")) {
|
||||
return (
|
||||
<h2 key={lineKey} className="mb-2 text-lg font-bold">
|
||||
{parseTextWithBold(line.substring(3))}
|
||||
</h2>
|
||||
);
|
||||
} else if (line.startsWith("### ")) {
|
||||
return (
|
||||
<h3 key={lineKey} className="mb-2 text-base font-bold">
|
||||
{parseTextWithBold(line.substring(4))}
|
||||
</h3>
|
||||
);
|
||||
} else if (line.startsWith("- ")) {
|
||||
return (
|
||||
<li key={lineKey} className="ml-4 list-disc">
|
||||
{parseTextWithBold(line.substring(2))}
|
||||
</li>
|
||||
);
|
||||
} else if (line.startsWith("```")) {
|
||||
return (
|
||||
<pre
|
||||
key={lineKey}
|
||||
className="my-2 overflow-x-auto rounded bg-neutral-100 p-2 dark:bg-neutral-800"
|
||||
>
|
||||
<code>{line.substring(3)}</code>
|
||||
</pre>
|
||||
);
|
||||
} else if (line.trim() === "") {
|
||||
return <br key={lineKey} />;
|
||||
} else {
|
||||
return (
|
||||
<span key={lineKey}>
|
||||
{parseTextWithBold(line)}
|
||||
{lineIndex < segment.content.split("\n").length - 1 &&
|
||||
"\n"}
|
||||
</span>
|
||||
);
|
||||
}
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
|
||||
case "tool":
|
||||
const toolData = segment.content;
|
||||
return (
|
||||
<div key={segmentKey} className="my-3">
|
||||
<ToolCallWidget
|
||||
toolName={toolData.name}
|
||||
parameters={toolData.parameters}
|
||||
result={toolData.result}
|
||||
status={toolData.status}
|
||||
error={toolData.error}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
case "carousel":
|
||||
const carouselData = segment.content;
|
||||
return (
|
||||
<div key={segmentKey} className="my-4">
|
||||
<AgentCarousel
|
||||
agents={carouselData.agents}
|
||||
query={carouselData.query}
|
||||
onSelectAgent={onSelectAgent!}
|
||||
onGetDetails={onGetAgentDetails!}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
case "credentials_setup":
|
||||
const credentialsData = segment.content;
|
||||
return (
|
||||
<div key={segmentKey} className="my-4">
|
||||
<CredentialsSetupWidget
|
||||
agentInfo={{
|
||||
id: credentialsData.agent_id || credentialsData.agent?.id,
|
||||
name: credentialsData.agent_name || credentialsData.agent?.name,
|
||||
graph_id:
|
||||
credentialsData.graph_id || credentialsData.agent?.graph_id,
|
||||
}}
|
||||
credentialsSchema={
|
||||
credentialsData.credentials_schema ||
|
||||
credentialsData.credentials ||
|
||||
{}
|
||||
}
|
||||
onCredentialsSubmit={async (_credentials) => {
|
||||
// After credentials are set up, retry the agent setup
|
||||
const agentInfo =
|
||||
credentialsData.agent_info || credentialsData.agent;
|
||||
if (agentInfo && onSelectAgent) {
|
||||
// Send a message to retry setting up the agent now that credentials are configured
|
||||
const message = `The credentials have been configured. Now set up the agent "${agentInfo.name || agentInfo.agent_id}" (ID: ${agentInfo.graph_id || agentInfo.agent_id})`;
|
||||
console.log(
|
||||
"Retrying agent setup after credentials:",
|
||||
message,
|
||||
);
|
||||
// Trigger the onSelectAgent callback which should send the appropriate message
|
||||
onSelectAgent(agentInfo);
|
||||
}
|
||||
}}
|
||||
onSkip={() => {
|
||||
console.log("User skipped credentials setup");
|
||||
// Optionally show a message that the agent cannot be run without credentials
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
case "agent_setup":
|
||||
const setupData = segment.content;
|
||||
return (
|
||||
<div key={segmentKey} className="my-4">
|
||||
<AgentSetupCard
|
||||
status={setupData.status}
|
||||
triggerType={setupData.trigger_type}
|
||||
name={setupData.name}
|
||||
graphId={setupData.graph_id}
|
||||
graphVersion={setupData.graph_version}
|
||||
scheduleId={setupData.schedule_id}
|
||||
webhookUrl={setupData.webhook_url}
|
||||
cron={setupData.cron}
|
||||
_cronUtc={setupData.cron_utc}
|
||||
timezone={setupData.timezone}
|
||||
nextRun={setupData.next_run}
|
||||
addedToLibrary={setupData.added_to_library}
|
||||
libraryId={setupData.library_id}
|
||||
message={setupData.message}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
case "agent_run":
|
||||
const runData = segment.content;
|
||||
return (
|
||||
<div key={segmentKey} className="my-4">
|
||||
<AgentRunCard
|
||||
executionId={runData.execution_id}
|
||||
graphId={runData.graph_id}
|
||||
graphName={runData.graph_name}
|
||||
status={runData.status}
|
||||
inputs={runData.inputs}
|
||||
outputs={runData.outputs}
|
||||
error={runData.error}
|
||||
timeoutReached={runData.timeout_reached}
|
||||
endedAt={runData.ended_at}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
case "auth_required":
|
||||
// Auth required segments are handled separately by ChatInterface
|
||||
// They trigger the auth prompt widget, not rendered inline
|
||||
return null;
|
||||
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"flex gap-4 px-4 py-6",
|
||||
isUser && "justify-end",
|
||||
!isUser && "justify-start",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{!isUser && (
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-8 w-8 items-center justify-center rounded-full bg-violet-600">
|
||||
<Bot className="h-5 w-5 text-white" />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
className={cn(
|
||||
"max-w-[70%]",
|
||||
isUser && "rounded-lg bg-neutral-100 px-4 py-3 dark:bg-neutral-800",
|
||||
isAssistant && "space-y-2",
|
||||
isSystem &&
|
||||
"rounded-lg border border-blue-200 bg-blue-50 px-4 py-3 dark:border-blue-800 dark:bg-blue-900/20",
|
||||
isTool &&
|
||||
"rounded-lg border border-green-200 bg-green-50 px-4 py-3 dark:border-green-800 dark:bg-green-900/20",
|
||||
)}
|
||||
>
|
||||
{isSystem && (
|
||||
<div className="mb-2 text-xs font-medium text-blue-600 dark:text-blue-400">
|
||||
System
|
||||
</div>
|
||||
)}
|
||||
|
||||
{isTool && (
|
||||
<div className="mb-2 text-xs font-medium text-green-600 dark:text-green-400">
|
||||
Tool Response
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="prose prose-sm dark:prose-invert max-w-none">
|
||||
{processedSegments.map(renderSegment)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{isUser && (
|
||||
<div className="flex-shrink-0">
|
||||
<div className="flex h-8 w-8 items-center justify-center rounded-full bg-neutral-600">
|
||||
<User className="h-5 w-5 text-white" />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
179
autogpt_platform/frontend/src/components/chat/ToolCallWidget.tsx
Normal file
179
autogpt_platform/frontend/src/components/chat/ToolCallWidget.tsx
Normal file
@@ -0,0 +1,179 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState } from "react";
|
||||
import {
|
||||
ChevronDown,
|
||||
ChevronUp,
|
||||
Wrench,
|
||||
Loader2,
|
||||
CheckCircle,
|
||||
XCircle,
|
||||
} from "lucide-react";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface ToolCallWidgetProps {
|
||||
toolName: string;
|
||||
parameters?: Record<string, any>;
|
||||
result?: string;
|
||||
status: "calling" | "executing" | "completed" | "error";
|
||||
error?: string;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export function ToolCallWidget({
|
||||
toolName,
|
||||
parameters,
|
||||
result,
|
||||
status,
|
||||
error,
|
||||
className,
|
||||
}: ToolCallWidgetProps) {
|
||||
const [isExpanded, setIsExpanded] = useState(false);
|
||||
|
||||
const getStatusIcon = () => {
|
||||
switch (status) {
|
||||
case "calling":
|
||||
return <Loader2 className="h-4 w-4 animate-spin text-blue-500" />;
|
||||
case "executing":
|
||||
return <Loader2 className="h-4 w-4 animate-spin text-yellow-500" />;
|
||||
case "completed":
|
||||
return <CheckCircle className="h-4 w-4 text-green-500" />;
|
||||
case "error":
|
||||
return <XCircle className="h-4 w-4 text-red-500" />;
|
||||
}
|
||||
};
|
||||
|
||||
const getStatusText = () => {
|
||||
switch (status) {
|
||||
case "calling":
|
||||
return "Preparing tool...";
|
||||
case "executing":
|
||||
return "Executing...";
|
||||
case "completed":
|
||||
return "Completed";
|
||||
case "error":
|
||||
return "Error";
|
||||
}
|
||||
};
|
||||
|
||||
const getToolDisplayName = () => {
|
||||
const toolDisplayNames: Record<string, string> = {
|
||||
find_agent: "🔍 Search Marketplace",
|
||||
get_agent_details: "📋 Get Agent Details",
|
||||
check_credentials: "🔑 Check Credentials",
|
||||
setup_agent: "⚙️ Setup Agent",
|
||||
run_agent: "▶️ Run Agent",
|
||||
get_required_setup_info: "📝 Get Setup Requirements",
|
||||
};
|
||||
return toolDisplayNames[toolName] || toolName;
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"overflow-hidden rounded-lg border transition-all duration-200",
|
||||
status === "error"
|
||||
? "border-red-200 dark:border-red-800"
|
||||
: "border-neutral-200 dark:border-neutral-700",
|
||||
"bg-white dark:bg-neutral-900",
|
||||
"animate-in fade-in-50 slide-in-from-top-1",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center justify-between px-3 py-2",
|
||||
"bg-gradient-to-r",
|
||||
status === "error"
|
||||
? "from-red-50 to-red-100 dark:from-red-900/20 dark:to-red-800/20"
|
||||
: "from-neutral-50 to-neutral-100 dark:from-neutral-800/20 dark:to-neutral-700/20",
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<Wrench className="h-4 w-4 text-neutral-500 dark:text-neutral-400" />
|
||||
<span className="text-sm font-medium text-neutral-700 dark:text-neutral-300">
|
||||
{getToolDisplayName()}
|
||||
</span>
|
||||
<div className="ml-2 flex items-center gap-1.5">
|
||||
{getStatusIcon()}
|
||||
<span className="text-xs text-neutral-500 dark:text-neutral-400">
|
||||
{getStatusText()}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button
|
||||
onClick={() => setIsExpanded(!isExpanded)}
|
||||
className="rounded p-1 hover:bg-neutral-200/50 dark:hover:bg-neutral-700/50"
|
||||
>
|
||||
{isExpanded ? (
|
||||
<ChevronUp className="h-4 w-4 text-neutral-600 dark:text-neutral-400" />
|
||||
) : (
|
||||
<ChevronDown className="h-4 w-4 text-neutral-600 dark:text-neutral-400" />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{isExpanded && (
|
||||
<div className="px-4 py-3">
|
||||
{parameters && Object.keys(parameters).length > 0 && (
|
||||
<div className="mb-3">
|
||||
<div className="mb-2 text-xs font-medium text-neutral-600 dark:text-neutral-400">
|
||||
Parameters:
|
||||
</div>
|
||||
<div className="rounded-md bg-neutral-50 p-3 dark:bg-neutral-800">
|
||||
<pre className="text-xs text-neutral-700 dark:text-neutral-300">
|
||||
{JSON.stringify(parameters, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{result &&
|
||||
status === "completed" &&
|
||||
(() => {
|
||||
// Check if result is agent carousel data - if so, don't display raw JSON
|
||||
try {
|
||||
const parsed = JSON.parse(result);
|
||||
if (parsed.type === "agent_carousel") {
|
||||
return (
|
||||
<div className="text-xs text-neutral-600 dark:text-neutral-400">
|
||||
Found {parsed.count} agents matching “{parsed.query}
|
||||
”
|
||||
</div>
|
||||
);
|
||||
}
|
||||
} catch {}
|
||||
|
||||
// Display regular result
|
||||
return (
|
||||
<div>
|
||||
<div className="mb-2 text-xs font-medium text-neutral-600 dark:text-neutral-400">
|
||||
Result:
|
||||
</div>
|
||||
<div className="rounded-md bg-green-50 p-3 dark:bg-green-900/20">
|
||||
<pre className="whitespace-pre-wrap text-xs text-green-800 dark:text-green-200">
|
||||
{result}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})()}
|
||||
|
||||
{error && status === "error" && (
|
||||
<div>
|
||||
<div className="mb-2 text-xs font-medium text-red-600 dark:text-red-400">
|
||||
Error:
|
||||
</div>
|
||||
<div className="rounded-md bg-red-50 p-3 dark:bg-red-900/20">
|
||||
<p className="text-sm text-red-800 dark:text-red-200">
|
||||
{error}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
233
autogpt_platform/frontend/src/hooks/useChatSession.ts
Normal file
233
autogpt_platform/frontend/src/hooks/useChatSession.ts
Normal file
@@ -0,0 +1,233 @@
|
||||
import { useState, useEffect, useCallback, useMemo, useRef } from "react";
|
||||
import {
|
||||
// ChatAPI,
|
||||
ChatSession,
|
||||
ChatMessage,
|
||||
} from "@/lib/autogpt-server-api/chat";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
|
||||
interface UseChatSessionResult {
|
||||
session: ChatSession | null;
|
||||
messages: ChatMessage[];
|
||||
isLoading: boolean;
|
||||
error: Error | null;
|
||||
createSession: () => Promise<void>;
|
||||
loadSession: (sessionId: string, retryOnFailure?: boolean) => Promise<void>;
|
||||
refreshSession: () => Promise<void>;
|
||||
deleteSession: () => Promise<void>;
|
||||
clearSession: () => void;
|
||||
}
|
||||
|
||||
export function useChatSession(
|
||||
urlSessionId?: string | null,
|
||||
): UseChatSessionResult {
|
||||
const [session, setSession] = useState<ChatSession | null>(null);
|
||||
const [messages, setMessages] = useState<ChatMessage[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<Error | null>(null);
|
||||
const urlSessionIdRef = useRef(urlSessionId);
|
||||
|
||||
const api = useMemo(() => new BackendAPI(), []);
|
||||
const chatAPI = useMemo(() => api.chat, [api]);
|
||||
|
||||
// Keep ref updated
|
||||
useEffect(() => {
|
||||
urlSessionIdRef.current = urlSessionId;
|
||||
}, [urlSessionId]);
|
||||
|
||||
const createSession = useCallback(async () => {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const newSession = await chatAPI.createSession({});
|
||||
|
||||
setSession(newSession);
|
||||
setMessages(newSession.messages || []);
|
||||
|
||||
// Store session ID in localStorage
|
||||
localStorage.setItem("chat_session_id", newSession.id);
|
||||
} catch (err) {
|
||||
setError(err as Error);
|
||||
console.error("Failed to create chat session:", err);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [chatAPI]);
|
||||
|
||||
const loadSession = useCallback(
|
||||
async (sessionId: string, retryOnFailure = true) => {
|
||||
// For URL-based sessions, always try to load (don't skip based on previous failures)
|
||||
const failedSessionsKey = "failed_chat_sessions";
|
||||
const failedSessions = JSON.parse(
|
||||
localStorage.getItem(failedSessionsKey) || "[]",
|
||||
);
|
||||
|
||||
// Only skip if it's not explicitly requested via URL (urlSessionId)
|
||||
if (
|
||||
failedSessions.includes(sessionId) &&
|
||||
sessionId !== urlSessionIdRef.current
|
||||
) {
|
||||
console.log(
|
||||
`Session ${sessionId} previously failed to load, skipping...`,
|
||||
);
|
||||
// Clear the stored session ID and don't retry
|
||||
localStorage.removeItem("chat_session_id");
|
||||
if (!session && retryOnFailure) {
|
||||
// Create a new session instead
|
||||
await createSession();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const loadedSession = await chatAPI.getSession(sessionId, true);
|
||||
console.log("🔍 Loaded session:", sessionId, loadedSession);
|
||||
console.log("📝 Messages in session:", loadedSession.messages);
|
||||
setSession(loadedSession);
|
||||
setMessages(loadedSession.messages || []);
|
||||
|
||||
// Update localStorage to remember this session
|
||||
localStorage.setItem("chat_session_id", sessionId);
|
||||
// Clear any pending session flag
|
||||
localStorage.removeItem("pending_chat_session");
|
||||
|
||||
// Remove from failed sessions if it was there
|
||||
const updatedFailedSessions = failedSessions.filter(
|
||||
(id: string) => id !== sessionId,
|
||||
);
|
||||
localStorage.setItem(
|
||||
failedSessionsKey,
|
||||
JSON.stringify(updatedFailedSessions),
|
||||
);
|
||||
} catch (err) {
|
||||
console.error("Failed to load chat session:", err);
|
||||
|
||||
// Mark this session as failed
|
||||
failedSessions.push(sessionId);
|
||||
localStorage.setItem(failedSessionsKey, JSON.stringify(failedSessions));
|
||||
|
||||
// If session doesn't exist, clear localStorage
|
||||
localStorage.removeItem("chat_session_id");
|
||||
|
||||
if (retryOnFailure) {
|
||||
// Create a new session instead
|
||||
console.log("Session not found, creating a new one...");
|
||||
try {
|
||||
const newSession = await chatAPI.createSession();
|
||||
|
||||
setSession(newSession);
|
||||
setMessages(newSession.messages || []);
|
||||
localStorage.setItem("chat_session_id", newSession.id);
|
||||
} catch (createErr) {
|
||||
setError(createErr as Error);
|
||||
console.error("Failed to create new session:", createErr);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
},
|
||||
[chatAPI, createSession, session],
|
||||
);
|
||||
|
||||
const deleteSession = useCallback(async () => {
|
||||
if (!session) return;
|
||||
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
await chatAPI.deleteSession(session.id);
|
||||
clearSession();
|
||||
} catch (err) {
|
||||
setError(err as Error);
|
||||
console.error("Failed to delete chat session:", err);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [session, chatAPI]);
|
||||
|
||||
// Load session from localStorage or URL on mount
|
||||
useEffect(() => {
|
||||
// If urlSessionId is explicitly null, don't load any session (will create new one)
|
||||
if (urlSessionId === null) {
|
||||
// Clear stored session to start fresh
|
||||
localStorage.removeItem("chat_session_id");
|
||||
return;
|
||||
}
|
||||
|
||||
// Priority 1: URL session ID (explicit navigation to a session)
|
||||
if (urlSessionId) {
|
||||
console.log("📍 Loading session from URL:", urlSessionId);
|
||||
loadSession(urlSessionId, false); // Don't auto-create new session if URL session fails
|
||||
return;
|
||||
}
|
||||
|
||||
// Priority 2: Pending session (from auth redirect)
|
||||
const pendingSessionId = localStorage.getItem("pending_chat_session");
|
||||
if (pendingSessionId) {
|
||||
console.log("📍 Loading pending session:", pendingSessionId);
|
||||
loadSession(pendingSessionId, false); // Don't retry on failure
|
||||
// Clear the pending session flag
|
||||
localStorage.removeItem("pending_chat_session");
|
||||
localStorage.setItem("chat_session_id", pendingSessionId);
|
||||
return;
|
||||
}
|
||||
|
||||
// Priority 3: Regular stored session - ONLY load if explicitly in URL
|
||||
// Don't automatically load the last session just because it exists in localStorage
|
||||
// This prevents unwanted session persistence across page loads
|
||||
}, [urlSessionId, loadSession]);
|
||||
|
||||
const refreshSession = useCallback(async () => {
|
||||
if (!session) return;
|
||||
|
||||
try {
|
||||
console.log("🔄 Refreshing session:", session.id);
|
||||
const refreshedSession = await chatAPI.getSession(session.id, true);
|
||||
console.log("✅ Refreshed session data:", refreshedSession);
|
||||
console.log("📝 Refreshed messages:", refreshedSession.messages);
|
||||
setSession(refreshedSession);
|
||||
setMessages(refreshedSession.messages || []);
|
||||
} catch (err) {
|
||||
console.error("Failed to refresh session:", err);
|
||||
}
|
||||
}, [session, chatAPI]);
|
||||
|
||||
const clearSession = useCallback(() => {
|
||||
setSession(null);
|
||||
setMessages([]);
|
||||
setError(null);
|
||||
localStorage.removeItem("chat_session_id");
|
||||
}, []);
|
||||
|
||||
const _addMessage = useCallback((message: ChatMessage) => {
|
||||
setMessages((prev) => [...prev, message]);
|
||||
}, []);
|
||||
|
||||
const _updateLastMessage = useCallback((content: string) => {
|
||||
setMessages((prev) => {
|
||||
const newMessages = [...prev];
|
||||
if (newMessages.length > 0) {
|
||||
newMessages[newMessages.length - 1].content = content;
|
||||
}
|
||||
return newMessages;
|
||||
});
|
||||
}, []);
|
||||
|
||||
return {
|
||||
session,
|
||||
messages,
|
||||
isLoading,
|
||||
error,
|
||||
createSession,
|
||||
loadSession,
|
||||
refreshSession,
|
||||
deleteSession,
|
||||
clearSession,
|
||||
};
|
||||
}
|
||||
84
autogpt_platform/frontend/src/hooks/useChatStream.ts
Normal file
84
autogpt_platform/frontend/src/hooks/useChatStream.ts
Normal file
@@ -0,0 +1,84 @@
|
||||
import { useState, useCallback, useRef, useMemo } from "react";
|
||||
import { StreamChunk } from "@/lib/autogpt-server-api/chat";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
|
||||
interface UseChatStreamResult {
|
||||
isStreaming: boolean;
|
||||
error: Error | null;
|
||||
sendMessage: (
|
||||
sessionId: string,
|
||||
message: string,
|
||||
onChunk?: (chunk: StreamChunk) => void,
|
||||
) => Promise<void>;
|
||||
stopStreaming: () => void;
|
||||
}
|
||||
|
||||
export function useChatStream(): UseChatStreamResult {
|
||||
const [isStreaming, setIsStreaming] = useState(false);
|
||||
const [error, setError] = useState<Error | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
|
||||
const api = useMemo(() => new BackendAPI(), []);
|
||||
const chatAPI = useMemo(() => api.chat, [api]);
|
||||
|
||||
const sendMessage = useCallback(
|
||||
async (
|
||||
sessionId: string,
|
||||
message: string,
|
||||
onChunk?: (chunk: StreamChunk) => void,
|
||||
) => {
|
||||
setIsStreaming(true);
|
||||
setError(null);
|
||||
|
||||
// Create new abort controller for this stream
|
||||
abortControllerRef.current = new AbortController();
|
||||
|
||||
try {
|
||||
const stream = chatAPI.streamChat(
|
||||
sessionId,
|
||||
message,
|
||||
"gpt-4o",
|
||||
50,
|
||||
(err) => {
|
||||
setError(err);
|
||||
console.error("Stream error:", err);
|
||||
},
|
||||
);
|
||||
|
||||
for await (const chunk of stream) {
|
||||
// Check if streaming was aborted
|
||||
if (abortControllerRef.current?.signal.aborted) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (onChunk) {
|
||||
onChunk(chunk);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error && err.name !== "AbortError") {
|
||||
setError(err);
|
||||
console.error("Failed to stream message:", err);
|
||||
}
|
||||
} finally {
|
||||
setIsStreaming(false);
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
},
|
||||
[chatAPI],
|
||||
);
|
||||
|
||||
const stopStreaming = useCallback(() => {
|
||||
if (abortControllerRef.current) {
|
||||
abortControllerRef.current.abort();
|
||||
setIsStreaming(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
return {
|
||||
isStreaming,
|
||||
error,
|
||||
sendMessage,
|
||||
stopStreaming,
|
||||
};
|
||||
}
|
||||
@@ -38,6 +38,11 @@ export default function useCredentials(
|
||||
): CredentialsData | null {
|
||||
const allProviders = useContext(CredentialsProvidersContext);
|
||||
|
||||
// If block input schema doesn't have credentials, return null
|
||||
if (!credsInputSchema) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const discriminatorValue = [
|
||||
credsInputSchema.discriminator
|
||||
? getValue(credsInputSchema.discriminator, nodeInputValues)
|
||||
@@ -50,7 +55,22 @@ export default function useCredentials(
|
||||
: null;
|
||||
|
||||
let providerName: CredentialsProviderName;
|
||||
if (credsInputSchema.credentials_provider.length > 1) {
|
||||
|
||||
// Handle cases where credentials_provider might be undefined or not an array
|
||||
const credentialsProviders = Array.isArray(
|
||||
credsInputSchema.credentials_provider,
|
||||
)
|
||||
? credsInputSchema.credentials_provider
|
||||
: credsInputSchema.credentials_provider
|
||||
? [credsInputSchema.credentials_provider]
|
||||
: [];
|
||||
|
||||
if (credentialsProviders.length === 0) {
|
||||
console.warn("No credentials provider specified in schema");
|
||||
return null;
|
||||
}
|
||||
|
||||
if (credentialsProviders.length > 1) {
|
||||
if (!credsInputSchema.discriminator) {
|
||||
throw new Error(
|
||||
"Multi-provider credential input requires discriminator!",
|
||||
@@ -65,21 +85,16 @@ export default function useCredentials(
|
||||
}
|
||||
providerName = discriminatedProvider;
|
||||
} else {
|
||||
providerName = credsInputSchema.credentials_provider[0];
|
||||
providerName = credentialsProviders[0];
|
||||
}
|
||||
const provider = allProviders ? allProviders[providerName] : null;
|
||||
|
||||
// If block input schema doesn't have credentials, return null
|
||||
if (!credsInputSchema) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const supportsApiKey = credsInputSchema.credentials_types.includes("api_key");
|
||||
const supportsOAuth2 = credsInputSchema.credentials_types.includes("oauth2");
|
||||
const supportsUserPassword =
|
||||
credsInputSchema.credentials_types.includes("user_password");
|
||||
const supportsHostScoped =
|
||||
credsInputSchema.credentials_types.includes("host_scoped");
|
||||
// Safely handle credentials_types which might be undefined
|
||||
const credentialsTypes = credsInputSchema.credentials_types || [];
|
||||
const supportsApiKey = credentialsTypes.includes("api_key");
|
||||
const supportsOAuth2 = credentialsTypes.includes("oauth2");
|
||||
const supportsUserPassword = credentialsTypes.includes("user_password");
|
||||
const supportsHostScoped = credentialsTypes.includes("host_scoped");
|
||||
|
||||
// No provider means maybe it's still loading
|
||||
if (!provider) {
|
||||
|
||||
360
autogpt_platform/frontend/src/lib/autogpt-server-api/chat.ts
Normal file
360
autogpt_platform/frontend/src/lib/autogpt-server-api/chat.ts
Normal file
@@ -0,0 +1,360 @@
|
||||
import type BackendAPI from "./client";
|
||||
|
||||
export interface ChatSession {
|
||||
id: string;
|
||||
created_at: string;
|
||||
updated_at?: string;
|
||||
user_id: string;
|
||||
messages?: ChatMessage[];
|
||||
metadata?: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface ChatMessage {
|
||||
id?: string;
|
||||
content: string;
|
||||
role: "USER" | "ASSISTANT" | "SYSTEM" | "TOOL";
|
||||
created_at?: string;
|
||||
tool_calls?: any[];
|
||||
tool_call_id?: string;
|
||||
tokens?: {
|
||||
prompt?: number;
|
||||
completion?: number;
|
||||
total?: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface CreateSessionRequest {
|
||||
metadata?: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface SendMessageRequest {
|
||||
message: string;
|
||||
model?: string;
|
||||
max_context_messages?: number;
|
||||
}
|
||||
|
||||
export interface StreamChunk {
|
||||
type:
|
||||
| "text"
|
||||
| "html"
|
||||
| "error"
|
||||
| "text_chunk"
|
||||
| "tool_call"
|
||||
| "tool_response"
|
||||
| "login_needed"
|
||||
| "stream_end";
|
||||
content: string;
|
||||
// Additional fields for structured responses
|
||||
tool_id?: string;
|
||||
tool_name?: string;
|
||||
arguments?: Record<string, any>;
|
||||
result?: any;
|
||||
success?: boolean;
|
||||
message?: string;
|
||||
session_id?: string;
|
||||
agent_info?: any;
|
||||
timestamp?: string;
|
||||
summary?: any;
|
||||
}
|
||||
|
||||
export class ChatAPI {
|
||||
private api: BackendAPI;
|
||||
|
||||
constructor(api: BackendAPI) {
|
||||
this.api = api;
|
||||
}
|
||||
|
||||
async createSession(request?: CreateSessionRequest): Promise<ChatSession> {
|
||||
// Generate a unique anonymous ID for this session
|
||||
const anonId =
|
||||
typeof window !== "undefined"
|
||||
? localStorage.getItem("anon_id") ||
|
||||
Math.random().toString(36).substring(2, 15)
|
||||
: "server-anon";
|
||||
|
||||
if (typeof window !== "undefined" && !localStorage.getItem("anon_id")) {
|
||||
localStorage.setItem("anon_id", anonId);
|
||||
}
|
||||
|
||||
try {
|
||||
// First try with authentication if available
|
||||
const supabase = await (this.api as any).getSupabaseClient();
|
||||
const {
|
||||
data: { session },
|
||||
} = await supabase.auth.getSession();
|
||||
|
||||
if (session?.access_token) {
|
||||
// User is authenticated, use normal request
|
||||
const response = await (this.api as any)._request(
|
||||
"POST",
|
||||
"/v2/chat/sessions",
|
||||
request || {},
|
||||
);
|
||||
return response;
|
||||
}
|
||||
} catch (_e) {
|
||||
// Continue with anonymous session
|
||||
}
|
||||
|
||||
// Create anonymous session through proxy
|
||||
const proxyUrl = `/api/proxy/api/v2/chat/sessions`;
|
||||
const response = await fetch(proxyUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
...request,
|
||||
metadata: { anon_id: anonId },
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Failed to create chat session: ${error}`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async createSessionOld(request?: CreateSessionRequest): Promise<ChatSession> {
|
||||
const headers: HeadersInit = {
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
|
||||
const response = await fetch(`/api/proxy/api/v2/chat/sessions`, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(request || {}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Failed to create chat session: ${error}`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async getSession(
|
||||
sessionId: string,
|
||||
includeMessages = true,
|
||||
): Promise<ChatSession> {
|
||||
const response = await (this.api as any)._get(
|
||||
`/v2/chat/sessions/${sessionId}?include_messages=${includeMessages}`,
|
||||
);
|
||||
return response;
|
||||
}
|
||||
|
||||
async getSessionOld(
|
||||
sessionId: string,
|
||||
includeMessages = true,
|
||||
): Promise<ChatSession> {
|
||||
const headers: HeadersInit = {
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
|
||||
const response = await fetch(
|
||||
`/api/proxy/api/v2/chat/sessions/${sessionId}?include_messages=${includeMessages}`,
|
||||
{
|
||||
method: "GET",
|
||||
headers,
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Failed to get chat session: ${error}`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async listSessions(
|
||||
limit = 50,
|
||||
offset = 0,
|
||||
includeLastMessage = true,
|
||||
): Promise<{
|
||||
sessions: ChatSession[];
|
||||
total: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
}> {
|
||||
const params = new URLSearchParams({
|
||||
limit: limit.toString(),
|
||||
offset: offset.toString(),
|
||||
include_last_message: includeLastMessage.toString(),
|
||||
});
|
||||
|
||||
const response = await (this.api as any)._get(
|
||||
`/v2/chat/sessions?${params}`,
|
||||
);
|
||||
return response;
|
||||
}
|
||||
|
||||
async listSessionsOld(
|
||||
limit = 50,
|
||||
offset = 0,
|
||||
includeLastMessage = true,
|
||||
): Promise<{
|
||||
sessions: ChatSession[];
|
||||
total: number;
|
||||
limit: number;
|
||||
offset: number;
|
||||
}> {
|
||||
const headers: HeadersInit = {
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
|
||||
const params = new URLSearchParams({
|
||||
limit: limit.toString(),
|
||||
offset: offset.toString(),
|
||||
include_last_message: includeLastMessage.toString(),
|
||||
});
|
||||
|
||||
const response = await fetch(`/api/proxy/api/v2/chat/sessions?${params}`, {
|
||||
method: "GET",
|
||||
headers,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Failed to list chat sessions: ${error}`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async deleteSession(sessionId: string): Promise<void> {
|
||||
await (this.api as any)._delete(`/v2/chat/sessions/${sessionId}`);
|
||||
}
|
||||
|
||||
async deleteSessionOld(sessionId: string): Promise<void> {
|
||||
const headers: HeadersInit = {
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
|
||||
const response = await fetch(
|
||||
`/api/proxy/api/v2/chat/sessions/${sessionId}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
headers,
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Failed to delete chat session: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
async sendMessage(
|
||||
sessionId: string,
|
||||
request: SendMessageRequest,
|
||||
): Promise<ChatMessage> {
|
||||
const response = await (this.api as any)._request(
|
||||
"POST",
|
||||
`/v2/chat/sessions/${sessionId}/messages`,
|
||||
request,
|
||||
);
|
||||
return response;
|
||||
}
|
||||
|
||||
async sendMessageOld(
|
||||
sessionId: string,
|
||||
request: SendMessageRequest,
|
||||
): Promise<ChatMessage> {
|
||||
const headers: HeadersInit = {
|
||||
"Content-Type": "application/json",
|
||||
};
|
||||
|
||||
const response = await fetch(
|
||||
`/api/proxy/api/v2/chat/sessions/${sessionId}/messages`,
|
||||
{
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(request),
|
||||
},
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Failed to send message: ${error}`);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async *streamChat(
|
||||
sessionId: string,
|
||||
message: string,
|
||||
model = "gpt-4o",
|
||||
maxContext = 50,
|
||||
onError?: (error: Error) => void,
|
||||
): AsyncGenerator<StreamChunk, void, unknown> {
|
||||
const params = new URLSearchParams({
|
||||
message,
|
||||
model,
|
||||
max_context: maxContext.toString(),
|
||||
});
|
||||
|
||||
try {
|
||||
// Use the proxy endpoint for authentication
|
||||
// The proxy will handle adding the auth token from the server session
|
||||
const proxyUrl = `/api/proxy/api/v2/chat/sessions/${sessionId}/stream?${params}`;
|
||||
|
||||
const response = await fetch(proxyUrl, {
|
||||
method: "GET",
|
||||
// No need to set headers here - the proxy handles authentication
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Failed to stream chat: ${error}`);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body available");
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
|
||||
// Keep the last incomplete line in the buffer
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const data = line.slice(6).trim();
|
||||
|
||||
if (data === "[DONE]") {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const chunk = JSON.parse(data) as StreamChunk;
|
||||
yield chunk;
|
||||
} catch (_e) {
|
||||
console.error("Failed to parse SSE data:", data, _e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
if (onError) {
|
||||
onError(error as Error);
|
||||
} else {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
getSupabaseAnonKey,
|
||||
} from "@/lib/env-config";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { ChatAPI } from "./chat";
|
||||
import type {
|
||||
AddUserCreditsResponse,
|
||||
AnalyticsDetails,
|
||||
@@ -86,6 +87,7 @@ export default class BackendAPI {
|
||||
private wsOnDisconnectHandlers: Set<() => void> = new Set();
|
||||
private wsMessageHandlers: Record<string, Set<(data: any) => void>> = {};
|
||||
private isIntentionallyDisconnected: boolean = false;
|
||||
public chat: ChatAPI;
|
||||
|
||||
readonly HEARTBEAT_INTERVAL = 100_000; // 100 seconds
|
||||
readonly HEARTBEAT_TIMEOUT = 10_000; // 10 seconds
|
||||
@@ -98,6 +100,7 @@ export default class BackendAPI {
|
||||
) {
|
||||
this.baseUrl = baseUrl;
|
||||
this.wsUrl = wsUrl;
|
||||
this.chat = new ChatAPI(this);
|
||||
}
|
||||
|
||||
private async getSupabaseClient(): Promise<SupabaseClient | null> {
|
||||
@@ -447,7 +450,7 @@ export default class BackendAPI {
|
||||
|
||||
getStoreProfile(): Promise<ProfileDetails | null> {
|
||||
try {
|
||||
const result = this._get("/store/profile");
|
||||
const result = this._get("/v2/store/profile");
|
||||
return result;
|
||||
} catch (error) {
|
||||
console.error("Error fetching store profile:", error);
|
||||
@@ -563,7 +566,7 @@ export default class BackendAPI {
|
||||
}
|
||||
|
||||
updateStoreProfile(profile: ProfileDetails): Promise<ProfileDetails> {
|
||||
return this._request("POST", "/store/profile", profile);
|
||||
return this._request("POST", "/v2/store/profile", profile);
|
||||
}
|
||||
|
||||
reviewAgent(
|
||||
|
||||
Reference in New Issue
Block a user