Compare commits

..

47 Commits

Author SHA1 Message Date
Reinier van der Leer
17b4914140 Merge branch 'dev' into pwuts/open-2995-refactor-move-copilot-ai-generation-tool-execution-to 2026-02-16 14:50:04 +01:00
Reinier van der Leer
9d4dcbd9e0 fix(backend/docker): Make server last (= default) build stage
Without specifying an explicit build target it would build the `migrate` stage because it is the last stage in the Dockerfile. This caused deployment failures.

- Follow-up to #12124 and 074be7ae
2026-02-16 14:49:30 +01:00
Reinier van der Leer
1901861cbc Merge branch 'dev' into pwuts/open-2995-refactor-move-copilot-ai-generation-tool-execution-to 2026-02-16 14:26:37 +01:00
Reinier van der Leer
074be7aea6 fix(backend/docker): Update run commands to match deployment
- Follow-up to #12124

Changes:
- Update `run` commands for all backend services in `docker-compose.platform.yml` to match the deployment commands used in production
- Add trigger on `docker-compose(.platform)?.yml` changes to the Frontend CI workflow
2026-02-16 14:23:29 +01:00
Otto
39d28b24fc ci(backend): Upgrade RabbitMQ from 3.12 (EOL) to 4.1.4 (#12118)
## Summary
Upgrades RabbitMQ from the end-of-life `rabbitmq:3.12-management` to
`rabbitmq:4.1.4`, aligning CI, local dev, and e2e testing with
production.

## Changes

### CI Workflow (`.github/workflows/platform-backend-ci.yml`)
- **Image:** `rabbitmq:3.12-management` → `rabbitmq:4.1.4`
- **Port:** Removed 15672 (management UI) — not used
- **Health check:** Added to prevent flaky tests from race conditions
during startup

### Docker Compose (`docker-compose.platform.yml`,
`docker-compose.test.yaml`)
- **Image:** `rabbitmq:management` → `rabbitmq:4.1.4`
- **Port:** Removed 15672 (management UI) — not used

## Why
- RabbitMQ 3.12 is EOL
- We don't use the management interface, so `-management` variant is
unnecessary
- CI and local dev/e2e should match production (4.1.4)

## Testing
CI validates that backend tests pass against RabbitMQ 4.1.4 on Python
3.11, 3.12, and 3.13.

