mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-17 18:21:46 -05:00
Compare commits
10 Commits
docs/deplo
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9ba7e51db | ||
|
|
d23248f065 | ||
|
|
905373a712 | ||
|
|
ee9d39bc0f | ||
|
|
05aaf7a85e | ||
|
|
9d4dcbd9e0 | ||
|
|
074be7aea6 | ||
|
|
39d28b24fc | ||
|
|
bf79a7748a | ||
|
|
649d4ab7f5 |
9
.github/workflows/platform-backend-ci.yml
vendored
9
.github/workflows/platform-backend-ci.yml
vendored
@@ -41,13 +41,18 @@ jobs:
|
|||||||
ports:
|
ports:
|
||||||
- 6379:6379
|
- 6379:6379
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
image: rabbitmq:3.12-management
|
image: rabbitmq:4.1.4
|
||||||
ports:
|
ports:
|
||||||
- 5672:5672
|
- 5672:5672
|
||||||
- 15672:15672
|
|
||||||
env:
|
env:
|
||||||
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
||||||
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
||||||
|
options: >-
|
||||||
|
--health-cmd "rabbitmq-diagnostics -q ping"
|
||||||
|
--health-interval 30s
|
||||||
|
--health-timeout 10s
|
||||||
|
--health-retries 5
|
||||||
|
--health-start-period 10s
|
||||||
clamav:
|
clamav:
|
||||||
image: clamav/clamav-debian:latest
|
image: clamav/clamav-debian:latest
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
6
.github/workflows/platform-frontend-ci.yml
vendored
6
.github/workflows/platform-frontend-ci.yml
vendored
@@ -6,10 +6,16 @@ on:
|
|||||||
paths:
|
paths:
|
||||||
- ".github/workflows/platform-frontend-ci.yml"
|
- ".github/workflows/platform-frontend-ci.yml"
|
||||||
- "autogpt_platform/frontend/**"
|
- "autogpt_platform/frontend/**"
|
||||||
|
- "autogpt_platform/backend/Dockerfile"
|
||||||
|
- "autogpt_platform/docker-compose.yml"
|
||||||
|
- "autogpt_platform/docker-compose.platform.yml"
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- ".github/workflows/platform-frontend-ci.yml"
|
- ".github/workflows/platform-frontend-ci.yml"
|
||||||
- "autogpt_platform/frontend/**"
|
- "autogpt_platform/frontend/**"
|
||||||
|
- "autogpt_platform/backend/Dockerfile"
|
||||||
|
- "autogpt_platform/docker-compose.yml"
|
||||||
|
- "autogpt_platform/docker-compose.platform.yml"
|
||||||
merge_group:
|
merge_group:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
|||||||
@@ -53,63 +53,6 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
|
|||||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||||
|
|
||||||
# ============================== BACKEND SERVER ============================== #
|
|
||||||
|
|
||||||
FROM debian:13-slim AS server
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
ENV POETRY_HOME=/opt/poetry \
|
|
||||||
POETRY_NO_INTERACTION=1 \
|
|
||||||
POETRY_VIRTUALENVS_CREATE=true \
|
|
||||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
|
||||||
DEBIAN_FRONTEND=noninteractive
|
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
|
||||||
|
|
||||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
|
||||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
|
||||||
# for the bash_exec MCP tool.
|
|
||||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
python3.13 \
|
|
||||||
python3-pip \
|
|
||||||
ffmpeg \
|
|
||||||
imagemagick \
|
|
||||||
jq \
|
|
||||||
ripgrep \
|
|
||||||
tree \
|
|
||||||
bubblewrap \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
|
||||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
|
||||||
# Copy Node.js installation for Prisma
|
|
||||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
|
||||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
|
||||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
|
||||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
|
||||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
|
||||||
|
|
||||||
WORKDIR /app/autogpt_platform/backend
|
|
||||||
|
|
||||||
# Copy only the .venv from builder (not the entire /app directory)
|
|
||||||
# The .venv includes the generated Prisma client
|
|
||||||
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
|
|
||||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
|
||||||
|
|
||||||
# Copy dependency files + autogpt_libs (path dependency)
|
|
||||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
|
||||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
|
|
||||||
|
|
||||||
# Copy backend code + docs (for Copilot docs search)
|
|
||||||
COPY autogpt_platform/backend ./
|
|
||||||
COPY docs /app/docs
|
|
||||||
RUN poetry install --no-ansi --only-root
|
|
||||||
|
|
||||||
ENV PORT=8000
|
|
||||||
|
|
||||||
CMD ["poetry", "run", "rest"]
|
|
||||||
|
|
||||||
# =============================== DB MIGRATOR =============================== #
|
# =============================== DB MIGRATOR =============================== #
|
||||||
|
|
||||||
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
|
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
|
||||||
@@ -141,3 +84,59 @@ COPY autogpt_platform/backend/schema.prisma ./
|
|||||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||||
COPY autogpt_platform/backend/migrations ./migrations
|
COPY autogpt_platform/backend/migrations ./migrations
|
||||||
|
|
||||||
|
# ============================== BACKEND SERVER ============================== #
|
||||||
|
|
||||||
|
FROM debian:13-slim AS server
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||||
|
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||||
|
# for the bash_exec MCP tool.
|
||||||
|
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
python3.13 \
|
||||||
|
python3-pip \
|
||||||
|
ffmpeg \
|
||||||
|
imagemagick \
|
||||||
|
jq \
|
||||||
|
ripgrep \
|
||||||
|
tree \
|
||||||
|
bubblewrap \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
||||||
|
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||||
|
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||||
|
# Copy Node.js installation for Prisma
|
||||||
|
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||||
|
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||||
|
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||||
|
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||||
|
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||||
|
|
||||||
|
WORKDIR /app/autogpt_platform/backend
|
||||||
|
|
||||||
|
# Copy only the .venv from builder (not the entire /app directory)
|
||||||
|
# The .venv includes the generated Prisma client
|
||||||
|
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
|
||||||
|
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||||
|
|
||||||
|
# Copy dependency files + autogpt_libs (path dependency)
|
||||||
|
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||||
|
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
|
||||||
|
|
||||||
|
# Copy backend code + docs (for Copilot docs search)
|
||||||
|
COPY autogpt_platform/backend ./
|
||||||
|
COPY docs /app/docs
|
||||||
|
# Install the project package to create entry point scripts in .venv/bin/
|
||||||
|
# (e.g., rest, executor, ws, db, scheduler, notification - see [tool.poetry.scripts])
|
||||||
|
RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||||
|
poetry install --no-ansi --only-root
|
||||||
|
|
||||||
|
ENV PORT=8000
|
||||||
|
|
||||||
|
CMD ["rest"]
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
"""Common test fixtures for server tests."""
|
"""Common test fixtures for server tests.
|
||||||
|
|
||||||
|
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
|
||||||
|
setup_test_user, and setup_admin_user are defined in the parent conftest.py
|
||||||
|
(backend/conftest.py) and are available here automatically.
|
||||||
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest_snapshot.plugin import Snapshot
|
from pytest_snapshot.plugin import Snapshot
|
||||||
@@ -11,54 +16,6 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
|||||||
return snapshot
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_user_id() -> str:
|
|
||||||
"""Test user ID fixture."""
|
|
||||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def admin_user_id() -> str:
|
|
||||||
"""Admin user ID fixture."""
|
|
||||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def target_user_id() -> str:
|
|
||||||
"""Target user ID fixture."""
|
|
||||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def setup_test_user(test_user_id):
|
|
||||||
"""Create test user in database before tests."""
|
|
||||||
from backend.data.user import get_or_create_user
|
|
||||||
|
|
||||||
# Create the test user in the database using JWT token format
|
|
||||||
user_data = {
|
|
||||||
"sub": test_user_id,
|
|
||||||
"email": "test@example.com",
|
|
||||||
"user_metadata": {"name": "Test User"},
|
|
||||||
}
|
|
||||||
await get_or_create_user(user_data)
|
|
||||||
return test_user_id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def setup_admin_user(admin_user_id):
|
|
||||||
"""Create admin user in database before tests."""
|
|
||||||
from backend.data.user import get_or_create_user
|
|
||||||
|
|
||||||
# Create the admin user in the database using JWT token format
|
|
||||||
user_data = {
|
|
||||||
"sub": admin_user_id,
|
|
||||||
"email": "test-admin@example.com",
|
|
||||||
"user_metadata": {"name": "Test Admin"},
|
|
||||||
}
|
|
||||||
await get_or_create_user(user_data)
|
|
||||||
return admin_user_id
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_jwt_user(test_user_id):
|
def mock_jwt_user(test_user_id):
|
||||||
"""Provide mock JWT payload for regular user testing."""
|
"""Provide mock JWT payload for regular user testing."""
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from backend.api.external.middleware import require_permission
|
from backend.api.external.middleware import require_permission
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
from backend.copilot.tools import find_agent_tool, run_agent_tool
|
||||||
from backend.api.features.chat.tools.models import ToolResponseBase
|
from backend.copilot.tools.models import ToolResponseBase
|
||||||
from backend.data.auth.base import APIAuthorizationInfo
|
from backend.data.auth.base import APIAuthorizationInfo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -11,24 +11,25 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response,
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.copilot import service as chat_service
|
||||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
from backend.copilot import stream_registry
|
||||||
|
from backend.copilot.completion_handler import (
|
||||||
from . import service as chat_service
|
process_operation_failure,
|
||||||
from . import stream_registry
|
process_operation_success,
|
||||||
from .completion_handler import process_operation_failure, process_operation_success
|
)
|
||||||
from .config import ChatConfig
|
from backend.copilot.config import ChatConfig
|
||||||
from .model import (
|
from backend.copilot.executor.utils import enqueue_copilot_task
|
||||||
|
from backend.copilot.model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatSession,
|
ChatSession,
|
||||||
append_and_save_message,
|
append_and_save_message,
|
||||||
create_chat_session,
|
create_chat_session,
|
||||||
|
delete_chat_session,
|
||||||
get_chat_session,
|
get_chat_session,
|
||||||
get_user_sessions,
|
get_user_sessions,
|
||||||
)
|
)
|
||||||
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
|
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||||
from .sdk import service as sdk_service
|
from backend.copilot.tools.models import (
|
||||||
from .tools.models import (
|
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
AgentOutputResponse,
|
AgentOutputResponse,
|
||||||
AgentPreviewResponse,
|
AgentPreviewResponse,
|
||||||
@@ -51,7 +52,8 @@ from .tools.models import (
|
|||||||
SetupRequirementsResponse,
|
SetupRequirementsResponse,
|
||||||
UnderstandingUpdatedResponse,
|
UnderstandingUpdatedResponse,
|
||||||
)
|
)
|
||||||
from .tracking import track_user_message
|
from backend.copilot.tracking import track_user_message
|
||||||
|
from backend.util.exceptions import NotFoundError
|
||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
@@ -211,6 +213,43 @@ async def create_session(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete(
|
||||||
|
"/sessions/{session_id}",
|
||||||
|
dependencies=[Security(auth.requires_user)],
|
||||||
|
status_code=204,
|
||||||
|
responses={404: {"description": "Session not found or access denied"}},
|
||||||
|
)
|
||||||
|
async def delete_session(
|
||||||
|
session_id: str,
|
||||||
|
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||||
|
) -> Response:
|
||||||
|
"""
|
||||||
|
Delete a chat session.
|
||||||
|
|
||||||
|
Permanently removes a chat session and all its messages.
|
||||||
|
Only the owner can delete their sessions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session ID to delete.
|
||||||
|
user_id: The authenticated user's ID.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
204 No Content on success.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 404 if session not found or not owned by user.
|
||||||
|
"""
|
||||||
|
deleted = await delete_chat_session(session_id, user_id)
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Session {session_id} not found or access denied",
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get(
|
||||||
"/sessions/{session_id}",
|
"/sessions/{session_id}",
|
||||||
)
|
)
|
||||||
@@ -316,7 +355,7 @@ async def stream_chat_post(
|
|||||||
f"user={user_id}, message_len={len(request.message)}",
|
f"user={user_id}, message_len={len(request.message)}",
|
||||||
extra={"json_fields": log_meta},
|
extra={"json_fields": log_meta},
|
||||||
)
|
)
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
await _validate_and_get_session(session_id, user_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||||
extra={
|
extra={
|
||||||
@@ -343,7 +382,7 @@ async def stream_chat_post(
|
|||||||
message_length=len(request.message),
|
message_length=len(request.message),
|
||||||
)
|
)
|
||||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||||
session = await append_and_save_message(session_id, message)
|
await append_and_save_message(session_id, message)
|
||||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||||
|
|
||||||
# Create a task in the stream registry for reconnection support
|
# Create a task in the stream registry for reconnection support
|
||||||
@@ -370,125 +409,19 @@ async def stream_chat_post(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Background task that runs the AI generation independently of SSE connection
|
await enqueue_copilot_task(
|
||||||
async def run_ai_generation():
|
task_id=task_id,
|
||||||
import time as time_module
|
session_id=session_id,
|
||||||
|
|
||||||
gen_start_time = time_module.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
first_chunk_time, ttfc = None, None
|
|
||||||
chunk_count = 0
|
|
||||||
try:
|
|
||||||
# Emit a start event with task_id for reconnection
|
|
||||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
|
||||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
|
|
||||||
* 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Choose service based on LaunchDarkly flag (falls back to config default)
|
|
||||||
use_sdk = await is_feature_enabled(
|
|
||||||
Flag.COPILOT_SDK,
|
|
||||||
user_id or "anonymous",
|
|
||||||
default=config.use_claude_agent_sdk,
|
|
||||||
)
|
|
||||||
stream_fn = (
|
|
||||||
sdk_service.stream_chat_completion_sdk
|
|
||||||
if use_sdk
|
|
||||||
else chat_service.stream_chat_completion
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
|
|
||||||
extra={"json_fields": log_meta},
|
|
||||||
)
|
|
||||||
# Pass message=None since we already added it to the session above
|
|
||||||
async for chunk in stream_fn(
|
|
||||||
session_id,
|
|
||||||
None, # Message already in session
|
|
||||||
is_user_message=request.is_user_message,
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session=session, # Pass session with message already added
|
operation_id=operation_id,
|
||||||
|
message=request.message,
|
||||||
|
is_user_message=request.is_user_message,
|
||||||
context=request.context,
|
context=request.context,
|
||||||
):
|
|
||||||
# Skip duplicate StreamStart — we already published one above
|
|
||||||
if isinstance(chunk, StreamStart):
|
|
||||||
continue
|
|
||||||
chunk_count += 1
|
|
||||||
if first_chunk_time is None:
|
|
||||||
first_chunk_time = time_module.perf_counter()
|
|
||||||
ttfc = first_chunk_time - gen_start_time
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"chunk_type": type(chunk).__name__,
|
|
||||||
"time_to_first_chunk_ms": ttfc * 1000,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
# Write to Redis (subscribers will receive via XREAD)
|
|
||||||
await stream_registry.publish_chunk(task_id, chunk)
|
|
||||||
|
|
||||||
gen_end_time = time_module.perf_counter()
|
|
||||||
total_time = (gen_end_time - gen_start_time) * 1000
|
|
||||||
logger.info(
|
|
||||||
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
|
|
||||||
f"task={task_id}, session={session_id}, "
|
|
||||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"total_time_ms": total_time,
|
|
||||||
"time_to_first_chunk_ms": (
|
|
||||||
ttfc * 1000 if ttfc is not None else None
|
|
||||||
),
|
|
||||||
"n_chunks": chunk_count,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
await stream_registry.mark_task_completed(task_id, "completed")
|
|
||||||
except Exception as e:
|
|
||||||
elapsed = time_module.perf_counter() - gen_start_time
|
|
||||||
logger.error(
|
|
||||||
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
|
||||||
extra={
|
|
||||||
"json_fields": {
|
|
||||||
**log_meta,
|
|
||||||
"elapsed_ms": elapsed * 1000,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# Publish a StreamError so the frontend can display an error message
|
|
||||||
try:
|
|
||||||
await stream_registry.publish_chunk(
|
|
||||||
task_id,
|
|
||||||
StreamError(
|
|
||||||
errorText="An error occurred. Please try again.",
|
|
||||||
code="stream_error",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass # Best-effort; mark_task_completed will publish StreamFinish
|
|
||||||
await stream_registry.mark_task_completed(task_id, "failed")
|
|
||||||
|
|
||||||
# Start the AI generation in a background task
|
|
||||||
bg_task = asyncio.create_task(run_ai_generation())
|
|
||||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
|
||||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
|
||||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import fastapi
|
|||||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
|
|
||||||
from backend.data.workspace import get_workspace, get_workspace_file
|
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
|
||||||
from backend.util.workspace_storage import get_workspace_storage
|
from backend.util.workspace_storage import get_workspace_storage
|
||||||
|
|
||||||
|
|
||||||
@@ -44,11 +44,11 @@ router = fastapi.APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _create_streaming_response(content: bytes, file) -> Response:
|
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
|
||||||
"""Create a streaming response for file content."""
|
"""Create a streaming response for file content."""
|
||||||
return Response(
|
return Response(
|
||||||
content=content,
|
content=content,
|
||||||
media_type=file.mimeType,
|
media_type=file.mime_type,
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||||
"Content-Length": str(len(content)),
|
"Content-Length": str(len(content)),
|
||||||
@@ -56,7 +56,7 @@ def _create_streaming_response(content: bytes, file) -> Response:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _create_file_download_response(file) -> Response:
|
async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||||
"""
|
"""
|
||||||
Create a download response for a workspace file.
|
Create a download response for a workspace file.
|
||||||
|
|
||||||
@@ -66,33 +66,33 @@ async def _create_file_download_response(file) -> Response:
|
|||||||
storage = await get_workspace_storage()
|
storage = await get_workspace_storage()
|
||||||
|
|
||||||
# For local storage, stream the file directly
|
# For local storage, stream the file directly
|
||||||
if file.storagePath.startswith("local://"):
|
if file.storage_path.startswith("local://"):
|
||||||
content = await storage.retrieve(file.storagePath)
|
content = await storage.retrieve(file.storage_path)
|
||||||
return _create_streaming_response(content, file)
|
return _create_streaming_response(content, file)
|
||||||
|
|
||||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||||
try:
|
try:
|
||||||
url = await storage.get_download_url(file.storagePath, expires_in=300)
|
url = await storage.get_download_url(file.storage_path, expires_in=300)
|
||||||
# If we got back an API path (fallback), stream directly instead
|
# If we got back an API path (fallback), stream directly instead
|
||||||
if url.startswith("/api/"):
|
if url.startswith("/api/"):
|
||||||
content = await storage.retrieve(file.storagePath)
|
content = await storage.retrieve(file.storage_path)
|
||||||
return _create_streaming_response(content, file)
|
return _create_streaming_response(content, file)
|
||||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log the signed URL failure with context
|
# Log the signed URL failure with context
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to get signed URL for file {file.id} "
|
f"Failed to get signed URL for file {file.id} "
|
||||||
f"(storagePath={file.storagePath}): {e}",
|
f"(storagePath={file.storage_path}): {e}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# Fall back to streaming directly from GCS
|
# Fall back to streaming directly from GCS
|
||||||
try:
|
try:
|
||||||
content = await storage.retrieve(file.storagePath)
|
content = await storage.retrieve(file.storage_path)
|
||||||
return _create_streaming_response(content, file)
|
return _create_streaming_response(content, file)
|
||||||
except Exception as fallback_error:
|
except Exception as fallback_error:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Fallback streaming also failed for file {file.id} "
|
f"Fallback streaming also failed for file {file.id} "
|
||||||
f"(storagePath={file.storagePath}): {fallback_error}",
|
f"(storagePath={file.storage_path}): {fallback_error}",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -41,11 +41,11 @@ import backend.data.user
|
|||||||
import backend.integrations.webhooks.utils
|
import backend.integrations.webhooks.utils
|
||||||
import backend.util.service
|
import backend.util.service
|
||||||
import backend.util.settings
|
import backend.util.settings
|
||||||
from backend.api.features.chat.completion_consumer import (
|
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||||
|
from backend.copilot.completion_consumer import (
|
||||||
start_completion_consumer,
|
start_completion_consumer,
|
||||||
stop_completion_consumer,
|
stop_completion_consumer,
|
||||||
)
|
)
|
||||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
|
||||||
from backend.data.model import Credentials
|
from backend.data.model import Credentials
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.monitoring.instrumentation import instrument_fastapi
|
from backend.monitoring.instrumentation import instrument_fastapi
|
||||||
|
|||||||
@@ -38,7 +38,9 @@ def main(**kwargs):
|
|||||||
|
|
||||||
from backend.api.rest_api import AgentServer
|
from backend.api.rest_api import AgentServer
|
||||||
from backend.api.ws_api import WebsocketServer
|
from backend.api.ws_api import WebsocketServer
|
||||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
from backend.copilot.executor.manager import CoPilotExecutor
|
||||||
|
from backend.data.db_manager import DatabaseManager
|
||||||
|
from backend.executor import ExecutionManager, Scheduler
|
||||||
from backend.notifications import NotificationManager
|
from backend.notifications import NotificationManager
|
||||||
|
|
||||||
run_processes(
|
run_processes(
|
||||||
@@ -48,6 +50,7 @@ def main(**kwargs):
|
|||||||
WebsocketServer(),
|
WebsocketServer(),
|
||||||
AgentServer(),
|
AgentServer(),
|
||||||
ExecutionManager(),
|
ExecutionManager(),
|
||||||
|
CoPilotExecutor(),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
@@ -27,6 +28,54 @@ async def server():
|
|||||||
yield server
|
yield server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_user_id() -> str:
|
||||||
|
"""Test user ID fixture."""
|
||||||
|
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_user_id() -> str:
|
||||||
|
"""Admin user ID fixture."""
|
||||||
|
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def target_user_id() -> str:
|
||||||
|
"""Target user ID fixture."""
|
||||||
|
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def setup_test_user(test_user_id):
|
||||||
|
"""Create test user in database before tests."""
|
||||||
|
from backend.data.user import get_or_create_user
|
||||||
|
|
||||||
|
# Create the test user in the database using JWT token format
|
||||||
|
user_data = {
|
||||||
|
"sub": test_user_id,
|
||||||
|
"email": "test@example.com",
|
||||||
|
"user_metadata": {"name": "Test User"},
|
||||||
|
}
|
||||||
|
await get_or_create_user(user_data)
|
||||||
|
return test_user_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def setup_admin_user(admin_user_id):
|
||||||
|
"""Create admin user in database before tests."""
|
||||||
|
from backend.data.user import get_or_create_user
|
||||||
|
|
||||||
|
# Create the admin user in the database using JWT token format
|
||||||
|
user_data = {
|
||||||
|
"sub": admin_user_id,
|
||||||
|
"email": "test-admin@example.com",
|
||||||
|
"user_metadata": {"name": "Test Admin"},
|
||||||
|
}
|
||||||
|
await get_or_create_user(user_data)
|
||||||
|
return admin_user_id
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||||
async def graph_cleanup(server):
|
async def graph_cleanup(server):
|
||||||
created_graph_ids = []
|
created_graph_ids = []
|
||||||
|
|||||||
1
autogpt_platform/backend/backend/copilot/__init__.py
Normal file
1
autogpt_platform/backend/backend/copilot/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
@@ -37,12 +37,10 @@ stale pending messages from dead consumers.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from prisma import Prisma
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from redis.exceptions import ResponseError
|
from redis.exceptions import ResponseError
|
||||||
|
|
||||||
@@ -69,8 +67,8 @@ class OperationCompleteMessage(BaseModel):
|
|||||||
class ChatCompletionConsumer:
|
class ChatCompletionConsumer:
|
||||||
"""Consumer for chat operation completion messages from Redis Streams.
|
"""Consumer for chat operation completion messages from Redis Streams.
|
||||||
|
|
||||||
This consumer initializes its own Prisma client in start() to ensure
|
Database operations are handled through the chat_db() accessor, which
|
||||||
database operations work correctly within this async context.
|
routes through DatabaseManager RPC when Prisma is not directly connected.
|
||||||
|
|
||||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||||
messages reliably with automatic redelivery on failure.
|
messages reliably with automatic redelivery on failure.
|
||||||
@@ -79,7 +77,6 @@ class ChatCompletionConsumer:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._consumer_task: asyncio.Task | None = None
|
self._consumer_task: asyncio.Task | None = None
|
||||||
self._running = False
|
self._running = False
|
||||||
self._prisma: Prisma | None = None
|
|
||||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
@@ -115,15 +112,6 @@ class ChatCompletionConsumer:
|
|||||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _ensure_prisma(self) -> Prisma:
|
|
||||||
"""Lazily initialize Prisma client on first use."""
|
|
||||||
if self._prisma is None:
|
|
||||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
|
||||||
self._prisma = Prisma(datasource={"url": database_url})
|
|
||||||
await self._prisma.connect()
|
|
||||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
|
||||||
return self._prisma
|
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the completion consumer."""
|
"""Stop the completion consumer."""
|
||||||
self._running = False
|
self._running = False
|
||||||
@@ -136,11 +124,6 @@ class ChatCompletionConsumer:
|
|||||||
pass
|
pass
|
||||||
self._consumer_task = None
|
self._consumer_task = None
|
||||||
|
|
||||||
if self._prisma:
|
|
||||||
await self._prisma.disconnect()
|
|
||||||
self._prisma = None
|
|
||||||
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
|
||||||
|
|
||||||
logger.info("Chat completion consumer stopped")
|
logger.info("Chat completion consumer stopped")
|
||||||
|
|
||||||
async def _consume_messages(self) -> None:
|
async def _consume_messages(self) -> None:
|
||||||
@@ -252,7 +235,7 @@ class ChatCompletionConsumer:
|
|||||||
# XAUTOCLAIM after min_idle_time expires
|
# XAUTOCLAIM after min_idle_time expires
|
||||||
|
|
||||||
async def _handle_message(self, body: bytes) -> None:
|
async def _handle_message(self, body: bytes) -> None:
|
||||||
"""Handle a completion message using our own Prisma client."""
|
"""Handle a completion message."""
|
||||||
try:
|
try:
|
||||||
data = orjson.loads(body)
|
data = orjson.loads(body)
|
||||||
message = OperationCompleteMessage(**data)
|
message = OperationCompleteMessage(**data)
|
||||||
@@ -302,8 +285,7 @@ class ChatCompletionConsumer:
|
|||||||
message: OperationCompleteMessage,
|
message: OperationCompleteMessage,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle successful operation completion."""
|
"""Handle successful operation completion."""
|
||||||
prisma = await self._ensure_prisma()
|
await process_operation_success(task, message.result)
|
||||||
await process_operation_success(task, message.result, prisma)
|
|
||||||
|
|
||||||
async def _handle_failure(
|
async def _handle_failure(
|
||||||
self,
|
self,
|
||||||
@@ -311,8 +293,7 @@ class ChatCompletionConsumer:
|
|||||||
message: OperationCompleteMessage,
|
message: OperationCompleteMessage,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle failed operation completion."""
|
"""Handle failed operation completion."""
|
||||||
prisma = await self._ensure_prisma()
|
await process_operation_failure(task, message.error)
|
||||||
await process_operation_failure(task, message.error, prisma)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level consumer instance
|
# Module-level consumer instance
|
||||||
@@ -9,7 +9,8 @@ import logging
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from prisma import Prisma
|
|
||||||
|
from backend.data.db_accessors import chat_db
|
||||||
|
|
||||||
from . import service as chat_service
|
from . import service as chat_service
|
||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
@@ -72,48 +73,40 @@ async def _update_tool_message(
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
tool_call_id: str,
|
tool_call_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
prisma_client: Prisma | None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update tool message in database.
|
"""Update tool message in database using the chat_db accessor.
|
||||||
|
|
||||||
|
Routes through DatabaseManager RPC when Prisma is not directly
|
||||||
|
connected (e.g. in the CoPilot Executor microservice).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session_id: The session ID
|
session_id: The session ID
|
||||||
tool_call_id: The tool call ID to update
|
tool_call_id: The tool call ID to update
|
||||||
content: The new content for the message
|
content: The new content for the message
|
||||||
prisma_client: Optional Prisma client. If None, uses chat_service.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ToolMessageUpdateError: If the database update fails. The caller should
|
ToolMessageUpdateError: If the database update fails.
|
||||||
handle this to avoid marking the task as completed with inconsistent state.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if prisma_client:
|
updated = await chat_db().update_tool_message_content(
|
||||||
# Use provided Prisma client (for consumer with its own connection)
|
|
||||||
updated_count = await prisma_client.chatmessage.update_many(
|
|
||||||
where={
|
|
||||||
"sessionId": session_id,
|
|
||||||
"toolCallId": tool_call_id,
|
|
||||||
},
|
|
||||||
data={"content": content},
|
|
||||||
)
|
|
||||||
# Check if any rows were updated - 0 means message not found
|
|
||||||
if updated_count == 0:
|
|
||||||
raise ToolMessageUpdateError(
|
|
||||||
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Use service function (for webhook endpoint)
|
|
||||||
await chat_service._update_pending_operation(
|
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
result=content,
|
new_content=content,
|
||||||
|
)
|
||||||
|
if not updated:
|
||||||
|
raise ToolMessageUpdateError(
|
||||||
|
f"No message found with tool_call_id="
|
||||||
|
f"{tool_call_id} in session {session_id}"
|
||||||
)
|
)
|
||||||
except ToolMessageUpdateError:
|
except ToolMessageUpdateError:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
logger.error(
|
||||||
|
f"[COMPLETION] Failed to update tool message: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
raise ToolMessageUpdateError(
|
raise ToolMessageUpdateError(
|
||||||
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
f"Failed to update tool message for tool call #{tool_call_id}: {e}"
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
@@ -202,7 +195,6 @@ async def _save_agent_from_result(
|
|||||||
async def process_operation_success(
|
async def process_operation_success(
|
||||||
task: stream_registry.ActiveTask,
|
task: stream_registry.ActiveTask,
|
||||||
result: dict | str | None,
|
result: dict | str | None,
|
||||||
prisma_client: Prisma | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle successful operation completion.
|
"""Handle successful operation completion.
|
||||||
|
|
||||||
@@ -212,12 +204,10 @@ async def process_operation_success(
|
|||||||
Args:
|
Args:
|
||||||
task: The active task that completed
|
task: The active task that completed
|
||||||
result: The result data from the operation
|
result: The result data from the operation
|
||||||
prisma_client: Optional Prisma client for database operations.
|
|
||||||
If None, uses chat_service._update_pending_operation instead.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ToolMessageUpdateError: If the database update fails. The task will be
|
ToolMessageUpdateError: If the database update fails. The task
|
||||||
marked as failed instead of completed to avoid inconsistent state.
|
will be marked as failed instead of completed.
|
||||||
"""
|
"""
|
||||||
# For agent generation tools, save the agent to library
|
# For agent generation tools, save the agent to library
|
||||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||||
@@ -250,7 +240,6 @@ async def process_operation_success(
|
|||||||
session_id=task.session_id,
|
session_id=task.session_id,
|
||||||
tool_call_id=task.tool_call_id,
|
tool_call_id=task.tool_call_id,
|
||||||
content=result_str,
|
content=result_str,
|
||||||
prisma_client=prisma_client,
|
|
||||||
)
|
)
|
||||||
except ToolMessageUpdateError:
|
except ToolMessageUpdateError:
|
||||||
# DB update failed - mark task as failed to avoid inconsistent state
|
# DB update failed - mark task as failed to avoid inconsistent state
|
||||||
@@ -293,18 +282,15 @@ async def process_operation_success(
|
|||||||
async def process_operation_failure(
|
async def process_operation_failure(
|
||||||
task: stream_registry.ActiveTask,
|
task: stream_registry.ActiveTask,
|
||||||
error: str | None,
|
error: str | None,
|
||||||
prisma_client: Prisma | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Handle failed operation completion.
|
"""Handle failed operation completion.
|
||||||
|
|
||||||
Publishes the error to the stream registry, updates the database with
|
Publishes the error to the stream registry, updates the database
|
||||||
the error response, and marks the task as failed.
|
with the error response, and marks the task as failed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task: The active task that failed
|
task: The active task that failed
|
||||||
error: The error message from the operation
|
error: The error message from the operation
|
||||||
prisma_client: Optional Prisma client for database operations.
|
|
||||||
If None, uses chat_service._update_pending_operation instead.
|
|
||||||
"""
|
"""
|
||||||
error_msg = error or "Operation failed"
|
error_msg = error or "Operation failed"
|
||||||
|
|
||||||
@@ -325,7 +311,6 @@ async def process_operation_failure(
|
|||||||
session_id=task.session_id,
|
session_id=task.session_id,
|
||||||
tool_call_id=task.tool_call_id,
|
tool_call_id=task.tool_call_id,
|
||||||
content=error_response.model_dump_json(),
|
content=error_response.model_dump_json(),
|
||||||
prisma_client=prisma_client,
|
|
||||||
)
|
)
|
||||||
except ToolMessageUpdateError:
|
except ToolMessageUpdateError:
|
||||||
# DB update failed - log but continue with cleanup
|
# DB update failed - log but continue with cleanup
|
||||||
@@ -14,29 +14,27 @@ from prisma.types import (
|
|||||||
ChatSessionWhereInput,
|
ChatSessionWhereInput,
|
||||||
)
|
)
|
||||||
|
|
||||||
from backend.data.db import transaction
|
from backend.data import db
|
||||||
from backend.util.json import SafeJson
|
from backend.util.json import SafeJson
|
||||||
|
|
||||||
|
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||||
"""Get a chat session by ID from the database."""
|
"""Get a chat session by ID from the database."""
|
||||||
session = await PrismaChatSession.prisma().find_unique(
|
session = await PrismaChatSession.prisma().find_unique(
|
||||||
where={"id": session_id},
|
where={"id": session_id},
|
||||||
include={"Messages": True},
|
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
||||||
)
|
)
|
||||||
if session and session.Messages:
|
return ChatSession.from_db(session) if session else None
|
||||||
# Sort messages by sequence in Python - Prisma Python client doesn't support
|
|
||||||
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
|
|
||||||
session.Messages.sort(key=lambda m: m.sequence)
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
async def create_chat_session(
|
async def create_chat_session(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> PrismaChatSession:
|
) -> ChatSessionInfo:
|
||||||
"""Create a new chat session in the database."""
|
"""Create a new chat session in the database."""
|
||||||
data = ChatSessionCreateInput(
|
data = ChatSessionCreateInput(
|
||||||
id=session_id,
|
id=session_id,
|
||||||
@@ -45,7 +43,8 @@ async def create_chat_session(
|
|||||||
successfulAgentRuns=SafeJson({}),
|
successfulAgentRuns=SafeJson({}),
|
||||||
successfulAgentSchedules=SafeJson({}),
|
successfulAgentSchedules=SafeJson({}),
|
||||||
)
|
)
|
||||||
return await PrismaChatSession.prisma().create(data=data)
|
prisma_session = await PrismaChatSession.prisma().create(data=data)
|
||||||
|
return ChatSessionInfo.from_db(prisma_session)
|
||||||
|
|
||||||
|
|
||||||
async def update_chat_session(
|
async def update_chat_session(
|
||||||
@@ -56,7 +55,7 @@ async def update_chat_session(
|
|||||||
total_prompt_tokens: int | None = None,
|
total_prompt_tokens: int | None = None,
|
||||||
total_completion_tokens: int | None = None,
|
total_completion_tokens: int | None = None,
|
||||||
title: str | None = None,
|
title: str | None = None,
|
||||||
) -> PrismaChatSession | None:
|
) -> ChatSession | None:
|
||||||
"""Update a chat session's metadata."""
|
"""Update a chat session's metadata."""
|
||||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||||
|
|
||||||
@@ -76,12 +75,9 @@ async def update_chat_session(
|
|||||||
session = await PrismaChatSession.prisma().update(
|
session = await PrismaChatSession.prisma().update(
|
||||||
where={"id": session_id},
|
where={"id": session_id},
|
||||||
data=data,
|
data=data,
|
||||||
include={"Messages": True},
|
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
||||||
)
|
)
|
||||||
if session and session.Messages:
|
return ChatSession.from_db(session) if session else None
|
||||||
# Sort in Python - Prisma Python doesn't support order_by in include clauses
|
|
||||||
session.Messages.sort(key=lambda m: m.sequence)
|
|
||||||
return session
|
|
||||||
|
|
||||||
|
|
||||||
async def add_chat_message(
|
async def add_chat_message(
|
||||||
@@ -94,7 +90,7 @@ async def add_chat_message(
|
|||||||
refusal: str | None = None,
|
refusal: str | None = None,
|
||||||
tool_calls: list[dict[str, Any]] | None = None,
|
tool_calls: list[dict[str, Any]] | None = None,
|
||||||
function_call: dict[str, Any] | None = None,
|
function_call: dict[str, Any] | None = None,
|
||||||
) -> PrismaChatMessage:
|
) -> ChatMessage:
|
||||||
"""Add a message to a chat session."""
|
"""Add a message to a chat session."""
|
||||||
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
||||||
# because Prisma's TypedDict validation rejects optional fields set to None.
|
# because Prisma's TypedDict validation rejects optional fields set to None.
|
||||||
@@ -129,14 +125,14 @@ async def add_chat_message(
|
|||||||
),
|
),
|
||||||
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
||||||
)
|
)
|
||||||
return message
|
return ChatMessage.from_db(message)
|
||||||
|
|
||||||
|
|
||||||
async def add_chat_messages_batch(
|
async def add_chat_messages_batch(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
start_sequence: int,
|
start_sequence: int,
|
||||||
) -> list[PrismaChatMessage]:
|
) -> list[ChatMessage]:
|
||||||
"""Add multiple messages to a chat session in a batch.
|
"""Add multiple messages to a chat session in a batch.
|
||||||
|
|
||||||
Uses a transaction for atomicity - if any message creation fails,
|
Uses a transaction for atomicity - if any message creation fails,
|
||||||
@@ -147,7 +143,7 @@ async def add_chat_messages_batch(
|
|||||||
|
|
||||||
created_messages = []
|
created_messages = []
|
||||||
|
|
||||||
async with transaction() as tx:
|
async with db.transaction() as tx:
|
||||||
for i, msg in enumerate(messages):
|
for i, msg in enumerate(messages):
|
||||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||||
# directly because Prisma's TypedDict validation rejects optional fields
|
# directly because Prisma's TypedDict validation rejects optional fields
|
||||||
@@ -187,21 +183,22 @@ async def add_chat_messages_batch(
|
|||||||
data={"updatedAt": datetime.now(UTC)},
|
data={"updatedAt": datetime.now(UTC)},
|
||||||
)
|
)
|
||||||
|
|
||||||
return created_messages
|
return [ChatMessage.from_db(m) for m in created_messages]
|
||||||
|
|
||||||
|
|
||||||
async def get_user_chat_sessions(
|
async def get_user_chat_sessions(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> list[PrismaChatSession]:
|
) -> list[ChatSessionInfo]:
|
||||||
"""Get chat sessions for a user, ordered by most recent."""
|
"""Get chat sessions for a user, ordered by most recent."""
|
||||||
return await PrismaChatSession.prisma().find_many(
|
prisma_sessions = await PrismaChatSession.prisma().find_many(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
order={"updatedAt": "desc"},
|
order={"updatedAt": "desc"},
|
||||||
take=limit,
|
take=limit,
|
||||||
skip=offset,
|
skip=offset,
|
||||||
)
|
)
|
||||||
|
return [ChatSessionInfo.from_db(s) for s in prisma_sessions]
|
||||||
|
|
||||||
|
|
||||||
async def get_user_session_count(user_id: str) -> int:
|
async def get_user_session_count(user_id: str) -> int:
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
"""Entry point for running the CoPilot Executor service.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m backend.copilot.executor
|
||||||
|
"""
|
||||||
|
|
||||||
|
from backend.app import run_processes
|
||||||
|
|
||||||
|
from .manager import CoPilotExecutor
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run the CoPilot Executor service."""
|
||||||
|
run_processes(CoPilotExecutor())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
519
autogpt_platform/backend/backend/copilot/executor/manager.py
Normal file
519
autogpt_platform/backend/backend/copilot/executor/manager.py
Normal file
@@ -0,0 +1,519 @@
|
|||||||
|
"""CoPilot Executor Manager - main service for CoPilot task execution.
|
||||||
|
|
||||||
|
This module contains the CoPilotExecutor class that consumes chat tasks from
|
||||||
|
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
|
|
||||||
|
from pika.adapters.blocking_connection import BlockingChannel
|
||||||
|
from pika.exceptions import AMQPChannelError, AMQPConnectionError
|
||||||
|
from pika.spec import Basic, BasicProperties
|
||||||
|
from prometheus_client import Gauge, start_http_server
|
||||||
|
|
||||||
|
from backend.data import redis_client as redis
|
||||||
|
from backend.data.rabbitmq import SyncRabbitMQ
|
||||||
|
from backend.executor.cluster_lock import ClusterLock
|
||||||
|
from backend.util.decorator import error_logged
|
||||||
|
from backend.util.logging import TruncatedLogger
|
||||||
|
from backend.util.process import AppProcess
|
||||||
|
from backend.util.retry import continuous_retry
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
from .processor import execute_copilot_task, init_worker
|
||||||
|
from .utils import (
|
||||||
|
COPILOT_CANCEL_QUEUE_NAME,
|
||||||
|
COPILOT_EXECUTION_QUEUE_NAME,
|
||||||
|
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||||
|
CancelCoPilotEvent,
|
||||||
|
CoPilotExecutionEntry,
|
||||||
|
create_copilot_queue_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||||
|
settings = Settings()
|
||||||
|
|
||||||
|
# Prometheus metrics
|
||||||
|
active_tasks_gauge = Gauge(
|
||||||
|
"copilot_executor_active_tasks",
|
||||||
|
"Number of active CoPilot tasks",
|
||||||
|
)
|
||||||
|
pool_size_gauge = Gauge(
|
||||||
|
"copilot_executor_pool_size",
|
||||||
|
"Maximum number of CoPilot executor workers",
|
||||||
|
)
|
||||||
|
utilization_gauge = Gauge(
|
||||||
|
"copilot_executor_utilization_ratio",
|
||||||
|
"Ratio of active tasks to pool size",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CoPilotExecutor(AppProcess):
|
||||||
|
"""CoPilot Executor service for processing chat generation tasks.
|
||||||
|
|
||||||
|
This service consumes tasks from RabbitMQ, processes them using a thread pool,
|
||||||
|
and publishes results to Redis Streams. It follows the graph executor pattern
|
||||||
|
for reliable message handling and graceful shutdown.
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- RabbitMQ-based task distribution with manual acknowledgment
|
||||||
|
- Thread pool executor for concurrent task processing
|
||||||
|
- Cluster lock for duplicate prevention across pods
|
||||||
|
- Graceful shutdown with timeout for in-flight tasks
|
||||||
|
- FANOUT exchange for cancellation broadcast
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.pool_size = settings.config.num_copilot_workers
|
||||||
|
self.active_tasks: dict[str, tuple[Future, threading.Event]] = {}
|
||||||
|
self.executor_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
self._executor = None
|
||||||
|
self._stop_consuming = None
|
||||||
|
|
||||||
|
self._cancel_thread = None
|
||||||
|
self._cancel_client = None
|
||||||
|
self._run_thread = None
|
||||||
|
self._run_client = None
|
||||||
|
|
||||||
|
self._task_locks: dict[str, ClusterLock] = {}
|
||||||
|
self._active_tasks_lock = threading.Lock()
|
||||||
|
|
||||||
|
# ============ Main Entry Points (AppProcess interface) ============ #
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""Main service loop - consume from RabbitMQ."""
|
||||||
|
logger.info(f"Pod assigned executor_id: {self.executor_id}")
|
||||||
|
logger.info(f"Spawn max-{self.pool_size} workers...")
|
||||||
|
|
||||||
|
pool_size_gauge.set(self.pool_size)
|
||||||
|
self._update_metrics()
|
||||||
|
start_http_server(settings.config.copilot_executor_port)
|
||||||
|
|
||||||
|
self.cancel_thread.start()
|
||||||
|
self.run_thread.start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
time.sleep(1e5)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Graceful shutdown with active execution waiting."""
|
||||||
|
pid = os.getpid()
|
||||||
|
logger.info(f"[cleanup {pid}] Starting graceful shutdown...")
|
||||||
|
|
||||||
|
# Signal the consumer thread to stop
|
||||||
|
try:
|
||||||
|
self.stop_consuming.set()
|
||||||
|
run_channel = self.run_client.get_channel()
|
||||||
|
run_channel.connection.add_callback_threadsafe(
|
||||||
|
lambda: run_channel.stop_consuming()
|
||||||
|
)
|
||||||
|
logger.info(f"[cleanup {pid}] Consumer has been signaled to stop")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[cleanup {pid}] Error stopping consumer: {e}")
|
||||||
|
|
||||||
|
# Wait for active executions to complete
|
||||||
|
if self.active_tasks:
|
||||||
|
logger.info(
|
||||||
|
f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.monotonic()
|
||||||
|
last_refresh = start_time
|
||||||
|
lock_refresh_interval = settings.config.cluster_lock_timeout / 10
|
||||||
|
|
||||||
|
while (
|
||||||
|
self.active_tasks
|
||||||
|
and (time.monotonic() - start_time) < GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
|
||||||
|
):
|
||||||
|
self._cleanup_completed_tasks()
|
||||||
|
if not self.active_tasks:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Refresh cluster locks periodically
|
||||||
|
current_time = time.monotonic()
|
||||||
|
if current_time - last_refresh >= lock_refresh_interval:
|
||||||
|
for lock in list(self._task_locks.values()):
|
||||||
|
try:
|
||||||
|
lock.refresh()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[cleanup {pid}] Failed to refresh lock: {e}"
|
||||||
|
)
|
||||||
|
last_refresh = current_time
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..."
|
||||||
|
)
|
||||||
|
time.sleep(10.0)
|
||||||
|
|
||||||
|
# Stop message consumers
|
||||||
|
if self._run_thread:
|
||||||
|
self._stop_message_consumers(
|
||||||
|
self._run_thread, self.run_client, "[cleanup][run]"
|
||||||
|
)
|
||||||
|
if self._cancel_thread:
|
||||||
|
self._stop_message_consumers(
|
||||||
|
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shutdown executor
|
||||||
|
if self._executor:
|
||||||
|
logger.info(f"[cleanup {pid}] Shutting down executor...")
|
||||||
|
self._executor.shutdown(wait=False)
|
||||||
|
|
||||||
|
# Close async resources (workspace storage aiohttp session, etc.)
|
||||||
|
try:
|
||||||
|
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
loop.run_until_complete(shutdown_workspace_storage())
|
||||||
|
loop.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[cleanup {pid}] Error closing workspace storage: {e}")
|
||||||
|
|
||||||
|
# Release any remaining locks
|
||||||
|
for task_id, lock in list(self._task_locks.items()):
|
||||||
|
try:
|
||||||
|
lock.release()
|
||||||
|
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
|
||||||
|
|
||||||
|
# ============ RabbitMQ Consumer Methods ============ #
|
||||||
|
|
||||||
|
@continuous_retry()
|
||||||
|
def _consume_cancel(self):
|
||||||
|
"""Consume cancellation messages from FANOUT exchange."""
|
||||||
|
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||||
|
logger.info("Stop reconnecting cancel consumer - service cleaned up")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.cancel_client.is_ready:
|
||||||
|
self.cancel_client.disconnect()
|
||||||
|
self.cancel_client.connect()
|
||||||
|
|
||||||
|
# Check again after connect - shutdown may have been requested
|
||||||
|
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||||
|
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||||
|
self.cancel_client.disconnect()
|
||||||
|
return
|
||||||
|
|
||||||
|
cancel_channel = self.cancel_client.get_channel()
|
||||||
|
cancel_channel.basic_consume(
|
||||||
|
queue=COPILOT_CANCEL_QUEUE_NAME,
|
||||||
|
on_message_callback=self._handle_cancel_message,
|
||||||
|
auto_ack=True,
|
||||||
|
)
|
||||||
|
logger.info("Starting to consume cancel messages...")
|
||||||
|
cancel_channel.start_consuming()
|
||||||
|
if not self.stop_consuming.is_set() or self.active_tasks:
|
||||||
|
raise RuntimeError("Cancel message consumer stopped unexpectedly")
|
||||||
|
logger.info("Cancel message consumer stopped gracefully")
|
||||||
|
|
||||||
|
@continuous_retry()
|
||||||
|
def _consume_run(self):
|
||||||
|
"""Consume run messages from DIRECT exchange."""
|
||||||
|
if self.stop_consuming.is_set():
|
||||||
|
logger.info("Stop reconnecting run consumer - service cleaned up")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.run_client.is_ready:
|
||||||
|
self.run_client.disconnect()
|
||||||
|
self.run_client.connect()
|
||||||
|
|
||||||
|
# Check again after connect - shutdown may have been requested
|
||||||
|
if self.stop_consuming.is_set():
|
||||||
|
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||||
|
self.run_client.disconnect()
|
||||||
|
return
|
||||||
|
|
||||||
|
run_channel = self.run_client.get_channel()
|
||||||
|
run_channel.basic_qos(prefetch_count=self.pool_size)
|
||||||
|
|
||||||
|
run_channel.basic_consume(
|
||||||
|
queue=COPILOT_EXECUTION_QUEUE_NAME,
|
||||||
|
on_message_callback=self._handle_run_message,
|
||||||
|
auto_ack=False,
|
||||||
|
consumer_tag="copilot_execution_consumer",
|
||||||
|
)
|
||||||
|
logger.info("Starting to consume run messages...")
|
||||||
|
run_channel.start_consuming()
|
||||||
|
if not self.stop_consuming.is_set():
|
||||||
|
raise RuntimeError("Run message consumer stopped unexpectedly")
|
||||||
|
logger.info("Run message consumer stopped gracefully")
|
||||||
|
|
||||||
|
# ============ Message Handlers ============ #
|
||||||
|
|
||||||
|
@error_logged(swallow=True)
|
||||||
|
def _handle_cancel_message(
|
||||||
|
self,
|
||||||
|
_channel: BlockingChannel,
|
||||||
|
_method: Basic.Deliver,
|
||||||
|
_properties: BasicProperties,
|
||||||
|
body: bytes,
|
||||||
|
):
|
||||||
|
"""Handle cancel message from FANOUT exchange."""
|
||||||
|
request = CancelCoPilotEvent.model_validate_json(body)
|
||||||
|
task_id = request.task_id
|
||||||
|
if not task_id:
|
||||||
|
logger.warning("Cancel message missing 'task_id'")
|
||||||
|
return
|
||||||
|
if task_id not in self.active_tasks:
|
||||||
|
logger.debug(f"Cancel received for {task_id} but not active")
|
||||||
|
return
|
||||||
|
|
||||||
|
_, cancel_event = self.active_tasks[task_id]
|
||||||
|
logger.info(f"Received cancel for {task_id}")
|
||||||
|
if not cancel_event.is_set():
|
||||||
|
cancel_event.set()
|
||||||
|
else:
|
||||||
|
logger.debug(f"Cancel already set for {task_id}")
|
||||||
|
|
||||||
|
def _handle_run_message(
|
||||||
|
self,
|
||||||
|
_channel: BlockingChannel,
|
||||||
|
method: Basic.Deliver,
|
||||||
|
_properties: BasicProperties,
|
||||||
|
body: bytes,
|
||||||
|
):
|
||||||
|
"""Handle run message from DIRECT exchange."""
|
||||||
|
delivery_tag = method.delivery_tag
|
||||||
|
# Capture the channel used at message delivery time to ensure we ack
|
||||||
|
# on the correct channel. Delivery tags are channel-scoped and become
|
||||||
|
# invalid if the channel is recreated after reconnection.
|
||||||
|
delivery_channel = _channel
|
||||||
|
|
||||||
|
def ack_message(reject: bool, requeue: bool):
|
||||||
|
"""Acknowledge or reject the message.
|
||||||
|
|
||||||
|
Uses the channel from the original message delivery. If the channel
|
||||||
|
is no longer open (e.g., after reconnection), logs a warning and
|
||||||
|
skips the ack - RabbitMQ will redeliver the message automatically.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not delivery_channel.is_open:
|
||||||
|
logger.warning(
|
||||||
|
f"Channel closed, cannot ack delivery_tag={delivery_tag}. "
|
||||||
|
"Message will be redelivered by RabbitMQ."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if reject:
|
||||||
|
delivery_channel.connection.add_callback_threadsafe(
|
||||||
|
lambda: delivery_channel.basic_nack(
|
||||||
|
delivery_tag, requeue=requeue
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
delivery_channel.connection.add_callback_threadsafe(
|
||||||
|
lambda: delivery_channel.basic_ack(delivery_tag)
|
||||||
|
)
|
||||||
|
except (AMQPChannelError, AMQPConnectionError) as e:
|
||||||
|
# Channel/connection errors indicate stale delivery tag - don't retry
|
||||||
|
logger.warning(
|
||||||
|
f"Cannot ack delivery_tag={delivery_tag} due to channel/connection "
|
||||||
|
f"error: {e}. Message will be redelivered by RabbitMQ."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Other errors might be transient, but log and skip to avoid blocking
|
||||||
|
logger.error(
|
||||||
|
f"Unexpected error acking delivery_tag={delivery_tag}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if we're shutting down
|
||||||
|
if self.stop_consuming.is_set():
|
||||||
|
logger.info("Rejecting new task during shutdown")
|
||||||
|
ack_message(reject=True, requeue=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if we can accept more tasks
|
||||||
|
self._cleanup_completed_tasks()
|
||||||
|
if len(self.active_tasks) >= self.pool_size:
|
||||||
|
ack_message(reject=True, requeue=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
entry = CoPilotExecutionEntry.model_validate_json(body)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Could not parse run message: {e}, body={body}")
|
||||||
|
ack_message(reject=True, requeue=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
task_id = entry.task_id
|
||||||
|
|
||||||
|
# Check for local duplicate - task is already running on this executor
|
||||||
|
if task_id in self.active_tasks:
|
||||||
|
logger.warning(
|
||||||
|
f"Task {task_id} already running locally, rejecting duplicate"
|
||||||
|
)
|
||||||
|
ack_message(reject=True, requeue=False)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Try to acquire cluster-wide lock
|
||||||
|
cluster_lock = ClusterLock(
|
||||||
|
redis=redis.get_redis(),
|
||||||
|
key=f"copilot:task:{task_id}:lock",
|
||||||
|
owner_id=self.executor_id,
|
||||||
|
timeout=settings.config.cluster_lock_timeout,
|
||||||
|
)
|
||||||
|
current_owner = cluster_lock.try_acquire()
|
||||||
|
if current_owner != self.executor_id:
|
||||||
|
if current_owner is not None:
|
||||||
|
logger.warning(f"Task {task_id} already running on pod {current_owner}")
|
||||||
|
ack_message(reject=True, requeue=False)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Could not acquire lock for {task_id} - Redis unavailable"
|
||||||
|
)
|
||||||
|
ack_message(reject=True, requeue=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Execute the task
|
||||||
|
try:
|
||||||
|
self._task_locks[task_id] = cluster_lock
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
cancel_event = threading.Event()
|
||||||
|
future = self.executor.submit(
|
||||||
|
execute_copilot_task, entry, cancel_event, cluster_lock
|
||||||
|
)
|
||||||
|
self.active_tasks[task_id] = (future, cancel_event)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to setup execution for {task_id}: {e}")
|
||||||
|
cluster_lock.release()
|
||||||
|
if task_id in self._task_locks:
|
||||||
|
del self._task_locks[task_id]
|
||||||
|
ack_message(reject=True, requeue=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._update_metrics()
|
||||||
|
|
||||||
|
def on_run_done(f: Future):
|
||||||
|
logger.info(f"Run completed for {task_id}")
|
||||||
|
try:
|
||||||
|
if exec_error := f.exception():
|
||||||
|
logger.error(f"Execution for {task_id} failed: {exec_error}")
|
||||||
|
# Don't requeue failed tasks - they've been marked as failed
|
||||||
|
# in the stream registry. Requeuing would cause infinite retries
|
||||||
|
# for deterministic failures.
|
||||||
|
ack_message(reject=True, requeue=False)
|
||||||
|
else:
|
||||||
|
ack_message(reject=False, requeue=False)
|
||||||
|
except BaseException as e:
|
||||||
|
logger.exception(f"Error in run completion callback: {e}")
|
||||||
|
finally:
|
||||||
|
# Release the cluster lock
|
||||||
|
if task_id in self._task_locks:
|
||||||
|
logger.info(f"Releasing cluster lock for {task_id}")
|
||||||
|
self._task_locks[task_id].release()
|
||||||
|
del self._task_locks[task_id]
|
||||||
|
self._cleanup_completed_tasks()
|
||||||
|
|
||||||
|
future.add_done_callback(on_run_done)
|
||||||
|
|
||||||
|
# ============ Helper Methods ============ #
|
||||||
|
|
||||||
|
def _cleanup_completed_tasks(self) -> list[str]:
|
||||||
|
"""Remove completed futures from active_tasks and update metrics."""
|
||||||
|
completed_tasks = []
|
||||||
|
with self._active_tasks_lock:
|
||||||
|
for task_id, (future, _) in list(self.active_tasks.items()):
|
||||||
|
if future.done():
|
||||||
|
completed_tasks.append(task_id)
|
||||||
|
self.active_tasks.pop(task_id, None)
|
||||||
|
logger.info(f"Cleaned up completed task {task_id}")
|
||||||
|
|
||||||
|
self._update_metrics()
|
||||||
|
return completed_tasks
|
||||||
|
|
||||||
|
def _update_metrics(self):
|
||||||
|
"""Update Prometheus metrics."""
|
||||||
|
active_count = len(self.active_tasks)
|
||||||
|
active_tasks_gauge.set(active_count)
|
||||||
|
if self.stop_consuming.is_set():
|
||||||
|
utilization_gauge.set(1.0)
|
||||||
|
else:
|
||||||
|
utilization_gauge.set(
|
||||||
|
active_count / self.pool_size if self.pool_size > 0 else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _stop_message_consumers(
|
||||||
|
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
|
||||||
|
):
|
||||||
|
"""Stop a message consumer thread."""
|
||||||
|
try:
|
||||||
|
channel = client.get_channel()
|
||||||
|
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
|
||||||
|
|
||||||
|
thread.join(timeout=300)
|
||||||
|
if thread.is_alive():
|
||||||
|
logger.error(
|
||||||
|
f"{prefix} Thread did not finish in time, forcing disconnect"
|
||||||
|
)
|
||||||
|
|
||||||
|
client.disconnect()
|
||||||
|
logger.info(f"{prefix} Client disconnected")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{prefix} Error disconnecting client: {e}")
|
||||||
|
|
||||||
|
# ============ Lazy-initialized Properties ============ #
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cancel_thread(self) -> threading.Thread:
|
||||||
|
if self._cancel_thread is None:
|
||||||
|
self._cancel_thread = threading.Thread(
|
||||||
|
target=lambda: self._consume_cancel(),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
return self._cancel_thread
|
||||||
|
|
||||||
|
@property
|
||||||
|
def run_thread(self) -> threading.Thread:
|
||||||
|
if self._run_thread is None:
|
||||||
|
self._run_thread = threading.Thread(
|
||||||
|
target=lambda: self._consume_run(),
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
return self._run_thread
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stop_consuming(self) -> threading.Event:
|
||||||
|
if self._stop_consuming is None:
|
||||||
|
self._stop_consuming = threading.Event()
|
||||||
|
return self._stop_consuming
|
||||||
|
|
||||||
|
@property
|
||||||
|
def executor(self) -> ThreadPoolExecutor:
|
||||||
|
if self._executor is None:
|
||||||
|
self._executor = ThreadPoolExecutor(
|
||||||
|
max_workers=self.pool_size,
|
||||||
|
initializer=init_worker,
|
||||||
|
)
|
||||||
|
return self._executor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cancel_client(self) -> SyncRabbitMQ:
|
||||||
|
if self._cancel_client is None:
|
||||||
|
self._cancel_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||||
|
return self._cancel_client
|
||||||
|
|
||||||
|
@property
|
||||||
|
def run_client(self) -> SyncRabbitMQ:
|
||||||
|
if self._run_client is None:
|
||||||
|
self._run_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||||
|
return self._run_client
|
||||||
253
autogpt_platform/backend/backend/copilot/executor/processor.py
Normal file
253
autogpt_platform/backend/backend/copilot/executor/processor.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
"""CoPilot execution processor - per-worker execution logic.
|
||||||
|
|
||||||
|
This module contains the processor class that handles CoPilot task execution
|
||||||
|
in a thread-local context, following the graph executor pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
from backend.copilot import service as copilot_service
|
||||||
|
from backend.copilot import stream_registry
|
||||||
|
from backend.copilot.config import ChatConfig
|
||||||
|
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
|
||||||
|
from backend.copilot.sdk import service as sdk_service
|
||||||
|
from backend.executor.cluster_lock import ClusterLock
|
||||||
|
from backend.util.decorator import error_logged
|
||||||
|
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||||
|
from backend.util.logging import TruncatedLogger, configure_logging
|
||||||
|
from backend.util.process import set_service_name
|
||||||
|
from backend.util.retry import func_retry
|
||||||
|
|
||||||
|
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||||
|
|
||||||
|
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Module Entry Points ============ #
|
||||||
|
|
||||||
|
# Thread-local storage for processor instances
|
||||||
|
_tls = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def execute_copilot_task(
|
||||||
|
entry: CoPilotExecutionEntry,
|
||||||
|
cancel: threading.Event,
|
||||||
|
cluster_lock: ClusterLock,
|
||||||
|
):
|
||||||
|
"""Execute a CoPilot task using the thread-local processor.
|
||||||
|
|
||||||
|
This function is the entry point called by the thread pool executor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry: The task payload
|
||||||
|
cancel: Threading event to signal cancellation
|
||||||
|
cluster_lock: Distributed lock for this execution
|
||||||
|
"""
|
||||||
|
processor: CoPilotProcessor = _tls.processor
|
||||||
|
return processor.execute(entry, cancel, cluster_lock)
|
||||||
|
|
||||||
|
|
||||||
|
def init_worker():
|
||||||
|
"""Initialize the processor for the current worker thread.
|
||||||
|
|
||||||
|
This function is called by the thread pool executor when a new worker
|
||||||
|
thread is created. It ensures each worker has its own processor instance.
|
||||||
|
"""
|
||||||
|
_tls.processor = CoPilotProcessor()
|
||||||
|
_tls.processor.on_executor_start()
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Processor Class ============ #
|
||||||
|
|
||||||
|
|
||||||
|
class CoPilotProcessor:
|
||||||
|
"""Per-worker execution logic for CoPilot tasks.
|
||||||
|
|
||||||
|
This class is instantiated once per worker thread and handles the execution
|
||||||
|
of CoPilot chat generation tasks. It maintains an async event loop for
|
||||||
|
running the async service code.
|
||||||
|
|
||||||
|
The execution flow:
|
||||||
|
1. CoPilot task is picked from RabbitMQ queue
|
||||||
|
2. Manager submits task to thread pool
|
||||||
|
3. Processor executes the task in its event loop
|
||||||
|
4. Results are published to Redis Streams
|
||||||
|
"""
|
||||||
|
|
||||||
|
@func_retry
|
||||||
|
def on_executor_start(self):
|
||||||
|
"""Initialize the processor when the worker thread starts.
|
||||||
|
|
||||||
|
This method is called once per worker thread to set up the async event
|
||||||
|
loop and initialize any required resources.
|
||||||
|
|
||||||
|
Database is accessed only through DatabaseManager, so we don't need to connect
|
||||||
|
to Prisma directly.
|
||||||
|
"""
|
||||||
|
configure_logging()
|
||||||
|
set_service_name("CoPilotExecutor")
|
||||||
|
self.tid = threading.get_ident()
|
||||||
|
self.execution_loop = asyncio.new_event_loop()
|
||||||
|
self.execution_thread = threading.Thread(
|
||||||
|
target=self.execution_loop.run_forever, daemon=True
|
||||||
|
)
|
||||||
|
self.execution_thread.start()
|
||||||
|
|
||||||
|
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||||
|
|
||||||
|
@error_logged(swallow=False)
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
entry: CoPilotExecutionEntry,
|
||||||
|
cancel: threading.Event,
|
||||||
|
cluster_lock: ClusterLock,
|
||||||
|
):
|
||||||
|
"""Execute a CoPilot task.
|
||||||
|
|
||||||
|
This is the main entry point for task execution. It runs the async
|
||||||
|
execution logic in the worker's event loop and handles errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry: The task payload containing session and message info
|
||||||
|
cancel: Threading event to signal cancellation
|
||||||
|
cluster_lock: Distributed lock to prevent duplicate execution
|
||||||
|
"""
|
||||||
|
log = CoPilotLogMetadata(
|
||||||
|
logging.getLogger(__name__),
|
||||||
|
task_id=entry.task_id,
|
||||||
|
session_id=entry.session_id,
|
||||||
|
user_id=entry.user_id,
|
||||||
|
)
|
||||||
|
log.info("Starting execution")
|
||||||
|
|
||||||
|
start_time = time.monotonic()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Run the async execution in our event loop
|
||||||
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
|
self._execute_async(entry, cancel, cluster_lock, log),
|
||||||
|
self.execution_loop,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for completion, checking cancel periodically
|
||||||
|
while not future.done():
|
||||||
|
try:
|
||||||
|
future.result(timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if cancel.is_set():
|
||||||
|
log.info("Cancellation requested")
|
||||||
|
future.cancel()
|
||||||
|
break
|
||||||
|
# Refresh cluster lock to maintain ownership
|
||||||
|
cluster_lock.refresh()
|
||||||
|
|
||||||
|
if not future.cancelled():
|
||||||
|
# Get result to propagate any exceptions
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - start_time
|
||||||
|
log.info(f"Execution completed in {elapsed:.2f}s")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
elapsed = time.monotonic() - start_time
|
||||||
|
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
|
||||||
|
# Note: _execute_async already marks the task as failed before re-raising,
|
||||||
|
# so we don't call _mark_task_failed here to avoid duplicate error events.
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _execute_async(
|
||||||
|
self,
|
||||||
|
entry: CoPilotExecutionEntry,
|
||||||
|
cancel: threading.Event,
|
||||||
|
cluster_lock: ClusterLock,
|
||||||
|
log: CoPilotLogMetadata,
|
||||||
|
):
|
||||||
|
"""Async execution logic for CoPilot task.
|
||||||
|
|
||||||
|
This method calls the existing stream_chat_completion service function
|
||||||
|
and publishes results to the stream registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entry: The task payload
|
||||||
|
cancel: Threading event to signal cancellation
|
||||||
|
cluster_lock: Distributed lock for refresh
|
||||||
|
log: Structured logger for this task
|
||||||
|
"""
|
||||||
|
last_refresh = time.monotonic()
|
||||||
|
refresh_interval = 30.0 # Refresh lock every 30 seconds
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Choose service based on LaunchDarkly flag
|
||||||
|
config = ChatConfig()
|
||||||
|
use_sdk = await is_feature_enabled(
|
||||||
|
Flag.COPILOT_SDK,
|
||||||
|
entry.user_id or "anonymous",
|
||||||
|
default=config.use_claude_agent_sdk,
|
||||||
|
)
|
||||||
|
stream_fn = (
|
||||||
|
sdk_service.stream_chat_completion_sdk
|
||||||
|
if use_sdk
|
||||||
|
else copilot_service.stream_chat_completion
|
||||||
|
)
|
||||||
|
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
|
||||||
|
|
||||||
|
# Stream chat completion and publish chunks to Redis
|
||||||
|
async for chunk in stream_fn(
|
||||||
|
session_id=entry.session_id,
|
||||||
|
message=entry.message if entry.message else None,
|
||||||
|
is_user_message=entry.is_user_message,
|
||||||
|
user_id=entry.user_id,
|
||||||
|
context=entry.context,
|
||||||
|
):
|
||||||
|
# Check for cancellation
|
||||||
|
if cancel.is_set():
|
||||||
|
log.info("Cancelled during streaming")
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
entry.task_id, StreamError(errorText="Operation cancelled")
|
||||||
|
)
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
entry.task_id, StreamFinishStep()
|
||||||
|
)
|
||||||
|
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
|
||||||
|
await stream_registry.mark_task_completed(
|
||||||
|
entry.task_id, status="failed"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Refresh cluster lock periodically
|
||||||
|
current_time = time.monotonic()
|
||||||
|
if current_time - last_refresh >= refresh_interval:
|
||||||
|
cluster_lock.refresh()
|
||||||
|
last_refresh = current_time
|
||||||
|
|
||||||
|
# Publish chunk to stream registry
|
||||||
|
await stream_registry.publish_chunk(entry.task_id, chunk)
|
||||||
|
|
||||||
|
# Mark task as completed
|
||||||
|
await stream_registry.mark_task_completed(entry.task_id, status="completed")
|
||||||
|
log.info("Task completed successfully")
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
log.info("Task cancelled")
|
||||||
|
await stream_registry.mark_task_completed(entry.task_id, status="failed")
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Task failed: {e}")
|
||||||
|
await self._mark_task_failed(entry.task_id, str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _mark_task_failed(self, task_id: str, error_message: str):
|
||||||
|
"""Mark a task as failed and publish error to stream registry."""
|
||||||
|
try:
|
||||||
|
await stream_registry.publish_chunk(
|
||||||
|
task_id, StreamError(errorText=error_message)
|
||||||
|
)
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||||
|
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||||
|
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to mark task {task_id} as failed: {e}")
|
||||||
207
autogpt_platform/backend/backend/copilot/executor/utils.py
Normal file
207
autogpt_platform/backend/backend/copilot/executor/utils.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""RabbitMQ queue configuration for CoPilot executor.
|
||||||
|
|
||||||
|
Defines two exchanges and queues following the graph executor pattern:
|
||||||
|
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||||
|
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||||
|
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Logging Helper ============ #
|
||||||
|
|
||||||
|
|
||||||
|
class CoPilotLogMetadata(TruncatedLogger):
|
||||||
|
"""Structured logging helper for CoPilot executor.
|
||||||
|
|
||||||
|
In cloud environments (structured logging enabled), uses a simple prefix
|
||||||
|
and passes metadata via json_fields. In local environments, uses a detailed
|
||||||
|
prefix with all metadata key-value pairs for easier debugging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger: The underlying logger instance
|
||||||
|
max_length: Maximum log message length before truncation
|
||||||
|
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
|
||||||
|
These are added to json_fields in cloud mode, or to the prefix in local mode.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
logger: logging.Logger,
|
||||||
|
max_length: int = 1000,
|
||||||
|
**kwargs: str | None,
|
||||||
|
):
|
||||||
|
# Filter out None values
|
||||||
|
metadata = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
metadata["component"] = "CoPilotExecutor"
|
||||||
|
|
||||||
|
if is_structured_logging_enabled():
|
||||||
|
prefix = "[CoPilotExecutor]"
|
||||||
|
else:
|
||||||
|
# Build prefix from metadata key-value pairs
|
||||||
|
meta_parts = "|".join(
|
||||||
|
f"{k}:{v}" for k, v in metadata.items() if k != "component"
|
||||||
|
)
|
||||||
|
prefix = (
|
||||||
|
f"[CoPilotExecutor|{meta_parts}]" if meta_parts else "[CoPilotExecutor]"
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
logger,
|
||||||
|
max_length=max_length,
|
||||||
|
prefix=prefix,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Exchange and Queue Configuration ============ #
|
||||||
|
|
||||||
|
COPILOT_EXECUTION_EXCHANGE = Exchange(
|
||||||
|
name="copilot_execution",
|
||||||
|
type=ExchangeType.DIRECT,
|
||||||
|
durable=True,
|
||||||
|
auto_delete=False,
|
||||||
|
)
|
||||||
|
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue"
|
||||||
|
COPILOT_EXECUTION_ROUTING_KEY = "copilot.run"
|
||||||
|
|
||||||
|
COPILOT_CANCEL_EXCHANGE = Exchange(
|
||||||
|
name="copilot_cancel",
|
||||||
|
type=ExchangeType.FANOUT,
|
||||||
|
durable=True,
|
||||||
|
auto_delete=False,
|
||||||
|
)
|
||||||
|
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
|
||||||
|
|
||||||
|
# CoPilot operations can include extended thinking and agent generation
|
||||||
|
# which may take 30+ minutes to complete
|
||||||
|
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
|
||||||
|
|
||||||
|
# Graceful shutdown timeout - allow in-flight operations to complete
|
||||||
|
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
|
||||||
|
|
||||||
|
|
||||||
|
def create_copilot_queue_config() -> RabbitMQConfig:
|
||||||
|
"""Create RabbitMQ configuration for CoPilot executor.
|
||||||
|
|
||||||
|
Defines two exchanges and queues:
|
||||||
|
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||||
|
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RabbitMQConfig with exchanges and queues defined
|
||||||
|
"""
|
||||||
|
run_queue = Queue(
|
||||||
|
name=COPILOT_EXECUTION_QUEUE_NAME,
|
||||||
|
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||||
|
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||||
|
durable=True,
|
||||||
|
auto_delete=False,
|
||||||
|
arguments={
|
||||||
|
# Extended consumer timeout for long-running LLM operations
|
||||||
|
# Default 30-minute timeout is insufficient for extended thinking
|
||||||
|
# and agent generation which can take 30+ minutes
|
||||||
|
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||||
|
* 1000,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
cancel_queue = Queue(
|
||||||
|
name=COPILOT_CANCEL_QUEUE_NAME,
|
||||||
|
exchange=COPILOT_CANCEL_EXCHANGE,
|
||||||
|
routing_key="", # not used for FANOUT
|
||||||
|
durable=True,
|
||||||
|
auto_delete=False,
|
||||||
|
)
|
||||||
|
return RabbitMQConfig(
|
||||||
|
vhost="/",
|
||||||
|
exchanges=[COPILOT_EXECUTION_EXCHANGE, COPILOT_CANCEL_EXCHANGE],
|
||||||
|
queues=[run_queue, cancel_queue],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Message Models ============ #
|
||||||
|
|
||||||
|
|
||||||
|
class CoPilotExecutionEntry(BaseModel):
|
||||||
|
"""Task payload for CoPilot AI generation.
|
||||||
|
|
||||||
|
This model represents a chat generation task to be processed by the executor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_id: str
|
||||||
|
"""Unique identifier for this task (used for stream registry)"""
|
||||||
|
|
||||||
|
session_id: str
|
||||||
|
"""Chat session ID"""
|
||||||
|
|
||||||
|
user_id: str | None
|
||||||
|
"""User ID (may be None for anonymous users)"""
|
||||||
|
|
||||||
|
operation_id: str
|
||||||
|
"""Operation ID for webhook callbacks and completion tracking"""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
"""User's message to process"""
|
||||||
|
|
||||||
|
is_user_message: bool = True
|
||||||
|
"""Whether the message is from the user (vs system/assistant)"""
|
||||||
|
|
||||||
|
context: dict[str, str] | None = None
|
||||||
|
"""Optional context for the message (e.g., {url: str, content: str})"""
|
||||||
|
|
||||||
|
|
||||||
|
class CancelCoPilotEvent(BaseModel):
|
||||||
|
"""Event to cancel a CoPilot operation."""
|
||||||
|
|
||||||
|
task_id: str
|
||||||
|
"""Task ID to cancel"""
|
||||||
|
|
||||||
|
|
||||||
|
# ============ Queue Publishing Helpers ============ #
|
||||||
|
|
||||||
|
|
||||||
|
async def enqueue_copilot_task(
|
||||||
|
task_id: str,
|
||||||
|
session_id: str,
|
||||||
|
user_id: str | None,
|
||||||
|
operation_id: str,
|
||||||
|
message: str,
|
||||||
|
is_user_message: bool = True,
|
||||||
|
context: dict[str, str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Enqueue a CoPilot task for processing by the executor service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: Unique identifier for this task (used for stream registry)
|
||||||
|
session_id: Chat session ID
|
||||||
|
user_id: User ID (may be None for anonymous users)
|
||||||
|
operation_id: Operation ID for webhook callbacks and completion tracking
|
||||||
|
message: User's message to process
|
||||||
|
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||||
|
context: Optional context for the message (e.g., {url: str, content: str})
|
||||||
|
"""
|
||||||
|
from backend.util.clients import get_async_copilot_queue
|
||||||
|
|
||||||
|
entry = CoPilotExecutionEntry(
|
||||||
|
task_id=task_id,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
operation_id=operation_id,
|
||||||
|
message=message,
|
||||||
|
is_user_message=is_user_message,
|
||||||
|
context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
queue_client = await get_async_copilot_queue()
|
||||||
|
await queue_client.publish_message(
|
||||||
|
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||||
|
message=entry.model_dump_json(),
|
||||||
|
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||||
|
)
|
||||||
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any, cast
|
from typing import Any, Self, cast
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
@@ -23,26 +23,17 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
|||||||
from prisma.models import ChatSession as PrismaChatSession
|
from prisma.models import ChatSession as PrismaChatSession
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from backend.data.db_accessors import chat_db
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.exceptions import DatabaseError, RedisError
|
from backend.util.exceptions import DatabaseError, RedisError
|
||||||
|
|
||||||
from . import db as chat_db
|
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
|
|
||||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
|
||||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
|
||||||
if value is None:
|
|
||||||
return default
|
|
||||||
if isinstance(value, str):
|
|
||||||
return json.loads(value)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
# Redis cache key prefix for chat sessions
|
# Redis cache key prefix for chat sessions
|
||||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||||
|
|
||||||
@@ -52,28 +43,7 @@ def _get_session_cache_key(session_id: str) -> str:
|
|||||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||||
|
|
||||||
|
|
||||||
# Session-level locks to prevent race conditions during concurrent upserts.
|
# ===================== Chat data models ===================== #
|
||||||
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
|
||||||
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
|
||||||
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
|
||||||
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
|
||||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
|
||||||
_session_locks_mutex = asyncio.Lock()
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
|
||||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
|
||||||
|
|
||||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
|
||||||
when no coroutine holds a reference to them, preventing memory leaks from
|
|
||||||
unbounded growth of session locks.
|
|
||||||
"""
|
|
||||||
async with _session_locks_mutex:
|
|
||||||
lock = _session_locks.get(session_id)
|
|
||||||
if lock is None:
|
|
||||||
lock = asyncio.Lock()
|
|
||||||
_session_locks[session_id] = lock
|
|
||||||
return lock
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
@@ -85,6 +55,19 @@ class ChatMessage(BaseModel):
|
|||||||
tool_calls: list[dict] | None = None
|
tool_calls: list[dict] | None = None
|
||||||
function_call: dict | None = None
|
function_call: dict | None = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
|
||||||
|
"""Convert a Prisma ChatMessage to a Pydantic ChatMessage."""
|
||||||
|
return ChatMessage(
|
||||||
|
role=prisma_message.role,
|
||||||
|
content=prisma_message.content,
|
||||||
|
name=prisma_message.name,
|
||||||
|
tool_call_id=prisma_message.toolCallId,
|
||||||
|
refusal=prisma_message.refusal,
|
||||||
|
tool_calls=_parse_json_field(prisma_message.toolCalls),
|
||||||
|
function_call=_parse_json_field(prisma_message.functionCall),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Usage(BaseModel):
|
class Usage(BaseModel):
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
@@ -92,11 +75,10 @@ class Usage(BaseModel):
|
|||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
class ChatSession(BaseModel):
|
class ChatSessionInfo(BaseModel):
|
||||||
session_id: str
|
session_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
title: str | None = None
|
title: str | None = None
|
||||||
messages: list[ChatMessage]
|
|
||||||
usage: list[Usage]
|
usage: list[Usage]
|
||||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||||
started_at: datetime
|
started_at: datetime
|
||||||
@@ -104,60 +86,9 @@ class ChatSession(BaseModel):
|
|||||||
successful_agent_runs: dict[str, int] = {}
|
successful_agent_runs: dict[str, int] = {}
|
||||||
successful_agent_schedules: dict[str, int] = {}
|
successful_agent_schedules: dict[str, int] = {}
|
||||||
|
|
||||||
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
@classmethod
|
||||||
"""Attach a tool_call to the current turn's assistant message.
|
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
||||||
|
"""Convert Prisma ChatSession to Pydantic ChatSession."""
|
||||||
Searches backwards for the most recent assistant message (stopping at
|
|
||||||
any user message boundary). If found, appends the tool_call to it.
|
|
||||||
Otherwise creates a new assistant message with the tool_call.
|
|
||||||
"""
|
|
||||||
for msg in reversed(self.messages):
|
|
||||||
if msg.role == "user":
|
|
||||||
break
|
|
||||||
if msg.role == "assistant":
|
|
||||||
if not msg.tool_calls:
|
|
||||||
msg.tool_calls = []
|
|
||||||
msg.tool_calls.append(tool_call)
|
|
||||||
return
|
|
||||||
|
|
||||||
self.messages.append(
|
|
||||||
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def new(user_id: str) -> "ChatSession":
|
|
||||||
return ChatSession(
|
|
||||||
session_id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
title=None,
|
|
||||||
messages=[],
|
|
||||||
usage=[],
|
|
||||||
credentials={},
|
|
||||||
started_at=datetime.now(UTC),
|
|
||||||
updated_at=datetime.now(UTC),
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_db(
|
|
||||||
prisma_session: PrismaChatSession,
|
|
||||||
prisma_messages: list[PrismaChatMessage] | None = None,
|
|
||||||
) -> "ChatSession":
|
|
||||||
"""Convert Prisma models to Pydantic ChatSession."""
|
|
||||||
messages = []
|
|
||||||
if prisma_messages:
|
|
||||||
for msg in prisma_messages:
|
|
||||||
messages.append(
|
|
||||||
ChatMessage(
|
|
||||||
role=msg.role,
|
|
||||||
content=msg.content,
|
|
||||||
name=msg.name,
|
|
||||||
tool_call_id=msg.toolCallId,
|
|
||||||
refusal=msg.refusal,
|
|
||||||
tool_calls=_parse_json_field(msg.toolCalls),
|
|
||||||
function_call=_parse_json_field(msg.functionCall),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Parse JSON fields from Prisma
|
# Parse JSON fields from Prisma
|
||||||
credentials = _parse_json_field(prisma_session.credentials, default={})
|
credentials = _parse_json_field(prisma_session.credentials, default={})
|
||||||
successful_agent_runs = _parse_json_field(
|
successful_agent_runs = _parse_json_field(
|
||||||
@@ -179,11 +110,10 @@ class ChatSession(BaseModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatSession(
|
return cls(
|
||||||
session_id=prisma_session.id,
|
session_id=prisma_session.id,
|
||||||
user_id=prisma_session.userId,
|
user_id=prisma_session.userId,
|
||||||
title=prisma_session.title,
|
title=prisma_session.title,
|
||||||
messages=messages,
|
|
||||||
usage=usage,
|
usage=usage,
|
||||||
credentials=credentials,
|
credentials=credentials,
|
||||||
started_at=prisma_session.createdAt,
|
started_at=prisma_session.createdAt,
|
||||||
@@ -192,47 +122,56 @@ class ChatSession(BaseModel):
|
|||||||
successful_agent_schedules=successful_agent_schedules,
|
successful_agent_schedules=successful_agent_schedules,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _merge_consecutive_assistant_messages(
|
|
||||||
messages: list[ChatCompletionMessageParam],
|
|
||||||
) -> list[ChatCompletionMessageParam]:
|
|
||||||
"""Merge consecutive assistant messages into single messages.
|
|
||||||
|
|
||||||
Long-running tool flows can create split assistant messages: one with
|
class ChatSession(ChatSessionInfo):
|
||||||
text content and another with tool_calls. Anthropic's API requires
|
messages: list[ChatMessage]
|
||||||
tool_result blocks to reference a tool_use in the immediately preceding
|
|
||||||
assistant message, so these splits cause 400 errors via OpenRouter.
|
@classmethod
|
||||||
|
def new(cls, user_id: str) -> Self:
|
||||||
|
return cls(
|
||||||
|
session_id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
title=None,
|
||||||
|
messages=[],
|
||||||
|
usage=[],
|
||||||
|
credentials={},
|
||||||
|
started_at=datetime.now(UTC),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
||||||
|
"""Convert Prisma ChatSession to Pydantic ChatSession."""
|
||||||
|
if prisma_session.Messages is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Prisma session {prisma_session.id} is missing Messages relation"
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
**ChatSessionInfo.from_db(prisma_session).model_dump(),
|
||||||
|
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||||
|
"""Attach a tool_call to the current turn's assistant message.
|
||||||
|
|
||||||
|
Searches backwards for the most recent assistant message (stopping at
|
||||||
|
any user message boundary). If found, appends the tool_call to it.
|
||||||
|
Otherwise creates a new assistant message with the tool_call.
|
||||||
"""
|
"""
|
||||||
if len(messages) < 2:
|
for msg in reversed(self.messages):
|
||||||
return messages
|
if msg.role == "user":
|
||||||
|
break
|
||||||
|
if msg.role == "assistant":
|
||||||
|
if not msg.tool_calls:
|
||||||
|
msg.tool_calls = []
|
||||||
|
msg.tool_calls.append(tool_call)
|
||||||
|
return
|
||||||
|
|
||||||
result: list[ChatCompletionMessageParam] = [messages[0]]
|
self.messages.append(
|
||||||
for msg in messages[1:]:
|
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
||||||
prev = result[-1]
|
|
||||||
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
|
||||||
result.append(msg)
|
|
||||||
continue
|
|
||||||
|
|
||||||
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
|
||||||
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
|
||||||
|
|
||||||
curr_content = curr.get("content") or ""
|
|
||||||
if curr_content:
|
|
||||||
prev_content = prev.get("content") or ""
|
|
||||||
prev["content"] = (
|
|
||||||
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
|
||||||
)
|
)
|
||||||
|
|
||||||
curr_tool_calls = curr.get("tool_calls")
|
|
||||||
if curr_tool_calls:
|
|
||||||
prev_tool_calls = prev.get("tool_calls")
|
|
||||||
prev["tool_calls"] = (
|
|
||||||
list(prev_tool_calls) + list(curr_tool_calls)
|
|
||||||
if prev_tool_calls
|
|
||||||
else list(curr_tool_calls)
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||||
messages = []
|
messages = []
|
||||||
for message in self.messages:
|
for message in self.messages:
|
||||||
@@ -321,38 +260,68 @@ class ChatSession(BaseModel):
|
|||||||
)
|
)
|
||||||
return self._merge_consecutive_assistant_messages(messages)
|
return self._merge_consecutive_assistant_messages(messages)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_consecutive_assistant_messages(
|
||||||
|
messages: list[ChatCompletionMessageParam],
|
||||||
|
) -> list[ChatCompletionMessageParam]:
|
||||||
|
"""Merge consecutive assistant messages into single messages.
|
||||||
|
|
||||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
Long-running tool flows can create split assistant messages: one with
|
||||||
"""Get a chat session from Redis cache."""
|
text content and another with tool_calls. Anthropic's API requires
|
||||||
redis_key = _get_session_cache_key(session_id)
|
tool_result blocks to reference a tool_use in the immediately preceding
|
||||||
async_redis = await get_redis_async()
|
assistant message, so these splits cause 400 errors via OpenRouter.
|
||||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
"""
|
||||||
|
if len(messages) < 2:
|
||||||
|
return messages
|
||||||
|
|
||||||
if raw_session is None:
|
result: list[ChatCompletionMessageParam] = [messages[0]]
|
||||||
return None
|
for msg in messages[1:]:
|
||||||
|
prev = result[-1]
|
||||||
|
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
||||||
|
result.append(msg)
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
||||||
session = ChatSession.model_validate_json(raw_session)
|
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
||||||
logger.info(
|
|
||||||
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
curr_content = curr.get("content") or ""
|
||||||
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
if curr_content:
|
||||||
|
prev_content = prev.get("content") or ""
|
||||||
|
prev["content"] = (
|
||||||
|
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
||||||
)
|
)
|
||||||
return session
|
|
||||||
except Exception as e:
|
curr_tool_calls = curr.get("tool_calls")
|
||||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
if curr_tool_calls:
|
||||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
prev_tool_calls = prev.get("tool_calls")
|
||||||
|
prev["tool_calls"] = (
|
||||||
|
list(prev_tool_calls) + list(curr_tool_calls)
|
||||||
|
if prev_tool_calls
|
||||||
|
else list(curr_tool_calls)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _cache_session(session: ChatSession) -> None:
|
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||||
"""Cache a chat session in Redis."""
|
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||||
redis_key = _get_session_cache_key(session.session_id)
|
if value is None:
|
||||||
async_redis = await get_redis_async()
|
return default
|
||||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
if isinstance(value, str):
|
||||||
|
return json.loads(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
# ================ Chat cache + DB operations ================ #
|
||||||
|
|
||||||
|
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
|
||||||
|
# connected directly.
|
||||||
|
|
||||||
|
|
||||||
async def cache_chat_session(session: ChatSession) -> None:
|
async def cache_chat_session(session: ChatSession) -> None:
|
||||||
"""Cache a chat session without persisting to the database."""
|
"""Cache a chat session in Redis (without persisting to the database)."""
|
||||||
await _cache_session(session)
|
redis_key = _get_session_cache_key(session.session_id)
|
||||||
|
async_redis = await get_redis_async()
|
||||||
|
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||||
|
|
||||||
|
|
||||||
async def invalidate_session_cache(session_id: str) -> None:
|
async def invalidate_session_cache(session_id: str) -> None:
|
||||||
@@ -370,77 +339,6 @@ async def invalidate_session_cache(session_id: str) -> None:
|
|||||||
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
|
||||||
"""Get a chat session from the database."""
|
|
||||||
prisma_session = await chat_db.get_chat_session(session_id)
|
|
||||||
if not prisma_session:
|
|
||||||
return None
|
|
||||||
|
|
||||||
messages = prisma_session.Messages
|
|
||||||
logger.debug(
|
|
||||||
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
|
||||||
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
|
||||||
)
|
|
||||||
|
|
||||||
return ChatSession.from_db(prisma_session, messages)
|
|
||||||
|
|
||||||
|
|
||||||
async def _save_session_to_db(
|
|
||||||
session: ChatSession, existing_message_count: int
|
|
||||||
) -> None:
|
|
||||||
"""Save or update a chat session in the database."""
|
|
||||||
# Check if session exists in DB
|
|
||||||
existing = await chat_db.get_chat_session(session.session_id)
|
|
||||||
|
|
||||||
if not existing:
|
|
||||||
# Create new session
|
|
||||||
await chat_db.create_chat_session(
|
|
||||||
session_id=session.session_id,
|
|
||||||
user_id=session.user_id,
|
|
||||||
)
|
|
||||||
existing_message_count = 0
|
|
||||||
|
|
||||||
# Calculate total tokens from usage
|
|
||||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
|
||||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
|
||||||
|
|
||||||
# Update session metadata
|
|
||||||
await chat_db.update_chat_session(
|
|
||||||
session_id=session.session_id,
|
|
||||||
credentials=session.credentials,
|
|
||||||
successful_agent_runs=session.successful_agent_runs,
|
|
||||||
successful_agent_schedules=session.successful_agent_schedules,
|
|
||||||
total_prompt_tokens=total_prompt,
|
|
||||||
total_completion_tokens=total_completion,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add new messages (only those after existing count)
|
|
||||||
new_messages = session.messages[existing_message_count:]
|
|
||||||
if new_messages:
|
|
||||||
messages_data = []
|
|
||||||
for msg in new_messages:
|
|
||||||
messages_data.append(
|
|
||||||
{
|
|
||||||
"role": msg.role,
|
|
||||||
"content": msg.content,
|
|
||||||
"name": msg.name,
|
|
||||||
"tool_call_id": msg.tool_call_id,
|
|
||||||
"refusal": msg.refusal,
|
|
||||||
"tool_calls": msg.tool_calls,
|
|
||||||
"function_call": msg.function_call,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
|
||||||
f"roles={[m['role'] for m in messages_data]}"
|
|
||||||
)
|
|
||||||
await chat_db.add_chat_messages_batch(
|
|
||||||
session_id=session.session_id,
|
|
||||||
messages=messages_data,
|
|
||||||
start_sequence=existing_message_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_chat_session(
|
async def get_chat_session(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
@@ -488,16 +386,53 @@ async def get_chat_session(
|
|||||||
|
|
||||||
# Cache the session from DB
|
# Cache the session from DB
|
||||||
try:
|
try:
|
||||||
await _cache_session(session)
|
await cache_chat_session(session)
|
||||||
|
logger.info(f"Cached session {session_id} from database")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||||
|
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
async def upsert_chat_session(
|
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||||
session: ChatSession,
|
"""Get a chat session from Redis cache."""
|
||||||
) -> ChatSession:
|
redis_key = _get_session_cache_key(session_id)
|
||||||
|
async_redis = await get_redis_async()
|
||||||
|
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||||
|
|
||||||
|
if raw_session is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
session = ChatSession.model_validate_json(raw_session)
|
||||||
|
logger.info(
|
||||||
|
f"Loading session {session_id} from cache: "
|
||||||
|
f"message_count={len(session.messages)}, "
|
||||||
|
f"roles={[m.role for m in session.messages]}"
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||||
|
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||||
|
"""Get a chat session from the database."""
|
||||||
|
session = await chat_db().get_chat_session(session_id)
|
||||||
|
if not session:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Loaded session {session_id} from DB: "
|
||||||
|
f"has_messages={bool(session.messages)}, "
|
||||||
|
f"message_count={len(session.messages)}, "
|
||||||
|
f"roles={[m.role for m in session.messages]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def upsert_chat_session(session: ChatSession) -> ChatSession:
|
||||||
"""Update a chat session in both cache and database.
|
"""Update a chat session in both cache and database.
|
||||||
|
|
||||||
Uses session-level locking to prevent race conditions when concurrent
|
Uses session-level locking to prevent race conditions when concurrent
|
||||||
@@ -515,7 +450,7 @@ async def upsert_chat_session(
|
|||||||
|
|
||||||
async with lock:
|
async with lock:
|
||||||
# Get existing message count from DB for incremental saves
|
# Get existing message count from DB for incremental saves
|
||||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
existing_message_count = await chat_db().get_chat_session_message_count(
|
||||||
session.session_id
|
session.session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -532,7 +467,7 @@ async def upsert_chat_session(
|
|||||||
|
|
||||||
# Save to cache (best-effort, even if DB failed)
|
# Save to cache (best-effort, even if DB failed)
|
||||||
try:
|
try:
|
||||||
await _cache_session(session)
|
await cache_chat_session(session)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# If DB succeeded but cache failed, raise cache error
|
# If DB succeeded but cache failed, raise cache error
|
||||||
if db_error is None:
|
if db_error is None:
|
||||||
@@ -553,6 +488,65 @@ async def upsert_chat_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def _save_session_to_db(
|
||||||
|
session: ChatSession, existing_message_count: int
|
||||||
|
) -> None:
|
||||||
|
"""Save or update a chat session in the database."""
|
||||||
|
db = chat_db()
|
||||||
|
|
||||||
|
# Check if session exists in DB
|
||||||
|
existing = await db.get_chat_session(session.session_id)
|
||||||
|
|
||||||
|
if not existing:
|
||||||
|
# Create new session
|
||||||
|
await db.create_chat_session(
|
||||||
|
session_id=session.session_id,
|
||||||
|
user_id=session.user_id,
|
||||||
|
)
|
||||||
|
existing_message_count = 0
|
||||||
|
|
||||||
|
# Calculate total tokens from usage
|
||||||
|
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||||
|
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||||
|
|
||||||
|
# Update session metadata
|
||||||
|
await db.update_chat_session(
|
||||||
|
session_id=session.session_id,
|
||||||
|
credentials=session.credentials,
|
||||||
|
successful_agent_runs=session.successful_agent_runs,
|
||||||
|
successful_agent_schedules=session.successful_agent_schedules,
|
||||||
|
total_prompt_tokens=total_prompt,
|
||||||
|
total_completion_tokens=total_completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add new messages (only those after existing count)
|
||||||
|
new_messages = session.messages[existing_message_count:]
|
||||||
|
if new_messages:
|
||||||
|
messages_data = []
|
||||||
|
for msg in new_messages:
|
||||||
|
messages_data.append(
|
||||||
|
{
|
||||||
|
"role": msg.role,
|
||||||
|
"content": msg.content,
|
||||||
|
"name": msg.name,
|
||||||
|
"tool_call_id": msg.tool_call_id,
|
||||||
|
"refusal": msg.refusal,
|
||||||
|
"tool_calls": msg.tool_calls,
|
||||||
|
"function_call": msg.function_call,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||||
|
f"roles={[m['role'] for m in messages_data]}, "
|
||||||
|
f"start_sequence={existing_message_count}"
|
||||||
|
)
|
||||||
|
await db.add_chat_messages_batch(
|
||||||
|
session_id=session.session_id,
|
||||||
|
messages=messages_data,
|
||||||
|
start_sequence=existing_message_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||||
"""Atomically append a message to a session and persist it.
|
"""Atomically append a message to a session and persist it.
|
||||||
|
|
||||||
@@ -568,7 +562,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
|||||||
raise ValueError(f"Session {session_id} not found")
|
raise ValueError(f"Session {session_id} not found")
|
||||||
|
|
||||||
session.messages.append(message)
|
session.messages.append(message)
|
||||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
existing_message_count = await chat_db().get_chat_session_message_count(
|
||||||
session_id
|
session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -580,7 +574,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await _cache_session(session)
|
await cache_chat_session(session)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
||||||
|
|
||||||
@@ -599,7 +593,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
|||||||
|
|
||||||
# Create in database first - fail fast if this fails
|
# Create in database first - fail fast if this fails
|
||||||
try:
|
try:
|
||||||
await chat_db.create_chat_session(
|
await chat_db().create_chat_session(
|
||||||
session_id=session.session_id,
|
session_id=session.session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
@@ -611,7 +605,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
|||||||
|
|
||||||
# Cache the session (best-effort optimization, DB is source of truth)
|
# Cache the session (best-effort optimization, DB is source of truth)
|
||||||
try:
|
try:
|
||||||
await _cache_session(session)
|
await cache_chat_session(session)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||||
|
|
||||||
@@ -622,20 +616,16 @@ async def get_user_sessions(
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> tuple[list[ChatSession], int]:
|
) -> tuple[list[ChatSessionInfo], int]:
|
||||||
"""Get chat sessions for a user from the database with total count.
|
"""Get chat sessions for a user from the database with total count.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (sessions, total_count) where total_count is the overall
|
A tuple of (sessions, total_count) where total_count is the overall
|
||||||
number of sessions for the user (not just the current page).
|
number of sessions for the user (not just the current page).
|
||||||
"""
|
"""
|
||||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
db = chat_db()
|
||||||
total_count = await chat_db.get_user_session_count(user_id)
|
sessions = await db.get_user_chat_sessions(user_id, limit, offset)
|
||||||
|
total_count = await db.get_user_session_count(user_id)
|
||||||
sessions = []
|
|
||||||
for prisma_session in prisma_sessions:
|
|
||||||
# Convert without messages for listing (lighter weight)
|
|
||||||
sessions.append(ChatSession.from_db(prisma_session, None))
|
|
||||||
|
|
||||||
return sessions, total_count
|
return sessions, total_count
|
||||||
|
|
||||||
@@ -653,7 +643,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
|||||||
"""
|
"""
|
||||||
# Delete from database first (with optional user_id validation)
|
# Delete from database first (with optional user_id validation)
|
||||||
# This confirms ownership before invalidating cache
|
# This confirms ownership before invalidating cache
|
||||||
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
deleted = await chat_db().delete_chat_session(session_id, user_id)
|
||||||
|
|
||||||
if not deleted:
|
if not deleted:
|
||||||
return False
|
return False
|
||||||
@@ -688,7 +678,7 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
|||||||
True if updated successfully, False otherwise.
|
True if updated successfully, False otherwise.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
||||||
if result is None:
|
if result is None:
|
||||||
logger.warning(f"Session {session_id} not found for title update")
|
logger.warning(f"Session {session_id} not found for title update")
|
||||||
return False
|
return False
|
||||||
@@ -700,7 +690,7 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
|||||||
cached = await _get_session_from_cache(session_id)
|
cached = await _get_session_from_cache(session_id)
|
||||||
if cached:
|
if cached:
|
||||||
cached.title = title
|
cached.title = title
|
||||||
await _cache_session(cached)
|
await cache_chat_session(cached)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Not critical - title will be correct on next full cache refresh
|
# Not critical - title will be correct on next full cache refresh
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -711,3 +701,29 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Chat session locks ==================== #
|
||||||
|
|
||||||
|
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||||
|
_session_locks_mutex = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||||
|
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||||
|
|
||||||
|
This was originally added to solve the specific problem of race conditions between
|
||||||
|
the session title thread and the conversation thread, which always occurs on the
|
||||||
|
same instance as we prevent rapid request sends on the frontend.
|
||||||
|
|
||||||
|
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||||
|
when no coroutine holds a reference to them, preventing memory leaks from
|
||||||
|
unbounded growth of session locks. Explicit cleanup also occurs
|
||||||
|
in `delete_chat_session()`.
|
||||||
|
"""
|
||||||
|
async with _session_locks_mutex:
|
||||||
|
lock = _session_locks.get(session_id)
|
||||||
|
if lock is None:
|
||||||
|
lock = asyncio.Lock()
|
||||||
|
_session_locks[session_id] = lock
|
||||||
|
return lock
|
||||||
@@ -20,7 +20,7 @@ from claude_agent_sdk import (
|
|||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
from backend.api.features.chat.response_model import (
|
from backend.copilot.response_model import (
|
||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
@@ -34,10 +34,8 @@ from backend.api.features.chat.response_model import (
|
|||||||
StreamToolInputStart,
|
StreamToolInputStart,
|
||||||
StreamToolOutputAvailable,
|
StreamToolOutputAvailable,
|
||||||
)
|
)
|
||||||
from backend.api.features.chat.sdk.tool_adapter import (
|
|
||||||
MCP_TOOL_PREFIX,
|
from .tool_adapter import MCP_TOOL_PREFIX, pop_pending_tool_output
|
||||||
pop_pending_tool_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -10,7 +10,7 @@ from claude_agent_sdk import (
|
|||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
from backend.api.features.chat.response_model import (
|
from backend.copilot.response_model import (
|
||||||
StreamBaseResponse,
|
StreamBaseResponse,
|
||||||
StreamError,
|
StreamError,
|
||||||
StreamFinish,
|
StreamFinish,
|
||||||
@@ -11,7 +11,7 @@ import re
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from backend.api.features.chat.sdk.tool_adapter import (
|
from .tool_adapter import (
|
||||||
BLOCKED_TOOLS,
|
BLOCKED_TOOLS,
|
||||||
DANGEROUS_PATTERNS,
|
DANGEROUS_PATTERNS,
|
||||||
MCP_TOOL_PREFIX,
|
MCP_TOOL_PREFIX,
|
||||||
@@ -1,4 +1,9 @@
|
|||||||
"""Unit tests for SDK security hooks."""
|
"""Tests for SDK security hooks — workspace paths, tool access, and deny messages.
|
||||||
|
|
||||||
|
These are pure unit tests with no external dependencies (no SDK, no DB, no server).
|
||||||
|
They validate that the security hooks correctly block unauthorized paths,
|
||||||
|
tool access, and dangerous input patterns.
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -12,6 +17,10 @@ def _is_denied(result: dict) -> bool:
|
|||||||
return hook.get("permissionDecision") == "deny"
|
return hook.get("permissionDecision") == "deny"
|
||||||
|
|
||||||
|
|
||||||
|
def _reason(result: dict) -> str:
|
||||||
|
return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "")
|
||||||
|
|
||||||
|
|
||||||
# -- Blocked tools -----------------------------------------------------------
|
# -- Blocked tools -----------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -163,3 +172,19 @@ def test_non_workspace_tool_passes_isolation():
|
|||||||
"find_agent", {"query": "email"}, user_id="user-1"
|
"find_agent", {"query": "email"}, user_id="user-1"
|
||||||
)
|
)
|
||||||
assert result == {}
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
# -- Deny message quality ----------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_blocked_tool_message_clarity():
|
||||||
|
"""Deny messages must include [SECURITY] and 'cannot be bypassed'."""
|
||||||
|
reason = _reason(_validate_tool_access("bash", {}))
|
||||||
|
assert "[SECURITY]" in reason
|
||||||
|
assert "cannot be bypassed" in reason
|
||||||
|
|
||||||
|
|
||||||
|
def test_bash_builtin_blocked_message_clarity():
|
||||||
|
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
|
||||||
|
assert "[SECURITY]" in reason
|
||||||
|
assert "cannot be bypassed" in reason
|
||||||
@@ -255,7 +255,7 @@ def _build_sdk_env() -> dict[str, str]:
|
|||||||
def _make_sdk_cwd(session_id: str) -> str:
|
def _make_sdk_cwd(session_id: str) -> str:
|
||||||
"""Create a safe, session-specific working directory path.
|
"""Create a safe, session-specific working directory path.
|
||||||
|
|
||||||
Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path`
|
Delegates to :func:`~backend.copilot.tools.sandbox.make_session_path`
|
||||||
(single source of truth for path sanitization) and adds a defence-in-depth
|
(single source of truth for path sanitization) and adds a defence-in-depth
|
||||||
assertion.
|
assertion.
|
||||||
"""
|
"""
|
||||||
@@ -440,12 +440,16 @@ async def stream_chat_completion_sdk(
|
|||||||
f"Session {session_id} not found. Please create a new session first."
|
f"Session {session_id} not found. Please create a new session first."
|
||||||
)
|
)
|
||||||
|
|
||||||
if message:
|
# Append the new message to the session if it's not already there
|
||||||
session.messages.append(
|
new_message_role = "user" if is_user_message else "assistant"
|
||||||
ChatMessage(
|
if message and (
|
||||||
role="user" if is_user_message else "assistant", content=message
|
len(session.messages) == 0
|
||||||
)
|
or not (
|
||||||
|
session.messages[-1].role == new_message_role
|
||||||
|
and session.messages[-1].content == message
|
||||||
)
|
)
|
||||||
|
):
|
||||||
|
session.messages.append(ChatMessage(role=new_message_role, content=message))
|
||||||
if is_user_message:
|
if is_user_message:
|
||||||
track_user_message(
|
track_user_message(
|
||||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||||
@@ -689,11 +693,15 @@ async def stream_chat_completion_sdk(
|
|||||||
await asyncio.sleep(0.5)
|
await asyncio.sleep(0.5)
|
||||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||||
if raw_transcript:
|
if raw_transcript:
|
||||||
task = asyncio.create_task(
|
try:
|
||||||
_upload_transcript_bg(user_id, session_id, raw_transcript)
|
async with asyncio.timeout(30):
|
||||||
|
await _upload_transcript_bg(
|
||||||
|
user_id, session_id, raw_transcript
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
f"[SDK] Transcript upload timed out for {session_id}"
|
||||||
)
|
)
|
||||||
_background_tasks.add(task)
|
|
||||||
task.add_done_callback(_background_tasks.discard)
|
|
||||||
else:
|
else:
|
||||||
logger.debug("[SDK] Stop hook fired but transcript not usable")
|
logger.debug("[SDK] Stop hook fired but transcript not usable")
|
||||||
|
|
||||||
@@ -18,9 +18,9 @@ from collections.abc import Awaitable, Callable
|
|||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.api.features.chat.tools import TOOL_REGISTRY
|
from backend.copilot.tools import TOOL_REGISTRY
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
from backend.copilot.tools.base import BaseTool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from backend.api.features.chat.sdk.transcript import (
|
from .transcript import (
|
||||||
STRIPPABLE_TYPES,
|
STRIPPABLE_TYPES,
|
||||||
read_transcript_file,
|
read_transcript_file,
|
||||||
strip_progress_entries,
|
strip_progress_entries,
|
||||||
@@ -27,20 +27,18 @@ from openai.types.chat import (
|
|||||||
ChatCompletionToolParam,
|
ChatCompletionToolParam,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from backend.data.db_accessors import chat_db, understanding_db
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.data.understanding import (
|
from backend.data.understanding import format_understanding_for_prompt
|
||||||
format_understanding_for_prompt,
|
|
||||||
get_business_understanding,
|
|
||||||
)
|
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.settings import AppEnvironment, Settings
|
from backend.util.settings import AppEnvironment, Settings
|
||||||
|
|
||||||
from . import db as chat_db
|
|
||||||
from . import stream_registry
|
from . import stream_registry
|
||||||
from .config import ChatConfig
|
from .config import ChatConfig
|
||||||
from .model import (
|
from .model import (
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
ChatSession,
|
ChatSession,
|
||||||
|
ChatSessionInfo,
|
||||||
Usage,
|
Usage,
|
||||||
cache_chat_session,
|
cache_chat_session,
|
||||||
get_chat_session,
|
get_chat_session,
|
||||||
@@ -263,7 +261,7 @@ async def _build_system_prompt(
|
|||||||
understanding = None
|
understanding = None
|
||||||
if user_id:
|
if user_id:
|
||||||
try:
|
try:
|
||||||
understanding = await get_business_understanding(user_id)
|
understanding = await understanding_db().get_business_understanding(user_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to fetch business understanding: {e}")
|
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||||
understanding = None
|
understanding = None
|
||||||
@@ -339,7 +337,7 @@ async def _generate_session_title(
|
|||||||
async def assign_user_to_session(
|
async def assign_user_to_session(
|
||||||
session_id: str,
|
session_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
) -> ChatSession:
|
) -> ChatSessionInfo:
|
||||||
"""
|
"""
|
||||||
Assign a user to a chat session.
|
Assign a user to a chat session.
|
||||||
"""
|
"""
|
||||||
@@ -428,12 +426,16 @@ async def stream_chat_completion(
|
|||||||
f"Session {session_id} not found. Please create a new session first."
|
f"Session {session_id} not found. Please create a new session first."
|
||||||
)
|
)
|
||||||
|
|
||||||
if message:
|
# Append the new message to the session if it's not already there
|
||||||
session.messages.append(
|
new_message_role = "user" if is_user_message else "assistant"
|
||||||
ChatMessage(
|
if message and (
|
||||||
role="user" if is_user_message else "assistant", content=message
|
len(session.messages) == 0
|
||||||
)
|
or not (
|
||||||
|
session.messages[-1].role == new_message_role
|
||||||
|
and session.messages[-1].content == message
|
||||||
)
|
)
|
||||||
|
):
|
||||||
|
session.messages.append(ChatMessage(role=new_message_role, content=message))
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Appended message (role={'user' if is_user_message else 'assistant'}), "
|
f"Appended message (role={'user' if is_user_message else 'assistant'}), "
|
||||||
f"new message_count={len(session.messages)}"
|
f"new message_count={len(session.messages)}"
|
||||||
@@ -1770,7 +1772,7 @@ async def _update_pending_operation(
|
|||||||
This is called by background tasks when long-running operations complete.
|
This is called by background tasks when long-running operations complete.
|
||||||
"""
|
"""
|
||||||
# Update the message in database
|
# Update the message in database
|
||||||
updated = await chat_db.update_tool_message_content(
|
updated = await chat_db().update_tool_message_content(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
new_content=result,
|
new_content=result,
|
||||||
@@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any
|
|||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.api.features.chat.tracking import track_tool_called
|
from backend.copilot.tracking import track_tool_called
|
||||||
|
|
||||||
from .add_understanding import AddUnderstandingTool
|
from .add_understanding import AddUnderstandingTool
|
||||||
from .agent_output import AgentOutputTool
|
from .agent_output import AgentOutputTool
|
||||||
@@ -31,7 +31,7 @@ from .workspace_files import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -6,11 +6,11 @@ import pytest
|
|||||||
from prisma.types import ProfileCreateInput
|
from prisma.types import ProfileCreateInput
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.store import db as store_db
|
from backend.api.features.store import db as store_db
|
||||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||||
from backend.blocks.llm import AITextGeneratorBlock
|
from backend.blocks.llm import AITextGeneratorBlock
|
||||||
|
from backend.copilot.model import ChatSession
|
||||||
from backend.data.db import prisma
|
from backend.data.db import prisma
|
||||||
from backend.data.graph import Graph, Link, Node, create_graph
|
from backend.data.graph import Graph, Link, Node, create_graph
|
||||||
from backend.data.model import APIKeyCredentials
|
from backend.data.model import APIKeyCredentials
|
||||||
@@ -3,11 +3,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.data.understanding import (
|
from backend.data.db_accessors import understanding_db
|
||||||
BusinessUnderstandingInput,
|
from backend.data.understanding import BusinessUnderstandingInput
|
||||||
upsert_business_understanding,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||||
@@ -99,7 +97,9 @@ and automations for the user's specific needs."""
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Upsert with merge
|
# Upsert with merge
|
||||||
understanding = await upsert_business_understanding(user_id, input_data)
|
understanding = await understanding_db().upsert_business_understanding(
|
||||||
|
user_id, input_data
|
||||||
|
)
|
||||||
|
|
||||||
# Build current understanding summary (filter out empty values)
|
# Build current understanding summary (filter out empty values)
|
||||||
current_understanding = {
|
current_understanding = {
|
||||||
@@ -5,9 +5,8 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any, NotRequired, TypedDict
|
from typing import Any, NotRequired, TypedDict
|
||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.data.db_accessors import graph_db, library_db, store_db
|
||||||
from backend.api.features.store import db as store_db
|
from backend.data.graph import Graph, Link, Node
|
||||||
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .service import (
|
from .service import (
|
||||||
@@ -145,8 +144,9 @@ async def get_library_agent_by_id(
|
|||||||
Returns:
|
Returns:
|
||||||
LibraryAgentSummary if found, None otherwise
|
LibraryAgentSummary if found, None otherwise
|
||||||
"""
|
"""
|
||||||
|
db = library_db()
|
||||||
try:
|
try:
|
||||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
agent = await db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
if agent:
|
if agent:
|
||||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
return LibraryAgentSummary(
|
return LibraryAgentSummary(
|
||||||
@@ -163,7 +163,7 @@ async def get_library_agent_by_id(
|
|||||||
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
agent = await db.get_library_agent(agent_id, user_id)
|
||||||
if agent:
|
if agent:
|
||||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
return LibraryAgentSummary(
|
return LibraryAgentSummary(
|
||||||
@@ -215,7 +215,7 @@ async def get_library_agents_for_generation(
|
|||||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await library_db.list_library_agents(
|
response = await library_db().list_library_agents(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_term=search_query,
|
search_term=search_query,
|
||||||
page=1,
|
page=1,
|
||||||
@@ -272,7 +272,7 @@ async def search_marketplace_agents_for_generation(
|
|||||||
List of LibraryAgentSummary with full input/output schemas
|
List of LibraryAgentSummary with full input/output schemas
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await store_db.get_store_agents(
|
response = await store_db().get_store_agents(
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=max_results,
|
page_size=max_results,
|
||||||
@@ -286,7 +286,7 @@ async def search_marketplace_agents_for_generation(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
||||||
graphs = await get_store_listed_graphs(*graph_ids)
|
graphs = await graph_db().get_store_listed_graphs(graph_ids)
|
||||||
|
|
||||||
results: list[LibraryAgentSummary] = []
|
results: list[LibraryAgentSummary] = []
|
||||||
for agent in agents_with_graphs:
|
for agent in agents_with_graphs:
|
||||||
@@ -673,9 +673,10 @@ async def save_agent_to_library(
|
|||||||
Tuple of (created Graph, LibraryAgent)
|
Tuple of (created Graph, LibraryAgent)
|
||||||
"""
|
"""
|
||||||
graph = json_to_graph(agent_json)
|
graph = json_to_graph(agent_json)
|
||||||
|
db = library_db()
|
||||||
if is_update:
|
if is_update:
|
||||||
return await library_db.update_graph_in_library(graph, user_id)
|
return await db.update_graph_in_library(graph, user_id)
|
||||||
return await library_db.create_graph_in_library(graph, user_id)
|
return await db.create_graph_in_library(graph, user_id)
|
||||||
|
|
||||||
|
|
||||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||||
@@ -735,12 +736,14 @@ async def get_agent_as_json(
|
|||||||
Returns:
|
Returns:
|
||||||
Agent as JSON dict or None if not found
|
Agent as JSON dict or None if not found
|
||||||
"""
|
"""
|
||||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
db = graph_db()
|
||||||
|
|
||||||
|
graph = await db.get_graph(agent_id, version=None, user_id=user_id)
|
||||||
|
|
||||||
if not graph and user_id:
|
if not graph and user_id:
|
||||||
try:
|
try:
|
||||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
library_agent = await library_db().get_library_agent(agent_id, user_id)
|
||||||
graph = await get_graph(
|
graph = await db.get_graph(
|
||||||
library_agent.graph_id, version=None, user_id=user_id
|
library_agent.graph_id, version=None, user_id=user_id
|
||||||
)
|
)
|
||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
@@ -7,10 +7,9 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, field_validator
|
from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.library import db as library_db
|
|
||||||
from backend.api.features.library.model import LibraryAgent
|
from backend.api.features.library.model import LibraryAgent
|
||||||
from backend.data import execution as execution_db
|
from backend.copilot.model import ChatSession
|
||||||
|
from backend.data.db_accessors import execution_db, library_db
|
||||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
@@ -165,10 +164,12 @@ class AgentOutputTool(BaseTool):
|
|||||||
Resolve agent from provided identifiers.
|
Resolve agent from provided identifiers.
|
||||||
Returns (library_agent, error_message).
|
Returns (library_agent, error_message).
|
||||||
"""
|
"""
|
||||||
|
lib_db = library_db()
|
||||||
|
|
||||||
# Priority 1: Exact library agent ID
|
# Priority 1: Exact library agent ID
|
||||||
if library_agent_id:
|
if library_agent_id:
|
||||||
try:
|
try:
|
||||||
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
agent = await lib_db.get_library_agent(library_agent_id, user_id)
|
||||||
return agent, None
|
return agent, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||||
@@ -182,7 +183,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||||
|
|
||||||
# Find in user's library by graph_id
|
# Find in user's library by graph_id
|
||||||
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||||
if not agent:
|
if not agent:
|
||||||
return (
|
return (
|
||||||
None,
|
None,
|
||||||
@@ -194,7 +195,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
# Priority 3: Fuzzy name search in library
|
# Priority 3: Fuzzy name search in library
|
||||||
if agent_name:
|
if agent_name:
|
||||||
try:
|
try:
|
||||||
response = await library_db.list_library_agents(
|
response = await lib_db.list_library_agents(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
search_term=agent_name,
|
search_term=agent_name,
|
||||||
page_size=5,
|
page_size=5,
|
||||||
@@ -228,9 +229,11 @@ class AgentOutputTool(BaseTool):
|
|||||||
Fetch execution(s) based on filters.
|
Fetch execution(s) based on filters.
|
||||||
Returns (single_execution, available_executions_meta, error_message).
|
Returns (single_execution, available_executions_meta, error_message).
|
||||||
"""
|
"""
|
||||||
|
exec_db = execution_db()
|
||||||
|
|
||||||
# If specific execution_id provided, fetch it directly
|
# If specific execution_id provided, fetch it directly
|
||||||
if execution_id:
|
if execution_id:
|
||||||
execution = await execution_db.get_graph_execution(
|
execution = await exec_db.get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
include_node_executions=False,
|
include_node_executions=False,
|
||||||
@@ -240,7 +243,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
return execution, [], None
|
return execution, [], None
|
||||||
|
|
||||||
# Get completed executions with time filters
|
# Get completed executions with time filters
|
||||||
executions = await execution_db.get_graph_executions(
|
executions = await exec_db.get_graph_executions(
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
statuses=[ExecutionStatus.COMPLETED],
|
statuses=[ExecutionStatus.COMPLETED],
|
||||||
@@ -254,7 +257,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
|
|
||||||
# If only one execution, fetch full details
|
# If only one execution, fetch full details
|
||||||
if len(executions) == 1:
|
if len(executions) == 1:
|
||||||
full_execution = await execution_db.get_graph_execution(
|
full_execution = await exec_db.get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=executions[0].id,
|
execution_id=executions[0].id,
|
||||||
include_node_executions=False,
|
include_node_executions=False,
|
||||||
@@ -262,7 +265,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
return full_execution, [], None
|
return full_execution, [], None
|
||||||
|
|
||||||
# Multiple executions - return latest with full details, plus list of available
|
# Multiple executions - return latest with full details, plus list of available
|
||||||
full_execution = await execution_db.get_graph_execution(
|
full_execution = await exec_db.get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=executions[0].id,
|
execution_id=executions[0].id,
|
||||||
include_node_executions=False,
|
include_node_executions=False,
|
||||||
@@ -380,7 +383,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
and not input_data.store_slug
|
and not input_data.store_slug
|
||||||
):
|
):
|
||||||
# Fetch execution directly to get graph_id
|
# Fetch execution directly to get graph_id
|
||||||
execution = await execution_db.get_graph_execution(
|
execution = await execution_db().get_graph_execution(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
execution_id=input_data.execution_id,
|
execution_id=input_data.execution_id,
|
||||||
include_node_executions=False,
|
include_node_executions=False,
|
||||||
@@ -392,7 +395,7 @@ class AgentOutputTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Find library agent by graph_id
|
# Find library agent by graph_id
|
||||||
agent = await library_db.get_library_agent_by_graph_id(
|
agent = await library_db().get_library_agent_by_graph_id(
|
||||||
user_id, execution.graph_id
|
user_id, execution.graph_id
|
||||||
)
|
)
|
||||||
if not agent:
|
if not agent:
|
||||||
@@ -4,8 +4,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
from backend.data.db_accessors import library_db, store_db
|
||||||
from backend.api.features.store import db as store_db
|
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
@@ -45,8 +44,10 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
|||||||
Returns:
|
Returns:
|
||||||
AgentInfo if found, None otherwise
|
AgentInfo if found, None otherwise
|
||||||
"""
|
"""
|
||||||
|
lib_db = library_db()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||||
if agent:
|
if agent:
|
||||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||||
return AgentInfo(
|
return AgentInfo(
|
||||||
@@ -71,7 +72,7 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
agent = await lib_db.get_library_agent(agent_id, user_id)
|
||||||
if agent:
|
if agent:
|
||||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||||
return AgentInfo(
|
return AgentInfo(
|
||||||
@@ -133,7 +134,7 @@ async def search_agents(
|
|||||||
try:
|
try:
|
||||||
if source == "marketplace":
|
if source == "marketplace":
|
||||||
logger.info(f"Searching marketplace for: {query}")
|
logger.info(f"Searching marketplace for: {query}")
|
||||||
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
results = await store_db().get_store_agents(search_query=query, page_size=5)
|
||||||
for agent in results.agents:
|
for agent in results.agents:
|
||||||
agents.append(
|
agents.append(
|
||||||
AgentInfo(
|
AgentInfo(
|
||||||
@@ -159,7 +160,7 @@ async def search_agents(
|
|||||||
|
|
||||||
if not agents:
|
if not agents:
|
||||||
logger.info(f"Searching user library for: {query}")
|
logger.info(f"Searching user library for: {query}")
|
||||||
results = await library_db.list_library_agents(
|
results = await library_db().list_library_agents(
|
||||||
user_id=user_id, # type: ignore[arg-type]
|
user_id=user_id, # type: ignore[arg-type]
|
||||||
search_term=query,
|
search_term=query,
|
||||||
page_size=10,
|
page_size=10,
|
||||||
@@ -5,8 +5,8 @@ from typing import Any
|
|||||||
|
|
||||||
from openai.types.chat import ChatCompletionToolParam
|
from openai.types.chat import ChatCompletionToolParam
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||||
|
|
||||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||||
|
|
||||||
@@ -11,18 +11,11 @@ available (e.g. macOS development).
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
|
||||||
from backend.api.features.chat.tools.models import (
|
from .base import BaseTool
|
||||||
BashExecResponse,
|
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
|
||||||
ErrorResponse,
|
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
|
||||||
ToolResponseBase,
|
|
||||||
)
|
|
||||||
from backend.api.features.chat.tools.sandbox import (
|
|
||||||
get_workspace_dir,
|
|
||||||
has_full_sandbox,
|
|
||||||
run_sandboxed,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -3,13 +3,10 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
|
||||||
from backend.api.features.chat.tools.models import (
|
from .base import BaseTool
|
||||||
ErrorResponse,
|
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||||
ResponseType,
|
|
||||||
ToolResponseBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -78,7 +75,7 @@ class CheckOperationStatusTool(BaseTool):
|
|||||||
session: ChatSession,
|
session: ChatSession,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ToolResponseBase:
|
) -> ToolResponseBase:
|
||||||
from backend.api.features.chat import stream_registry
|
from backend.copilot import stream_registry
|
||||||
|
|
||||||
operation_id = (kwargs.get("operation_id") or "").strip()
|
operation_id = (kwargs.get("operation_id") or "").strip()
|
||||||
task_id = (kwargs.get("task_id") or "").strip()
|
task_id = (kwargs.get("task_id") or "").strip()
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
@@ -3,9 +3,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.store import db as store_db
|
|
||||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||||
|
from backend.copilot.model import ChatSession
|
||||||
|
from backend.data.db_accessors import store_db as get_store_db
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
@@ -137,6 +137,8 @@ class CustomizeAgentTool(BaseTool):
|
|||||||
|
|
||||||
creator_username, agent_slug = parts
|
creator_username, agent_slug = parts
|
||||||
|
|
||||||
|
store_db = get_store_db()
|
||||||
|
|
||||||
# Fetch the marketplace agent details
|
# Fetch the marketplace agent details
|
||||||
try:
|
try:
|
||||||
agent_details = await store_db.get_store_agent_details(
|
agent_details = await store_db.get_store_agent_details(
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
|
|
||||||
from .agent_generator import (
|
from .agent_generator import (
|
||||||
AgentGeneratorNotConfiguredError,
|
AgentGeneratorNotConfiguredError,
|
||||||
@@ -5,9 +5,14 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.blocks.linear._api import LinearClient
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
from backend.copilot.model import ChatSession
|
||||||
from backend.api.features.chat.tools.models import (
|
from backend.data.db_accessors import user_db
|
||||||
|
from backend.data.model import APIKeyCredentials
|
||||||
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import (
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
FeatureRequestCreatedResponse,
|
FeatureRequestCreatedResponse,
|
||||||
FeatureRequestInfo,
|
FeatureRequestInfo,
|
||||||
@@ -15,10 +20,6 @@ from backend.api.features.chat.tools.models import (
|
|||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
)
|
)
|
||||||
from backend.blocks.linear._api import LinearClient
|
|
||||||
from backend.data.model import APIKeyCredentials
|
|
||||||
from backend.data.user import get_user_email_by_id
|
|
||||||
from backend.util.settings import Settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -104,8 +105,8 @@ def _get_linear_config() -> tuple[LinearClient, str, str]:
|
|||||||
Raises RuntimeError if any required setting is missing.
|
Raises RuntimeError if any required setting is missing.
|
||||||
"""
|
"""
|
||||||
secrets = _get_settings().secrets
|
secrets = _get_settings().secrets
|
||||||
if not secrets.linear_api_key:
|
if not secrets.copilot_linear_api_key:
|
||||||
raise RuntimeError("LINEAR_API_KEY is not configured")
|
raise RuntimeError("COPILOT_LINEAR_API_KEY is not configured")
|
||||||
if not secrets.linear_feature_request_project_id:
|
if not secrets.linear_feature_request_project_id:
|
||||||
raise RuntimeError("LINEAR_FEATURE_REQUEST_PROJECT_ID is not configured")
|
raise RuntimeError("LINEAR_FEATURE_REQUEST_PROJECT_ID is not configured")
|
||||||
if not secrets.linear_feature_request_team_id:
|
if not secrets.linear_feature_request_team_id:
|
||||||
@@ -114,7 +115,7 @@ def _get_linear_config() -> tuple[LinearClient, str, str]:
|
|||||||
credentials = APIKeyCredentials(
|
credentials = APIKeyCredentials(
|
||||||
id="system-linear",
|
id="system-linear",
|
||||||
provider="linear",
|
provider="linear",
|
||||||
api_key=SecretStr(secrets.linear_api_key),
|
api_key=SecretStr(secrets.copilot_linear_api_key),
|
||||||
title="System Linear API Key",
|
title="System Linear API Key",
|
||||||
)
|
)
|
||||||
client = LinearClient(credentials=credentials)
|
client = LinearClient(credentials=credentials)
|
||||||
@@ -332,7 +333,9 @@ class CreateFeatureRequestTool(BaseTool):
|
|||||||
# Resolve a human-readable name (email) for the Linear customer record.
|
# Resolve a human-readable name (email) for the Linear customer record.
|
||||||
# Fall back to user_id if the lookup fails or returns None.
|
# Fall back to user_id if the lookup fails or returns None.
|
||||||
try:
|
try:
|
||||||
customer_display_name = await get_user_email_by_id(user_id) or user_id
|
customer_display_name = (
|
||||||
|
await user_db().get_user_email_by_id(user_id) or user_id
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
customer_display_name = user_id
|
customer_display_name = user_id
|
||||||
|
|
||||||
@@ -1,22 +1,18 @@
|
|||||||
"""Tests for SearchFeatureRequestsTool and CreateFeatureRequestTool."""
|
"""Tests for SearchFeatureRequestsTool and CreateFeatureRequestTool."""
|
||||||
|
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.api.features.chat.tools.feature_requests import (
|
from ._test_data import make_session
|
||||||
CreateFeatureRequestTool,
|
from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool
|
||||||
SearchFeatureRequestsTool,
|
from .models import (
|
||||||
)
|
|
||||||
from backend.api.features.chat.tools.models import (
|
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
FeatureRequestCreatedResponse,
|
FeatureRequestCreatedResponse,
|
||||||
FeatureRequestSearchResponse,
|
FeatureRequestSearchResponse,
|
||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-feature-requests"
|
_TEST_USER_ID = "test-user-feature-requests"
|
||||||
_TEST_USER_EMAIL = "testuser@example.com"
|
_TEST_USER_EMAIL = "testuser@example.com"
|
||||||
|
|
||||||
@@ -39,7 +35,7 @@ def _mock_linear_config(*, query_return=None, mutate_return=None):
|
|||||||
client.mutate.return_value = mutate_return
|
client.mutate.return_value = mutate_return
|
||||||
return (
|
return (
|
||||||
patch(
|
patch(
|
||||||
"backend.api.features.chat.tools.feature_requests._get_linear_config",
|
"backend.copilot.tools.feature_requests._get_linear_config",
|
||||||
return_value=(client, _FAKE_PROJECT_ID, _FAKE_TEAM_ID),
|
return_value=(client, _FAKE_PROJECT_ID, _FAKE_TEAM_ID),
|
||||||
),
|
),
|
||||||
client,
|
client,
|
||||||
@@ -208,7 +204,7 @@ class TestSearchFeatureRequestsTool:
|
|||||||
async def test_linear_client_init_failure(self):
|
async def test_linear_client_init_failure(self):
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.feature_requests._get_linear_config",
|
"backend.copilot.tools.feature_requests._get_linear_config",
|
||||||
side_effect=RuntimeError("No API key"),
|
side_effect=RuntimeError("No API key"),
|
||||||
):
|
):
|
||||||
tool = SearchFeatureRequestsTool()
|
tool = SearchFeatureRequestsTool()
|
||||||
@@ -231,10 +227,11 @@ class TestCreateFeatureRequestTool:
|
|||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def _patch_email_lookup(self):
|
def _patch_email_lookup(self):
|
||||||
|
mock_user_db = MagicMock()
|
||||||
|
mock_user_db.get_user_email_by_id = AsyncMock(return_value=_TEST_USER_EMAIL)
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.feature_requests.get_user_email_by_id",
|
"backend.copilot.tools.feature_requests.user_db",
|
||||||
new_callable=AsyncMock,
|
return_value=mock_user_db,
|
||||||
return_value=_TEST_USER_EMAIL,
|
|
||||||
):
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -347,7 +344,7 @@ class TestCreateFeatureRequestTool:
|
|||||||
async def test_linear_client_init_failure(self):
|
async def test_linear_client_init_failure(self):
|
||||||
session = make_session(user_id=_TEST_USER_ID)
|
session = make_session(user_id=_TEST_USER_ID)
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.feature_requests._get_linear_config",
|
"backend.copilot.tools.feature_requests._get_linear_config",
|
||||||
side_effect=RuntimeError("No API key"),
|
side_effect=RuntimeError("No API key"),
|
||||||
):
|
):
|
||||||
tool = CreateFeatureRequestTool()
|
tool = CreateFeatureRequestTool()
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
|
|
||||||
from .agent_search import search_agents
|
from .agent_search import search_agents
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
@@ -3,17 +3,18 @@ from typing import Any
|
|||||||
|
|
||||||
from prisma.enums import ContentType
|
from prisma.enums import ContentType
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.blocks import get_block
|
||||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
from backend.blocks._base import BlockType
|
||||||
from backend.api.features.chat.tools.models import (
|
from backend.copilot.model import ChatSession
|
||||||
|
from backend.data.db_accessors import search
|
||||||
|
|
||||||
|
from .base import BaseTool, ToolResponseBase
|
||||||
|
from .models import (
|
||||||
BlockInfoSummary,
|
BlockInfoSummary,
|
||||||
BlockListResponse,
|
BlockListResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
|
||||||
from backend.blocks import get_block
|
|
||||||
from backend.blocks._base import BlockType
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -107,7 +108,7 @@ class FindBlockTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Search for blocks using hybrid search
|
# Search for blocks using hybrid search
|
||||||
results, total = await unified_hybrid_search(
|
results, total = await search().unified_hybrid_search(
|
||||||
query=query,
|
query=query,
|
||||||
content_types=[ContentType.BLOCK],
|
content_types=[ContentType.BLOCK],
|
||||||
page=1,
|
page=1,
|
||||||
@@ -4,15 +4,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.api.features.chat.tools.find_block import (
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
from .find_block import (
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||||
FindBlockTool,
|
FindBlockTool,
|
||||||
)
|
)
|
||||||
from backend.api.features.chat.tools.models import BlockListResponse
|
from .models import BlockListResponse
|
||||||
from backend.blocks._base import BlockType
|
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-find-block"
|
_TEST_USER_ID = "test-user-find-block"
|
||||||
|
|
||||||
@@ -84,13 +84,17 @@ class TestFindBlockFiltering:
|
|||||||
"standard-block-id": standard_block,
|
"standard-block-id": standard_block,
|
||||||
}.get(block_id)
|
}.get(block_id)
|
||||||
|
|
||||||
|
mock_search_db = MagicMock()
|
||||||
|
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||||
|
return_value=(search_results, 2)
|
||||||
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
"backend.copilot.tools.find_block.search",
|
||||||
new_callable=AsyncMock,
|
return_value=mock_search_db,
|
||||||
return_value=(search_results, 2),
|
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.find_block.get_block",
|
"backend.copilot.tools.find_block.get_block",
|
||||||
side_effect=mock_get_block,
|
side_effect=mock_get_block,
|
||||||
):
|
):
|
||||||
tool = FindBlockTool()
|
tool = FindBlockTool()
|
||||||
@@ -128,13 +132,17 @@ class TestFindBlockFiltering:
|
|||||||
"normal-block-id": normal_block,
|
"normal-block-id": normal_block,
|
||||||
}.get(block_id)
|
}.get(block_id)
|
||||||
|
|
||||||
|
mock_search_db = MagicMock()
|
||||||
|
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||||
|
return_value=(search_results, 2)
|
||||||
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
"backend.copilot.tools.find_block.search",
|
||||||
new_callable=AsyncMock,
|
return_value=mock_search_db,
|
||||||
return_value=(search_results, 2),
|
|
||||||
):
|
):
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.find_block.get_block",
|
"backend.copilot.tools.find_block.get_block",
|
||||||
side_effect=mock_get_block,
|
side_effect=mock_get_block,
|
||||||
):
|
):
|
||||||
tool = FindBlockTool()
|
tool = FindBlockTool()
|
||||||
@@ -353,12 +361,16 @@ class TestFindBlockFiltering:
|
|||||||
for d in block_defs
|
for d in block_defs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mock_search_db = MagicMock()
|
||||||
|
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||||
|
return_value=(search_results, len(search_results))
|
||||||
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
"backend.copilot.tools.find_block.search",
|
||||||
new_callable=AsyncMock,
|
return_value=mock_search_db,
|
||||||
return_value=(search_results, len(search_results)),
|
|
||||||
), patch(
|
), patch(
|
||||||
"backend.api.features.chat.tools.find_block.get_block",
|
"backend.copilot.tools.find_block.get_block",
|
||||||
side_effect=lambda bid: mock_blocks.get(bid),
|
side_effect=lambda bid: mock_blocks.get(bid),
|
||||||
):
|
):
|
||||||
tool = FindBlockTool()
|
tool = FindBlockTool()
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
|
|
||||||
from .agent_search import search_agents
|
from .agent_search import search_agents
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
@@ -4,13 +4,10 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
|
||||||
from backend.api.features.chat.tools.models import (
|
from .base import BaseTool
|
||||||
DocPageResponse,
|
from .models import DocPageResponse, ErrorResponse, ToolResponseBase
|
||||||
ErrorResponse,
|
|
||||||
ToolResponseBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -5,16 +5,12 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from backend.api.features.chat.config import ChatConfig
|
from backend.copilot.config import ChatConfig
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.api.features.chat.tracking import (
|
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
|
||||||
track_agent_run_success,
|
from backend.data.db_accessors import graph_db, library_db, user_db
|
||||||
track_agent_scheduled,
|
|
||||||
)
|
|
||||||
from backend.api.features.library import db as library_db
|
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.data.user import get_user_by_id
|
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
from backend.util.clients import get_scheduler_client
|
from backend.util.clients import get_scheduler_client
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
@@ -200,7 +196,7 @@ class RunAgentTool(BaseTool):
|
|||||||
|
|
||||||
# Priority: library_agent_id if provided
|
# Priority: library_agent_id if provided
|
||||||
if has_library_id:
|
if has_library_id:
|
||||||
library_agent = await library_db.get_library_agent(
|
library_agent = await library_db().get_library_agent(
|
||||||
params.library_agent_id, user_id
|
params.library_agent_id, user_id
|
||||||
)
|
)
|
||||||
if not library_agent:
|
if not library_agent:
|
||||||
@@ -209,9 +205,7 @@ class RunAgentTool(BaseTool):
|
|||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
# Get the graph from the library agent
|
# Get the graph from the library agent
|
||||||
from backend.data.graph import get_graph
|
graph = await graph_db().get_graph(
|
||||||
|
|
||||||
graph = await get_graph(
|
|
||||||
library_agent.graph_id,
|
library_agent.graph_id,
|
||||||
library_agent.graph_version,
|
library_agent.graph_version,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -522,7 +516,7 @@ class RunAgentTool(BaseTool):
|
|||||||
library_agent = await get_or_create_library_agent(graph, user_id)
|
library_agent = await get_or_create_library_agent(graph, user_id)
|
||||||
|
|
||||||
# Get user timezone
|
# Get user timezone
|
||||||
user = await get_user_by_id(user_id)
|
user = await user_db().get_user_by_id(user_id)
|
||||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
|
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
|
||||||
|
|
||||||
# Create schedule
|
# Create schedule
|
||||||
@@ -7,20 +7,17 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic_core import PydanticUndefined
|
from pydantic_core import PydanticUndefined
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
|
||||||
from backend.api.features.chat.tools.find_block import (
|
|
||||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
|
||||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
|
||||||
)
|
|
||||||
from backend.blocks import get_block
|
from backend.blocks import get_block
|
||||||
from backend.blocks._base import AnyBlockSchema
|
from backend.blocks._base import AnyBlockSchema
|
||||||
|
from backend.copilot.model import ChatSession
|
||||||
|
from backend.data.db_accessors import workspace_db
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
|
||||||
from .helpers import get_inputs_from_schema
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockDetails,
|
BlockDetails,
|
||||||
@@ -276,7 +273,7 @@ class RunBlockTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Get or create user's workspace for CoPilot file operations
|
# Get or create user's workspace for CoPilot file operations
|
||||||
workspace = await get_or_create_workspace(user_id)
|
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||||
|
|
||||||
# Generate synthetic IDs for CoPilot context
|
# Generate synthetic IDs for CoPilot context
|
||||||
# Each chat session is treated as its own agent with one continuous run
|
# Each chat session is treated as its own agent with one continuous run
|
||||||
@@ -4,16 +4,16 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.api.features.chat.tools.models import (
|
from backend.blocks._base import BlockType
|
||||||
|
|
||||||
|
from ._test_data import make_session
|
||||||
|
from .models import (
|
||||||
BlockDetailsResponse,
|
BlockDetailsResponse,
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
InputValidationErrorResponse,
|
InputValidationErrorResponse,
|
||||||
)
|
)
|
||||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
from .run_block import RunBlockTool
|
||||||
from backend.blocks._base import BlockType
|
|
||||||
|
|
||||||
from ._test_data import make_session
|
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-run-block"
|
_TEST_USER_ID = "test-user-run-block"
|
||||||
|
|
||||||
@@ -77,7 +77,7 @@ class TestRunBlockFiltering:
|
|||||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
"backend.copilot.tools.run_block.get_block",
|
||||||
return_value=input_block,
|
return_value=input_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -103,7 +103,7 @@ class TestRunBlockFiltering:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
"backend.copilot.tools.run_block.get_block",
|
||||||
return_value=smart_block,
|
return_value=smart_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -127,7 +127,7 @@ class TestRunBlockFiltering:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
"backend.copilot.tools.run_block.get_block",
|
||||||
return_value=standard_block,
|
return_value=standard_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -183,7 +183,7 @@ class TestRunBlockInputValidation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
"backend.copilot.tools.run_block.get_block",
|
||||||
return_value=mock_block,
|
return_value=mock_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -222,7 +222,7 @@ class TestRunBlockInputValidation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
"backend.copilot.tools.run_block.get_block",
|
||||||
return_value=mock_block,
|
return_value=mock_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -263,7 +263,7 @@ class TestRunBlockInputValidation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
"backend.copilot.tools.run_block.get_block",
|
||||||
return_value=mock_block,
|
return_value=mock_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -302,15 +302,19 @@ class TestRunBlockInputValidation:
|
|||||||
|
|
||||||
mock_block.execute = mock_execute
|
mock_block.execute = mock_execute
|
||||||
|
|
||||||
|
mock_workspace_db = MagicMock()
|
||||||
|
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||||
|
return_value=MagicMock(id="test-workspace-id")
|
||||||
|
)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
"backend.copilot.tools.run_block.get_block",
|
||||||
return_value=mock_block,
|
return_value=mock_block,
|
||||||
),
|
),
|
||||||
patch(
|
patch(
|
||||||
"backend.api.features.chat.tools.run_block.get_or_create_workspace",
|
"backend.copilot.tools.run_block.workspace_db",
|
||||||
new_callable=AsyncMock,
|
return_value=mock_workspace_db,
|
||||||
return_value=MagicMock(id="test-workspace-id"),
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -344,7 +348,7 @@ class TestRunBlockInputValidation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
"backend.copilot.tools.run_block.get_block",
|
||||||
return_value=mock_block,
|
return_value=mock_block,
|
||||||
):
|
):
|
||||||
tool = RunBlockTool()
|
tool = RunBlockTool()
|
||||||
@@ -5,16 +5,17 @@ from typing import Any
|
|||||||
|
|
||||||
from prisma.enums import ContentType
|
from prisma.enums import ContentType
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
from backend.data.db_accessors import search
|
||||||
from backend.api.features.chat.tools.models import (
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import (
|
||||||
DocSearchResult,
|
DocSearchResult,
|
||||||
DocSearchResultsResponse,
|
DocSearchResultsResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
NoResultsResponse,
|
NoResultsResponse,
|
||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -117,7 +118,7 @@ class SearchDocsTool(BaseTool):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Search using hybrid search for DOCUMENTATION content type only
|
# Search using hybrid search for DOCUMENTATION content type only
|
||||||
results, total = await unified_hybrid_search(
|
results, total = await search().unified_hybrid_search(
|
||||||
query=query,
|
query=query,
|
||||||
content_types=[ContentType.DOCUMENTATION],
|
content_types=[ContentType.DOCUMENTATION],
|
||||||
page=1,
|
page=1,
|
||||||
@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.api.features.chat.tools.models import BlockDetailsResponse
|
|
||||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
|
||||||
from backend.blocks._base import BlockType
|
from backend.blocks._base import BlockType
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
|
|
||||||
from ._test_data import make_session
|
from ._test_data import make_session
|
||||||
|
from .models import BlockDetailsResponse
|
||||||
|
from .run_block import RunBlockTool
|
||||||
|
|
||||||
_TEST_USER_ID = "test-user-run-block-details"
|
_TEST_USER_ID = "test-user-run-block-details"
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ async def test_run_block_returns_details_when_no_input_provided():
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
"backend.copilot.tools.run_block.get_block",
|
||||||
return_value=http_block,
|
return_value=http_block,
|
||||||
):
|
):
|
||||||
# Mock credentials check to return no missing credentials
|
# Mock credentials check to return no missing credentials
|
||||||
@@ -120,7 +120,7 @@ async def test_run_block_returns_details_when_only_credentials_provided():
|
|||||||
}
|
}
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"backend.api.features.chat.tools.run_block.get_block",
|
"backend.copilot.tools.run_block.get_block",
|
||||||
return_value=mock,
|
return_value=mock,
|
||||||
):
|
):
|
||||||
with patch.object(
|
with patch.object(
|
||||||
@@ -3,9 +3,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.api.features.library import db as library_db
|
|
||||||
from backend.api.features.library import model as library_model
|
from backend.api.features.library import model as library_model
|
||||||
from backend.api.features.store import db as store_db
|
from backend.data.db_accessors import library_db, store_db
|
||||||
from backend.data.graph import GraphModel
|
from backend.data.graph import GraphModel
|
||||||
from backend.data.model import (
|
from backend.data.model import (
|
||||||
Credentials,
|
Credentials,
|
||||||
@@ -39,13 +38,14 @@ async def fetch_graph_from_store_slug(
|
|||||||
Raises:
|
Raises:
|
||||||
DatabaseError: If there's a database error during lookup.
|
DatabaseError: If there's a database error during lookup.
|
||||||
"""
|
"""
|
||||||
|
sdb = store_db()
|
||||||
try:
|
try:
|
||||||
store_agent = await store_db.get_store_agent_details(username, agent_name)
|
store_agent = await sdb.get_store_agent_details(username, agent_name)
|
||||||
except NotFoundError:
|
except NotFoundError:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# Get the graph from store listing version
|
# Get the graph from store listing version
|
||||||
graph = await store_db.get_available_graph(
|
graph = await sdb.get_available_graph(
|
||||||
store_agent.store_listing_version_id, hide_nodes=False
|
store_agent.store_listing_version_id, hide_nodes=False
|
||||||
)
|
)
|
||||||
return graph, store_agent
|
return graph, store_agent
|
||||||
@@ -210,13 +210,13 @@ async def get_or_create_library_agent(
|
|||||||
Returns:
|
Returns:
|
||||||
LibraryAgent instance
|
LibraryAgent instance
|
||||||
"""
|
"""
|
||||||
existing = await library_db.get_library_agent_by_graph_id(
|
existing = await library_db().get_library_agent_by_graph_id(
|
||||||
graph_id=graph.id, user_id=user_id
|
graph_id=graph.id, user_id=user_id
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
return existing
|
return existing
|
||||||
|
|
||||||
library_agents = await library_db.create_library_agent(
|
library_agents = await library_db().create_library_agent(
|
||||||
graph=graph,
|
graph=graph,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
create_library_agents_for_sub_graphs=False,
|
create_library_agents_for_sub_graphs=False,
|
||||||
@@ -6,15 +6,12 @@ from typing import Any
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import html2text
|
import html2text
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.api.features.chat.tools.base import BaseTool
|
|
||||||
from backend.api.features.chat.tools.models import (
|
|
||||||
ErrorResponse,
|
|
||||||
ToolResponseBase,
|
|
||||||
WebFetchResponse,
|
|
||||||
)
|
|
||||||
from backend.util.request import Requests
|
from backend.util.request import Requests
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
from .models import ErrorResponse, ToolResponseBase, WebFetchResponse
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Limits
|
# Limits
|
||||||
@@ -6,8 +6,8 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from backend.api.features.chat.model import ChatSession
|
from backend.copilot.model import ChatSession
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.db_accessors import workspace_db
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
from backend.util.workspace import WorkspaceManager
|
from backend.util.workspace import WorkspaceManager
|
||||||
@@ -148,7 +148,7 @@ class ListWorkspaceFilesTool(BaseTool):
|
|||||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workspace = await get_or_create_workspace(user_id)
|
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||||
# Pass session_id for session-scoped file access
|
# Pass session_id for session-scoped file access
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
@@ -167,8 +167,8 @@ class ListWorkspaceFilesTool(BaseTool):
|
|||||||
file_id=f.id,
|
file_id=f.id,
|
||||||
name=f.name,
|
name=f.name,
|
||||||
path=f.path,
|
path=f.path,
|
||||||
mime_type=f.mimeType,
|
mime_type=f.mime_type,
|
||||||
size_bytes=f.sizeBytes,
|
size_bytes=f.size_bytes,
|
||||||
)
|
)
|
||||||
for f in files
|
for f in files
|
||||||
]
|
]
|
||||||
@@ -284,7 +284,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workspace = await get_or_create_workspace(user_id)
|
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||||
# Pass session_id for session-scoped file access
|
# Pass session_id for session-scoped file access
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
@@ -309,8 +309,8 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
target_file_id = file_info.id
|
target_file_id = file_info.id
|
||||||
|
|
||||||
# Decide whether to return inline content or metadata+URL
|
# Decide whether to return inline content or metadata+URL
|
||||||
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
||||||
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
is_text_file = self._is_text_mime_type(file_info.mime_type)
|
||||||
|
|
||||||
# Return inline content for small text files (unless force_download_url)
|
# Return inline content for small text files (unless force_download_url)
|
||||||
if is_small_file and is_text_file and not force_download_url:
|
if is_small_file and is_text_file and not force_download_url:
|
||||||
@@ -321,7 +321,7 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
file_id=file_info.id,
|
file_id=file_info.id,
|
||||||
name=file_info.name,
|
name=file_info.name,
|
||||||
path=file_info.path,
|
path=file_info.path,
|
||||||
mime_type=file_info.mimeType,
|
mime_type=file_info.mime_type,
|
||||||
content_base64=content_b64,
|
content_base64=content_b64,
|
||||||
message=f"Successfully read file: {file_info.name}",
|
message=f"Successfully read file: {file_info.name}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -350,11 +350,11 @@ class ReadWorkspaceFileTool(BaseTool):
|
|||||||
file_id=file_info.id,
|
file_id=file_info.id,
|
||||||
name=file_info.name,
|
name=file_info.name,
|
||||||
path=file_info.path,
|
path=file_info.path,
|
||||||
mime_type=file_info.mimeType,
|
mime_type=file_info.mime_type,
|
||||||
size_bytes=file_info.sizeBytes,
|
size_bytes=file_info.size_bytes,
|
||||||
download_url=download_url,
|
download_url=download_url,
|
||||||
preview=preview,
|
preview=preview,
|
||||||
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
message=f"File: {file_info.name} ({file_info.size_bytes} bytes). Use download_url to retrieve content.",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -484,7 +484,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
|||||||
# Virus scan
|
# Virus scan
|
||||||
await scan_content_safe(content, filename=filename)
|
await scan_content_safe(content, filename=filename)
|
||||||
|
|
||||||
workspace = await get_or_create_workspace(user_id)
|
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||||
# Pass session_id for session-scoped file access
|
# Pass session_id for session-scoped file access
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
@@ -500,7 +500,7 @@ class WriteWorkspaceFileTool(BaseTool):
|
|||||||
file_id=file_record.id,
|
file_id=file_record.id,
|
||||||
name=file_record.name,
|
name=file_record.name,
|
||||||
path=file_record.path,
|
path=file_record.path,
|
||||||
size_bytes=file_record.sizeBytes,
|
size_bytes=file_record.size_bytes,
|
||||||
message=f"Successfully wrote file: {file_record.name}",
|
message=f"Successfully wrote file: {file_record.name}",
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
@@ -583,7 +583,7 @@ class DeleteWorkspaceFileTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
workspace = await get_or_create_workspace(user_id)
|
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||||
# Pass session_id for session-scoped file access
|
# Pass session_id for session-scoped file access
|
||||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||||
|
|
||||||
118
autogpt_platform/backend/backend/data/db_accessors.py
Normal file
118
autogpt_platform/backend/backend/data/db_accessors.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
from backend.data import db
|
||||||
|
|
||||||
|
|
||||||
|
def chat_db():
|
||||||
|
if db.is_connected():
|
||||||
|
from backend.copilot import db as _chat_db
|
||||||
|
|
||||||
|
chat_db = _chat_db
|
||||||
|
else:
|
||||||
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
|
chat_db = get_database_manager_async_client()
|
||||||
|
|
||||||
|
return chat_db
|
||||||
|
|
||||||
|
|
||||||
|
def graph_db():
|
||||||
|
if db.is_connected():
|
||||||
|
from backend.data import graph as _graph_db
|
||||||
|
|
||||||
|
graph_db = _graph_db
|
||||||
|
else:
|
||||||
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
|
graph_db = get_database_manager_async_client()
|
||||||
|
|
||||||
|
return graph_db
|
||||||
|
|
||||||
|
|
||||||
|
def library_db():
|
||||||
|
if db.is_connected():
|
||||||
|
from backend.api.features.library import db as _library_db
|
||||||
|
|
||||||
|
library_db = _library_db
|
||||||
|
else:
|
||||||
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
|
library_db = get_database_manager_async_client()
|
||||||
|
|
||||||
|
return library_db
|
||||||
|
|
||||||
|
|
||||||
|
def store_db():
|
||||||
|
if db.is_connected():
|
||||||
|
from backend.api.features.store import db as _store_db
|
||||||
|
|
||||||
|
store_db = _store_db
|
||||||
|
else:
|
||||||
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
|
store_db = get_database_manager_async_client()
|
||||||
|
|
||||||
|
return store_db
|
||||||
|
|
||||||
|
|
||||||
|
def search():
|
||||||
|
if db.is_connected():
|
||||||
|
from backend.api.features.store import hybrid_search as _search
|
||||||
|
|
||||||
|
search = _search
|
||||||
|
else:
|
||||||
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
|
search = get_database_manager_async_client()
|
||||||
|
|
||||||
|
return search
|
||||||
|
|
||||||
|
|
||||||
|
def execution_db():
|
||||||
|
if db.is_connected():
|
||||||
|
from backend.data import execution as _execution_db
|
||||||
|
|
||||||
|
execution_db = _execution_db
|
||||||
|
else:
|
||||||
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
|
execution_db = get_database_manager_async_client()
|
||||||
|
|
||||||
|
return execution_db
|
||||||
|
|
||||||
|
|
||||||
|
def user_db():
|
||||||
|
if db.is_connected():
|
||||||
|
from backend.data import user as _user_db
|
||||||
|
|
||||||
|
user_db = _user_db
|
||||||
|
else:
|
||||||
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
|
user_db = get_database_manager_async_client()
|
||||||
|
|
||||||
|
return user_db
|
||||||
|
|
||||||
|
|
||||||
|
def understanding_db():
|
||||||
|
if db.is_connected():
|
||||||
|
from backend.data import understanding as _understanding_db
|
||||||
|
|
||||||
|
understanding_db = _understanding_db
|
||||||
|
else:
|
||||||
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
|
understanding_db = get_database_manager_async_client()
|
||||||
|
|
||||||
|
return understanding_db
|
||||||
|
|
||||||
|
|
||||||
|
def workspace_db():
|
||||||
|
if db.is_connected():
|
||||||
|
from backend.data import workspace as _workspace_db
|
||||||
|
|
||||||
|
workspace_db = _workspace_db
|
||||||
|
else:
|
||||||
|
from backend.util.clients import get_database_manager_async_client
|
||||||
|
|
||||||
|
workspace_db = get_database_manager_async_client()
|
||||||
|
|
||||||
|
return workspace_db
|
||||||
@@ -4,14 +4,26 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas
|
|||||||
|
|
||||||
from backend.api.features.library.db import (
|
from backend.api.features.library.db import (
|
||||||
add_store_agent_to_library,
|
add_store_agent_to_library,
|
||||||
|
create_graph_in_library,
|
||||||
|
create_library_agent,
|
||||||
|
get_library_agent,
|
||||||
|
get_library_agent_by_graph_id,
|
||||||
list_library_agents,
|
list_library_agents,
|
||||||
|
update_graph_in_library,
|
||||||
|
)
|
||||||
|
from backend.api.features.store.db import (
|
||||||
|
get_agent,
|
||||||
|
get_available_graph,
|
||||||
|
get_store_agent_details,
|
||||||
|
get_store_agents,
|
||||||
)
|
)
|
||||||
from backend.api.features.store.db import get_store_agent_details, get_store_agents
|
|
||||||
from backend.api.features.store.embeddings import (
|
from backend.api.features.store.embeddings import (
|
||||||
backfill_missing_embeddings,
|
backfill_missing_embeddings,
|
||||||
cleanup_orphaned_embeddings,
|
cleanup_orphaned_embeddings,
|
||||||
get_embedding_stats,
|
get_embedding_stats,
|
||||||
)
|
)
|
||||||
|
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||||
|
from backend.copilot import db as chat_db
|
||||||
from backend.data import db
|
from backend.data import db
|
||||||
from backend.data.analytics import (
|
from backend.data.analytics import (
|
||||||
get_accuracy_trends_and_alerts,
|
get_accuracy_trends_and_alerts,
|
||||||
@@ -48,6 +60,7 @@ from backend.data.graph import (
|
|||||||
get_graph_metadata,
|
get_graph_metadata,
|
||||||
get_graph_settings,
|
get_graph_settings,
|
||||||
get_node,
|
get_node,
|
||||||
|
get_store_listed_graphs,
|
||||||
validate_graph_execution_permissions,
|
validate_graph_execution_permissions,
|
||||||
)
|
)
|
||||||
from backend.data.human_review import (
|
from backend.data.human_review import (
|
||||||
@@ -67,6 +80,10 @@ from backend.data.notifications import (
|
|||||||
remove_notifications_from_batch,
|
remove_notifications_from_batch,
|
||||||
)
|
)
|
||||||
from backend.data.onboarding import increment_onboarding_runs
|
from backend.data.onboarding import increment_onboarding_runs
|
||||||
|
from backend.data.understanding import (
|
||||||
|
get_business_understanding,
|
||||||
|
upsert_business_understanding,
|
||||||
|
)
|
||||||
from backend.data.user import (
|
from backend.data.user import (
|
||||||
get_active_user_ids_in_timerange,
|
get_active_user_ids_in_timerange,
|
||||||
get_user_by_id,
|
get_user_by_id,
|
||||||
@@ -76,6 +93,15 @@ from backend.data.user import (
|
|||||||
get_user_notification_preference,
|
get_user_notification_preference,
|
||||||
update_user_integrations,
|
update_user_integrations,
|
||||||
)
|
)
|
||||||
|
from backend.data.workspace import (
|
||||||
|
count_workspace_files,
|
||||||
|
create_workspace_file,
|
||||||
|
get_or_create_workspace,
|
||||||
|
get_workspace_file,
|
||||||
|
get_workspace_file_by_path,
|
||||||
|
list_workspace_files,
|
||||||
|
soft_delete_workspace_file,
|
||||||
|
)
|
||||||
from backend.util.service import (
|
from backend.util.service import (
|
||||||
AppService,
|
AppService,
|
||||||
AppServiceClient,
|
AppServiceClient,
|
||||||
@@ -107,6 +133,13 @@ async def _get_credits(user_id: str) -> int:
|
|||||||
|
|
||||||
|
|
||||||
class DatabaseManager(AppService):
|
class DatabaseManager(AppService):
|
||||||
|
"""Database connection pooling service.
|
||||||
|
|
||||||
|
This service connects to the Prisma engine and exposes database
|
||||||
|
operations via RPC endpoints. It acts as a centralized connection pool
|
||||||
|
for all services that need database access.
|
||||||
|
"""
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(self, app: "FastAPI"):
|
async def lifespan(self, app: "FastAPI"):
|
||||||
async with super().lifespan(app):
|
async with super().lifespan(app):
|
||||||
@@ -142,11 +175,15 @@ class DatabaseManager(AppService):
|
|||||||
def _(
|
def _(
|
||||||
f: Callable[P, R], name: str | None = None
|
f: Callable[P, R], name: str | None = None
|
||||||
) -> Callable[Concatenate[object, P], R]:
|
) -> Callable[Concatenate[object, P], R]:
|
||||||
|
"""
|
||||||
|
Exposes a function as an RPC endpoint, and adds a virtual `self` param
|
||||||
|
to the function's type so it can be bound as a method.
|
||||||
|
"""
|
||||||
if name is not None:
|
if name is not None:
|
||||||
f.__name__ = name
|
f.__name__ = name
|
||||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||||
|
|
||||||
# Executions
|
# ============ Graph Executions ============ #
|
||||||
get_child_graph_executions = _(get_child_graph_executions)
|
get_child_graph_executions = _(get_child_graph_executions)
|
||||||
get_graph_executions = _(get_graph_executions)
|
get_graph_executions = _(get_graph_executions)
|
||||||
get_graph_executions_count = _(get_graph_executions_count)
|
get_graph_executions_count = _(get_graph_executions_count)
|
||||||
@@ -170,36 +207,37 @@ class DatabaseManager(AppService):
|
|||||||
get_frequently_executed_graphs = _(get_frequently_executed_graphs)
|
get_frequently_executed_graphs = _(get_frequently_executed_graphs)
|
||||||
get_marketplace_graphs_for_monitoring = _(get_marketplace_graphs_for_monitoring)
|
get_marketplace_graphs_for_monitoring = _(get_marketplace_graphs_for_monitoring)
|
||||||
|
|
||||||
# Graphs
|
# ============ Graphs ============ #
|
||||||
get_node = _(get_node)
|
get_node = _(get_node)
|
||||||
get_graph = _(get_graph)
|
get_graph = _(get_graph)
|
||||||
get_connected_output_nodes = _(get_connected_output_nodes)
|
get_connected_output_nodes = _(get_connected_output_nodes)
|
||||||
get_graph_metadata = _(get_graph_metadata)
|
get_graph_metadata = _(get_graph_metadata)
|
||||||
get_graph_settings = _(get_graph_settings)
|
get_graph_settings = _(get_graph_settings)
|
||||||
|
get_store_listed_graphs = _(get_store_listed_graphs)
|
||||||
|
|
||||||
# Credits
|
# ============ Credits ============ #
|
||||||
spend_credits = _(_spend_credits, name="spend_credits")
|
spend_credits = _(_spend_credits, name="spend_credits")
|
||||||
get_credits = _(_get_credits, name="get_credits")
|
get_credits = _(_get_credits, name="get_credits")
|
||||||
|
|
||||||
# User + User Metadata + User Integrations
|
# ============ User + Integrations ============ #
|
||||||
|
get_user_by_id = _(get_user_by_id)
|
||||||
get_user_integrations = _(get_user_integrations)
|
get_user_integrations = _(get_user_integrations)
|
||||||
update_user_integrations = _(update_user_integrations)
|
update_user_integrations = _(update_user_integrations)
|
||||||
|
|
||||||
# User Comms - async
|
# ============ User Comms ============ #
|
||||||
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
|
||||||
get_user_by_id = _(get_user_by_id)
|
|
||||||
get_user_email_by_id = _(get_user_email_by_id)
|
get_user_email_by_id = _(get_user_email_by_id)
|
||||||
get_user_email_verification = _(get_user_email_verification)
|
get_user_email_verification = _(get_user_email_verification)
|
||||||
get_user_notification_preference = _(get_user_notification_preference)
|
get_user_notification_preference = _(get_user_notification_preference)
|
||||||
|
|
||||||
# Human In The Loop
|
# ============ Human In The Loop ============ #
|
||||||
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
|
||||||
check_approval = _(check_approval)
|
check_approval = _(check_approval)
|
||||||
get_or_create_human_review = _(get_or_create_human_review)
|
get_or_create_human_review = _(get_or_create_human_review)
|
||||||
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
|
||||||
update_review_processed_status = _(update_review_processed_status)
|
update_review_processed_status = _(update_review_processed_status)
|
||||||
|
|
||||||
# Notifications - async
|
# ============ Notifications ============ #
|
||||||
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
|
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
|
||||||
create_or_add_to_user_notification_batch = _(
|
create_or_add_to_user_notification_batch = _(
|
||||||
create_or_add_to_user_notification_batch
|
create_or_add_to_user_notification_batch
|
||||||
@@ -212,29 +250,62 @@ class DatabaseManager(AppService):
|
|||||||
get_user_notification_oldest_message_in_batch
|
get_user_notification_oldest_message_in_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
# Library
|
# ============ Library ============ #
|
||||||
list_library_agents = _(list_library_agents)
|
list_library_agents = _(list_library_agents)
|
||||||
add_store_agent_to_library = _(add_store_agent_to_library)
|
add_store_agent_to_library = _(add_store_agent_to_library)
|
||||||
|
create_graph_in_library = _(create_graph_in_library)
|
||||||
|
create_library_agent = _(create_library_agent)
|
||||||
|
get_library_agent = _(get_library_agent)
|
||||||
|
get_library_agent_by_graph_id = _(get_library_agent_by_graph_id)
|
||||||
|
update_graph_in_library = _(update_graph_in_library)
|
||||||
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
|
||||||
|
|
||||||
# Onboarding
|
# ============ Onboarding ============ #
|
||||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||||
|
|
||||||
# OAuth
|
# ============ OAuth ============ #
|
||||||
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
||||||
|
|
||||||
# Store
|
# ============ Store ============ #
|
||||||
get_store_agents = _(get_store_agents)
|
get_store_agents = _(get_store_agents)
|
||||||
get_store_agent_details = _(get_store_agent_details)
|
get_store_agent_details = _(get_store_agent_details)
|
||||||
|
get_agent = _(get_agent)
|
||||||
|
get_available_graph = _(get_available_graph)
|
||||||
|
|
||||||
# Store Embeddings
|
# ============ Search ============ #
|
||||||
get_embedding_stats = _(get_embedding_stats)
|
get_embedding_stats = _(get_embedding_stats)
|
||||||
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
backfill_missing_embeddings = _(backfill_missing_embeddings)
|
||||||
cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings)
|
cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings)
|
||||||
|
unified_hybrid_search = _(unified_hybrid_search)
|
||||||
|
|
||||||
# Summary data - async
|
# ============ Summary Data ============ #
|
||||||
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
get_user_execution_summary_data = _(get_user_execution_summary_data)
|
||||||
|
|
||||||
|
# ============ Workspace ============ #
|
||||||
|
count_workspace_files = _(count_workspace_files)
|
||||||
|
create_workspace_file = _(create_workspace_file)
|
||||||
|
get_or_create_workspace = _(get_or_create_workspace)
|
||||||
|
get_workspace_file = _(get_workspace_file)
|
||||||
|
get_workspace_file_by_path = _(get_workspace_file_by_path)
|
||||||
|
list_workspace_files = _(list_workspace_files)
|
||||||
|
soft_delete_workspace_file = _(soft_delete_workspace_file)
|
||||||
|
|
||||||
|
# ============ Understanding ============ #
|
||||||
|
get_business_understanding = _(get_business_understanding)
|
||||||
|
upsert_business_understanding = _(upsert_business_understanding)
|
||||||
|
|
||||||
|
# ============ CoPilot Chat Sessions ============ #
|
||||||
|
get_chat_session = _(chat_db.get_chat_session)
|
||||||
|
create_chat_session = _(chat_db.create_chat_session)
|
||||||
|
update_chat_session = _(chat_db.update_chat_session)
|
||||||
|
add_chat_message = _(chat_db.add_chat_message)
|
||||||
|
add_chat_messages_batch = _(chat_db.add_chat_messages_batch)
|
||||||
|
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
|
||||||
|
get_user_session_count = _(chat_db.get_user_session_count)
|
||||||
|
delete_chat_session = _(chat_db.delete_chat_session)
|
||||||
|
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
|
||||||
|
update_tool_message_content = _(chat_db.update_tool_message_content)
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManagerClient(AppServiceClient):
|
class DatabaseManagerClient(AppServiceClient):
|
||||||
d = DatabaseManager
|
d = DatabaseManager
|
||||||
@@ -296,43 +367,50 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
def get_service_type(cls):
|
def get_service_type(cls):
|
||||||
return DatabaseManager
|
return DatabaseManager
|
||||||
|
|
||||||
|
# ============ Graph Executions ============ #
|
||||||
create_graph_execution = d.create_graph_execution
|
create_graph_execution = d.create_graph_execution
|
||||||
get_child_graph_executions = d.get_child_graph_executions
|
get_child_graph_executions = d.get_child_graph_executions
|
||||||
get_connected_output_nodes = d.get_connected_output_nodes
|
get_connected_output_nodes = d.get_connected_output_nodes
|
||||||
get_latest_node_execution = d.get_latest_node_execution
|
get_latest_node_execution = d.get_latest_node_execution
|
||||||
get_graph = d.get_graph
|
|
||||||
get_graph_metadata = d.get_graph_metadata
|
|
||||||
get_graph_settings = d.get_graph_settings
|
|
||||||
get_graph_execution = d.get_graph_execution
|
get_graph_execution = d.get_graph_execution
|
||||||
get_graph_execution_meta = d.get_graph_execution_meta
|
get_graph_execution_meta = d.get_graph_execution_meta
|
||||||
get_node = d.get_node
|
get_graph_executions = d.get_graph_executions
|
||||||
get_node_execution = d.get_node_execution
|
get_node_execution = d.get_node_execution
|
||||||
get_node_executions = d.get_node_executions
|
get_node_executions = d.get_node_executions
|
||||||
get_user_by_id = d.get_user_by_id
|
|
||||||
get_user_integrations = d.get_user_integrations
|
|
||||||
upsert_execution_input = d.upsert_execution_input
|
|
||||||
upsert_execution_output = d.upsert_execution_output
|
|
||||||
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
|
||||||
update_graph_execution_stats = d.update_graph_execution_stats
|
update_graph_execution_stats = d.update_graph_execution_stats
|
||||||
update_node_execution_status = d.update_node_execution_status
|
update_node_execution_status = d.update_node_execution_status
|
||||||
update_node_execution_status_batch = d.update_node_execution_status_batch
|
update_node_execution_status_batch = d.update_node_execution_status_batch
|
||||||
update_user_integrations = d.update_user_integrations
|
upsert_execution_input = d.upsert_execution_input
|
||||||
|
upsert_execution_output = d.upsert_execution_output
|
||||||
|
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
|
||||||
get_execution_kv_data = d.get_execution_kv_data
|
get_execution_kv_data = d.get_execution_kv_data
|
||||||
set_execution_kv_data = d.set_execution_kv_data
|
set_execution_kv_data = d.set_execution_kv_data
|
||||||
|
|
||||||
# Human In The Loop
|
# ============ Graphs ============ #
|
||||||
|
get_graph = d.get_graph
|
||||||
|
get_graph_metadata = d.get_graph_metadata
|
||||||
|
get_graph_settings = d.get_graph_settings
|
||||||
|
get_node = d.get_node
|
||||||
|
get_store_listed_graphs = d.get_store_listed_graphs
|
||||||
|
|
||||||
|
# ============ User + Integrations ============ #
|
||||||
|
get_user_by_id = d.get_user_by_id
|
||||||
|
get_user_integrations = d.get_user_integrations
|
||||||
|
update_user_integrations = d.update_user_integrations
|
||||||
|
|
||||||
|
# ============ Human In The Loop ============ #
|
||||||
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
|
||||||
check_approval = d.check_approval
|
check_approval = d.check_approval
|
||||||
get_or_create_human_review = d.get_or_create_human_review
|
get_or_create_human_review = d.get_or_create_human_review
|
||||||
update_review_processed_status = d.update_review_processed_status
|
update_review_processed_status = d.update_review_processed_status
|
||||||
|
|
||||||
# User Comms
|
# ============ User Comms ============ #
|
||||||
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
|
||||||
get_user_email_by_id = d.get_user_email_by_id
|
get_user_email_by_id = d.get_user_email_by_id
|
||||||
get_user_email_verification = d.get_user_email_verification
|
get_user_email_verification = d.get_user_email_verification
|
||||||
get_user_notification_preference = d.get_user_notification_preference
|
get_user_notification_preference = d.get_user_notification_preference
|
||||||
|
|
||||||
# Notifications
|
# ============ Notifications ============ #
|
||||||
clear_all_user_notification_batches = d.clear_all_user_notification_batches
|
clear_all_user_notification_batches = d.clear_all_user_notification_batches
|
||||||
create_or_add_to_user_notification_batch = (
|
create_or_add_to_user_notification_batch = (
|
||||||
d.create_or_add_to_user_notification_batch
|
d.create_or_add_to_user_notification_batch
|
||||||
@@ -345,20 +423,55 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
d.get_user_notification_oldest_message_in_batch
|
d.get_user_notification_oldest_message_in_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
# Library
|
# ============ Library ============ #
|
||||||
list_library_agents = d.list_library_agents
|
list_library_agents = d.list_library_agents
|
||||||
add_store_agent_to_library = d.add_store_agent_to_library
|
add_store_agent_to_library = d.add_store_agent_to_library
|
||||||
|
create_graph_in_library = d.create_graph_in_library
|
||||||
|
create_library_agent = d.create_library_agent
|
||||||
|
get_library_agent = d.get_library_agent
|
||||||
|
get_library_agent_by_graph_id = d.get_library_agent_by_graph_id
|
||||||
|
update_graph_in_library = d.update_graph_in_library
|
||||||
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
validate_graph_execution_permissions = d.validate_graph_execution_permissions
|
||||||
|
|
||||||
# Onboarding
|
# ============ Onboarding ============ #
|
||||||
increment_onboarding_runs = d.increment_onboarding_runs
|
increment_onboarding_runs = d.increment_onboarding_runs
|
||||||
|
|
||||||
# OAuth
|
# ============ OAuth ============ #
|
||||||
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
||||||
|
|
||||||
# Store
|
# ============ Store ============ #
|
||||||
get_store_agents = d.get_store_agents
|
get_store_agents = d.get_store_agents
|
||||||
get_store_agent_details = d.get_store_agent_details
|
get_store_agent_details = d.get_store_agent_details
|
||||||
|
get_agent = d.get_agent
|
||||||
|
get_available_graph = d.get_available_graph
|
||||||
|
|
||||||
# Summary data
|
# ============ Search ============ #
|
||||||
|
unified_hybrid_search = d.unified_hybrid_search
|
||||||
|
|
||||||
|
# ============ Summary Data ============ #
|
||||||
get_user_execution_summary_data = d.get_user_execution_summary_data
|
get_user_execution_summary_data = d.get_user_execution_summary_data
|
||||||
|
|
||||||
|
# ============ Workspace ============ #
|
||||||
|
count_workspace_files = d.count_workspace_files
|
||||||
|
create_workspace_file = d.create_workspace_file
|
||||||
|
get_or_create_workspace = d.get_or_create_workspace
|
||||||
|
get_workspace_file = d.get_workspace_file
|
||||||
|
get_workspace_file_by_path = d.get_workspace_file_by_path
|
||||||
|
list_workspace_files = d.list_workspace_files
|
||||||
|
soft_delete_workspace_file = d.soft_delete_workspace_file
|
||||||
|
|
||||||
|
# ============ Understanding ============ #
|
||||||
|
get_business_understanding = d.get_business_understanding
|
||||||
|
upsert_business_understanding = d.upsert_business_understanding
|
||||||
|
|
||||||
|
# ============ CoPilot Chat Sessions ============ #
|
||||||
|
get_chat_session = d.get_chat_session
|
||||||
|
create_chat_session = d.create_chat_session
|
||||||
|
update_chat_session = d.update_chat_session
|
||||||
|
add_chat_message = d.add_chat_message
|
||||||
|
add_chat_messages_batch = d.add_chat_messages_batch
|
||||||
|
get_user_chat_sessions = d.get_user_chat_sessions
|
||||||
|
get_user_session_count = d.get_user_session_count
|
||||||
|
delete_chat_session = d.delete_chat_session
|
||||||
|
get_chat_session_message_count = d.get_chat_session_message_count
|
||||||
|
update_tool_message_content = d.update_tool_message_content
|
||||||
@@ -1147,14 +1147,14 @@ async def get_graph(
|
|||||||
return GraphModel.from_db(graph, for_export)
|
return GraphModel.from_db(graph, for_export)
|
||||||
|
|
||||||
|
|
||||||
async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]:
|
async def get_store_listed_graphs(graph_ids: list[str]) -> dict[str, GraphModel]:
|
||||||
"""Batch-fetch multiple store-listed graphs by their IDs.
|
"""Batch-fetch multiple store-listed graphs by their IDs.
|
||||||
|
|
||||||
Only returns graphs that have approved store listings (publicly available).
|
Only returns graphs that have approved store listings (publicly available).
|
||||||
Does not require permission checks since store-listed graphs are public.
|
Does not require permission checks since store-listed graphs are public.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*graph_ids: Variable number of graph IDs to fetch
|
graph_ids: List of graph IDs to fetch
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict mapping graph_id to GraphModel for graphs with approved store listings
|
Dict mapping graph_id to GraphModel for graphs with approved store listings
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import logging
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import pydantic
|
||||||
from prisma.models import UserWorkspace, UserWorkspaceFile
|
from prisma.models import UserWorkspace, UserWorkspaceFile
|
||||||
from prisma.types import UserWorkspaceFileWhereInput
|
from prisma.types import UserWorkspaceFileWhereInput
|
||||||
|
|
||||||
@@ -16,7 +17,61 @@ from backend.util.json import SafeJson
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
class Workspace(pydantic.BaseModel):
|
||||||
|
"""Pydantic model for UserWorkspace, safe for RPC transport."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_db(workspace: "UserWorkspace") -> "Workspace":
|
||||||
|
return Workspace(
|
||||||
|
id=workspace.id,
|
||||||
|
user_id=workspace.userId,
|
||||||
|
created_at=workspace.createdAt,
|
||||||
|
updated_at=workspace.updatedAt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceFile(pydantic.BaseModel):
|
||||||
|
"""Pydantic model for UserWorkspaceFile, safe for RPC transport."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
workspace_id: str
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
name: str
|
||||||
|
path: str
|
||||||
|
storage_path: str
|
||||||
|
mime_type: str
|
||||||
|
size_bytes: int
|
||||||
|
checksum: Optional[str] = None
|
||||||
|
is_deleted: bool = False
|
||||||
|
deleted_at: Optional[datetime] = None
|
||||||
|
metadata: dict = pydantic.Field(default_factory=dict)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_db(file: "UserWorkspaceFile") -> "WorkspaceFile":
|
||||||
|
return WorkspaceFile(
|
||||||
|
id=file.id,
|
||||||
|
workspace_id=file.workspaceId,
|
||||||
|
created_at=file.createdAt,
|
||||||
|
updated_at=file.updatedAt,
|
||||||
|
name=file.name,
|
||||||
|
path=file.path,
|
||||||
|
storage_path=file.storagePath,
|
||||||
|
mime_type=file.mimeType,
|
||||||
|
size_bytes=file.sizeBytes,
|
||||||
|
checksum=file.checksum,
|
||||||
|
is_deleted=file.isDeleted,
|
||||||
|
deleted_at=file.deletedAt,
|
||||||
|
metadata=file.metadata if isinstance(file.metadata, dict) else {},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_or_create_workspace(user_id: str) -> Workspace:
|
||||||
"""
|
"""
|
||||||
Get user's workspace, creating one if it doesn't exist.
|
Get user's workspace, creating one if it doesn't exist.
|
||||||
|
|
||||||
@@ -27,7 +82,7 @@ async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
|||||||
user_id: The user's ID
|
user_id: The user's ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
UserWorkspace instance
|
Workspace instance
|
||||||
"""
|
"""
|
||||||
workspace = await UserWorkspace.prisma().upsert(
|
workspace = await UserWorkspace.prisma().upsert(
|
||||||
where={"userId": user_id},
|
where={"userId": user_id},
|
||||||
@@ -37,10 +92,10 @@ async def get_or_create_workspace(user_id: str) -> UserWorkspace:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return workspace
|
return Workspace.from_db(workspace)
|
||||||
|
|
||||||
|
|
||||||
async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
|
async def get_workspace(user_id: str) -> Optional[Workspace]:
|
||||||
"""
|
"""
|
||||||
Get user's workspace if it exists.
|
Get user's workspace if it exists.
|
||||||
|
|
||||||
@@ -48,9 +103,10 @@ async def get_workspace(user_id: str) -> Optional[UserWorkspace]:
|
|||||||
user_id: The user's ID
|
user_id: The user's ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
UserWorkspace instance or None
|
Workspace instance or None
|
||||||
"""
|
"""
|
||||||
return await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id})
|
||||||
|
return Workspace.from_db(workspace) if workspace else None
|
||||||
|
|
||||||
|
|
||||||
async def create_workspace_file(
|
async def create_workspace_file(
|
||||||
@@ -63,7 +119,7 @@ async def create_workspace_file(
|
|||||||
size_bytes: int,
|
size_bytes: int,
|
||||||
checksum: Optional[str] = None,
|
checksum: Optional[str] = None,
|
||||||
metadata: Optional[dict] = None,
|
metadata: Optional[dict] = None,
|
||||||
) -> UserWorkspaceFile:
|
) -> WorkspaceFile:
|
||||||
"""
|
"""
|
||||||
Create a new workspace file record.
|
Create a new workspace file record.
|
||||||
|
|
||||||
@@ -79,7 +135,7 @@ async def create_workspace_file(
|
|||||||
metadata: Optional additional metadata
|
metadata: Optional additional metadata
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Created UserWorkspaceFile instance
|
Created WorkspaceFile instance
|
||||||
"""
|
"""
|
||||||
# Normalize path to start with /
|
# Normalize path to start with /
|
||||||
if not path.startswith("/"):
|
if not path.startswith("/"):
|
||||||
@@ -103,34 +159,37 @@ async def create_workspace_file(
|
|||||||
f"Created workspace file {file.id} at path {path} "
|
f"Created workspace file {file.id} at path {path} "
|
||||||
f"in workspace {workspace_id}"
|
f"in workspace {workspace_id}"
|
||||||
)
|
)
|
||||||
return file
|
return WorkspaceFile.from_db(file)
|
||||||
|
|
||||||
|
|
||||||
async def get_workspace_file(
|
async def get_workspace_file(
|
||||||
file_id: str,
|
file_id: str,
|
||||||
workspace_id: Optional[str] = None,
|
workspace_id: str,
|
||||||
) -> Optional[UserWorkspaceFile]:
|
) -> Optional[WorkspaceFile]:
|
||||||
"""
|
"""
|
||||||
Get a workspace file by ID.
|
Get a workspace file by ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_id: The file ID
|
file_id: The file ID
|
||||||
workspace_id: Optional workspace ID for validation
|
workspace_id: Workspace ID for scoping (required)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
UserWorkspaceFile instance or None
|
WorkspaceFile instance or None
|
||||||
"""
|
"""
|
||||||
where_clause: dict = {"id": file_id, "isDeleted": False}
|
where_clause: UserWorkspaceFileWhereInput = {
|
||||||
if workspace_id:
|
"id": file_id,
|
||||||
where_clause["workspaceId"] = workspace_id
|
"isDeleted": False,
|
||||||
|
"workspaceId": workspace_id,
|
||||||
|
}
|
||||||
|
|
||||||
return await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
file = await UserWorkspaceFile.prisma().find_first(where=where_clause)
|
||||||
|
return WorkspaceFile.from_db(file) if file else None
|
||||||
|
|
||||||
|
|
||||||
async def get_workspace_file_by_path(
|
async def get_workspace_file_by_path(
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
path: str,
|
path: str,
|
||||||
) -> Optional[UserWorkspaceFile]:
|
) -> Optional[WorkspaceFile]:
|
||||||
"""
|
"""
|
||||||
Get a workspace file by its virtual path.
|
Get a workspace file by its virtual path.
|
||||||
|
|
||||||
@@ -139,19 +198,20 @@ async def get_workspace_file_by_path(
|
|||||||
path: Virtual path
|
path: Virtual path
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
UserWorkspaceFile instance or None
|
WorkspaceFile instance or None
|
||||||
"""
|
"""
|
||||||
# Normalize path
|
# Normalize path
|
||||||
if not path.startswith("/"):
|
if not path.startswith("/"):
|
||||||
path = f"/{path}"
|
path = f"/{path}"
|
||||||
|
|
||||||
return await UserWorkspaceFile.prisma().find_first(
|
file = await UserWorkspaceFile.prisma().find_first(
|
||||||
where={
|
where={
|
||||||
"workspaceId": workspace_id,
|
"workspaceId": workspace_id,
|
||||||
"path": path,
|
"path": path,
|
||||||
"isDeleted": False,
|
"isDeleted": False,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
return WorkspaceFile.from_db(file) if file else None
|
||||||
|
|
||||||
|
|
||||||
async def list_workspace_files(
|
async def list_workspace_files(
|
||||||
@@ -160,7 +220,7 @@ async def list_workspace_files(
|
|||||||
include_deleted: bool = False,
|
include_deleted: bool = False,
|
||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
) -> list[UserWorkspaceFile]:
|
) -> list[WorkspaceFile]:
|
||||||
"""
|
"""
|
||||||
List files in a workspace.
|
List files in a workspace.
|
||||||
|
|
||||||
@@ -172,7 +232,7 @@ async def list_workspace_files(
|
|||||||
offset: Number of files to skip
|
offset: Number of files to skip
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of UserWorkspaceFile instances
|
List of WorkspaceFile instances
|
||||||
"""
|
"""
|
||||||
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
||||||
|
|
||||||
@@ -185,12 +245,13 @@ async def list_workspace_files(
|
|||||||
path_prefix = f"/{path_prefix}"
|
path_prefix = f"/{path_prefix}"
|
||||||
where_clause["path"] = {"startswith": path_prefix}
|
where_clause["path"] = {"startswith": path_prefix}
|
||||||
|
|
||||||
return await UserWorkspaceFile.prisma().find_many(
|
files = await UserWorkspaceFile.prisma().find_many(
|
||||||
where=where_clause,
|
where=where_clause,
|
||||||
order={"createdAt": "desc"},
|
order={"createdAt": "desc"},
|
||||||
take=limit,
|
take=limit,
|
||||||
skip=offset,
|
skip=offset,
|
||||||
)
|
)
|
||||||
|
return [WorkspaceFile.from_db(f) for f in files]
|
||||||
|
|
||||||
|
|
||||||
async def count_workspace_files(
|
async def count_workspace_files(
|
||||||
@@ -209,7 +270,7 @@ async def count_workspace_files(
|
|||||||
Returns:
|
Returns:
|
||||||
Number of files
|
Number of files
|
||||||
"""
|
"""
|
||||||
where_clause: dict = {"workspaceId": workspace_id}
|
where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id}
|
||||||
if not include_deleted:
|
if not include_deleted:
|
||||||
where_clause["isDeleted"] = False
|
where_clause["isDeleted"] = False
|
||||||
|
|
||||||
@@ -224,8 +285,8 @@ async def count_workspace_files(
|
|||||||
|
|
||||||
async def soft_delete_workspace_file(
|
async def soft_delete_workspace_file(
|
||||||
file_id: str,
|
file_id: str,
|
||||||
workspace_id: Optional[str] = None,
|
workspace_id: str,
|
||||||
) -> Optional[UserWorkspaceFile]:
|
) -> Optional[WorkspaceFile]:
|
||||||
"""
|
"""
|
||||||
Soft-delete a workspace file.
|
Soft-delete a workspace file.
|
||||||
|
|
||||||
@@ -234,10 +295,10 @@ async def soft_delete_workspace_file(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_id: The file ID
|
file_id: The file ID
|
||||||
workspace_id: Optional workspace ID for validation
|
workspace_id: Workspace ID for scoping (required)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated UserWorkspaceFile instance or None if not found
|
Updated WorkspaceFile instance or None if not found
|
||||||
"""
|
"""
|
||||||
# First verify the file exists and belongs to workspace
|
# First verify the file exists and belongs to workspace
|
||||||
file = await get_workspace_file(file_id, workspace_id)
|
file = await get_workspace_file(file_id, workspace_id)
|
||||||
@@ -259,7 +320,7 @@ async def soft_delete_workspace_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Soft-deleted workspace file {file_id}")
|
logger.info(f"Soft-deleted workspace file {file_id}")
|
||||||
return updated
|
return WorkspaceFile.from_db(updated) if updated else None
|
||||||
|
|
||||||
|
|
||||||
async def get_workspace_total_size(workspace_id: str) -> int:
|
async def get_workspace_total_size(workspace_id: str) -> int:
|
||||||
@@ -273,4 +334,4 @@ async def get_workspace_total_size(workspace_id: str) -> int:
|
|||||||
Total size in bytes
|
Total size in bytes
|
||||||
"""
|
"""
|
||||||
files = await list_workspace_files(workspace_id)
|
files = await list_workspace_files(workspace_id)
|
||||||
return sum(file.sizeBytes for file in files)
|
return sum(file.size_bytes for file in files)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from backend.app import run_processes
|
from backend.app import run_processes
|
||||||
from backend.executor import DatabaseManager
|
from backend.data.db_manager import DatabaseManager
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
from .database import DatabaseManager, DatabaseManagerAsyncClient, DatabaseManagerClient
|
|
||||||
from .manager import ExecutionManager
|
from .manager import ExecutionManager
|
||||||
from .scheduler import Scheduler
|
from .scheduler import Scheduler
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DatabaseManager",
|
|
||||||
"DatabaseManagerClient",
|
|
||||||
"DatabaseManagerAsyncClient",
|
|
||||||
"ExecutionManager",
|
"ExecutionManager",
|
||||||
"Scheduler",
|
"Scheduler",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from backend.util.settings import Settings
|
|||||||
from backend.util.truncate import truncate
|
from backend.util.truncate import truncate
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.executor import DatabaseManagerAsyncClient
|
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import logging
|
|||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.executor import DatabaseManagerAsyncClient
|
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Redis-based distributed locking for cluster coordination."""
|
"""Redis-based distributed locking for cluster coordination."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
@@ -19,6 +20,7 @@ class ClusterLock:
|
|||||||
self.owner_id = owner_id
|
self.owner_id = owner_id
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self._last_refresh = 0.0
|
self._last_refresh = 0.0
|
||||||
|
self._refresh_lock = threading.Lock()
|
||||||
|
|
||||||
def try_acquire(self) -> str | None:
|
def try_acquire(self) -> str | None:
|
||||||
"""Try to acquire the lock.
|
"""Try to acquire the lock.
|
||||||
@@ -31,6 +33,7 @@ class ClusterLock:
|
|||||||
try:
|
try:
|
||||||
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
|
||||||
if success:
|
if success:
|
||||||
|
with self._refresh_lock:
|
||||||
self._last_refresh = time.time()
|
self._last_refresh = time.time()
|
||||||
return self.owner_id # Successfully acquired
|
return self.owner_id # Successfully acquired
|
||||||
|
|
||||||
@@ -57,22 +60,26 @@ class ClusterLock:
|
|||||||
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
|
||||||
During rate limiting, still verifies lock existence but skips TTL extension.
|
During rate limiting, still verifies lock existence but skips TTL extension.
|
||||||
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
Setting _last_refresh to 0 bypasses rate limiting for testing.
|
||||||
|
|
||||||
|
Thread-safe: uses _refresh_lock to protect _last_refresh access.
|
||||||
"""
|
"""
|
||||||
# Calculate refresh interval: max(timeout // 10, 1)
|
# Calculate refresh interval: max(timeout // 10, 1)
|
||||||
refresh_interval = max(self.timeout // 10, 1)
|
refresh_interval = max(self.timeout // 10, 1)
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
# Check if we're within the rate limit period
|
# Check if we're within the rate limit period (thread-safe read)
|
||||||
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
|
||||||
|
with self._refresh_lock:
|
||||||
|
last_refresh = self._last_refresh
|
||||||
is_rate_limited = (
|
is_rate_limited = (
|
||||||
self._last_refresh > 0
|
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
|
||||||
and (current_time - self._last_refresh) < refresh_interval
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Always verify lock existence, even during rate limiting
|
# Always verify lock existence, even during rate limiting
|
||||||
current_value = self.redis.get(self.key)
|
current_value = self.redis.get(self.key)
|
||||||
if not current_value:
|
if not current_value:
|
||||||
|
with self._refresh_lock:
|
||||||
self._last_refresh = 0
|
self._last_refresh = 0
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -82,6 +89,7 @@ class ClusterLock:
|
|||||||
else str(current_value)
|
else str(current_value)
|
||||||
)
|
)
|
||||||
if stored_owner != self.owner_id:
|
if stored_owner != self.owner_id:
|
||||||
|
with self._refresh_lock:
|
||||||
self._last_refresh = 0
|
self._last_refresh = 0
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -91,19 +99,23 @@ class ClusterLock:
|
|||||||
|
|
||||||
# Perform actual refresh
|
# Perform actual refresh
|
||||||
if self.redis.expire(self.key, self.timeout):
|
if self.redis.expire(self.key, self.timeout):
|
||||||
|
with self._refresh_lock:
|
||||||
self._last_refresh = current_time
|
self._last_refresh = current_time
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
with self._refresh_lock:
|
||||||
self._last_refresh = 0
|
self._last_refresh = 0
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
|
||||||
|
with self._refresh_lock:
|
||||||
self._last_refresh = 0
|
self._last_refresh = 0
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def release(self):
|
def release(self):
|
||||||
"""Release the lock."""
|
"""Release the lock."""
|
||||||
|
with self._refresh_lock:
|
||||||
if self._last_refresh == 0:
|
if self._last_refresh == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -112,4 +124,5 @@ class ClusterLock:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
with self._refresh_lock:
|
||||||
self._last_refresh = 0.0
|
self._last_refresh = 0.0
|
||||||
|
|||||||
@@ -93,7 +93,10 @@ from .utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
from backend.data.db_manager import (
|
||||||
|
DatabaseManagerAsyncClient,
|
||||||
|
DatabaseManagerClient,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -13,12 +13,15 @@ if TYPE_CHECKING:
|
|||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from supabase import AClient, Client
|
from supabase import AClient, Client
|
||||||
|
|
||||||
|
from backend.data.db_manager import (
|
||||||
|
DatabaseManagerAsyncClient,
|
||||||
|
DatabaseManagerClient,
|
||||||
|
)
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
AsyncRedisExecutionEventBus,
|
AsyncRedisExecutionEventBus,
|
||||||
RedisExecutionEventBus,
|
RedisExecutionEventBus,
|
||||||
)
|
)
|
||||||
from backend.data.rabbitmq import AsyncRabbitMQ, SyncRabbitMQ
|
from backend.data.rabbitmq import AsyncRabbitMQ, SyncRabbitMQ
|
||||||
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
|
|
||||||
from backend.executor.scheduler import SchedulerClient
|
from backend.executor.scheduler import SchedulerClient
|
||||||
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
from backend.integrations.credentials_store import IntegrationCredentialsStore
|
||||||
from backend.notifications.notifications import NotificationManagerClient
|
from backend.notifications.notifications import NotificationManagerClient
|
||||||
@@ -27,7 +30,7 @@ if TYPE_CHECKING:
|
|||||||
@thread_cached
|
@thread_cached
|
||||||
def get_database_manager_client() -> "DatabaseManagerClient":
|
def get_database_manager_client() -> "DatabaseManagerClient":
|
||||||
"""Get a thread-cached DatabaseManagerClient with request retry enabled."""
|
"""Get a thread-cached DatabaseManagerClient with request retry enabled."""
|
||||||
from backend.executor import DatabaseManagerClient
|
from backend.data.db_manager import DatabaseManagerClient
|
||||||
from backend.util.service import get_service_client
|
from backend.util.service import get_service_client
|
||||||
|
|
||||||
return get_service_client(DatabaseManagerClient, request_retry=True)
|
return get_service_client(DatabaseManagerClient, request_retry=True)
|
||||||
@@ -38,7 +41,7 @@ def get_database_manager_async_client(
|
|||||||
should_retry: bool = True,
|
should_retry: bool = True,
|
||||||
) -> "DatabaseManagerAsyncClient":
|
) -> "DatabaseManagerAsyncClient":
|
||||||
"""Get a thread-cached DatabaseManagerAsyncClient with request retry enabled."""
|
"""Get a thread-cached DatabaseManagerAsyncClient with request retry enabled."""
|
||||||
from backend.executor import DatabaseManagerAsyncClient
|
from backend.data.db_manager import DatabaseManagerAsyncClient
|
||||||
from backend.util.service import get_service_client
|
from backend.util.service import get_service_client
|
||||||
|
|
||||||
return get_service_client(DatabaseManagerAsyncClient, request_retry=should_retry)
|
return get_service_client(DatabaseManagerAsyncClient, request_retry=should_retry)
|
||||||
@@ -106,6 +109,20 @@ async def get_async_execution_queue() -> "AsyncRabbitMQ":
|
|||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
# ============ CoPilot Queue Helpers ============ #
|
||||||
|
|
||||||
|
|
||||||
|
@thread_cached
|
||||||
|
async def get_async_copilot_queue() -> "AsyncRabbitMQ":
|
||||||
|
"""Get a thread-cached AsyncRabbitMQ CoPilot queue client."""
|
||||||
|
from backend.copilot.executor.utils import create_copilot_queue_config
|
||||||
|
from backend.data.rabbitmq import AsyncRabbitMQ
|
||||||
|
|
||||||
|
client = AsyncRabbitMQ(create_copilot_queue_config())
|
||||||
|
await client.connect()
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
# ============ Integration Credentials Store ============ #
|
# ============ Integration Credentials Store ============ #
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -383,7 +383,7 @@ async def store_media_file(
|
|||||||
else:
|
else:
|
||||||
info = await workspace_manager.get_file_info(ws.file_ref)
|
info = await workspace_manager.get_file_info(ws.file_ref)
|
||||||
if info:
|
if info:
|
||||||
return MediaFileType(f"{file}#{info.mimeType}")
|
return MediaFileType(f"{file}#{info.mime_type}")
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return MediaFileType(file)
|
return MediaFileType(file)
|
||||||
@@ -397,7 +397,7 @@ async def store_media_file(
|
|||||||
filename=filename,
|
filename=filename,
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
return MediaFileType(f"workspace://{file_record.id}#{file_record.mimeType}")
|
return MediaFileType(f"workspace://{file_record.id}#{file_record.mime_type}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid return_format: {return_format}")
|
raise ValueError(f"Invalid return_format: {return_format}")
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from typing import (
|
|||||||
import httpx
|
import httpx
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, Request, responses
|
from fastapi import FastAPI, Request, responses
|
||||||
from prisma.errors import DataError
|
from prisma.errors import DataError, UniqueViolationError
|
||||||
from pydantic import BaseModel, TypeAdapter, create_model
|
from pydantic import BaseModel, TypeAdapter, create_model
|
||||||
|
|
||||||
import backend.util.exceptions as exceptions
|
import backend.util.exceptions as exceptions
|
||||||
@@ -201,6 +201,7 @@ EXCEPTION_MAPPING = {
|
|||||||
UnhealthyServiceError,
|
UnhealthyServiceError,
|
||||||
HTTPClientError,
|
HTTPClientError,
|
||||||
HTTPServerError,
|
HTTPServerError,
|
||||||
|
UniqueViolationError,
|
||||||
*[
|
*[
|
||||||
ErrorType
|
ErrorType
|
||||||
for _, ErrorType in inspect.getmembers(exceptions)
|
for _, ErrorType in inspect.getmembers(exceptions)
|
||||||
@@ -416,6 +417,9 @@ class AppService(BaseAppService, ABC):
|
|||||||
self.fastapi_app.add_exception_handler(
|
self.fastapi_app.add_exception_handler(
|
||||||
DataError, self._handle_internal_http_error(400)
|
DataError, self._handle_internal_http_error(400)
|
||||||
)
|
)
|
||||||
|
self.fastapi_app.add_exception_handler(
|
||||||
|
UniqueViolationError, self._handle_internal_http_error(400)
|
||||||
|
)
|
||||||
self.fastapi_app.add_exception_handler(
|
self.fastapi_app.add_exception_handler(
|
||||||
Exception, self._handle_internal_http_error(500)
|
Exception, self._handle_internal_http_error(500)
|
||||||
)
|
)
|
||||||
@@ -478,6 +482,7 @@ def get_service_client(
|
|||||||
# Don't retry these specific exceptions that won't be fixed by retrying
|
# Don't retry these specific exceptions that won't be fixed by retrying
|
||||||
ValueError, # Invalid input/parameters
|
ValueError, # Invalid input/parameters
|
||||||
DataError, # Prisma data integrity errors (foreign key, unique constraints)
|
DataError, # Prisma data integrity errors (foreign key, unique constraints)
|
||||||
|
UniqueViolationError, # Unique constraint violations
|
||||||
KeyError, # Missing required data
|
KeyError, # Missing required data
|
||||||
TypeError, # Wrong data types
|
TypeError, # Wrong data types
|
||||||
AttributeError, # Missing attributes
|
AttributeError, # Missing attributes
|
||||||
|
|||||||
@@ -211,16 +211,23 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
|||||||
description="The port for execution manager daemon to run on",
|
description="The port for execution manager daemon to run on",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_copilot_workers: int = Field(
|
||||||
|
default=5,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
description="Number of concurrent CoPilot executor workers",
|
||||||
|
)
|
||||||
|
|
||||||
|
copilot_executor_port: int = Field(
|
||||||
|
default=8008,
|
||||||
|
description="The port for CoPilot executor daemon to run on",
|
||||||
|
)
|
||||||
|
|
||||||
execution_scheduler_port: int = Field(
|
execution_scheduler_port: int = Field(
|
||||||
default=8003,
|
default=8003,
|
||||||
description="The port for execution scheduler daemon to run on",
|
description="The port for execution scheduler daemon to run on",
|
||||||
)
|
)
|
||||||
|
|
||||||
agent_server_port: int = Field(
|
|
||||||
default=8004,
|
|
||||||
description="The port for agent server daemon to run on",
|
|
||||||
)
|
|
||||||
|
|
||||||
database_api_port: int = Field(
|
database_api_port: int = Field(
|
||||||
default=8005,
|
default=8005,
|
||||||
description="The port for database server API to run on",
|
description="The port for database server API to run on",
|
||||||
@@ -662,7 +669,7 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
|||||||
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
||||||
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
|
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
|
||||||
|
|
||||||
linear_api_key: str = Field(
|
copilot_linear_api_key: str = Field(
|
||||||
default="", description="Linear API key for system-level operations"
|
default="", description="Linear API key for system-level operations"
|
||||||
)
|
)
|
||||||
linear_feature_request_project_id: str = Field(
|
linear_feature_request_project_id: str = Field(
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from backend.api.rest_api import AgentServer
|
|||||||
from backend.blocks._base import Block, BlockSchema
|
from backend.blocks._base import Block, BlockSchema
|
||||||
from backend.data import db
|
from backend.data import db
|
||||||
from backend.data.block import initialize_blocks
|
from backend.data.block import initialize_blocks
|
||||||
|
from backend.data.db_manager import DatabaseManager
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
ExecutionContext,
|
ExecutionContext,
|
||||||
ExecutionStatus,
|
ExecutionStatus,
|
||||||
@@ -19,7 +20,7 @@ from backend.data.execution import (
|
|||||||
)
|
)
|
||||||
from backend.data.model import _BaseCredentials
|
from backend.data.model import _BaseCredentials
|
||||||
from backend.data.user import create_default_user
|
from backend.data.user import create_default_user
|
||||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
from backend.executor import ExecutionManager, Scheduler
|
||||||
from backend.notifications.notifications import NotificationManager
|
from backend.notifications.notifications import NotificationManager
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -11,16 +11,9 @@ import uuid
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from prisma.errors import UniqueViolationError
|
from prisma.errors import UniqueViolationError
|
||||||
from prisma.models import UserWorkspaceFile
|
|
||||||
|
|
||||||
from backend.data.workspace import (
|
from backend.data.db_accessors import workspace_db
|
||||||
count_workspace_files,
|
from backend.data.workspace import WorkspaceFile
|
||||||
create_workspace_file,
|
|
||||||
get_workspace_file,
|
|
||||||
get_workspace_file_by_path,
|
|
||||||
list_workspace_files,
|
|
||||||
soft_delete_workspace_file,
|
|
||||||
)
|
|
||||||
from backend.util.settings import Config
|
from backend.util.settings import Config
|
||||||
from backend.util.virus_scanner import scan_content_safe
|
from backend.util.virus_scanner import scan_content_safe
|
||||||
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage
|
||||||
@@ -125,13 +118,14 @@ class WorkspaceManager:
|
|||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If file doesn't exist
|
FileNotFoundError: If file doesn't exist
|
||||||
"""
|
"""
|
||||||
|
db = workspace_db()
|
||||||
resolved_path = self._resolve_path(path)
|
resolved_path = self._resolve_path(path)
|
||||||
file = await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
file = await db.get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||||
if file is None:
|
if file is None:
|
||||||
raise FileNotFoundError(f"File not found at path: {resolved_path}")
|
raise FileNotFoundError(f"File not found at path: {resolved_path}")
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
storage = await get_workspace_storage()
|
||||||
return await storage.retrieve(file.storagePath)
|
return await storage.retrieve(file.storage_path)
|
||||||
|
|
||||||
async def read_file_by_id(self, file_id: str) -> bytes:
|
async def read_file_by_id(self, file_id: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
@@ -146,12 +140,13 @@ class WorkspaceManager:
|
|||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If file doesn't exist
|
FileNotFoundError: If file doesn't exist
|
||||||
"""
|
"""
|
||||||
file = await get_workspace_file(file_id, self.workspace_id)
|
db = workspace_db()
|
||||||
|
file = await db.get_workspace_file(file_id, self.workspace_id)
|
||||||
if file is None:
|
if file is None:
|
||||||
raise FileNotFoundError(f"File not found: {file_id}")
|
raise FileNotFoundError(f"File not found: {file_id}")
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
storage = await get_workspace_storage()
|
||||||
return await storage.retrieve(file.storagePath)
|
return await storage.retrieve(file.storage_path)
|
||||||
|
|
||||||
async def write_file(
|
async def write_file(
|
||||||
self,
|
self,
|
||||||
@@ -160,7 +155,7 @@ class WorkspaceManager:
|
|||||||
path: Optional[str] = None,
|
path: Optional[str] = None,
|
||||||
mime_type: Optional[str] = None,
|
mime_type: Optional[str] = None,
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
) -> UserWorkspaceFile:
|
) -> WorkspaceFile:
|
||||||
"""
|
"""
|
||||||
Write file to workspace.
|
Write file to workspace.
|
||||||
|
|
||||||
@@ -175,7 +170,7 @@ class WorkspaceManager:
|
|||||||
overwrite: Whether to overwrite existing file at path
|
overwrite: Whether to overwrite existing file at path
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Created UserWorkspaceFile instance
|
Created WorkspaceFile instance
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If file exceeds size limit or path already exists
|
ValueError: If file exceeds size limit or path already exists
|
||||||
@@ -204,8 +199,10 @@ class WorkspaceManager:
|
|||||||
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
|
# For overwrite=True, we let the write proceed and handle via UniqueViolationError
|
||||||
# This ensures the new file is written to storage BEFORE the old one is deleted,
|
# This ensures the new file is written to storage BEFORE the old one is deleted,
|
||||||
# preventing data loss if the new write fails
|
# preventing data loss if the new write fails
|
||||||
|
db = workspace_db()
|
||||||
|
|
||||||
if not overwrite:
|
if not overwrite:
|
||||||
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
existing = await db.get_workspace_file_by_path(self.workspace_id, path)
|
||||||
if existing is not None:
|
if existing is not None:
|
||||||
raise ValueError(f"File already exists at path: {path}")
|
raise ValueError(f"File already exists at path: {path}")
|
||||||
|
|
||||||
@@ -232,7 +229,7 @@ class WorkspaceManager:
|
|||||||
# Create database record - handle race condition where another request
|
# Create database record - handle race condition where another request
|
||||||
# created a file at the same path between our check and create
|
# created a file at the same path between our check and create
|
||||||
try:
|
try:
|
||||||
file = await create_workspace_file(
|
file = await db.create_workspace_file(
|
||||||
workspace_id=self.workspace_id,
|
workspace_id=self.workspace_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
name=filename,
|
name=filename,
|
||||||
@@ -246,12 +243,12 @@ class WorkspaceManager:
|
|||||||
# Race condition: another request created a file at this path
|
# Race condition: another request created a file at this path
|
||||||
if overwrite:
|
if overwrite:
|
||||||
# Re-fetch and delete the conflicting file, then retry
|
# Re-fetch and delete the conflicting file, then retry
|
||||||
existing = await get_workspace_file_by_path(self.workspace_id, path)
|
existing = await db.get_workspace_file_by_path(self.workspace_id, path)
|
||||||
if existing:
|
if existing:
|
||||||
await self.delete_file(existing.id)
|
await self.delete_file(existing.id)
|
||||||
# Retry the create - if this also fails, clean up storage file
|
# Retry the create - if this also fails, clean up storage file
|
||||||
try:
|
try:
|
||||||
file = await create_workspace_file(
|
file = await db.create_workspace_file(
|
||||||
workspace_id=self.workspace_id,
|
workspace_id=self.workspace_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
name=filename,
|
name=filename,
|
||||||
@@ -296,7 +293,7 @@ class WorkspaceManager:
|
|||||||
limit: Optional[int] = None,
|
limit: Optional[int] = None,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
include_all_sessions: bool = False,
|
include_all_sessions: bool = False,
|
||||||
) -> list[UserWorkspaceFile]:
|
) -> list[WorkspaceFile]:
|
||||||
"""
|
"""
|
||||||
List files in workspace.
|
List files in workspace.
|
||||||
|
|
||||||
@@ -311,11 +308,12 @@ class WorkspaceManager:
|
|||||||
If False (default), only list current session's files.
|
If False (default), only list current session's files.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of UserWorkspaceFile instances
|
List of WorkspaceFile instances
|
||||||
"""
|
"""
|
||||||
effective_path = self._get_effective_path(path, include_all_sessions)
|
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||||
|
db = workspace_db()
|
||||||
|
|
||||||
return await list_workspace_files(
|
return await db.list_workspace_files(
|
||||||
workspace_id=self.workspace_id,
|
workspace_id=self.workspace_id,
|
||||||
path_prefix=effective_path,
|
path_prefix=effective_path,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -332,20 +330,21 @@ class WorkspaceManager:
|
|||||||
Returns:
|
Returns:
|
||||||
True if deleted, False if not found
|
True if deleted, False if not found
|
||||||
"""
|
"""
|
||||||
file = await get_workspace_file(file_id, self.workspace_id)
|
db = workspace_db()
|
||||||
|
file = await db.get_workspace_file(file_id, self.workspace_id)
|
||||||
if file is None:
|
if file is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Delete from storage
|
# Delete from storage
|
||||||
storage = await get_workspace_storage()
|
storage = await get_workspace_storage()
|
||||||
try:
|
try:
|
||||||
await storage.delete(file.storagePath)
|
await storage.delete(file.storage_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to delete file from storage: {e}")
|
logger.warning(f"Failed to delete file from storage: {e}")
|
||||||
# Continue with database soft-delete even if storage delete fails
|
# Continue with database soft-delete even if storage delete fails
|
||||||
|
|
||||||
# Soft-delete database record
|
# Soft-delete database record
|
||||||
result = await soft_delete_workspace_file(file_id, self.workspace_id)
|
result = await db.soft_delete_workspace_file(file_id, self.workspace_id)
|
||||||
return result is not None
|
return result is not None
|
||||||
|
|
||||||
async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str:
|
async def get_download_url(self, file_id: str, expires_in: int = 3600) -> str:
|
||||||
@@ -362,14 +361,15 @@ class WorkspaceManager:
|
|||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If file doesn't exist
|
FileNotFoundError: If file doesn't exist
|
||||||
"""
|
"""
|
||||||
file = await get_workspace_file(file_id, self.workspace_id)
|
db = workspace_db()
|
||||||
|
file = await db.get_workspace_file(file_id, self.workspace_id)
|
||||||
if file is None:
|
if file is None:
|
||||||
raise FileNotFoundError(f"File not found: {file_id}")
|
raise FileNotFoundError(f"File not found: {file_id}")
|
||||||
|
|
||||||
storage = await get_workspace_storage()
|
storage = await get_workspace_storage()
|
||||||
return await storage.get_download_url(file.storagePath, expires_in)
|
return await storage.get_download_url(file.storage_path, expires_in)
|
||||||
|
|
||||||
async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]:
|
async def get_file_info(self, file_id: str) -> Optional[WorkspaceFile]:
|
||||||
"""
|
"""
|
||||||
Get file metadata.
|
Get file metadata.
|
||||||
|
|
||||||
@@ -377,11 +377,12 @@ class WorkspaceManager:
|
|||||||
file_id: The file's ID
|
file_id: The file's ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
UserWorkspaceFile instance or None
|
WorkspaceFile instance or None
|
||||||
"""
|
"""
|
||||||
return await get_workspace_file(file_id, self.workspace_id)
|
db = workspace_db()
|
||||||
|
return await db.get_workspace_file(file_id, self.workspace_id)
|
||||||
|
|
||||||
async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]:
|
async def get_file_info_by_path(self, path: str) -> Optional[WorkspaceFile]:
|
||||||
"""
|
"""
|
||||||
Get file metadata by path.
|
Get file metadata by path.
|
||||||
|
|
||||||
@@ -392,10 +393,11 @@ class WorkspaceManager:
|
|||||||
path: Virtual path
|
path: Virtual path
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
UserWorkspaceFile instance or None
|
WorkspaceFile instance or None
|
||||||
"""
|
"""
|
||||||
|
db = workspace_db()
|
||||||
resolved_path = self._resolve_path(path)
|
resolved_path = self._resolve_path(path)
|
||||||
return await get_workspace_file_by_path(self.workspace_id, resolved_path)
|
return await db.get_workspace_file_by_path(self.workspace_id, resolved_path)
|
||||||
|
|
||||||
async def get_file_count(
|
async def get_file_count(
|
||||||
self,
|
self,
|
||||||
@@ -417,7 +419,8 @@ class WorkspaceManager:
|
|||||||
Number of files
|
Number of files
|
||||||
"""
|
"""
|
||||||
effective_path = self._get_effective_path(path, include_all_sessions)
|
effective_path = self._get_effective_path(path, include_all_sessions)
|
||||||
|
db = workspace_db()
|
||||||
|
|
||||||
return await count_workspace_files(
|
return await db.count_workspace_files(
|
||||||
self.workspace_id, path_prefix=effective_path
|
self.workspace_id, path_prefix=effective_path
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ services:
|
|||||||
|
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
<<: *agpt-services
|
<<: *agpt-services
|
||||||
image: rabbitmq:management
|
image: rabbitmq:4.1.4
|
||||||
container_name: rabbitmq
|
container_name: rabbitmq
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: rabbitmq-diagnostics -q ping
|
test: rabbitmq-diagnostics -q ping
|
||||||
@@ -66,7 +66,6 @@ services:
|
|||||||
- RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
- RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
|
||||||
ports:
|
ports:
|
||||||
- "5672:5672"
|
- "5672:5672"
|
||||||
- "15672:15672"
|
|
||||||
clamav:
|
clamav:
|
||||||
image: clamav/clamav-debian:latest
|
image: clamav/clamav-debian:latest
|
||||||
ports:
|
ports:
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ ws = "backend.ws:main"
|
|||||||
scheduler = "backend.scheduler:main"
|
scheduler = "backend.scheduler:main"
|
||||||
notification = "backend.notification:main"
|
notification = "backend.notification:main"
|
||||||
executor = "backend.exec:main"
|
executor = "backend.exec:main"
|
||||||
|
copilot-executor = "backend.copilot.executor.__main__:main"
|
||||||
cli = "backend.cli:main"
|
cli = "backend.cli:main"
|
||||||
format = "linter:format"
|
format = "linter:format"
|
||||||
lint = "linter:lint"
|
lint = "linter:lint"
|
||||||
|
|||||||
@@ -9,10 +9,8 @@ from unittest.mock import AsyncMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.api.features.chat.tools.agent_generator import core
|
from backend.copilot.tools.agent_generator import core
|
||||||
from backend.api.features.chat.tools.agent_generator.core import (
|
from backend.copilot.tools.agent_generator.core import AgentGeneratorNotConfiguredError
|
||||||
AgentGeneratorNotConfiguredError,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestServiceNotConfigured:
|
class TestServiceNotConfigured:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.api.features.chat.tools.agent_generator import core
|
from backend.copilot.tools.agent_generator import core
|
||||||
|
|
||||||
|
|
||||||
class TestGetLibraryAgentsForGeneration:
|
class TestGetLibraryAgentsForGeneration:
|
||||||
@@ -31,18 +31,20 @@ class TestGetLibraryAgentsForGeneration:
|
|||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.agents = [mock_agent]
|
mock_response.agents = [mock_agent]
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
core.library_db,
|
core,
|
||||||
"list_library_agents",
|
"library_db",
|
||||||
new_callable=AsyncMock,
|
return_value=mock_db,
|
||||||
return_value=mock_response,
|
):
|
||||||
) as mock_list:
|
|
||||||
result = await core.get_library_agents_for_generation(
|
result = await core.get_library_agents_for_generation(
|
||||||
user_id="user-123",
|
user_id="user-123",
|
||||||
search_query="send email",
|
search_query="send email",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_list.assert_called_once_with(
|
mock_db.list_library_agents.assert_called_once_with(
|
||||||
user_id="user-123",
|
user_id="user-123",
|
||||||
search_term="send email",
|
search_term="send email",
|
||||||
page=1,
|
page=1,
|
||||||
@@ -80,11 +82,13 @@ class TestGetLibraryAgentsForGeneration:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
core.library_db,
|
core,
|
||||||
"list_library_agents",
|
"library_db",
|
||||||
new_callable=AsyncMock,
|
return_value=mock_db,
|
||||||
return_value=mock_response,
|
|
||||||
):
|
):
|
||||||
result = await core.get_library_agents_for_generation(
|
result = await core.get_library_agents_for_generation(
|
||||||
user_id="user-123",
|
user_id="user-123",
|
||||||
@@ -101,18 +105,20 @@ class TestGetLibraryAgentsForGeneration:
|
|||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.agents = []
|
mock_response.agents = []
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
core.library_db,
|
core,
|
||||||
"list_library_agents",
|
"library_db",
|
||||||
new_callable=AsyncMock,
|
return_value=mock_db,
|
||||||
return_value=mock_response,
|
):
|
||||||
) as mock_list:
|
|
||||||
await core.get_library_agents_for_generation(
|
await core.get_library_agents_for_generation(
|
||||||
user_id="user-123",
|
user_id="user-123",
|
||||||
max_results=5,
|
max_results=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_list.assert_called_once_with(
|
mock_db.list_library_agents.assert_called_once_with(
|
||||||
user_id="user-123",
|
user_id="user-123",
|
||||||
search_term=None,
|
search_term=None,
|
||||||
page=1,
|
page=1,
|
||||||
@@ -144,24 +150,24 @@ class TestSearchMarketplaceAgentsForGeneration:
|
|||||||
mock_graph.input_schema = {"type": "object"}
|
mock_graph.input_schema = {"type": "object"}
|
||||||
mock_graph.output_schema = {"type": "object"}
|
mock_graph.output_schema = {"type": "object"}
|
||||||
|
|
||||||
|
mock_store_db = MagicMock()
|
||||||
|
mock_store_db.get_store_agents = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
mock_graph_db = MagicMock()
|
||||||
|
mock_graph_db.get_store_listed_graphs = AsyncMock(
|
||||||
|
return_value={"graph-123": mock_graph}
|
||||||
|
)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch.object(core, "store_db", return_value=mock_store_db),
|
||||||
"backend.api.features.store.db.get_store_agents",
|
patch.object(core, "graph_db", return_value=mock_graph_db),
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=mock_response,
|
|
||||||
) as mock_search,
|
|
||||||
patch(
|
|
||||||
"backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value={"graph-123": mock_graph},
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
result = await core.search_marketplace_agents_for_generation(
|
result = await core.search_marketplace_agents_for_generation(
|
||||||
search_query="automation",
|
search_query="automation",
|
||||||
max_results=10,
|
max_results=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_search.assert_called_once_with(
|
mock_store_db.get_store_agents.assert_called_once_with(
|
||||||
search_query="automation",
|
search_query="automation",
|
||||||
page=1,
|
page=1,
|
||||||
page_size=10,
|
page_size=10,
|
||||||
@@ -707,7 +713,7 @@ class TestExtractUuidsFromText:
|
|||||||
|
|
||||||
|
|
||||||
class TestGetLibraryAgentById:
|
class TestGetLibraryAgentById:
|
||||||
"""Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id)."""
|
"""Test get_library_agent_by_id function (alias: get_library_agent_by_graph_id)."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_returns_agent_when_found_by_graph_id(self):
|
async def test_returns_agent_when_found_by_graph_id(self):
|
||||||
@@ -720,12 +726,10 @@ class TestGetLibraryAgentById:
|
|||||||
mock_agent.input_schema = {"properties": {}}
|
mock_agent.input_schema = {"properties": {}}
|
||||||
mock_agent.output_schema = {"properties": {}}
|
mock_agent.output_schema = {"properties": {}}
|
||||||
|
|
||||||
with patch.object(
|
mock_db = MagicMock()
|
||||||
core.library_db,
|
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||||
"get_library_agent_by_graph_id",
|
|
||||||
new_callable=AsyncMock,
|
with patch.object(core, "library_db", return_value=mock_db):
|
||||||
return_value=mock_agent,
|
|
||||||
):
|
|
||||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -743,20 +747,11 @@ class TestGetLibraryAgentById:
|
|||||||
mock_agent.input_schema = {"properties": {}}
|
mock_agent.input_schema = {"properties": {}}
|
||||||
mock_agent.output_schema = {"properties": {}}
|
mock_agent.output_schema = {"properties": {}}
|
||||||
|
|
||||||
with (
|
mock_db = MagicMock()
|
||||||
patch.object(
|
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None)
|
||||||
core.library_db,
|
mock_db.get_library_agent = AsyncMock(return_value=mock_agent)
|
||||||
"get_library_agent_by_graph_id",
|
|
||||||
new_callable=AsyncMock,
|
with patch.object(core, "library_db", return_value=mock_db):
|
||||||
return_value=None, # Not found by graph_id
|
|
||||||
),
|
|
||||||
patch.object(
|
|
||||||
core.library_db,
|
|
||||||
"get_library_agent",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=mock_agent, # Found by library ID
|
|
||||||
),
|
|
||||||
):
|
|
||||||
result = await core.get_library_agent_by_id("user-123", "library-id-123")
|
result = await core.get_library_agent_by_id("user-123", "library-id-123")
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
@@ -766,20 +761,13 @@ class TestGetLibraryAgentById:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_returns_none_when_not_found_by_either_method(self):
|
async def test_returns_none_when_not_found_by_either_method(self):
|
||||||
"""Test that None is returned when agent not found by either method."""
|
"""Test that None is returned when agent not found by either method."""
|
||||||
with (
|
mock_db = MagicMock()
|
||||||
patch.object(
|
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None)
|
||||||
core.library_db,
|
mock_db.get_library_agent = AsyncMock(
|
||||||
"get_library_agent_by_graph_id",
|
side_effect=core.NotFoundError("Not found")
|
||||||
new_callable=AsyncMock,
|
)
|
||||||
return_value=None,
|
|
||||||
),
|
with patch.object(core, "library_db", return_value=mock_db):
|
||||||
patch.object(
|
|
||||||
core.library_db,
|
|
||||||
"get_library_agent",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
side_effect=core.NotFoundError("Not found"),
|
|
||||||
),
|
|
||||||
):
|
|
||||||
result = await core.get_library_agent_by_id("user-123", "nonexistent")
|
result = await core.get_library_agent_by_id("user-123", "nonexistent")
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
@@ -787,27 +775,20 @@ class TestGetLibraryAgentById:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_returns_none_on_exception(self):
|
async def test_returns_none_on_exception(self):
|
||||||
"""Test that None is returned when exception occurs in both lookups."""
|
"""Test that None is returned when exception occurs in both lookups."""
|
||||||
with (
|
mock_db = MagicMock()
|
||||||
patch.object(
|
mock_db.get_library_agent_by_graph_id = AsyncMock(
|
||||||
core.library_db,
|
side_effect=Exception("Database error")
|
||||||
"get_library_agent_by_graph_id",
|
)
|
||||||
new_callable=AsyncMock,
|
mock_db.get_library_agent = AsyncMock(side_effect=Exception("Database error"))
|
||||||
side_effect=Exception("Database error"),
|
|
||||||
),
|
with patch.object(core, "library_db", return_value=mock_db):
|
||||||
patch.object(
|
|
||||||
core.library_db,
|
|
||||||
"get_library_agent",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
side_effect=Exception("Database error"),
|
|
||||||
),
|
|
||||||
):
|
|
||||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_alias_works(self):
|
async def test_alias_works(self):
|
||||||
"""Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id."""
|
"""Test that get_library_agent_by_graph_id is an alias."""
|
||||||
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
|
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
|
||||||
|
|
||||||
|
|
||||||
@@ -828,20 +809,11 @@ class TestGetAllRelevantAgentsWithUuids:
|
|||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.agents = []
|
mock_response.agents = []
|
||||||
|
|
||||||
with (
|
mock_db = MagicMock()
|
||||||
patch.object(
|
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
|
||||||
core.library_db,
|
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
|
||||||
"get_library_agent_by_graph_id",
|
|
||||||
new_callable=AsyncMock,
|
with patch.object(core, "library_db", return_value=mock_db):
|
||||||
return_value=mock_agent,
|
|
||||||
),
|
|
||||||
patch.object(
|
|
||||||
core.library_db,
|
|
||||||
"list_library_agents",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=mock_response,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
result = await core.get_all_relevant_agents_for_generation(
|
result = await core.get_all_relevant_agents_for_generation(
|
||||||
user_id="user-123",
|
user_id="user-123",
|
||||||
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",
|
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from backend.api.features.chat.tools.agent_generator import service
|
from backend.copilot.tools.agent_generator import service
|
||||||
|
|
||||||
|
|
||||||
class TestServiceConfiguration:
|
class TestServiceConfiguration:
|
||||||
|
|||||||
@@ -1,133 +0,0 @@
|
|||||||
"""Tests for SDK security hooks — workspace paths, tool access, and deny messages.
|
|
||||||
|
|
||||||
These are pure unit tests with no external dependencies (no SDK, no DB, no server).
|
|
||||||
They validate that the security hooks correctly block unauthorized paths,
|
|
||||||
tool access, and dangerous input patterns.
|
|
||||||
|
|
||||||
Note: Bash command validation was removed — the SDK built-in Bash tool is not in
|
|
||||||
allowed_tools, and the bash_exec MCP tool has kernel-level network isolation
|
|
||||||
(unshare --net) making command-level parsing unnecessary.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from backend.api.features.chat.sdk.security_hooks import (
|
|
||||||
_validate_tool_access,
|
|
||||||
_validate_workspace_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
SDK_CWD = "/tmp/copilot-test-session"
|
|
||||||
|
|
||||||
|
|
||||||
def _is_denied(result: dict) -> bool:
|
|
||||||
hook = result.get("hookSpecificOutput", {})
|
|
||||||
return hook.get("permissionDecision") == "deny"
|
|
||||||
|
|
||||||
|
|
||||||
def _reason(result: dict) -> str:
|
|
||||||
return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Workspace path validation (Read, Write, Edit, etc.)
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class TestWorkspacePathValidation:
|
|
||||||
def test_path_in_workspace(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
def test_path_outside_workspace(self):
|
|
||||||
result = _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
def test_tool_results_allowed(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read",
|
|
||||||
{"file_path": "~/.claude/projects/abc/tool-results/out.txt"},
|
|
||||||
SDK_CWD,
|
|
||||||
)
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
def test_claude_settings_blocked(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read", {"file_path": "~/.claude/settings.json"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
def test_claude_projects_without_tool_results(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read", {"file_path": "~/.claude/projects/abc/credentials.json"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
def test_no_path_allowed(self):
|
|
||||||
"""Glob/Grep without path defaults to cwd — should be allowed."""
|
|
||||||
result = _validate_workspace_path("Grep", {"pattern": "foo"}, SDK_CWD)
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
def test_path_traversal_with_dotdot(self):
|
|
||||||
result = _validate_workspace_path(
|
|
||||||
"Read", {"file_path": f"{SDK_CWD}/../../../etc/passwd"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Tool access validation
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolAccessValidation:
|
|
||||||
def test_blocked_tools(self):
|
|
||||||
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
|
||||||
result = _validate_tool_access(tool, {})
|
|
||||||
assert _is_denied(result), f"Tool '{tool}' should be blocked"
|
|
||||||
|
|
||||||
def test_bash_builtin_blocked(self):
|
|
||||||
"""SDK built-in Bash (capital) is blocked as defence-in-depth."""
|
|
||||||
result = _validate_tool_access("Bash", {"command": "echo hello"}, SDK_CWD)
|
|
||||||
assert _is_denied(result)
|
|
||||||
assert "Bash" in _reason(result)
|
|
||||||
|
|
||||||
def test_workspace_tools_delegate(self):
|
|
||||||
result = _validate_tool_access(
|
|
||||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
|
|
||||||
)
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
def test_dangerous_pattern_blocked(self):
|
|
||||||
result = _validate_tool_access("SomeUnknownTool", {"data": "sudo rm -rf /"})
|
|
||||||
assert _is_denied(result)
|
|
||||||
|
|
||||||
def test_safe_unknown_tool_allowed(self):
|
|
||||||
result = _validate_tool_access("SomeSafeTool", {"data": "hello world"})
|
|
||||||
assert not _is_denied(result)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Deny message quality (ntindle feedback)
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
|
|
||||||
class TestDenyMessageClarity:
|
|
||||||
"""Deny messages must include [SECURITY] and 'cannot be bypassed'
|
|
||||||
so the model knows the restriction is enforced, not a suggestion."""
|
|
||||||
|
|
||||||
def test_blocked_tool_message(self):
|
|
||||||
reason = _reason(_validate_tool_access("bash", {}))
|
|
||||||
assert "[SECURITY]" in reason
|
|
||||||
assert "cannot be bypassed" in reason
|
|
||||||
|
|
||||||
def test_bash_builtin_blocked_message(self):
|
|
||||||
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
|
|
||||||
assert "[SECURITY]" in reason
|
|
||||||
assert "cannot be bypassed" in reason
|
|
||||||
|
|
||||||
def test_workspace_path_message(self):
|
|
||||||
reason = _reason(
|
|
||||||
_validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
|
|
||||||
)
|
|
||||||
assert "[SECURITY]" in reason
|
|
||||||
assert "cannot be bypassed" in reason
|
|
||||||
@@ -75,7 +75,7 @@ services:
|
|||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 5
|
retries: 5
|
||||||
rabbitmq:
|
rabbitmq:
|
||||||
image: rabbitmq:management
|
image: rabbitmq:4.1.4
|
||||||
container_name: rabbitmq
|
container_name: rabbitmq
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: rabbitmq-diagnostics -q ping
|
test: rabbitmq-diagnostics -q ping
|
||||||
@@ -88,14 +88,13 @@ services:
|
|||||||
<<: *backend-env
|
<<: *backend-env
|
||||||
ports:
|
ports:
|
||||||
- "5672:5672"
|
- "5672:5672"
|
||||||
- "15672:15672"
|
|
||||||
|
|
||||||
rest_server:
|
rest_server:
|
||||||
build:
|
build:
|
||||||
context: ../
|
context: ../
|
||||||
dockerfile: autogpt_platform/backend/Dockerfile
|
dockerfile: autogpt_platform/backend/Dockerfile
|
||||||
target: server
|
target: server
|
||||||
command: ["python", "-m", "backend.rest"]
|
command: ["rest"] # points to entry in [tool.poetry.scripts] in pyproject.toml
|
||||||
develop:
|
develop:
|
||||||
watch:
|
watch:
|
||||||
- path: ./
|
- path: ./
|
||||||
@@ -128,7 +127,7 @@ services:
|
|||||||
context: ../
|
context: ../
|
||||||
dockerfile: autogpt_platform/backend/Dockerfile
|
dockerfile: autogpt_platform/backend/Dockerfile
|
||||||
target: server
|
target: server
|
||||||
command: ["python", "-m", "backend.exec"]
|
command: ["executor"] # points to entry in [tool.poetry.scripts] in pyproject.toml
|
||||||
develop:
|
develop:
|
||||||
watch:
|
watch:
|
||||||
- path: ./
|
- path: ./
|
||||||
@@ -158,12 +157,47 @@ services:
|
|||||||
max-size: "10m"
|
max-size: "10m"
|
||||||
max-file: "3"
|
max-file: "3"
|
||||||
|
|
||||||
|
copilot_executor:
|
||||||
|
build:
|
||||||
|
context: ../
|
||||||
|
dockerfile: autogpt_platform/backend/Dockerfile
|
||||||
|
target: server
|
||||||
|
command: ["python", "-m", "backend.copilot.executor"]
|
||||||
|
develop:
|
||||||
|
watch:
|
||||||
|
- path: ./
|
||||||
|
target: autogpt_platform/backend/
|
||||||
|
action: rebuild
|
||||||
|
depends_on:
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
rabbitmq:
|
||||||
|
condition: service_healthy
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
migrate:
|
||||||
|
condition: service_completed_successfully
|
||||||
|
database_manager:
|
||||||
|
condition: service_started
|
||||||
|
<<: *backend-env-files
|
||||||
|
environment:
|
||||||
|
<<: *backend-env
|
||||||
|
ports:
|
||||||
|
- "8008:8008"
|
||||||
|
networks:
|
||||||
|
- app-network
|
||||||
|
logging:
|
||||||
|
driver: json-file
|
||||||
|
options:
|
||||||
|
max-size: "10m"
|
||||||
|
max-file: "3"
|
||||||
|
|
||||||
websocket_server:
|
websocket_server:
|
||||||
build:
|
build:
|
||||||
context: ../
|
context: ../
|
||||||
dockerfile: autogpt_platform/backend/Dockerfile
|
dockerfile: autogpt_platform/backend/Dockerfile
|
||||||
target: server
|
target: server
|
||||||
command: ["python", "-m", "backend.ws"]
|
command: ["ws"] # points to entry in [tool.poetry.scripts] in pyproject.toml
|
||||||
develop:
|
develop:
|
||||||
watch:
|
watch:
|
||||||
- path: ./
|
- path: ./
|
||||||
@@ -196,7 +230,7 @@ services:
|
|||||||
context: ../
|
context: ../
|
||||||
dockerfile: autogpt_platform/backend/Dockerfile
|
dockerfile: autogpt_platform/backend/Dockerfile
|
||||||
target: server
|
target: server
|
||||||
command: ["python", "-m", "backend.db"]
|
command: ["db"] # points to entry in [tool.poetry.scripts] in pyproject.toml
|
||||||
develop:
|
develop:
|
||||||
watch:
|
watch:
|
||||||
- path: ./
|
- path: ./
|
||||||
@@ -225,7 +259,7 @@ services:
|
|||||||
context: ../
|
context: ../
|
||||||
dockerfile: autogpt_platform/backend/Dockerfile
|
dockerfile: autogpt_platform/backend/Dockerfile
|
||||||
target: server
|
target: server
|
||||||
command: ["python", "-m", "backend.scheduler"]
|
command: ["scheduler"] # points to entry in [tool.poetry.scripts] in pyproject.toml
|
||||||
develop:
|
develop:
|
||||||
watch:
|
watch:
|
||||||
- path: ./
|
- path: ./
|
||||||
@@ -273,7 +307,7 @@ services:
|
|||||||
context: ../
|
context: ../
|
||||||
dockerfile: autogpt_platform/backend/Dockerfile
|
dockerfile: autogpt_platform/backend/Dockerfile
|
||||||
target: server
|
target: server
|
||||||
command: ["python", "-m", "backend.notification"]
|
command: ["notification"] # points to entry in [tool.poetry.scripts] in pyproject.toml
|
||||||
develop:
|
develop:
|
||||||
watch:
|
watch:
|
||||||
- path: ./
|
- path: ./
|
||||||
|
|||||||
@@ -53,6 +53,12 @@ services:
|
|||||||
file: ./docker-compose.platform.yml
|
file: ./docker-compose.platform.yml
|
||||||
service: executor
|
service: executor
|
||||||
|
|
||||||
|
copilot_executor:
|
||||||
|
<<: *agpt-services
|
||||||
|
extends:
|
||||||
|
file: ./docker-compose.platform.yml
|
||||||
|
service: copilot_executor
|
||||||
|
|
||||||
websocket_server:
|
websocket_server:
|
||||||
<<: *agpt-services
|
<<: *agpt-services
|
||||||
extends:
|
extends:
|
||||||
@@ -174,5 +180,6 @@ services:
|
|||||||
- deps
|
- deps
|
||||||
- rest_server
|
- rest_server
|
||||||
- executor
|
- executor
|
||||||
|
- copilot_executor
|
||||||
- websocket_server
|
- websocket_server
|
||||||
- database_manager
|
- database_manager
|
||||||
|
|||||||
@@ -62,7 +62,6 @@
|
|||||||
"@rjsf/validator-ajv8": "6.1.2",
|
"@rjsf/validator-ajv8": "6.1.2",
|
||||||
"@sentry/nextjs": "10.27.0",
|
"@sentry/nextjs": "10.27.0",
|
||||||
"@streamdown/cjk": "1.0.1",
|
"@streamdown/cjk": "1.0.1",
|
||||||
"@streamdown/code": "1.0.1",
|
|
||||||
"@streamdown/math": "1.0.1",
|
"@streamdown/math": "1.0.1",
|
||||||
"@streamdown/mermaid": "1.0.1",
|
"@streamdown/mermaid": "1.0.1",
|
||||||
"@supabase/ssr": "0.7.0",
|
"@supabase/ssr": "0.7.0",
|
||||||
@@ -116,6 +115,7 @@
|
|||||||
"remark-gfm": "4.0.1",
|
"remark-gfm": "4.0.1",
|
||||||
"remark-math": "6.0.0",
|
"remark-math": "6.0.0",
|
||||||
"shepherd.js": "14.5.1",
|
"shepherd.js": "14.5.1",
|
||||||
|
"shiki": "^3.21.0",
|
||||||
"sonner": "2.0.7",
|
"sonner": "2.0.7",
|
||||||
"streamdown": "2.1.0",
|
"streamdown": "2.1.0",
|
||||||
"tailwind-merge": "2.6.0",
|
"tailwind-merge": "2.6.0",
|
||||||
|
|||||||
16
autogpt_platform/frontend/pnpm-lock.yaml
generated
16
autogpt_platform/frontend/pnpm-lock.yaml
generated
@@ -108,9 +108,6 @@ importers:
|
|||||||
'@streamdown/cjk':
|
'@streamdown/cjk':
|
||||||
specifier: 1.0.1
|
specifier: 1.0.1
|
||||||
version: 1.0.1(@types/mdast@4.0.4)(micromark-util-types@2.0.2)(micromark@4.0.2)(react@18.3.1)(unified@11.0.5)
|
version: 1.0.1(@types/mdast@4.0.4)(micromark-util-types@2.0.2)(micromark@4.0.2)(react@18.3.1)(unified@11.0.5)
|
||||||
'@streamdown/code':
|
|
||||||
specifier: 1.0.1
|
|
||||||
version: 1.0.1(react@18.3.1)
|
|
||||||
'@streamdown/math':
|
'@streamdown/math':
|
||||||
specifier: 1.0.1
|
specifier: 1.0.1
|
||||||
version: 1.0.1(react@18.3.1)
|
version: 1.0.1(react@18.3.1)
|
||||||
@@ -270,6 +267,9 @@ importers:
|
|||||||
shepherd.js:
|
shepherd.js:
|
||||||
specifier: 14.5.1
|
specifier: 14.5.1
|
||||||
version: 14.5.1
|
version: 14.5.1
|
||||||
|
shiki:
|
||||||
|
specifier: ^3.21.0
|
||||||
|
version: 3.21.0
|
||||||
sonner:
|
sonner:
|
||||||
specifier: 2.0.7
|
specifier: 2.0.7
|
||||||
version: 2.0.7(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
version: 2.0.7(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
@@ -3307,11 +3307,6 @@ packages:
|
|||||||
peerDependencies:
|
peerDependencies:
|
||||||
react: ^18.0.0 || ^19.0.0
|
react: ^18.0.0 || ^19.0.0
|
||||||
|
|
||||||
'@streamdown/code@1.0.1':
|
|
||||||
resolution: {integrity: sha512-U9LITfQ28tZYAoY922jdtw1ryg4kgRBdURopqK9hph7G2fBUwPeHthjH7SvaV0fvFv7EqjqCzARJuWUljLe9Ag==}
|
|
||||||
peerDependencies:
|
|
||||||
react: ^18.0.0 || ^19.0.0
|
|
||||||
|
|
||||||
'@streamdown/math@1.0.1':
|
'@streamdown/math@1.0.1':
|
||||||
resolution: {integrity: sha512-R9WdHbpERiRU7WeO7oT1aIbnLJ/jraDr89F7X9x2OM//Y8G8UMATRnLD/RUwg4VLr8Nu7QSIJ0Pa8lXd2meM4Q==}
|
resolution: {integrity: sha512-R9WdHbpERiRU7WeO7oT1aIbnLJ/jraDr89F7X9x2OM//Y8G8UMATRnLD/RUwg4VLr8Nu7QSIJ0Pa8lXd2meM4Q==}
|
||||||
peerDependencies:
|
peerDependencies:
|
||||||
@@ -11907,11 +11902,6 @@ snapshots:
|
|||||||
- micromark-util-types
|
- micromark-util-types
|
||||||
- unified
|
- unified
|
||||||
|
|
||||||
'@streamdown/code@1.0.1(react@18.3.1)':
|
|
||||||
dependencies:
|
|
||||||
react: 18.3.1
|
|
||||||
shiki: 3.21.0
|
|
||||||
|
|
||||||
'@streamdown/math@1.0.1(react@18.3.1)':
|
'@streamdown/math@1.0.1(react@18.3.1)':
|
||||||
dependencies:
|
dependencies:
|
||||||
katex: 0.16.28
|
katex: 0.16.28
|
||||||
|
|||||||
@@ -1,8 +1,16 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
|
import {
|
||||||
|
DropdownMenu,
|
||||||
|
DropdownMenuContent,
|
||||||
|
DropdownMenuItem,
|
||||||
|
DropdownMenuTrigger,
|
||||||
|
} from "@/components/molecules/DropdownMenu/DropdownMenu";
|
||||||
import { SidebarProvider } from "@/components/ui/sidebar";
|
import { SidebarProvider } from "@/components/ui/sidebar";
|
||||||
|
import { DotsThree } from "@phosphor-icons/react";
|
||||||
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
|
||||||
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
|
||||||
|
import { DeleteChatDialog } from "./components/DeleteChatDialog/DeleteChatDialog";
|
||||||
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
|
||||||
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
import { MobileHeader } from "./components/MobileHeader/MobileHeader";
|
||||||
import { ScaleLoader } from "./components/ScaleLoader/ScaleLoader";
|
import { ScaleLoader } from "./components/ScaleLoader/ScaleLoader";
|
||||||
@@ -31,6 +39,12 @@ export function CopilotPage() {
|
|||||||
handleDrawerOpenChange,
|
handleDrawerOpenChange,
|
||||||
handleSelectSession,
|
handleSelectSession,
|
||||||
handleNewChat,
|
handleNewChat,
|
||||||
|
// Delete functionality
|
||||||
|
sessionToDelete,
|
||||||
|
isDeleting,
|
||||||
|
handleDeleteClick,
|
||||||
|
handleConfirmDelete,
|
||||||
|
handleCancelDelete,
|
||||||
} = useCopilotPage();
|
} = useCopilotPage();
|
||||||
|
|
||||||
if (isUserLoading || !isLoggedIn) {
|
if (isUserLoading || !isLoggedIn) {
|
||||||
@@ -60,6 +74,38 @@ export function CopilotPage() {
|
|||||||
onCreateSession={createSession}
|
onCreateSession={createSession}
|
||||||
onSend={onSend}
|
onSend={onSend}
|
||||||
onStop={stop}
|
onStop={stop}
|
||||||
|
headerSlot={
|
||||||
|
isMobile && sessionId ? (
|
||||||
|
<div className="flex justify-end">
|
||||||
|
<DropdownMenu>
|
||||||
|
<DropdownMenuTrigger asChild>
|
||||||
|
<button
|
||||||
|
className="rounded p-1.5 hover:bg-neutral-100"
|
||||||
|
aria-label="More actions"
|
||||||
|
>
|
||||||
|
<DotsThree className="h-5 w-5 text-neutral-600" />
|
||||||
|
</button>
|
||||||
|
</DropdownMenuTrigger>
|
||||||
|
<DropdownMenuContent align="end">
|
||||||
|
<DropdownMenuItem
|
||||||
|
onClick={() => {
|
||||||
|
const session = sessions.find(
|
||||||
|
(s) => s.id === sessionId,
|
||||||
|
);
|
||||||
|
if (session) {
|
||||||
|
handleDeleteClick(session.id, session.title);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
disabled={isDeleting}
|
||||||
|
className="text-red-600 focus:bg-red-50 focus:text-red-600"
|
||||||
|
>
|
||||||
|
Delete chat
|
||||||
|
</DropdownMenuItem>
|
||||||
|
</DropdownMenuContent>
|
||||||
|
</DropdownMenu>
|
||||||
|
</div>
|
||||||
|
) : undefined
|
||||||
|
}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -75,6 +121,15 @@ export function CopilotPage() {
|
|||||||
onOpenChange={handleDrawerOpenChange}
|
onOpenChange={handleDrawerOpenChange}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
{/* Delete confirmation dialog - rendered at top level for proper z-index on mobile */}
|
||||||
|
{isMobile && (
|
||||||
|
<DeleteChatDialog
|
||||||
|
session={sessionToDelete}
|
||||||
|
isDeleting={isDeleting}
|
||||||
|
onConfirm={handleConfirmDelete}
|
||||||
|
onCancel={handleCancelDelete}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</SidebarProvider>
|
</SidebarProvider>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
import { ChatInput } from "@/app/(platform)/copilot/components/ChatInput/ChatInput";
|
||||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||||
import { LayoutGroup, motion } from "framer-motion";
|
import { LayoutGroup, motion } from "framer-motion";
|
||||||
|
import { ReactNode } from "react";
|
||||||
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
|
import { ChatMessagesContainer } from "../ChatMessagesContainer/ChatMessagesContainer";
|
||||||
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
|
import { CopilotChatActionsProvider } from "../CopilotChatActionsProvider/CopilotChatActionsProvider";
|
||||||
import { EmptySession } from "../EmptySession/EmptySession";
|
import { EmptySession } from "../EmptySession/EmptySession";
|
||||||
@@ -16,6 +17,7 @@ export interface ChatContainerProps {
|
|||||||
onCreateSession: () => void | Promise<string>;
|
onCreateSession: () => void | Promise<string>;
|
||||||
onSend: (message: string) => void | Promise<void>;
|
onSend: (message: string) => void | Promise<void>;
|
||||||
onStop: () => void;
|
onStop: () => void;
|
||||||
|
headerSlot?: ReactNode;
|
||||||
}
|
}
|
||||||
export const ChatContainer = ({
|
export const ChatContainer = ({
|
||||||
messages,
|
messages,
|
||||||
@@ -27,6 +29,7 @@ export const ChatContainer = ({
|
|||||||
onCreateSession,
|
onCreateSession,
|
||||||
onSend,
|
onSend,
|
||||||
onStop,
|
onStop,
|
||||||
|
headerSlot,
|
||||||
}: ChatContainerProps) => {
|
}: ChatContainerProps) => {
|
||||||
const inputLayoutId = "copilot-2-chat-input";
|
const inputLayoutId = "copilot-2-chat-input";
|
||||||
|
|
||||||
@@ -41,6 +44,7 @@ export const ChatContainer = ({
|
|||||||
status={status}
|
status={status}
|
||||||
error={error}
|
error={error}
|
||||||
isLoading={isLoadingSession}
|
isLoading={isLoadingSession}
|
||||||
|
headerSlot={headerSlot}
|
||||||
/>
|
/>
|
||||||
<motion.div
|
<motion.div
|
||||||
initial={{ opacity: 0 }}
|
initial={{ opacity: 0 }}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user