feature(backend): Limit Chat to Auth Users, Limit Agent Runs Per Chat (#11330)

This commit is contained in:
Swifty
2025-11-06 13:11:15 +01:00
committed by GitHub
parent 42b643579f
commit a056d9e71a
18 changed files with 277 additions and 197 deletions

View File

@@ -50,12 +50,7 @@ from backend.util.exceptions import (
NotAuthorizedError,
NotFoundError,
)
from backend.util.feature_flag import (
Flag,
create_feature_flag_dependency,
initialize_launchdarkly,
shutdown_launchdarkly,
)
from backend.util.feature_flag import initialize_launchdarkly, shutdown_launchdarkly
from backend.util.service import UnhealthyServiceError
settings = backend.util.settings.Settings()
@@ -294,9 +289,6 @@ app.include_router(
chat_routes.router,
tags=["v2", "chat"],
prefix="/api/chat",
dependencies=[
fastapi.Depends(create_feature_flag_dependency(Flag.CHAT, default=False))
],
)
app.mount("/external-api", external_app)

View File

@@ -11,9 +11,7 @@ class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
# OpenAI API Configuration
model: str = Field(
default="qwen/qwen3-235b-a22b-2507", description="Default model to use"
)
model: str = Field(default="openai/gpt-5-mini", description="Default model to use")
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
default="https://openrouter.ai/api/v1",
@@ -36,6 +34,10 @@ class ChatConfig(BaseSettings):
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(default=3, description="Maximum number of retries")
max_agent_runs: int = Field(default=3, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=3, description="Maximum number of agent schedules"
)
@field_validator("api_key", mode="before")
@classmethod

View File

@@ -28,7 +28,7 @@ config = ChatConfig()
class ChatMessage(BaseModel):
role: str
content: str
content: str | None = None
name: str | None = None
tool_call_id: str | None = None
refusal: str | None = None
@@ -50,6 +50,8 @@ class ChatSession(BaseModel):
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
started_at: datetime
updated_at: datetime
successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {}
@staticmethod
def new(user_id: str | None) -> "ChatSession":
@@ -69,7 +71,7 @@ class ChatSession(BaseModel):
if message.role == "developer":
m = ChatCompletionDeveloperMessageParam(
role="developer",
content=message.content,
content=message.content or "",
)
if message.name:
m["name"] = message.name
@@ -77,7 +79,7 @@ class ChatSession(BaseModel):
elif message.role == "system":
m = ChatCompletionSystemMessageParam(
role="system",
content=message.content,
content=message.content or "",
)
if message.name:
m["name"] = message.name
@@ -85,7 +87,7 @@ class ChatSession(BaseModel):
elif message.role == "user":
m = ChatCompletionUserMessageParam(
role="user",
content=message.content,
content=message.content or "",
)
if message.name:
m["name"] = message.name
@@ -93,7 +95,7 @@ class ChatSession(BaseModel):
elif message.role == "assistant":
m = ChatCompletionAssistantMessageParam(
role="assistant",
content=message.content,
content=message.content or "",
)
if message.function_call:
m["function_call"] = FunctionCall(
@@ -137,7 +139,7 @@ class ChatSession(BaseModel):
messages.append(
ChatCompletionToolMessageParam(
role="tool",
content=message.content,
content=message.content or "",
tool_call_id=message.tool_call_id or "",
)
)

View File

@@ -50,7 +50,7 @@ class SessionDetailResponse(BaseModel):
"/sessions",
)
async def create_session(
user_id: Annotated[str | None, Depends(auth.get_optional_user_id)],
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CreateSessionResponse:
"""
Create a new chat session.
@@ -80,7 +80,7 @@ async def create_session(
)
async def get_session(
session_id: str,
user_id: Annotated[str | None, Depends(auth.get_optional_user_id)],
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> SessionDetailResponse:
"""
Retrieve the details of a specific chat session.
@@ -113,7 +113,7 @@ async def get_session(
async def stream_chat(
session_id: str,
message: Annotated[str, Query(min_length=1, max_length=10000)],
user_id: str | None = Depends(auth.get_optional_user_id),
user_id: str | None = Depends(auth.get_user_id),
is_user_message: bool = Query(default=True),
):
"""
@@ -146,8 +146,6 @@ async def stream_chat(
async for chunk in chat_service.stream_chat_completion(
session_id, message, is_user_message=is_user_message, user_id=user_id
):
with open("chunks.log", "a") as f:
f.write(f"{session_id}: {chunk}\n")
yield chunk.to_sse()
return StreamingResponse(

View File

@@ -132,7 +132,8 @@ async def stream_chat_completion(
has_done_tool_call = False
has_received_text = False
text_streaming_ended = False
messages_to_add: list[ChatMessage] = []
tool_response_messages: list[ChatMessage] = []
accumulated_tool_calls: list[dict[str, Any]] = []
should_retry = False
try:
@@ -142,7 +143,9 @@ async def stream_chat_completion(
):
if isinstance(chunk, StreamTextChunk):
assistant_response.content += chunk.content
content = chunk.content or ""
assert assistant_response.content is not None
assistant_response.content += content
has_received_text = True
yield chunk
elif isinstance(chunk, StreamToolCallStart):
@@ -152,15 +155,24 @@ async def stream_chat_completion(
text_streaming_ended = True
yield chunk
elif isinstance(chunk, StreamToolCall):
# Just pass on the tool call notification
pass
# Accumulate tool calls in OpenAI format
accumulated_tool_calls.append(
{
"id": chunk.tool_id,
"type": "function",
"function": {
"name": chunk.tool_name,
"arguments": orjson.dumps(chunk.arguments).decode("utf-8"),
},
}
)
elif isinstance(chunk, StreamToolExecutionResult):
result_content = (
chunk.result
if isinstance(chunk.result, str)
else orjson.dumps(chunk.result).decode("utf-8")
)
messages_to_add.append(
tool_response_messages.append(
ChatMessage(
role="tool",
content=result_content,
@@ -204,9 +216,18 @@ async def stream_chat_completion(
else:
# Non-retryable error or max retries exceeded
# Save any partial progress before reporting error
messages_to_save: list[ChatMessage] = []
# Add assistant message if it has content or tool calls
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
if assistant_response.content or assistant_response.tool_calls:
messages_to_add.append(assistant_response)
session.messages.extend(messages_to_add)
messages_to_save.append(assistant_response)
# Add tool response messages after assistant message
messages_to_save.extend(tool_response_messages)
session.messages.extend(messages_to_save)
await upsert_chat_session(session)
if not has_yielded_error:
@@ -246,11 +267,27 @@ async def stream_chat_completion(
logger.info(
f"Upserting session: {session.session_id} with user id {session.user_id}"
)
# Only append assistant response if it has content or tool calls
# to avoid saving empty messages on errors
# Build the messages list in the correct order
messages_to_save: list[ChatMessage] = []
# Add assistant message with tool_calls if any
if accumulated_tool_calls:
assistant_response.tool_calls = accumulated_tool_calls
logger.info(
f"Added {len(accumulated_tool_calls)} tool calls to assistant message"
)
if assistant_response.content or assistant_response.tool_calls:
messages_to_add.append(assistant_response)
session.messages.extend(messages_to_add)
messages_to_save.append(assistant_response)
logger.info(
f"Saving assistant message with content_len={len(assistant_response.content or '')}, tool_calls={len(assistant_response.tool_calls or [])}"
)
# Add tool response messages after assistant message
messages_to_save.extend(tool_response_messages)
logger.info(f"Saving {len(tool_response_messages)} tool response messages")
session.messages.extend(messages_to_save)
await upsert_chat_session(session)
# If we did a tool call, stream the chat completion again to get the next response
@@ -451,7 +488,7 @@ async def _yield_tool_call(
parameters=arguments,
tool_call_id=tool_calls[yield_idx]["id"],
user_id=session.user_id,
session_id=session.session_id,
session=session,
)
logger.info(f"Yielding Tool execution response: {tool_execution_response}")
yield tool_execution_response

View File

@@ -2,6 +2,8 @@ from typing import TYPE_CHECKING, Any
from openai.types.chat import ChatCompletionToolParam
from backend.server.v2.chat.model import ChatSession
from .base import BaseTool
from .find_agent import FindAgentTool
from .get_agent_details import GetAgentDetailsTool
@@ -33,7 +35,7 @@ async def execute_tool(
tool_name: str,
parameters: dict[str, Any],
user_id: str | None,
session_id: str,
session: ChatSession,
tool_call_id: str,
) -> "StreamToolExecutionResult":
@@ -47,5 +49,5 @@ async def execute_tool(
if tool_name not in tool_map:
raise ValueError(f"Tool {tool_name} not found")
return await tool_map[tool_name].execute(
user_id, session_id, tool_call_id, **parameters
user_id, session, tool_call_id, **parameters
)

View File

@@ -1,4 +1,5 @@
import uuid
from datetime import UTC, datetime
from os import getenv
import pytest
@@ -12,9 +13,23 @@ from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.model import APIKeyCredentials
from backend.data.user import get_or_create_user
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.server.v2.chat.model import ChatSession
from backend.server.v2.store import db as store_db
def make_session(user_id: str | None = None):
return ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
messages=[],
usage=[],
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
successful_agent_runs={},
successful_agent_schedules={},
)
@pytest.fixture(scope="session")
async def setup_test_data():
"""

View File

@@ -5,6 +5,7 @@ from typing import Any
from openai.types.chat import ChatCompletionToolParam
from backend.server.v2.chat.model import ChatSession
from backend.server.v2.chat.response_model import StreamToolExecutionResult
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
@@ -49,7 +50,7 @@ class BaseTool:
async def execute(
self,
user_id: str | None,
session_id: str,
session: ChatSession,
tool_call_id: str,
**kwargs,
) -> StreamToolExecutionResult:
@@ -73,13 +74,13 @@ class BaseTool:
tool_name=self.name,
result=NeedLoginResponse(
message=f"Please sign in to use {self.name}",
session_id=session_id,
session_id=session.session_id,
).model_dump_json(),
success=False,
)
try:
result = await self._execute(user_id, session_id, **kwargs)
result = await self._execute(user_id, session, **kwargs)
return StreamToolExecutionResult(
tool_id=tool_call_id,
tool_name=self.name,
@@ -93,7 +94,7 @@ class BaseTool:
result=ErrorResponse(
message=f"An error occurred while executing {self.name}",
error=str(e),
session_id=session_id,
session_id=session.session_id,
).model_dump_json(),
success=False,
)
@@ -101,7 +102,7 @@ class BaseTool:
async def _execute(
self,
user_id: str | None,
session_id: str,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Internal execution logic to be implemented by subclasses.

View File

@@ -3,6 +3,7 @@
import logging
from typing import Any
from backend.server.v2.chat.model import ChatSession
from backend.server.v2.chat.tools.base import BaseTool
from backend.server.v2.chat.tools.models import (
AgentCarouselResponse,
@@ -46,7 +47,7 @@ class FindAgentTool(BaseTool):
async def _execute(
self,
user_id: str | None,
session_id: str,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Search for agents in the marketplace.
@@ -62,7 +63,7 @@ class FindAgentTool(BaseTool):
ErrorResponse: Error message
"""
query = kwargs.get("query", "").strip()
session_id = session.session_id
if not query:
return ErrorResponse(
message="Please provide a search query",
@@ -125,25 +126,3 @@ class FindAgentTool(BaseTool):
count=len(agents),
session_id=session_id,
)
if __name__ == "__main__":
import asyncio
import prisma
find_agent_tool = FindAgentTool()
print(find_agent_tool.as_openai_tool())
async def main():
await prisma.Prisma().connect()
agents = await find_agent_tool.execute(
tool_call_id="tool_call_id",
query="Linkedin",
user_id="user",
session_id="session",
)
print(agents)
await prisma.Prisma().disconnect()
asyncio.run(main())

View File

@@ -5,6 +5,7 @@ from typing import Any
from backend.data import graph as graph_db
from backend.data.model import CredentialsMetaInput
from backend.server.v2.chat.model import ChatSession
from backend.server.v2.chat.tools.base import BaseTool
from backend.server.v2.chat.tools.models import (
AgentDetails,
@@ -46,7 +47,7 @@ class GetAgentDetailsTool(BaseTool):
async def _execute(
self,
user_id: str | None,
session_id: str,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Get detailed information about an agent.
@@ -61,7 +62,7 @@ class GetAgentDetailsTool(BaseTool):
"""
agent_id = kwargs.get("username_agent_slug", "").strip()
session_id = session.session_id
if not agent_id or "/" not in agent_id:
return ErrorResponse(
message="Please provide an agent ID in format 'creator/agent-name'",

View File

@@ -3,7 +3,11 @@ import uuid
import orjson
import pytest
from backend.server.v2.chat.tools._test_data import setup_llm_test_data, setup_test_data
from backend.server.v2.chat.tools._test_data import (
make_session,
setup_llm_test_data,
setup_test_data,
)
from backend.server.v2.chat.tools.get_agent_details import GetAgentDetailsTool
# This is so the formatter doesn't remove the fixture imports
@@ -25,10 +29,13 @@ async def test_get_agent_details_success(setup_test_data):
# Build the proper marketplace agent_id format: username/slug
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Build session
session = make_session()
# Execute the tool
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
)
@@ -85,10 +92,12 @@ async def test_get_agent_details_with_llm_credentials(setup_llm_test_data):
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
session = make_session(user_id=user.id)
# Execute the tool
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
)
@@ -110,7 +119,6 @@ async def test_get_agent_details_with_llm_credentials(setup_llm_test_data):
credentials = agent["credentials"]
# The LLM agent should have OpenAI credentials listed
# Note: This depends on how the graph's credentials_input_schema is structured
assert isinstance(credentials, list)
# Check that inputs include the user_prompt
@@ -124,10 +132,13 @@ async def test_get_agent_details_invalid_format():
"""Test error handling when agent_id is not in correct format"""
tool = GetAgentDetailsTool()
session = make_session()
session.user_id = str(uuid.uuid4())
# Execute with invalid format (no slash)
response = await tool.execute(
user_id=str(uuid.uuid4()),
session_id=str(uuid.uuid4()),
user_id=session.user_id,
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug="invalid-format",
)
@@ -147,10 +158,13 @@ async def test_get_agent_details_empty_slug():
"""Test error handling when agent_id is empty"""
tool = GetAgentDetailsTool()
session = make_session()
session.user_id = str(uuid.uuid4())
# Execute with empty slug
response = await tool.execute(
user_id=str(uuid.uuid4()),
session_id=str(uuid.uuid4()),
user_id=session.user_id,
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug="",
)
@@ -170,10 +184,13 @@ async def test_get_agent_details_not_found():
"""Test error handling when agent is not found in marketplace"""
tool = GetAgentDetailsTool()
session = make_session()
session.user_id = str(uuid.uuid4())
# Execute with non-existent agent
response = await tool.execute(
user_id=str(uuid.uuid4()),
session_id=str(uuid.uuid4()),
user_id=session.user_id,
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug="nonexistent/agent",
)
@@ -201,10 +218,13 @@ async def test_get_agent_details_anonymous_user(setup_test_data):
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
session = make_session()
# session.user_id stays as None
# Execute the tool without a user_id (anonymous)
response = await tool.execute(
user_id=None,
session_id=str(uuid.uuid4()),
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
)
@@ -238,10 +258,13 @@ async def test_get_agent_details_authenticated_user(setup_test_data):
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
session = make_session()
session.user_id = user.id
# Execute the tool with a user_id (authenticated)
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
)
@@ -275,10 +298,13 @@ async def test_get_agent_details_includes_execution_options(setup_test_data):
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
session = make_session()
session.user_id = user.id
# Execute the tool
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
)

View File

@@ -4,6 +4,7 @@ import logging
from typing import Any
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.server.v2.chat.model import ChatSession
from backend.server.v2.chat.tools.base import BaseTool
from backend.server.v2.chat.tools.get_agent_details import GetAgentDetailsTool
from backend.server.v2.chat.tools.models import (
@@ -57,7 +58,7 @@ class GetRequiredSetupInfoTool(BaseTool):
async def _execute(
self,
user_id: str | None,
session_id: str,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""
@@ -82,11 +83,9 @@ class GetRequiredSetupInfoTool(BaseTool):
assert (
user_id is not None
), "GetRequiredSetupInfoTool - This should never happen user_id is None when auth is required"
session_id = session.session_id
# Call _execute directly since we're calling internally from another tool
agent_details = await GetAgentDetailsTool()._execute(
user_id, session_id, **kwargs
)
agent_details = await GetAgentDetailsTool()._execute(user_id, session, **kwargs)
if isinstance(agent_details, ErrorResponse):
return agent_details

View File

@@ -4,6 +4,7 @@ import orjson
import pytest
from backend.server.v2.chat.tools._test_data import (
make_session,
setup_firecrawl_test_data,
setup_llm_test_data,
setup_test_data,
@@ -21,56 +22,46 @@ setup_firecrawl_test_data = setup_firecrawl_test_data
@pytest.mark.asyncio(scope="session")
async def test_get_required_setup_info_success(setup_test_data):
"""Test successfully getting setup info for a simple agent"""
# Use test data from fixture
user = setup_test_data["user"]
graph = setup_test_data["graph"]
store_submission = setup_test_data["store_submission"]
# Create the tool instance
tool = GetRequiredSetupInfoTool()
# Build the proper marketplace agent_id format: username/slug
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Execute the tool
session = make_session(user_id=user.id)
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={"test_input": "Hello World"},
session=session,
)
# Verify the response
assert response is not None
assert hasattr(response, "result")
# Parse the result JSON
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Check the basic structure
assert "setup_info" in result_data
setup_info = result_data["setup_info"]
# Check agent info
assert "agent_id" in setup_info
assert setup_info["agent_id"] == graph.id
assert "agent_name" in setup_info
assert setup_info["agent_name"] == "Test Agent"
# Check requirements
assert "requirements" in setup_info
requirements = setup_info["requirements"]
assert "credentials" in requirements
assert "inputs" in requirements
assert "execution_modes" in requirements
# Simple agent should have no credentials required
assert isinstance(requirements["credentials"], list)
assert len(requirements["credentials"]) == 0
# Check inputs format
assert isinstance(requirements["inputs"], list)
if len(requirements["inputs"]) > 0:
first_input = requirements["inputs"][0]
@@ -78,75 +69,60 @@ async def test_get_required_setup_info_success(setup_test_data):
assert "title" in first_input
assert "type" in first_input
# Check execution modes
assert isinstance(requirements["execution_modes"], list)
assert "manual" in requirements["execution_modes"]
assert "scheduled" in requirements["execution_modes"]
# Check user readiness
assert "user_readiness" in setup_info
user_readiness = setup_info["user_readiness"]
assert "has_all_credentials" in user_readiness
assert "ready_to_run" in user_readiness
# Simple agent with inputs provided should be ready
assert user_readiness["ready_to_run"] is True
@pytest.mark.asyncio(scope="session")
async def test_get_required_setup_info_missing_credentials(setup_firecrawl_test_data):
"""Test getting setup info for an agent requiring missing credentials"""
# Use test data from fixture
user = setup_firecrawl_test_data["user"]
store_submission = setup_firecrawl_test_data["store_submission"]
# Create the tool instance
tool = GetRequiredSetupInfoTool()
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Execute the tool
session = make_session(user_id=user.id)
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={"url": "https://example.com"},
session=session,
)
# Verify the response
assert response is not None
assert hasattr(response, "result")
# Parse the result JSON
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Check setup info
assert "setup_info" in result_data
setup_info = result_data["setup_info"]
# Check requirements
requirements = setup_info["requirements"]
# Should have Firecrawl credentials required
assert "credentials" in requirements
assert isinstance(requirements["credentials"], list)
assert len(requirements["credentials"]) > 0
# Check the credential requirement
firecrawl_cred = requirements["credentials"][0]
assert "provider" in firecrawl_cred
assert firecrawl_cred["provider"] == "firecrawl"
assert "type" in firecrawl_cred
assert firecrawl_cred["type"] == "api_key"
# Check user readiness - should NOT be ready since credentials are missing
user_readiness = setup_info["user_readiness"]
assert user_readiness["has_all_credentials"] is False
assert user_readiness["ready_to_run"] is False
# Check missing credentials
assert "missing_credentials" in user_readiness
assert isinstance(user_readiness["missing_credentials"], dict)
assert len(user_readiness["missing_credentials"]) > 0
@@ -155,42 +131,34 @@ async def test_get_required_setup_info_missing_credentials(setup_firecrawl_test_
@pytest.mark.asyncio(scope="session")
async def test_get_required_setup_info_with_available_credentials(setup_llm_test_data):
"""Test getting setup info when user has required credentials"""
# Use test data from fixture (includes OpenAI credentials)
user = setup_llm_test_data["user"]
store_submission = setup_llm_test_data["store_submission"]
# Create the tool instance
tool = GetRequiredSetupInfoTool()
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Execute the tool
session = make_session(user_id=user.id)
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={"user_prompt": "What is 2+2?"},
session=session,
)
# Verify the response
assert response is not None
assert hasattr(response, "result")
# Parse the result JSON
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Check setup info
setup_info = result_data["setup_info"]
# Check user readiness - should be ready since credentials are available
user_readiness = setup_info["user_readiness"]
assert user_readiness["has_all_credentials"] is True
assert user_readiness["ready_to_run"] is True
# Missing credentials should be empty
assert "missing_credentials" in user_readiness
assert len(user_readiness["missing_credentials"]) == 0
@@ -198,42 +166,34 @@ async def test_get_required_setup_info_with_available_credentials(setup_llm_test
@pytest.mark.asyncio(scope="session")
async def test_get_required_setup_info_missing_inputs(setup_test_data):
"""Test getting setup info when required inputs are not provided"""
# Use test data from fixture
user = setup_test_data["user"]
store_submission = setup_test_data["store_submission"]
# Create the tool instance
tool = GetRequiredSetupInfoTool()
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Execute the tool WITHOUT providing inputs
session = make_session(user_id=user.id)
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={}, # Empty inputs
session=session,
)
# Verify the response
assert response is not None
assert hasattr(response, "result")
# Parse the result JSON
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Check setup info
setup_info = result_data["setup_info"]
# Check requirements
requirements = setup_info["requirements"]
assert "inputs" in requirements
assert isinstance(requirements["inputs"], list)
# User readiness depends on whether inputs are required or optional
user_readiness = setup_info["user_readiness"]
assert "ready_to_run" in user_readiness
@@ -241,26 +201,24 @@ async def test_get_required_setup_info_missing_inputs(setup_test_data):
@pytest.mark.asyncio(scope="session")
async def test_get_required_setup_info_invalid_agent():
"""Test getting setup info for a non-existent agent"""
# Create the tool instance
tool = GetRequiredSetupInfoTool()
# Execute with invalid agent ID
session = make_session(user_id=None)
response = await tool.execute(
user_id=str(uuid.uuid4()),
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug="invalid/agent",
inputs={},
session=session,
)
# Verify error response
assert response is not None
assert hasattr(response, "result")
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
assert "message" in result_data
# Should indicate failure or not found
assert any(
phrase in result_data["message"].lower()
for phrase in ["not found", "failed", "error"]
@@ -270,35 +228,29 @@ async def test_get_required_setup_info_invalid_agent():
@pytest.mark.asyncio(scope="session")
async def test_get_required_setup_info_graph_metadata(setup_test_data):
"""Test that setup info includes graph metadata"""
# Use test data from fixture
user = setup_test_data["user"]
graph = setup_test_data["graph"]
store_submission = setup_test_data["store_submission"]
# Create the tool instance
tool = GetRequiredSetupInfoTool()
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Execute the tool
session = make_session(user_id=user.id)
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={"test_input": "test"},
session=session,
)
# Verify the response
assert response is not None
assert hasattr(response, "result")
# Parse the result JSON
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Check that graph_id and graph_version are included
assert "graph_id" in result_data
assert result_data["graph_id"] == graph.id
assert "graph_version" in result_data
@@ -308,41 +260,33 @@ async def test_get_required_setup_info_graph_metadata(setup_test_data):
@pytest.mark.asyncio(scope="session")
async def test_get_required_setup_info_inputs_structure(setup_test_data):
"""Test that inputs are properly structured as a list"""
# Use test data from fixture
user = setup_test_data["user"]
store_submission = setup_test_data["store_submission"]
# Create the tool instance
tool = GetRequiredSetupInfoTool()
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Execute the tool
session = make_session(user_id=user.id)
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={},
session=session,
)
# Verify the response
assert response is not None
assert hasattr(response, "result")
# Parse the result JSON
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Check inputs structure
setup_info = result_data["setup_info"]
requirements = setup_info["requirements"]
# Inputs should be a list
assert isinstance(requirements["inputs"], list)
# Each input should have proper structure
for input_field in requirements["inputs"]:
assert isinstance(input_field, dict)
assert "name" in input_field
@@ -356,38 +300,31 @@ async def test_get_required_setup_info_inputs_structure(setup_test_data):
@pytest.mark.asyncio(scope="session")
async def test_get_required_setup_info_execution_modes_structure(setup_test_data):
"""Test that execution_modes are properly structured as a list"""
# Use test data from fixture
user = setup_test_data["user"]
store_submission = setup_test_data["store_submission"]
# Create the tool instance
tool = GetRequiredSetupInfoTool()
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Execute the tool
session = make_session(user_id=user.id)
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={},
session=session,
)
# Verify the response
assert response is not None
assert hasattr(response, "result")
# Parse the result JSON
assert isinstance(response.result, str)
result_data = orjson.loads(response.result)
# Check execution modes structure
setup_info = result_data["setup_info"]
requirements = setup_info["requirements"]
# execution_modes should be a list of strings
assert isinstance(requirements["execution_modes"], list)
for mode in requirements["execution_modes"]:
assert isinstance(mode, str)

View File

@@ -7,6 +7,8 @@ from backend.data.graph import get_graph
from backend.data.model import CredentialsMetaInput
from backend.executor import utils as execution_utils
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.server.v2.chat.config import ChatConfig
from backend.server.v2.chat.model import ChatSession
from backend.server.v2.chat.tools.base import BaseTool
from backend.server.v2.chat.tools.get_required_setup_info import (
GetRequiredSetupInfoTool,
@@ -22,6 +24,7 @@ from backend.server.v2.library import db as library_db
from backend.server.v2.library import model as library_model
logger = logging.getLogger(__name__)
config = ChatConfig()
class RunAgentTool(BaseTool):
@@ -65,7 +68,7 @@ class RunAgentTool(BaseTool):
async def _execute(
self,
user_id: str | None,
session_id: str,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Execute an agent manually.
@@ -84,13 +87,12 @@ class RunAgentTool(BaseTool):
user_id is not None
), "User ID is required to run an agent. Superclass enforces authentication."
session_id = session.session_id
username_agent_slug = kwargs.get("username_agent_slug", "").strip()
inputs = kwargs.get("inputs", {})
# Call _execute directly since we're calling internally from another tool
response = await GetRequiredSetupInfoTool()._execute(
user_id, session_id, **kwargs
)
response = await GetRequiredSetupInfoTool()._execute(user_id, session, **kwargs)
if not isinstance(response, SetupRequirementsResponse):
return ErrorResponse(
@@ -126,6 +128,14 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
if graph and (
session.successful_agent_runs.get(graph.id, 0) >= config.max_agent_runs
):
return ErrorResponse(
message="Maximum number of agent schedules reached. You can't schedule this agent again in this chat session.",
session_id=session.session_id,
)
# Check if we already have a library agent for this graph
existing_library_agent = await library_db.get_library_agent_by_graph_id(
graph_id=graph.id, user_id=user_id
@@ -232,6 +242,11 @@ class RunAgentTool(BaseTool):
inputs=inputs,
graph_credentials_inputs=graph_credentials_inputs,
)
session.successful_agent_runs[library_agent.graph_id] = (
session.successful_agent_runs.get(library_agent.graph_id, 0) + 1
)
return ExecutionStartedResponse(
message="Agent execution successfully started. Do not run this tool again unless specifically asked to run the agent again.",
session_id=session_id,

View File

@@ -3,7 +3,11 @@ import uuid
import orjson
import pytest
from backend.server.v2.chat.tools._test_data import setup_llm_test_data, setup_test_data
from backend.server.v2.chat.tools._test_data import (
make_session,
setup_llm_test_data,
setup_test_data,
)
from backend.server.v2.chat.tools.run_agent import RunAgentTool
# This is so the formatter doesn't remove the fixture imports
@@ -25,6 +29,9 @@ async def test_run_agent(setup_test_data):
# Build the proper marketplace agent_id format: username/slug
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Build the session
session = make_session(user_id=user.id)
# Execute the tool
response = await tool.execute(
user_id=user.id,
@@ -32,6 +39,7 @@ async def test_run_agent(setup_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={"test_input": "Hello World"},
session=session,
)
# Verify the response
@@ -61,6 +69,9 @@ async def test_run_agent_missing_inputs(setup_test_data):
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Build the session
session = make_session(user_id=user.id)
# Execute the tool without required inputs
response = await tool.execute(
user_id=user.id,
@@ -68,6 +79,7 @@ async def test_run_agent_missing_inputs(setup_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={}, # Missing required input
session=session,
)
# Verify that we get an error response
@@ -89,6 +101,9 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
# Create the tool instance
tool = RunAgentTool()
# Build the session
session = make_session(user_id=user.id)
# Execute the tool with invalid agent ID
response = await tool.execute(
user_id=user.id,
@@ -96,6 +111,7 @@ async def test_run_agent_invalid_agent_id(setup_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug="invalid/agent-id",
inputs={"test_input": "Hello World"},
session=session,
)
# Verify that we get an error response
@@ -125,6 +141,9 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Build the session
session = make_session(user_id=user.id)
# Execute the tool with a prompt for the LLM
response = await tool.execute(
user_id=user.id,
@@ -132,6 +151,7 @@ async def test_run_agent_with_llm_credentials(setup_llm_test_data):
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
inputs={"user_prompt": "What is 2+2?"},
session=session,
)
# Verify the response

View File

@@ -9,6 +9,8 @@ from backend.data.graph import get_graph
from backend.data.model import CredentialsMetaInput
from backend.data.user import get_user_by_id
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.server.v2.chat.config import ChatConfig
from backend.server.v2.chat.model import ChatSession
from backend.server.v2.chat.tools.get_required_setup_info import (
GetRequiredSetupInfoTool,
)
@@ -28,6 +30,7 @@ from backend.util.timezone_utils import (
from .base import BaseTool
from .models import ErrorResponse, ToolResponseBase
config = ChatConfig()
logger = logging.getLogger(__name__)
@@ -113,7 +116,7 @@ class SetupAgentTool(BaseTool):
async def _execute(
self,
user_id: str | None,
session_id: str,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
"""Set up an agent with configuration.
@@ -130,6 +133,8 @@ class SetupAgentTool(BaseTool):
assert (
user_id is not None
), "User ID is required to run an agent. Superclass enforces authentication."
session_id = session.session_id
setup_type = kwargs.get("setup_type", "schedule").strip()
if setup_type != "schedule":
return ErrorResponse(
@@ -149,12 +154,22 @@ class SetupAgentTool(BaseTool):
inputs = kwargs.get("inputs", {})
library_agent = await self._get_or_add_library_agent(
username_agent_slug, user_id, session_id, **kwargs
username_agent_slug, user_id, session, **kwargs
)
if not isinstance(library_agent, AgentDetails):
# library agent is an ErrorResponse
return library_agent
if library_agent and (
session.successful_agent_schedules.get(library_agent.graph_id, 0)
if isinstance(library_agent, AgentDetails)
else 0 >= config.max_agent_schedules
):
return ErrorResponse(
message="Maximum number of agent schedules reached. You can't schedule this agent again in this chat session.",
session_id=session.session_id,
)
# At this point we know the user is ready to run the agent
# Create the schedule for the agent
from backend.server.v2.library import db as library_db
@@ -176,7 +191,7 @@ class SetupAgentTool(BaseTool):
name=cron_name,
inputs=inputs,
credentials=library_agent.required_credentials,
session_id=session_id,
session=session,
)
async def _add_graph_execution_schedule(
@@ -187,13 +202,13 @@ class SetupAgentTool(BaseTool):
name: str,
inputs: dict[str, Any],
credentials: dict[str, CredentialsMetaInput],
session_id: str,
session: ChatSession,
**kwargs,
) -> ExecutionStartedResponse | ErrorResponse:
# Use timezone from request if provided, otherwise fetch from user profile
user = await get_user_by_id(user_id)
user_timezone = get_user_timezone_or_utc(user.timezone if user else None)
session_id = session.session_id
# Map required credentials (schema field names) to actual user credential IDs
# credentials param contains CredentialsMetaInput with schema field names as keys
# We need to find the user's actual credentials that match the provider/type
@@ -252,6 +267,11 @@ class SetupAgentTool(BaseTool):
result.next_run_time = convert_utc_time_to_user_timezone(
result.next_run_time, user_timezone
)
session.successful_agent_schedules[library_agent.graph_id] = (
session.successful_agent_schedules.get(library_agent.graph_id, 0) + 1
)
return ExecutionStartedResponse(
message="Agent execution successfully scheduled. Do not run this tool again unless specifically asked to run the agent again.",
session_id=session_id,
@@ -261,12 +281,11 @@ class SetupAgentTool(BaseTool):
)
async def _get_or_add_library_agent(
self, agent_id: str, user_id: str, session_id: str, **kwargs
self, agent_id: str, user_id: str, session: ChatSession, **kwargs
) -> AgentDetails | ErrorResponse:
# Call _execute directly since we're calling internally from another tool
response = await GetRequiredSetupInfoTool()._execute(
user_id, session_id, **kwargs
)
session_id = session.session_id
response = await GetRequiredSetupInfoTool()._execute(user_id, session, **kwargs)
if not isinstance(response, SetupRequirementsResponse):
return ErrorResponse(

View File

@@ -3,7 +3,11 @@ import uuid
import orjson
import pytest
from backend.server.v2.chat.tools._test_data import setup_llm_test_data, setup_test_data
from backend.server.v2.chat.tools._test_data import (
make_session,
setup_llm_test_data,
setup_test_data,
)
from backend.server.v2.chat.tools.setup_agent import SetupAgentTool
from backend.util.clients import get_scheduler_client
@@ -22,13 +26,16 @@ async def test_setup_agent_missing_cron(setup_test_data):
# Create the tool instance
tool = SetupAgentTool()
# Build the session
session = make_session(user_id=user.id)
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Execute without cron
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
setup_type="schedule",
@@ -59,13 +66,16 @@ async def test_setup_agent_webhook_not_supported(setup_test_data):
# Create the tool instance
tool = SetupAgentTool()
# Build the session
session = make_session(user_id=user.id)
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Execute with webhook setup_type
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
setup_type="webhook",
@@ -94,13 +104,16 @@ async def test_setup_agent_schedule_success(setup_test_data):
# Create the tool instance
tool = SetupAgentTool()
# Build the session
session = make_session(user_id=user.id)
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Execute with schedule setup
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
setup_type="schedule",
@@ -137,13 +150,16 @@ async def test_setup_agent_with_credentials(setup_llm_test_data):
# Create the tool instance
tool = SetupAgentTool()
# Build the session
session = make_session(user_id=user.id)
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Execute with schedule setup
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
setup_type="schedule",
@@ -176,10 +192,13 @@ async def test_setup_agent_invalid_agent(setup_test_data):
# Create the tool instance
tool = SetupAgentTool()
# Build the session
session = make_session(user_id=user.id)
# Execute with non-existent agent
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug="nonexistent/agent",
setup_type="schedule",
@@ -214,6 +233,9 @@ async def test_setup_agent_schedule_created_in_scheduler(setup_test_data):
# Create the tool instance
tool = SetupAgentTool()
# Build the session
session = make_session(user_id=user.id)
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
@@ -223,7 +245,7 @@ async def test_setup_agent_schedule_created_in_scheduler(setup_test_data):
# Execute with schedule setup
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
setup_type="schedule",
@@ -273,6 +295,9 @@ async def test_setup_agent_schedule_with_credentials_triggered(setup_llm_test_da
# Create the tool instance
tool = SetupAgentTool()
# Build the session
session = make_session(user_id=user.id)
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
@@ -282,7 +307,7 @@ async def test_setup_agent_schedule_with_credentials_triggered(setup_llm_test_da
# Execute with schedule setup
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
setup_type="schedule",
@@ -361,13 +386,16 @@ async def test_setup_agent_creates_library_agent(setup_test_data):
# Create the tool instance
tool = SetupAgentTool()
# Build the session
session = make_session(user_id=user.id)
# Build the proper marketplace agent_id format
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
# Execute with schedule setup
response = await tool.execute(
user_id=user.id,
session_id=str(uuid.uuid4()),
session=session,
tool_call_id=str(uuid.uuid4()),
username_agent_slug=agent_marketplace_id,
setup_type="schedule",

View File

@@ -4811,9 +4811,12 @@
}
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
},
"security": [{ "HTTPBearer": [] }]
"security": [{ "HTTPBearerJWT": [] }]
}
},
"/api/chat/sessions/{session_id}": {
@@ -4822,7 +4825,7 @@
"summary": "Get Session",
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session; raises NotFoundError if not found.",
"operationId": "getV2GetSession",
"security": [{ "HTTPBearer": [] }],
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "session_id",
@@ -4849,6 +4852,9 @@
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
}
}
@@ -4859,7 +4865,7 @@
"summary": "Stream Chat",
"description": "Stream chat responses for a session.\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n message: The user's new message to process.\n user_id: Optional authenticated user ID.\n is_user_message: Whether the message is a user message.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
"operationId": "getV2StreamChat",
"security": [{ "HTTPBearer": [] }],
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "session_id",
@@ -4901,6 +4907,9 @@
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
},
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
}
}
}
@@ -4911,7 +4920,7 @@
"summary": "Session Assign User",
"description": "Assign an authenticated user to a chat session.\n\nUsed (typically post-login) to claim an existing anonymous session as the current authenticated user.\n\nArgs:\n session_id: The identifier for the (previously anonymous) session.\n user_id: The authenticated user's ID to associate with the session.\n\nReturns:\n dict: Status of the assignment.",
"operationId": "patchV2SessionAssignUser",
"security": [{ "HTTPBearer": [] }, { "HTTPBearerJWT": [] }],
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "session_id",
@@ -4966,8 +4975,7 @@
}
}
}
},
"security": [{ "HTTPBearer": [] }]
}
}
},
"/health": {
@@ -9868,8 +9876,7 @@
"type": "apiKey",
"in": "header",
"name": "X-Postmark-Webhook-Token"
},
"HTTPBearer": { "type": "http", "scheme": "bearer" }
}
},
"responses": {
"HTTP401NotAuthenticatedError": {