---
Closes SECRT-1703
2026-02-16 12:45:39 +00:00
Reinier van der Leer
bf79a7748a fix(backend/build): Update stale Poetry usage in Dockerfile (#12124)
[SECRT-2006: Dev deployment failing: poetry not found in container
PATH](https://linear.app/autogpt/issue/SECRT-2006)

- Follow-up to #12090

### Changes 🏗️

- Remove now-broken Poetry path config values
- Remove usage of now-broken `poetry run` in container run command
- Add trigger on `backend/Dockerfile` changes to Frontend CI workflow

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - If it works, CI will pass
2026-02-16 13:54:20 +01:00
Otto
649d4ab7f5 feat(chat): Add delete chat session endpoint and UI (#12112)
## Summary

Adds the ability to delete chat sessions from the CoPilot interface.

## Changes

### Backend
- Add `DELETE /api/chat/sessions/{session_id}` endpoint in `routes.py`
- Returns 204 on success, 404 if not found or not owned by user
- Reuses existing `delete_chat_session` function from `model.py`

### Frontend
- Add delete button (trash icon) that appears on hover for each chat
session
- Add confirmation dialog before deletion using existing
`DeleteConfirmDialog` component
- Refresh session list after successful delete
- Clear current session selection if the deleted session was active
- Update OpenAPI spec with new endpoint

## Testing

1. Hover over a chat session in sidebar → trash icon appears
2. Click trash icon → confirmation dialog
3. Confirm deletion → session removed, list refreshes
4. If deleted session was active, selection is cleared

## Screenshots

Delete button appears on hover, confirmation dialog on click.

## Related Issues

Closes SECRT-1928

<!-- greptile_comment -->

<h2>Greptile Overview</h2>

<details><summary><h3>Greptile Summary</h3></summary>

Adds the ability to delete chat sessions from the CoPilot interface — a
new `DELETE /api/chat/sessions/{session_id}` backend endpoint and a
corresponding delete button with confirmation dialog in the
`ChatSidebar` frontend component.

- **Backend route** (`routes.py`): Clean implementation reusing the
existing `delete_chat_session` model function with proper auth guards
and 204/404 responses. No issues.
- **Frontend** (`ChatSidebar.tsx`): Adds hover-visible trash icon per
session, confirmation dialog, mutation with cache invalidation, and
active session clearing on delete. However, it uses a `__legacy__`
component (`DeleteConfirmDialog`) which violates the project's style
guide — new code should use the modern design system components. Error
handling only logs to console without user-facing feedback (project
convention is to use toast notifications for mutation errors).
`isDeleting` is destructured but unused.
- **OpenAPI spec** updated correctly.
- **Unrelated file included**:
`notes/plan-SECRT-1959-graph-edge-desync.md` is a planning document for
a different ticket and should be removed from this PR. The `notes/`
directory is newly introduced and both plan files should be reconsidered
for inclusion.
</details>


<details><summary><h3>Confidence Score: 3/5</h3></summary>

- Functionally correct but has style guide violations and includes
unrelated files that should be addressed before merge.
- The core feature implementation (backend DELETE endpoint and frontend
mutation logic) is sound and follows existing patterns. Score is lowered
because: (1) the frontend uses a legacy component explicitly prohibited
by the project's style guide, (2) mutation errors are not surfaced to
the user, and (3) the PR includes an unrelated planning document for a
different ticket.
- Pay close attention to `ChatSidebar.tsx` for the legacy component
import and error handling, and
`notes/plan-SECRT-1959-graph-edge-desync.md` which should be removed.
</details>


<details><summary><h3>Sequence Diagram</h3></summary>

```mermaid
sequenceDiagram
    participant User
    participant ChatSidebar as ChatSidebar (Frontend)
    participant ReactQuery as React Query
    participant API as DELETE /api/chat/sessions/{id}
    participant Model as model.delete_chat_session
    participant DB as db.delete_chat_session (Prisma)
    participant Redis as Redis Cache

    User->>ChatSidebar: Click trash icon on session
    ChatSidebar->>ChatSidebar: Show DeleteConfirmDialog
    User->>ChatSidebar: Confirm deletion
    ChatSidebar->>ReactQuery: deleteSession({ sessionId })
    ReactQuery->>API: DELETE /api/chat/sessions/{session_id}
    API->>Model: delete_chat_session(session_id, user_id)
    Model->>DB: delete_many(where: {id, userId})
    DB-->>Model: bool (deleted count > 0)
    Model->>Redis: Delete session cache key
    Model->>Model: Clean up session lock
    Model-->>API: True
    API-->>ReactQuery: 204 No Content
    ReactQuery->>ChatSidebar: onSuccess callback
    ChatSidebar->>ReactQuery: invalidateQueries(sessions list)
    ChatSidebar->>ChatSidebar: Clear sessionId if deleted was active
```
</details>


<sub>Last reviewed commit: 44a92c6</sub>

<!-- greptile_other_comments_section -->

<details><summary><h4>Context used (3)</h4></summary>

- Context from `dashboard` - autogpt_platform/frontend/CLAUDE.md
([source](https://app.greptile.com/review/custom-context?memory=39861924-d320-41ba-a1a7-a8bff44f780a))
- Context from `dashboard` - autogpt_platform/frontend/CONTRIBUTING.md
([source](https://app.greptile.com/review/custom-context?memory=cc4f1b17-cb5c-4b63-b218-c772b48e20ee))
- Context from `dashboard` - autogpt_platform/CLAUDE.md
([source](https://app.greptile.com/review/custom-context?memory=6e9dc5dc-8942-47df-8677-e60062ec8c3a))
</details>


<!-- /greptile_comment -->

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
2026-02-16 12:19:18 +00:00
Reinier van der Leer
65e020d32b Merge branch 'dev' into pwuts/open-2995-refactor-move-copilot-ai-generation-tool-execution-to 2026-02-16 12:17:54 +01:00
Reinier van der Leer
87189e23b5 fix duplicate message appending in copilot service and sdk 2026-02-14 13:06:57 +01:00
Reinier van der Leer
35a37257a6 address thread safety comment 2026-02-13 22:57:52 +01:00
Reinier van der Leer
ee45f56310 fix tests 2026-02-13 21:35:38 +01:00
Reinier van der Leer
bfd04dcf04 address comments 2026-02-13 21:17:32 +01:00
Reinier van der Leer
608db31508 address minor comments 2026-02-13 18:40:44 +01:00
Reinier van der Leer
62db72500e Merge commit '9ddcaa884c86ca2bed0735ef37b2a711f3f20755' into pwuts/open-2995-refactor-move-copilot-ai-generation-tool-execution-to 2026-02-13 18:26:56 +01:00
Reinier van der Leer
c9efc3f51c Merge branch 'dev' into pwuts/open-2995-refactor-move-copilot-ai-generation-tool-execution-to 2026-02-13 18:05:19 +01:00
Zamil Majdy
9ddcaa884c Merge remote-tracking branch 'origin/dev' into pwuts/open-2995-refactor-move-copilot-ai-generation-tool-execution-to 2026-02-13 20:47:24 +04:00
Zamil Majdy
b3173ed91f Merge branch 'dev' and integrate SDK into copilot microservice
- Resolve merge conflicts from merged SDK changes (PR #12103)
- Move sdk/ files from api/features/chat/sdk/ to copilot/sdk/
- Fix all imports to use backend.copilot.* paths
- Move new tools (bash_exec, sandbox, web_fetch, feature_requests,
  check_operation_status) to copilot/tools/ with updated imports
- Add append_and_save_message to model.py (adapted to chat_db() pattern)
- Wire SDK service into copilot executor processor with feature flag
- Add track_user_message to routes.py stream handler
2026-02-13 20:24:36 +04:00
Reinier van der Leer
648eb9638a fix bodged merge 2026-02-13 14:59:01 +01:00
Reinier van der Leer
74477bbbf3 Merge branch 'dev' into pwuts/open-2995-refactor-move-copilot-ai-generation-tool-execution-to 2026-02-13 14:39:02 +01:00
Reinier van der Leer
cabda535ea Merge branch 'dev' into pwuts/open-2995-refactor-move-copilot-ai-generation-tool-execution-to 2026-02-13 13:06:39 +01:00
Reinier van der Leer
746a36822d Merge branch 'dev' into pwuts/open-2995-refactor-move-copilot-ai-generation-tool-execution-to 2026-02-13 00:47:00 +01:00
Reinier van der Leer
2a46d3fbf4 address more comments 2026-02-12 15:57:35 +01:00
Reinier van der Leer
ab25516a46 fix _consume_run check 2026-02-12 15:23:18 +01:00
Reinier van der Leer
6e2f595c7d address comments 2026-02-12 15:10:11 +01:00
Reinier van der Leer
e523eb62b5 fix lint 2026-02-12 14:30:13 +01:00
Reinier van der Leer
97ff65ef6a fix test 2026-02-12 14:29:48 +01:00
Reinier van der Leer
e8b81f71ef fix tests 2026-02-12 14:20:02 +01:00
Reinier van der Leer
d652821ed5 Merge branch 'dev' into pwuts/open-2995-copilot-microservice-with-block-refactor 2026-02-12 13:31:04 +01:00
Reinier van der Leer
80659d90e4 Merge branch 'pwuts/move-block-base-to-fix-circular-imports' into pwuts/open-2995-copilot-microservice-with-block-refactor 2026-02-12 13:01:31 +01:00
Reinier van der Leer
eef892893c untangle some more 2026-02-12 11:23:42 +01:00
Reinier van der Leer
23175708e6 update OpenAPI schema and add back model_rebuild 2026-02-12 10:38:00 +01:00
Reinier van der Leer
f02c00374e Merge branch 'pwuts/move-block-base-to-fix-circular-imports' into pwuts/open-2995-copilot-microservice-with-block-refactor 2026-02-12 10:28:21 +01:00
Reinier van der Leer
2fa166d839 fix tests 2026-02-12 10:24:20 +01:00
Reinier van der Leer
d927e4b611 fix generate_block_docs.py 2026-02-12 10:21:10 +01:00
Reinier van der Leer
6591b2171c fixed! :) 2026-02-12 10:07:43 +01:00
Reinier van der Leer
85d97a9d5c or this? 2026-02-11 17:40:29 +01:00
Reinier van der Leer
16c8b2a6e3 maybe this works? 2026-02-11 17:31:27 +01:00
Reinier van der Leer
cad54a9f3e eliminate more cross-imports 2026-02-11 17:12:43 +01:00
Reinier van der Leer
ca0620b102 fix type import and image generator 2026-02-11 16:13:52 +01:00
Reinier van der Leer
7a4cf4e186 Merge branch 'pwuts/move-block-base-to-fix-circular-imports' into pwuts/open-2995-copilot-microservice-with-block-refactor 2026-02-11 15:53:22 +01:00
Reinier van der Leer
fe9debd80f refactor(backend/blocks): Extract backend.blocks._base from backend.data.block
I'm getting circular import issues because there is a lot of cross-importing between `backend.data`, `backend.blocks`, and other components. This change reduces block-related cross-imports and thus risk of breaking circular imports.
2026-02-11 15:38:42 +01:00
Reinier van der Leer
7083dcf226 fix part of broken tests 2026-02-11 14:54:21 +01:00
Reinier van der Leer
ee2805d14c fix(backend/copilot): Use DatabaseManager where needed 2026-02-11 13:43:58 +01:00
Reinier van der Leer
f15362d619 reduce default copilot executor pool size 2026-02-11 01:02:32 +01:00
Reinier van der Leer
6c2374593f add entrypoint to pyproject.toml 2026-02-10 23:38:21 +01:00
Reinier van der Leer
0f4c33308f add to docker compose files 2026-02-10 22:55:47 +01:00
Reinier van der Leer
ecb9fdae25 feat(backend/copilot): Copilot Executor Microservice 2026-02-10 22:48:01 +01:00
101 changed files with 2457 additions and 1198 deletions

View File

@@ -41,13 +41,18 @@ jobs:
ports:
- 6379:6379
rabbitmq:
image: rabbitmq:3.12-management
image: rabbitmq:4.1.4
ports:
- 5672:5672
- 15672:15672
env:
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
options: >-
--health-cmd "rabbitmq-diagnostics -q ping"
--health-interval 30s
--health-timeout 10s
--health-retries 5
--health-start-period 10s
clamav:
image: clamav/clamav-debian:latest
ports:

View File

@@ -6,10 +6,16 @@ on:
paths:
- ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**"
- "autogpt_platform/backend/Dockerfile"
- "autogpt_platform/docker-compose.yml"
- "autogpt_platform/docker-compose.platform.yml"
pull_request:
paths:
- ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**"
- "autogpt_platform/backend/Dockerfile"
- "autogpt_platform/docker-compose.yml"
- "autogpt_platform/docker-compose.platform.yml"
merge_group:
workflow_dispatch:

View File

@@ -53,63 +53,6 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
RUN poetry run prisma generate && poetry run gen-prisma-stub
# ============================== BACKEND SERVER ============================== #
FROM debian:13-slim AS server
WORKDIR /app
ENV POETRY_HOME=/opt/poetry \
POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_CREATE=true \
POETRY_VIRTUALENVS_IN_PROJECT=true \
DEBIAN_FRONTEND=noninteractive
ENV PATH=/opt/poetry/bin:$PATH
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
# for the bash_exec MCP tool.
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \
python3-pip \
ffmpeg \
imagemagick \
jq \
ripgrep \
tree \
bubblewrap \
&& rm -rf /var/lib/apt/lists/*
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
# Copy Node.js installation for Prisma
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
COPY --from=builder /usr/bin/npm /usr/bin/npm
COPY --from=builder /usr/bin/npx /usr/bin/npx
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
WORKDIR /app/autogpt_platform/backend
# Copy only the .venv from builder (not the entire /app directory)
# The .venv includes the generated Prisma client
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
# Copy dependency files + autogpt_libs (path dependency)
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
# Copy backend code + docs (for Copilot docs search)
COPY autogpt_platform/backend ./
COPY docs /app/docs
RUN poetry install --no-ansi --only-root
ENV PORT=8000
CMD ["poetry", "run", "rest"]
# =============================== DB MIGRATOR =============================== #
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
@@ -141,3 +84,59 @@ COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
COPY autogpt_platform/backend/migrations ./migrations
# ============================== BACKEND SERVER ============================== #
FROM debian:13-slim AS server
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
# for the bash_exec MCP tool.
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \
python3-pip \
ffmpeg \
imagemagick \
jq \
ripgrep \
tree \
bubblewrap \
&& rm -rf /var/lib/apt/lists/*
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
# Copy Node.js installation for Prisma
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
COPY --from=builder /usr/bin/npm /usr/bin/npm
COPY --from=builder /usr/bin/npx /usr/bin/npx
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
WORKDIR /app/autogpt_platform/backend
# Copy only the .venv from builder (not the entire /app directory)
# The .venv includes the generated Prisma client
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
# Copy dependency files + autogpt_libs (path dependency)
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
# Copy backend code + docs (for Copilot docs search)
COPY autogpt_platform/backend ./
COPY docs /app/docs
# Install the project package to create entry point scripts in .venv/bin/
# (e.g., rest, executor, ws, db, scheduler, notification - see [tool.poetry.scripts])
RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
poetry install --no-ansi --only-root
ENV PORT=8000
CMD ["rest"]

View File

@@ -1,4 +1,9 @@
"""Common test fixtures for server tests."""
"""Common test fixtures for server tests.
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
setup_test_user, and setup_admin_user are defined in the parent conftest.py
(backend/conftest.py) and are available here automatically.
"""
import pytest
from pytest_snapshot.plugin import Snapshot
@@ -11,54 +16,6 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
return snapshot
@pytest.fixture
def test_user_id() -> str:
"""Test user ID fixture."""
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture
def admin_user_id() -> str:
"""Admin user ID fixture."""
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
@pytest.fixture
def target_user_id() -> str:
"""Target user ID fixture."""
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
@pytest.fixture
async def setup_test_user(test_user_id):
"""Create test user in database before tests."""
from backend.data.user import get_or_create_user
# Create the test user in the database using JWT token format
user_data = {
"sub": test_user_id,
"email": "test@example.com",
"user_metadata": {"name": "Test User"},
}
await get_or_create_user(user_data)
return test_user_id
@pytest.fixture
async def setup_admin_user(admin_user_id):
"""Create admin user in database before tests."""
from backend.data.user import get_or_create_user
# Create the admin user in the database using JWT token format
user_data = {
"sub": admin_user_id,
"email": "test-admin@example.com",
"user_metadata": {"name": "Test Admin"},
}
await get_or_create_user(user_data)
return admin_user_id
@pytest.fixture
def mock_jwt_user(test_user_id):
"""Provide mock JWT payload for regular user testing."""

View File

@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.api.external.middleware import require_permission
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
from backend.api.features.chat.tools.models import ToolResponseBase
from backend.copilot.model import ChatSession
from backend.copilot.tools import find_agent_tool, run_agent_tool
from backend.copilot.tools.models import ToolResponseBase
from backend.data.auth.base import APIAuthorizationInfo
logger = logging.getLogger(__name__)

View File

@@ -11,24 +11,25 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response,
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.util.exceptions import NotFoundError
from backend.util.feature_flag import Flag, is_feature_enabled
from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import (
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.completion_handler import (
process_operation_failure,
process_operation_success,
)
from backend.copilot.config import ChatConfig
from backend.copilot.executor.utils import enqueue_copilot_task
from backend.copilot.model import (
ChatMessage,
ChatSession,
append_and_save_message,
create_chat_session,
delete_chat_session,
get_chat_session,
get_user_sessions,
)
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
from .sdk import service as sdk_service
from .tools.models import (
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
AgentPreviewResponse,
@@ -51,7 +52,8 @@ from .tools.models import (
SetupRequirementsResponse,
UnderstandingUpdatedResponse,
)
from .tracking import track_user_message
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
config = ChatConfig()
@@ -211,6 +213,43 @@ async def create_session(
)
@router.delete(
"/sessions/{session_id}",
dependencies=[Security(auth.requires_user)],
status_code=204,
responses={404: {"description": "Session not found or access denied"}},
)
async def delete_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> Response:
"""
Delete a chat session.
Permanently removes a chat session and all its messages.
Only the owner can delete their sessions.
Args:
session_id: The session ID to delete.
user_id: The authenticated user's ID.
Returns:
204 No Content on success.
Raises:
HTTPException: 404 if session not found or not owned by user.
"""
deleted = await delete_chat_session(session_id, user_id)
if not deleted:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
return Response(status_code=204)
@router.get(
"/sessions/{session_id}",
)
@@ -316,7 +355,7 @@ async def stream_chat_post(
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
session = await _validate_and_get_session(session_id, user_id)
await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
extra={
@@ -343,7 +382,7 @@ async def stream_chat_post(
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
session = await append_and_save_message(session_id, message)
await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
@@ -370,125 +409,19 @@ async def stream_chat_post(
},
)
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
import time as time_module
await enqueue_copilot_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
operation_id=operation_id,
message=request.message,
is_user_message=request.is_user_message,
context=request.context,
)
gen_start_time = time_module.perf_counter()
logger.info(
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
first_chunk_time, ttfc = None, None
chunk_count = 0
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
logger.info(
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
* 1000,
}
},
)
# Choose service based on LaunchDarkly flag (falls back to config default)
use_sdk = await is_feature_enabled(
Flag.COPILOT_SDK,
user_id or "anonymous",
default=config.use_claude_agent_sdk,
)
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else chat_service.stream_chat_completion
)
logger.info(
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
extra={"json_fields": log_meta},
)
# Pass message=None since we already added it to the session above
async for chunk in stream_fn(
session_id,
None, # Message already in session
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass session with message already added
context=request.context,
):
# Skip duplicate StreamStart — we already published one above
if isinstance(chunk, StreamStart):
continue
chunk_count += 1
if first_chunk_time is None:
first_chunk_time = time_module.perf_counter()
ttfc = first_chunk_time - gen_start_time
logger.info(
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"time_to_first_chunk_ms": ttfc * 1000,
}
},
)
# Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk)
gen_end_time = time_module.perf_counter()
total_time = (gen_end_time - gen_start_time) * 1000
logger.info(
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
f"task={task_id}, session={session_id}, "
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"time_to_first_chunk_ms": (
ttfc * 1000 if ttfc is not None else None
),
"n_chunks": chunk_count,
}
},
)
await stream_registry.mark_task_completed(task_id, "completed")
except Exception as e:
elapsed = time_module.perf_counter() - gen_start_time
logger.error(
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed * 1000,
"error": str(e),
}
},
)
# Publish a StreamError so the frontend can display an error message
try:
await stream_registry.publish_chunk(
task_id,
StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
),
)
except Exception:
pass # Best-effort; mark_task_completed will publish StreamFinish
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task)
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)

View File

@@ -41,11 +41,11 @@ import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.api.features.chat.completion_consumer import (
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.copilot.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi

View File

@@ -38,7 +38,9 @@ def main(**kwargs):
from backend.api.rest_api import AgentServer
from backend.api.ws_api import WebsocketServer
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
from backend.copilot.executor.manager import CoPilotExecutor
from backend.data.db_manager import DatabaseManager
from backend.executor import ExecutionManager, Scheduler
from backend.notifications import NotificationManager
run_processes(
@@ -48,6 +50,7 @@ def main(**kwargs):
WebsocketServer(),
AgentServer(),
ExecutionManager(),
CoPilotExecutor(),
**kwargs,
)

View File

@@ -106,6 +106,8 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
GPT4O_MINI = "gpt-4o-mini"
GPT4O = "gpt-4o"
GPT4_TURBO = "gpt-4-turbo"
GPT3_5_TURBO = "gpt-3.5-turbo"
# Anthropic models
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
CLAUDE_4_OPUS = "claude-opus-4-20250514"
@@ -253,6 +255,12 @@ MODEL_METADATA = {
LlmModel.GPT4O: ModelMetadata(
"openai", 128000, 16384, "GPT-4o", "OpenAI", "OpenAI", 2
), # gpt-4o-2024-08-06
LlmModel.GPT4_TURBO: ModelMetadata(
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
), # gpt-4-turbo-2024-04-09
LlmModel.GPT3_5_TURBO: ModelMetadata(
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
), # gpt-3.5-turbo-0125
# https://docs.anthropic.com/en/docs/about-claude/models
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3

View File

@@ -1,6 +1,7 @@
import logging
import os
import pytest
import pytest_asyncio
from dotenv import load_dotenv
@@ -27,6 +28,54 @@ async def server():
yield server
@pytest.fixture
def test_user_id() -> str:
"""Test user ID fixture."""
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture
def admin_user_id() -> str:
"""Admin user ID fixture."""
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
@pytest.fixture
def target_user_id() -> str:
"""Target user ID fixture."""
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
@pytest.fixture
async def setup_test_user(test_user_id):
"""Create test user in database before tests."""
from backend.data.user import get_or_create_user
# Create the test user in the database using JWT token format
user_data = {
"sub": test_user_id,
"email": "test@example.com",
"user_metadata": {"name": "Test User"},
}
await get_or_create_user(user_data)
return test_user_id
@pytest.fixture
async def setup_admin_user(admin_user_id):
"""Create admin user in database before tests."""
from backend.data.user import get_or_create_user
# Create the admin user in the database using JWT token format
user_data = {
"sub": admin_user_id,
"email": "test-admin@example.com",
"user_metadata": {"name": "Test Admin"},
}
await get_or_create_user(user_data)
return admin_user_id
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
async def graph_cleanup(server):
created_graph_ids = []

View File

@@ -0,0 +1 @@

View File

@@ -37,12 +37,10 @@ stale pending messages from dead consumers.
import asyncio
import logging
import os
import uuid
from typing import Any
import orjson
from prisma import Prisma
from pydantic import BaseModel
from redis.exceptions import ResponseError
@@ -69,8 +67,8 @@ class OperationCompleteMessage(BaseModel):
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from Redis Streams.
This consumer initializes its own Prisma client in start() to ensure
database operations work correctly within this async context.
Database operations are handled through the chat_db() accessor, which
routes through DatabaseManager RPC when Prisma is not directly connected.
Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure.
@@ -79,7 +77,6 @@ class ChatCompletionConsumer:
def __init__(self):
self._consumer_task: asyncio.Task | None = None
self._running = False
self._prisma: Prisma | None = None
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
@@ -115,15 +112,6 @@ class ChatCompletionConsumer:
f"Chat completion consumer started (consumer: {self._consumer_name})"
)
async def _ensure_prisma(self) -> Prisma:
"""Lazily initialize Prisma client on first use."""
if self._prisma is None:
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
self._prisma = Prisma(datasource={"url": database_url})
await self._prisma.connect()
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
return self._prisma
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
@@ -136,11 +124,6 @@ class ChatCompletionConsumer:
pass
self._consumer_task = None
if self._prisma:
await self._prisma.disconnect()
self._prisma = None
logger.info("[COMPLETION] Consumer Prisma client disconnected")
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
@@ -252,7 +235,7 @@ class ChatCompletionConsumer:
# XAUTOCLAIM after min_idle_time expires
async def _handle_message(self, body: bytes) -> None:
"""Handle a completion message using our own Prisma client."""
"""Handle a completion message."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
@@ -302,8 +285,7 @@ class ChatCompletionConsumer:
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
prisma = await self._ensure_prisma()
await process_operation_success(task, message.result, prisma)
await process_operation_success(task, message.result)
async def _handle_failure(
self,
@@ -311,8 +293,7 @@ class ChatCompletionConsumer:
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
prisma = await self._ensure_prisma()
await process_operation_failure(task, message.error, prisma)
await process_operation_failure(task, message.error)
# Module-level consumer instance

View File

@@ -9,7 +9,8 @@ import logging
from typing import Any
import orjson
from prisma import Prisma
from backend.data.db_accessors import chat_db
from . import service as chat_service
from . import stream_registry
@@ -72,48 +73,40 @@ async def _update_tool_message(
session_id: str,
tool_call_id: str,
content: str,
prisma_client: Prisma | None,
) -> None:
"""Update tool message in database.
"""Update tool message in database using the chat_db accessor.
Routes through DatabaseManager RPC when Prisma is not directly
connected (e.g. in the CoPilot Executor microservice).
Args:
session_id: The session ID
tool_call_id: The tool call ID to update
content: The new content for the message
prisma_client: Optional Prisma client. If None, uses chat_service.
Raises:
ToolMessageUpdateError: If the database update fails. The caller should
handle this to avoid marking the task as completed with inconsistent state.
ToolMessageUpdateError: If the database update fails.
"""
try:
if prisma_client:
# Use provided Prisma client (for consumer with its own connection)
updated_count = await prisma_client.chatmessage.update_many(
where={
"sessionId": session_id,
"toolCallId": tool_call_id,
},
data={"content": content},
)
# Check if any rows were updated - 0 means message not found
if updated_count == 0:
raise ToolMessageUpdateError(
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
)
else:
# Use service function (for webhook endpoint)
await chat_service._update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=content,
updated = await chat_db().update_tool_message_content(
session_id=session_id,
tool_call_id=tool_call_id,
new_content=content,
)
if not updated:
raise ToolMessageUpdateError(
f"No message found with tool_call_id="
f"{tool_call_id} in session {session_id}"
)
except ToolMessageUpdateError:
raise
except Exception as e:
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
logger.error(
f"[COMPLETION] Failed to update tool message: {e}",
exc_info=True,
)
raise ToolMessageUpdateError(
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
f"Failed to update tool message for tool call #{tool_call_id}: {e}"
) from e
@@ -202,7 +195,6 @@ async def _save_agent_from_result(
async def process_operation_success(
task: stream_registry.ActiveTask,
result: dict | str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle successful operation completion.
@@ -212,12 +204,10 @@ async def process_operation_success(
Args:
task: The active task that completed
result: The result data from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
Raises:
ToolMessageUpdateError: If the database update fails. The task will be
marked as failed instead of completed to avoid inconsistent state.
ToolMessageUpdateError: If the database update fails. The task
will be marked as failed instead of completed.
"""
# For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
@@ -250,7 +240,6 @@ async def process_operation_success(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=result_str,
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - mark task as failed to avoid inconsistent state
@@ -293,18 +282,15 @@ async def process_operation_success(
async def process_operation_failure(
task: stream_registry.ActiveTask,
error: str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle failed operation completion.
Publishes the error to the stream registry, updates the database with
the error response, and marks the task as failed.
Publishes the error to the stream registry, updates the database
with the error response, and marks the task as failed.
Args:
task: The active task that failed
error: The error message from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
"""
error_msg = error or "Operation failed"
@@ -325,7 +311,6 @@ async def process_operation_failure(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=error_response.model_dump_json(),
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - log but continue with cleanup

View File

@@ -14,29 +14,27 @@ from prisma.types import (
ChatSessionWhereInput,
)
from backend.data.db import transaction
from backend.data import db
from backend.util.json import SafeJson
from .model import ChatMessage, ChatSession
logger = logging.getLogger(__name__)
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
async def get_chat_session(session_id: str) -> ChatSession | None:
"""Get a chat session by ID from the database."""
session = await PrismaChatSession.prisma().find_unique(
where={"id": session_id},
include={"Messages": True},
include={"Messages": {"order_by": {"sequence": "asc"}}},
)
if session and session.Messages:
# Sort messages by sequence in Python - Prisma Python client doesn't support
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
session.Messages.sort(key=lambda m: m.sequence)
return session
return ChatSession.from_db(session) if session else None
async def create_chat_session(
session_id: str,
user_id: str,
) -> PrismaChatSession:
) -> ChatSession:
"""Create a new chat session in the database."""
data = ChatSessionCreateInput(
id=session_id,
@@ -45,7 +43,8 @@ async def create_chat_session(
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
)
return await PrismaChatSession.prisma().create(data=data)
prisma_session = await PrismaChatSession.prisma().create(data=data)
return ChatSession.from_db(prisma_session)
async def update_chat_session(
@@ -56,7 +55,7 @@ async def update_chat_session(
total_prompt_tokens: int | None = None,
total_completion_tokens: int | None = None,
title: str | None = None,
) -> PrismaChatSession | None:
) -> ChatSession | None:
"""Update a chat session's metadata."""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
@@ -76,12 +75,9 @@ async def update_chat_session(
session = await PrismaChatSession.prisma().update(
where={"id": session_id},
data=data,
include={"Messages": True},
include={"Messages": {"order_by": {"sequence": "asc"}}},
)
if session and session.Messages:
# Sort in Python - Prisma Python doesn't support order_by in include clauses
session.Messages.sort(key=lambda m: m.sequence)
return session
return ChatSession.from_db(session) if session else None
async def add_chat_message(
@@ -94,7 +90,7 @@ async def add_chat_message(
refusal: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
function_call: dict[str, Any] | None = None,
) -> PrismaChatMessage:
) -> ChatMessage:
"""Add a message to a chat session."""
# Build input dict dynamically rather than using ChatMessageCreateInput directly
# because Prisma's TypedDict validation rejects optional fields set to None.
@@ -129,14 +125,14 @@ async def add_chat_message(
),
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
)
return message
return ChatMessage.from_db(message)
async def add_chat_messages_batch(
session_id: str,
messages: list[dict[str, Any]],
start_sequence: int,
) -> list[PrismaChatMessage]:
) -> list[ChatMessage]:
"""Add multiple messages to a chat session in a batch.
Uses a transaction for atomicity - if any message creation fails,
@@ -147,7 +143,7 @@ async def add_chat_messages_batch(
created_messages = []
async with transaction() as tx:
async with db.transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
@@ -187,21 +183,22 @@ async def add_chat_messages_batch(
data={"updatedAt": datetime.now(UTC)},
)
return created_messages
return [ChatMessage.from_db(m) for m in created_messages]
async def get_user_chat_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> list[PrismaChatSession]:
) -> list[ChatSession]:
"""Get chat sessions for a user, ordered by most recent."""
return await PrismaChatSession.prisma().find_many(
prisma_sessions = await PrismaChatSession.prisma().find_many(
where={"userId": user_id},
order={"updatedAt": "desc"},
take=limit,
skip=offset,
)
return [ChatSession.from_db(s) for s in prisma_sessions]
async def get_user_session_count(user_id: str) -> int:

View File

@@ -0,0 +1,18 @@
"""Entry point for running the CoPilot Executor service.
Usage:
python -m backend.copilot.executor
"""
from backend.app import run_processes
from .manager import CoPilotExecutor
def main():
"""Run the CoPilot Executor service."""
run_processes(CoPilotExecutor())
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,508 @@
"""CoPilot Executor Manager - main service for CoPilot task execution.
This module contains the CoPilotExecutor class that consumes chat tasks from
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
"""
import logging
import os
import threading
import time
import uuid
from concurrent.futures import Future, ThreadPoolExecutor
from pika.adapters.blocking_connection import BlockingChannel
from pika.exceptions import AMQPChannelError, AMQPConnectionError
from pika.spec import Basic, BasicProperties
from prometheus_client import Gauge, start_http_server
from backend.data import redis_client as redis
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.cluster_lock import ClusterLock
from backend.util.decorator import error_logged
from backend.util.logging import TruncatedLogger
from backend.util.process import AppProcess
from backend.util.retry import continuous_retry
from backend.util.settings import Settings
from .processor import execute_copilot_task, init_worker
from .utils import (
COPILOT_CANCEL_QUEUE_NAME,
COPILOT_EXECUTION_QUEUE_NAME,
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
CancelCoPilotEvent,
CoPilotExecutionEntry,
create_copilot_queue_config,
)
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
settings = Settings()
# Prometheus metrics
active_tasks_gauge = Gauge(
"copilot_executor_active_tasks",
"Number of active CoPilot tasks",
)
pool_size_gauge = Gauge(
"copilot_executor_pool_size",
"Maximum number of CoPilot executor workers",
)
utilization_gauge = Gauge(
"copilot_executor_utilization_ratio",
"Ratio of active tasks to pool size",
)
class CoPilotExecutor(AppProcess):
"""CoPilot Executor service for processing chat generation tasks.
This service consumes tasks from RabbitMQ, processes them using a thread pool,
and publishes results to Redis Streams. It follows the graph executor pattern
for reliable message handling and graceful shutdown.
Key features:
- RabbitMQ-based task distribution with manual acknowledgment
- Thread pool executor for concurrent task processing
- Cluster lock for duplicate prevention across pods
- Graceful shutdown with timeout for in-flight tasks
- FANOUT exchange for cancellation broadcast
"""
def __init__(self):
super().__init__()
self.pool_size = settings.config.num_copilot_workers
self.active_tasks: dict[str, tuple[Future, threading.Event]] = {}
self.executor_id = str(uuid.uuid4())
self._executor = None
self._stop_consuming = None
self._cancel_thread = None
self._cancel_client = None
self._run_thread = None
self._run_client = None
self._task_locks: dict[str, ClusterLock] = {}
self._active_tasks_lock = threading.Lock()
# ============ Main Entry Points (AppProcess interface) ============ #
def run(self):
"""Main service loop - consume from RabbitMQ."""
logger.info(f"Pod assigned executor_id: {self.executor_id}")
logger.info(f"Spawn max-{self.pool_size} workers...")
pool_size_gauge.set(self.pool_size)
self._update_metrics()
start_http_server(settings.config.copilot_executor_port)
self.cancel_thread.start()
self.run_thread.start()
while True:
time.sleep(1e5)
def cleanup(self):
"""Graceful shutdown with active execution waiting."""
pid = os.getpid()
logger.info(f"[cleanup {pid}] Starting graceful shutdown...")
# Signal the consumer thread to stop
try:
self.stop_consuming.set()
run_channel = self.run_client.get_channel()
run_channel.connection.add_callback_threadsafe(
lambda: run_channel.stop_consuming()
)
logger.info(f"[cleanup {pid}] Consumer has been signaled to stop")
except Exception as e:
logger.error(f"[cleanup {pid}] Error stopping consumer: {e}")
# Wait for active executions to complete
if self.active_tasks:
logger.info(
f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
)
start_time = time.monotonic()
last_refresh = start_time
lock_refresh_interval = settings.config.cluster_lock_timeout / 10
while (
self.active_tasks
and (time.monotonic() - start_time) < GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
):
self._cleanup_completed_tasks()
if not self.active_tasks:
break
# Refresh cluster locks periodically
current_time = time.monotonic()
if current_time - last_refresh >= lock_refresh_interval:
for lock in list(self._task_locks.values()):
try:
lock.refresh()
except Exception as e:
logger.warning(
f"[cleanup {pid}] Failed to refresh lock: {e}"
)
last_refresh = current_time
logger.info(
f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..."
)
time.sleep(10.0)
# Stop message consumers
if self._run_thread:
self._stop_message_consumers(
self._run_thread, self.run_client, "[cleanup][run]"
)
if self._cancel_thread:
self._stop_message_consumers(
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
)
# Shutdown executor
if self._executor:
logger.info(f"[cleanup {pid}] Shutting down executor...")
self._executor.shutdown(wait=False)
# Release any remaining locks
for task_id, lock in list(self._task_locks.items()):
try:
lock.release()
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
except Exception as e:
logger.error(
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
)
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
# ============ RabbitMQ Consumer Methods ============ #
@continuous_retry()
def _consume_cancel(self):
"""Consume cancellation messages from FANOUT exchange."""
if self.stop_consuming.is_set() and not self.active_tasks:
logger.info("Stop reconnecting cancel consumer - service cleaned up")
return
if not self.cancel_client.is_ready:
self.cancel_client.disconnect()
self.cancel_client.connect()
# Check again after connect - shutdown may have been requested
if self.stop_consuming.is_set() and not self.active_tasks:
logger.info("Stop consuming requested during reconnect - disconnecting")
self.cancel_client.disconnect()
return
cancel_channel = self.cancel_client.get_channel()
cancel_channel.basic_consume(
queue=COPILOT_CANCEL_QUEUE_NAME,
on_message_callback=self._handle_cancel_message,
auto_ack=True,
)
logger.info("Starting cancel message consumer...")
cancel_channel.start_consuming()
if not self.stop_consuming.is_set() or self.active_tasks:
raise RuntimeError("Cancel message consumer stopped unexpectedly")
logger.info("Cancel message consumer stopped gracefully")
@continuous_retry()
def _consume_run(self):
"""Consume run messages from DIRECT exchange."""
if self.stop_consuming.is_set():
logger.info("Stop reconnecting run consumer - service cleaned up")
return
if not self.run_client.is_ready:
self.run_client.disconnect()
self.run_client.connect()
# Check again after connect - shutdown may have been requested
if self.stop_consuming.is_set():
logger.info("Stop consuming requested during reconnect - disconnecting")
self.run_client.disconnect()
return
run_channel = self.run_client.get_channel()
run_channel.basic_qos(prefetch_count=self.pool_size)
run_channel.basic_consume(
queue=COPILOT_EXECUTION_QUEUE_NAME,
on_message_callback=self._handle_run_message,
auto_ack=False,
consumer_tag="copilot_execution_consumer",
)
logger.info("Starting to consume run messages...")
run_channel.start_consuming()
if not self.stop_consuming.is_set():
raise RuntimeError("Run message consumer stopped unexpectedly")
logger.info("Run message consumer stopped gracefully")
# ============ Message Handlers ============ #
@error_logged(swallow=True)
def _handle_cancel_message(
self,
_channel: BlockingChannel,
_method: Basic.Deliver,
_properties: BasicProperties,
body: bytes,
):
"""Handle cancel message from FANOUT exchange."""
request = CancelCoPilotEvent.model_validate_json(body)
task_id = request.task_id
if not task_id:
logger.warning("Cancel message missing 'task_id'")
return
if task_id not in self.active_tasks:
logger.debug(f"Cancel received for {task_id} but not active")
return
_, cancel_event = self.active_tasks[task_id]
logger.info(f"Received cancel for {task_id}")
if not cancel_event.is_set():
cancel_event.set()
else:
logger.debug(f"Cancel already set for {task_id}")
def _handle_run_message(
self,
_channel: BlockingChannel,
method: Basic.Deliver,
_properties: BasicProperties,
body: bytes,
):
"""Handle run message from DIRECT exchange."""
delivery_tag = method.delivery_tag
# Capture the channel used at message delivery time to ensure we ack
# on the correct channel. Delivery tags are channel-scoped and become
# invalid if the channel is recreated after reconnection.
delivery_channel = _channel
def ack_message(reject: bool, requeue: bool):
"""Acknowledge or reject the message.
Uses the channel from the original message delivery. If the channel
is no longer open (e.g., after reconnection), logs a warning and
skips the ack - RabbitMQ will redeliver the message automatically.
"""
try:
if not delivery_channel.is_open:
logger.warning(
f"Channel closed, cannot ack delivery_tag={delivery_tag}. "
"Message will be redelivered by RabbitMQ."
)
return
if reject:
delivery_channel.connection.add_callback_threadsafe(
lambda: delivery_channel.basic_nack(
delivery_tag, requeue=requeue
)
)
else:
delivery_channel.connection.add_callback_threadsafe(
lambda: delivery_channel.basic_ack(delivery_tag)
)
except (AMQPChannelError, AMQPConnectionError) as e:
# Channel/connection errors indicate stale delivery tag - don't retry
logger.warning(
f"Cannot ack delivery_tag={delivery_tag} due to channel/connection "
f"error: {e}. Message will be redelivered by RabbitMQ."
)
except Exception as e:
# Other errors might be transient, but log and skip to avoid blocking
logger.error(
f"Unexpected error acking delivery_tag={delivery_tag}: {e}"
)
# Check if we're shutting down
if self.stop_consuming.is_set():
logger.info("Rejecting new task during shutdown")
ack_message(reject=True, requeue=True)
return
# Check if we can accept more tasks
self._cleanup_completed_tasks()
if len(self.active_tasks) >= self.pool_size:
ack_message(reject=True, requeue=True)
return
try:
entry = CoPilotExecutionEntry.model_validate_json(body)
except Exception as e:
logger.error(f"Could not parse run message: {e}, body={body}")
ack_message(reject=True, requeue=False)
return
task_id = entry.task_id
# Check for local duplicate - task is already running on this executor
if task_id in self.active_tasks:
logger.warning(
f"Task {task_id} already running locally, rejecting duplicate"
)
ack_message(reject=True, requeue=False)
return
# Try to acquire cluster-wide lock
cluster_lock = ClusterLock(
redis=redis.get_redis(),
key=f"copilot:task:{task_id}:lock",
owner_id=self.executor_id,
timeout=settings.config.cluster_lock_timeout,
)
current_owner = cluster_lock.try_acquire()
if current_owner != self.executor_id:
if current_owner is not None:
logger.warning(f"Task {task_id} already running on pod {current_owner}")
ack_message(reject=True, requeue=False)
else:
logger.warning(
f"Could not acquire lock for {task_id} - Redis unavailable"
)
ack_message(reject=True, requeue=True)
return
# Execute the task
try:
self._task_locks[task_id] = cluster_lock
logger.info(
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
)
cancel_event = threading.Event()
future = self.executor.submit(
execute_copilot_task, entry, cancel_event, cluster_lock
)
self.active_tasks[task_id] = (future, cancel_event)
except Exception as e:
logger.warning(f"Failed to setup execution for {task_id}: {e}")
cluster_lock.release()
if task_id in self._task_locks:
del self._task_locks[task_id]
ack_message(reject=True, requeue=True)
return
self._update_metrics()
def on_run_done(f: Future):
logger.info(f"Run completed for {task_id}")
try:
if exec_error := f.exception():
logger.error(f"Execution for {task_id} failed: {exec_error}")
# Don't requeue failed tasks - they've been marked as failed
# in the stream registry. Requeuing would cause infinite retries
# for deterministic failures.
ack_message(reject=True, requeue=False)
else:
ack_message(reject=False, requeue=False)
except BaseException as e:
logger.exception(f"Error in run completion callback: {e}")
finally:
# Release the cluster lock
if task_id in self._task_locks:
logger.info(f"Releasing cluster lock for {task_id}")
self._task_locks[task_id].release()
del self._task_locks[task_id]
self._cleanup_completed_tasks()
future.add_done_callback(on_run_done)
# ============ Helper Methods ============ #
def _cleanup_completed_tasks(self) -> list[str]:
"""Remove completed futures from active_tasks and update metrics."""
completed_tasks = []
with self._active_tasks_lock:
for task_id, (future, _) in list(self.active_tasks.items()):
if future.done():
completed_tasks.append(task_id)
self.active_tasks.pop(task_id, None)
logger.info(f"Cleaned up completed task {task_id}")
self._update_metrics()
return completed_tasks
def _update_metrics(self):
"""Update Prometheus metrics."""
active_count = len(self.active_tasks)
active_tasks_gauge.set(active_count)
if self.stop_consuming.is_set():
utilization_gauge.set(1.0)
else:
utilization_gauge.set(
active_count / self.pool_size if self.pool_size > 0 else 0
)
def _stop_message_consumers(
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
):
"""Stop a message consumer thread."""
try:
channel = client.get_channel()
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
thread.join(timeout=300)
if thread.is_alive():
logger.error(
f"{prefix} Thread did not finish in time, forcing disconnect"
)
client.disconnect()
logger.info(f"{prefix} Client disconnected")
except Exception as e:
logger.error(f"{prefix} Error disconnecting client: {e}")
# ============ Lazy-initialized Properties ============ #
@property
def cancel_thread(self) -> threading.Thread:
if self._cancel_thread is None:
self._cancel_thread = threading.Thread(
target=lambda: self._consume_cancel(),
daemon=True,
)
return self._cancel_thread
@property
def run_thread(self) -> threading.Thread:
if self._run_thread is None:
self._run_thread = threading.Thread(
target=lambda: self._consume_run(),
daemon=True,
)
return self._run_thread
@property
def stop_consuming(self) -> threading.Event:
if self._stop_consuming is None:
self._stop_consuming = threading.Event()
return self._stop_consuming
@property
def executor(self) -> ThreadPoolExecutor:
if self._executor is None:
self._executor = ThreadPoolExecutor(
max_workers=self.pool_size,
initializer=init_worker,
)
return self._executor
@property
def cancel_client(self) -> SyncRabbitMQ:
if self._cancel_client is None:
self._cancel_client = SyncRabbitMQ(create_copilot_queue_config())
return self._cancel_client
@property
def run_client(self) -> SyncRabbitMQ:
if self._run_client is None:
self._run_client = SyncRabbitMQ(create_copilot_queue_config())
return self._run_client

View File

@@ -0,0 +1,253 @@
"""CoPilot execution processor - per-worker execution logic.
This module contains the processor class that handles CoPilot task execution
in a thread-local context, following the graph executor pattern.
"""
import asyncio
import logging
import threading
import time
from backend.copilot import service as copilot_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
from backend.copilot.sdk import service as sdk_service
from backend.executor.cluster_lock import ClusterLock
from backend.util.decorator import error_logged
from backend.util.feature_flag import Flag, is_feature_enabled
from backend.util.logging import TruncatedLogger, configure_logging
from backend.util.process import set_service_name
from backend.util.retry import func_retry
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
# ============ Module Entry Points ============ #
# Thread-local storage for processor instances
_tls = threading.local()
def execute_copilot_task(
entry: CoPilotExecutionEntry,
cancel: threading.Event,
cluster_lock: ClusterLock,
):
"""Execute a CoPilot task using the thread-local processor.
This function is the entry point called by the thread pool executor.
Args:
entry: The task payload
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock for this execution
"""
processor: CoPilotProcessor = _tls.processor
return processor.execute(entry, cancel, cluster_lock)
def init_worker():
"""Initialize the processor for the current worker thread.
This function is called by the thread pool executor when a new worker
thread is created. It ensures each worker has its own processor instance.
"""
_tls.processor = CoPilotProcessor()
_tls.processor.on_executor_start()
# ============ Processor Class ============ #
class CoPilotProcessor:
"""Per-worker execution logic for CoPilot tasks.
This class is instantiated once per worker thread and handles the execution
of CoPilot chat generation tasks. It maintains an async event loop for
running the async service code.
The execution flow:
1. CoPilot task is picked from RabbitMQ queue
2. Manager submits task to thread pool
3. Processor executes the task in its event loop
4. Results are published to Redis Streams
"""
@func_retry
def on_executor_start(self):
"""Initialize the processor when the worker thread starts.
This method is called once per worker thread to set up the async event
loop and initialize any required resources.
Database is accessed only through DatabaseManager, so we don't need to connect
to Prisma directly.
"""
configure_logging()
set_service_name("CoPilotExecutor")
self.tid = threading.get_ident()
self.execution_loop = asyncio.new_event_loop()
self.execution_thread = threading.Thread(
target=self.execution_loop.run_forever, daemon=True
)
self.execution_thread.start()
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
@error_logged(swallow=False)
def execute(
self,
entry: CoPilotExecutionEntry,
cancel: threading.Event,
cluster_lock: ClusterLock,
):
"""Execute a CoPilot task.
This is the main entry point for task execution. It runs the async
execution logic in the worker's event loop and handles errors.
Args:
entry: The task payload containing session and message info
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock to prevent duplicate execution
"""
log = CoPilotLogMetadata(
logging.getLogger(__name__),
task_id=entry.task_id,
session_id=entry.session_id,
user_id=entry.user_id,
)
log.info("Starting execution")
start_time = time.monotonic()
try:
# Run the async execution in our event loop
future = asyncio.run_coroutine_threadsafe(
self._execute_async(entry, cancel, cluster_lock, log),
self.execution_loop,
)
# Wait for completion, checking cancel periodically
while not future.done():
try:
future.result(timeout=1.0)
except asyncio.TimeoutError:
if cancel.is_set():
log.info("Cancellation requested")
future.cancel()
break
# Refresh cluster lock to maintain ownership
cluster_lock.refresh()
if not future.cancelled():
# Get result to propagate any exceptions
future.result()
elapsed = time.monotonic() - start_time
log.info(f"Execution completed in {elapsed:.2f}s")
except Exception as e:
elapsed = time.monotonic() - start_time
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
# Note: _execute_async already marks the task as failed before re-raising,
# so we don't call _mark_task_failed here to avoid duplicate error events.
raise
async def _execute_async(
self,
entry: CoPilotExecutionEntry,
cancel: threading.Event,
cluster_lock: ClusterLock,
log: CoPilotLogMetadata,
):
"""Async execution logic for CoPilot task.
This method calls the existing stream_chat_completion service function
and publishes results to the stream registry.
Args:
entry: The task payload
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock for refresh
log: Structured logger for this task
"""
last_refresh = time.monotonic()
refresh_interval = 30.0 # Refresh lock every 30 seconds
try:
# Choose service based on LaunchDarkly flag
config = ChatConfig()
use_sdk = await is_feature_enabled(
Flag.COPILOT_SDK,
entry.user_id or "anonymous",
default=config.use_claude_agent_sdk,
)
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else copilot_service.stream_chat_completion
)
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
# Stream chat completion and publish chunks to Redis
async for chunk in stream_fn(
session_id=entry.session_id,
message=entry.message if entry.message else None,
is_user_message=entry.is_user_message,
user_id=entry.user_id,
context=entry.context,
):
# Check for cancellation
if cancel.is_set():
log.info("Cancelled during streaming")
await stream_registry.publish_chunk(
entry.task_id, StreamError(errorText="Operation cancelled")
)
await stream_registry.publish_chunk(
entry.task_id, StreamFinishStep()
)
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
await stream_registry.mark_task_completed(
entry.task_id, status="failed"
)
return
# Refresh cluster lock periodically
current_time = time.monotonic()
if current_time - last_refresh >= refresh_interval:
cluster_lock.refresh()
last_refresh = current_time
# Publish chunk to stream registry
await stream_registry.publish_chunk(entry.task_id, chunk)
# Mark task as completed
await stream_registry.mark_task_completed(entry.task_id, status="completed")
log.info("Task completed successfully")
except asyncio.CancelledError:
log.info("Task cancelled")
await stream_registry.mark_task_completed(entry.task_id, status="failed")
raise
except Exception as e:
log.error(f"Task failed: {e}")
await self._mark_task_failed(entry.task_id, str(e))
raise
async def _mark_task_failed(self, task_id: str, error_message: str):
"""Mark a task as failed and publish error to stream registry."""
try:
await stream_registry.publish_chunk(
task_id, StreamError(errorText=error_message)
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())
await stream_registry.mark_task_completed(task_id, status="failed")
except Exception as e:
logger.error(f"Failed to mark task {task_id} as failed: {e}")

View File

@@ -0,0 +1,207 @@
"""RabbitMQ queue configuration for CoPilot executor.
Defines two exchanges and queues following the graph executor pattern:
- 'copilot_execution' (DIRECT) for chat generation tasks
- 'copilot_cancel' (FANOUT) for cancellation requests
"""
import logging
from pydantic import BaseModel
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
logger = logging.getLogger(__name__)
# ============ Logging Helper ============ #
class CoPilotLogMetadata(TruncatedLogger):
"""Structured logging helper for CoPilot executor.
In cloud environments (structured logging enabled), uses a simple prefix
and passes metadata via json_fields. In local environments, uses a detailed
prefix with all metadata key-value pairs for easier debugging.
Args:
logger: The underlying logger instance
max_length: Maximum log message length before truncation
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
These are added to json_fields in cloud mode, or to the prefix in local mode.
"""
def __init__(
self,
logger: logging.Logger,
max_length: int = 1000,
**kwargs: str | None,
):
# Filter out None values
metadata = {k: v for k, v in kwargs.items() if v is not None}
metadata["component"] = "CoPilotExecutor"
if is_structured_logging_enabled():
prefix = "[CoPilotExecutor]"
else:
# Build prefix from metadata key-value pairs
meta_parts = "|".join(
f"{k}:{v}" for k, v in metadata.items() if k != "component"
)
prefix = (
f"[CoPilotExecutor|{meta_parts}]" if meta_parts else "[CoPilotExecutor]"
)
super().__init__(
logger,
max_length=max_length,
prefix=prefix,
metadata=metadata,
)
# ============ Exchange and Queue Configuration ============ #
COPILOT_EXECUTION_EXCHANGE = Exchange(
name="copilot_execution",
type=ExchangeType.DIRECT,
durable=True,
auto_delete=False,
)
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue"
COPILOT_EXECUTION_ROUTING_KEY = "copilot.run"
COPILOT_CANCEL_EXCHANGE = Exchange(
name="copilot_cancel",
type=ExchangeType.FANOUT,
durable=True,
auto_delete=False,
)
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
# CoPilot operations can include extended thinking and agent generation
# which may take 30+ minutes to complete
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
# Graceful shutdown timeout - allow in-flight operations to complete
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
def create_copilot_queue_config() -> RabbitMQConfig:
"""Create RabbitMQ configuration for CoPilot executor.
Defines two exchanges and queues:
- 'copilot_execution' (DIRECT) for chat generation tasks
- 'copilot_cancel' (FANOUT) for cancellation requests
Returns:
RabbitMQConfig with exchanges and queues defined
"""
run_queue = Queue(
name=COPILOT_EXECUTION_QUEUE_NAME,
exchange=COPILOT_EXECUTION_EXCHANGE,
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
durable=True,
auto_delete=False,
arguments={
# Extended consumer timeout for long-running LLM operations
# Default 30-minute timeout is insufficient for extended thinking
# and agent generation which can take 30+ minutes
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
* 1000,
},
)
cancel_queue = Queue(
name=COPILOT_CANCEL_QUEUE_NAME,
exchange=COPILOT_CANCEL_EXCHANGE,
routing_key="", # not used for FANOUT
durable=True,
auto_delete=False,
)
return RabbitMQConfig(
vhost="/",
exchanges=[COPILOT_EXECUTION_EXCHANGE, COPILOT_CANCEL_EXCHANGE],
queues=[run_queue, cancel_queue],
)
# ============ Message Models ============ #
class CoPilotExecutionEntry(BaseModel):
"""Task payload for CoPilot AI generation.
This model represents a chat generation task to be processed by the executor.
"""
task_id: str
"""Unique identifier for this task (used for stream registry)"""
session_id: str
"""Chat session ID"""
user_id: str | None
"""User ID (may be None for anonymous users)"""
operation_id: str
"""Operation ID for webhook callbacks and completion tracking"""
message: str
"""User's message to process"""
is_user_message: bool = True
"""Whether the message is from the user (vs system/assistant)"""
context: dict[str, str] | None = None
"""Optional context for the message (e.g., {url: str, content: str})"""
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
task_id: str
"""Task ID to cancel"""
# ============ Queue Publishing Helpers ============ #
async def enqueue_copilot_task(
task_id: str,
session_id: str,
user_id: str | None,
operation_id: str,
message: str,
is_user_message: bool = True,
context: dict[str, str] | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
Args:
task_id: Unique identifier for this task (used for stream registry)
session_id: Chat session ID
user_id: User ID (may be None for anonymous users)
operation_id: Operation ID for webhook callbacks and completion tracking
message: User's message to process
is_user_message: Whether the message is from the user (vs system/assistant)
context: Optional context for the message (e.g., {url: str, content: str})
"""
from backend.util.clients import get_async_copilot_queue
entry = CoPilotExecutionEntry(
task_id=task_id,
session_id=session_id,
user_id=user_id,
operation_id=operation_id,
message=message,
is_user_message=is_user_message,
context=context,
)
queue_client = await get_async_copilot_queue()
await queue_client.publish_message(
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
message=entry.model_dump_json(),
exchange=COPILOT_EXECUTION_EXCHANGE,
)

View File

@@ -23,26 +23,17 @@ from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from pydantic import BaseModel
from backend.data.db_accessors import chat_db
from backend.data.redis_client import get_redis_async
from backend.util import json
from backend.util.exceptions import DatabaseError, RedisError
from . import db as chat_db
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
"""Parse a JSON field that may be stored as string or already parsed."""
if value is None:
return default
if isinstance(value, str):
return json.loads(value)
return value
# Redis cache key prefix for chat sessions
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
@@ -52,28 +43,7 @@ def _get_session_cache_key(session_id: str) -> str:
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
# Session-level locks to prevent race conditions during concurrent upserts.
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock
# ===================== Chat data models ===================== #
class ChatMessage(BaseModel):
@@ -85,6 +55,19 @@ class ChatMessage(BaseModel):
tool_calls: list[dict] | None = None
function_call: dict | None = None
@staticmethod
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
"""Convert a Prisma ChatMessage to a Pydantic ChatMessage."""
return ChatMessage(
role=prisma_message.role,
content=prisma_message.content,
name=prisma_message.name,
tool_call_id=prisma_message.toolCallId,
refusal=prisma_message.refusal,
tool_calls=_parse_json_field(prisma_message.toolCalls),
function_call=_parse_json_field(prisma_message.functionCall),
)
class Usage(BaseModel):
prompt_tokens: int
@@ -138,26 +121,8 @@ class ChatSession(BaseModel):
)
@staticmethod
def from_db(
prisma_session: PrismaChatSession,
prisma_messages: list[PrismaChatMessage] | None = None,
) -> "ChatSession":
"""Convert Prisma models to Pydantic ChatSession."""
messages = []
if prisma_messages:
for msg in prisma_messages:
messages.append(
ChatMessage(
role=msg.role,
content=msg.content,
name=msg.name,
tool_call_id=msg.toolCallId,
refusal=msg.refusal,
tool_calls=_parse_json_field(msg.toolCalls),
function_call=_parse_json_field(msg.functionCall),
)
)
def from_db(prisma_session: PrismaChatSession) -> "ChatSession":
"""Convert Prisma ChatSession to Pydantic ChatSession."""
# Parse JSON fields from Prisma
credentials = _parse_json_field(prisma_session.credentials, default={})
successful_agent_runs = _parse_json_field(
@@ -183,7 +148,11 @@ class ChatSession(BaseModel):
session_id=prisma_session.id,
user_id=prisma_session.userId,
title=prisma_session.title,
messages=messages,
messages=(
[ChatMessage.from_db(m) for m in prisma_session.Messages]
if prisma_session.Messages
else []
),
usage=usage,
credentials=credentials,
started_at=prisma_session.createdAt,
@@ -322,37 +291,26 @@ class ChatSession(BaseModel):
return self._merge_consecutive_assistant_messages(messages)
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
"""Get a chat session from Redis cache."""
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
raw_session: bytes | None = await async_redis.get(redis_key)
if raw_session is None:
return None
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
"""Parse a JSON field that may be stored as string or already parsed."""
if value is None:
return default
if isinstance(value, str):
return json.loads(value)
return value
async def _cache_session(session: ChatSession) -> None:
"""Cache a chat session in Redis."""
redis_key = _get_session_cache_key(session.session_id)
async_redis = await get_redis_async()
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
# ================ Chat cache + DB operations ================ #
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
# connected directly.
async def cache_chat_session(session: ChatSession) -> None:
"""Cache a chat session without persisting to the database."""
await _cache_session(session)
"""Cache a chat session in Redis (without persisting to the database)."""
redis_key = _get_session_cache_key(session.session_id)
async_redis = await get_redis_async()
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
async def invalidate_session_cache(session_id: str) -> None:
@@ -370,77 +328,6 @@ async def invalidate_session_cache(session_id: str) -> None:
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
prisma_session = await chat_db.get_chat_session(session_id)
if not prisma_session:
return None
messages = prisma_session.Messages
logger.debug(
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
)
return ChatSession.from_db(prisma_session, messages)
async def _save_session_to_db(
session: ChatSession, existing_message_count: int
) -> None:
"""Save or update a chat session in the database."""
# Check if session exists in DB
existing = await chat_db.get_chat_session(session.session_id)
if not existing:
# Create new session
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
# Update session metadata
await chat_db.update_chat_session(
session_id=session.session_id,
credentials=session.credentials,
successful_agent_runs=session.successful_agent_runs,
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
if new_messages:
messages_data = []
for msg in new_messages:
messages_data.append(
{
"role": msg.role,
"content": msg.content,
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"refusal": msg.refusal,
"tool_calls": msg.tool_calls,
"function_call": msg.function_call,
}
)
logger.debug(
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
f"roles={[m['role'] for m in messages_data]}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
async def get_chat_session(
session_id: str,
user_id: str | None = None,
@@ -488,13 +375,52 @@ async def get_chat_session(
# Cache the session from DB
try:
await _cache_session(session)
await cache_chat_session(session)
logger.info(f"Cached session {session_id} from database")
except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}")
return session
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
"""Get a chat session from Redis cache."""
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
raw_session: bytes | None = await async_redis.get(redis_key)
if raw_session is None:
return None
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
session = await chat_db().get_chat_session(session_id)
if not session:
return None
logger.info(
f"Loaded session {session_id} from DB: "
f"has_messages={bool(session.messages)}, "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
async def upsert_chat_session(
session: ChatSession,
) -> ChatSession:
@@ -515,7 +441,7 @@ async def upsert_chat_session(
async with lock:
# Get existing message count from DB for incremental saves
existing_message_count = await chat_db.get_chat_session_message_count(
existing_message_count = await chat_db().get_chat_session_message_count(
session.session_id
)
@@ -532,7 +458,7 @@ async def upsert_chat_session(
# Save to cache (best-effort, even if DB failed)
try:
await _cache_session(session)
await cache_chat_session(session)
except Exception as e:
# If DB succeeded but cache failed, raise cache error
if db_error is None:
@@ -553,6 +479,65 @@ async def upsert_chat_session(
return session
async def _save_session_to_db(
session: ChatSession, existing_message_count: int
) -> None:
"""Save or update a chat session in the database."""
db = chat_db()
# Check if session exists in DB
existing = await db.get_chat_session(session.session_id)
if not existing:
# Create new session
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
# Update session metadata
await db.update_chat_session(
session_id=session.session_id,
credentials=session.credentials,
successful_agent_runs=session.successful_agent_runs,
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
if new_messages:
messages_data = []
for msg in new_messages:
messages_data.append(
{
"role": msg.role,
"content": msg.content,
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"refusal": msg.refusal,
"tool_calls": msg.tool_calls,
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
@@ -568,7 +553,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db.get_chat_session_message_count(
existing_message_count = await chat_db().get_chat_session_message_count(
session_id
)
@@ -580,7 +565,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
) from e
try:
await _cache_session(session)
await cache_chat_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
@@ -599,7 +584,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
# Create in database first - fail fast if this fails
try:
await chat_db.create_chat_session(
await chat_db().create_chat_session(
session_id=session.session_id,
user_id=user_id,
)
@@ -611,7 +596,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
# Cache the session (best-effort optimization, DB is source of truth)
try:
await _cache_session(session)
await cache_chat_session(session)
except Exception as e:
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
@@ -629,13 +614,9 @@ async def get_user_sessions(
A tuple of (sessions, total_count) where total_count is the overall
number of sessions for the user (not just the current page).
"""
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
total_count = await chat_db.get_user_session_count(user_id)
sessions = []
for prisma_session in prisma_sessions:
# Convert without messages for listing (lighter weight)
sessions.append(ChatSession.from_db(prisma_session, None))
db = chat_db()
sessions = await db.get_user_chat_sessions(user_id, limit, offset)
total_count = await db.get_user_session_count(user_id)
return sessions, total_count
@@ -653,7 +634,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
"""
# Delete from database first (with optional user_id validation)
# This confirms ownership before invalidating cache
deleted = await chat_db.delete_chat_session(session_id, user_id)
deleted = await chat_db().delete_chat_session(session_id, user_id)
if not deleted:
return False
@@ -688,7 +669,7 @@ async def update_session_title(session_id: str, title: str) -> bool:
True if updated successfully, False otherwise.
"""
try:
result = await chat_db.update_chat_session(session_id=session_id, title=title)
result = await chat_db().update_chat_session(session_id=session_id, title=title)
if result is None:
logger.warning(f"Session {session_id} not found for title update")
return False
@@ -700,7 +681,7 @@ async def update_session_title(session_id: str, title: str) -> bool:
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await _cache_session(cached)
await cache_chat_session(cached)
except Exception as e:
# Not critical - title will be correct on next full cache refresh
logger.warning(
@@ -711,3 +692,29 @@ async def update_session_title(session_id: str, title: str) -> bool:
except Exception as e:
logger.error(f"Failed to update title for session {session_id}: {e}")
return False
# ==================== Chat session locks ==================== #
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
This was originally added to solve the specific problem of race conditions between
the session title thread and the conversation thread, which always occurs on the
same instance as we prevent rapid request sends on the frontend.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks. Explicit cleanup also occurs
in `delete_chat_session()`.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock

View File

@@ -20,7 +20,7 @@ from claude_agent_sdk import (
UserMessage,
)
from backend.api.features.chat.response_model import (
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
@@ -34,10 +34,8 @@ from backend.api.features.chat.response_model import (
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.api.features.chat.sdk.tool_adapter import (
MCP_TOOL_PREFIX,
pop_pending_tool_output,
)
from .tool_adapter import MCP_TOOL_PREFIX, pop_pending_tool_output
logger = logging.getLogger(__name__)

View File

@@ -10,7 +10,7 @@ from claude_agent_sdk import (
UserMessage,
)
from backend.api.features.chat.response_model import (
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,

View File

@@ -11,7 +11,7 @@ import re
from collections.abc import Callable
from typing import Any, cast
from backend.api.features.chat.sdk.tool_adapter import (
from .tool_adapter import (
BLOCKED_TOOLS,
DANGEROUS_PATTERNS,
MCP_TOOL_PREFIX,

View File

@@ -1,4 +1,9 @@
"""Unit tests for SDK security hooks."""
"""Tests for SDK security hooks — workspace paths, tool access, and deny messages.
These are pure unit tests with no external dependencies (no SDK, no DB, no server).
They validate that the security hooks correctly block unauthorized paths,
tool access, and dangerous input patterns.
"""
import os
@@ -12,6 +17,10 @@ def _is_denied(result: dict) -> bool:
return hook.get("permissionDecision") == "deny"
def _reason(result: dict) -> str:
return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "")
# -- Blocked tools -----------------------------------------------------------
@@ -163,3 +172,19 @@ def test_non_workspace_tool_passes_isolation():
"find_agent", {"query": "email"}, user_id="user-1"
)
assert result == {}
# -- Deny message quality ----------------------------------------------------
def test_blocked_tool_message_clarity():
"""Deny messages must include [SECURITY] and 'cannot be bypassed'."""
reason = _reason(_validate_tool_access("bash", {}))
assert "[SECURITY]" in reason
assert "cannot be bypassed" in reason
def test_bash_builtin_blocked_message_clarity():
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
assert "[SECURITY]" in reason
assert "cannot be bypassed" in reason

View File

@@ -255,7 +255,7 @@ def _build_sdk_env() -> dict[str, str]:
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path`
Delegates to :func:`~backend.copilot.tools.sandbox.make_session_path`
(single source of truth for path sanitization) and adds a defence-in-depth
assertion.
"""
@@ -440,12 +440,16 @@ async def stream_chat_completion_sdk(
f"Session {session_id} not found. Please create a new session first."
)
if message:
session.messages.append(
ChatMessage(
role="user" if is_user_message else "assistant", content=message
)
# Append the new message to the session if it's not already there
new_message_role = "user" if is_user_message else "assistant"
if message and (
len(session.messages) == 0
or not (
session.messages[-1].role == new_message_role
and session.messages[-1].content == message
)
):
session.messages.append(ChatMessage(role=new_message_role, content=message))
if is_user_message:
track_user_message(
user_id=user_id, session_id=session_id, message_length=len(message)

View File

@@ -18,9 +18,9 @@ from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools import TOOL_REGISTRY
from backend.api.features.chat.tools.base import BaseTool
from backend.copilot.model import ChatSession
from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools.base import BaseTool
logger = logging.getLogger(__name__)

View File

@@ -3,7 +3,7 @@
import json
import os
from backend.api.features.chat.sdk.transcript import (
from .transcript import (
STRIPPABLE_TYPES,
read_transcript_file,
strip_progress_entries,

View File

@@ -27,6 +27,7 @@ from openai.types.chat import (
ChatCompletionToolParam,
)
from backend.data.db_accessors import chat_db
from backend.data.redis_client import get_redis_async
from backend.data.understanding import (
format_understanding_for_prompt,
@@ -35,7 +36,6 @@ from backend.data.understanding import (
from backend.util.exceptions import NotFoundError
from backend.util.settings import AppEnvironment, Settings
from . import db as chat_db
from . import stream_registry
from .config import ChatConfig
from .model import (
@@ -428,12 +428,16 @@ async def stream_chat_completion(
f"Session {session_id} not found. Please create a new session first."
)
if message:
session.messages.append(
ChatMessage(
role="user" if is_user_message else "assistant", content=message
)
# Append the new message to the session if it's not already there
new_message_role = "user" if is_user_message else "assistant"
if message and (
len(session.messages) == 0
or not (
session.messages[-1].role == new_message_role
and session.messages[-1].content == message
)
):
session.messages.append(ChatMessage(role=new_message_role, content=message))
logger.info(
f"Appended message (role={'user' if is_user_message else 'assistant'}), "
f"new message_count={len(session.messages)}"
@@ -1770,7 +1774,7 @@ async def _update_pending_operation(
This is called by background tasks when long-running operations complete.
"""
# Update the message in database
updated = await chat_db.update_tool_message_content(
updated = await chat_db().update_tool_message_content(
session_id=session_id,
tool_call_id=tool_call_id,
new_content=result,

View File

@@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any
from openai.types.chat import ChatCompletionToolParam
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tracking import track_tool_called
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
@@ -31,7 +31,7 @@ from .workspace_files import (
)
if TYPE_CHECKING:
from backend.api.features.chat.response_model import StreamToolOutputAvailable
from backend.copilot.response_model import StreamToolOutputAvailable
logger = logging.getLogger(__name__)

View File

@@ -6,11 +6,11 @@ import pytest
from prisma.types import ProfileCreateInput
from pydantic import SecretStr
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import AITextGeneratorBlock
from backend.copilot.model import ChatSession
from backend.data.db import prisma
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.model import APIKeyCredentials

View File

@@ -3,11 +3,9 @@
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.data.understanding import (
BusinessUnderstandingInput,
upsert_business_understanding,
)
from backend.copilot.model import ChatSession
from backend.data.db_accessors import understanding_db
from backend.data.understanding import BusinessUnderstandingInput
from .base import BaseTool
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
@@ -99,7 +97,9 @@ and automations for the user's specific needs."""
]
# Upsert with merge
understanding = await upsert_business_understanding(user_id, input_data)
understanding = await understanding_db().upsert_business_understanding(
user_id, input_data
)
# Build current understanding summary (filter out empty values)
current_understanding = {

View File

@@ -5,9 +5,8 @@ import re
import uuid
from typing import Any, NotRequired, TypedDict
from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
from backend.data.db_accessors import graph_db, library_db, store_db
from backend.data.graph import Graph, Link, Node
from backend.util.exceptions import DatabaseError, NotFoundError
from .service import (
@@ -145,8 +144,9 @@ async def get_library_agent_by_id(
Returns:
LibraryAgentSummary if found, None otherwise
"""
db = library_db()
try:
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
agent = await db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return LibraryAgentSummary(
@@ -163,7 +163,7 @@ async def get_library_agent_by_id(
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
try:
agent = await library_db.get_library_agent(agent_id, user_id)
agent = await db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return LibraryAgentSummary(
@@ -215,7 +215,7 @@ async def get_library_agents_for_generation(
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
"""
try:
response = await library_db.list_library_agents(
response = await library_db().list_library_agents(
user_id=user_id,
search_term=search_query,
page=1,
@@ -272,7 +272,7 @@ async def search_marketplace_agents_for_generation(
List of LibraryAgentSummary with full input/output schemas
"""
try:
response = await store_db.get_store_agents(
response = await store_db().get_store_agents(
search_query=search_query,
page=1,
page_size=max_results,
@@ -286,7 +286,7 @@ async def search_marketplace_agents_for_generation(
return []
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
graphs = await get_store_listed_graphs(*graph_ids)
graphs = await graph_db().get_store_listed_graphs(*graph_ids)
results: list[LibraryAgentSummary] = []
for agent in agents_with_graphs:
@@ -673,9 +673,10 @@ async def save_agent_to_library(
Tuple of (created Graph, LibraryAgent)
"""
graph = json_to_graph(agent_json)
db = library_db()
if is_update:
return await library_db.update_graph_in_library(graph, user_id)
return await library_db.create_graph_in_library(graph, user_id)
return await db.update_graph_in_library(graph, user_id)
return await db.create_graph_in_library(graph, user_id)
def graph_to_json(graph: Graph) -> dict[str, Any]:
@@ -735,12 +736,14 @@ async def get_agent_as_json(
Returns:
Agent as JSON dict or None if not found
"""
graph = await get_graph(agent_id, version=None, user_id=user_id)
db = graph_db()
graph = await db.get_graph(agent_id, version=None, user_id=user_id)
if not graph and user_id:
try:
library_agent = await library_db.get_library_agent(agent_id, user_id)
graph = await get_graph(
library_agent = await library_db().get_library_agent(agent_id, user_id)
graph = await db.get_graph(
library_agent.graph_id, version=None, user_id=user_id
)
except NotFoundError:

View File

@@ -7,10 +7,9 @@ from typing import Any
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.library import db as library_db
from backend.api.features.library.model import LibraryAgent
from backend.data import execution as execution_db
from backend.copilot.model import ChatSession
from backend.data.db_accessors import execution_db, library_db
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
from .base import BaseTool
@@ -165,10 +164,12 @@ class AgentOutputTool(BaseTool):
Resolve agent from provided identifiers.
Returns (library_agent, error_message).
"""
lib_db = library_db()
# Priority 1: Exact library agent ID
if library_agent_id:
try:
agent = await library_db.get_library_agent(library_agent_id, user_id)
agent = await lib_db.get_library_agent(library_agent_id, user_id)
return agent, None
except Exception as e:
logger.warning(f"Failed to get library agent by ID: {e}")
@@ -182,7 +183,7 @@ class AgentOutputTool(BaseTool):
return None, f"Agent '{store_slug}' not found in marketplace"
# Find in user's library by graph_id
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id)
if not agent:
return (
None,
@@ -194,7 +195,7 @@ class AgentOutputTool(BaseTool):
# Priority 3: Fuzzy name search in library
if agent_name:
try:
response = await library_db.list_library_agents(
response = await lib_db.list_library_agents(
user_id=user_id,
search_term=agent_name,
page_size=5,
@@ -228,9 +229,11 @@ class AgentOutputTool(BaseTool):
Fetch execution(s) based on filters.
Returns (single_execution, available_executions_meta, error_message).
"""
exec_db = execution_db()
# If specific execution_id provided, fetch it directly
if execution_id:
execution = await execution_db.get_graph_execution(
execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
@@ -240,7 +243,7 @@ class AgentOutputTool(BaseTool):
return execution, [], None
# Get completed executions with time filters
executions = await execution_db.get_graph_executions(
executions = await exec_db.get_graph_executions(
graph_id=graph_id,
user_id=user_id,
statuses=[ExecutionStatus.COMPLETED],
@@ -254,7 +257,7 @@ class AgentOutputTool(BaseTool):
# If only one execution, fetch full details
if len(executions) == 1:
full_execution = await execution_db.get_graph_execution(
full_execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
@@ -262,7 +265,7 @@ class AgentOutputTool(BaseTool):
return full_execution, [], None
# Multiple executions - return latest with full details, plus list of available
full_execution = await execution_db.get_graph_execution(
full_execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
@@ -380,7 +383,7 @@ class AgentOutputTool(BaseTool):
and not input_data.store_slug
):
# Fetch execution directly to get graph_id
execution = await execution_db.get_graph_execution(
execution = await execution_db().get_graph_execution(
user_id=user_id,
execution_id=input_data.execution_id,
include_node_executions=False,
@@ -392,7 +395,7 @@ class AgentOutputTool(BaseTool):
)
# Find library agent by graph_id
agent = await library_db.get_library_agent_by_graph_id(
agent = await library_db().get_library_agent_by_graph_id(
user_id, execution.graph_id
)
if not agent:

View File

@@ -4,8 +4,7 @@ import logging
import re
from typing import Literal
from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db
from backend.data.db_accessors import library_db, store_db
from backend.util.exceptions import DatabaseError, NotFoundError
from .models import (
@@ -45,8 +44,10 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
Returns:
AgentInfo if found, None otherwise
"""
lib_db = library_db()
try:
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return AgentInfo(
@@ -71,7 +72,7 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
)
try:
agent = await library_db.get_library_agent(agent_id, user_id)
agent = await lib_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return AgentInfo(
@@ -133,7 +134,7 @@ async def search_agents(
try:
if source == "marketplace":
logger.info(f"Searching marketplace for: {query}")
results = await store_db.get_store_agents(search_query=query, page_size=5)
results = await store_db().get_store_agents(search_query=query, page_size=5)
for agent in results.agents:
agents.append(
AgentInfo(
@@ -159,7 +160,7 @@ async def search_agents(
if not agents:
logger.info(f"Searching user library for: {query}")
results = await library_db.list_library_agents(
results = await library_db().list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=query,
page_size=10,

View File

@@ -5,8 +5,8 @@ from typing import Any
from openai.types.chat import ChatCompletionToolParam
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.response_model import StreamToolOutputAvailable
from backend.copilot.model import ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase

View File

@@ -11,18 +11,11 @@ available (e.g. macOS development).
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
BashExecResponse,
ErrorResponse,
ToolResponseBase,
)
from backend.api.features.chat.tools.sandbox import (
get_workspace_dir,
has_full_sandbox,
run_sandboxed,
)
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
logger = logging.getLogger(__name__)

View File

@@ -3,13 +3,10 @@
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
ErrorResponse,
ResponseType,
ToolResponseBase,
)
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
@@ -78,7 +75,7 @@ class CheckOperationStatusTool(BaseTool):
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
from backend.api.features.chat import stream_registry
from backend.copilot import stream_registry
operation_id = (kwargs.get("operation_id") or "").strip()
task_id = (kwargs.get("task_id") or "").strip()

View File

@@ -3,7 +3,7 @@
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.copilot.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,

View File

@@ -3,9 +3,9 @@
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError
from backend.copilot.model import ChatSession
from backend.data.db_accessors import store_db as get_store_db
from .agent_generator import (
AgentGeneratorNotConfiguredError,
@@ -137,6 +137,8 @@ class CustomizeAgentTool(BaseTool):
creator_username, agent_slug = parts
store_db = get_store_db()
# Fetch the marketplace agent details
try:
agent_details = await store_db.get_store_agent_details(

View File

@@ -3,7 +3,7 @@
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.copilot.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,

View File

@@ -5,9 +5,14 @@ from typing import Any
from pydantic import SecretStr
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
from backend.blocks.linear._api import LinearClient
from backend.copilot.model import ChatSession
from backend.data.db_accessors import user_db
from backend.data.model import APIKeyCredentials
from backend.util.settings import Settings
from .base import BaseTool
from .models import (
ErrorResponse,
FeatureRequestCreatedResponse,
FeatureRequestInfo,
@@ -15,10 +20,6 @@ from backend.api.features.chat.tools.models import (
NoResultsResponse,
ToolResponseBase,
)
from backend.blocks.linear._api import LinearClient
from backend.data.model import APIKeyCredentials
from backend.data.user import get_user_email_by_id
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -332,7 +333,9 @@ class CreateFeatureRequestTool(BaseTool):
# Resolve a human-readable name (email) for the Linear customer record.
# Fall back to user_id if the lookup fails or returns None.
try:
customer_display_name = await get_user_email_by_id(user_id) or user_id
customer_display_name = (
await user_db().get_user_email_by_id(user_id) or user_id
)
except Exception:
customer_display_name = user_id

View File

@@ -1,22 +1,18 @@
"""Tests for SearchFeatureRequestsTool and CreateFeatureRequestTool."""
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.chat.tools.feature_requests import (
CreateFeatureRequestTool,
SearchFeatureRequestsTool,
)
from backend.api.features.chat.tools.models import (
from ._test_data import make_session
from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool
from .models import (
ErrorResponse,
FeatureRequestCreatedResponse,
FeatureRequestSearchResponse,
NoResultsResponse,
)
from ._test_data import make_session
_TEST_USER_ID = "test-user-feature-requests"
_TEST_USER_EMAIL = "testuser@example.com"
@@ -39,7 +35,7 @@ def _mock_linear_config(*, query_return=None, mutate_return=None):
client.mutate.return_value = mutate_return
return (
patch(
"backend.api.features.chat.tools.feature_requests._get_linear_config",
"backend.copilot.tools.feature_requests._get_linear_config",
return_value=(client, _FAKE_PROJECT_ID, _FAKE_TEAM_ID),
),
client,
@@ -208,7 +204,7 @@ class TestSearchFeatureRequestsTool:
async def test_linear_client_init_failure(self):
session = make_session(user_id=_TEST_USER_ID)
with patch(
"backend.api.features.chat.tools.feature_requests._get_linear_config",
"backend.copilot.tools.feature_requests._get_linear_config",
side_effect=RuntimeError("No API key"),
):
tool = SearchFeatureRequestsTool()
@@ -231,10 +227,11 @@ class TestCreateFeatureRequestTool:
@pytest.fixture(autouse=True)
def _patch_email_lookup(self):
mock_user_db = MagicMock()
mock_user_db.get_user_email_by_id = AsyncMock(return_value=_TEST_USER_EMAIL)
with patch(
"backend.api.features.chat.tools.feature_requests.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TEST_USER_EMAIL,
"backend.copilot.tools.feature_requests.user_db",
return_value=mock_user_db,
):
yield
@@ -347,7 +344,7 @@ class TestCreateFeatureRequestTool:
async def test_linear_client_init_failure(self):
session = make_session(user_id=_TEST_USER_ID)
with patch(
"backend.api.features.chat.tools.feature_requests._get_linear_config",
"backend.copilot.tools.feature_requests._get_linear_config",
side_effect=RuntimeError("No API key"),
):
tool = CreateFeatureRequestTool()

View File

@@ -2,7 +2,7 @@
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.copilot.model import ChatSession
from .agent_search import search_agents
from .base import BaseTool

View File

@@ -3,17 +3,18 @@ from typing import Any
from prisma.enums import ContentType
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
from backend.api.features.chat.tools.models import (
from backend.blocks import get_block
from backend.blocks._base import BlockType
from backend.copilot.model import ChatSession
from backend.data.db_accessors import search
from .base import BaseTool, ToolResponseBase
from .models import (
BlockInfoSummary,
BlockListResponse,
ErrorResponse,
NoResultsResponse,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.blocks import get_block
from backend.blocks._base import BlockType
logger = logging.getLogger(__name__)
@@ -107,7 +108,7 @@ class FindBlockTool(BaseTool):
try:
# Search for blocks using hybrid search
results, total = await unified_hybrid_search(
results, total = await search().unified_hybrid_search(
query=query,
content_types=[ContentType.BLOCK],
page=1,

View File

@@ -4,15 +4,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.chat.tools.find_block import (
from backend.blocks._base import BlockType
from ._test_data import make_session
from .find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
FindBlockTool,
)
from backend.api.features.chat.tools.models import BlockListResponse
from backend.blocks._base import BlockType
from ._test_data import make_session
from .models import BlockListResponse
_TEST_USER_ID = "test-user-find-block"
@@ -84,13 +84,17 @@ class TestFindBlockFiltering:
"standard-block-id": standard_block,
}.get(block_id)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 2)
)
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
):
with patch(
"backend.api.features.chat.tools.find_block.get_block",
"backend.copilot.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
@@ -128,13 +132,17 @@ class TestFindBlockFiltering:
"normal-block-id": normal_block,
}.get(block_id)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 2)
)
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
):
with patch(
"backend.api.features.chat.tools.find_block.get_block",
"backend.copilot.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
@@ -353,12 +361,16 @@ class TestFindBlockFiltering:
for d in block_defs
}
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, len(search_results))
)
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, len(search_results)),
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
), patch(
"backend.api.features.chat.tools.find_block.get_block",
"backend.copilot.tools.find_block.get_block",
side_effect=lambda bid: mock_blocks.get(bid),
):
tool = FindBlockTool()

View File

@@ -2,7 +2,7 @@
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.copilot.model import ChatSession
from .agent_search import search_agents
from .base import BaseTool

View File

@@ -4,13 +4,10 @@ import logging
from pathlib import Path
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
DocPageResponse,
ErrorResponse,
ToolResponseBase,
)
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import DocPageResponse, ErrorResponse, ToolResponseBase
logger = logging.getLogger(__name__)

View File

@@ -5,16 +5,12 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator
from backend.api.features.chat.config import ChatConfig
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tracking import (
track_agent_run_success,
track_agent_scheduled,
)
from backend.api.features.library import db as library_db
from backend.copilot.config import ChatConfig
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
from backend.data.db_accessors import graph_db, library_db, user_db
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.data.user import get_user_by_id
from backend.executor import utils as execution_utils
from backend.util.clients import get_scheduler_client
from backend.util.exceptions import DatabaseError, NotFoundError
@@ -200,7 +196,7 @@ class RunAgentTool(BaseTool):
# Priority: library_agent_id if provided
if has_library_id:
library_agent = await library_db.get_library_agent(
library_agent = await library_db().get_library_agent(
params.library_agent_id, user_id
)
if not library_agent:
@@ -209,9 +205,7 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
# Get the graph from the library agent
from backend.data.graph import get_graph
graph = await get_graph(
graph = await graph_db().get_graph(
library_agent.graph_id,
library_agent.graph_version,
user_id=user_id,
@@ -522,7 +516,7 @@ class RunAgentTool(BaseTool):
library_agent = await get_or_create_library_agent(graph, user_id)
# Get user timezone
user = await get_user_by_id(user_id)
user = await user_db().get_user_by_id(user_id)
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
# Create schedule

View File

@@ -7,20 +7,17 @@ from typing import Any
from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
)
from backend.blocks import get_block
from backend.blocks._base import AnyBlockSchema
from backend.copilot.model import ChatSession
from backend.data.db_accessors import workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from .base import BaseTool
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
from .helpers import get_inputs_from_schema
from .models import (
BlockDetails,
@@ -276,7 +273,7 @@ class RunBlockTool(BaseTool):
try:
# Get or create user's workspace for CoPilot file operations
workspace = await get_or_create_workspace(user_id)
workspace = await workspace_db().get_or_create_workspace(user_id)
# Generate synthetic IDs for CoPilot context
# Each chat session is treated as its own agent with one continuous run

View File

@@ -4,16 +4,16 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.chat.tools.models import (
from backend.blocks._base import BlockType
from ._test_data import make_session
from .models import (
BlockDetailsResponse,
BlockOutputResponse,
ErrorResponse,
InputValidationErrorResponse,
)
from backend.api.features.chat.tools.run_block import RunBlockTool
from backend.blocks._base import BlockType
from ._test_data import make_session
from .run_block import RunBlockTool
_TEST_USER_ID = "test-user-run-block"
@@ -77,7 +77,7 @@ class TestRunBlockFiltering:
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
"backend.copilot.tools.run_block.get_block",
return_value=input_block,
):
tool = RunBlockTool()
@@ -103,7 +103,7 @@ class TestRunBlockFiltering:
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
"backend.copilot.tools.run_block.get_block",
return_value=smart_block,
):
tool = RunBlockTool()
@@ -127,7 +127,7 @@ class TestRunBlockFiltering:
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
"backend.copilot.tools.run_block.get_block",
return_value=standard_block,
):
tool = RunBlockTool()
@@ -183,7 +183,7 @@ class TestRunBlockInputValidation:
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
):
tool = RunBlockTool()
@@ -222,7 +222,7 @@ class TestRunBlockInputValidation:
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
):
tool = RunBlockTool()
@@ -263,7 +263,7 @@ class TestRunBlockInputValidation:
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
):
tool = RunBlockTool()
@@ -302,15 +302,19 @@ class TestRunBlockInputValidation:
mock_block.execute = mock_execute
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="test-workspace-id")
)
with (
patch(
"backend.api.features.chat.tools.run_block.get_block",
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.api.features.chat.tools.run_block.get_or_create_workspace",
new_callable=AsyncMock,
return_value=MagicMock(id="test-workspace-id"),
"backend.copilot.tools.run_block.workspace_db",
return_value=mock_workspace_db,
),
):
tool = RunBlockTool()
@@ -344,7 +348,7 @@ class TestRunBlockInputValidation:
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
"backend.copilot.tools.run_block.get_block",
return_value=mock_block,
):
tool = RunBlockTool()

View File

@@ -5,16 +5,17 @@ from typing import Any
from prisma.enums import ContentType
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
from backend.copilot.model import ChatSession
from backend.data.db_accessors import search
from .base import BaseTool
from .models import (
DocSearchResult,
DocSearchResultsResponse,
ErrorResponse,
NoResultsResponse,
ToolResponseBase,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
logger = logging.getLogger(__name__)
@@ -117,7 +118,7 @@ class SearchDocsTool(BaseTool):
try:
# Search using hybrid search for DOCUMENTATION content type only
results, total = await unified_hybrid_search(
results, total = await search().unified_hybrid_search(
query=query,
content_types=[ContentType.DOCUMENTATION],
page=1,

View File

@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.chat.tools.models import BlockDetailsResponse
from backend.api.features.chat.tools.run_block import RunBlockTool
from backend.blocks._base import BlockType
from backend.data.model import CredentialsMetaInput
from backend.integrations.providers import ProviderName
from ._test_data import make_session
from .models import BlockDetailsResponse
from .run_block import RunBlockTool
_TEST_USER_ID = "test-user-run-block-details"
@@ -61,7 +61,7 @@ async def test_run_block_returns_details_when_no_input_provided():
)
with patch(
"backend.api.features.chat.tools.run_block.get_block",
"backend.copilot.tools.run_block.get_block",
return_value=http_block,
):
# Mock credentials check to return no missing credentials
@@ -120,7 +120,7 @@ async def test_run_block_returns_details_when_only_credentials_provided():
}
with patch(
"backend.api.features.chat.tools.run_block.get_block",
"backend.copilot.tools.run_block.get_block",
return_value=mock,
):
with patch.object(

View File

@@ -3,9 +3,8 @@
import logging
from typing import Any
from backend.api.features.library import db as library_db
from backend.api.features.library import model as library_model
from backend.api.features.store import db as store_db
from backend.data.db_accessors import library_db, store_db
from backend.data.graph import GraphModel
from backend.data.model import (
Credentials,
@@ -39,13 +38,14 @@ async def fetch_graph_from_store_slug(
Raises:
DatabaseError: If there's a database error during lookup.
"""
sdb = store_db()
try:
store_agent = await store_db.get_store_agent_details(username, agent_name)
store_agent = await sdb.get_store_agent_details(username, agent_name)
except NotFoundError:
return None, None
# Get the graph from store listing version
graph = await store_db.get_available_graph(
graph = await sdb.get_available_graph(
store_agent.store_listing_version_id, hide_nodes=False
)
return graph, store_agent
@@ -210,13 +210,13 @@ async def get_or_create_library_agent(
Returns:
LibraryAgent instance
"""
existing = await library_db.get_library_agent_by_graph_id(
existing = await library_db().get_library_agent_by_graph_id(
graph_id=graph.id, user_id=user_id
)
if existing:
return existing
library_agents = await library_db.create_library_agent(
library_agents = await library_db().create_library_agent(
graph=graph,
user_id=user_id,
create_library_agents_for_sub_graphs=False,

View File

@@ -6,15 +6,12 @@ from typing import Any
import aiohttp
import html2text
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
ErrorResponse,
ToolResponseBase,
WebFetchResponse,
)
from backend.copilot.model import ChatSession
from backend.util.request import Requests
from .base import BaseTool
from .models import ErrorResponse, ToolResponseBase, WebFetchResponse
logger = logging.getLogger(__name__)
# Limits

View File

@@ -6,8 +6,8 @@ from typing import Any, Optional
from pydantic import BaseModel
from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace
from backend.copilot.model import ChatSession
from backend.data.db_accessors import workspace_db
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
@@ -148,7 +148,7 @@ class ListWorkspaceFilesTool(BaseTool):
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
workspace = await get_or_create_workspace(user_id)
workspace = await workspace_db().get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
@@ -284,7 +284,7 @@ class ReadWorkspaceFileTool(BaseTool):
)
try:
workspace = await get_or_create_workspace(user_id)
workspace = await workspace_db().get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
@@ -484,7 +484,7 @@ class WriteWorkspaceFileTool(BaseTool):
# Virus scan
await scan_content_safe(content, filename=filename)
workspace = await get_or_create_workspace(user_id)
workspace = await workspace_db().get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
@@ -583,7 +583,7 @@ class DeleteWorkspaceFileTool(BaseTool):
)
try:
workspace = await get_or_create_workspace(user_id)
workspace = await workspace_db().get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)

View File

@@ -75,6 +75,8 @@ MODEL_COST: dict[LlmModel, int] = {
LlmModel.GPT41_MINI: 1,
LlmModel.GPT4O_MINI: 1,
LlmModel.GPT4O: 3,
LlmModel.GPT4_TURBO: 10,
LlmModel.GPT3_5_TURBO: 1,
LlmModel.CLAUDE_4_1_OPUS: 21,
LlmModel.CLAUDE_4_OPUS: 21,
LlmModel.CLAUDE_4_SONNET: 5,

View File

@@ -79,7 +79,7 @@ async def test_block_credit_usage(server: SpinTestServer):
node_exec_id="test_node_exec",
block_id=AITextGeneratorBlock().id,
inputs={
"model": "gpt-4o",
"model": "gpt-4-turbo",
"credentials": {
"id": openai_credentials.id,
"provider": openai_credentials.provider,
@@ -100,7 +100,7 @@ async def test_block_credit_usage(server: SpinTestServer):
graph_exec_id="test_graph_exec",
node_exec_id="test_node_exec",
block_id=AITextGeneratorBlock().id,
inputs={"model": "gpt-4o", "api_key": "owned_api_key"},
inputs={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
execution_context=ExecutionContext(user_timezone="UTC"),
),
)

View File

@@ -0,0 +1,118 @@
from backend.data import db
def chat_db():
if db.is_connected():
from backend.copilot import db as _chat_db
chat_db = _chat_db
else:
from backend.util.clients import get_database_manager_async_client
chat_db = get_database_manager_async_client()
return chat_db
def graph_db():
if db.is_connected():
from backend.data import graph as _graph_db
graph_db = _graph_db
else:
from backend.util.clients import get_database_manager_async_client
graph_db = get_database_manager_async_client()
return graph_db
def library_db():
if db.is_connected():
from backend.api.features.library import db as _library_db
library_db = _library_db
else:
from backend.util.clients import get_database_manager_async_client
library_db = get_database_manager_async_client()
return library_db
def store_db():
if db.is_connected():
from backend.api.features.store import db as _store_db
store_db = _store_db
else:
from backend.util.clients import get_database_manager_async_client
store_db = get_database_manager_async_client()
return store_db
def search():
if db.is_connected():
from backend.api.features.store import hybrid_search as _search
search = _search
else:
from backend.util.clients import get_database_manager_async_client
search = get_database_manager_async_client()
return search
def execution_db():
if db.is_connected():
from backend.data import execution as _execution_db
execution_db = _execution_db
else:
from backend.util.clients import get_database_manager_async_client
execution_db = get_database_manager_async_client()
return execution_db
def user_db():
if db.is_connected():
from backend.data import user as _user_db
user_db = _user_db
else:
from backend.util.clients import get_database_manager_async_client
user_db = get_database_manager_async_client()
return user_db
def understanding_db():
if db.is_connected():
from backend.data import understanding as _understanding_db
understanding_db = _understanding_db
else:
from backend.util.clients import get_database_manager_async_client
understanding_db = get_database_manager_async_client()
return understanding_db
def workspace_db():
if db.is_connected():
from backend.data import workspace as _workspace_db
workspace_db = _workspace_db
else:
from backend.util.clients import get_database_manager_async_client
workspace_db = get_database_manager_async_client()
return workspace_db

View File

@@ -4,14 +4,26 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas
from backend.api.features.library.db import (
add_store_agent_to_library,
create_graph_in_library,
create_library_agent,
get_library_agent,
get_library_agent_by_graph_id,
list_library_agents,
update_graph_in_library,
)
from backend.api.features.store.db import (
get_agent,
get_available_graph,
get_store_agent_details,
get_store_agents,
)
from backend.api.features.store.db import get_store_agent_details, get_store_agents
from backend.api.features.store.embeddings import (
backfill_missing_embeddings,
cleanup_orphaned_embeddings,
get_embedding_stats,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.copilot import db as chat_db
from backend.data import db
from backend.data.analytics import (
get_accuracy_trends_and_alerts,
@@ -48,6 +60,7 @@ from backend.data.graph import (
get_graph_metadata,
get_graph_settings,
get_node,
get_store_listed_graphs,
validate_graph_execution_permissions,
)
from backend.data.human_review import (
@@ -67,6 +80,10 @@ from backend.data.notifications import (
remove_notifications_from_batch,
)
from backend.data.onboarding import increment_onboarding_runs
from backend.data.understanding import (
get_business_understanding,
upsert_business_understanding,
)
from backend.data.user import (
get_active_user_ids_in_timerange,
get_user_by_id,
@@ -76,6 +93,7 @@ from backend.data.user import (
get_user_notification_preference,
update_user_integrations,
)
from backend.data.workspace import get_or_create_workspace
from backend.util.service import (
AppService,
AppServiceClient,
@@ -107,6 +125,13 @@ async def _get_credits(user_id: str) -> int:
class DatabaseManager(AppService):
"""Database connection pooling service.
This service connects to the Prisma engine and exposes database
operations via RPC endpoints. It acts as a centralized connection pool
for all services that need database access.
"""
@asynccontextmanager
async def lifespan(self, app: "FastAPI"):
async with super().lifespan(app):
@@ -142,11 +167,15 @@ class DatabaseManager(AppService):
def _(
f: Callable[P, R], name: str | None = None
) -> Callable[Concatenate[object, P], R]:
"""
Exposes a function as an RPC endpoint, and adds a virtual `self` param
to the function's type so it can be bound as a method.
"""
if name is not None:
f.__name__ = name
return cast(Callable[Concatenate[object, P], R], expose(f))
# Executions
# ============ Graph Executions ============ #
get_child_graph_executions = _(get_child_graph_executions)
get_graph_executions = _(get_graph_executions)
get_graph_executions_count = _(get_graph_executions_count)
@@ -170,36 +199,37 @@ class DatabaseManager(AppService):
get_frequently_executed_graphs = _(get_frequently_executed_graphs)
get_marketplace_graphs_for_monitoring = _(get_marketplace_graphs_for_monitoring)
# Graphs
# ============ Graphs ============ #
get_node = _(get_node)
get_graph = _(get_graph)
get_connected_output_nodes = _(get_connected_output_nodes)
get_graph_metadata = _(get_graph_metadata)
get_graph_settings = _(get_graph_settings)
get_store_listed_graphs = _(get_store_listed_graphs)
# Credits
# ============ Credits ============ #
spend_credits = _(_spend_credits, name="spend_credits")
get_credits = _(_get_credits, name="get_credits")
# User + User Metadata + User Integrations
# ============ User + Integrations ============ #
get_user_by_id = _(get_user_by_id)
get_user_integrations = _(get_user_integrations)
update_user_integrations = _(update_user_integrations)
# User Comms - async
# ============ User Comms ============ #
get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange)
get_user_by_id = _(get_user_by_id)
get_user_email_by_id = _(get_user_email_by_id)
get_user_email_verification = _(get_user_email_verification)
get_user_notification_preference = _(get_user_notification_preference)
# Human In The Loop
# ============ Human In The Loop ============ #
cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution)
check_approval = _(check_approval)
get_or_create_human_review = _(get_or_create_human_review)
has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec)
update_review_processed_status = _(update_review_processed_status)
# Notifications - async
# ============ Notifications ============ #
clear_all_user_notification_batches = _(clear_all_user_notification_batches)
create_or_add_to_user_notification_batch = _(
create_or_add_to_user_notification_batch
@@ -212,29 +242,56 @@ class DatabaseManager(AppService):
get_user_notification_oldest_message_in_batch
)
# Library
# ============ Library ============ #
list_library_agents = _(list_library_agents)
add_store_agent_to_library = _(add_store_agent_to_library)
create_graph_in_library = _(create_graph_in_library)
create_library_agent = _(create_library_agent)
get_library_agent = _(get_library_agent)
get_library_agent_by_graph_id = _(get_library_agent_by_graph_id)
update_graph_in_library = _(update_graph_in_library)
validate_graph_execution_permissions = _(validate_graph_execution_permissions)
# Onboarding
# ============ Onboarding ============ #
increment_onboarding_runs = _(increment_onboarding_runs)
# OAuth
# ============ OAuth ============ #
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
# Store
# ============ Store ============ #
get_store_agents = _(get_store_agents)
get_store_agent_details = _(get_store_agent_details)
get_agent = _(get_agent)
get_available_graph = _(get_available_graph)
# Store Embeddings
# ============ Search ============ #
get_embedding_stats = _(get_embedding_stats)
backfill_missing_embeddings = _(backfill_missing_embeddings)
cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings)
unified_hybrid_search = _(unified_hybrid_search)
# Summary data - async
# ============ Summary Data ============ #
get_user_execution_summary_data = _(get_user_execution_summary_data)
# ============ Workspace ============ #
get_or_create_workspace = _(get_or_create_workspace)
# ============ Understanding ============ #
get_business_understanding = _(get_business_understanding)
upsert_business_understanding = _(upsert_business_understanding)
# ============ CoPilot Chat Sessions ============ #
get_chat_session = _(chat_db.get_chat_session)
create_chat_session = _(chat_db.create_chat_session)
update_chat_session = _(chat_db.update_chat_session)
add_chat_message = _(chat_db.add_chat_message)
add_chat_messages_batch = _(chat_db.add_chat_messages_batch)
get_user_chat_sessions = _(chat_db.get_user_chat_sessions)
get_user_session_count = _(chat_db.get_user_session_count)
delete_chat_session = _(chat_db.delete_chat_session)
get_chat_session_message_count = _(chat_db.get_chat_session_message_count)
update_tool_message_content = _(chat_db.update_tool_message_content)
class DatabaseManagerClient(AppServiceClient):
d = DatabaseManager
@@ -296,43 +353,50 @@ class DatabaseManagerAsyncClient(AppServiceClient):
def get_service_type(cls):
return DatabaseManager
# ============ Graph Executions ============ #
create_graph_execution = d.create_graph_execution
get_child_graph_executions = d.get_child_graph_executions
get_connected_output_nodes = d.get_connected_output_nodes
get_latest_node_execution = d.get_latest_node_execution
get_graph = d.get_graph
get_graph_metadata = d.get_graph_metadata
get_graph_settings = d.get_graph_settings
get_graph_execution = d.get_graph_execution
get_graph_execution_meta = d.get_graph_execution_meta
get_node = d.get_node
get_graph_executions = d.get_graph_executions
get_node_execution = d.get_node_execution
get_node_executions = d.get_node_executions
get_user_by_id = d.get_user_by_id
get_user_integrations = d.get_user_integrations
upsert_execution_input = d.upsert_execution_input
upsert_execution_output = d.upsert_execution_output
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
update_graph_execution_stats = d.update_graph_execution_stats
update_node_execution_status = d.update_node_execution_status
update_node_execution_status_batch = d.update_node_execution_status_batch
update_user_integrations = d.update_user_integrations
upsert_execution_input = d.upsert_execution_input
upsert_execution_output = d.upsert_execution_output
get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id
get_execution_kv_data = d.get_execution_kv_data
set_execution_kv_data = d.set_execution_kv_data
# Human In The Loop
# ============ Graphs ============ #
get_graph = d.get_graph
get_graph_metadata = d.get_graph_metadata
get_graph_settings = d.get_graph_settings
get_node = d.get_node
get_store_listed_graphs = d.get_store_listed_graphs
# ============ User + Integrations ============ #
get_user_by_id = d.get_user_by_id
get_user_integrations = d.get_user_integrations
update_user_integrations = d.update_user_integrations
# ============ Human In The Loop ============ #
cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution
check_approval = d.check_approval
get_or_create_human_review = d.get_or_create_human_review
update_review_processed_status = d.update_review_processed_status
# User Comms
# ============ User Comms ============ #
get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange
get_user_email_by_id = d.get_user_email_by_id
get_user_email_verification = d.get_user_email_verification
get_user_notification_preference = d.get_user_notification_preference
# Notifications
# ============ Notifications ============ #
clear_all_user_notification_batches = d.clear_all_user_notification_batches
create_or_add_to_user_notification_batch = (
d.create_or_add_to_user_notification_batch
@@ -345,20 +409,49 @@ class DatabaseManagerAsyncClient(AppServiceClient):
d.get_user_notification_oldest_message_in_batch
)
# Library
# ============ Library ============ #
list_library_agents = d.list_library_agents
add_store_agent_to_library = d.add_store_agent_to_library
create_graph_in_library = d.create_graph_in_library
create_library_agent = d.create_library_agent
get_library_agent = d.get_library_agent
get_library_agent_by_graph_id = d.get_library_agent_by_graph_id
update_graph_in_library = d.update_graph_in_library
validate_graph_execution_permissions = d.validate_graph_execution_permissions
# Onboarding
# ============ Onboarding ============ #
increment_onboarding_runs = d.increment_onboarding_runs
# OAuth
# ============ OAuth ============ #
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
# Store
# ============ Store ============ #
get_store_agents = d.get_store_agents
get_store_agent_details = d.get_store_agent_details
get_agent = d.get_agent
get_available_graph = d.get_available_graph
# Summary data
# ============ Search ============ #
unified_hybrid_search = d.unified_hybrid_search
# ============ Summary Data ============ #
get_user_execution_summary_data = d.get_user_execution_summary_data
# ============ Workspace ============ #
get_or_create_workspace = d.get_or_create_workspace
# ============ Understanding ============ #
get_business_understanding = d.get_business_understanding
upsert_business_understanding = d.upsert_business_understanding
# ============ CoPilot Chat Sessions ============ #
get_chat_session = d.get_chat_session
create_chat_session = d.create_chat_session
update_chat_session = d.update_chat_session
add_chat_message = d.add_chat_message
add_chat_messages_batch = d.add_chat_messages_batch
get_user_chat_sessions = d.get_user_chat_sessions
get_user_session_count = d.get_user_session_count
delete_chat_session = d.delete_chat_session
get_chat_session_message_count = d.get_chat_session_message_count
update_tool_message_content = d.update_tool_message_content

View File

@@ -1,5 +1,5 @@
from backend.app import run_processes
from backend.executor import DatabaseManager
from backend.data.db_manager import DatabaseManager
def main():

View File

@@ -1,11 +1,7 @@
from .database import DatabaseManager, DatabaseManagerAsyncClient, DatabaseManagerClient
from .manager import ExecutionManager
from .scheduler import Scheduler
__all__ = [
"DatabaseManager",
"DatabaseManagerClient",
"DatabaseManagerAsyncClient",
"ExecutionManager",
"Scheduler",
]

View File

@@ -22,7 +22,7 @@ from backend.util.settings import Settings
from backend.util.truncate import truncate
if TYPE_CHECKING:
from backend.executor import DatabaseManagerAsyncClient
from backend.data.db_manager import DatabaseManagerAsyncClient
logger = logging.getLogger(__name__)

View File

@@ -4,7 +4,7 @@ import logging
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from backend.executor import DatabaseManagerAsyncClient
from backend.data.db_manager import DatabaseManagerAsyncClient
from pydantic import ValidationError

View File

@@ -1,6 +1,7 @@
"""Redis-based distributed locking for cluster coordination."""
import logging
import threading
import time
from typing import TYPE_CHECKING
@@ -19,6 +20,7 @@ class ClusterLock:
self.owner_id = owner_id
self.timeout = timeout
self._last_refresh = 0.0
self._refresh_lock = threading.Lock()
def try_acquire(self) -> str | None:
"""Try to acquire the lock.
@@ -31,7 +33,8 @@ class ClusterLock:
try:
success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout)
if success:
self._last_refresh = time.time()
with self._refresh_lock:
self._last_refresh = time.time()
return self.owner_id # Successfully acquired
# Failed to acquire, get current owner
@@ -57,23 +60,27 @@ class ClusterLock:
Rate limited to at most once every timeout/10 seconds (minimum 1 second).
During rate limiting, still verifies lock existence but skips TTL extension.
Setting _last_refresh to 0 bypasses rate limiting for testing.
Thread-safe: uses _refresh_lock to protect _last_refresh access.
"""
# Calculate refresh interval: max(timeout // 10, 1)
refresh_interval = max(self.timeout // 10, 1)
current_time = time.time()
# Check if we're within the rate limit period
# Check if we're within the rate limit period (thread-safe read)
# _last_refresh == 0 forces a refresh (bypasses rate limiting for testing)
with self._refresh_lock:
last_refresh = self._last_refresh
is_rate_limited = (
self._last_refresh > 0
and (current_time - self._last_refresh) < refresh_interval
last_refresh > 0 and (current_time - last_refresh) < refresh_interval
)
try:
# Always verify lock existence, even during rate limiting
current_value = self.redis.get(self.key)
if not current_value:
self._last_refresh = 0
with self._refresh_lock:
self._last_refresh = 0
return False
stored_owner = (
@@ -82,7 +89,8 @@ class ClusterLock:
else str(current_value)
)
if stored_owner != self.owner_id:
self._last_refresh = 0
with self._refresh_lock:
self._last_refresh = 0
return False
# If rate limited, return True but don't update TTL or timestamp
@@ -91,25 +99,30 @@ class ClusterLock:
# Perform actual refresh
if self.redis.expire(self.key, self.timeout):
self._last_refresh = current_time
with self._refresh_lock:
self._last_refresh = current_time
return True
self._last_refresh = 0
with self._refresh_lock:
self._last_refresh = 0
return False
except Exception as e:
logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}")
self._last_refresh = 0
with self._refresh_lock:
self._last_refresh = 0
return False
def release(self):
"""Release the lock."""
if self._last_refresh == 0:
return
with self._refresh_lock:
if self._last_refresh == 0:
return
try:
self.redis.delete(self.key)
except Exception:
pass
self._last_refresh = 0.0
with self._refresh_lock:
self._last_refresh = 0.0

View File

@@ -93,7 +93,10 @@ from .utils import (
)
if TYPE_CHECKING:
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
from backend.data.db_manager import (
DatabaseManagerAsyncClient,
DatabaseManagerClient,
)
_logger = logging.getLogger(__name__)

View File

@@ -13,12 +13,15 @@ if TYPE_CHECKING:
from openai import AsyncOpenAI
from supabase import AClient, Client
from backend.data.db_manager import (
DatabaseManagerAsyncClient,
DatabaseManagerClient,
)
from backend.data.execution import (
AsyncRedisExecutionEventBus,
RedisExecutionEventBus,
)
from backend.data.rabbitmq import AsyncRabbitMQ, SyncRabbitMQ
from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient
from backend.executor.scheduler import SchedulerClient
from backend.integrations.credentials_store import IntegrationCredentialsStore
from backend.notifications.notifications import NotificationManagerClient
@@ -27,7 +30,7 @@ if TYPE_CHECKING:
@thread_cached
def get_database_manager_client() -> "DatabaseManagerClient":
"""Get a thread-cached DatabaseManagerClient with request retry enabled."""
from backend.executor import DatabaseManagerClient
from backend.data.db_manager import DatabaseManagerClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerClient, request_retry=True)
@@ -38,7 +41,7 @@ def get_database_manager_async_client(
should_retry: bool = True,
) -> "DatabaseManagerAsyncClient":
"""Get a thread-cached DatabaseManagerAsyncClient with request retry enabled."""
from backend.executor import DatabaseManagerAsyncClient
from backend.data.db_manager import DatabaseManagerAsyncClient
from backend.util.service import get_service_client
return get_service_client(DatabaseManagerAsyncClient, request_retry=should_retry)
@@ -106,6 +109,20 @@ async def get_async_execution_queue() -> "AsyncRabbitMQ":
return client
# ============ CoPilot Queue Helpers ============ #
@thread_cached
async def get_async_copilot_queue() -> "AsyncRabbitMQ":
"""Get a thread-cached AsyncRabbitMQ CoPilot queue client."""
from backend.copilot.executor.utils import create_copilot_queue_config
from backend.data.rabbitmq import AsyncRabbitMQ
client = AsyncRabbitMQ(create_copilot_queue_config())
await client.connect()
return client
# ============ Integration Credentials Store ============ #

View File

@@ -211,16 +211,23 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="The port for execution manager daemon to run on",
)
num_copilot_workers: int = Field(
default=5,
ge=1,
le=100,
description="Number of concurrent CoPilot executor workers",
)
copilot_executor_port: int = Field(
default=8008,
description="The port for CoPilot executor daemon to run on",
)
execution_scheduler_port: int = Field(
default=8003,
description="The port for execution scheduler daemon to run on",
)
agent_server_port: int = Field(
default=8004,
description="The port for agent server daemon to run on",
)
database_api_port: int = Field(
default=8005,
description="The port for database server API to run on",

View File

@@ -11,6 +11,7 @@ from backend.api.rest_api import AgentServer
from backend.blocks._base import Block, BlockSchema
from backend.data import db
from backend.data.block import initialize_blocks
from backend.data.db_manager import DatabaseManager
from backend.data.execution import (
ExecutionContext,
ExecutionStatus,
@@ -19,7 +20,7 @@ from backend.data.execution import (
)
from backend.data.model import _BaseCredentials
from backend.data.user import create_default_user
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
from backend.executor import ExecutionManager, Scheduler
from backend.notifications.notifications import NotificationManager
log = logging.getLogger(__name__)

View File

@@ -53,7 +53,7 @@ services:
rabbitmq:
<<: *agpt-services
image: rabbitmq:management
image: rabbitmq:4.1.4
container_name: rabbitmq
healthcheck:
test: rabbitmq-diagnostics -q ping
@@ -66,7 +66,6 @@ services:
- RABBITMQ_DEFAULT_PASS=k0VMxyIJF9S35f3x2uaw5IWAl6Y536O7
ports:
- "5672:5672"
- "15672:15672"
clamav:
image: clamav/clamav-debian:latest
ports:

View File

@@ -1,42 +0,0 @@
-- Migrate deprecated OpenAI GPT-4-turbo and GPT-3.5-turbo models
-- This updates all AgentNode blocks that use deprecated models
-- OpenAI is retiring these models:
-- - gpt-4-turbo: March 26, 2026 -> migrate to gpt-4o
-- - gpt-3.5-turbo: September 28, 2026 -> migrate to gpt-4o-mini
-- Update gpt-4-turbo to gpt-4o (staying in same capability tier)
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
'"gpt-4o"'::jsonb
)
WHERE "constantInput"::jsonb->>'model' = 'gpt-4-turbo';
-- Update gpt-3.5-turbo to gpt-4o-mini (appropriate replacement for lightweight model)
UPDATE "AgentNode"
SET "constantInput" = JSONB_SET(
"constantInput"::jsonb,
'{model}',
'"gpt-4o-mini"'::jsonb
)
WHERE "constantInput"::jsonb->>'model' = 'gpt-3.5-turbo';
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput)
UPDATE "AgentNodeExecutionInputOutput"
SET "data" = JSONB_SET(
"data"::jsonb,
'{model}',
'"gpt-4o"'::jsonb
)
WHERE "agentPresetId" IS NOT NULL
AND "data"::jsonb->>'model' = 'gpt-4-turbo';
UPDATE "AgentNodeExecutionInputOutput"
SET "data" = JSONB_SET(
"data"::jsonb,
'{model}',
'"gpt-4o-mini"'::jsonb
)
WHERE "agentPresetId" IS NOT NULL
AND "data"::jsonb->>'model' = 'gpt-3.5-turbo';

View File

@@ -117,6 +117,7 @@ ws = "backend.ws:main"
scheduler = "backend.scheduler:main"
notification = "backend.notification:main"
executor = "backend.exec:main"
copilot-executor = "backend.copilot.executor.__main__:main"
cli = "backend.cli:main"
format = "linter:format"
lint = "linter:lint"

View File

@@ -9,10 +9,8 @@ from unittest.mock import AsyncMock, patch
import pytest
from backend.api.features.chat.tools.agent_generator import core
from backend.api.features.chat.tools.agent_generator.core import (
AgentGeneratorNotConfiguredError,
)
from backend.copilot.tools.agent_generator import core
from backend.copilot.tools.agent_generator.core import AgentGeneratorNotConfiguredError
class TestServiceNotConfigured:

View File

@@ -9,7 +9,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.chat.tools.agent_generator import core
from backend.copilot.tools.agent_generator import core
class TestGetLibraryAgentsForGeneration:
@@ -31,18 +31,20 @@ class TestGetLibraryAgentsForGeneration:
mock_response = MagicMock()
mock_response.agents = [mock_agent]
mock_db = MagicMock()
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
with patch.object(
core.library_db,
"list_library_agents",
new_callable=AsyncMock,
return_value=mock_response,
) as mock_list:
core,
"library_db",
return_value=mock_db,
):
result = await core.get_library_agents_for_generation(
user_id="user-123",
search_query="send email",
)
mock_list.assert_called_once_with(
mock_db.list_library_agents.assert_called_once_with(
user_id="user-123",
search_term="send email",
page=1,
@@ -80,11 +82,13 @@ class TestGetLibraryAgentsForGeneration:
),
]
mock_db = MagicMock()
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
with patch.object(
core.library_db,
"list_library_agents",
new_callable=AsyncMock,
return_value=mock_response,
core,
"library_db",
return_value=mock_db,
):
result = await core.get_library_agents_for_generation(
user_id="user-123",
@@ -101,18 +105,20 @@ class TestGetLibraryAgentsForGeneration:
mock_response = MagicMock()
mock_response.agents = []
mock_db = MagicMock()
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
with patch.object(
core.library_db,
"list_library_agents",
new_callable=AsyncMock,
return_value=mock_response,
) as mock_list:
core,
"library_db",
return_value=mock_db,
):
await core.get_library_agents_for_generation(
user_id="user-123",
max_results=5,
)
mock_list.assert_called_once_with(
mock_db.list_library_agents.assert_called_once_with(
user_id="user-123",
search_term=None,
page=1,
@@ -144,24 +150,24 @@ class TestSearchMarketplaceAgentsForGeneration:
mock_graph.input_schema = {"type": "object"}
mock_graph.output_schema = {"type": "object"}
mock_store_db = MagicMock()
mock_store_db.get_store_agents = AsyncMock(return_value=mock_response)
mock_graph_db = MagicMock()
mock_graph_db.get_store_listed_graphs = AsyncMock(
return_value={"graph-123": mock_graph}
)
with (
patch(
"backend.api.features.store.db.get_store_agents",
new_callable=AsyncMock,
return_value=mock_response,
) as mock_search,
patch(
"backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs",
new_callable=AsyncMock,
return_value={"graph-123": mock_graph},
),
patch.object(core, "store_db", return_value=mock_store_db),
patch.object(core, "graph_db", return_value=mock_graph_db),
):
result = await core.search_marketplace_agents_for_generation(
search_query="automation",
max_results=10,
)
mock_search.assert_called_once_with(
mock_store_db.get_store_agents.assert_called_once_with(
search_query="automation",
page=1,
page_size=10,
@@ -707,7 +713,7 @@ class TestExtractUuidsFromText:
class TestGetLibraryAgentById:
"""Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id)."""
"""Test get_library_agent_by_id function (alias: get_library_agent_by_graph_id)."""
@pytest.mark.asyncio
async def test_returns_agent_when_found_by_graph_id(self):
@@ -720,12 +726,10 @@ class TestGetLibraryAgentById:
mock_agent.input_schema = {"properties": {}}
mock_agent.output_schema = {"properties": {}}
with patch.object(
core.library_db,
"get_library_agent_by_graph_id",
new_callable=AsyncMock,
return_value=mock_agent,
):
mock_db = MagicMock()
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
with patch.object(core, "library_db", return_value=mock_db):
result = await core.get_library_agent_by_id("user-123", "agent-123")
assert result is not None
@@ -743,20 +747,11 @@ class TestGetLibraryAgentById:
mock_agent.input_schema = {"properties": {}}
mock_agent.output_schema = {"properties": {}}
with (
patch.object(
core.library_db,
"get_library_agent_by_graph_id",
new_callable=AsyncMock,
return_value=None, # Not found by graph_id
),
patch.object(
core.library_db,
"get_library_agent",
new_callable=AsyncMock,
return_value=mock_agent, # Found by library ID
),
):
mock_db = MagicMock()
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None)
mock_db.get_library_agent = AsyncMock(return_value=mock_agent)
with patch.object(core, "library_db", return_value=mock_db):
result = await core.get_library_agent_by_id("user-123", "library-id-123")
assert result is not None
@@ -766,20 +761,13 @@ class TestGetLibraryAgentById:
@pytest.mark.asyncio
async def test_returns_none_when_not_found_by_either_method(self):
"""Test that None is returned when agent not found by either method."""
with (
patch.object(
core.library_db,
"get_library_agent_by_graph_id",
new_callable=AsyncMock,
return_value=None,
),
patch.object(
core.library_db,
"get_library_agent",
new_callable=AsyncMock,
side_effect=core.NotFoundError("Not found"),
),
):
mock_db = MagicMock()
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None)
mock_db.get_library_agent = AsyncMock(
side_effect=core.NotFoundError("Not found")
)
with patch.object(core, "library_db", return_value=mock_db):
result = await core.get_library_agent_by_id("user-123", "nonexistent")
assert result is None
@@ -787,27 +775,20 @@ class TestGetLibraryAgentById:
@pytest.mark.asyncio
async def test_returns_none_on_exception(self):
"""Test that None is returned when exception occurs in both lookups."""
with (
patch.object(
core.library_db,
"get_library_agent_by_graph_id",
new_callable=AsyncMock,
side_effect=Exception("Database error"),
),
patch.object(
core.library_db,
"get_library_agent",
new_callable=AsyncMock,
side_effect=Exception("Database error"),
),
):
mock_db = MagicMock()
mock_db.get_library_agent_by_graph_id = AsyncMock(
side_effect=Exception("Database error")
)
mock_db.get_library_agent = AsyncMock(side_effect=Exception("Database error"))
with patch.object(core, "library_db", return_value=mock_db):
result = await core.get_library_agent_by_id("user-123", "agent-123")
assert result is None
@pytest.mark.asyncio
async def test_alias_works(self):
"""Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id."""
"""Test that get_library_agent_by_graph_id is an alias."""
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
@@ -828,20 +809,11 @@ class TestGetAllRelevantAgentsWithUuids:
mock_response = MagicMock()
mock_response.agents = []
with (
patch.object(
core.library_db,
"get_library_agent_by_graph_id",
new_callable=AsyncMock,
return_value=mock_agent,
),
patch.object(
core.library_db,
"list_library_agents",
new_callable=AsyncMock,
return_value=mock_response,
),
):
mock_db = MagicMock()
mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent)
mock_db.list_library_agents = AsyncMock(return_value=mock_response)
with patch.object(core, "library_db", return_value=mock_db):
result = await core.get_all_relevant_agents_for_generation(
user_id="user-123",
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",

View File

@@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from backend.api.features.chat.tools.agent_generator import service
from backend.copilot.tools.agent_generator import service
class TestServiceConfiguration:

View File

@@ -1,133 +0,0 @@
"""Tests for SDK security hooks — workspace paths, tool access, and deny messages.
These are pure unit tests with no external dependencies (no SDK, no DB, no server).
They validate that the security hooks correctly block unauthorized paths,
tool access, and dangerous input patterns.
Note: Bash command validation was removed — the SDK built-in Bash tool is not in
allowed_tools, and the bash_exec MCP tool has kernel-level network isolation
(unshare --net) making command-level parsing unnecessary.
"""
from backend.api.features.chat.sdk.security_hooks import (
_validate_tool_access,
_validate_workspace_path,
)
SDK_CWD = "/tmp/copilot-test-session"
def _is_denied(result: dict) -> bool:
hook = result.get("hookSpecificOutput", {})
return hook.get("permissionDecision") == "deny"
def _reason(result: dict) -> str:
return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "")
# ============================================================
# Workspace path validation (Read, Write, Edit, etc.)
# ============================================================
class TestWorkspacePathValidation:
def test_path_in_workspace(self):
result = _validate_workspace_path(
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
)
assert not _is_denied(result)
def test_path_outside_workspace(self):
result = _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
assert _is_denied(result)
def test_tool_results_allowed(self):
result = _validate_workspace_path(
"Read",
{"file_path": "~/.claude/projects/abc/tool-results/out.txt"},
SDK_CWD,
)
assert not _is_denied(result)
def test_claude_settings_blocked(self):
result = _validate_workspace_path(
"Read", {"file_path": "~/.claude/settings.json"}, SDK_CWD
)
assert _is_denied(result)
def test_claude_projects_without_tool_results(self):
result = _validate_workspace_path(
"Read", {"file_path": "~/.claude/projects/abc/credentials.json"}, SDK_CWD
)
assert _is_denied(result)
def test_no_path_allowed(self):
"""Glob/Grep without path defaults to cwd — should be allowed."""
result = _validate_workspace_path("Grep", {"pattern": "foo"}, SDK_CWD)
assert not _is_denied(result)
def test_path_traversal_with_dotdot(self):
result = _validate_workspace_path(
"Read", {"file_path": f"{SDK_CWD}/../../../etc/passwd"}, SDK_CWD
)
assert _is_denied(result)
# ============================================================
# Tool access validation
# ============================================================
class TestToolAccessValidation:
def test_blocked_tools(self):
for tool in ("bash", "shell", "exec", "terminal", "command"):
result = _validate_tool_access(tool, {})
assert _is_denied(result), f"Tool '{tool}' should be blocked"
def test_bash_builtin_blocked(self):
"""SDK built-in Bash (capital) is blocked as defence-in-depth."""
result = _validate_tool_access("Bash", {"command": "echo hello"}, SDK_CWD)
assert _is_denied(result)
assert "Bash" in _reason(result)
def test_workspace_tools_delegate(self):
result = _validate_tool_access(
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD
)
assert not _is_denied(result)
def test_dangerous_pattern_blocked(self):
result = _validate_tool_access("SomeUnknownTool", {"data": "sudo rm -rf /"})
assert _is_denied(result)
def test_safe_unknown_tool_allowed(self):
result = _validate_tool_access("SomeSafeTool", {"data": "hello world"})
assert not _is_denied(result)
# ============================================================
# Deny message quality (ntindle feedback)
# ============================================================
class TestDenyMessageClarity:
"""Deny messages must include [SECURITY] and 'cannot be bypassed'
so the model knows the restriction is enforced, not a suggestion."""
def test_blocked_tool_message(self):
reason = _reason(_validate_tool_access("bash", {}))
assert "[SECURITY]" in reason
assert "cannot be bypassed" in reason
def test_bash_builtin_blocked_message(self):
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
assert "[SECURITY]" in reason
assert "cannot be bypassed" in reason
def test_workspace_path_message(self):
reason = _reason(
_validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD)
)
assert "[SECURITY]" in reason
assert "cannot be bypassed" in reason

View File

@@ -75,7 +75,7 @@ services:
timeout: 5s
retries: 5
rabbitmq:
image: rabbitmq:management
image: rabbitmq:4.1.4
container_name: rabbitmq
healthcheck:
test: rabbitmq-diagnostics -q ping
@@ -88,14 +88,13 @@ services:
<<: *backend-env
ports:
- "5672:5672"
- "15672:15672"
rest_server:
build:
context: ../
dockerfile: autogpt_platform/backend/Dockerfile
target: server
command: ["python", "-m", "backend.rest"]
command: ["rest"] # points to entry in [tool.poetry.scripts] in pyproject.toml
develop:
watch:
- path: ./
@@ -128,7 +127,7 @@ services:
context: ../
dockerfile: autogpt_platform/backend/Dockerfile
target: server
command: ["python", "-m", "backend.exec"]
command: ["executor"] # points to entry in [tool.poetry.scripts] in pyproject.toml
develop:
watch:
- path: ./
@@ -158,12 +157,47 @@ services:
max-size: "10m"
max-file: "3"
copilot_executor:
build:
context: ../
dockerfile: autogpt_platform/backend/Dockerfile
target: server
command: ["python", "-m", "backend.copilot.executor"]
develop:
watch:
- path: ./
target: autogpt_platform/backend/
action: rebuild
depends_on:
redis:
condition: service_healthy
rabbitmq:
condition: service_healthy
db:
condition: service_healthy
migrate:
condition: service_completed_successfully
database_manager:
condition: service_started
<<: *backend-env-files
environment:
<<: *backend-env
ports:
- "8008:8008"
networks:
- app-network
logging:
driver: json-file
options:
max-size: "10m"
max-file: "3"
websocket_server:
build:
context: ../
dockerfile: autogpt_platform/backend/Dockerfile
target: server
command: ["python", "-m", "backend.ws"]
command: ["ws"] # points to entry in [tool.poetry.scripts] in pyproject.toml
develop:
watch:
- path: ./
@@ -196,7 +230,7 @@ services:
context: ../
dockerfile: autogpt_platform/backend/Dockerfile
target: server
command: ["python", "-m", "backend.db"]
command: ["db"] # points to entry in [tool.poetry.scripts] in pyproject.toml
develop:
watch:
- path: ./
@@ -225,7 +259,7 @@ services:
context: ../
dockerfile: autogpt_platform/backend/Dockerfile
target: server
command: ["python", "-m", "backend.scheduler"]
command: ["scheduler"] # points to entry in [tool.poetry.scripts] in pyproject.toml
develop:
watch:
- path: ./
@@ -273,7 +307,7 @@ services:
context: ../
dockerfile: autogpt_platform/backend/Dockerfile
target: server
command: ["python", "-m", "backend.notification"]
command: ["notification"] # points to entry in [tool.poetry.scripts] in pyproject.toml
develop:
watch:
- path: ./

View File

@@ -53,6 +53,12 @@ services:
file: ./docker-compose.platform.yml
service: executor
copilot_executor:
<<: *agpt-services
extends:
file: ./docker-compose.platform.yml
service: copilot_executor
websocket_server:
<<: *agpt-services
extends:
@@ -174,5 +180,6 @@ services:
- deps
- rest_server
- executor
- copilot_executor
- websocket_server
- database_manager

View File

@@ -1,6 +1,8 @@
"use client";
import { SidebarProvider } from "@/components/ui/sidebar";
// TODO: Replace with modern Dialog component when available
import DeleteConfirmDialog from "@/components/__legacy__/delete-confirm-dialog";
import { ChatContainer } from "./components/ChatContainer/ChatContainer";
import { ChatSidebar } from "./components/ChatSidebar/ChatSidebar";
import { MobileDrawer } from "./components/MobileDrawer/MobileDrawer";
@@ -31,6 +33,12 @@ export function CopilotPage() {
handleDrawerOpenChange,
handleSelectSession,
handleNewChat,
// Delete functionality
sessionToDelete,
isDeleting,
handleDeleteClick,
handleConfirmDelete,
handleCancelDelete,
} = useCopilotPage();
if (isUserLoading || !isLoggedIn) {
@@ -48,7 +56,19 @@ export function CopilotPage() {
>
{!isMobile && <ChatSidebar />}
<div className="relative flex h-full w-full flex-col overflow-hidden bg-[#f8f8f9] px-0">
{isMobile && <MobileHeader onOpenDrawer={handleOpenDrawer} />}
{isMobile && (
<MobileHeader
onOpenDrawer={handleOpenDrawer}
showDelete={!!sessionId}
isDeleting={isDeleting}
onDelete={() => {
const session = sessions.find((s) => s.id === sessionId);
if (session) {
handleDeleteClick(session.id, session.title);
}
}}
/>
)}
<div className="flex-1 overflow-hidden">
<ChatContainer
messages={messages}
@@ -75,6 +95,16 @@ export function CopilotPage() {
onOpenChange={handleDrawerOpenChange}
/>
)}
{/* Delete confirmation dialog - rendered at top level for proper z-index on mobile */}
{isMobile && (
<DeleteConfirmDialog
entityType="chat"
entityName={sessionToDelete?.title || "Untitled chat"}
open={!!sessionToDelete}
onOpenChange={(open) => !open && handleCancelDelete()}
onDoDelete={handleConfirmDelete}
/>
)}
</SidebarProvider>
);
}

View File

@@ -1,8 +1,15 @@
"use client";
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
import {
getGetV2ListSessionsQueryKey,
useDeleteV2DeleteSession,
useGetV2ListSessions,
} from "@/app/api/__generated__/endpoints/chat/chat";
import { Button } from "@/components/atoms/Button/Button";
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
import { Text } from "@/components/atoms/Text/Text";
import { toast } from "@/components/molecules/Toast/use-toast";
// TODO: Replace with modern Dialog component when available
import DeleteConfirmDialog from "@/components/__legacy__/delete-confirm-dialog";
import {
Sidebar,
SidebarContent,
@@ -12,18 +19,52 @@ import {
useSidebar,
} from "@/components/ui/sidebar";
import { cn } from "@/lib/utils";
import { PlusCircleIcon, PlusIcon } from "@phosphor-icons/react";
import { PlusCircleIcon, PlusIcon, TrashIcon } from "@phosphor-icons/react";
import { useQueryClient } from "@tanstack/react-query";
import { motion } from "framer-motion";
import { useState } from "react";
import { parseAsString, useQueryState } from "nuqs";
export function ChatSidebar() {
const { state } = useSidebar();
const isCollapsed = state === "collapsed";
const [sessionId, setSessionId] = useQueryState("sessionId", parseAsString);
const [sessionToDelete, setSessionToDelete] = useState<{
id: string;
title: string | null | undefined;
} | null>(null);
const queryClient = useQueryClient();
const { data: sessionsResponse, isLoading: isLoadingSessions } =
useGetV2ListSessions({ limit: 50 });
const { mutate: deleteSession, isPending: isDeleting } =
useDeleteV2DeleteSession({
mutation: {
onSuccess: () => {
// Invalidate sessions list to refetch
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
// If we deleted the current session, clear selection
if (sessionToDelete?.id === sessionId) {
setSessionId(null);
}
setSessionToDelete(null);
},
onError: (error) => {
toast({
title: "Failed to delete chat",
description:
error instanceof Error ? error.message : "An error occurred",
variant: "destructive",
});
setSessionToDelete(null);
},
},
});
const sessions =
sessionsResponse?.status === 200 ? sessionsResponse.data.sessions : [];
@@ -35,6 +76,22 @@ export function ChatSidebar() {
setSessionId(id);
}
function handleDeleteClick(
e: React.MouseEvent,
id: string,
title: string | null | undefined,
) {
e.stopPropagation(); // Prevent session selection
if (isDeleting) return; // Prevent double-click during deletion
setSessionToDelete({ id, title });
}
function handleConfirmDelete() {
if (sessionToDelete) {
deleteSession({ sessionId: sessionToDelete.id });
}
}
function formatDate(dateString: string) {
const date = new Date(dateString);
const now = new Date();
@@ -61,128 +118,152 @@ export function ChatSidebar() {
}
return (
<Sidebar
variant="inset"
collapsible="icon"
className="!top-[50px] !h-[calc(100vh-50px)] border-r border-zinc-100 px-0"
>
{isCollapsed && (
<SidebarHeader
className={cn(
"flex",
isCollapsed
? "flex-row items-center justify-between gap-y-4 md:flex-col md:items-start md:justify-start"
: "flex-row items-center justify-between",
)}
>
<motion.div
key={isCollapsed ? "header-collapsed" : "header-expanded"}
className="flex flex-col items-center gap-3 pt-4"
initial={{ opacity: 0, filter: "blur(3px)" }}
animate={{ opacity: 1, filter: "blur(0px)" }}
transition={{ type: "spring", bounce: 0.2 }}
>
<div className="flex flex-col items-center gap-2">
<SidebarTrigger />
<Button
variant="ghost"
onClick={handleNewChat}
style={{ minWidth: "auto", width: "auto" }}
>
<PlusCircleIcon className="!size-5" />
<span className="sr-only">New Chat</span>
</Button>
</div>
</motion.div>
</SidebarHeader>
)}
<SidebarContent className="gap-4 overflow-y-auto px-4 py-4 [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.1 }}
className="flex items-center justify-between px-3"
>
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="relative left-6">
<SidebarTrigger />
</div>
</motion.div>
)}
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.15 }}
className="mt-4 flex flex-col gap-1"
>
{isLoadingSessions ? (
<div className="flex min-h-[30rem] items-center justify-center py-4">
<LoadingSpinner size="small" className="text-neutral-600" />
</div>
) : sessions.length === 0 ? (
<p className="py-4 text-center text-sm text-neutral-500">
No conversations yet
</p>
) : (
sessions.map((session) => (
<button
key={session.id}
onClick={() => handleSelectSession(session.id)}
className={cn(
"w-full rounded-lg px-3 py-2.5 text-left transition-colors",
session.id === sessionId
? "bg-zinc-100"
: "hover:bg-zinc-50",
)}
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="min-w-0 max-w-full">
<Text
variant="body"
className={cn(
"truncate font-normal",
session.id === sessionId
? "text-zinc-600"
: "text-zinc-800",
)}
>
{session.title || `Untitled chat`}
</Text>
</div>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
</button>
))
<>
<Sidebar
variant="inset"
collapsible="icon"
className="!top-[50px] !h-[calc(100vh-50px)] border-r border-zinc-100 px-0"
>
{isCollapsed && (
<SidebarHeader
className={cn(
"flex",
isCollapsed
? "flex-row items-center justify-between gap-y-4 md:flex-col md:items-start md:justify-start"
: "flex-row items-center justify-between",
)}
</motion.div>
)}
</SidebarContent>
{!isCollapsed && sessionId && (
<SidebarFooter className="shrink-0 bg-zinc-50 p-3 pb-1 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.2 }}
>
<Button
variant="primary"
size="small"
onClick={handleNewChat}
className="w-full"
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
<motion.div
key={isCollapsed ? "header-collapsed" : "header-expanded"}
className="flex flex-col items-center gap-3 pt-4"
initial={{ opacity: 0, filter: "blur(3px)" }}
animate={{ opacity: 1, filter: "blur(0px)" }}
transition={{ type: "spring", bounce: 0.2 }}
>
New Chat
</Button>
</motion.div>
</SidebarFooter>
)}
</Sidebar>
<div className="flex flex-col items-center gap-2">
<SidebarTrigger />
<Button
variant="ghost"
onClick={handleNewChat}
style={{ minWidth: "auto", width: "auto" }}
>
<PlusCircleIcon className="!size-5" />
<span className="sr-only">New Chat</span>
</Button>
</div>
</motion.div>
</SidebarHeader>
)}
<SidebarContent className="gap-4 overflow-y-auto px-4 py-4 [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.1 }}
className="flex items-center justify-between px-3"
>
<Text variant="h3" size="body-medium">
Your chats
</Text>
<div className="relative left-6">
<SidebarTrigger />
</div>
</motion.div>
)}
{!isCollapsed && (
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.15 }}
className="mt-4 flex flex-col gap-1"
>
{isLoadingSessions ? (
<div className="flex min-h-[30rem] items-center justify-center py-4">
<LoadingSpinner size="small" className="text-neutral-600" />
</div>
) : sessions.length === 0 ? (
<p className="py-4 text-center text-sm text-neutral-500">
No conversations yet
</p>
) : (
sessions.map((session) => (
<div
key={session.id}
className={cn(
"group relative w-full rounded-lg transition-colors",
session.id === sessionId
? "bg-zinc-100"
: "hover:bg-zinc-50",
)}
>
<button
onClick={() => handleSelectSession(session.id)}
className="w-full px-3 py-2.5 pr-10 text-left"
>
<div className="flex min-w-0 max-w-full flex-col overflow-hidden">
<div className="min-w-0 max-w-full">
<Text
variant="body"
className={cn(
"truncate font-normal",
session.id === sessionId
? "text-zinc-600"
: "text-zinc-800",
)}
>
{session.title || `Untitled chat`}
</Text>
</div>
<Text variant="small" className="text-neutral-400">
{formatDate(session.updated_at)}
</Text>
</div>
</button>
<button
onClick={(e) =>
handleDeleteClick(e, session.id, session.title)
}
disabled={isDeleting}
className="absolute right-2 top-1/2 -translate-y-1/2 rounded p-1.5 text-zinc-400 opacity-0 transition-all group-hover:opacity-100 hover:bg-red-100 hover:text-red-600 focus-visible:opacity-100 disabled:cursor-not-allowed disabled:opacity-50"
aria-label="Delete chat"
>
<TrashIcon className="h-4 w-4" />
</button>
</div>
))
)}
</motion.div>
)}
</SidebarContent>
{!isCollapsed && sessionId && (
<SidebarFooter className="shrink-0 bg-zinc-50 p-3 pb-1 shadow-[0_-4px_6px_-1px_rgba(0,0,0,0.05)]">
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ duration: 0.2, delay: 0.2 }}
>
<Button
variant="primary"
size="small"
onClick={handleNewChat}
className="w-full"
leftIcon={<PlusIcon className="h-4 w-4" weight="bold" />}
>
New Chat
</Button>
</motion.div>
</SidebarFooter>
)}
</Sidebar>
<DeleteConfirmDialog
entityType="chat"
entityName={sessionToDelete?.title || "Untitled chat"}
open={!!sessionToDelete}
onOpenChange={(open) => !open && setSessionToDelete(null)}
onDoDelete={handleConfirmDelete}
/>
</>
);
}

