mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feature(backend): Limit Chat to Auth Users, Limit Agent Runs Per Chat (#11330)
This commit is contained in:
@@ -50,12 +50,7 @@ from backend.util.exceptions import (
|
||||
NotAuthorizedError,
|
||||
NotFoundError,
|
||||
)
|
||||
from backend.util.feature_flag import (
|
||||
Flag,
|
||||
create_feature_flag_dependency,
|
||||
initialize_launchdarkly,
|
||||
shutdown_launchdarkly,
|
||||
)
|
||||
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
|
||||
from backend.util.service import UnhealthyServiceError
|
||||
|
||||
settings = backend.util.settings.Settings()
|
||||
@@ -294,9 +289,6 @@ app.include_router(
|
||||
chat_routes.router,
|
||||
tags=["v2", "chat"],
|
||||
prefix="/api/chat",
|
||||
dependencies=[
|
||||
fastapi.Depends(create_feature_flag_dependency(Flag.CHAT, default=False))
|
||||
],
|
||||
)
|
||||
|
||||
app.mount("/external-api", external_app)
|
||||
|
||||
@@ -11,9 +11,7 @@ class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
|
||||
# OpenAI API Configuration
|
||||
model: str = Field(
|
||||
default="qwen/qwen3-235b-a22b-2507", description="Default model to use"
|
||||
)
|
||||
model: str = Field(default="openai/gpt-5-mini", description="Default model to use")
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
default="https://openrouter.ai/api/v1",
|
||||
@@ -36,6 +34,10 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(default=3, description="Maximum number of retries")
|
||||
max_agent_runs: int = Field(default=3, description="Maximum number of agent runs")
|
||||
max_agent_schedules: int = Field(
|
||||
default=3, description="Maximum number of agent schedules"
|
||||
)
|
||||
|
||||
@field_validator("api_key", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -28,7 +28,7 @@ config = ChatConfig()
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
content: str | None = None
|
||||
name: str | None = None
|
||||
tool_call_id: str | None = None
|
||||
refusal: str | None = None
|
||||
@@ -50,6 +50,8 @@ class ChatSession(BaseModel):
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
started_at: datetime
|
||||
updated_at: datetime
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
|
||||
@staticmethod
|
||||
def new(user_id: str | None) -> "ChatSession":
|
||||
@@ -69,7 +71,7 @@ class ChatSession(BaseModel):
|
||||
if message.role == "developer":
|
||||
m = ChatCompletionDeveloperMessageParam(
|
||||
role="developer",
|
||||
content=message.content,
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
@@ -77,7 +79,7 @@ class ChatSession(BaseModel):
|
||||
elif message.role == "system":
|
||||
m = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=message.content,
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
@@ -85,7 +87,7 @@ class ChatSession(BaseModel):
|
||||
elif message.role == "user":
|
||||
m = ChatCompletionUserMessageParam(
|
||||
role="user",
|
||||
content=message.content,
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.name:
|
||||
m["name"] = message.name
|
||||
@@ -93,7 +95,7 @@ class ChatSession(BaseModel):
|
||||
elif message.role == "assistant":
|
||||
m = ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=message.content,
|
||||
content=message.content or "",
|
||||
)
|
||||
if message.function_call:
|
||||
m["function_call"] = FunctionCall(
|
||||
@@ -137,7 +139,7 @@ class ChatSession(BaseModel):
|
||||
messages.append(
|
||||
ChatCompletionToolMessageParam(
|
||||
role="tool",
|
||||
content=message.content,
|
||||
content=message.content or "",
|
||||
tool_call_id=message.tool_call_id or "",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -50,7 +50,7 @@ class SessionDetailResponse(BaseModel):
|
||||
"/sessions",
|
||||
)
|
||||
async def create_session(
|
||||
user_id: Annotated[str | None, Depends(auth.get_optional_user_id)],
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CreateSessionResponse:
|
||||
"""
|
||||
Create a new chat session.
|
||||
@@ -80,7 +80,7 @@ async def create_session(
|
||||
)
|
||||
async def get_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_optional_user_id)],
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> SessionDetailResponse:
|
||||
"""
|
||||
Retrieve the details of a specific chat session.
|
||||
@@ -113,7 +113,7 @@ async def get_session(
|
||||
async def stream_chat(
|
||||
session_id: str,
|
||||
message: Annotated[str, Query(min_length=1, max_length=10000)],
|
||||
user_id: str | None = Depends(auth.get_optional_user_id),
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
is_user_message: bool = Query(default=True),
|
||||
):
|
||||
"""
|
||||
@@ -146,8 +146,6 @@ async def stream_chat(
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id, message, is_user_message=is_user_message, user_id=user_id
|
||||
):
|
||||
with open("chunks.log", "a") as f:
|
||||
f.write(f"{session_id}: {chunk}\n")
|
||||
yield chunk.to_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
|
||||
@@ -132,7 +132,8 @@ async def stream_chat_completion(
|
||||
has_done_tool_call = False
|
||||
has_received_text = False
|
||||
text_streaming_ended = False
|
||||
messages_to_add: list[ChatMessage] = []
|
||||
tool_response_messages: list[ChatMessage] = []
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
should_retry = False
|
||||
|
||||
try:
|
||||
@@ -142,7 +143,9 @@ async def stream_chat_completion(
|
||||
):
|
||||
|
||||
if isinstance(chunk, StreamTextChunk):
|
||||
assistant_response.content += chunk.content
|
||||
content = chunk.content or ""
|
||||
assert assistant_response.content is not None
|
||||
assistant_response.content += content
|
||||
has_received_text = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolCallStart):
|
||||
@@ -152,15 +155,24 @@ async def stream_chat_completion(
|
||||
text_streaming_ended = True
|
||||
yield chunk
|
||||
elif isinstance(chunk, StreamToolCall):
|
||||
# Just pass on the tool call notification
|
||||
pass
|
||||
# Accumulate tool calls in OpenAI format
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": chunk.tool_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": chunk.tool_name,
|
||||
"arguments": orjson.dumps(chunk.arguments).decode("utf-8"),
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(chunk, StreamToolExecutionResult):
|
||||
result_content = (
|
||||
chunk.result
|
||||
if isinstance(chunk.result, str)
|
||||
else orjson.dumps(chunk.result).decode("utf-8")
|
||||
)
|
||||
messages_to_add.append(
|
||||
tool_response_messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=result_content,
|
||||
@@ -204,9 +216,18 @@ async def stream_chat_completion(
|
||||
else:
|
||||
# Non-retryable error or max retries exceeded
|
||||
# Save any partial progress before reporting error
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
# Add assistant message if it has content or tool calls
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_add.append(assistant_response)
|
||||
session.messages.extend(messages_to_add)
|
||||
messages_to_save.append(assistant_response)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
|
||||
session.messages.extend(messages_to_save)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
if not has_yielded_error:
|
||||
@@ -246,11 +267,27 @@ async def stream_chat_completion(
|
||||
logger.info(
|
||||
f"Upserting session: {session.session_id} with user id {session.user_id}"
|
||||
)
|
||||
# Only append assistant response if it has content or tool calls
|
||||
# to avoid saving empty messages on errors
|
||||
|
||||
# Build the messages list in the correct order
|
||||
messages_to_save: list[ChatMessage] = []
|
||||
|
||||
# Add assistant message with tool_calls if any
|
||||
if accumulated_tool_calls:
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
logger.info(
|
||||
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
|
||||
)
|
||||
if assistant_response.content or assistant_response.tool_calls:
|
||||
messages_to_add.append(assistant_response)
|
||||
session.messages.extend(messages_to_add)
|
||||
messages_to_save.append(assistant_response)
|
||||
logger.info(
|
||||
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
|
||||
)
|
||||
|
||||
# Add tool response messages after assistant message
|
||||
messages_to_save.extend(tool_response_messages)
|
||||
logger.info(f"Saving {len(tool_response_messages)} tool response messages")
|
||||
|
||||
session.messages.extend(messages_to_save)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
# If we did a tool call, stream the chat completion again to get the next response
|
||||
@@ -451,7 +488,7 @@ async def _yield_tool_call(
|
||||
parameters=arguments,
|
||||
tool_call_id=tool_calls[yield_idx]["id"],
|
||||
user_id=session.user_id,
|
||||
session_id=session.session_id,
|
||||
session=session,
|
||||
)
|
||||
logger.info(f"Yielding Tool execution response: {tool_execution_response}")
|
||||
yield tool_execution_response
|
||||
|
||||
@@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .get_agent_details import GetAgentDetailsTool
|
||||
@@ -33,7 +35,7 @@ async def execute_tool(
|
||||
tool_name: str,
|
||||
parameters: dict[str, Any],
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
) -> "StreamToolExecutionResult":
|
||||
|
||||
@@ -47,5 +49,5 @@ async def execute_tool(
|
||||
if tool_name not in tool_map:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
return await tool_map[tool_name].execute(
|
||||
user_id, session_id, tool_call_id, **parameters
|
||||
user_id, session, tool_call_id, **parameters
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
@@ -12,9 +13,23 @@ from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.store import db as store_db
|
||||
|
||||
|
||||
def make_session(user_id: str | None = None):
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
messages=[],
|
||||
usage=[],
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
successful_agent_runs={},
|
||||
successful_agent_schedules={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
async def setup_test_data():
|
||||
"""
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.response_model import StreamToolExecutionResult
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
@@ -49,7 +50,7 @@ class BaseTool:
|
||||
async def execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
session: ChatSession,
|
||||
tool_call_id: str,
|
||||
**kwargs,
|
||||
) -> StreamToolExecutionResult:
|
||||
@@ -73,13 +74,13 @@ class BaseTool:
|
||||
tool_name=self.name,
|
||||
result=NeedLoginResponse(
|
||||
message=f"Please sign in to use {self.name}",
|
||||
session_id=session_id,
|
||||
session_id=session.session_id,
|
||||
).model_dump_json(),
|
||||
success=False,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self._execute(user_id, session_id, **kwargs)
|
||||
result = await self._execute(user_id, session, **kwargs)
|
||||
return StreamToolExecutionResult(
|
||||
tool_id=tool_call_id,
|
||||
tool_name=self.name,
|
||||
@@ -93,7 +94,7 @@ class BaseTool:
|
||||
result=ErrorResponse(
|
||||
message=f"An error occurred while executing {self.name}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
session_id=session.session_id,
|
||||
).model_dump_json(),
|
||||
success=False,
|
||||
)
|
||||
@@ -101,7 +102,7 @@ class BaseTool:
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Internal execution logic to be implemented by subclasses.
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
AgentCarouselResponse,
|
||||
@@ -46,7 +47,7 @@ class FindAgentTool(BaseTool):
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Search for agents in the marketplace.
|
||||
@@ -62,7 +63,7 @@ class FindAgentTool(BaseTool):
|
||||
ErrorResponse: Error message
|
||||
"""
|
||||
query = kwargs.get("query", "").strip()
|
||||
|
||||
session_id = session.session_id
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query",
|
||||
@@ -125,25 +126,3 @@ class FindAgentTool(BaseTool):
|
||||
count=len(agents),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
import prisma
|
||||
|
||||
find_agent_tool = FindAgentTool()
|
||||
print(find_agent_tool.as_openai_tool())
|
||||
|
||||
async def main():
|
||||
await prisma.Prisma().connect()
|
||||
agents = await find_agent_tool.execute(
|
||||
tool_call_id="tool_call_id",
|
||||
query="Linkedin",
|
||||
user_id="user",
|
||||
session_id="session",
|
||||
)
|
||||
print(agents)
|
||||
await prisma.Prisma().disconnect()
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any
|
||||
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
AgentDetails,
|
||||
@@ -46,7 +47,7 @@ class GetAgentDetailsTool(BaseTool):
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Get detailed information about an agent.
|
||||
@@ -61,7 +62,7 @@ class GetAgentDetailsTool(BaseTool):
|
||||
|
||||
"""
|
||||
agent_id = kwargs.get("username_agent_slug", "").strip()
|
||||
|
||||
session_id = session.session_id
|
||||
if not agent_id or "/" not in agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide an agent ID in format 'creator/agent-name'",
|
||||
|
||||
@@ -3,7 +3,11 @@ import uuid
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.tools._test_data import setup_llm_test_data, setup_test_data
|
||||
from backend.server.v2.chat.tools._test_data import (
|
||||
make_session,
|
||||
setup_llm_test_data,
|
||||
setup_test_data,
|
||||
)
|
||||
from backend.server.v2.chat.tools.get_agent_details import GetAgentDetailsTool
|
||||
|
||||
# This is so the formatter doesn't remove the fixture imports
|
||||
@@ -25,10 +29,13 @@ async def test_get_agent_details_success(setup_test_data):
|
||||
# Build the proper marketplace agent_id format: username/slug
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Build session
|
||||
session = make_session()
|
||||
|
||||
# Execute the tool
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
)
|
||||
@@ -85,10 +92,12 @@ async def test_get_agent_details_with_llm_credentials(setup_llm_test_data):
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute the tool
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
)
|
||||
@@ -110,7 +119,6 @@ async def test_get_agent_details_with_llm_credentials(setup_llm_test_data):
|
||||
credentials = agent["credentials"]
|
||||
|
||||
# The LLM agent should have OpenAI credentials listed
|
||||
# Note: This depends on how the graph's credentials_input_schema is structured
|
||||
assert isinstance(credentials, list)
|
||||
|
||||
# Check that inputs include the user_prompt
|
||||
@@ -124,10 +132,13 @@ async def test_get_agent_details_invalid_format():
|
||||
"""Test error handling when agent_id is not in correct format"""
|
||||
tool = GetAgentDetailsTool()
|
||||
|
||||
session = make_session()
|
||||
session.user_id = str(uuid.uuid4())
|
||||
|
||||
# Execute with invalid format (no slash)
|
||||
response = await tool.execute(
|
||||
user_id=str(uuid.uuid4()),
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=session.user_id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="invalid-format",
|
||||
)
|
||||
@@ -147,10 +158,13 @@ async def test_get_agent_details_empty_slug():
|
||||
"""Test error handling when agent_id is empty"""
|
||||
tool = GetAgentDetailsTool()
|
||||
|
||||
session = make_session()
|
||||
session.user_id = str(uuid.uuid4())
|
||||
|
||||
# Execute with empty slug
|
||||
response = await tool.execute(
|
||||
user_id=str(uuid.uuid4()),
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=session.user_id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="",
|
||||
)
|
||||
@@ -170,10 +184,13 @@ async def test_get_agent_details_not_found():
|
||||
"""Test error handling when agent is not found in marketplace"""
|
||||
tool = GetAgentDetailsTool()
|
||||
|
||||
session = make_session()
|
||||
session.user_id = str(uuid.uuid4())
|
||||
|
||||
# Execute with non-existent agent
|
||||
response = await tool.execute(
|
||||
user_id=str(uuid.uuid4()),
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=session.user_id,
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="nonexistent/agent",
|
||||
)
|
||||
@@ -201,10 +218,13 @@ async def test_get_agent_details_anonymous_user(setup_test_data):
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session()
|
||||
# session.user_id stays as None
|
||||
|
||||
# Execute the tool without a user_id (anonymous)
|
||||
response = await tool.execute(
|
||||
user_id=None,
|
||||
session_id=str(uuid.uuid4()),
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
)
|
||||
@@ -238,10 +258,13 @@ async def test_get_agent_details_authenticated_user(setup_test_data):
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session()
|
||||
session.user_id = user.id
|
||||
|
||||
# Execute the tool with a user_id (authenticated)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
)
|
||||
@@ -275,10 +298,13 @@ async def test_get_agent_details_includes_execution_options(setup_test_data):
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
session = make_session()
|
||||
session.user_id = user.id
|
||||
|
||||
# Execute the tool
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.get_agent_details import GetAgentDetailsTool
|
||||
from backend.server.v2.chat.tools.models import (
|
||||
@@ -57,7 +58,7 @@ class GetRequiredSetupInfoTool(BaseTool):
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""
|
||||
@@ -82,11 +83,9 @@ class GetRequiredSetupInfoTool(BaseTool):
|
||||
assert (
|
||||
user_id is not None
|
||||
), "GetRequiredSetupInfoTool - This should never happen user_id is None when auth is required"
|
||||
|
||||
session_id = session.session_id
|
||||
# Call _execute directly since we're calling internally from another tool
|
||||
agent_details = await GetAgentDetailsTool()._execute(
|
||||
user_id, session_id, **kwargs
|
||||
)
|
||||
agent_details = await GetAgentDetailsTool()._execute(user_id, session, **kwargs)
|
||||
|
||||
if isinstance(agent_details, ErrorResponse):
|
||||
return agent_details
|
||||
|
||||
@@ -4,6 +4,7 @@ import orjson
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.tools._test_data import (
|
||||
make_session,
|
||||
setup_firecrawl_test_data,
|
||||
setup_llm_test_data,
|
||||
setup_test_data,
|
||||
@@ -21,56 +22,46 @@ setup_firecrawl_test_data = setup_firecrawl_test_data
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_success(setup_test_data):
|
||||
"""Test successfully getting setup info for a simple agent"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
graph = setup_test_data["graph"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
|
||||
# Build the proper marketplace agent_id format: username/slug
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute the tool
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"test_input": "Hello World"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Check the basic structure
|
||||
assert "setup_info" in result_data
|
||||
setup_info = result_data["setup_info"]
|
||||
|
||||
# Check agent info
|
||||
assert "agent_id" in setup_info
|
||||
assert setup_info["agent_id"] == graph.id
|
||||
assert "agent_name" in setup_info
|
||||
assert setup_info["agent_name"] == "Test Agent"
|
||||
|
||||
# Check requirements
|
||||
assert "requirements" in setup_info
|
||||
requirements = setup_info["requirements"]
|
||||
assert "credentials" in requirements
|
||||
assert "inputs" in requirements
|
||||
assert "execution_modes" in requirements
|
||||
|
||||
# Simple agent should have no credentials required
|
||||
assert isinstance(requirements["credentials"], list)
|
||||
assert len(requirements["credentials"]) == 0
|
||||
|
||||
# Check inputs format
|
||||
assert isinstance(requirements["inputs"], list)
|
||||
if len(requirements["inputs"]) > 0:
|
||||
first_input = requirements["inputs"][0]
|
||||
@@ -78,75 +69,60 @@ async def test_get_required_setup_info_success(setup_test_data):
|
||||
assert "title" in first_input
|
||||
assert "type" in first_input
|
||||
|
||||
# Check execution modes
|
||||
assert isinstance(requirements["execution_modes"], list)
|
||||
assert "manual" in requirements["execution_modes"]
|
||||
assert "scheduled" in requirements["execution_modes"]
|
||||
|
||||
# Check user readiness
|
||||
assert "user_readiness" in setup_info
|
||||
user_readiness = setup_info["user_readiness"]
|
||||
assert "has_all_credentials" in user_readiness
|
||||
assert "ready_to_run" in user_readiness
|
||||
# Simple agent with inputs provided should be ready
|
||||
assert user_readiness["ready_to_run"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_missing_credentials(setup_firecrawl_test_data):
|
||||
"""Test getting setup info for an agent requiring missing credentials"""
|
||||
# Use test data from fixture
|
||||
user = setup_firecrawl_test_data["user"]
|
||||
store_submission = setup_firecrawl_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute the tool
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"url": "https://example.com"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Check setup info
|
||||
assert "setup_info" in result_data
|
||||
setup_info = result_data["setup_info"]
|
||||
|
||||
# Check requirements
|
||||
requirements = setup_info["requirements"]
|
||||
|
||||
# Should have Firecrawl credentials required
|
||||
assert "credentials" in requirements
|
||||
assert isinstance(requirements["credentials"], list)
|
||||
assert len(requirements["credentials"]) > 0
|
||||
|
||||
# Check the credential requirement
|
||||
firecrawl_cred = requirements["credentials"][0]
|
||||
assert "provider" in firecrawl_cred
|
||||
assert firecrawl_cred["provider"] == "firecrawl"
|
||||
assert "type" in firecrawl_cred
|
||||
assert firecrawl_cred["type"] == "api_key"
|
||||
|
||||
# Check user readiness - should NOT be ready since credentials are missing
|
||||
user_readiness = setup_info["user_readiness"]
|
||||
assert user_readiness["has_all_credentials"] is False
|
||||
assert user_readiness["ready_to_run"] is False
|
||||
|
||||
# Check missing credentials
|
||||
assert "missing_credentials" in user_readiness
|
||||
assert isinstance(user_readiness["missing_credentials"], dict)
|
||||
assert len(user_readiness["missing_credentials"]) > 0
|
||||
@@ -155,42 +131,34 @@ async def test_get_required_setup_info_missing_credentials(setup_firecrawl_test_
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_with_available_credentials(setup_llm_test_data):
|
||||
"""Test getting setup info when user has required credentials"""
|
||||
# Use test data from fixture (includes OpenAI credentials)
|
||||
user = setup_llm_test_data["user"]
|
||||
store_submission = setup_llm_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute the tool
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"user_prompt": "What is 2+2?"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Check setup info
|
||||
setup_info = result_data["setup_info"]
|
||||
|
||||
# Check user readiness - should be ready since credentials are available
|
||||
user_readiness = setup_info["user_readiness"]
|
||||
assert user_readiness["has_all_credentials"] is True
|
||||
assert user_readiness["ready_to_run"] is True
|
||||
|
||||
# Missing credentials should be empty
|
||||
assert "missing_credentials" in user_readiness
|
||||
assert len(user_readiness["missing_credentials"]) == 0
|
||||
|
||||
@@ -198,42 +166,34 @@ async def test_get_required_setup_info_with_available_credentials(setup_llm_test
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_missing_inputs(setup_test_data):
|
||||
"""Test getting setup info when required inputs are not provided"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute the tool WITHOUT providing inputs
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={}, # Empty inputs
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Check setup info
|
||||
setup_info = result_data["setup_info"]
|
||||
|
||||
# Check requirements
|
||||
requirements = setup_info["requirements"]
|
||||
assert "inputs" in requirements
|
||||
assert isinstance(requirements["inputs"], list)
|
||||
|
||||
# User readiness depends on whether inputs are required or optional
|
||||
user_readiness = setup_info["user_readiness"]
|
||||
assert "ready_to_run" in user_readiness
|
||||
|
||||
@@ -241,26 +201,24 @@ async def test_get_required_setup_info_missing_inputs(setup_test_data):
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_invalid_agent():
|
||||
"""Test getting setup info for a non-existent agent"""
|
||||
# Create the tool instance
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
|
||||
# Execute with invalid agent ID
|
||||
session = make_session(user_id=None)
|
||||
response = await tool.execute(
|
||||
user_id=str(uuid.uuid4()),
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="invalid/agent",
|
||||
inputs={},
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify error response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
assert "message" in result_data
|
||||
# Should indicate failure or not found
|
||||
assert any(
|
||||
phrase in result_data["message"].lower()
|
||||
for phrase in ["not found", "failed", "error"]
|
||||
@@ -270,35 +228,29 @@ async def test_get_required_setup_info_invalid_agent():
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_graph_metadata(setup_test_data):
|
||||
"""Test that setup info includes graph metadata"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
graph = setup_test_data["graph"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute the tool
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"test_input": "test"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Check that graph_id and graph_version are included
|
||||
assert "graph_id" in result_data
|
||||
assert result_data["graph_id"] == graph.id
|
||||
assert "graph_version" in result_data
|
||||
@@ -308,41 +260,33 @@ async def test_get_required_setup_info_graph_metadata(setup_test_data):
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_inputs_structure(setup_test_data):
|
||||
"""Test that inputs are properly structured as a list"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute the tool
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={},
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Check inputs structure
|
||||
setup_info = result_data["setup_info"]
|
||||
requirements = setup_info["requirements"]
|
||||
|
||||
# Inputs should be a list
|
||||
assert isinstance(requirements["inputs"], list)
|
||||
|
||||
# Each input should have proper structure
|
||||
for input_field in requirements["inputs"]:
|
||||
assert isinstance(input_field, dict)
|
||||
assert "name" in input_field
|
||||
@@ -356,38 +300,31 @@ async def test_get_required_setup_info_inputs_structure(setup_test_data):
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_get_required_setup_info_execution_modes_structure(setup_test_data):
|
||||
"""Test that execution_modes are properly structured as a list"""
|
||||
# Use test data from fixture
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
# Create the tool instance
|
||||
tool = GetRequiredSetupInfoTool()
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute the tool
|
||||
session = make_session(user_id=user.id)
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={},
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
assert response is not None
|
||||
assert hasattr(response, "result")
|
||||
|
||||
# Parse the result JSON
|
||||
assert isinstance(response.result, str)
|
||||
result_data = orjson.loads(response.result)
|
||||
|
||||
# Check execution modes structure
|
||||
setup_info = result_data["setup_info"]
|
||||
requirements = setup_info["requirements"]
|
||||
|
||||
# execution_modes should be a list of strings
|
||||
assert isinstance(requirements["execution_modes"], list)
|
||||
for mode in requirements["execution_modes"]:
|
||||
assert isinstance(mode, str)
|
||||
|
||||
@@ -7,6 +7,8 @@ from backend.data.graph import get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.base import BaseTool
|
||||
from backend.server.v2.chat.tools.get_required_setup_info import (
|
||||
GetRequiredSetupInfoTool,
|
||||
@@ -22,6 +24,7 @@ from backend.server.v2.library import db as library_db
|
||||
from backend.server.v2.library import model as library_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class RunAgentTool(BaseTool):
|
||||
@@ -65,7 +68,7 @@ class RunAgentTool(BaseTool):
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute an agent manually.
|
||||
@@ -84,13 +87,12 @@ class RunAgentTool(BaseTool):
|
||||
user_id is not None
|
||||
), "User ID is required to run an agent. Superclass enforces authentication."
|
||||
|
||||
session_id = session.session_id
|
||||
username_agent_slug = kwargs.get("username_agent_slug", "").strip()
|
||||
inputs = kwargs.get("inputs", {})
|
||||
|
||||
# Call _execute directly since we're calling internally from another tool
|
||||
response = await GetRequiredSetupInfoTool()._execute(
|
||||
user_id, session_id, **kwargs
|
||||
)
|
||||
response = await GetRequiredSetupInfoTool()._execute(user_id, session, **kwargs)
|
||||
|
||||
if not isinstance(response, SetupRequirementsResponse):
|
||||
return ErrorResponse(
|
||||
@@ -126,6 +128,14 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if graph and (
|
||||
session.successful_agent_runs.get(graph.id, 0) >= config.max_agent_runs
|
||||
):
|
||||
return ErrorResponse(
|
||||
message="Maximum number of agent schedules reached. You can't schedule this agent again in this chat session.",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
# Check if we already have a library agent for this graph
|
||||
existing_library_agent = await library_db.get_library_agent_by_graph_id(
|
||||
graph_id=graph.id, user_id=user_id
|
||||
@@ -232,6 +242,11 @@ class RunAgentTool(BaseTool):
|
||||
inputs=inputs,
|
||||
graph_credentials_inputs=graph_credentials_inputs,
|
||||
)
|
||||
|
||||
session.successful_agent_runs[library_agent.graph_id] = (
|
||||
session.successful_agent_runs.get(library_agent.graph_id, 0) + 1
|
||||
)
|
||||
|
||||
return ExecutionStartedResponse(
|
||||
message="Agent execution successfully started. Do not run this tool again unless specifically asked to run the agent again.",
|
||||
session_id=session_id,
|
||||
|
||||
@@ -3,7 +3,11 @@ import uuid
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.tools._test_data import setup_llm_test_data, setup_test_data
|
||||
from backend.server.v2.chat.tools._test_data import (
|
||||
make_session,
|
||||
setup_llm_test_data,
|
||||
setup_test_data,
|
||||
)
|
||||
from backend.server.v2.chat.tools.run_agent import RunAgentTool
|
||||
|
||||
# This is so the formatter doesn't remove the fixture imports
|
||||
@@ -25,6 +29,9 @@ async def test_run_agent(setup_test_data):
|
||||
# Build the proper marketplace agent_id format: username/slug
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute the tool
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
@@ -32,6 +39,7 @@ async def test_run_agent(setup_test_data):
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"test_input": "Hello World"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
@@ -61,6 +69,9 @@ async def test_run_agent_missing_inputs(setup_test_data):
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute the tool without required inputs
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
@@ -68,6 +79,7 @@ async def test_run_agent_missing_inputs(setup_test_data):
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={}, # Missing required input
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify that we get an error response
|
||||
@@ -89,6 +101,9 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
|
||||
# Create the tool instance
|
||||
tool = RunAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute the tool with invalid agent ID
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
@@ -96,6 +111,7 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="invalid/agent-id",
|
||||
inputs={"test_input": "Hello World"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify that we get an error response
|
||||
@@ -125,6 +141,9 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute the tool with a prompt for the LLM
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
@@ -132,6 +151,7 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={"user_prompt": "What is 2+2?"},
|
||||
session=session,
|
||||
)
|
||||
|
||||
# Verify the response
|
||||
|
||||
@@ -9,6 +9,8 @@ from backend.data.graph import get_graph
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.server.v2.chat.config import ChatConfig
|
||||
from backend.server.v2.chat.model import ChatSession
|
||||
from backend.server.v2.chat.tools.get_required_setup_info import (
|
||||
GetRequiredSetupInfoTool,
|
||||
)
|
||||
@@ -28,6 +30,7 @@ from backend.util.timezone_utils import (
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase
|
||||
|
||||
config = ChatConfig()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -113,7 +116,7 @@ class SetupAgentTool(BaseTool):
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session_id: str,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Set up an agent with configuration.
|
||||
@@ -130,6 +133,8 @@ class SetupAgentTool(BaseTool):
|
||||
assert (
|
||||
user_id is not None
|
||||
), "User ID is required to run an agent. Superclass enforces authentication."
|
||||
|
||||
session_id = session.session_id
|
||||
setup_type = kwargs.get("setup_type", "schedule").strip()
|
||||
if setup_type != "schedule":
|
||||
return ErrorResponse(
|
||||
@@ -149,12 +154,22 @@ class SetupAgentTool(BaseTool):
|
||||
inputs = kwargs.get("inputs", {})
|
||||
|
||||
library_agent = await self._get_or_add_library_agent(
|
||||
username_agent_slug, user_id, session_id, **kwargs
|
||||
username_agent_slug, user_id, session, **kwargs
|
||||
)
|
||||
|
||||
if not isinstance(library_agent, AgentDetails):
|
||||
# library agent is an ErrorResponse
|
||||
return library_agent
|
||||
|
||||
if library_agent and (
|
||||
session.successful_agent_schedules.get(library_agent.graph_id, 0)
|
||||
if isinstance(library_agent, AgentDetails)
|
||||
else 0 >= config.max_agent_schedules
|
||||
):
|
||||
return ErrorResponse(
|
||||
message="Maximum number of agent schedules reached. You can't schedule this agent again in this chat session.",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
# At this point we know the user is ready to run the agent
|
||||
# Create the schedule for the agent
|
||||
from backend.server.v2.library import db as library_db
|
||||
@@ -176,7 +191,7 @@ class SetupAgentTool(BaseTool):
|
||||
name=cron_name,
|
||||
inputs=inputs,
|
||||
credentials=library_agent.required_credentials,
|
||||
session_id=session_id,
|
||||
session=session,
|
||||
)
|
||||
|
||||
async def _add_graph_execution_schedule(
|
||||
@@ -187,13 +202,13 @@ class SetupAgentTool(BaseTool):
|
||||
name: str,
|
||||
inputs: dict[str, Any],
|
||||
credentials: dict[str, CredentialsMetaInput],
|
||||
session_id: str,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ExecutionStartedResponse | ErrorResponse:
|
||||
# Use timezone from request if provided, otherwise fetch from user profile
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
|
||||
|
||||
session_id = session.session_id
|
||||
# Map required credentials (schema field names) to actual user credential IDs
|
||||
# credentials param contains CredentialsMetaInput with schema field names as keys
|
||||
# We need to find the user's actual credentials that match the provider/type
|
||||
@@ -252,6 +267,11 @@ class SetupAgentTool(BaseTool):
|
||||
result.next_run_time = convert_utc_time_to_user_timezone(
|
||||
result.next_run_time, user_timezone
|
||||
)
|
||||
|
||||
session.successful_agent_schedules[library_agent.graph_id] = (
|
||||
session.successful_agent_schedules.get(library_agent.graph_id, 0) + 1
|
||||
)
|
||||
|
||||
return ExecutionStartedResponse(
|
||||
message="Agent execution successfully scheduled. Do not run this tool again unless specifically asked to run the agent again.",
|
||||
session_id=session_id,
|
||||
@@ -261,12 +281,11 @@ class SetupAgentTool(BaseTool):
|
||||
)
|
||||
|
||||
async def _get_or_add_library_agent(
|
||||
self, agent_id: str, user_id: str, session_id: str, **kwargs
|
||||
self, agent_id: str, user_id: str, session: ChatSession, **kwargs
|
||||
) -> AgentDetails | ErrorResponse:
|
||||
# Call _execute directly since we're calling internally from another tool
|
||||
response = await GetRequiredSetupInfoTool()._execute(
|
||||
user_id, session_id, **kwargs
|
||||
)
|
||||
session_id = session.session_id
|
||||
response = await GetRequiredSetupInfoTool()._execute(user_id, session, **kwargs)
|
||||
|
||||
if not isinstance(response, SetupRequirementsResponse):
|
||||
return ErrorResponse(
|
||||
|
||||
@@ -3,7 +3,11 @@ import uuid
|
||||
import orjson
|
||||
import pytest
|
||||
|
||||
from backend.server.v2.chat.tools._test_data import setup_llm_test_data, setup_test_data
|
||||
from backend.server.v2.chat.tools._test_data import (
|
||||
make_session,
|
||||
setup_llm_test_data,
|
||||
setup_test_data,
|
||||
)
|
||||
from backend.server.v2.chat.tools.setup_agent import SetupAgentTool
|
||||
from backend.util.clients import get_scheduler_client
|
||||
|
||||
@@ -22,13 +26,16 @@ async def test_setup_agent_missing_cron(setup_test_data):
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute without cron
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="schedule",
|
||||
@@ -59,13 +66,16 @@ async def test_setup_agent_webhook_not_supported(setup_test_data):
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute with webhook setup_type
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="webhook",
|
||||
@@ -94,13 +104,16 @@ async def test_setup_agent_schedule_success(setup_test_data):
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute with schedule setup
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="schedule",
|
||||
@@ -137,13 +150,16 @@ async def test_setup_agent_with_credentials(setup_llm_test_data):
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute with schedule setup
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="schedule",
|
||||
@@ -176,10 +192,13 @@ async def test_setup_agent_invalid_agent(setup_test_data):
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute with non-existent agent
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug="nonexistent/agent",
|
||||
setup_type="schedule",
|
||||
@@ -214,6 +233,9 @@ async def test_setup_agent_schedule_created_in_scheduler(setup_test_data):
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
@@ -223,7 +245,7 @@ async def test_setup_agent_schedule_created_in_scheduler(setup_test_data):
|
||||
# Execute with schedule setup
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="schedule",
|
||||
@@ -273,6 +295,9 @@ async def test_setup_agent_schedule_with_credentials_triggered(setup_llm_test_da
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
@@ -282,7 +307,7 @@ async def test_setup_agent_schedule_with_credentials_triggered(setup_llm_test_da
|
||||
# Execute with schedule setup
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="schedule",
|
||||
@@ -361,13 +386,16 @@ async def test_setup_agent_creates_library_agent(setup_test_data):
|
||||
# Create the tool instance
|
||||
tool = SetupAgentTool()
|
||||
|
||||
# Build the session
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Build the proper marketplace agent_id format
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
|
||||
# Execute with schedule setup
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
session=session,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
setup_type="schedule",
|
||||
|
||||
@@ -4811,9 +4811,12 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearer": [] }]
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions/{session_id}": {
|
||||
@@ -4822,7 +4825,7 @@
|
||||
"summary": "Get Session",
|
||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.",
|
||||
"operationId": "getV2GetSession",
|
||||
"security": [{ "HTTPBearer": [] }],
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
@@ -4849,6 +4852,9 @@
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4859,7 +4865,7 @@
|
||||
"summary": "Stream Chat",
|
||||
"description": "Stream chat responses for a session.\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n message: The user's new message to process.\n user_id: Optional authenticated user ID.\n is_user_message: Whether the message is a user message.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
|
||||
"operationId": "getV2StreamChat",
|
||||
"security": [{ "HTTPBearer": [] }],
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
@@ -4901,6 +4907,9 @@
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4911,7 +4920,7 @@
|
||||
"summary": "Session Assign User",
|
||||
"description": "Assign an authenticated user to a chat session.\n\nUsed (typically post-login) to claim an existing anonymous session as the current authenticated user.\n\nArgs:\n session_id: The identifier for the (previously anonymous) session.\n user_id: The authenticated user's ID to associate with the session.\n\nReturns:\n dict: Status of the assignment.",
|
||||
"operationId": "patchV2SessionAssignUser",
|
||||
"security": [{ "HTTPBearer": [] }, { "HTTPBearerJWT": [] }],
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "session_id",
|
||||
@@ -4966,8 +4975,7 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [{ "HTTPBearer": [] }]
|
||||
}
|
||||
}
|
||||
},
|
||||
"/health": {
|
||||
@@ -9868,8 +9876,7 @@
|
||||
"type": "apiKey",
|
||||
"in": "header",
|
||||
"name": "X-Postmark-Webhook-Token"
|
||||
},
|
||||
"HTTPBearer": { "type": "http", "scheme": "bearer" }
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"HTTP401NotAuthenticatedError": {
|
||||
|
||||
Reference in New Issue
Block a user