mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-12 15:55:03 -05:00
Compare commits
26 Commits
fix/copilo
...
pwuts/open
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2a46d3fbf4 | ||
|
|
ab25516a46 | ||
|
|
6e2f595c7d | ||
|
|
e523eb62b5 | ||
|
|
97ff65ef6a | ||
|
|
e8b81f71ef | ||
|
|
d652821ed5 | ||
|
|
80659d90e4 | ||
|
|
eef892893c | ||
|
|
23175708e6 | ||
|
|
f02c00374e | ||
|
|
2fa166d839 | ||
|
|
d927e4b611 | ||
|
|
6591b2171c | ||
|
|
85d97a9d5c | ||
|
|
16c8b2a6e3 | ||
|
|
cad54a9f3e | ||
|
|
ca0620b102 | ||
|
|
7a4cf4e186 | ||
|
|
fe9debd80f | ||
|
|
7083dcf226 | ||
|
|
ee2805d14c | ||
|
|
f15362d619 | ||
|
|
6c2374593f | ||
|
|
0f4c33308f | ||
|
|
ecb9fdae25 |
@@ -1,4 +1,9 @@
|
||||
"""Common test fixtures for server tests."""
|
||||
"""Common test fixtures for server tests.
|
||||
|
||||
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
|
||||
setup_test_user, and setup_admin_user are defined in the parent conftest.py
|
||||
(backend/conftest.py) and are available here automatically.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
@@ -11,54 +16,6 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
return snapshot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_user(test_user_id):
|
||||
"""Provide mock JWT payload for regular user testing."""
|
||||
|
||||
@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools import find_agent_tool, run_agent_tool
|
||||
from backend.copilot.tools.models import ToolResponseBase
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,15 +10,22 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response,
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat
|
||||
from .tools.models import (
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.completion_handler import (
|
||||
process_operation_failure,
|
||||
process_operation_success,
|
||||
)
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_copilot_task
|
||||
from backend.copilot.model import (
|
||||
ChatSession,
|
||||
create_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
)
|
||||
from backend.copilot.response_model import StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.models import (
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
AgentPreviewResponse,
|
||||
@@ -40,6 +47,7 @@ from .tools.models import (
|
||||
SetupRequirementsResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -301,7 +309,7 @@ async def stream_chat_post(
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
_session = await _validate_and_get_session(session_id, user_id) # noqa: F841
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time)*1000:.1f}ms",
|
||||
extra={
|
||||
@@ -336,82 +344,20 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
import time as time_module
|
||||
# Enqueue the task to RabbitMQ for processing by the CoPilot executor
|
||||
await enqueue_copilot_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
)
|
||||
|
||||
gen_start_time = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
first_chunk_time, ttfc = None, None
|
||||
chunk_count = 0
|
||||
try:
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
_task_id=task_id, # Pass task_id so service emits start with taskId for reconnection
|
||||
):
|
||||
chunk_count += 1
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time_module.perf_counter()
|
||||
ttfc = first_chunk_time - gen_start_time
|
||||
logger.info(
|
||||
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
"time_to_first_chunk_ms": ttfc * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Write to Redis (subscribers will receive via XREAD)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
gen_end_time = time_module.perf_counter()
|
||||
total_time = (gen_end_time - gen_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation FINISHED in {total_time/1000:.1f}s; "
|
||||
f"task={task_id}, session={session_id}, "
|
||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"time_to_first_chunk_ms": (
|
||||
ttfc * 1000 if ttfc is not None else None
|
||||
),
|
||||
"n_chunks": chunk_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
except Exception as e:
|
||||
elapsed = time_module.perf_counter() - gen_start_time
|
||||
logger.error(
|
||||
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed * 1000,
|
||||
"error": str(e),
|
||||
}
|
||||
},
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
bg_task = asyncio.create_task(run_ai_generation())
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
|
||||
@@ -1,369 +0,0 @@
|
||||
"""Feature request tools - search and create feature requests via Linear."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
ErrorResponse,
|
||||
FeatureRequestCreatedResponse,
|
||||
FeatureRequestInfo,
|
||||
FeatureRequestSearchResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.blocks.linear._api import LinearClient
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Target project and team IDs in our Linear workspace
|
||||
FEATURE_REQUEST_PROJECT_ID = "13f066f3-f639-4a67-aaa3-31483ebdf8cd"
|
||||
TEAM_ID = "557fd3d5-087e-43a9-83e3-476c8313ce49"
|
||||
|
||||
MAX_SEARCH_RESULTS = 10
|
||||
|
||||
# GraphQL queries/mutations
|
||||
SEARCH_ISSUES_QUERY = """
|
||||
query SearchFeatureRequests($term: String!, $filter: IssueFilter, $first: Int) {
|
||||
searchIssues(term: $term, filter: $filter, first: $first) {
|
||||
nodes {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
description
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
CUSTOMER_UPSERT_MUTATION = """
|
||||
mutation CustomerUpsert($input: CustomerUpsertInput!) {
|
||||
customerUpsert(input: $input) {
|
||||
success
|
||||
customer {
|
||||
id
|
||||
name
|
||||
externalIds
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
ISSUE_CREATE_MUTATION = """
|
||||
mutation IssueCreate($input: IssueCreateInput!) {
|
||||
issueCreate(input: $input) {
|
||||
success
|
||||
issue {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
url
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
CUSTOMER_NEED_CREATE_MUTATION = """
|
||||
mutation CustomerNeedCreate($input: CustomerNeedCreateInput!) {
|
||||
customerNeedCreate(input: $input) {
|
||||
success
|
||||
need {
|
||||
id
|
||||
body
|
||||
customer {
|
||||
id
|
||||
name
|
||||
}
|
||||
issue {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
url
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
_settings: Settings | None = None
|
||||
|
||||
|
||||
def _get_settings() -> Settings:
|
||||
global _settings
|
||||
if _settings is None:
|
||||
_settings = Settings()
|
||||
return _settings
|
||||
|
||||
|
||||
def _get_linear_client() -> LinearClient:
|
||||
"""Create a Linear client using the system API key from settings."""
|
||||
api_key = _get_settings().secrets.linear_api_key
|
||||
if not api_key:
|
||||
raise RuntimeError("LINEAR_API_KEY secret is not configured")
|
||||
credentials = APIKeyCredentials(
|
||||
id="system-linear",
|
||||
provider="linear",
|
||||
api_key=SecretStr(api_key),
|
||||
title="System Linear API Key",
|
||||
)
|
||||
return LinearClient(credentials=credentials)
|
||||
|
||||
|
||||
class SearchFeatureRequestsTool(BaseTool):
|
||||
"""Tool for searching existing feature requests in Linear."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "search_feature_requests"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search existing feature requests to check if a similar request "
|
||||
"already exists before creating a new one. Returns matching feature "
|
||||
"requests with their ID, title, and description."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search term to find matching feature requests.",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
query = kwargs.get("query", "").strip()
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not query:
|
||||
return ErrorResponse(
|
||||
message="Please provide a search query.",
|
||||
error="Missing query parameter",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
client = _get_linear_client()
|
||||
data = await client.query(
|
||||
SEARCH_ISSUES_QUERY,
|
||||
{
|
||||
"term": query,
|
||||
"filter": {
|
||||
"project": {"id": {"eq": FEATURE_REQUEST_PROJECT_ID}},
|
||||
},
|
||||
"first": MAX_SEARCH_RESULTS,
|
||||
},
|
||||
)
|
||||
|
||||
nodes = data.get("searchIssues", {}).get("nodes", [])
|
||||
|
||||
if not nodes:
|
||||
return NoResultsResponse(
|
||||
message=f"No feature requests found matching '{query}'.",
|
||||
suggestions=[
|
||||
"Try different keywords",
|
||||
"Use broader search terms",
|
||||
"You can create a new feature request if none exists",
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
results = [
|
||||
FeatureRequestInfo(
|
||||
id=node["id"],
|
||||
identifier=node["identifier"],
|
||||
title=node["title"],
|
||||
description=node.get("description"),
|
||||
)
|
||||
for node in nodes
|
||||
]
|
||||
|
||||
return FeatureRequestSearchResponse(
|
||||
message=f"Found {len(results)} feature request(s) matching '{query}'.",
|
||||
results=results,
|
||||
count=len(results),
|
||||
query=query,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class CreateFeatureRequestTool(BaseTool):
|
||||
"""Tool for creating feature requests (or adding needs to existing ones)."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_feature_request"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Create a new feature request or add a customer need to an existing one. "
|
||||
"Always search first with search_feature_requests to avoid duplicates. "
|
||||
"If a matching request exists, pass its ID as existing_issue_id to add "
|
||||
"the user's need to it instead of creating a duplicate."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Title for the feature request.",
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Detailed description of what the user wants and why.",
|
||||
},
|
||||
"existing_issue_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"If adding a need to an existing feature request, "
|
||||
"provide its Linear issue ID (from search results). "
|
||||
"Omit to create a new feature request."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": ["title", "description"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _find_or_create_customer(
|
||||
self, client: LinearClient, user_id: str
|
||||
) -> dict:
|
||||
"""Find existing customer by user_id or create a new one via upsert."""
|
||||
data = await client.mutate(
|
||||
CUSTOMER_UPSERT_MUTATION,
|
||||
{
|
||||
"input": {
|
||||
"name": user_id,
|
||||
"externalId": user_id,
|
||||
},
|
||||
},
|
||||
)
|
||||
result = data.get("customerUpsert", {})
|
||||
if not result.get("success"):
|
||||
raise RuntimeError(f"Failed to upsert customer: {data}")
|
||||
return result["customer"]
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
title = kwargs.get("title", "").strip()
|
||||
description = kwargs.get("description", "").strip()
|
||||
existing_issue_id = kwargs.get("existing_issue_id")
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not title or not description:
|
||||
return ErrorResponse(
|
||||
message="Both title and description are required.",
|
||||
error="Missing required parameters",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required to create feature requests.",
|
||||
error="Missing user_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
client = _get_linear_client()
|
||||
|
||||
# Step 1: Find or create customer for this user
|
||||
customer = await self._find_or_create_customer(client, user_id)
|
||||
customer_id = customer["id"]
|
||||
customer_name = customer["name"]
|
||||
|
||||
# Step 2: Create or reuse issue
|
||||
if existing_issue_id:
|
||||
# Add need to existing issue - we still need the issue details for response
|
||||
is_new_issue = False
|
||||
issue_id = existing_issue_id
|
||||
else:
|
||||
# Create new issue in the feature requests project
|
||||
data = await client.mutate(
|
||||
ISSUE_CREATE_MUTATION,
|
||||
{
|
||||
"input": {
|
||||
"title": title,
|
||||
"description": description,
|
||||
"teamId": TEAM_ID,
|
||||
"projectId": FEATURE_REQUEST_PROJECT_ID,
|
||||
},
|
||||
},
|
||||
)
|
||||
result = data.get("issueCreate", {})
|
||||
if not result.get("success"):
|
||||
return ErrorResponse(
|
||||
message="Failed to create feature request issue.",
|
||||
error=str(data),
|
||||
session_id=session_id,
|
||||
)
|
||||
issue = result["issue"]
|
||||
issue_id = issue["id"]
|
||||
is_new_issue = True
|
||||
|
||||
# Step 3: Create customer need on the issue
|
||||
data = await client.mutate(
|
||||
CUSTOMER_NEED_CREATE_MUTATION,
|
||||
{
|
||||
"input": {
|
||||
"customerId": customer_id,
|
||||
"issueId": issue_id,
|
||||
"body": description,
|
||||
"priority": 0,
|
||||
},
|
||||
},
|
||||
)
|
||||
need_result = data.get("customerNeedCreate", {})
|
||||
if not need_result.get("success"):
|
||||
return ErrorResponse(
|
||||
message="Failed to attach customer need to the feature request.",
|
||||
error=str(data),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
need = need_result["need"]
|
||||
issue_info = need["issue"]
|
||||
|
||||
return FeatureRequestCreatedResponse(
|
||||
message=(
|
||||
f"{'Created new feature request' if is_new_issue else 'Added your request to existing feature request'} "
|
||||
f"[{issue_info['identifier']}] {issue_info['title']}."
|
||||
),
|
||||
issue_id=issue_info["id"],
|
||||
issue_identifier=issue_info["identifier"],
|
||||
issue_title=issue_info["title"],
|
||||
issue_url=issue_info.get("url", ""),
|
||||
is_new_issue=is_new_issue,
|
||||
customer_name=customer_name,
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -40,11 +40,11 @@ import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.chat.completion_consumer import (
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.copilot.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
|
||||
@@ -38,7 +38,9 @@ def main(**kwargs):
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.copilot.executor.manager import CoPilotExecutor
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
|
||||
run_processes(
|
||||
@@ -48,6 +50,7 @@ def main(**kwargs):
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
CoPilotExecutor(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -27,6 +28,54 @@ async def server():
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||
async def graph_cleanup(server):
|
||||
created_graph_ids = []
|
||||
|
||||
8
autogpt_platform/backend/backend/copilot/__init__.py
Normal file
8
autogpt_platform/backend/backend/copilot/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""CoPilot module - AI assistant for AutoGPT platform.
|
||||
|
||||
This module contains the core CoPilot functionality including:
|
||||
- AI generation service (LLM calls)
|
||||
- Tool execution
|
||||
- Session management
|
||||
- Stream registry for SSE reconnection
|
||||
"""
|
||||
@@ -119,8 +119,9 @@ class ChatCompletionConsumer:
|
||||
"""Lazily initialize Prisma client on first use."""
|
||||
if self._prisma is None:
|
||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
self._prisma = Prisma(datasource={"url": database_url})
|
||||
await self._prisma.connect()
|
||||
prisma = Prisma(datasource={"url": database_url})
|
||||
await prisma.connect()
|
||||
self._prisma = prisma
|
||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||
return self._prisma
|
||||
|
||||
@@ -14,7 +14,7 @@ from prisma.types import (
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
|
||||
from backend.data.db import transaction
|
||||
from backend.data import db
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -147,7 +147,7 @@ async def add_chat_messages_batch(
|
||||
|
||||
created_messages = []
|
||||
|
||||
async with transaction() as tx:
|
||||
async with db.transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
@@ -0,0 +1,5 @@
|
||||
"""CoPilot Executor - Dedicated service for AI generation and tool execution.
|
||||
|
||||
This module contains the executor service that processes CoPilot tasks
|
||||
from RabbitMQ, following the graph executor pattern.
|
||||
"""
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Entry point for running the CoPilot Executor service.
|
||||
|
||||
Usage:
|
||||
python -m backend.copilot.executor
|
||||
"""
|
||||
|
||||
from backend.app import run_processes
|
||||
|
||||
from .manager import CoPilotExecutor
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the CoPilot Executor service."""
|
||||
run_processes(CoPilotExecutor())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
508
autogpt_platform/backend/backend/copilot/executor/manager.py
Normal file
508
autogpt_platform/backend/backend/copilot/executor/manager.py
Normal file
@@ -0,0 +1,508 @@
|
||||
"""CoPilot Executor Manager - main service for CoPilot task execution.
|
||||
|
||||
This module contains the CoPilotExecutor class that consumes chat tasks from
|
||||
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.exceptions import AMQPChannelError, AMQPConnectionError
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .processor import execute_copilot_task, init_worker
|
||||
from .utils import (
|
||||
COPILOT_CANCEL_QUEUE_NAME,
|
||||
COPILOT_EXECUTION_QUEUE_NAME,
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
CancelCoPilotEvent,
|
||||
CoPilotExecutionEntry,
|
||||
create_copilot_queue_config,
|
||||
)
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
settings = Settings()
|
||||
|
||||
# Prometheus metrics
|
||||
active_tasks_gauge = Gauge(
|
||||
"copilot_executor_active_tasks",
|
||||
"Number of active CoPilot tasks",
|
||||
)
|
||||
pool_size_gauge = Gauge(
|
||||
"copilot_executor_pool_size",
|
||||
"Maximum number of CoPilot executor workers",
|
||||
)
|
||||
utilization_gauge = Gauge(
|
||||
"copilot_executor_utilization_ratio",
|
||||
"Ratio of active tasks to pool size",
|
||||
)
|
||||
|
||||
|
||||
class CoPilotExecutor(AppProcess):
|
||||
"""CoPilot Executor service for processing chat generation tasks.
|
||||
|
||||
This service consumes tasks from RabbitMQ, processes them using a thread pool,
|
||||
and publishes results to Redis Streams. It follows the graph executor pattern
|
||||
for reliable message handling and graceful shutdown.
|
||||
|
||||
Key features:
|
||||
- RabbitMQ-based task distribution with manual acknowledgment
|
||||
- Thread pool executor for concurrent task processing
|
||||
- Cluster lock for duplicate prevention across pods
|
||||
- Graceful shutdown with timeout for in-flight tasks
|
||||
- FANOUT exchange for cancellation broadcast
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_copilot_workers
|
||||
self.active_tasks: dict[str, tuple[Future, threading.Event]] = {}
|
||||
self.executor_id = str(uuid.uuid4())
|
||||
|
||||
self._executor = None
|
||||
self._stop_consuming = None
|
||||
|
||||
self._cancel_thread = None
|
||||
self._cancel_client = None
|
||||
self._run_thread = None
|
||||
self._run_client = None
|
||||
|
||||
self._task_locks: dict[str, ClusterLock] = {}
|
||||
|
||||
# ============ Main Entry Points (AppProcess interface) ============ #
|
||||
|
||||
def run(self):
|
||||
"""Main service loop - consume from RabbitMQ."""
|
||||
logger.info(f"Pod assigned executor_id: {self.executor_id}")
|
||||
logger.info(f"Spawn max-{self.pool_size} workers...")
|
||||
|
||||
pool_size_gauge.set(self.pool_size)
|
||||
self._update_metrics()
|
||||
start_http_server(settings.config.copilot_executor_port)
|
||||
|
||||
self.cancel_thread.start()
|
||||
self.run_thread.start()
|
||||
|
||||
while True:
|
||||
time.sleep(1e5)
|
||||
|
||||
def cleanup(self):
|
||||
"""Graceful shutdown with active execution waiting."""
|
||||
pid = os.getpid()
|
||||
logger.info(f"[cleanup {pid}] Starting graceful shutdown...")
|
||||
|
||||
# Signal the consumer thread to stop
|
||||
try:
|
||||
self.stop_consuming.set()
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.connection.add_callback_threadsafe(
|
||||
lambda: run_channel.stop_consuming()
|
||||
)
|
||||
logger.info(f"[cleanup {pid}] Consumer has been signaled to stop")
|
||||
except Exception as e:
|
||||
logger.error(f"[cleanup {pid}] Error stopping consumer: {e}")
|
||||
|
||||
# Wait for active executions to complete
|
||||
if self.active_tasks:
|
||||
logger.info(
|
||||
f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
|
||||
)
|
||||
|
||||
start_time = time.monotonic()
|
||||
last_refresh = start_time
|
||||
lock_refresh_interval = settings.config.cluster_lock_timeout / 10
|
||||
|
||||
while (
|
||||
self.active_tasks
|
||||
and (time.monotonic() - start_time) < GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
|
||||
):
|
||||
self._cleanup_completed_tasks()
|
||||
if not self.active_tasks:
|
||||
break
|
||||
|
||||
# Refresh cluster locks periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= lock_refresh_interval:
|
||||
for lock in self._task_locks.values():
|
||||
try:
|
||||
lock.refresh()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[cleanup {pid}] Failed to refresh lock: {e}"
|
||||
)
|
||||
last_refresh = current_time
|
||||
|
||||
logger.info(
|
||||
f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..."
|
||||
)
|
||||
time.sleep(10.0)
|
||||
|
||||
# Stop message consumers
|
||||
if self._run_thread:
|
||||
self._stop_message_consumers(
|
||||
self._run_thread, self.run_client, "[cleanup][run]"
|
||||
)
|
||||
if self._cancel_thread:
|
||||
self._stop_message_consumers(
|
||||
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
|
||||
)
|
||||
|
||||
# Shutdown executor
|
||||
if self._executor:
|
||||
logger.info(f"[cleanup {pid}] Shutting down executor...")
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
# Release any remaining locks
|
||||
for task_id, lock in list(self._task_locks.items()):
|
||||
try:
|
||||
lock.release()
|
||||
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
|
||||
|
||||
# ============ RabbitMQ Consumer Methods ============ #
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_cancel(self):
|
||||
"""Consume cancellation messages from FANOUT exchange."""
|
||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||
logger.info("Stop reconnecting cancel consumer - service cleaned up")
|
||||
return
|
||||
|
||||
if not self.cancel_client.is_ready:
|
||||
self.cancel_client.disconnect()
|
||||
self.cancel_client.connect()
|
||||
|
||||
# Check again after connect - shutdown may have been requested
|
||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||
self.cancel_client.disconnect()
|
||||
return
|
||||
|
||||
cancel_channel = self.cancel_client.get_channel()
|
||||
cancel_channel.basic_consume(
|
||||
queue=COPILOT_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
)
|
||||
logger.info("Starting cancel message consumer...")
|
||||
cancel_channel.start_consuming()
|
||||
if not self.stop_consuming.is_set() or self.active_tasks:
|
||||
raise RuntimeError("Cancel message consumer stopped unexpectedly")
|
||||
logger.info("Cancel message consumer stopped gracefully")
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_run(self):
|
||||
"""Consume run messages from DIRECT exchange."""
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Stop reconnecting run consumer - service cleaned up")
|
||||
return
|
||||
|
||||
if not self.run_client.is_ready:
|
||||
self.run_client.disconnect()
|
||||
self.run_client.connect()
|
||||
|
||||
# Check again after connect - shutdown may have been requested
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||
self.run_client.disconnect()
|
||||
return
|
||||
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.basic_qos(prefetch_count=self.pool_size)
|
||||
|
||||
run_channel.basic_consume(
|
||||
queue=COPILOT_EXECUTION_QUEUE_NAME,
|
||||
on_message_callback=self._handle_run_message,
|
||||
auto_ack=False,
|
||||
consumer_tag="copilot_execution_consumer",
|
||||
)
|
||||
logger.info("Starting to consume run messages...")
|
||||
run_channel.start_consuming()
|
||||
if not self.stop_consuming.is_set():
|
||||
raise RuntimeError("Run message consumer stopped unexpectedly")
|
||||
logger.info("Run message consumer stopped gracefully")
|
||||
|
||||
# ============ Message Handlers ============ #
|
||||
|
||||
@error_logged(swallow=True)
|
||||
def _handle_cancel_message(
|
||||
self,
|
||||
_channel: BlockingChannel,
|
||||
_method: Basic.Deliver,
|
||||
_properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""Handle cancel message from FANOUT exchange."""
|
||||
request = CancelCoPilotEvent.model_validate_json(body)
|
||||
task_id = request.task_id
|
||||
if not task_id:
|
||||
logger.warning("Cancel message missing 'task_id'")
|
||||
return
|
||||
if task_id not in self.active_tasks:
|
||||
logger.debug(f"Cancel received for {task_id} but not active")
|
||||
return
|
||||
|
||||
_, cancel_event = self.active_tasks[task_id]
|
||||
logger.info(f"Received cancel for {task_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(f"Cancel already set for {task_id}")
|
||||
|
||||
def _handle_run_message(
|
||||
self,
|
||||
_channel: BlockingChannel,
|
||||
method: Basic.Deliver,
|
||||
_properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""Handle run message from DIRECT exchange."""
|
||||
delivery_tag = method.delivery_tag
|
||||
# Capture the channel used at message delivery time to ensure we ack
|
||||
# on the correct channel. Delivery tags are channel-scoped and become
|
||||
# invalid if the channel is recreated after reconnection.
|
||||
delivery_channel = _channel
|
||||
|
||||
def ack_message(reject: bool, requeue: bool):
|
||||
"""Acknowledge or reject the message.
|
||||
|
||||
Uses the channel from the original message delivery. If the channel
|
||||
is no longer open (e.g., after reconnection), logs a warning and
|
||||
skips the ack - RabbitMQ will redeliver the message automatically.
|
||||
"""
|
||||
try:
|
||||
if not delivery_channel.is_open:
|
||||
logger.warning(
|
||||
f"Channel closed, cannot ack delivery_tag={delivery_tag}. "
|
||||
"Message will be redelivered by RabbitMQ."
|
||||
)
|
||||
return
|
||||
|
||||
if reject:
|
||||
delivery_channel.connection.add_callback_threadsafe(
|
||||
lambda: delivery_channel.basic_nack(
|
||||
delivery_tag, requeue=requeue
|
||||
)
|
||||
)
|
||||
else:
|
||||
delivery_channel.connection.add_callback_threadsafe(
|
||||
lambda: delivery_channel.basic_ack(delivery_tag)
|
||||
)
|
||||
except (AMQPChannelError, AMQPConnectionError) as e:
|
||||
# Channel/connection errors indicate stale delivery tag - don't retry
|
||||
logger.warning(
|
||||
f"Cannot ack delivery_tag={delivery_tag} due to channel/connection "
|
||||
f"error: {e}. Message will be redelivered by RabbitMQ."
|
||||
)
|
||||
except Exception as e:
|
||||
# Other errors might be transient, but log and skip to avoid blocking
|
||||
logger.error(
|
||||
f"Unexpected error acking delivery_tag={delivery_tag}: {e}"
|
||||
)
|
||||
|
||||
# Check if we're shutting down
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Rejecting new task during shutdown")
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Check if we can accept more tasks
|
||||
self._cleanup_completed_tasks()
|
||||
if len(self.active_tasks) >= self.pool_size:
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
try:
|
||||
entry = CoPilotExecutionEntry.model_validate_json(body)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not parse run message: {e}, body={body}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
task_id = entry.task_id
|
||||
|
||||
# Check for local duplicate - task is already running on this executor
|
||||
if task_id in self.active_tasks:
|
||||
logger.warning(
|
||||
f"Task {task_id} already running locally, rejecting duplicate"
|
||||
)
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
# Try to acquire cluster-wide lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"copilot:task:{task_id}:lock",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
if current_owner is not None:
|
||||
logger.warning(f"Task {task_id} already running on pod {current_owner}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not acquire lock for {task_id} - Redis unavailable"
|
||||
)
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Execute the task
|
||||
try:
|
||||
self._task_locks[task_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
future = self.executor.submit(
|
||||
execute_copilot_task, entry, cancel_event, cluster_lock
|
||||
)
|
||||
self.active_tasks[task_id] = (future, cancel_event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup execution for {task_id}: {e}")
|
||||
cluster_lock.release()
|
||||
if task_id in self._task_locks:
|
||||
del self._task_locks[task_id]
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
self._update_metrics()
|
||||
|
||||
def on_run_done(f: Future):
|
||||
logger.info(f"Run completed for {task_id}")
|
||||
try:
|
||||
if exec_error := f.exception():
|
||||
logger.error(f"Execution for {task_id} failed: {exec_error}")
|
||||
# Don't requeue failed tasks - they've been marked as failed
|
||||
# in the stream registry. Requeuing would cause infinite retries
|
||||
# for deterministic failures.
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
ack_message(reject=False, requeue=False)
|
||||
except BaseException as e:
|
||||
logger.exception(f"Error in run completion callback: {e}")
|
||||
finally:
|
||||
# Release the cluster lock
|
||||
if task_id in self._task_locks:
|
||||
logger.info(f"Releasing cluster lock for {task_id}")
|
||||
self._task_locks[task_id].release()
|
||||
del self._task_locks[task_id]
|
||||
self._cleanup_completed_tasks()
|
||||
|
||||
future.add_done_callback(on_run_done)
|
||||
|
||||
# ============ Helper Methods ============ #
|
||||
|
||||
def _cleanup_completed_tasks(self) -> list[str]:
|
||||
"""Remove completed futures from active_tasks and update metrics."""
|
||||
completed_tasks = []
|
||||
for task_id, (future, _) in self.active_tasks.items():
|
||||
if future.done():
|
||||
completed_tasks.append(task_id)
|
||||
|
||||
for task_id in completed_tasks:
|
||||
logger.info(f"Cleaned up completed task {task_id}")
|
||||
self.active_tasks.pop(task_id, None)
|
||||
|
||||
self._update_metrics()
|
||||
return completed_tasks
|
||||
|
||||
def _update_metrics(self):
|
||||
"""Update Prometheus metrics."""
|
||||
active_count = len(self.active_tasks)
|
||||
active_tasks_gauge.set(active_count)
|
||||
if self.stop_consuming.is_set():
|
||||
utilization_gauge.set(1.0)
|
||||
else:
|
||||
utilization_gauge.set(
|
||||
active_count / self.pool_size if self.pool_size > 0 else 0
|
||||
)
|
||||
|
||||
def _stop_message_consumers(
|
||||
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
|
||||
):
|
||||
"""Stop a message consumer thread."""
|
||||
try:
|
||||
channel = client.get_channel()
|
||||
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
|
||||
|
||||
thread.join(timeout=300)
|
||||
if thread.is_alive():
|
||||
logger.error(
|
||||
f"{prefix} Thread did not finish in time, forcing disconnect"
|
||||
)
|
||||
|
||||
client.disconnect()
|
||||
logger.info(f"{prefix} Client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} Error disconnecting client: {e}")
|
||||
|
||||
# ============ Lazy-initialized Properties ============ #
|
||||
|
||||
@property
|
||||
def cancel_thread(self) -> threading.Thread:
|
||||
if self._cancel_thread is None:
|
||||
self._cancel_thread = threading.Thread(
|
||||
target=lambda: self._consume_cancel(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._cancel_thread
|
||||
|
||||
@property
|
||||
def run_thread(self) -> threading.Thread:
|
||||
if self._run_thread is None:
|
||||
self._run_thread = threading.Thread(
|
||||
target=lambda: self._consume_run(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._run_thread
|
||||
|
||||
@property
|
||||
def stop_consuming(self) -> threading.Event:
|
||||
if self._stop_consuming is None:
|
||||
self._stop_consuming = threading.Event()
|
||||
return self._stop_consuming
|
||||
|
||||
@property
|
||||
def executor(self) -> ThreadPoolExecutor:
|
||||
if self._executor is None:
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=init_worker,
|
||||
)
|
||||
return self._executor
|
||||
|
||||
@property
|
||||
def cancel_client(self) -> SyncRabbitMQ:
|
||||
if self._cancel_client is None:
|
||||
self._cancel_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||
return self._cancel_client
|
||||
|
||||
@property
|
||||
def run_client(self) -> SyncRabbitMQ:
|
||||
if self._run_client is None:
|
||||
self._run_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||
return self._run_client
|
||||
237
autogpt_platform/backend/backend/copilot/executor/processor.py
Normal file
237
autogpt_platform/backend/backend/copilot/executor/processor.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""CoPilot execution processor - per-worker execution logic.
|
||||
|
||||
This module contains the processor class that handles CoPilot task execution
|
||||
in a thread-local context, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from backend.copilot import service as copilot_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.retry import func_retry
|
||||
|
||||
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
|
||||
|
||||
# ============ Module Entry Points ============ #
|
||||
|
||||
# Thread-local storage for processor instances
|
||||
_tls = threading.local()
|
||||
|
||||
|
||||
def execute_copilot_task(
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot task using the thread-local processor.
|
||||
|
||||
This function is the entry point called by the thread pool executor.
|
||||
|
||||
Args:
|
||||
entry: The task payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for this execution
|
||||
"""
|
||||
processor: CoPilotProcessor = _tls.processor
|
||||
return processor.execute(entry, cancel, cluster_lock)
|
||||
|
||||
|
||||
def init_worker():
|
||||
"""Initialize the processor for the current worker thread.
|
||||
|
||||
This function is called by the thread pool executor when a new worker
|
||||
thread is created. It ensures each worker has its own processor instance.
|
||||
"""
|
||||
_tls.processor = CoPilotProcessor()
|
||||
_tls.processor.on_executor_start()
|
||||
|
||||
|
||||
# ============ Processor Class ============ #
|
||||
|
||||
|
||||
class CoPilotProcessor:
|
||||
"""Per-worker execution logic for CoPilot tasks.
|
||||
|
||||
This class is instantiated once per worker thread and handles the execution
|
||||
of CoPilot chat generation tasks. It maintains an async event loop for
|
||||
running the async service code.
|
||||
|
||||
The execution flow:
|
||||
1. CoPilot task is picked from RabbitMQ queue
|
||||
2. Manager submits task to thread pool
|
||||
3. Processor executes the task in its event loop
|
||||
4. Results are published to Redis Streams
|
||||
"""
|
||||
|
||||
@func_retry
|
||||
def on_executor_start(self):
|
||||
"""Initialize the processor when the worker thread starts.
|
||||
|
||||
This method is called once per worker thread to set up the async event
|
||||
loop and initialize any required resources.
|
||||
|
||||
Database is accessed only through DatabaseManager, so we don't need to connect
|
||||
to Prisma directly.
|
||||
"""
|
||||
configure_logging()
|
||||
set_service_name("CoPilotExecutor")
|
||||
self.tid = threading.get_ident()
|
||||
self.execution_loop = asyncio.new_event_loop()
|
||||
self.execution_thread = threading.Thread(
|
||||
target=self.execution_loop.run_forever, daemon=True
|
||||
)
|
||||
self.execution_thread.start()
|
||||
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
@error_logged(swallow=False)
|
||||
def execute(
|
||||
self,
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot task.
|
||||
|
||||
This is the main entry point for task execution. It runs the async
|
||||
execution logic in the worker's event loop and handles errors.
|
||||
|
||||
Args:
|
||||
entry: The task payload containing session and message info
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock to prevent duplicate execution
|
||||
"""
|
||||
log = CoPilotLogMetadata(
|
||||
logging.getLogger(__name__),
|
||||
task_id=entry.task_id,
|
||||
session_id=entry.session_id,
|
||||
user_id=entry.user_id,
|
||||
)
|
||||
log.info("Starting execution")
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
try:
|
||||
# Run the async execution in our event loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._execute_async(entry, cancel, cluster_lock, log),
|
||||
self.execution_loop,
|
||||
)
|
||||
|
||||
# Wait for completion, checking cancel periodically
|
||||
while not future.done():
|
||||
try:
|
||||
future.result(timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
if cancel.is_set():
|
||||
log.info("Cancellation requested")
|
||||
future.cancel()
|
||||
break
|
||||
# Refresh cluster lock to maintain ownership
|
||||
cluster_lock.refresh()
|
||||
|
||||
if not future.cancelled():
|
||||
# Get result to propagate any exceptions
|
||||
future.result()
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.info(f"Execution completed in {elapsed:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
|
||||
# Note: _execute_async already marks the task as failed before re-raising,
|
||||
# so we don't call _mark_task_failed here to avoid duplicate error events.
|
||||
raise
|
||||
|
||||
async def _execute_async(
|
||||
self,
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
log: CoPilotLogMetadata,
|
||||
):
|
||||
"""Async execution logic for CoPilot task.
|
||||
|
||||
This method calls the existing stream_chat_completion service function
|
||||
and publishes results to the stream registry.
|
||||
|
||||
Args:
|
||||
entry: The task payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for refresh
|
||||
log: Structured logger for this task
|
||||
"""
|
||||
last_refresh = time.monotonic()
|
||||
refresh_interval = 30.0 # Refresh lock every 30 seconds
|
||||
|
||||
try:
|
||||
# Stream chat completion and publish chunks to Redis
|
||||
async for chunk in copilot_service.stream_chat_completion(
|
||||
session_id=entry.session_id,
|
||||
message=entry.message if entry.message else None,
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
_task_id=entry.task_id,
|
||||
):
|
||||
# Check for cancellation
|
||||
if cancel.is_set():
|
||||
log.info("Cancelled during streaming")
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamError(errorText="Operation cancelled")
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamFinishStep()
|
||||
)
|
||||
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(
|
||||
entry.task_id, status="failed"
|
||||
)
|
||||
return
|
||||
|
||||
# Refresh cluster lock periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
|
||||
# Publish chunk to stream registry
|
||||
await stream_registry.publish_chunk(entry.task_id, chunk)
|
||||
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="completed")
|
||||
log.info("Task completed successfully")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.info("Task cancelled")
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="failed")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Task failed: {e}")
|
||||
await self._mark_task_failed(entry.task_id, str(e))
|
||||
raise
|
||||
|
||||
async def _mark_task_failed(self, task_id: str, error_message: str):
|
||||
"""Mark a task as failed and publish error to stream registry."""
|
||||
try:
|
||||
await stream_registry.publish_chunk(
|
||||
task_id, StreamError(errorText=error_message)
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark task {task_id} as failed: {e}")
|
||||
207
autogpt_platform/backend/backend/copilot/executor/utils.py
Normal file
207
autogpt_platform/backend/backend/copilot/executor/utils.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""RabbitMQ queue configuration for CoPilot executor.
|
||||
|
||||
Defines two exchanges and queues following the graph executor pattern:
|
||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============ Logging Helper ============ #
|
||||
|
||||
|
||||
class CoPilotLogMetadata(TruncatedLogger):
|
||||
"""Structured logging helper for CoPilot executor.
|
||||
|
||||
In cloud environments (structured logging enabled), uses a simple prefix
|
||||
and passes metadata via json_fields. In local environments, uses a detailed
|
||||
prefix with all metadata key-value pairs for easier debugging.
|
||||
|
||||
Args:
|
||||
logger: The underlying logger instance
|
||||
max_length: Maximum log message length before truncation
|
||||
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
|
||||
These are added to json_fields in cloud mode, or to the prefix in local mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger: logging.Logger,
|
||||
max_length: int = 1000,
|
||||
**kwargs: str | None,
|
||||
):
|
||||
# Filter out None values
|
||||
metadata = {k: v for k, v in kwargs.items() if v is not None}
|
||||
metadata["component"] = "CoPilotExecutor"
|
||||
|
||||
if is_structured_logging_enabled():
|
||||
prefix = "[CoPilotExecutor]"
|
||||
else:
|
||||
# Build prefix from metadata key-value pairs
|
||||
meta_parts = "|".join(
|
||||
f"{k}:{v}" for k, v in metadata.items() if k != "component"
|
||||
)
|
||||
prefix = (
|
||||
f"[CoPilotExecutor|{meta_parts}]" if meta_parts else "[CoPilotExecutor]"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
logger,
|
||||
max_length=max_length,
|
||||
prefix=prefix,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
# ============ Exchange and Queue Configuration ============ #
|
||||
|
||||
COPILOT_EXECUTION_EXCHANGE = Exchange(
|
||||
name="copilot_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue"
|
||||
COPILOT_EXECUTION_ROUTING_KEY = "copilot.run"
|
||||
|
||||
COPILOT_CANCEL_EXCHANGE = Exchange(
|
||||
name="copilot_cancel",
|
||||
type=ExchangeType.FANOUT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
|
||||
|
||||
# CoPilot operations can include extended thinking and agent generation
|
||||
# which may take 30+ minutes to complete
|
||||
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
|
||||
|
||||
# Graceful shutdown timeout - allow in-flight operations to complete
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def create_copilot_queue_config() -> RabbitMQConfig:
|
||||
"""Create RabbitMQ configuration for CoPilot executor.
|
||||
|
||||
Defines two exchanges and queues:
|
||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||
|
||||
Returns:
|
||||
RabbitMQConfig with exchanges and queues defined
|
||||
"""
|
||||
run_queue = Queue(
|
||||
name=COPILOT_EXECUTION_QUEUE_NAME,
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
arguments={
|
||||
# Extended consumer timeout for long-running LLM operations
|
||||
# Default 30-minute timeout is insufficient for extended thinking
|
||||
# and agent generation which can take 30+ minutes
|
||||
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||
* 1000,
|
||||
},
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
name=COPILOT_CANCEL_QUEUE_NAME,
|
||||
exchange=COPILOT_CANCEL_EXCHANGE,
|
||||
routing_key="", # not used for FANOUT
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
return RabbitMQConfig(
|
||||
vhost="/",
|
||||
exchanges=[COPILOT_EXECUTION_EXCHANGE, COPILOT_CANCEL_EXCHANGE],
|
||||
queues=[run_queue, cancel_queue],
|
||||
)
|
||||
|
||||
|
||||
# ============ Message Models ============ #
|
||||
|
||||
|
||||
class CoPilotExecutionEntry(BaseModel):
|
||||
"""Task payload for CoPilot AI generation.
|
||||
|
||||
This model represents a chat generation task to be processed by the executor.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
"""Unique identifier for this task (used for stream registry)"""
|
||||
|
||||
session_id: str
|
||||
"""Chat session ID"""
|
||||
|
||||
user_id: str | None
|
||||
"""User ID (may be None for anonymous users)"""
|
||||
|
||||
operation_id: str
|
||||
"""Operation ID for webhook callbacks and completion tracking"""
|
||||
|
||||
message: str
|
||||
"""User's message to process"""
|
||||
|
||||
is_user_message: bool = True
|
||||
"""Whether the message is from the user (vs system/assistant)"""
|
||||
|
||||
context: dict[str, str] | None = None
|
||||
"""Optional context for the message (e.g., {url: str, content: str})"""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
|
||||
task_id: str
|
||||
"""Task ID to cancel"""
|
||||
|
||||
|
||||
# ============ Queue Publishing Helpers ============ #
|
||||
|
||||
|
||||
async def enqueue_copilot_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
operation_id: str,
|
||||
message: str,
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for this task (used for stream registry)
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (may be None for anonymous users)
|
||||
operation_id: Operation ID for webhook callbacks and completion tracking
|
||||
message: User's message to process
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
entry = CoPilotExecutionEntry(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
await queue_client.publish_message(
|
||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||
message=entry.model_dump_json(),
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
)
|
||||
@@ -23,26 +23,17 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import DatabaseError, RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
# Redis cache key prefix for chat sessions
|
||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||
|
||||
@@ -52,28 +43,7 @@ def _get_session_cache_key(session_id: str) -> str:
|
||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||
|
||||
|
||||
# Session-level locks to prevent race conditions during concurrent upserts.
|
||||
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
||||
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
||||
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
||||
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
# ===================== Chat data models ===================== #
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -322,38 +292,26 @@ class ChatSession(BaseModel):
|
||||
return self._merge_consecutive_assistant_messages(messages)
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
# ================ Chat cache + DB operations ================ #
|
||||
|
||||
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
|
||||
# connected directly.
|
||||
|
||||
|
||||
async def cache_chat_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session without persisting to the database."""
|
||||
await _cache_session(session)
|
||||
"""Cache a chat session in Redis (without persisting to the database)."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def invalidate_session_cache(session_id: str) -> None:
|
||||
@@ -371,80 +329,6 @@ async def invalidate_session_cache(session_id: str) -> None:
|
||||
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
# Check if session exists in DB
|
||||
existing = await chat_db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await chat_db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
@@ -492,7 +376,7 @@ async def get_chat_session(
|
||||
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await _cache_session(session)
|
||||
await cache_chat_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
@@ -500,6 +384,45 @@ async def get_chat_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db().get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.info(
|
||||
f"Loading session {session_id} from DB: "
|
||||
f"has_messages={messages is not None}, "
|
||||
f"message_count={len(messages) if messages else 0}, "
|
||||
f"roles={[m.role for m in messages] if messages else []}"
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
) -> ChatSession:
|
||||
@@ -520,7 +443,7 @@ async def upsert_chat_session(
|
||||
|
||||
async with lock:
|
||||
# Get existing message count from DB for incremental saves
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
existing_message_count = await chat_db().get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
|
||||
@@ -537,7 +460,7 @@ async def upsert_chat_session(
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
if db_error is None:
|
||||
@@ -558,6 +481,65 @@ async def upsert_chat_session(
|
||||
return session
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
db = chat_db()
|
||||
|
||||
# Check if session exists in DB
|
||||
existing = await db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
await db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def create_chat_session(user_id: str) -> ChatSession:
|
||||
"""Create a new chat session and persist it.
|
||||
|
||||
@@ -570,7 +552,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db.create_chat_session(
|
||||
await chat_db().create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -582,7 +564,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Cache the session (best-effort optimization, DB is source of truth)
|
||||
try:
|
||||
await _cache_session(session)
|
||||
await cache_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||
|
||||
@@ -600,8 +582,9 @@ async def get_user_sessions(
|
||||
A tuple of (sessions, total_count) where total_count is the overall
|
||||
number of sessions for the user (not just the current page).
|
||||
"""
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await chat_db.get_user_session_count(user_id)
|
||||
db = chat_db()
|
||||
prisma_sessions = await db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await db.get_user_session_count(user_id)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
@@ -624,7 +607,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
"""
|
||||
# Delete from database first (with optional user_id validation)
|
||||
# This confirms ownership before invalidating cache
|
||||
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
||||
deleted = await chat_db().delete_chat_session(session_id, user_id)
|
||||
|
||||
if not deleted:
|
||||
return False
|
||||
@@ -659,7 +642,7 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
True if updated successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
||||
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
@@ -676,3 +659,29 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== Chat session locks ==================== #
|
||||
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
This was originally added to solve the specific problem of race conditions between
|
||||
the session title thread and the conversation thread, which always occurs on the
|
||||
same instance as we prevent rapid request sends on the frontend.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks. Explicit cleanup also occurs
|
||||
in `delete_chat_session()`.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
@@ -27,6 +27,7 @@ from openai.types.chat import (
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import (
|
||||
format_understanding_for_prompt,
|
||||
@@ -35,7 +36,6 @@ from backend.data.understanding import (
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from . import stream_registry
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
@@ -1744,7 +1744,7 @@ async def _update_pending_operation(
|
||||
This is called by background tasks when long-running operations complete.
|
||||
"""
|
||||
# Update the message in database
|
||||
updated = await chat_db.update_tool_message_content(
|
||||
updated = await chat_db().update_tool_message_content(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
new_content=result,
|
||||
@@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tracking import track_tool_called
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_tool_called
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
@@ -12,7 +12,6 @@ from .base import BaseTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
from .find_library_agent import FindLibraryAgentTool
|
||||
@@ -28,7 +27,7 @@ from .workspace_files import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -46,9 +45,6 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
"search_docs": SearchDocsTool(),
|
||||
"get_doc_page": GetDocPageTool(),
|
||||
# Feature request tools
|
||||
"search_feature_requests": SearchFeatureRequestsTool(),
|
||||
"create_feature_request": CreateFeatureRequestTool(),
|
||||
# Workspace tools for CoPilot file operations
|
||||
"list_workspace_files": ListWorkspaceFilesTool(),
|
||||
"read_workspace_file": ReadWorkspaceFileTool(),
|
||||
@@ -6,11 +6,11 @@ import pytest
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import APIKeyCredentials
|
||||
@@ -3,11 +3,9 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import understanding_db
|
||||
from backend.data.understanding import BusinessUnderstandingInput
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||
@@ -99,7 +97,9 @@ and automations for the user's specific needs."""
|
||||
]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await upsert_business_understanding(user_id, input_data)
|
||||
understanding = await understanding_db().upsert_business_understanding(
|
||||
user_id, input_data
|
||||
)
|
||||
|
||||
# Build current understanding summary (filter out empty values)
|
||||
current_understanding = {
|
||||
@@ -5,9 +5,8 @@ import re
|
||||
import uuid
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||
from backend.data.db_accessors import graph_db, library_db, store_db
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .service import (
|
||||
@@ -145,8 +144,9 @@ async def get_library_agent_by_id(
|
||||
Returns:
|
||||
LibraryAgentSummary if found, None otherwise
|
||||
"""
|
||||
db = library_db()
|
||||
try:
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
agent = await db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
@@ -163,7 +163,7 @@ async def get_library_agent_by_id(
|
||||
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
agent = await db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
@@ -215,7 +215,7 @@ async def get_library_agents_for_generation(
|
||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||
"""
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
response = await library_db().list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
page=1,
|
||||
@@ -272,7 +272,7 @@ async def search_marketplace_agents_for_generation(
|
||||
List of LibraryAgentSummary with full input/output schemas
|
||||
"""
|
||||
try:
|
||||
response = await store_db.get_store_agents(
|
||||
response = await store_db().get_store_agents(
|
||||
search_query=search_query,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
@@ -286,7 +286,7 @@ async def search_marketplace_agents_for_generation(
|
||||
return []
|
||||
|
||||
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
||||
graphs = await get_store_listed_graphs(*graph_ids)
|
||||
graphs = await graph_db().get_store_listed_graphs(*graph_ids)
|
||||
|
||||
results: list[LibraryAgentSummary] = []
|
||||
for agent in agents_with_graphs:
|
||||
@@ -673,9 +673,10 @@ async def save_agent_to_library(
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
graph = json_to_graph(agent_json)
|
||||
db = library_db()
|
||||
if is_update:
|
||||
return await library_db.update_graph_in_library(graph, user_id)
|
||||
return await library_db.create_graph_in_library(graph, user_id)
|
||||
return await db.update_graph_in_library(graph, user_id)
|
||||
return await db.create_graph_in_library(graph, user_id)
|
||||
|
||||
|
||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||
@@ -735,12 +736,14 @@ async def get_agent_as_json(
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||
db = graph_db()
|
||||
|
||||
graph = await db.get_graph(agent_id, version=None, user_id=user_id)
|
||||
|
||||
if not graph and user_id:
|
||||
try:
|
||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
graph = await get_graph(
|
||||
library_agent = await library_db().get_library_agent(agent_id, user_id)
|
||||
graph = await db.get_graph(
|
||||
library_agent.graph_id, version=None, user_id=user_id
|
||||
)
|
||||
except NotFoundError:
|
||||
@@ -7,10 +7,9 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.data import execution as execution_db
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import execution_db, library_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -165,10 +164,12 @@ class AgentOutputTool(BaseTool):
|
||||
Resolve agent from provided identifiers.
|
||||
Returns (library_agent, error_message).
|
||||
"""
|
||||
lib_db = library_db()
|
||||
|
||||
# Priority 1: Exact library agent ID
|
||||
if library_agent_id:
|
||||
try:
|
||||
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||
agent = await lib_db.get_library_agent(library_agent_id, user_id)
|
||||
return agent, None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||
@@ -182,7 +183,7 @@ class AgentOutputTool(BaseTool):
|
||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||
|
||||
# Find in user's library by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
if not agent:
|
||||
return (
|
||||
None,
|
||||
@@ -194,7 +195,7 @@ class AgentOutputTool(BaseTool):
|
||||
# Priority 3: Fuzzy name search in library
|
||||
if agent_name:
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
response = await lib_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=agent_name,
|
||||
page_size=5,
|
||||
@@ -228,9 +229,11 @@ class AgentOutputTool(BaseTool):
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
"""
|
||||
exec_db = execution_db()
|
||||
|
||||
# If specific execution_id provided, fetch it directly
|
||||
if execution_id:
|
||||
execution = await execution_db.get_graph_execution(
|
||||
execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
@@ -240,7 +243,7 @@ class AgentOutputTool(BaseTool):
|
||||
return execution, [], None
|
||||
|
||||
# Get completed executions with time filters
|
||||
executions = await execution_db.get_graph_executions(
|
||||
executions = await exec_db.get_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
@@ -254,7 +257,7 @@ class AgentOutputTool(BaseTool):
|
||||
|
||||
# If only one execution, fetch full details
|
||||
if len(executions) == 1:
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
@@ -262,7 +265,7 @@ class AgentOutputTool(BaseTool):
|
||||
return full_execution, [], None
|
||||
|
||||
# Multiple executions - return latest with full details, plus list of available
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
@@ -380,7 +383,7 @@ class AgentOutputTool(BaseTool):
|
||||
and not input_data.store_slug
|
||||
):
|
||||
# Fetch execution directly to get graph_id
|
||||
execution = await execution_db.get_graph_execution(
|
||||
execution = await execution_db().get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=False,
|
||||
@@ -392,7 +395,7 @@ class AgentOutputTool(BaseTool):
|
||||
)
|
||||
|
||||
# Find library agent by graph_id
|
||||
agent = await library_db.get_library_agent_by_graph_id(
|
||||
agent = await library_db().get_library_agent_by_graph_id(
|
||||
user_id, execution.graph_id
|
||||
)
|
||||
if not agent:
|
||||
@@ -4,8 +4,7 @@ import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .models import (
|
||||
@@ -45,8 +44,10 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
Returns:
|
||||
AgentInfo if found, None otherwise
|
||||
"""
|
||||
lib_db = library_db()
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
@@ -71,7 +72,7 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
)
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
agent = await lib_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
@@ -133,7 +134,7 @@ async def search_agents(
|
||||
try:
|
||||
if source == "marketplace":
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||
results = await store_db().get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
@@ -159,7 +160,7 @@ async def search_agents(
|
||||
|
||||
if not agents:
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
results = await library_db().list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -3,9 +3,9 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import store_db as get_store_db
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -137,6 +137,8 @@ class CustomizeAgentTool(BaseTool):
|
||||
|
||||
creator_username, agent_slug = parts
|
||||
|
||||
store_db = get_store_db()
|
||||
|
||||
# Fetch the marketplace agent details
|
||||
try:
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
@@ -3,18 +3,18 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.api.features.chat.tools.models import (
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.copilot.tools.models import (
|
||||
BlockInfoSummary,
|
||||
BlockInputFieldInfo,
|
||||
BlockListResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.data.db_accessors import search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -107,7 +107,7 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await unified_hybrid_search(
|
||||
results, total = await search().unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
FindBlockTool,
|
||||
)
|
||||
from backend.api.features.chat.tools.models import BlockListResponse
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.models import BlockListResponse
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
@@ -75,13 +75,17 @@ class TestFindBlockFiltering:
|
||||
"standard-block-id": standard_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
@@ -119,13 +123,17 @@ class TestFindBlockFiltering:
|
||||
"normal-block-id": normal_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
):
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
@@ -4,9 +4,9 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.copilot.tools.models import (
|
||||
DocPageResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
@@ -40,9 +40,6 @@ class ResponseType(str, Enum):
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
# Input validation
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
# Feature request types
|
||||
FEATURE_REQUEST_SEARCH = "feature_request_search"
|
||||
FEATURE_REQUEST_CREATED = "feature_request_created"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -424,34 +421,3 @@ class AsyncProcessingResponse(ToolResponseBase):
|
||||
status: str = "accepted" # Must be "accepted" for detection
|
||||
operation_id: str | None = None
|
||||
task_id: str | None = None
|
||||
|
||||
|
||||
# Feature request models
|
||||
class FeatureRequestInfo(BaseModel):
|
||||
"""Information about a feature request issue."""
|
||||
|
||||
id: str
|
||||
identifier: str
|
||||
title: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class FeatureRequestSearchResponse(ToolResponseBase):
|
||||
"""Response for search_feature_requests tool."""
|
||||
|
||||
type: ResponseType = ResponseType.FEATURE_REQUEST_SEARCH
|
||||
results: list[FeatureRequestInfo]
|
||||
count: int
|
||||
query: str
|
||||
|
||||
|
||||
class FeatureRequestCreatedResponse(ToolResponseBase):
|
||||
"""Response for create_feature_request tool."""
|
||||
|
||||
type: ResponseType = ResponseType.FEATURE_REQUEST_CREATED
|
||||
issue_id: str
|
||||
issue_identifier: str
|
||||
issue_title: str
|
||||
issue_url: str
|
||||
is_new_issue: bool # False if added to existing
|
||||
customer_name: str
|
||||
@@ -5,16 +5,12 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.api.features.chat.config import ChatConfig
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tracking import (
|
||||
track_agent_run_success,
|
||||
track_agent_scheduled,
|
||||
)
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
|
||||
from backend.data.db_accessors import graph_db, library_db, user_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
@@ -200,7 +196,7 @@ class RunAgentTool(BaseTool):
|
||||
|
||||
# Priority: library_agent_id if provided
|
||||
if has_library_id:
|
||||
library_agent = await library_db.get_library_agent(
|
||||
library_agent = await library_db().get_library_agent(
|
||||
params.library_agent_id, user_id
|
||||
)
|
||||
if not library_agent:
|
||||
@@ -209,9 +205,7 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
# Get the graph from the library agent
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
graph = await get_graph(
|
||||
graph = await graph_db().get_graph(
|
||||
library_agent.graph_id,
|
||||
library_agent.graph_version,
|
||||
user_id=user_id,
|
||||
@@ -522,7 +516,7 @@ class RunAgentTool(BaseTool):
|
||||
library_agent = await get_or_create_library_agent(graph, user_id)
|
||||
|
||||
# Get user timezone
|
||||
user = await get_user_by_id(user_id)
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
|
||||
|
||||
# Create schedule
|
||||
@@ -7,16 +7,16 @@ from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
)
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
@@ -190,7 +190,7 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Get or create user's workspace for CoPilot file operations
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
|
||||
# Generate synthetic IDs for CoPilot context
|
||||
# Each chat session is treated as its own agent with one continuous run
|
||||
@@ -4,9 +4,9 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.models import ErrorResponse
|
||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
from backend.copilot.tools.run_block import RunBlockTool
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
@@ -39,7 +39,7 @@ class TestRunBlockFiltering:
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=input_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -65,7 +65,7 @@ class TestRunBlockFiltering:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=smart_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -89,7 +89,7 @@ class TestRunBlockFiltering:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -5,16 +5,16 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.copilot.tools.models import (
|
||||
DocSearchResult,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.data.db_accessors import search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -117,7 +117,7 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Search using hybrid search for DOCUMENTATION content type only
|
||||
results, total = await unified_hybrid_search(
|
||||
results, total = await search().unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
@@ -3,9 +3,8 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import (
|
||||
Credentials,
|
||||
@@ -38,13 +37,14 @@ async def fetch_graph_from_store_slug(
|
||||
Raises:
|
||||
DatabaseError: If there's a database error during lookup.
|
||||
"""
|
||||
sdb = store_db()
|
||||
try:
|
||||
store_agent = await store_db.get_store_agent_details(username, agent_name)
|
||||
store_agent = await sdb.get_store_agent_details(username, agent_name)
|
||||
except NotFoundError:
|
||||
return None, None
|
||||
|
||||
# Get the graph from store listing version
|
||||
graph = await store_db.get_available_graph(
|
||||
graph = await sdb.get_available_graph(
|
||||
store_agent.store_listing_version_id, hide_nodes=False
|
||||
)
|
||||
return graph, store_agent
|
||||
@@ -209,13 +209,13 @@ async def get_or_create_library_agent(
|
||||
Returns:
|
||||
LibraryAgent instance
|
||||
"""
|
||||
existing = await library_db.get_library_agent_by_graph_id(
|
||||
existing = await library_db().get_library_agent_by_graph_id(
|
||||
graph_id=graph.id, user_id=user_id
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
library_agents = await library_db.create_library_agent(
|
||||
library_agents = await library_db().create_library_agent(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
@@ -6,8 +6,8 @@ from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
@@ -146,7 +146,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
@@ -280,7 +280,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
@@ -478,7 +478,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
@@ -577,7 +577,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
118
autogpt_platform/backend/backend/data/db_accessors.py
Normal file
118
autogpt_platform/backend/backend/data/db_accessors.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from backend.data import db
|
||||
|
||||
|
||||
def chat_db():
|
||||
if db.is_connected():
|
||||
from backend.copilot import db as _chat_db
|
||||
|
||||
chat_db = _chat_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
chat_db = get_database_manager_async_client()
|
||||
|
||||
return chat_db
|
||||
|
||||
|
||||
def graph_db():
|
||||
if db.is_connected():
|
||||
from backend.data import graph as _graph_db
|
||||
|
||||
graph_db = _graph_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
graph_db = get_database_manager_async_client()
|
||||
|
||||
return graph_db
|
||||
|
||||
|
||||
def library_db():
|
||||
if db.is_connected():
|
||||
from backend.api.features.library import db as _library_db
|
||||
|
||||
library_db = _library_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
library_db = get_database_manager_async_client()
|
||||
|
||||
return library_db
|
||||
|
||||
|
||||
def store_db():
|
||||
if db.is_connected():
|
||||
from backend.api.features.store import db as _store_db
|
||||
|
||||
store_db = _store_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
store_db = get_database_manager_async_client()
|
||||
|
||||
return store_db
|
||||
|
||||
|
||||
def search():
|
||||
if db.is_connected():
|
||||
from backend.api.features.store import hybrid_search as _search
|
||||
|
||||
search = _search
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
search = get_database_manager_async_client()
|
||||
|
||||
return search
|
||||
|
||||
|
||||
def execution_db():
|
||||
if db.is_connected():
|
||||
from backend.data import execution as _execution_db
|
||||
|
||||
execution_db = _execution_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
execution_db = get_database_manager_async_client()
|
||||
|
||||
return execution_db
|
||||
|
||||
|
||||
def user_db():
|
||||
if db.is_connected():
|
||||
from backend.data import user as _user_db
|
||||
|
||||
user_db = _user_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
user_db = get_database_manager_async_client()
|
||||
|
||||
return user_db
|
||||
|
||||
|
||||
def understanding_db():
|
||||
if db.is_connected():
|
||||
from backend.data import understanding as _understanding_db
|
||||
|
||||
understanding_db = _understanding_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
understanding_db = get_database_manager_async_client()
|
||||
|
||||
return understanding_db
|
||||
|
||||
|
||||
def workspace_db():
|
||||
if db.is_connected():
|
||||
from backend.data import workspace as _workspace_db
|
||||
|
||||
workspace_db = _workspace_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
workspace_db = get_database_manager_async_client()
|
||||
|
||||
return workspace_db
|
||||
@@ -4,14 +4,26 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas
|
||||
|
||||
from backend.api.features.library.db import (
|
||||
add_store_agent_to_library,
|
||||
create_graph_in_library,
|
||||
create_library_agent,
|
||||
get_library_agent,
|
||||
get_library_agent_by_graph_id,
|
||||
list_library_agents,
|
||||
update_graph_in_library,
|
||||
)
|
||||
from backend.api.features.store.db import (
|
||||
get_agent,
|
||||
get_available_graph,
|
||||
get_store_agent_details,
|
||||
get_store_agents,
|
||||
)
|
||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
||||
from backend.api.features.store.embeddings import (
|
||||
backfill_missing_embeddings,
|
||||
cleanup_orphaned_embeddings,
|
||||
get_embedding_stats,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.copilot import db as chat_db
|
||||
from backend.data import db
|
||||
from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
@@ -48,6 +60,7 @@ from backend.data.graph import (
|
||||
get_graph_metadata,
|
||||
get_graph_settings,
|
||||
get_node,
|
||||
get_store_listed_graphs,
|
||||
validate_graph_execution_permissions,
|
||||
)
|
||||
from backend.data.human_review import (
|
||||
@@ -67,6 +80,10 @@ from backend.data.notifications import (
|
||||
remove_notifications_from_batch,
|
||||
)
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.data.understanding import (
|
||||
get_business_understanding,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
from backend.data.user import (
|
||||
get_active_user_ids_in_timerange,
|
||||
get_user_by_id,
|
||||
@@ -76,6 +93,7 @@ from backend.data.user import (
|
||||
get_user_notification_preference,
|
||||
update_user_integrations,
|
||||
)
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.service import (
|
||||
AppService,
|
||||
AppServiceClient,
|
||||
@@ -107,6 +125,13 @@ async def _get_credits(user_id: str) -> int:
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
"""Database connection pooling service.
|
||||
|
||||
This service connects to the Prisma engine and exposes database
|
||||
operations via RPC endpoints. It acts as a centralized connection pool
|
||||
for all services that need database access.
|
||||
"""
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(self, app: "FastAPI"):
|
||||
async with super().lifespan(app):
|
||||
@@ -142,11 +167,15 @@ class DatabaseManager(AppService):
|
||||
def _(
|
||||
f: Callable[P, R], name: str | None = None
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
"""
|
||||
Exposes a function as an RPC endpoint, and adds a virtual `self` param
|
||||
to the function's type so it can be bound as a method.
|
||||
"""
|
||||
if name is not None:
|
||||
f.__name__ = name
|
||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||
|
||||
# Executions
|
||||
# ============ Graph Executions ============ #
|
||||
get_child_graph_executions = _(get_child_graph_executions)
|
||||
get_graph_executions = _(get_graph_executions)
|
||||
get_graph_executions_count = _(get_graph_executions_count)
|
||||
@@ -170,36 +199,37 @@ class DatabaseManager(AppService):
|
||||
get_frequently_executed_graphs = _(get_frequently_executed_graphs)
|
||||
get_marketplace_graphs_for_monitoring = _(get_marketplace_graphs_for_monitoring)
|
||||
|
||||
# Graphs
|
||||
# ============ Graphs ============ #
|
||||
get_node = _(get_node)
|
||||
get_graph = _(get_graph)
|
||||
get_connected_output_nodes = _(get_connected_output_nodes)
|
||||
get_graph_metadata = _(get_graph_metadata)
|
||||
get_graph_settings = _(get_graph_settings)
|
||||
get_store_listed_graphs = _(get_store_listed_graphs)
|
||||
|
||||
# Credits
|
||||
# ============ Credits ============ #
|
||||
spend_credits = _(_spend_credits, name="spend_credits")
|
||||
get_credits = _(_get_credits, name="get_credits")
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
# ============ User + Integrations ============ #
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
get_user_integrations = _(get_user_integrations)
|
||||
update_user_integrations = _(update_user_integrations)
|
||||
|
||||
# User Comms - async
|
||||
# ============ User Comms ============ #
|
||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||
get_user_by_id = _(get_user_by_id)
|
||||
get_user_email_by_id = _(get_user_email_by_id)
|
||||
get_user_email_verification = _(get_user_email_verification)
|
||||
get_user_notification_preference = _(get_user_notification_preference)
|
||||
|
||||
# Human In The Loop
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||
check_approval = _(check_approval)
|
||||
get_or_create_human_review = _(get_or_create_human_review)
|
||||
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
||||
update_review_processed_status = _(update_review_processed_status)
|
||||
|
||||
# Notifications - async
|
||||
# ============ Notifications ============ #
|
||||
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
|
||||
create_or_add_to_user_notification_batch = _(
|
||||
create_or_add_to_user_notification_batch
|
||||
@@ -212,29 +242,56 @@ class DatabaseManager(AppService):
|
||||
get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Library
|
||||
# ============ Library ============ #
|
||||
list_library_agents = _(list_library_agents)
|
||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
||||
create_graph_in_library = _(create_graph_in_library)
|
||||
create_library_agent = _(create_library_agent)
|
||||
get_library_agent = _(get_library_agent)
|
||||
get_library_agent_by_graph_id = _(get_library_agent_by_graph_id)
|
||||
update_graph_in_library = _(update_graph_in_library)
|
||||
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
||||
|
||||
# Onboarding
|
||||
# ============ Onboarding ============ #
|
||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||
|
||||
# OAuth
|
||||
# ============ OAuth ============ #
|
||||
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
||||
|
||||
# Store
|
||||
# ============ Store ============ #
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
get_agent = _(get_agent)
|
||||
get_available_graph = _(get_available_graph)
|
||||
|
||||
# Store Embeddings
|
||||
# ============ Search ============ #
|
||||
get_embedding_stats = _(get_embedding_stats)
|
||||
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
||||
cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings)
|
||||
unified_hybrid_search = _(unified_hybrid_search)
|
||||
|
||||
# Summary data - async
|
||||
# ============ Summary Data ============ #
|
||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||
|
||||
# ============ Workspace ============ #
|
||||
get_or_create_workspace = _(get_or_create_workspace)
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = _(get_business_understanding)
|
||||
upsert_business_understanding = _(upsert_business_understanding)
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = _(chat_db.get_chat_session)
|
||||
create_chat_session = _(chat_db.create_chat_session)
|
||||
update_chat_session = _(chat_db.update_chat_session)
|
||||
add_chat_message = _(chat_db.add_chat_message)
|
||||
add_chat_messages_batch = _(chat_db.add_chat_messages_batch)
|
||||
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
|
||||
get_user_session_count = _(chat_db.get_user_session_count)
|
||||
delete_chat_session = _(chat_db.delete_chat_session)
|
||||
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
|
||||
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||
|
||||
|
||||
class DatabaseManagerClient(AppServiceClient):
|
||||
d = DatabaseManager
|
||||
@@ -296,43 +353,50 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
def get_service_type(cls):
|
||||
return DatabaseManager
|
||||
|
||||
# ============ Graph Executions ============ #
|
||||
create_graph_execution = d.create_graph_execution
|
||||
get_child_graph_executions = d.get_child_graph_executions
|
||||
get_connected_output_nodes = d.get_connected_output_nodes
|
||||
get_latest_node_execution = d.get_latest_node_execution
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_settings = d.get_graph_settings
|
||||
get_graph_execution = d.get_graph_execution
|
||||
get_graph_execution_meta = d.get_graph_execution_meta
|
||||
get_node = d.get_node
|
||||
get_graph_executions = d.get_graph_executions
|
||||
get_node_execution = d.get_node_execution
|
||||
get_node_executions = d.get_node_executions
|
||||
get_user_by_id = d.get_user_by_id
|
||||
get_user_integrations = d.get_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
||||
update_graph_execution_stats = d.update_graph_execution_stats
|
||||
update_node_execution_status = d.update_node_execution_status
|
||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||
update_user_integrations = d.update_user_integrations
|
||||
upsert_execution_input = d.upsert_execution_input
|
||||
upsert_execution_output = d.upsert_execution_output
|
||||
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
||||
get_execution_kv_data = d.get_execution_kv_data
|
||||
set_execution_kv_data = d.set_execution_kv_data
|
||||
|
||||
# Human In The Loop
|
||||
# ============ Graphs ============ #
|
||||
get_graph = d.get_graph
|
||||
get_graph_metadata = d.get_graph_metadata
|
||||
get_graph_settings = d.get_graph_settings
|
||||
get_node = d.get_node
|
||||
get_store_listed_graphs = d.get_store_listed_graphs
|
||||
|
||||
# ============ User + Integrations ============ #
|
||||
get_user_by_id = d.get_user_by_id
|
||||
get_user_integrations = d.get_user_integrations
|
||||
update_user_integrations = d.update_user_integrations
|
||||
|
||||
# ============ Human In The Loop ============ #
|
||||
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||
check_approval = d.check_approval
|
||||
get_or_create_human_review = d.get_or_create_human_review
|
||||
update_review_processed_status = d.update_review_processed_status
|
||||
|
||||
# User Comms
|
||||
# ============ User Comms ============ #
|
||||
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
||||
get_user_email_by_id = d.get_user_email_by_id
|
||||
get_user_email_verification = d.get_user_email_verification
|
||||
get_user_notification_preference = d.get_user_notification_preference
|
||||
|
||||
# Notifications
|
||||
# ============ Notifications ============ #
|
||||
clear_all_user_notification_batches = d.clear_all_user_notification_batches
|
||||
create_or_add_to_user_notification_batch = (
|
||||
d.create_or_add_to_user_notification_batch
|
||||
@@ -345,20 +409,49 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
d.get_user_notification_oldest_message_in_batch
|
||||
)
|
||||
|
||||
# Library
|
||||
# ============ Library ============ #
|
||||
list_library_agents = d.list_library_agents
|
||||
add_store_agent_to_library = d.add_store_agent_to_library
|
||||
create_graph_in_library = d.create_graph_in_library
|
||||
create_library_agent = d.create_library_agent
|
||||
get_library_agent = d.get_library_agent
|
||||
get_library_agent_by_graph_id = d.get_library_agent_by_graph_id
|
||||
update_graph_in_library = d.update_graph_in_library
|
||||
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
||||
|
||||
# Onboarding
|
||||
# ============ Onboarding ============ #
|
||||
increment_onboarding_runs = d.increment_onboarding_runs
|
||||
|
||||
# OAuth
|
||||
# ============ OAuth ============ #
|
||||
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
||||
|
||||
# Store
|
||||
# ============ Store ============ #
|
||||
get_store_agents = d.get_store_agents
|
||||
get_store_agent_details = d.get_store_agent_details
|
||||
get_agent = d.get_agent
|
||||
get_available_graph = d.get_available_graph
|
||||
|
||||
# Summary data
|
||||
# ============ Search ============ #
|
||||
unified_hybrid_search = d.unified_hybrid_search
|
||||
|
||||
# ============ Summary Data ============ #
|
||||
get_user_execution_summary_data = d.get_user_execution_summary_data
|
||||
|
||||
# ============ Workspace ============ #
|
||||
get_or_create_workspace = d.get_or_create_workspace
|
||||
|
||||
# ============ Understanding ============ #
|
||||
get_business_understanding = d.get_business_understanding
|
||||
upsert_business_understanding = d.upsert_business_understanding
|
||||
|
||||
# ============ CoPilot Chat Sessions ============ #
|
||||
get_chat_session = d.get_chat_session
|
||||
create_chat_session = d.create_chat_session
|
||||
update_chat_session = d.update_chat_session
|
||||
add_chat_message = d.add_chat_message
|
||||
add_chat_messages_batch = d.add_chat_messages_batch
|
||||
get_user_chat_sessions = d.get_user_chat_sessions
|
||||
get_user_session_count = d.get_user_session_count
|
||||
delete_chat_session = d.delete_chat_session
|
||||
get_chat_session_message_count = d.get_chat_session_message_count
|
||||
update_tool_message_content = d.update_tool_message_content
|
||||
@@ -1,5 +1,5 @@
|
||||
from backend.app import run_processes
|
||||
from backend.executor import DatabaseManager
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
from .database import DatabaseManager, DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from .manager import ExecutionManager
|
||||
from .scheduler import Scheduler
|
||||
|
||||
__all__ = [
|
||||
"DatabaseManager",
|
||||
"DatabaseManagerClient",
|
||||
"DatabaseManagerAsyncClient",
|
||||
"ExecutionManager",
|
||||
"Scheduler",
|
||||
]
|
||||
|
||||
@@ -22,7 +22,7 @@ from backend.util.settings import Settings
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Redis-based distributed locking for cluster coordination."""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -19,6 +20,7 @@ class ClusterLock:
|
||||
self.owner_id = owner_id
|
||||
self.timeout = timeout
|
||||
self._last_refresh = 0.0
|
||||
self._refresh_lock = threading.Lock()
|
||||
|
||||
def try_acquire(self) -> str | None:
|
||||
"""Try to acquire the lock.
|
||||
@@ -31,7 +33,8 @@ class ClusterLock:
|
||||
try:
|
||||
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
||||
if success:
|
||||
self._last_refresh = time.time()
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = time.time()
|
||||
return self.owner_id # Successfully acquired
|
||||
|
||||
# Failed to acquire, get current owner
|
||||
@@ -57,23 +60,27 @@ class ClusterLock:
|
||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||
|
||||
Thread-safe: uses _refresh_lock to protect _last_refresh access.
|
||||
"""
|
||||
# Calculate refresh interval: max(timeout // 10, 1)
|
||||
refresh_interval = max(self.timeout // 10, 1)
|
||||
current_time = time.time()
|
||||
|
||||
# Check if we're within the rate limit period
|
||||
# Check if we're within the rate limit period (thread-safe read)
|
||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||
with self._refresh_lock:
|
||||
last_refresh = self._last_refresh
|
||||
is_rate_limited = (
|
||||
self._last_refresh > 0
|
||||
and (current_time - self._last_refresh) < refresh_interval
|
||||
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
|
||||
)
|
||||
|
||||
try:
|
||||
# Always verify lock existence, even during rate limiting
|
||||
current_value = self.redis.get(self.key)
|
||||
if not current_value:
|
||||
self._last_refresh = 0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
stored_owner = (
|
||||
@@ -82,7 +89,8 @@ class ClusterLock:
|
||||
else str(current_value)
|
||||
)
|
||||
if stored_owner != self.owner_id:
|
||||
self._last_refresh = 0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
# If rate limited, return True but don't update TTL or timestamp
|
||||
@@ -91,25 +99,30 @@ class ClusterLock:
|
||||
|
||||
# Perform actual refresh
|
||||
if self.redis.expire(self.key, self.timeout):
|
||||
self._last_refresh = current_time
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = current_time
|
||||
return True
|
||||
|
||||
self._last_refresh = 0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
||||
self._last_refresh = 0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0
|
||||
return False
|
||||
|
||||
def release(self):
|
||||
"""Release the lock."""
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
with self._refresh_lock:
|
||||
if self._last_refresh == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
self.redis.delete(self.key)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._last_refresh = 0.0
|
||||
with self._refresh_lock:
|
||||
self._last_refresh = 0.0
|
||||
|
||||
@@ -92,7 +92,10 @@ from .utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from backend.data.db_manager import (
|
||||
DatabaseManagerAsyncClient,
|
||||
DatabaseManagerClient,
|
||||
)
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,12 +13,15 @@ if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
from supabase import AClient, Client
|
||||
|
||||
from backend.data.db_manager import (
|
||||
DatabaseManagerAsyncClient,
|
||||
DatabaseManagerClient,
|
||||
)
|
||||
from backend.data.execution import (
|
||||
AsyncRedisExecutionEventBus,
|
||||
RedisExecutionEventBus,
|
||||
)
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ, SyncRabbitMQ
|
||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
||||
from backend.executor.scheduler import SchedulerClient
|
||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||
from backend.notifications.notifications import NotificationManagerClient
|
||||
@@ -27,7 +30,7 @@ if TYPE_CHECKING:
|
||||
@thread_cached
|
||||
def get_database_manager_client() -> "DatabaseManagerClient":
|
||||
"""Get a thread-cached DatabaseManagerClient with request retry enabled."""
|
||||
from backend.executor import DatabaseManagerClient
|
||||
from backend.data.db_manager import DatabaseManagerClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerClient, request_retry=True)
|
||||
@@ -38,7 +41,7 @@ def get_database_manager_async_client(
|
||||
should_retry: bool = True,
|
||||
) -> "DatabaseManagerAsyncClient":
|
||||
"""Get a thread-cached DatabaseManagerAsyncClient with request retry enabled."""
|
||||
from backend.executor import DatabaseManagerAsyncClient
|
||||
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||
from backend.util.service import get_service_client
|
||||
|
||||
return get_service_client(DatabaseManagerAsyncClient, request_retry=should_retry)
|
||||
@@ -106,6 +109,20 @@ async def get_async_execution_queue() -> "AsyncRabbitMQ":
|
||||
return client
|
||||
|
||||
|
||||
# ============ CoPilot Queue Helpers ============ #
|
||||
|
||||
|
||||
@thread_cached
|
||||
async def get_async_copilot_queue() -> "AsyncRabbitMQ":
|
||||
"""Get a thread-cached AsyncRabbitMQ CoPilot queue client."""
|
||||
from backend.copilot.executor.utils import create_copilot_queue_config
|
||||
from backend.data.rabbitmq import AsyncRabbitMQ
|
||||
|
||||
client = AsyncRabbitMQ(create_copilot_queue_config())
|
||||
await client.connect()
|
||||
return client
|
||||
|
||||
|
||||
# ============ Integration Credentials Store ============ #
|
||||
|
||||
|
||||
|
||||
@@ -211,16 +211,23 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The port for execution manager daemon to run on",
|
||||
)
|
||||
|
||||
num_copilot_workers: int = Field(
|
||||
default=5,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of concurrent CoPilot executor workers",
|
||||
)
|
||||
|
||||
copilot_executor_port: int = Field(
|
||||
default=8008,
|
||||
description="The port for CoPilot executor daemon to run on",
|
||||
)
|
||||
|
||||
execution_scheduler_port: int = Field(
|
||||
default=8003,
|
||||
description="The port for execution scheduler daemon to run on",
|
||||
)
|
||||
|
||||
agent_server_port: int = Field(
|
||||
default=8004,
|
||||
description="The port for agent server daemon to run on",
|
||||
)
|
||||
|
||||
database_api_port: int = Field(
|
||||
default=8005,
|
||||
description="The port for database server API to run on",
|
||||
@@ -658,9 +665,6 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
||||
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
|
||||
|
||||
linear_api_key: str = Field(
|
||||
default="", description="Linear API key for system-level operations"
|
||||
)
|
||||
linear_client_id: str = Field(default="", description="Linear client ID")
|
||||
linear_client_secret: str = Field(default="", description="Linear client secret")
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from backend.api.rest_api import AgentServer
|
||||
from backend.blocks._base import Block, BlockSchema
|
||||
from backend.data import db
|
||||
from backend.data.block import initialize_blocks
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
ExecutionStatus,
|
||||
@@ -19,7 +20,7 @@ from backend.data.execution import (
|
||||
)
|
||||
from backend.data.model import _BaseCredentials
|
||||
from backend.data.user import create_default_user
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.notifications.notifications import NotificationManager
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -116,6 +116,7 @@ ws = "backend.ws:main"
|
||||
scheduler = "backend.scheduler:main"
|
||||
notification = "backend.notification:main"
|
||||
executor = "backend.exec:main"
|
||||
copilot-executor = "backend.copilot.executor.__main__:main"
|
||||
cli = "backend.cli:main"
|
||||
format = "linter:format"
|
||||
lint = "linter:lint"
|
||||
|
||||
@@ -9,10 +9,8 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import core
|
||||
from backend.api.features.chat.tools.agent_generator.core import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
)
|
||||
from backend.copilot.tools.agent_generator import core
|
||||
from backend.copilot.tools.agent_generator.core import AgentGeneratorNotConfiguredError
|
||||
|
||||
|
||||
class TestServiceNotConfigured:
|
||||
|
||||
@@ -9,7 +9,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import core
|
||||
from backend.copilot.tools.agent_generator import core
|
||||
|
||||
|
||||
class TestGetLibraryAgentsForGeneration:
|
||||
@@ -31,18 +31,20 @@ class TestGetLibraryAgentsForGeneration:
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = [mock_agent]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
core,
|
||||
"library_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="send email",
|
||||
)
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
mock_db.list_library_agents.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term="send email",
|
||||
page=1,
|
||||
@@ -80,11 +82,13 @@ class TestGetLibraryAgentsForGeneration:
|
||||
),
|
||||
]
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
core,
|
||||
"library_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
@@ -101,18 +105,20 @@ class TestGetLibraryAgentsForGeneration:
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
core,
|
||||
"library_db",
|
||||
return_value=mock_db,
|
||||
):
|
||||
await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
max_results=5,
|
||||
)
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
mock_db.list_library_agents.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term=None,
|
||||
page=1,
|
||||
@@ -144,24 +150,24 @@ class TestSearchMarketplaceAgentsForGeneration:
|
||||
mock_graph.input_schema = {"type": "object"}
|
||||
mock_graph.output_schema = {"type": "object"}
|
||||
|
||||
mock_store_db = MagicMock()
|
||||
mock_store_db.get_store_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
mock_graph_db = MagicMock()
|
||||
mock_graph_db.get_store_listed_graphs = AsyncMock(
|
||||
return_value={"graph-123": mock_graph}
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_search,
|
||||
patch(
|
||||
"backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"graph-123": mock_graph},
|
||||
),
|
||||
patch.object(core, "store_db", return_value=mock_store_db),
|
||||
patch.object(core, "graph_db", return_value=mock_graph_db),
|
||||
):
|
||||
result = await core.search_marketplace_agents_for_generation(
|
||||
search_query="automation",
|
||||
max_results=10,
|
||||
)
|
||||
|
||||
mock_search.assert_called_once_with(
|
||||
mock_store_db.get_store_agents.assert_called_once_with(
|
||||
search_query="automation",
|
||||
page=1,
|
||||
page_size=10,
|
||||
@@ -707,7 +713,7 @@ class TestExtractUuidsFromText:
|
||||
|
||||
|
||||
class TestGetLibraryAgentById:
|
||||
"""Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id)."""
|
||||
"""Test get_library_agent_by_id function (alias: get_library_agent_by_graph_id)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_agent_when_found_by_graph_id(self):
|
||||
@@ -720,12 +726,10 @@ class TestGetLibraryAgentById:
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is not None
|
||||
@@ -743,20 +747,11 @@ class TestGetLibraryAgentById:
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # Not found by graph_id
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent, # Found by library ID
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None)
|
||||
mock_db.get_library_agent = AsyncMock(return_value=mock_agent)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_library_agent_by_id("user-123", "library-id-123")
|
||||
|
||||
assert result is not None
|
||||
@@ -766,20 +761,13 @@ class TestGetLibraryAgentById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_found_by_either_method(self):
|
||||
"""Test that None is returned when agent not found by either method."""
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=core.NotFoundError("Not found"),
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None)
|
||||
mock_db.get_library_agent = AsyncMock(
|
||||
side_effect=core.NotFoundError("Not found")
|
||||
)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_library_agent_by_id("user-123", "nonexistent")
|
||||
|
||||
assert result is None
|
||||
@@ -787,27 +775,20 @@ class TestGetLibraryAgentById:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_exception(self):
|
||||
"""Test that None is returned when exception occurs in both lookups."""
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(
|
||||
side_effect=Exception("Database error")
|
||||
)
|
||||
mock_db.get_library_agent = AsyncMock(side_effect=Exception("Database error"))
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alias_works(self):
|
||||
"""Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id."""
|
||||
"""Test that get_library_agent_by_graph_id is an alias."""
|
||||
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
|
||||
|
||||
|
||||
@@ -828,20 +809,11 @@ class TestGetAllRelevantAgentsWithUuids:
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(core, "library_db", return_value=mock_db):
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",
|
||||
|
||||
@@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import service
|
||||
from backend.copilot.tools.agent_generator import service
|
||||
|
||||
|
||||
class TestServiceConfiguration:
|
||||
|
||||
@@ -1,468 +0,0 @@
|
||||
"""
|
||||
Test script for Linear GraphQL API - Customer Requests operations.
|
||||
|
||||
Tests the exact GraphQL calls needed for:
|
||||
1. search_feature_requests - search issues in the Customer Feature Requests project
|
||||
2. add_feature_request - upsert customer + create customer need on issue
|
||||
|
||||
Requires LINEAR_API_KEY in backend/.env
|
||||
Generate one at: https://linear.app/settings/api
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
LINEAR_API_URL = "https://api.linear.app/graphql"
|
||||
API_KEY = os.getenv("LINEAR_API_KEY")
|
||||
|
||||
# Target project for feature requests
|
||||
FEATURE_REQUEST_PROJECT_ID = "13f066f3-f639-4a67-aaa3-31483ebdf8cd"
|
||||
# Team: Internal
|
||||
TEAM_ID = "557fd3d5-087e-43a9-83e3-476c8313ce49"
|
||||
|
||||
if not API_KEY:
|
||||
print("ERROR: LINEAR_API_KEY not found in .env")
|
||||
print("Generate a personal API key at: https://linear.app/settings/api")
|
||||
print("Then add LINEAR_API_KEY=lin_api_... to backend/.env")
|
||||
sys.exit(1)
|
||||
|
||||
HEADERS = {
|
||||
"Authorization": API_KEY,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
def graphql(query: str, variables: dict | None = None) -> dict:
|
||||
"""Execute a GraphQL query against Linear API."""
|
||||
payload = {"query": query}
|
||||
if variables:
|
||||
payload["variables"] = variables
|
||||
|
||||
resp = httpx.post(LINEAR_API_URL, json=payload, headers=HEADERS, timeout=30)
|
||||
if resp.status_code != 200:
|
||||
print(f"HTTP {resp.status_code}: {resp.text[:500]}")
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if "errors" in data:
|
||||
print(f"GraphQL Errors: {json.dumps(data['errors'], indent=2)}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# QUERIES
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Search issues within the feature requests project by title/description
|
||||
SEARCH_ISSUES_IN_PROJECT = """
|
||||
query SearchFeatureRequests($filter: IssueFilter!, $first: Int) {
|
||||
issues(filter: $filter, first: $first) {
|
||||
nodes {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
description
|
||||
url
|
||||
state {
|
||||
name
|
||||
type
|
||||
}
|
||||
project {
|
||||
id
|
||||
name
|
||||
}
|
||||
labels {
|
||||
nodes {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# Get issue with its customer needs
|
||||
GET_ISSUE_WITH_NEEDS = """
|
||||
query GetIssueWithNeeds($id: String!) {
|
||||
issue(id: $id) {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
url
|
||||
needs {
|
||||
nodes {
|
||||
id
|
||||
body
|
||||
priority
|
||||
customer {
|
||||
id
|
||||
name
|
||||
domains
|
||||
externalIds
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# Search customers
|
||||
SEARCH_CUSTOMERS = """
|
||||
query SearchCustomers($filter: CustomerFilter, $first: Int) {
|
||||
customers(filter: $filter, first: $first) {
|
||||
nodes {
|
||||
id
|
||||
name
|
||||
domains
|
||||
externalIds
|
||||
revenue
|
||||
size
|
||||
status {
|
||||
name
|
||||
}
|
||||
tier {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MUTATIONS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CUSTOMER_UPSERT = """
|
||||
mutation CustomerUpsert($input: CustomerUpsertInput!) {
|
||||
customerUpsert(input: $input) {
|
||||
success
|
||||
customer {
|
||||
id
|
||||
name
|
||||
domains
|
||||
externalIds
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
CUSTOMER_NEED_CREATE = """
|
||||
mutation CustomerNeedCreate($input: CustomerNeedCreateInput!) {
|
||||
customerNeedCreate(input: $input) {
|
||||
success
|
||||
need {
|
||||
id
|
||||
body
|
||||
priority
|
||||
customer {
|
||||
id
|
||||
name
|
||||
}
|
||||
issue {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
ISSUE_CREATE = """
|
||||
mutation IssueCreate($input: IssueCreateInput!) {
|
||||
issueCreate(input: $input) {
|
||||
success
|
||||
issue {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
url
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TESTS
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_1_search_feature_requests():
|
||||
"""Search for feature requests in the target project by keyword."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 1: Search feature requests in project by keyword")
|
||||
print("=" * 60)
|
||||
|
||||
search_term = "agent"
|
||||
result = graphql(
|
||||
SEARCH_ISSUES_IN_PROJECT,
|
||||
{
|
||||
"filter": {
|
||||
"project": {"id": {"eq": FEATURE_REQUEST_PROJECT_ID}},
|
||||
"or": [
|
||||
{"title": {"containsIgnoreCase": search_term}},
|
||||
{"description": {"containsIgnoreCase": search_term}},
|
||||
],
|
||||
},
|
||||
"first": 5,
|
||||
},
|
||||
)
|
||||
|
||||
issues = result.get("data", {}).get("issues", {}).get("nodes", [])
|
||||
for issue in issues:
|
||||
proj = issue.get("project") or {}
|
||||
print(f"\n [{issue['identifier']}] {issue['title']}")
|
||||
print(f" Project: {proj.get('name', 'N/A')}")
|
||||
print(f" State: {issue['state']['name']}")
|
||||
print(f" URL: {issue['url']}")
|
||||
|
||||
print(f"\n Found {len(issues)} issues matching '{search_term}'")
|
||||
return issues
|
||||
|
||||
|
||||
def test_2_list_all_in_project():
|
||||
"""List all issues in the feature requests project."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 2: List all issues in Customer Feature Requests project")
|
||||
print("=" * 60)
|
||||
|
||||
result = graphql(
|
||||
SEARCH_ISSUES_IN_PROJECT,
|
||||
{
|
||||
"filter": {
|
||||
"project": {"id": {"eq": FEATURE_REQUEST_PROJECT_ID}},
|
||||
},
|
||||
"first": 10,
|
||||
},
|
||||
)
|
||||
|
||||
issues = result.get("data", {}).get("issues", {}).get("nodes", [])
|
||||
if not issues:
|
||||
print(" No issues in project yet (empty project)")
|
||||
for issue in issues:
|
||||
print(f"\n [{issue['identifier']}] {issue['title']}")
|
||||
print(f" State: {issue['state']['name']}")
|
||||
|
||||
print(f"\n Total: {len(issues)} issues")
|
||||
return issues
|
||||
|
||||
|
||||
def test_3_search_customers():
|
||||
"""List existing customers."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 3: List customers")
|
||||
print("=" * 60)
|
||||
|
||||
result = graphql(SEARCH_CUSTOMERS, {"first": 10})
|
||||
customers = result.get("data", {}).get("customers", {}).get("nodes", [])
|
||||
|
||||
if not customers:
|
||||
print(" No customers exist yet")
|
||||
for c in customers:
|
||||
status = c.get("status") or {}
|
||||
tier = c.get("tier") or {}
|
||||
print(f"\n [{c['id'][:8]}...] {c['name']}")
|
||||
print(f" Domains: {c.get('domains', [])}")
|
||||
print(f" External IDs: {c.get('externalIds', [])}")
|
||||
print(
|
||||
f" Status: {status.get('name', 'N/A')}, Tier: {tier.get('name', 'N/A')}"
|
||||
)
|
||||
|
||||
print(f"\n Total: {len(customers)} customers")
|
||||
return customers
|
||||
|
||||
|
||||
def test_4_customer_upsert():
|
||||
"""Upsert a test customer."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 4: Customer upsert (find-or-create)")
|
||||
print("=" * 60)
|
||||
|
||||
result = graphql(
|
||||
CUSTOMER_UPSERT,
|
||||
{
|
||||
"input": {
|
||||
"name": "Test Customer (API Test)",
|
||||
"domains": ["test-api-customer.example.com"],
|
||||
"externalId": "test-customer-001",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
upsert = result.get("data", {}).get("customerUpsert", {})
|
||||
if upsert.get("success"):
|
||||
customer = upsert["customer"]
|
||||
print(f" Success! Customer: {customer['name']}")
|
||||
print(f" ID: {customer['id']}")
|
||||
print(f" Domains: {customer['domains']}")
|
||||
print(f" External IDs: {customer['externalIds']}")
|
||||
return customer
|
||||
else:
|
||||
print(f" Failed: {json.dumps(result, indent=2)}")
|
||||
return None
|
||||
|
||||
|
||||
def test_5_create_issue_and_need(customer_id: str):
|
||||
"""Create a new feature request issue and attach a customer need."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 5: Create issue + customer need")
|
||||
print("=" * 60)
|
||||
|
||||
# Step 1: Create issue in the project
|
||||
result = graphql(
|
||||
ISSUE_CREATE,
|
||||
{
|
||||
"input": {
|
||||
"title": "Test Feature Request (API Test - safe to delete)",
|
||||
"description": "This is a test feature request created via the GraphQL API.",
|
||||
"teamId": TEAM_ID,
|
||||
"projectId": FEATURE_REQUEST_PROJECT_ID,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
data = result.get("data")
|
||||
if not data:
|
||||
print(f" Issue creation failed: {json.dumps(result, indent=2)}")
|
||||
return None
|
||||
issue_data = data.get("issueCreate", {})
|
||||
if not issue_data.get("success"):
|
||||
print(f" Issue creation failed: {json.dumps(result, indent=2)}")
|
||||
return None
|
||||
|
||||
issue = issue_data["issue"]
|
||||
print(f" Created issue: [{issue['identifier']}] {issue['title']}")
|
||||
print(f" URL: {issue['url']}")
|
||||
|
||||
# Step 2: Attach customer need
|
||||
result = graphql(
|
||||
CUSTOMER_NEED_CREATE,
|
||||
{
|
||||
"input": {
|
||||
"customerId": customer_id,
|
||||
"issueId": issue["id"],
|
||||
"body": "Our team really needs this feature for our workflow. High priority for us!",
|
||||
"priority": 0,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
need_data = result.get("data", {}).get("customerNeedCreate", {})
|
||||
if need_data.get("success"):
|
||||
need = need_data["need"]
|
||||
print(f" Attached customer need: {need['id']}")
|
||||
print(f" Customer: {need['customer']['name']}")
|
||||
print(f" Body: {need['body'][:80]}")
|
||||
else:
|
||||
print(f" Customer need creation failed: {json.dumps(result, indent=2)}")
|
||||
|
||||
# Step 3: Verify by fetching the issue with needs
|
||||
print("\n Verifying...")
|
||||
verify = graphql(GET_ISSUE_WITH_NEEDS, {"id": issue["id"]})
|
||||
issue_verify = verify.get("data", {}).get("issue", {})
|
||||
needs = issue_verify.get("needs", {}).get("nodes", [])
|
||||
print(f" Issue now has {len(needs)} customer need(s)")
|
||||
for n in needs:
|
||||
cust = n.get("customer") or {}
|
||||
print(f" - {cust.get('name', 'N/A')}: {n.get('body', '')[:60]}")
|
||||
|
||||
return issue
|
||||
|
||||
|
||||
def test_6_add_need_to_existing(customer_id: str, issue_id: str):
|
||||
"""Add a customer need to an existing issue (the common case)."""
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 6: Add customer need to existing issue")
|
||||
print("=" * 60)
|
||||
|
||||
result = graphql(
|
||||
CUSTOMER_NEED_CREATE,
|
||||
{
|
||||
"input": {
|
||||
"customerId": customer_id,
|
||||
"issueId": issue_id,
|
||||
"body": "We also want this! +1 from our organization.",
|
||||
"priority": 0,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
need_data = result.get("data", {}).get("customerNeedCreate", {})
|
||||
if need_data.get("success"):
|
||||
need = need_data["need"]
|
||||
print(f" Success! Need: {need['id']}")
|
||||
print(f" Customer: {need['customer']['name']}")
|
||||
print(f" Issue: [{need['issue']['identifier']}] {need['issue']['title']}")
|
||||
return need
|
||||
else:
|
||||
print(f" Failed: {json.dumps(result, indent=2)}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
print("Linear GraphQL API - Customer Requests Test Suite")
|
||||
print("=" * 60)
|
||||
print(f"API URL: {LINEAR_API_URL}")
|
||||
print(f"API Key: {API_KEY[:10]}...")
|
||||
print(f"Project: Customer Feature Requests ({FEATURE_REQUEST_PROJECT_ID[:8]}...)")
|
||||
|
||||
# --- Read-only tests ---
|
||||
test_1_search_feature_requests()
|
||||
test_2_list_all_in_project()
|
||||
test_3_search_customers()
|
||||
|
||||
# --- Write tests ---
|
||||
print("\n" + "=" * 60)
|
||||
answer = (
|
||||
input("Run WRITE tests? (creates test customer + issue + need) [y/N]: ")
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
if answer != "y":
|
||||
print("Skipped write tests.")
|
||||
print("\nDone!")
|
||||
return
|
||||
|
||||
customer = test_4_customer_upsert()
|
||||
if not customer:
|
||||
print("Customer upsert failed, stopping.")
|
||||
return
|
||||
|
||||
issue = test_5_create_issue_and_need(customer["id"])
|
||||
if not issue:
|
||||
print("Issue creation failed, stopping.")
|
||||
return
|
||||
|
||||
# Test adding a second need to the same issue (simulates another customer requesting same feature)
|
||||
# First upsert a second customer
|
||||
result = graphql(
|
||||
CUSTOMER_UPSERT,
|
||||
{
|
||||
"input": {
|
||||
"name": "Second Test Customer",
|
||||
"domains": ["second-test.example.com"],
|
||||
"externalId": "test-customer-002",
|
||||
}
|
||||
},
|
||||
)
|
||||
customer2 = result.get("data", {}).get("customerUpsert", {}).get("customer")
|
||||
if customer2:
|
||||
test_6_add_need_to_existing(customer2["id"], issue["id"])
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All tests complete!")
|
||||
print(
|
||||
"Check the project: https://linear.app/autogpt/project/customer-feature-requests-710dcbf8bf4e/issues"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -158,6 +158,41 @@ services:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
copilot_executor:
|
||||
build:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: server
|
||||
command: ["python", "-m", "backend.copilot.executor"]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
target: autogpt_platform/backend/
|
||||
action: rebuild
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
rabbitmq:
|
||||
condition: service_healthy
|
||||
db:
|
||||
condition: service_healthy
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
database_manager:
|
||||
condition: service_started
|
||||
<<: *backend-env-files
|
||||
environment:
|
||||
<<: *backend-env
|
||||
ports:
|
||||
- "8008:8008"
|
||||
networks:
|
||||
- app-network
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "10m"
|
||||
max-file: "3"
|
||||
|
||||
websocket_server:
|
||||
build:
|
||||
context: ../
|
||||
|
||||
@@ -53,6 +53,12 @@ services:
|
||||
file: ./docker-compose.platform.yml
|
||||
service: executor
|
||||
|
||||
copilot_executor:
|
||||
<<: *agpt-services
|
||||
extends:
|
||||
file: ./docker-compose.platform.yml
|
||||
service: copilot_executor
|
||||
|
||||
websocket_server:
|
||||
<<: *agpt-services
|
||||
extends:
|
||||
@@ -174,5 +180,6 @@ services:
|
||||
- deps
|
||||
- rest_server
|
||||
- executor
|
||||
- copilot_executor
|
||||
- websocket_server
|
||||
- database_manager
|
||||
|
||||
@@ -15,10 +15,6 @@ import { ToolUIPart, UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { CreateAgentTool } from "../../tools/CreateAgent/CreateAgent";
|
||||
import { EditAgentTool } from "../../tools/EditAgent/EditAgent";
|
||||
import {
|
||||
CreateFeatureRequestTool,
|
||||
SearchFeatureRequestsTool,
|
||||
} from "../../tools/FeatureRequests/FeatureRequests";
|
||||
import { FindAgentsTool } from "../../tools/FindAgents/FindAgents";
|
||||
import { FindBlocksTool } from "../../tools/FindBlocks/FindBlocks";
|
||||
import { RunAgentTool } from "../../tools/RunAgent/RunAgent";
|
||||
@@ -163,7 +159,7 @@ export const ChatMessagesContainer = ({
|
||||
|
||||
return (
|
||||
<Conversation className="min-h-0 flex-1">
|
||||
<ConversationContent className="flex flex-1 flex-col gap-6 px-3 py-6">
|
||||
<ConversationContent className="flex min-h-screen flex-1 flex-col gap-6 px-3 py-6">
|
||||
{isLoading && messages.length === 0 && (
|
||||
<div className="flex min-h-full flex-1 items-center justify-center">
|
||||
<LoadingSpinner className="text-neutral-600" />
|
||||
@@ -258,20 +254,6 @@ export const ChatMessagesContainer = ({
|
||||
part={part as ToolUIPart}
|
||||
/>
|
||||
);
|
||||
case "tool-search_feature_requests":
|
||||
return (
|
||||
<SearchFeatureRequestsTool
|
||||
key={`${message.id}-${i}`}
|
||||
part={part as ToolUIPart}
|
||||
/>
|
||||
);
|
||||
case "tool-create_feature_request":
|
||||
return (
|
||||
<CreateFeatureRequestTool
|
||||
key={`${message.id}-${i}`}
|
||||
part={part as ToolUIPart}
|
||||
/>
|
||||
);
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -14,10 +14,6 @@ import { Text } from "@/components/atoms/Text/Text";
|
||||
import { CopilotChatActionsProvider } from "../components/CopilotChatActionsProvider/CopilotChatActionsProvider";
|
||||
import { CreateAgentTool } from "../tools/CreateAgent/CreateAgent";
|
||||
import { EditAgentTool } from "../tools/EditAgent/EditAgent";
|
||||
import {
|
||||
CreateFeatureRequestTool,
|
||||
SearchFeatureRequestsTool,
|
||||
} from "../tools/FeatureRequests/FeatureRequests";
|
||||
import { FindAgentsTool } from "../tools/FindAgents/FindAgents";
|
||||
import { FindBlocksTool } from "../tools/FindBlocks/FindBlocks";
|
||||
import { RunAgentTool } from "../tools/RunAgent/RunAgent";
|
||||
@@ -49,8 +45,6 @@ const SECTIONS = [
|
||||
"Tool: Create Agent",
|
||||
"Tool: Edit Agent",
|
||||
"Tool: View Agent Output",
|
||||
"Tool: Search Feature Requests",
|
||||
"Tool: Create Feature Request",
|
||||
"Full Conversation Example",
|
||||
] as const;
|
||||
|
||||
@@ -1427,235 +1421,6 @@ export default function StyleguidePage() {
|
||||
</SubSection>
|
||||
</Section>
|
||||
|
||||
{/* ============================================================= */}
|
||||
{/* SEARCH FEATURE REQUESTS */}
|
||||
{/* ============================================================= */}
|
||||
|
||||
<Section title="Tool: Search Feature Requests">
|
||||
<SubSection label="Input streaming">
|
||||
<SearchFeatureRequestsTool
|
||||
part={{
|
||||
type: "tool-search_feature_requests",
|
||||
toolCallId: uid(),
|
||||
state: "input-streaming",
|
||||
input: { query: "dark mode" },
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Input available">
|
||||
<SearchFeatureRequestsTool
|
||||
part={{
|
||||
type: "tool-search_feature_requests",
|
||||
toolCallId: uid(),
|
||||
state: "input-available",
|
||||
input: { query: "dark mode" },
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (with results)">
|
||||
<SearchFeatureRequestsTool
|
||||
part={{
|
||||
type: "tool-search_feature_requests",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { query: "dark mode" },
|
||||
output: {
|
||||
type: "feature_request_search",
|
||||
message:
|
||||
'Found 2 feature request(s) matching "dark mode".',
|
||||
query: "dark mode",
|
||||
count: 2,
|
||||
results: [
|
||||
{
|
||||
id: "fr-001",
|
||||
identifier: "INT-42",
|
||||
title: "Add dark mode to the platform",
|
||||
description:
|
||||
"Users have requested a dark mode option for the builder and copilot interfaces to reduce eye strain during long sessions.",
|
||||
},
|
||||
{
|
||||
id: "fr-002",
|
||||
identifier: "INT-87",
|
||||
title: "Dark theme for agent output viewer",
|
||||
description:
|
||||
"Specifically requesting dark theme support for the agent output/execution viewer panel.",
|
||||
},
|
||||
],
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (no results)">
|
||||
<SearchFeatureRequestsTool
|
||||
part={{
|
||||
type: "tool-search_feature_requests",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { query: "teleportation" },
|
||||
output: {
|
||||
type: "no_results",
|
||||
message:
|
||||
"No feature requests found matching 'teleportation'.",
|
||||
suggestions: [
|
||||
"Try different keywords",
|
||||
"Use broader search terms",
|
||||
"You can create a new feature request if none exists",
|
||||
],
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (error)">
|
||||
<SearchFeatureRequestsTool
|
||||
part={{
|
||||
type: "tool-search_feature_requests",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: { query: "dark mode" },
|
||||
output: {
|
||||
type: "error",
|
||||
message: "Failed to search feature requests.",
|
||||
error: "LINEAR_API_KEY environment variable is not set",
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output error">
|
||||
<SearchFeatureRequestsTool
|
||||
part={{
|
||||
type: "tool-search_feature_requests",
|
||||
toolCallId: uid(),
|
||||
state: "output-error",
|
||||
input: { query: "dark mode" },
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
</Section>
|
||||
|
||||
{/* ============================================================= */}
|
||||
{/* CREATE FEATURE REQUEST */}
|
||||
{/* ============================================================= */}
|
||||
|
||||
<Section title="Tool: Create Feature Request">
|
||||
<SubSection label="Input streaming">
|
||||
<CreateFeatureRequestTool
|
||||
part={{
|
||||
type: "tool-create_feature_request",
|
||||
toolCallId: uid(),
|
||||
state: "input-streaming",
|
||||
input: {
|
||||
title: "Add dark mode",
|
||||
description: "I would love dark mode for the platform.",
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Input available">
|
||||
<CreateFeatureRequestTool
|
||||
part={{
|
||||
type: "tool-create_feature_request",
|
||||
toolCallId: uid(),
|
||||
state: "input-available",
|
||||
input: {
|
||||
title: "Add dark mode",
|
||||
description: "I would love dark mode for the platform.",
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (new issue created)">
|
||||
<CreateFeatureRequestTool
|
||||
part={{
|
||||
type: "tool-create_feature_request",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: {
|
||||
title: "Add dark mode",
|
||||
description: "I would love dark mode for the platform.",
|
||||
},
|
||||
output: {
|
||||
type: "feature_request_created",
|
||||
message:
|
||||
"Created new feature request [INT-105] Add dark mode.",
|
||||
issue_id: "issue-new-123",
|
||||
issue_identifier: "INT-105",
|
||||
issue_title: "Add dark mode",
|
||||
issue_url:
|
||||
"https://linear.app/autogpt/issue/INT-105/add-dark-mode",
|
||||
is_new_issue: true,
|
||||
customer_name: "user-abc-123",
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (added to existing issue)">
|
||||
<CreateFeatureRequestTool
|
||||
part={{
|
||||
type: "tool-create_feature_request",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: {
|
||||
title: "Dark mode support",
|
||||
description:
|
||||
"Please add dark mode, it would help with long sessions.",
|
||||
existing_issue_id: "fr-001",
|
||||
},
|
||||
output: {
|
||||
type: "feature_request_created",
|
||||
message:
|
||||
"Added your request to existing feature request [INT-42] Add dark mode to the platform.",
|
||||
issue_id: "fr-001",
|
||||
issue_identifier: "INT-42",
|
||||
issue_title: "Add dark mode to the platform",
|
||||
issue_url:
|
||||
"https://linear.app/autogpt/issue/INT-42/add-dark-mode-to-the-platform",
|
||||
is_new_issue: false,
|
||||
customer_name: "user-xyz-789",
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output available (error)">
|
||||
<CreateFeatureRequestTool
|
||||
part={{
|
||||
type: "tool-create_feature_request",
|
||||
toolCallId: uid(),
|
||||
state: "output-available",
|
||||
input: {
|
||||
title: "Add dark mode",
|
||||
description: "I would love dark mode.",
|
||||
},
|
||||
output: {
|
||||
type: "error",
|
||||
message:
|
||||
"Failed to attach customer need to the feature request.",
|
||||
error: "Linear API request failed (500): Internal error",
|
||||
},
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
|
||||
<SubSection label="Output error">
|
||||
<CreateFeatureRequestTool
|
||||
part={{
|
||||
type: "tool-create_feature_request",
|
||||
toolCallId: uid(),
|
||||
state: "output-error",
|
||||
input: { title: "Add dark mode" },
|
||||
}}
|
||||
/>
|
||||
</SubSection>
|
||||
</Section>
|
||||
|
||||
{/* ============================================================= */}
|
||||
{/* FULL CONVERSATION EXAMPLE */}
|
||||
{/* ============================================================= */}
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import type { ToolUIPart } from "ai";
|
||||
import { useMemo } from "react";
|
||||
|
||||
import { MorphingTextAnimation } from "../../components/MorphingTextAnimation/MorphingTextAnimation";
|
||||
import {
|
||||
ContentBadge,
|
||||
ContentCard,
|
||||
ContentCardDescription,
|
||||
ContentCardHeader,
|
||||
ContentCardTitle,
|
||||
ContentGrid,
|
||||
ContentLink,
|
||||
ContentMessage,
|
||||
ContentSuggestionsList,
|
||||
} from "../../components/ToolAccordion/AccordionContent";
|
||||
import { ToolAccordion } from "../../components/ToolAccordion/ToolAccordion";
|
||||
import {
|
||||
AccordionIcon,
|
||||
getAccordionTitle,
|
||||
getAnimationText,
|
||||
getFeatureRequestOutput,
|
||||
isCreatedOutput,
|
||||
isErrorOutput,
|
||||
isNoResultsOutput,
|
||||
isSearchResultsOutput,
|
||||
ToolIcon,
|
||||
type FeatureRequestToolType,
|
||||
} from "./helpers";
|
||||
|
||||
export interface FeatureRequestToolPart {
|
||||
type: FeatureRequestToolType;
|
||||
toolCallId: string;
|
||||
state: ToolUIPart["state"];
|
||||
input?: unknown;
|
||||
output?: unknown;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
part: FeatureRequestToolPart;
|
||||
}
|
||||
|
||||
function truncate(text: string, maxChars: number): string {
|
||||
const trimmed = text.trim();
|
||||
if (trimmed.length <= maxChars) return trimmed;
|
||||
return `${trimmed.slice(0, maxChars).trimEnd()}…`;
|
||||
}
|
||||
|
||||
export function SearchFeatureRequestsTool({ part }: Props) {
|
||||
const output = getFeatureRequestOutput(part);
|
||||
const text = getAnimationText(part);
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError =
|
||||
part.state === "output-error" || (!!output && isErrorOutput(output));
|
||||
|
||||
const normalized = useMemo(() => {
|
||||
if (!output) return null;
|
||||
return { title: getAccordionTitle(part.type, output) };
|
||||
}, [output, part.type]);
|
||||
|
||||
const isOutputAvailable = part.state === "output-available" && !!output;
|
||||
|
||||
const searchOutput =
|
||||
isOutputAvailable && output && isSearchResultsOutput(output)
|
||||
? output
|
||||
: null;
|
||||
const noResultsOutput =
|
||||
isOutputAvailable && output && isNoResultsOutput(output) ? output : null;
|
||||
const errorOutput =
|
||||
isOutputAvailable && output && isErrorOutput(output) ? output : null;
|
||||
|
||||
const hasExpandableContent =
|
||||
isOutputAvailable &&
|
||||
((!!searchOutput && searchOutput.count > 0) ||
|
||||
!!noResultsOutput ||
|
||||
!!errorOutput);
|
||||
|
||||
const accordionDescription =
|
||||
hasExpandableContent && searchOutput
|
||||
? `Found ${searchOutput.count} result${searchOutput.count === 1 ? "" : "s"} for "${searchOutput.query}"`
|
||||
: hasExpandableContent && (noResultsOutput || errorOutput)
|
||||
? ((noResultsOutput ?? errorOutput)?.message ?? null)
|
||||
: null;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolIcon
|
||||
toolType={part.type}
|
||||
isStreaming={isStreaming}
|
||||
isError={isError}
|
||||
/>
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{hasExpandableContent && normalized && (
|
||||
<ToolAccordion
|
||||
icon={<AccordionIcon toolType={part.type} />}
|
||||
title={normalized.title}
|
||||
description={accordionDescription}
|
||||
>
|
||||
{searchOutput && (
|
||||
<ContentGrid>
|
||||
{searchOutput.results.map((r) => (
|
||||
<ContentCard key={r.id}>
|
||||
<ContentCardHeader>
|
||||
<ContentCardTitle>
|
||||
{r.identifier} — {r.title}
|
||||
</ContentCardTitle>
|
||||
</ContentCardHeader>
|
||||
{r.description && (
|
||||
<ContentCardDescription>
|
||||
{truncate(r.description, 200)}
|
||||
</ContentCardDescription>
|
||||
)}
|
||||
</ContentCard>
|
||||
))}
|
||||
</ContentGrid>
|
||||
)}
|
||||
|
||||
{noResultsOutput && (
|
||||
<div>
|
||||
<ContentMessage>{noResultsOutput.message}</ContentMessage>
|
||||
{noResultsOutput.suggestions &&
|
||||
noResultsOutput.suggestions.length > 0 && (
|
||||
<ContentSuggestionsList items={noResultsOutput.suggestions} />
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{errorOutput && (
|
||||
<div>
|
||||
<ContentMessage>{errorOutput.message}</ContentMessage>
|
||||
{errorOutput.error && (
|
||||
<ContentCardDescription>
|
||||
{errorOutput.error}
|
||||
</ContentCardDescription>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</ToolAccordion>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function CreateFeatureRequestTool({ part }: Props) {
|
||||
const output = getFeatureRequestOutput(part);
|
||||
const text = getAnimationText(part);
|
||||
const isStreaming =
|
||||
part.state === "input-streaming" || part.state === "input-available";
|
||||
const isError =
|
||||
part.state === "output-error" || (!!output && isErrorOutput(output));
|
||||
|
||||
const normalized = useMemo(() => {
|
||||
if (!output) return null;
|
||||
return { title: getAccordionTitle(part.type, output) };
|
||||
}, [output, part.type]);
|
||||
|
||||
const isOutputAvailable = part.state === "output-available" && !!output;
|
||||
|
||||
const createdOutput =
|
||||
isOutputAvailable && output && isCreatedOutput(output) ? output : null;
|
||||
const errorOutput =
|
||||
isOutputAvailable && output && isErrorOutput(output) ? output : null;
|
||||
|
||||
const hasExpandableContent =
|
||||
isOutputAvailable && (!!createdOutput || !!errorOutput);
|
||||
|
||||
const accordionDescription =
|
||||
hasExpandableContent && createdOutput
|
||||
? `${createdOutput.issue_identifier} — ${createdOutput.issue_title}`
|
||||
: hasExpandableContent && errorOutput
|
||||
? errorOutput.message
|
||||
: null;
|
||||
|
||||
return (
|
||||
<div className="py-2">
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<ToolIcon
|
||||
toolType={part.type}
|
||||
isStreaming={isStreaming}
|
||||
isError={isError}
|
||||
/>
|
||||
<MorphingTextAnimation
|
||||
text={text}
|
||||
className={isError ? "text-red-500" : undefined}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{hasExpandableContent && normalized && (
|
||||
<ToolAccordion
|
||||
icon={<AccordionIcon toolType={part.type} />}
|
||||
title={normalized.title}
|
||||
description={accordionDescription}
|
||||
>
|
||||
{createdOutput && (
|
||||
<ContentCard>
|
||||
<ContentCardHeader
|
||||
action={
|
||||
createdOutput.issue_url ? (
|
||||
<ContentLink href={createdOutput.issue_url}>
|
||||
View
|
||||
</ContentLink>
|
||||
) : undefined
|
||||
}
|
||||
>
|
||||
<ContentCardTitle>
|
||||
{createdOutput.issue_identifier} — {createdOutput.issue_title}
|
||||
</ContentCardTitle>
|
||||
</ContentCardHeader>
|
||||
<div className="mt-2 flex items-center gap-2">
|
||||
<ContentBadge>
|
||||
{createdOutput.is_new_issue ? "New" : "Existing"}
|
||||
</ContentBadge>
|
||||
</div>
|
||||
<ContentMessage>{createdOutput.message}</ContentMessage>
|
||||
</ContentCard>
|
||||
)}
|
||||
|
||||
{errorOutput && (
|
||||
<div>
|
||||
<ContentMessage>{errorOutput.message}</ContentMessage>
|
||||
{errorOutput.error && (
|
||||
<ContentCardDescription>
|
||||
{errorOutput.error}
|
||||
</ContentCardDescription>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</ToolAccordion>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,271 +0,0 @@
|
||||
import {
|
||||
CheckCircleIcon,
|
||||
LightbulbIcon,
|
||||
MagnifyingGlassIcon,
|
||||
PlusCircleIcon,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { ToolUIPart } from "ai";
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Types (local until API client is regenerated) */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
interface FeatureRequestInfo {
|
||||
id: string;
|
||||
identifier: string;
|
||||
title: string;
|
||||
description?: string | null;
|
||||
}
|
||||
|
||||
export interface FeatureRequestSearchResponse {
|
||||
type: "feature_request_search";
|
||||
message: string;
|
||||
results: FeatureRequestInfo[];
|
||||
count: number;
|
||||
query: string;
|
||||
}
|
||||
|
||||
export interface FeatureRequestCreatedResponse {
|
||||
type: "feature_request_created";
|
||||
message: string;
|
||||
issue_id: string;
|
||||
issue_identifier: string;
|
||||
issue_title: string;
|
||||
issue_url: string;
|
||||
is_new_issue: boolean;
|
||||
customer_name: string;
|
||||
}
|
||||
|
||||
interface NoResultsResponse {
|
||||
type: "no_results";
|
||||
message: string;
|
||||
suggestions?: string[];
|
||||
}
|
||||
|
||||
interface ErrorResponse {
|
||||
type: "error";
|
||||
message: string;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export type FeatureRequestOutput =
|
||||
| FeatureRequestSearchResponse
|
||||
| FeatureRequestCreatedResponse
|
||||
| NoResultsResponse
|
||||
| ErrorResponse;
|
||||
|
||||
export type FeatureRequestToolType =
|
||||
| "tool-search_feature_requests"
|
||||
| "tool-create_feature_request"
|
||||
| string;
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Output parsing */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
function parseOutput(output: unknown): FeatureRequestOutput | null {
|
||||
if (!output) return null;
|
||||
if (typeof output === "string") {
|
||||
const trimmed = output.trim();
|
||||
if (!trimmed) return null;
|
||||
try {
|
||||
return parseOutput(JSON.parse(trimmed) as unknown);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
if (typeof output === "object") {
|
||||
const type = (output as { type?: unknown }).type;
|
||||
if (
|
||||
type === "feature_request_search" ||
|
||||
type === "feature_request_created" ||
|
||||
type === "no_results" ||
|
||||
type === "error"
|
||||
) {
|
||||
return output as FeatureRequestOutput;
|
||||
}
|
||||
// Fallback structural checks
|
||||
if ("results" in output && "query" in output)
|
||||
return output as FeatureRequestSearchResponse;
|
||||
if ("issue_identifier" in output)
|
||||
return output as FeatureRequestCreatedResponse;
|
||||
if ("suggestions" in output && !("error" in output))
|
||||
return output as NoResultsResponse;
|
||||
if ("error" in output || "details" in output)
|
||||
return output as ErrorResponse;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
export function getFeatureRequestOutput(
|
||||
part: unknown,
|
||||
): FeatureRequestOutput | null {
|
||||
if (!part || typeof part !== "object") return null;
|
||||
return parseOutput((part as { output?: unknown }).output);
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Type guards */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export function isSearchResultsOutput(
|
||||
output: FeatureRequestOutput,
|
||||
): output is FeatureRequestSearchResponse {
|
||||
return (
|
||||
output.type === "feature_request_search" ||
|
||||
("results" in output && "query" in output)
|
||||
);
|
||||
}
|
||||
|
||||
export function isCreatedOutput(
|
||||
output: FeatureRequestOutput,
|
||||
): output is FeatureRequestCreatedResponse {
|
||||
return (
|
||||
output.type === "feature_request_created" || "issue_identifier" in output
|
||||
);
|
||||
}
|
||||
|
||||
export function isNoResultsOutput(
|
||||
output: FeatureRequestOutput,
|
||||
): output is NoResultsResponse {
|
||||
return (
|
||||
output.type === "no_results" ||
|
||||
("suggestions" in output && !("error" in output))
|
||||
);
|
||||
}
|
||||
|
||||
export function isErrorOutput(
|
||||
output: FeatureRequestOutput,
|
||||
): output is ErrorResponse {
|
||||
return output.type === "error" || "error" in output;
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Accordion metadata */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export function getAccordionTitle(
|
||||
toolType: FeatureRequestToolType,
|
||||
output: FeatureRequestOutput,
|
||||
): string {
|
||||
if (toolType === "tool-search_feature_requests") {
|
||||
if (isSearchResultsOutput(output)) return "Feature requests";
|
||||
if (isNoResultsOutput(output)) return "No feature requests found";
|
||||
return "Feature request search error";
|
||||
}
|
||||
if (isCreatedOutput(output)) {
|
||||
return output.is_new_issue
|
||||
? "Feature request created"
|
||||
: "Added to feature request";
|
||||
}
|
||||
if (isErrorOutput(output)) return "Feature request error";
|
||||
return "Feature request";
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Animation text */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
interface AnimationPart {
|
||||
type: FeatureRequestToolType;
|
||||
state: ToolUIPart["state"];
|
||||
input?: unknown;
|
||||
output?: unknown;
|
||||
}
|
||||
|
||||
export function getAnimationText(part: AnimationPart): string {
|
||||
if (part.type === "tool-search_feature_requests") {
|
||||
const query = (part.input as { query?: string } | undefined)?.query?.trim();
|
||||
const queryText = query ? ` for "${query}"` : "";
|
||||
|
||||
switch (part.state) {
|
||||
case "input-streaming":
|
||||
case "input-available":
|
||||
return `Searching feature requests${queryText}`;
|
||||
case "output-available": {
|
||||
const output = parseOutput(part.output);
|
||||
if (!output) return `Searching feature requests${queryText}`;
|
||||
if (isSearchResultsOutput(output)) {
|
||||
return `Found ${output.count} feature request${output.count === 1 ? "" : "s"}${queryText}`;
|
||||
}
|
||||
if (isNoResultsOutput(output))
|
||||
return `No feature requests found${queryText}`;
|
||||
return `Error searching feature requests${queryText}`;
|
||||
}
|
||||
case "output-error":
|
||||
return `Error searching feature requests${queryText}`;
|
||||
default:
|
||||
return "Searching feature requests";
|
||||
}
|
||||
}
|
||||
|
||||
// create_feature_request
|
||||
const title = (part.input as { title?: string } | undefined)?.title?.trim();
|
||||
const titleText = title ? ` "${title}"` : "";
|
||||
|
||||
switch (part.state) {
|
||||
case "input-streaming":
|
||||
case "input-available":
|
||||
return `Creating feature request${titleText}`;
|
||||
case "output-available": {
|
||||
const output = parseOutput(part.output);
|
||||
if (!output) return `Creating feature request${titleText}`;
|
||||
if (isCreatedOutput(output)) {
|
||||
return output.is_new_issue
|
||||
? `Created ${output.issue_identifier}`
|
||||
: `Added to ${output.issue_identifier}`;
|
||||
}
|
||||
if (isErrorOutput(output)) return "Error creating feature request";
|
||||
return `Created feature request${titleText}`;
|
||||
}
|
||||
case "output-error":
|
||||
return "Error creating feature request";
|
||||
default:
|
||||
return "Creating feature request";
|
||||
}
|
||||
}
|
||||
|
||||
/* ------------------------------------------------------------------ */
|
||||
/* Icons */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
export function ToolIcon({
|
||||
toolType,
|
||||
isStreaming,
|
||||
isError,
|
||||
}: {
|
||||
toolType: FeatureRequestToolType;
|
||||
isStreaming?: boolean;
|
||||
isError?: boolean;
|
||||
}) {
|
||||
const IconComponent =
|
||||
toolType === "tool-create_feature_request"
|
||||
? PlusCircleIcon
|
||||
: MagnifyingGlassIcon;
|
||||
|
||||
return (
|
||||
<IconComponent
|
||||
size={14}
|
||||
weight="regular"
|
||||
className={
|
||||
isError
|
||||
? "text-red-500"
|
||||
: isStreaming
|
||||
? "text-neutral-500"
|
||||
: "text-neutral-400"
|
||||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export function AccordionIcon({
|
||||
toolType,
|
||||
}: {
|
||||
toolType: FeatureRequestToolType;
|
||||
}) {
|
||||
const IconComponent =
|
||||
toolType === "tool-create_feature_request"
|
||||
? CheckCircleIcon
|
||||
: LightbulbIcon;
|
||||
return <IconComponent size={32} weight="light" />;
|
||||
}
|
||||
@@ -10495,9 +10495,7 @@
|
||||
"operation_started",
|
||||
"operation_pending",
|
||||
"operation_in_progress",
|
||||
"input_validation_error",
|
||||
"feature_request_search",
|
||||
"feature_request_created"
|
||||
"input_validation_error"
|
||||
],
|
||||
"title": "ResponseType",
|
||||
"description": "Types of tool responses."
|
||||
|
||||
@@ -180,14 +180,3 @@ body[data-google-picker-open="true"] [data-dialog-content] {
|
||||
z-index: 1 !important;
|
||||
pointer-events: none !important;
|
||||
}
|
||||
|
||||
/* CoPilot chat table styling — remove left/right borders, increase padding */
|
||||
[data-streamdown="table-wrapper"] table {
|
||||
border-left: none;
|
||||
border-right: none;
|
||||
}
|
||||
|
||||
[data-streamdown="table-wrapper"] th,
|
||||
[data-streamdown="table-wrapper"] td {
|
||||
padding: 0.875rem 1rem; /* py-3.5 px-4 */
|
||||
}
|
||||
|
||||
@@ -30,7 +30,6 @@ export function APIKeyCredentialsModal({
|
||||
const {
|
||||
form,
|
||||
isLoading,
|
||||
isSubmitting,
|
||||
supportsApiKey,
|
||||
providerName,
|
||||
schemaDescription,
|
||||
@@ -139,12 +138,7 @@ export function APIKeyCredentialsModal({
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
<Button
|
||||
type="submit"
|
||||
className="min-w-68"
|
||||
loading={isSubmitting}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
<Button type="submit" className="min-w-68">
|
||||
Add API Key
|
||||
</Button>
|
||||
</form>
|
||||
|
||||
@@ -4,7 +4,6 @@ import {
|
||||
CredentialsMetaInput,
|
||||
} from "@/lib/autogpt-server-api/types";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useState } from "react";
|
||||
import { useForm, type UseFormReturn } from "react-hook-form";
|
||||
import { z } from "zod";
|
||||
|
||||
@@ -27,7 +26,6 @@ export function useAPIKeyCredentialsModal({
|
||||
}: Args): {
|
||||
form: UseFormReturn<APIKeyFormValues>;
|
||||
isLoading: boolean;
|
||||
isSubmitting: boolean;
|
||||
supportsApiKey: boolean;
|
||||
provider?: string;
|
||||
providerName?: string;
|
||||
@@ -35,7 +33,6 @@ export function useAPIKeyCredentialsModal({
|
||||
onSubmit: (values: APIKeyFormValues) => Promise<void>;
|
||||
} {
|
||||
const credentials = useCredentials(schema, siblingInputs);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
|
||||
const formSchema = z.object({
|
||||
apiKey: z.string().min(1, "API Key is required"),
|
||||
@@ -43,42 +40,48 @@ export function useAPIKeyCredentialsModal({
|
||||
expiresAt: z.string().optional(),
|
||||
});
|
||||
|
||||
function getDefaultExpirationDate(): string {
|
||||
const tomorrow = new Date();
|
||||
tomorrow.setDate(tomorrow.getDate() + 1);
|
||||
tomorrow.setHours(0, 0, 0, 0);
|
||||
const year = tomorrow.getFullYear();
|
||||
const month = String(tomorrow.getMonth() + 1).padStart(2, "0");
|
||||
const day = String(tomorrow.getDate()).padStart(2, "0");
|
||||
const hours = String(tomorrow.getHours()).padStart(2, "0");
|
||||
const minutes = String(tomorrow.getMinutes()).padStart(2, "0");
|
||||
return `${year}-${month}-${day}T${hours}:${minutes}`;
|
||||
}
|
||||
|
||||
const form = useForm<APIKeyFormValues>({
|
||||
resolver: zodResolver(formSchema),
|
||||
defaultValues: {
|
||||
apiKey: "",
|
||||
title: "",
|
||||
expiresAt: "",
|
||||
expiresAt: getDefaultExpirationDate(),
|
||||
},
|
||||
});
|
||||
|
||||
async function onSubmit(values: APIKeyFormValues) {
|
||||
if (!credentials || credentials.isLoading) return;
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
const expiresAt = values.expiresAt
|
||||
? new Date(values.expiresAt).getTime() / 1000
|
||||
: undefined;
|
||||
const newCredentials = await credentials.createAPIKeyCredentials({
|
||||
api_key: values.apiKey,
|
||||
title: values.title,
|
||||
expires_at: expiresAt,
|
||||
});
|
||||
onCredentialsCreate({
|
||||
provider: credentials.provider,
|
||||
id: newCredentials.id,
|
||||
type: "api_key",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
const expiresAt = values.expiresAt
|
||||
? new Date(values.expiresAt).getTime() / 1000
|
||||
: undefined;
|
||||
const newCredentials = await credentials.createAPIKeyCredentials({
|
||||
api_key: values.apiKey,
|
||||
title: values.title,
|
||||
expires_at: expiresAt,
|
||||
});
|
||||
onCredentialsCreate({
|
||||
provider: credentials.provider,
|
||||
id: newCredentials.id,
|
||||
type: "api_key",
|
||||
title: newCredentials.title,
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
form,
|
||||
isLoading: !credentials || credentials.isLoading,
|
||||
isSubmitting,
|
||||
supportsApiKey: !!credentials?.supportsApiKey,
|
||||
provider: credentials?.provider,
|
||||
providerName:
|
||||
|
||||
@@ -226,7 +226,7 @@ function renderMarkdown(
|
||||
table: ({ children, ...props }) => (
|
||||
<div className="my-4 overflow-x-auto">
|
||||
<table
|
||||
className="min-w-full divide-y divide-gray-200 border-y border-gray-200 dark:divide-gray-700 dark:border-gray-700"
|
||||
className="min-w-full divide-y divide-gray-200 rounded-lg border border-gray-200 dark:divide-gray-700 dark:border-gray-700"
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
@@ -235,7 +235,7 @@ function renderMarkdown(
|
||||
),
|
||||
th: ({ children, ...props }) => (
|
||||
<th
|
||||
className="bg-gray-50 px-4 py-3.5 text-left text-xs font-semibold uppercase tracking-wider text-gray-700 dark:bg-gray-800 dark:text-gray-300"
|
||||
className="bg-gray-50 px-4 py-3 text-left text-xs font-semibold uppercase tracking-wider text-gray-700 dark:bg-gray-800 dark:text-gray-300"
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
@@ -243,7 +243,7 @@ function renderMarkdown(
|
||||
),
|
||||
td: ({ children, ...props }) => (
|
||||
<td
|
||||
className="border-t border-gray-200 px-4 py-3.5 text-sm text-gray-600 dark:border-gray-700 dark:text-gray-400"
|
||||
className="border-t border-gray-200 px-4 py-3 text-sm text-gray-600 dark:border-gray-700 dark:text-gray-400"
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
|
||||
Reference in New Issue
Block a user