View File

@@ -1,22 +1,46 @@
import { Button } from "@/components/atoms/Button/Button";
import { NAVBAR_HEIGHT_PX } from "@/lib/constants";
import { ListIcon } from "@phosphor-icons/react";
import { ListIcon, TrashIcon } from "@phosphor-icons/react";
interface Props {
onOpenDrawer: () => void;
showDelete?: boolean;
isDeleting?: boolean;
onDelete?: () => void;
}
export function MobileHeader({ onOpenDrawer }: Props) {
export function MobileHeader({
onOpenDrawer,
showDelete,
isDeleting,
onDelete,
}: Props) {
return (
<Button
variant="icon"
size="icon"
aria-label="Open sessions"
onClick={onOpenDrawer}
className="fixed z-50 bg-white shadow-md"
<div
className="fixed z-50 flex gap-2"
style={{ left: "1rem", top: `${NAVBAR_HEIGHT_PX + 20}px` }}
>
<ListIcon width="1.25rem" height="1.25rem" />
</Button>
<Button
variant="icon"
size="icon"
aria-label="Open sessions"
onClick={onOpenDrawer}
className="bg-white shadow-md"
>
<ListIcon width="1.25rem" height="1.25rem" />
</Button>
{showDelete && onDelete && (
<Button
variant="icon"
size="icon"
aria-label="Delete current chat"
onClick={onDelete}
disabled={isDeleting}
className="bg-white text-red-500 shadow-md hover:bg-red-50 hover:text-red-600 disabled:opacity-50"
>
<TrashIcon width="1.25rem" height="1.25rem" />
</Button>
)}
</div>
);
}

View File

@@ -1,10 +1,15 @@
import { useGetV2ListSessions } from "@/app/api/__generated__/endpoints/chat/chat";
import {
getGetV2ListSessionsQueryKey,
useDeleteV2DeleteSession,
useGetV2ListSessions,
} from "@/app/api/__generated__/endpoints/chat/chat";
import { toast } from "@/components/molecules/Toast/use-toast";
import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useChat } from "@ai-sdk/react";
import { useQueryClient } from "@tanstack/react-query";
import { DefaultChatTransport } from "ai";
import { useEffect, useMemo, useRef, useState } from "react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useChatSession } from "./useChatSession";
import { useLongRunningToolPolling } from "./hooks/useLongRunningToolPolling";
@@ -14,6 +19,11 @@ export function useCopilotPage() {
const { isUserLoading, isLoggedIn } = useSupabase();
const [isDrawerOpen, setIsDrawerOpen] = useState(false);
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
const [sessionToDelete, setSessionToDelete] = useState<{
id: string;
title: string | null | undefined;
} | null>(null);
const queryClient = useQueryClient();
const {
sessionId,
@@ -24,6 +34,30 @@ export function useCopilotPage() {
isCreatingSession,
} = useChatSession();
const { mutate: deleteSessionMutation, isPending: isDeleting } =
useDeleteV2DeleteSession({
mutation: {
onSuccess: () => {
queryClient.invalidateQueries({
queryKey: getGetV2ListSessionsQueryKey(),
});
if (sessionToDelete?.id === sessionId) {
setSessionId(null);
}
setSessionToDelete(null);
},
onError: (error) => {
toast({
title: "Failed to delete chat",
description:
error instanceof Error ? error.message : "An error occurred",
variant: "destructive",
});
setSessionToDelete(null);
},
},
});
const breakpoint = useBreakpoint();
const isMobile =
breakpoint === "base" || breakpoint === "sm" || breakpoint === "md";
@@ -143,6 +177,24 @@ export function useCopilotPage() {
if (isMobile) setIsDrawerOpen(false);
}
const handleDeleteClick = useCallback(
(id: string, title: string | null | undefined) => {
if (isDeleting) return;
setSessionToDelete({ id, title });
},
[isDeleting],
);
const handleConfirmDelete = useCallback(() => {
if (sessionToDelete) {
deleteSessionMutation({ sessionId: sessionToDelete.id });
}
}, [sessionToDelete, deleteSessionMutation]);
const handleCancelDelete = useCallback(() => {
setSessionToDelete(null);
}, []);
return {
sessionId,
messages,
@@ -165,5 +217,11 @@ export function useCopilotPage() {
handleDrawerOpenChange,
handleSelectSession,
handleNewChat,
// Delete functionality
sessionToDelete,
isDeleting,
handleDeleteClick,
handleConfirmDelete,
handleCancelDelete,
};
}

View File

@@ -1151,6 +1151,36 @@
}
},
"/api/chat/sessions/{session_id}": {
"delete": {
"tags": ["v2", "chat", "chat"],
"summary": "Delete Session",
"description": "Delete a chat session.\n\nPermanently removes a chat session and all its messages.\nOnly the owner can delete their sessions.\n\nArgs:\n session_id: The session ID to delete.\n user_id: The authenticated user's ID.\n\nReturns:\n 204 No Content on success.\n\nRaises:\n HTTPException: 404 if session not found or not owned by user.",
"operationId": "deleteV2DeleteSession",
"security": [{ "HTTPBearerJWT": [] }],
"parameters": [
{
"name": "session_id",
"in": "path",
"required": true,
"schema": { "type": "string", "title": "Session Id" }
}
],
"responses": {
"204": { "description": "Successful Response" },
"401": {
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
},
"404": { "description": "Session not found or access denied" },
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
}
}
}
}
},
"get": {
"tags": ["v2", "chat", "chat"],
"summary": "Get Session",

View File

@@ -115,7 +115,7 @@ const DialogFooter = ({
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn(
"flex flex-col-reverse sm:flex-row sm:justify-end sm:space-x-2",
"flex flex-col-reverse gap-2 sm:flex-row sm:justify-end",
className,
)}
{...props}

Some files were not shown because too many files have changed in this diff Show More