Compare commits

..

9 Commits

Author SHA1 Message Date
Nick Tindle
c42aab3482 Merge branch 'dev' into ntindle/google-issues-fix
Resolved conflicts:
- executor/utils.py: kept auto-credentials validation + dev's comment update
- graph_test.py: kept both auto-credentials tests AND MCP deduplication tests
2026-02-16 00:09:35 -06:00
Nick Tindle
e7705427bb fix: add is_auto_credential and input_field_name to auto credentials schema
Post-refactor fix: these fields were moved from data/block.py to blocks/_base.py in #12068
2026-02-12 16:13:23 -06:00
Nicholas Tindle
201ec5aa3a Merge branch 'dev' into ntindle/google-issues-fix 2026-02-12 16:12:20 -06:00
Nicholas Tindle
2b8134a711 Merge branch 'dev' into ntindle/google-issues-fix 2026-02-09 13:46:37 -06:00
Nicholas Tindle
90b3b5ba16 fix(backend): Fix misplaced section header in graph_test.py
Move the _reassign_ids section comment to above the actual _reassign_ids
tests, and label the combine() tests correctly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 16:11:47 -06:00
Nicholas Tindle
f4f81bc4fc fix(backend): Remove _credentials_id key on fork instead of setting to None
Setting _credentials_id to None on fork was ambiguous — both "forked,
needs re-auth" and "chained data from upstream" were represented as None.
This caused _acquire_auto_credentials to silently skip credential
acquisition for forked agents, leading to confusing TypeErrors at runtime.

Now the key is deleted entirely, making the three states unambiguous:
- Present with value: user-selected credentials
- Present as None: chained data from upstream block
- Absent: forked/needs re-authentication

Also adds pre-run validation for the missing key case and makes error
messages provider-agnostic.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 17:34:16 -06:00
Nicholas Tindle
c5abc01f25 fix(backend): Add error handling for auto-credentials store lookup
Wrap get_creds_by_id call in try/except in the auto-credentials
validation path to match the error handling pattern used for regular
credentials.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 16:53:29 -06:00
Nicholas Tindle
8b7053c1de merge: Resolve conflicts with dev (PR #11986 graph model refactor)
Adapt auto-credentials filtering to dev's refactored graph model:
- aggregate_credentials_inputs() now returns 3-tuples (field_info, node_pairs, is_required)
- credentials_input_schema moved to GraphModel, builds JSON schema directly
- Update regular/auto_credentials_inputs properties for 3-tuple format
- Update test mocks and assertions for new tuple format and class hierarchy

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 16:39:57 -06:00
Nicholas Tindle
e00c1202ad fix(platform): Fix Google Drive auto-credentials handling across the platform
- Tag auto-credentials with `is_auto_credential` and `input_field_name` on `CredentialsFieldInfo` to distinguish them from regular user-provided credentials
- Add `regular_credentials_inputs` and `auto_credentials_inputs` properties to `Graph` so UI schemas, CoPilot, and library presets only surface regular credentials
- Extract `_acquire_auto_credentials()` helper in executor to resolve embedded `_credentials_id` at execution time with proper lock management
- Validate auto-credentials ownership in `_validate_node_input_credentials()` to catch stale/missing credentials before execution
- Clear `_credentials_id` in `_reassign_ids()` on graph fork so cloned agents require re-authentication
- Propagate `is_auto_credential` through `combine()` and `discriminate()` on `CredentialsFieldInfo`
- Add `referrerPolicy: "no-referrer-when-downgrade"` to Google API script loading to fix Firefox API key validation
- Comprehensive test coverage for all new behavior

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-06 16:08:53 -06:00
342 changed files with 24831 additions and 22429 deletions

View File

@@ -41,18 +41,13 @@ jobs:
ports:
- 6379:6379
rabbitmq:
image: rabbitmq:4.1.4
image: rabbitmq:3.12-management
ports:
- 5672:5672
- 15672:15672
env:
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
options: >-
--health-cmd "rabbitmq-diagnostics -q ping"
--health-interval 30s
--health-timeout 10s
--health-retries 5
--health-start-period 10s
clamav:
image: clamav/clamav-debian:latest
ports:

View File

@@ -6,16 +6,10 @@ on:
paths:
- ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**"
- "autogpt_platform/backend/Dockerfile"
- "autogpt_platform/docker-compose.yml"
- "autogpt_platform/docker-compose.platform.yml"
pull_request:
paths:
- ".github/workflows/platform-frontend-ci.yml"
- "autogpt_platform/frontend/**"
- "autogpt_platform/backend/Dockerfile"
- "autogpt_platform/docker-compose.yml"
- "autogpt_platform/docker-compose.platform.yml"
merge_group:
workflow_dispatch:

4
.gitignore vendored
View File

@@ -180,6 +180,4 @@ autogpt_platform/backend/settings.py
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
.next
# Implementation plans (generated by AI agents)
plans/
.next

View File

@@ -1,10 +1,3 @@
default_install_hook_types:
- pre-commit
- pre-push
- post-checkout
default_stages: [pre-commit]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
@@ -24,7 +17,6 @@ repos:
name: Detect secrets
description: Detects high entropy strings that are likely to be passwords.
files: ^autogpt_platform/
exclude: pnpm-lock\.yaml$
stages: [pre-push]
- repo: local
@@ -34,106 +26,49 @@ repos:
- id: poetry-install
name: Check & Install dependencies - AutoGPT Platform - Backend
alias: poetry-install-platform-backend
entry: poetry -C autogpt_platform/backend install
# include autogpt_libs source (since it's a path dependency)
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$" || exit 0;
poetry -C autogpt_platform/backend install
'
always_run: true
files: ^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - AutoGPT Platform - Libs
alias: poetry-install-platform-libs
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/autogpt_libs/poetry\.lock$" || exit 0;
poetry -C autogpt_platform/autogpt_libs install
'
always_run: true
entry: poetry -C autogpt_platform/autogpt_libs install
files: ^autogpt_platform/autogpt_libs/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: pnpm-install
name: Check & Install dependencies - AutoGPT Platform - Frontend
alias: pnpm-install-platform-frontend
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/frontend/pnpm-lock\.yaml$" || exit 0;
pnpm --prefix autogpt_platform/frontend install
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - AutoGPT
alias: poetry-install-classic-autogpt
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
poetry -C classic/original_autogpt install
'
entry: poetry -C classic/original_autogpt install
# include forge source (since it's a path dependency)
always_run: true
files: ^classic/(original_autogpt|forge)/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Forge
alias: poetry-install-classic-forge
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
poetry -C classic/forge install
'
always_run: true
entry: poetry -C classic/forge install
files: ^classic/forge/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Benchmark
alias: poetry-install-classic-benchmark
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
poetry -C classic/benchmark install
'
always_run: true
entry: poetry -C classic/benchmark install
files: ^classic/benchmark/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- repo: local
# For proper type checking, Prisma client must be up-to-date.
@@ -141,54 +76,12 @@ repos:
- id: prisma-generate
name: Prisma Generate - AutoGPT Platform - Backend
alias: prisma-generate-platform-backend
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema\.prisma)$" || exit 0;
cd autogpt_platform/backend
&& poetry run prisma generate
&& poetry run gen-prisma-stub
'
entry: bash -c 'cd autogpt_platform/backend && poetry run prisma generate'
# include everything that triggers poetry install + the prisma schema
always_run: true
files: ^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema.prisma)$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: export-api-schema
name: Export API schema - AutoGPT Platform - Backend -> Frontend
alias: export-api-schema-platform
entry: >
bash -c '
cd autogpt_platform/backend
&& poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
&& cd ../frontend
&& pnpm prettier --write ./src/app/api/openapi.json
'
files: ^autogpt_platform/backend/
language: system
pass_filenames: false
- id: generate-api-client
name: Generate API client - AutoGPT Platform - Frontend
alias: generate-api-client-platform-frontend
entry: >
bash -c '
SCHEMA=autogpt_platform/frontend/src/app/api/openapi.json;
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --quiet "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF" -- "$SCHEMA" && exit 0
else
git diff --quiet HEAD -- "$SCHEMA" && exit 0
fi;
cd autogpt_platform/frontend && pnpm generate:api
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2

View File

@@ -190,8 +190,5 @@ ZEROBOUNCE_API_KEY=
POSTHOG_API_KEY=
POSTHOG_HOST=https://eu.i.posthog.com
# Tally Form Integration (pre-populate business understanding on signup)
TALLY_API_KEY=
# Other Services
AUTOMOD_API_KEY=

View File

@@ -53,6 +53,63 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
RUN poetry run prisma generate && poetry run gen-prisma-stub
# ============================== BACKEND SERVER ============================== #
FROM debian:13-slim AS server
WORKDIR /app
ENV POETRY_HOME=/opt/poetry \
POETRY_NO_INTERACTION=1 \
POETRY_VIRTUALENVS_CREATE=true \
POETRY_VIRTUALENVS_IN_PROJECT=true \
DEBIAN_FRONTEND=noninteractive
ENV PATH=/opt/poetry/bin:$PATH
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
# for the bash_exec MCP tool.
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \
python3-pip \
ffmpeg \
imagemagick \
jq \
ripgrep \
tree \
bubblewrap \
&& rm -rf /var/lib/apt/lists/*
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
# Copy Node.js installation for Prisma
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
COPY --from=builder /usr/bin/npm /usr/bin/npm
COPY --from=builder /usr/bin/npx /usr/bin/npx
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
WORKDIR /app/autogpt_platform/backend
# Copy only the .venv from builder (not the entire /app directory)
# The .venv includes the generated Prisma client
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
# Copy dependency files + autogpt_libs (path dependency)
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
# Copy backend code + docs (for Copilot docs search)
COPY autogpt_platform/backend ./
COPY docs /app/docs
RUN poetry install --no-ansi --only-root
ENV PORT=8000
CMD ["poetry", "run", "rest"]
# =============================== DB MIGRATOR =============================== #
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
@@ -84,59 +141,3 @@ COPY autogpt_platform/backend/schema.prisma ./
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
COPY autogpt_platform/backend/migrations ./migrations
# ============================== BACKEND SERVER ============================== #
FROM debian:13-slim AS server
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
# for the bash_exec MCP tool.
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \
python3-pip \
ffmpeg \
imagemagick \
jq \
ripgrep \
tree \
bubblewrap \
&& rm -rf /var/lib/apt/lists/*
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
# Copy Node.js installation for Prisma
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
COPY --from=builder /usr/bin/npm /usr/bin/npm
COPY --from=builder /usr/bin/npx /usr/bin/npx
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
WORKDIR /app/autogpt_platform/backend
# Copy only the .venv from builder (not the entire /app directory)
# The .venv includes the generated Prisma client
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
# Copy dependency files + autogpt_libs (path dependency)
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
# Copy backend code + docs (for Copilot docs search)
COPY autogpt_platform/backend ./
COPY docs /app/docs
# Install the project package to create entry point scripts in .venv/bin/
# (e.g., rest, executor, ws, db, scheduler, notification - see [tool.poetry.scripts])
RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
poetry install --no-ansi --only-root
ENV PORT=8000
CMD ["rest"]

View File

@@ -1,9 +1,4 @@
"""Common test fixtures for server tests.
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
setup_test_user, and setup_admin_user are defined in the parent conftest.py
(backend/conftest.py) and are available here automatically.
"""
"""Common test fixtures for server tests."""
import pytest
from pytest_snapshot.plugin import Snapshot
@@ -16,6 +11,54 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
return snapshot
@pytest.fixture
def test_user_id() -> str:
"""Test user ID fixture."""
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture
def admin_user_id() -> str:
"""Admin user ID fixture."""
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
@pytest.fixture
def target_user_id() -> str:
"""Target user ID fixture."""
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
@pytest.fixture
async def setup_test_user(test_user_id):
"""Create test user in database before tests."""
from backend.data.user import get_or_create_user
# Create the test user in the database using JWT token format
user_data = {
"sub": test_user_id,
"email": "test@example.com",
"user_metadata": {"name": "Test User"},
}
await get_or_create_user(user_data)
return test_user_id
@pytest.fixture
async def setup_admin_user(admin_user_id):
"""Create admin user in database before tests."""
from backend.data.user import get_or_create_user
# Create the admin user in the database using JWT token format
user_data = {
"sub": admin_user_id,
"email": "test-admin@example.com",
"user_metadata": {"name": "Test Admin"},
}
await get_or_create_user(user_data)
return admin_user_id
@pytest.fixture
def mock_jwt_user(test_user_id):
"""Provide mock JWT payload for regular user testing."""

View File

@@ -88,23 +88,20 @@ async def require_auth(
)
def require_permission(*permissions: APIKeyPermission):
def require_permission(permission: APIKeyPermission):
"""
Dependency function for checking required permissions.
All listed permissions must be present.
Dependency function for checking specific permissions
(works with API keys and OAuth tokens)
"""
async def check_permissions(
async def check_permission(
auth: APIAuthorizationInfo = Security(require_auth),
) -> APIAuthorizationInfo:
missing = [p for p in permissions if p not in auth.scopes]
if missing:
if permission not in auth.scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permission(s): "
f"{', '.join(p.value for p in missing)}",
detail=f"Missing required permission: {permission.value}",
)
return auth
return check_permissions
return check_permission

View File

@@ -18,7 +18,6 @@ from backend.data import user as user_db
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.executor.utils import add_graph_execution
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.settings import Settings
from .integrations import integrations_router
@@ -96,43 +95,6 @@ async def execute_graph_block(
return output
@v1_router.post(
path="/graphs",
tags=["graphs"],
status_code=201,
dependencies=[
Security(
require_permission(
APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY
)
)
],
)
async def create_graph(
graph: graph_db.Graph,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY)
),
) -> graph_db.GraphModel:
"""
Create a new agent graph.
The graph will be validated and assigned a new ID.
It is automatically added to the user's library.
"""
from backend.api.features.library import db as library_db
graph_model = graph_db.make_graph_model(graph, auth.user_id)
graph_model.reassign_ids(user_id=auth.user_id, reassign_graph_id=True)
graph_model.validate_graph(for_run=False)
await graph_db.create_graph(graph_model, user_id=auth.user_id)
await library_db.create_library_agent(graph_model, auth.user_id)
activated_graph = await on_graph_activate(graph_model, user_id=auth.user_id)
return activated_graph
@v1_router.post(
path="/graphs/{graph_id}/execute/{graph_version}",
tags=["graphs"],

View File

@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
from pydantic import BaseModel, Field
from backend.api.external.middleware import require_permission
from backend.copilot.model import ChatSession
from backend.copilot.tools import find_agent_tool, run_agent_tool
from backend.copilot.tools.models import ToolResponseBase
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
from backend.api.features.chat.tools.models import ToolResponseBase
from backend.data.auth.base import APIAuthorizationInfo
logger = logging.getLogger(__name__)

View File

@@ -1,17 +1,15 @@
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from difflib import SequenceMatcher
from typing import Any, Sequence, get_args, get_origin
from typing import Sequence
import prisma
from prisma.enums import ContentType
from prisma.models import mv_suggested_blocks
import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.blocks import load_all_blocks
from backend.blocks._base import (
AnyBlockSchema,
@@ -21,6 +19,7 @@ from backend.blocks._base import (
BlockType,
)
from backend.blocks.llm import LlmModel
from backend.data.db import query_raw_with_schema
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
from backend.util.models import Pagination
@@ -43,16 +42,6 @@ MAX_LIBRARY_AGENT_RESULTS = 100
MAX_MARKETPLACE_AGENT_RESULTS = 100
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
# Boost blocks over marketplace agents in search results
BLOCK_SCORE_BOOST = 50.0
# Block IDs to exclude from search results
EXCLUDED_BLOCK_IDS = frozenset(
{
"e189baac-8c20-45a1-94a7-55177ea42565", # AgentExecutorBlock
}
)
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
@@ -75,8 +64,8 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
# Skip disabled and excluded blocks
if block.disabled or block.id in EXCLUDED_BLOCK_IDS:
# Skip disabled blocks
if block.disabled:
continue
# Skip blocks that don't have categories (all should have at least one)
if not block.categories:
@@ -127,9 +116,6 @@ def get_blocks(
# Skip disabled blocks
if block.disabled:
continue
# Skip excluded blocks
if block.id in EXCLUDED_BLOCK_IDS:
continue
# Skip blocks that don't match the category
if category and category not in {c.name.lower() for c in block.categories}:
continue
@@ -269,25 +255,14 @@ async def _build_cached_search_results(
"my_agents": 0,
}
# Use hybrid search when query is present, otherwise list all blocks
if (include_blocks or include_integrations) and normalized_query:
block_results, block_total, integration_total = await _hybrid_search_blocks(
query=search_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
elif include_blocks or include_integrations:
# No query - list all blocks using in-memory approach
block_results, block_total, integration_total = _collect_block_results(
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
block_results, block_total, integration_total = _collect_block_results(
normalized_query=normalized_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
if include_library_agents:
library_response = await library_db.list_library_agents(
@@ -332,14 +307,10 @@ async def _build_cached_search_results(
def _collect_block_results(
*,
normalized_query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
"""
Collect all blocks for listing (no search query).
All blocks get BLOCK_SCORE_BOOST to prioritize them over marketplace agents.
"""
results: list[_ScoredItem] = []
block_count = 0
integration_count = 0
@@ -352,10 +323,6 @@ def _collect_block_results(
if block.disabled:
continue
# Skip excluded blocks
if block.id in EXCLUDED_BLOCK_IDS:
continue
block_info = block.get_info()
credentials = list(block.input_schema.get_credentials_fields().values())
is_integration = len(credentials) > 0
@@ -365,6 +332,10 @@ def _collect_block_results(
if not is_integration and not include_blocks:
continue
score = _score_block(block, block_info, normalized_query)
if not _should_include_item(score, normalized_query):
continue
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
integration_count += 1
@@ -375,122 +346,8 @@ def _collect_block_results(
_ScoredItem(
item=block_info,
filter_type=filter_type,
score=BLOCK_SCORE_BOOST,
sort_key=block_info.name.lower(),
)
)
return results, block_count, integration_count
async def _hybrid_search_blocks(
*,
query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
"""
Search blocks using hybrid search with builder-specific filtering.
Uses unified_hybrid_search for semantic + lexical search, then applies
post-filtering for block/integration types and scoring adjustments.
Scoring:
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
to prioritize blocks over marketplace agents in combined results
- +30 for exact name match, +15 for prefix name match
- +20 if the block has an LlmModel field and the query matches an LLM model name
Args:
query: The search query string
include_blocks: Whether to include regular blocks
include_integrations: Whether to include integration blocks
Returns:
Tuple of (scored_items, block_count, integration_count)
"""
results: list[_ScoredItem] = []
block_count = 0
integration_count = 0
if not include_blocks and not include_integrations:
return results, block_count, integration_count
normalized_query = query.strip().lower()
# Fetch more results to account for post-filtering
search_results, _ = await unified_hybrid_search(
query=query,
content_types=[ContentType.BLOCK],
page=1,
page_size=150,
min_score=0.10,
)
# Load all blocks for getting BlockInfo
all_blocks = load_all_blocks()
for result in search_results:
block_id = result["content_id"]
# Skip excluded blocks
if block_id in EXCLUDED_BLOCK_IDS:
continue
metadata = result.get("metadata", {})
hybrid_score = result.get("relevance", 0.0)
# Get the actual block class
if block_id not in all_blocks:
continue
block_cls = all_blocks[block_id]
block: AnyBlockSchema = block_cls()
if block.disabled:
continue
# Check block/integration filter using metadata
is_integration = metadata.get("is_integration", False)
if is_integration and not include_integrations:
continue
if not is_integration and not include_blocks:
continue
# Get block info
block_info = block.get_info()
# Calculate final score: scale hybrid score and add builder-specific bonuses
# Hybrid scores are 0-1, builder scores were 0-200+
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
# Add LLM model match bonus
has_llm_field = metadata.get("has_llm_model_field", False)
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
final_score += 20
# Add exact/prefix match bonus for deterministic tie-breaking
name = block_info.name.lower()
if name == normalized_query:
final_score += 30
elif name.startswith(normalized_query):
final_score += 15
# Track counts
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
integration_count += 1
else:
block_count += 1
results.append(
_ScoredItem(
item=block_info,
filter_type=filter_type,
score=final_score,
sort_key=name,
score=score,
sort_key=_get_item_name(block_info),
)
)
@@ -615,8 +472,6 @@ async def _get_static_counts():
block: AnyBlockSchema = block_type()
if block.disabled:
continue
if block.id in EXCLUDED_BLOCK_IDS:
continue
all_blocks += 1
@@ -643,25 +498,47 @@ async def _get_static_counts():
}
def _contains_type(annotation: Any, target: type) -> bool:
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
if annotation is target:
return True
origin = get_origin(annotation)
if origin is None:
return False
return any(_contains_type(arg, target) for arg in get_args(annotation))
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values():
if _contains_type(field.annotation, LlmModel):
if field.annotation == LlmModel:
# Check if query matches any value in llm_models
if any(query in name for name in llm_models):
return True
return False
def _score_block(
block: AnyBlockSchema,
block_info: BlockInfo,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = block_info.name.lower()
description = block_info.description.lower()
score = _score_primary_fields(name, description, normalized_query)
category_text = " ".join(
category.get("category", "").lower() for category in block_info.categories
)
score += _score_additional_field(category_text, normalized_query, 12, 6)
credentials_info = block.input_schema.get_credentials_fields_info().values()
provider_names = [
provider.value.lower()
for info in credentials_info
for provider in info.provider
]
provider_text = " ".join(provider_names)
score += _score_additional_field(provider_text, normalized_query, 15, 6)
if _matches_llm_model(block.input_schema, normalized_query):
score += 20
return score
def _score_library_agent(
agent: library_model.LibraryAgent,
normalized_query: str,
@@ -768,20 +645,31 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
return providers
@cached(ttl_seconds=3600, shared_cache=True)
@cached(ttl_seconds=3600)
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
"""Return the most-executed blocks from the last 14 days.
suggested_blocks = []
# Sum the number of executions for each block type
# Prisma cannot group by nested relations, so we do a raw query
# Calculate the cutoff timestamp
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
Queries the mv_suggested_blocks materialized view (refreshed hourly via pg_cron)
and returns the top `count` blocks sorted by execution count, excluding
Input/Output/Agent block types and blocks in EXCLUDED_BLOCK_IDS.
"""
results = await mv_suggested_blocks.prisma().find_many()
results = await query_raw_with_schema(
"""
SELECT
agent_node."agentBlockId" AS block_id,
COUNT(execution.id) AS execution_count
FROM {schema_prefix}"AgentNodeExecution" execution
JOIN {schema_prefix}"AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
WHERE execution."endedTime" >= $1::timestamp
GROUP BY agent_node."agentBlockId"
ORDER BY execution_count DESC;
""",
timestamp_threshold,
)
# Get the top blocks based on execution count
# But ignore Input, Output, Agent, and excluded blocks
# But ignore Input and Output blocks
blocks: list[tuple[BlockInfo, int]] = []
execution_counts = {row.block_id: row.execution_count for row in results}
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
@@ -791,9 +679,11 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
BlockType.AGENT,
):
continue
if block.id in EXCLUDED_BLOCK_IDS:
continue
execution_count = execution_counts.get(block.id, 0)
# Find the execution count for this block
execution_count = next(
(row["execution_count"] for row in results if row["block_id"] == block.id),
0,
)
blocks.append((block.get_info(), execution_count))
# Sort blocks by execution count
blocks.sort(key=lambda x: x[1], reverse=True)

View File

@@ -27,6 +27,7 @@ class SearchEntry(BaseModel):
# Suggestions
class SuggestionsResponse(BaseModel):
otto_suggestions: list[str]
recent_searches: list[SearchEntry]
providers: list[ProviderName]
top_blocks: list[BlockInfo]

View File

@@ -1,5 +1,5 @@
import logging
from typing import Annotated, Sequence, cast, get_args
from typing import Annotated, Sequence
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
@@ -10,8 +10,6 @@ from backend.util.models import Pagination
from . import db as builder_db
from . import model as builder_model
VALID_FILTER_VALUES = get_args(builder_model.FilterType)
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
@@ -51,6 +49,11 @@ async def get_suggestions(
Get all suggestions for the Blocks Menu.
"""
return builder_model.SuggestionsResponse(
otto_suggestions=[
"What blocks do I need to get started?",
"Help me create a list",
"Help me feed my data to Google Maps",
],
recent_searches=await builder_db.get_recent_searches(user_id),
providers=[
ProviderName.TWITTER,
@@ -148,7 +151,7 @@ async def get_providers(
async def search(
user_id: Annotated[str, fastapi.Security(get_user_id)],
search_query: Annotated[str | None, fastapi.Query()] = None,
filter: Annotated[str | None, fastapi.Query()] = None,
filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
search_id: Annotated[str | None, fastapi.Query()] = None,
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1,
@@ -157,20 +160,9 @@ async def search(
"""
Search for blocks (including integrations), marketplace agents, and user library agents.
"""
# Parse and validate filter parameter
filters: list[builder_model.FilterType]
if filter:
filter_values = [f.strip() for f in filter.split(",")]
invalid_filters = [f for f in filter_values if f not in VALID_FILTER_VALUES]
if invalid_filters:
raise fastapi.HTTPException(
status_code=400,
detail=f"Invalid filter value(s): {', '.join(invalid_filters)}. "
f"Valid values are: {', '.join(VALID_FILTER_VALUES)}",
)
filters = cast(list[builder_model.FilterType], filter_values)
else:
filters = [
# If no filters are provided, then we will return all types
if not filter:
filter = [
"blocks",
"integrations",
"marketplace_agents",
@@ -182,7 +174,7 @@ async def search(
cached_results = await builder_db.get_sorted_search_results(
user_id=user_id,
search_query=search_query,
filters=filters,
filters=filter,
by_creator=by_creator,
)
@@ -204,7 +196,7 @@ async def search(
user_id,
builder_model.SearchEntry(
search_query=search_query,
filter=filters,
filter=filter,
by_creator=by_creator,
search_id=search_id,
),

View File

@@ -0,0 +1,368 @@
"""Redis Streams consumer for operation completion messages.
This module provides a consumer (ChatCompletionConsumer) that listens for
completion notifications (OperationCompleteMessage) from external services
(like Agent Generator) and triggers the appropriate stream registry and
chat service updates via process_operation_success/process_operation_failure.
Why Redis Streams instead of RabbitMQ?
--------------------------------------
While the project typically uses RabbitMQ for async task queues (e.g., execution
queue), Redis Streams was chosen for chat completion notifications because:
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
Streams (via stream_registry) for message persistence and replay. Using Redis
Streams for completion notifications keeps all chat streaming infrastructure
in one system, simplifying operations and reducing cross-system coordination.
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
allowing consumers to replay missed messages after reconnection. This aligns
with the SSE reconnection pattern where clients can resume from last_message_id.
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
recovering from dead consumers - ideal for the completion callback pattern.
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
stream_registry) provides lower latency than an additional RabbitMQ hop.
5. **Atomicity with Task State**: Completion processing often needs to update
task metadata stored in Redis. Keeping both in Redis enables simpler
transactional semantics without distributed coordination.
The consumer uses Redis Streams with consumer groups for reliable message
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
stale pending messages from dead consumers.
"""
import asyncio
import logging
import os
import uuid
from typing import Any
import orjson
from prisma import Prisma
from pydantic import BaseModel
from redis.exceptions import ResponseError
from backend.data.redis_client import get_redis_async
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from Redis Streams.
This consumer initializes its own Prisma client in start() to ensure
database operations work correctly within this async context.
Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure.
"""
def __init__(self):
self._consumer_task: asyncio.Task | None = None
self._running = False
self._prisma: Prisma | None = None
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
# Create consumer group if it doesn't exist
try:
redis = await get_redis_async()
await redis.xgroup_create(
config.stream_completion_name,
config.stream_consumer_group,
id="0",
mkstream=True,
)
logger.info(
f"Created consumer group '{config.stream_consumer_group}' "
f"on stream '{config.stream_completion_name}'"
)
except ResponseError as e:
if "BUSYGROUP" in str(e):
logger.debug(
f"Consumer group '{config.stream_consumer_group}' already exists"
)
else:
raise
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info(
f"Chat completion consumer started (consumer: {self._consumer_name})"
)
async def _ensure_prisma(self) -> Prisma:
"""Lazily initialize Prisma client on first use."""
if self._prisma is None:
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
self._prisma = Prisma(datasource={"url": database_url})
await self._prisma.connect()
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
return self._prisma
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
if self._prisma:
await self._prisma.disconnect()
self._prisma = None
logger.info("[COMPLETION] Consumer Prisma client disconnected")
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
block_timeout = 5000 # milliseconds
while self._running and retry_count < max_retries:
try:
redis = await get_redis_async()
# Reset retry count on successful connection
retry_count = 0
while self._running:
# First, claim any stale pending messages from dead consumers
# Redis does NOT auto-redeliver pending messages; we must explicitly
# claim them using XAUTOCLAIM
try:
claimed_result = await redis.xautoclaim(
name=config.stream_completion_name,
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
min_idle_time=config.stream_claim_min_idle_ms,
start_id="0-0",
count=10,
)
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
if claimed_result and len(claimed_result) >= 2:
claimed_entries = claimed_result[1]
if claimed_entries:
logger.info(
f"Claimed {len(claimed_entries)} stale pending messages"
)
for entry_id, data in claimed_entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except Exception as e:
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
# Read new messages from the stream
messages = await redis.xreadgroup(
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
streams={config.stream_completion_name: ">"},
block=block_timeout,
count=10,
)
if not messages:
continue
for stream_name, entries in messages:
for entry_id, data in entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _process_entry(
self, redis: Any, entry_id: str, data: dict[str, Any]
) -> None:
"""Process a single stream entry and acknowledge it on success.
Args:
redis: Redis client connection
entry_id: The stream entry ID
data: The entry data dict
"""
try:
# Handle the message
message_data = data.get("data")
if message_data:
await self._handle_message(
message_data.encode()
if isinstance(message_data, str)
else message_data
)
# Acknowledge the message after successful processing
await redis.xack(
config.stream_completion_name,
config.stream_consumer_group,
entry_id,
)
except Exception as e:
logger.error(
f"Error processing completion message {entry_id}: {e}",
exc_info=True,
)
# Message remains in pending state and will be claimed by
# XAUTOCLAIM after min_idle_time expires
async def _handle_message(self, body: bytes) -> None:
"""Handle a completion message using our own Prisma client."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"[COMPLETION] Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"[COMPLETION] Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
logger.info(
f"[COMPLETION] Found task: task_id={task.task_id}, "
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
)
# Guard against empty task fields
if not task.task_id or not task.session_id or not task.tool_call_id:
logger.error(
f"[COMPLETION] Task has empty critical fields! "
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
f"tool_call_id={task.tool_call_id!r}"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
prisma = await self._ensure_prisma()
await process_operation_success(task, message.result, prisma)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
prisma = await self._ensure_prisma()
await process_operation_failure(task, message.error, prisma)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message to Redis Streams.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
redis = await get_redis_async()
await redis.xadd(
config.stream_completion_name,
{"data": message.model_dump_json()},
maxlen=config.stream_max_length,
)
logger.info(f"Published completion for operation {operation_id}")

View File

@@ -0,0 +1,344 @@
"""Shared completion handling for operation success and failure.
This module provides common logic for handling operation completion from both:
- The Redis Streams consumer (completion_consumer.py)
- The HTTP webhook endpoint (routes.py)
"""
import logging
from typing import Any
import orjson
from prisma import Prisma
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Tools that produce agent_json that needs to be saved to library
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
# Keys that should be stripped from agent_json when returning in error responses
SENSITIVE_KEYS = frozenset(
{
"api_key",
"apikey",
"api_secret",
"password",
"secret",
"credentials",
"credential",
"token",
"access_token",
"refresh_token",
"private_key",
"privatekey",
"auth",
"authorization",
}
)
def _sanitize_agent_json(obj: Any) -> Any:
"""Recursively sanitize agent_json by removing sensitive keys.
Args:
obj: The object to sanitize (dict, list, or primitive)
Returns:
Sanitized copy with sensitive keys removed/redacted
"""
if isinstance(obj, dict):
return {
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
return [_sanitize_agent_json(item) for item in obj]
else:
return obj
class ToolMessageUpdateError(Exception):
"""Raised when updating a tool message in the database fails."""
pass
async def _update_tool_message(
session_id: str,
tool_call_id: str,
content: str,
prisma_client: Prisma | None,
) -> None:
"""Update tool message in database.
Args:
session_id: The session ID
tool_call_id: The tool call ID to update
content: The new content for the message
prisma_client: Optional Prisma client. If None, uses chat_service.
Raises:
ToolMessageUpdateError: If the database update fails. The caller should
handle this to avoid marking the task as completed with inconsistent state.
"""
try:
if prisma_client:
# Use provided Prisma client (for consumer with its own connection)
updated_count = await prisma_client.chatmessage.update_many(
where={
"sessionId": session_id,
"toolCallId": tool_call_id,
},
data={"content": content},
)
# Check if any rows were updated - 0 means message not found
if updated_count == 0:
raise ToolMessageUpdateError(
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
)
else:
# Use service function (for webhook endpoint)
await chat_service._update_pending_operation(
session_id=session_id,
tool_call_id=tool_call_id,
result=content,
)
except ToolMessageUpdateError:
raise
except Exception as e:
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
raise ToolMessageUpdateError(
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
) from e
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
"""Serialize result to JSON string with sensible defaults.
Args:
result: The result to serialize. Can be a dict, list, string,
number, boolean, or None.
Returns:
JSON string representation of the result. Returns '{"status": "completed"}'
only when result is explicitly None.
"""
if isinstance(result, str):
return result
if result is None:
return '{"status": "completed"}'
return orjson.dumps(result).decode("utf-8")
async def _save_agent_from_result(
result: dict[str, Any],
user_id: str | None,
tool_name: str,
) -> dict[str, Any]:
"""Save agent to library if result contains agent_json.
Args:
result: The result dict that may contain agent_json
user_id: The user ID to save the agent for
tool_name: The tool name (create_agent or edit_agent)
Returns:
Updated result dict with saved agent details, or original result if no agent_json
"""
if not user_id:
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
return result
agent_json = result.get("agent_json")
if not agent_json:
logger.warning(
f"[COMPLETION] {tool_name} completed but no agent_json in result"
)
return result
try:
from .tools.agent_generator import save_agent_to_library
is_update = tool_name == "edit_agent"
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update
)
logger.info(
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
)
# Return a response similar to AgentSavedResponse
return {
"type": "agent_saved",
"message": f"Agent '{created_graph.name}' has been saved to your library!",
"agent_id": created_graph.id,
"agent_name": created_graph.name,
"library_agent_id": library_agent.id,
"library_agent_link": f"/library/agents/{library_agent.id}",
"agent_page_link": f"/build?flowID={created_graph.id}",
}
except Exception as e:
logger.error(
f"[COMPLETION] Failed to save agent to library: {e}",
exc_info=True,
)
# Return error but don't fail the whole operation
# Sanitize agent_json to remove sensitive keys before returning
return {
"type": "error",
"message": f"Agent was generated but failed to save: {str(e)}",
"error": str(e),
"agent_json": _sanitize_agent_json(agent_json),
}
async def process_operation_success(
task: stream_registry.ActiveTask,
result: dict | str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle successful operation completion.
Publishes the result to the stream registry, updates the database,
generates LLM continuation, and marks the task as completed.
Args:
task: The active task that completed
result: The result data from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
Raises:
ToolMessageUpdateError: If the database update fails. The task will be
marked as failed instead of completed to avoid inconsistent state.
"""
# For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
# Serialize result for output (only substitute default when result is exactly None)
result_output = result if result is not None else {"status": "completed"}
output_str = (
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
)
# Publish result to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=output_str,
success=True,
),
)
# Update pending operation in database
# If this fails, we must not continue to mark the task as completed
result_str = serialize_result(result)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=result_str,
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - mark task as failed to avoid inconsistent state
logger.error(
f"[COMPLETION] DB update failed for task {task.task_id}, "
"marking as failed instead of completed"
)
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText="Failed to save operation result to database"),
)
await stream_registry.mark_task_completed(task.task_id, status="failed")
raise
# Generate LLM continuation with streaming
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
async def process_operation_failure(
task: stream_registry.ActiveTask,
error: str | None,
prisma_client: Prisma | None = None,
) -> None:
"""Handle failed operation completion.
Publishes the error to the stream registry, updates the database with
the error response, and marks the task as failed.
Args:
task: The active task that failed
error: The error message from the operation
prisma_client: Optional Prisma client for database operations.
If None, uses chat_service._update_pending_operation instead.
"""
error_msg = error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Update pending operation with error
# If this fails, we still continue to mark the task as failed
error_response = ErrorResponse(
message=error_msg,
error=error,
)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=error_response.model_dump_json(),
prisma_client=prisma_client,
)
except ToolMessageUpdateError:
# DB update failed - log but continue with cleanup
logger.error(
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
"continuing with cleanup"
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")

View File

@@ -27,6 +27,7 @@ class ChatConfig(BaseSettings):
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
@@ -36,29 +37,52 @@ class ChatConfig(BaseSettings):
default=30, description="Maximum number of agent schedules"
)
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=600,
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
description="TTL in seconds for stream data in Redis (1 hour)",
)
stream_lock_ttl: int = Field(
default=120,
description="TTL in seconds for stream lock (2 minutes). Short timeout allows "
"reconnection after refresh/crash without long waits.",
)
stream_max_length: int = Field(
default=10000,
description="Maximum number of messages to store per stream",
)
# Redis key prefixes for stream registry
session_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for session metadata hash keys",
# Redis Streams configuration for completion consumer
stream_completion_name: str = Field(
default="chat:completions",
description="Redis Stream name for operation completions",
)
turn_stream_prefix: str = Field(
stream_consumer_group: str = Field(
default="chat_consumers",
description="Consumer group name for completion stream",
)
stream_claim_min_idle_ms: int = Field(
default=60000,
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
)
# Redis key prefixes for stream registry
task_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for task metadata hash keys",
)
task_stream_prefix: str = Field(
default="chat:stream:",
description="Prefix for turn message stream keys",
description="Prefix for task message stream keys",
)
task_op_prefix: str = Field(
default="chat:task:op:",
description="Prefix for operation ID to task ID mapping keys",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
)
# Langfuse Prompt Management Configuration
@@ -85,7 +109,7 @@ class ChatConfig(BaseSettings):
)
claude_agent_max_subtasks: int = Field(
default=10,
description="Max number of concurrent sub-agent Tasks the SDK can run per session.",
description="Max number of sub-agent Tasks the SDK can spawn per session.",
)
claude_agent_use_resume: bool = Field(
default=True,
@@ -130,6 +154,14 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):

View File

@@ -3,9 +3,8 @@
import asyncio
import logging
from datetime import UTC, datetime
from typing import Any
from typing import Any, cast
from prisma.errors import UniqueViolationError
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import (
@@ -15,27 +14,29 @@ from prisma.types import (
ChatSessionWhereInput,
)
from backend.data import db
from backend.data.db import transaction
from backend.util.json import SafeJson
from .model import ChatMessage, ChatSession, ChatSessionInfo
logger = logging.getLogger(__name__)
async def get_chat_session(session_id: str) -> ChatSession | None:
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
"""Get a chat session by ID from the database."""
session = await PrismaChatSession.prisma().find_unique(
where={"id": session_id},
include={"Messages": {"order_by": {"sequence": "asc"}}},
include={"Messages": True},
)
return ChatSession.from_db(session) if session else None
if session and session.Messages:
# Sort messages by sequence in Python - Prisma Python client doesn't support
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
session.Messages.sort(key=lambda m: m.sequence)
return session
async def create_chat_session(
session_id: str,
user_id: str,
) -> ChatSessionInfo:
) -> PrismaChatSession:
"""Create a new chat session in the database."""
data = ChatSessionCreateInput(
id=session_id,
@@ -44,8 +45,7 @@ async def create_chat_session(
successfulAgentRuns=SafeJson({}),
successfulAgentSchedules=SafeJson({}),
)
prisma_session = await PrismaChatSession.prisma().create(data=data)
return ChatSessionInfo.from_db(prisma_session)
return await PrismaChatSession.prisma().create(data=data)
async def update_chat_session(
@@ -56,7 +56,7 @@ async def update_chat_session(
total_prompt_tokens: int | None = None,
total_completion_tokens: int | None = None,
title: str | None = None,
) -> ChatSession | None:
) -> PrismaChatSession | None:
"""Update a chat session's metadata."""
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
@@ -76,9 +76,12 @@ async def update_chat_session(
session = await PrismaChatSession.prisma().update(
where={"id": session_id},
data=data,
include={"Messages": {"order_by": {"sequence": "asc"}}},
include={"Messages": True},
)
return ChatSession.from_db(session) if session else None
if session and session.Messages:
# Sort in Python - Prisma Python doesn't support order_by in include clauses
session.Messages.sort(key=lambda m: m.sequence)
return session
async def add_chat_message(
@@ -91,11 +94,12 @@ async def add_chat_message(
refusal: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
function_call: dict[str, Any] | None = None,
) -> ChatMessage:
) -> PrismaChatMessage:
"""Add a message to a chat session."""
# Build ChatMessageCreateInput with only non-None values
# (Prisma TypedDict rejects optional fields set to None)
data: ChatMessageCreateInput = {
# Build input dict dynamically rather than using ChatMessageCreateInput directly
# because Prisma's TypedDict validation rejects optional fields set to None.
# We only include fields that have values, then cast at the end.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": role,
"sequence": sequence,
@@ -123,117 +127,81 @@ async def add_chat_message(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
),
PrismaChatMessage.prisma().create(data=data),
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
)
return ChatMessage.from_db(message)
return message
async def add_chat_messages_batch(
session_id: str,
messages: list[dict[str, Any]],
start_sequence: int,
) -> int:
) -> list[PrismaChatMessage]:
"""Add multiple messages to a chat session in a batch.
Uses collision detection with retry: tries to create messages starting
at start_sequence. If a unique constraint violation occurs (e.g., the
streaming loop and long-running callback race), queries the latest
sequence and retries with the correct offset. This avoids unnecessary
upserts and DB queries in the common case (no collision).
Returns:
Next sequence number for the next message to be inserted. This equals
start_sequence + len(messages) and allows callers to update their
counters even when collision detection adjusts start_sequence.
Uses a transaction for atomicity - if any message creation fails,
the entire batch is rolled back.
"""
if not messages:
# No messages to add - return current count
return start_sequence
return []
max_retries = 5
for attempt in range(max_retries):
try:
# Single timestamp for all messages and session update
now = datetime.now(UTC)
created_messages = []
async with db.transaction() as tx:
# Build all message data
messages_data = []
for i, msg in enumerate(messages):
# Build ChatMessageCreateInput with only non-None values
# (Prisma TypedDict rejects optional fields set to None)
# Note: create_many doesn't support nested creates, use sessionId directly
data: ChatMessageCreateInput = {
"sessionId": session_id,
"role": msg["role"],
"sequence": start_sequence + i,
"createdAt": now,
}
async with transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
messages_data.append(data)
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
)
created_messages.append(created)
# Run create_many and session update in parallel within transaction
# Both use the same timestamp for consistency
await asyncio.gather(
PrismaChatMessage.prisma(tx).create_many(data=messages_data),
PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": now},
),
)
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
# Return next sequence number for counter sync
return start_sequence + len(messages)
except UniqueViolationError:
if attempt < max_retries - 1:
# Collision detected - query MAX(sequence)+1 and retry with correct offset
logger.info(
f"Collision detected for session {session_id} at sequence "
f"{start_sequence}, querying DB for latest sequence"
)
start_sequence = await get_next_sequence(session_id)
logger.info(
f"Retrying batch insert with start_sequence={start_sequence}"
)
continue
else:
# Max retries exceeded - propagate error
raise
# Should never reach here due to raise in exception handler
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
return created_messages
async def get_user_chat_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> list[ChatSessionInfo]:
) -> list[PrismaChatSession]:
"""Get chat sessions for a user, ordered by most recent."""
prisma_sessions = await PrismaChatSession.prisma().find_many(
return await PrismaChatSession.prisma().find_many(
where={"userId": user_id},
order={"updatedAt": "desc"},
take=limit,
skip=offset,
)
return [ChatSessionInfo.from_db(s) for s in prisma_sessions]
async def get_user_session_count(user_id: str) -> int:
@@ -272,20 +240,10 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
return False
async def get_next_sequence(session_id: str) -> int:
"""Get the next sequence number for a new message in this session.
Uses MAX(sequence) + 1 for robustness. Returns 0 if no messages exist.
More robust than COUNT(*) because it's immune to deleted messages.
Optimized to select only the sequence column using raw SQL.
The unique index on (sessionId, sequence) makes this query fast.
"""
results = await db.query_raw_with_schema(
'SELECT "sequence" FROM {schema_prefix}"ChatMessage" WHERE "sessionId" = $1 ORDER BY "sequence" DESC LIMIT 1',
session_id,
)
return 0 if not results else results[0]["sequence"] + 1
async def get_chat_session_message_count(session_id: str) -> int:
"""Get the number of messages in a chat session."""
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
return count
async def update_tool_message_content(

View File

@@ -2,7 +2,7 @@ import asyncio
import logging
import uuid
from datetime import UTC, datetime
from typing import Any, Self, cast
from typing import Any, cast
from weakref import WeakValueDictionary
from openai.types.chat import (
@@ -23,17 +23,26 @@ from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from pydantic import BaseModel
from backend.data.db_accessors import chat_db
from backend.data.redis_client import get_redis_async
from backend.util import json
from backend.util.exceptions import DatabaseError, RedisError
from . import db as chat_db
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
"""Parse a JSON field that may be stored as string or already parsed."""
if value is None:
return default
if isinstance(value, str):
return json.loads(value)
return value
# Redis cache key prefix for chat sessions
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
@@ -43,7 +52,28 @@ def _get_session_cache_key(session_id: str) -> str:
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
# ===================== Chat data models ===================== #
# Session-level locks to prevent race conditions during concurrent upserts.
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock
class ChatMessage(BaseModel):
@@ -55,19 +85,6 @@ class ChatMessage(BaseModel):
tool_calls: list[dict] | None = None
function_call: dict | None = None
@staticmethod
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
"""Convert a Prisma ChatMessage to a Pydantic ChatMessage."""
return ChatMessage(
role=prisma_message.role,
content=prisma_message.content,
name=prisma_message.name,
tool_call_id=prisma_message.toolCallId,
refusal=prisma_message.refusal,
tool_calls=_parse_json_field(prisma_message.toolCalls),
function_call=_parse_json_field(prisma_message.functionCall),
)
class Usage(BaseModel):
prompt_tokens: int
@@ -75,10 +92,11 @@ class Usage(BaseModel):
total_tokens: int
class ChatSessionInfo(BaseModel):
class ChatSession(BaseModel):
session_id: str
user_id: str
title: str | None = None
messages: list[ChatMessage]
usage: list[Usage]
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
started_at: datetime
@@ -86,9 +104,60 @@ class ChatSessionInfo(BaseModel):
successful_agent_runs: dict[str, int] = {}
successful_agent_schedules: dict[str, int] = {}
@classmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
"""Convert Prisma ChatSession to Pydantic ChatSession."""
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
"""Attach a tool_call to the current turn's assistant message.
Searches backwards for the most recent assistant message (stopping at
any user message boundary). If found, appends the tool_call to it.
Otherwise creates a new assistant message with the tool_call.
"""
for msg in reversed(self.messages):
if msg.role == "user":
break
if msg.role == "assistant":
if not msg.tool_calls:
msg.tool_calls = []
msg.tool_calls.append(tool_call)
return
self.messages.append(
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
)
@staticmethod
def new(user_id: str) -> "ChatSession":
return ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
title=None,
messages=[],
usage=[],
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
@staticmethod
def from_db(
prisma_session: PrismaChatSession,
prisma_messages: list[PrismaChatMessage] | None = None,
) -> "ChatSession":
"""Convert Prisma models to Pydantic ChatSession."""
messages = []
if prisma_messages:
for msg in prisma_messages:
messages.append(
ChatMessage(
role=msg.role,
content=msg.content,
name=msg.name,
tool_call_id=msg.toolCallId,
refusal=msg.refusal,
tool_calls=_parse_json_field(msg.toolCalls),
function_call=_parse_json_field(msg.functionCall),
)
)
# Parse JSON fields from Prisma
credentials = _parse_json_field(prisma_session.credentials, default={})
successful_agent_runs = _parse_json_field(
@@ -110,10 +179,11 @@ class ChatSessionInfo(BaseModel):
)
)
return cls(
return ChatSession(
session_id=prisma_session.id,
user_id=prisma_session.userId,
title=prisma_session.title,
messages=messages,
usage=usage,
credentials=credentials,
started_at=prisma_session.createdAt,
@@ -122,55 +192,46 @@ class ChatSessionInfo(BaseModel):
successful_agent_schedules=successful_agent_schedules,
)
@staticmethod
def _merge_consecutive_assistant_messages(
messages: list[ChatCompletionMessageParam],
) -> list[ChatCompletionMessageParam]:
"""Merge consecutive assistant messages into single messages.
class ChatSession(ChatSessionInfo):
messages: list[ChatMessage]
@classmethod
def new(cls, user_id: str) -> Self:
return cls(
session_id=str(uuid.uuid4()),
user_id=user_id,
title=None,
messages=[],
usage=[],
credentials={},
started_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
)
@classmethod
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
"""Convert Prisma ChatSession to Pydantic ChatSession."""
if prisma_session.Messages is None:
raise ValueError(
f"Prisma session {prisma_session.id} is missing Messages relation"
)
return cls(
**ChatSessionInfo.from_db(prisma_session).model_dump(),
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
)
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
"""Attach a tool_call to the current turn's assistant message.
Searches backwards for the most recent assistant message (stopping at
any user message boundary). If found, appends the tool_call to it.
Otherwise creates a new assistant message with the tool_call.
Long-running tool flows can create split assistant messages: one with
text content and another with tool_calls. Anthropic's API requires
tool_result blocks to reference a tool_use in the immediately preceding
assistant message, so these splits cause 400 errors via OpenRouter.
"""
for msg in reversed(self.messages):
if msg.role == "user":
break
if msg.role == "assistant":
if not msg.tool_calls:
msg.tool_calls = []
msg.tool_calls.append(tool_call)
return
if len(messages) < 2:
return messages
self.messages.append(
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
)
result: list[ChatCompletionMessageParam] = [messages[0]]
for msg in messages[1:]:
prev = result[-1]
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
result.append(msg)
continue
prev = cast(ChatCompletionAssistantMessageParam, prev)
curr = cast(ChatCompletionAssistantMessageParam, msg)
curr_content = curr.get("content") or ""
if curr_content:
prev_content = prev.get("content") or ""
prev["content"] = (
f"{prev_content}\n{curr_content}" if prev_content else curr_content
)
curr_tool_calls = curr.get("tool_calls")
if curr_tool_calls:
prev_tool_calls = prev.get("tool_calls")
prev["tool_calls"] = (
list(prev_tool_calls) + list(curr_tool_calls)
if prev_tool_calls
else list(curr_tool_calls)
)
return result
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
messages = []
@@ -260,70 +321,40 @@ class ChatSession(ChatSessionInfo):
)
return self._merge_consecutive_assistant_messages(messages)
@staticmethod
def _merge_consecutive_assistant_messages(
messages: list[ChatCompletionMessageParam],
) -> list[ChatCompletionMessageParam]:
"""Merge consecutive assistant messages into single messages.
Long-running tool flows can create split assistant messages: one with
text content and another with tool_calls. Anthropic's API requires
tool_result blocks to reference a tool_use in the immediately preceding
assistant message, so these splits cause 400 errors via OpenRouter.
"""
if len(messages) < 2:
return messages
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
"""Get a chat session from Redis cache."""
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
raw_session: bytes | None = await async_redis.get(redis_key)
result: list[ChatCompletionMessageParam] = [messages[0]]
for msg in messages[1:]:
prev = result[-1]
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
result.append(msg)
continue
if raw_session is None:
return None
prev = cast(ChatCompletionAssistantMessageParam, prev)
curr = cast(ChatCompletionAssistantMessageParam, msg)
curr_content = curr.get("content") or ""
if curr_content:
prev_content = prev.get("content") or ""
prev["content"] = (
f"{prev_content}\n{curr_content}" if prev_content else curr_content
)
curr_tool_calls = curr.get("tool_calls")
if curr_tool_calls:
prev_tool_calls = prev.get("tool_calls")
prev["tool_calls"] = (
list(prev_tool_calls) + list(curr_tool_calls)
if prev_tool_calls
else list(curr_tool_calls)
)
return result
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
"""Parse a JSON field that may be stored as string or already parsed."""
if value is None:
return default
if isinstance(value, str):
return json.loads(value)
return value
# ================ Chat cache + DB operations ================ #
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
# connected directly.
async def cache_chat_session(session: ChatSession) -> None:
"""Cache a chat session in Redis (without persisting to the database)."""
async def _cache_session(session: ChatSession) -> None:
"""Cache a chat session in Redis."""
redis_key = _get_session_cache_key(session.session_id)
async_redis = await get_redis_async()
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
async def cache_chat_session(session: ChatSession) -> None:
"""Cache a chat session without persisting to the database."""
await _cache_session(session)
async def invalidate_session_cache(session_id: str) -> None:
"""Invalidate a chat session from Redis cache.
@@ -339,6 +370,77 @@ async def invalidate_session_cache(session_id: str) -> None:
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
prisma_session = await chat_db.get_chat_session(session_id)
if not prisma_session:
return None
messages = prisma_session.Messages
logger.debug(
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
)
return ChatSession.from_db(prisma_session, messages)
async def _save_session_to_db(
session: ChatSession, existing_message_count: int
) -> None:
"""Save or update a chat session in the database."""
# Check if session exists in DB
existing = await chat_db.get_chat_session(session.session_id)
if not existing:
# Create new session
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
# Update session metadata
await chat_db.update_chat_session(
session_id=session.session_id,
credentials=session.credentials,
successful_agent_runs=session.successful_agent_runs,
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
if new_messages:
messages_data = []
for msg in new_messages:
messages_data.append(
{
"role": msg.role,
"content": msg.content,
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"refusal": msg.refusal,
"tool_calls": msg.tool_calls,
"function_call": msg.function_call,
}
)
logger.debug(
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
f"roles={[m['role'] for m in messages_data]}"
)
await chat_db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
async def get_chat_session(
session_id: str,
user_id: str | None = None,
@@ -386,52 +488,13 @@ async def get_chat_session(
# Cache the session from DB
try:
await cache_chat_session(session)
logger.info(f"Cached session {session_id} from database")
await _cache_session(session)
except Exception as e:
logger.warning(f"Failed to cache session {session_id}: {e}")
return session
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
"""Get a chat session from Redis cache."""
redis_key = _get_session_cache_key(session_id)
async_redis = await get_redis_async()
raw_session: bytes | None = await async_redis.get(redis_key)
if raw_session is None:
return None
try:
session = ChatSession.model_validate_json(raw_session)
logger.info(
f"Loading session {session_id} from cache: "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
except Exception as e:
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
raise RedisError(f"Corrupted session data for {session_id}") from e
async def _get_session_from_db(session_id: str) -> ChatSession | None:
"""Get a chat session from the database."""
session = await chat_db().get_chat_session(session_id)
if not session:
return None
logger.info(
f"Loaded session {session_id} from DB: "
f"has_messages={bool(session.messages)}, "
f"message_count={len(session.messages)}, "
f"roles={[m.role for m in session.messages]}"
)
return session
async def upsert_chat_session(
session: ChatSession,
) -> ChatSession:
@@ -451,18 +514,16 @@ async def upsert_chat_session(
lock = await _get_session_lock(session.session_id)
async with lock:
# Always query DB for existing message count to ensure consistency
existing_message_count = await chat_db().get_next_sequence(session.session_id)
# Get existing message count from DB for incremental saves
existing_message_count = await chat_db.get_chat_session_message_count(
session.session_id
)
db_error: Exception | None = None
# Save to database (primary storage)
try:
await _save_session_to_db(
session,
existing_message_count,
skip_existence_check=existing_message_count > 0,
)
await _save_session_to_db(session, existing_message_count)
except Exception as e:
logger.error(
f"Failed to save session {session.session_id} to database: {e}"
@@ -471,7 +532,7 @@ async def upsert_chat_session(
# Save to cache (best-effort, even if DB failed)
try:
await cache_chat_session(session)
await _cache_session(session)
except Exception as e:
# If DB succeeded but cache failed, raise cache error
if db_error is None:
@@ -492,75 +553,6 @@ async def upsert_chat_session(
return session
async def _save_session_to_db(
session: ChatSession,
existing_message_count: int,
*,
skip_existence_check: bool = False,
) -> None:
"""Save or update a chat session in the database.
Args:
skip_existence_check: When True, skip the ``get_chat_session`` query
and assume the session row already exists. Saves one DB round trip
for incremental saves during streaming.
"""
db = chat_db()
if not skip_existence_check:
# Check if session exists in DB
existing = await db.get_chat_session(session.session_id)
if not existing:
# Create new session
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
total_completion = sum(u.completion_tokens for u in session.usage)
# Update session metadata
await db.update_chat_session(
session_id=session.session_id,
credentials=session.credentials,
successful_agent_runs=session.successful_agent_runs,
successful_agent_schedules=session.successful_agent_schedules,
total_prompt_tokens=total_prompt,
total_completion_tokens=total_completion,
)
# Add new messages (only those after existing count)
new_messages = session.messages[existing_message_count:]
if new_messages:
messages_data = []
for msg in new_messages:
messages_data.append(
{
"role": msg.role,
"content": msg.content,
"name": msg.name,
"tool_call_id": msg.tool_call_id,
"refusal": msg.refusal,
"tool_calls": msg.tool_calls,
"function_call": msg.function_call,
}
)
logger.info(
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
f"roles={[m['role'] for m in messages_data]}, "
f"start_sequence={existing_message_count}"
)
await db.add_chat_messages_batch(
session_id=session.session_id,
messages=messages_data,
start_sequence=existing_message_count,
)
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
"""Atomically append a message to a session and persist it.
@@ -576,7 +568,9 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db().get_next_sequence(session_id)
existing_message_count = await chat_db.get_chat_session_message_count(
session_id
)
try:
await _save_session_to_db(session, existing_message_count)
@@ -586,7 +580,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
) from e
try:
await cache_chat_session(session)
await _cache_session(session)
except Exception as e:
logger.warning(f"Cache write failed for session {session_id}: {e}")
@@ -605,7 +599,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
# Create in database first - fail fast if this fails
try:
await chat_db().create_chat_session(
await chat_db.create_chat_session(
session_id=session.session_id,
user_id=user_id,
)
@@ -617,7 +611,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
# Cache the session (best-effort optimization, DB is source of truth)
try:
await cache_chat_session(session)
await _cache_session(session)
except Exception as e:
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
@@ -628,16 +622,20 @@ async def get_user_sessions(
user_id: str,
limit: int = 50,
offset: int = 0,
) -> tuple[list[ChatSessionInfo], int]:
) -> tuple[list[ChatSession], int]:
"""Get chat sessions for a user from the database with total count.
Returns:
A tuple of (sessions, total_count) where total_count is the overall
number of sessions for the user (not just the current page).
"""
db = chat_db()
sessions = await db.get_user_chat_sessions(user_id, limit, offset)
total_count = await db.get_user_session_count(user_id)
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
total_count = await chat_db.get_user_session_count(user_id)
sessions = []
for prisma_session in prisma_sessions:
# Convert without messages for listing (lighter weight)
sessions.append(ChatSession.from_db(prisma_session, None))
return sessions, total_count
@@ -655,7 +653,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
"""
# Delete from database first (with optional user_id validation)
# This confirms ownership before invalidating cache
deleted = await chat_db().delete_chat_session(session_id, user_id)
deleted = await chat_db.delete_chat_session(session_id, user_id)
if not deleted:
return False
@@ -690,7 +688,7 @@ async def update_session_title(session_id: str, title: str) -> bool:
True if updated successfully, False otherwise.
"""
try:
result = await chat_db().update_chat_session(session_id=session_id, title=title)
result = await chat_db.update_chat_session(session_id=session_id, title=title)
if result is None:
logger.warning(f"Session {session_id} not found for title update")
return False
@@ -702,7 +700,7 @@ async def update_session_title(session_id: str, title: str) -> bool:
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await cache_chat_session(cached)
await _cache_session(cached)
except Exception as e:
# Not critical - title will be correct on next full cache refresh
logger.warning(
@@ -713,29 +711,3 @@ async def update_session_title(session_id: str, title: str) -> bool:
except Exception as e:
logger.error(f"Failed to update title for session {session_id}: {e}")
return False
# ==================== Chat session locks ==================== #
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
_session_locks_mutex = asyncio.Lock()
async def _get_session_lock(session_id: str) -> asyncio.Lock:
"""Get or create a lock for a specific session to prevent concurrent upserts.
This was originally added to solve the specific problem of race conditions between
the session title thread and the conversation thread, which always occurs on the
same instance as we prevent rapid request sends on the frontend.
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
when no coroutine holds a reference to them, preventing memory leaks from
unbounded growth of session locks. Explicit cleanup also occurs
in `delete_chat_session()`.
"""
async with _session_locks_mutex:
lock = _session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
_session_locks[session_id] = lock
return lock

View File

@@ -331,96 +331,3 @@ def test_to_openai_messages_merges_split_assistants():
tc_list = merged.get("tool_calls")
assert tc_list is not None and len(list(tc_list)) == 1
assert list(tc_list)[0]["id"] == "tc1"
# --------------------------------------------------------------------------- #
# Concurrent save collision detection #
# --------------------------------------------------------------------------- #
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_saves_collision_detection(setup_test_user, test_user_id):
"""Test that concurrent saves from streaming loop and callback handle collisions correctly.
Simulates the race condition where:
1. Streaming loop starts with saved_msg_count=5
2. Long-running callback appends message #5 and saves
3. Streaming loop tries to save with stale count=5
The collision detection should handle this gracefully.
"""
import asyncio
# Create a session with initial messages
session = ChatSession.new(user_id=test_user_id)
for i in range(3):
session.messages.append(
ChatMessage(
role="user" if i % 2 == 0 else "assistant", content=f"Message {i}"
)
)
# Save initial messages
session = await upsert_chat_session(session)
# Simulate streaming loop and callback saving concurrently
async def streaming_loop_save():
"""Simulates streaming loop saving messages."""
# Add 2 messages
session.messages.append(ChatMessage(role="user", content="Streaming message 1"))
session.messages.append(
ChatMessage(role="assistant", content="Streaming message 2")
)
# Wait a bit to let callback potentially save first
await asyncio.sleep(0.01)
# Save (will query DB for existing count)
return await upsert_chat_session(session)
async def callback_save():
"""Simulates long-running callback saving a message."""
# Add 1 message
session.messages.append(
ChatMessage(role="tool", content="Callback result", tool_call_id="tc1")
)
# Save immediately (will query DB for existing count)
return await upsert_chat_session(session)
# Run both saves concurrently - one will hit collision detection
results = await asyncio.gather(streaming_loop_save(), callback_save())
# Both should succeed
assert all(r is not None for r in results)
# Reload session from DB to verify
from backend.data.redis_client import get_redis_async
redis_key = f"chat:session:{session.session_id}"
async_redis = await get_redis_async()
await async_redis.delete(redis_key) # Clear cache to force DB load
loaded_session = await get_chat_session(session.session_id, test_user_id)
assert loaded_session is not None
# Should have all 6 messages (3 initial + 2 streaming + 1 callback)
assert len(loaded_session.messages) == 6
# Verify no duplicate sequences
sequences = []
for i, msg in enumerate(loaded_session.messages):
# Messages should have sequential sequence numbers starting from 0
sequences.append(i)
# All sequences should be unique and sequential
assert sequences == list(range(6))
# Verify message content is preserved
contents = [m.content for m in loaded_session.messages]
assert "Message 0" in contents
assert "Message 1" in contents
assert "Message 2" in contents
assert "Streaming message 1" in contents
assert "Streaming message 2" in contents
assert "Callback result" in contents

View File

@@ -5,8 +5,6 @@ This module implements the AI SDK UI Stream Protocol (v1) for streaming chat res
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
"""
import json
import logging
from enum import Enum
from typing import Any
@@ -14,8 +12,6 @@ from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
logger = logging.getLogger(__name__)
class ResponseType(str, Enum):
"""Types of streaming responses following AI SDK protocol."""
@@ -51,8 +47,7 @@ class StreamBaseResponse(BaseModel):
def to_sse(self) -> str:
"""Convert to SSE format."""
json_str = self.model_dump_json(exclude_none=True)
return f"data: {json_str}\n\n"
return f"data: {self.model_dump_json()}\n\n"
# ========== Message Lifecycle ==========
@@ -63,13 +58,15 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
sessionId: str | None = Field(
taskId: str | None = Field(
default=None,
description="Session ID for SSE reconnection.",
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-protocol fields like sessionId."""
"""Convert to SSE format, excluding non-protocol fields like taskId."""
import json
data: dict[str, Any] = {
"type": self.type.value,
"messageId": self.messageId,
@@ -166,6 +163,8 @@ class StreamToolOutputAvailable(StreamBaseResponse):
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
import json
data = {
"type": self.type.value,
"toolCallId": self.toolCallId,

View File

@@ -2,30 +2,33 @@
import asyncio
import logging
import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.model import (
from backend.util.exceptions import NotFoundError
from backend.util.feature_flag import Flag, is_feature_enabled
from . import service as chat_service
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
from .model import (
ChatMessage,
ChatSession,
append_and_save_message,
create_chat_session,
delete_chat_session,
get_chat_session,
get_user_sessions,
)
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
from backend.copilot.tools.models import (
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
from .sdk import service as sdk_service
from .tools.models import (
AgentDetailsResponse,
AgentOutputResponse,
AgentPreviewResponse,
@@ -42,12 +45,13 @@ from backend.copilot.tools.models import (
InputValidationErrorResponse,
NeedLoginResponse,
NoResultsResponse,
OperationInProgressResponse,
OperationPendingResponse,
OperationStartedResponse,
SetupRequirementsResponse,
SuggestedGoalResponse,
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from .tracking import track_user_message
config = ChatConfig()
@@ -92,8 +96,10 @@ class CreateSessionResponse(BaseModel):
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
turn_id: str
task_id: str
last_message_id: str # Redis Stream message ID for resumption
operation_id: str # Operation ID for completion tracking
tool_name: str # Name of the tool being executed
class SessionDetailResponse(BaseModel):
@@ -123,11 +129,12 @@ class ListSessionsResponse(BaseModel):
total: int
class CancelSessionResponse(BaseModel):
"""Response model for the cancel session endpoint."""
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
cancelled: bool
reason: str | None = None
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -204,43 +211,6 @@ async def create_session(
)
@router.delete(
"/sessions/{session_id}",
dependencies=[Security(auth.requires_user)],
status_code=204,
responses={404: {"description": "Session not found or access denied"}},
)
async def delete_session(
session_id: str,
user_id: Annotated[str, Security(auth.get_user_id)],
) -> Response:
"""
Delete a chat session.
Permanently removes a chat session and all its messages.
Only the owner can delete their sessions.
Args:
session_id: The session ID to delete.
user_id: The authenticated user's ID.
Returns:
204 No Content on success.
Raises:
HTTPException: 404 if session not found or not owned by user.
"""
deleted = await delete_chat_session(session_id, user_id)
if not deleted:
raise HTTPException(
status_code=404,
detail=f"Session {session_id} not found or access denied",
)
return Response(status_code=204)
@router.get(
"/sessions/{session_id}",
)
@@ -252,7 +222,7 @@ async def get_session(
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns active_stream info for reconnection.
If there's an active stream for this session, returns the task_id for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
@@ -270,21 +240,28 @@ async def get_session(
# Check if there's an active stream for this session
active_stream_info = None
active_session, last_message_id = await stream_registry.get_active_session(
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
# Keep the assistant message (including tool_calls) so the frontend can
# render the correct tool UI (e.g. CreateAgent with mini game).
# convertChatSessionToUiMessages handles isComplete=false by setting
# tool parts without output to state "input-available".
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
# stream replay instead, preventing duplicate content.
if messages and messages[-1].get("role") == "assistant":
messages = messages[:-1]
# Use "0-0" as last_message_id to replay the stream from the beginning.
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
task_id=active_task.task_id,
last_message_id="0-0",
operation_id=active_task.operation_id,
tool_name=active_task.tool_name,
)
return SessionDetailResponse(
@@ -297,51 +274,6 @@ async def get_session(
)
@router.post(
"/sessions/{session_id}/cancel",
status_code=200,
)
async def cancel_session_task(
session_id: str,
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CancelSessionResponse:
"""Cancel the active streaming task for a session.
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
polls Redis until the task status flips from ``running`` or a timeout
(5 s) is reached. Returns only after the cancellation is confirmed.
"""
await _validate_and_get_session(session_id, user_id)
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
if not active_session:
return CancelSessionResponse(cancelled=True, reason="no_active_session")
await enqueue_cancel_task(session_id)
logger.info(f"[CANCEL] Published cancel for session ...{session_id[-8:]}")
# Poll until the executor confirms the task is no longer running.
poll_interval = 0.5
max_wait = 5.0
waited = 0.0
while waited < max_wait:
await asyncio.sleep(poll_interval)
waited += poll_interval
session_state = await stream_registry.get_session(session_id)
if session_state is None or session_state.status != "running":
logger.info(
f"[CANCEL] Session ...{session_id[-8:]} confirmed stopped "
f"(status={session_state.status if session_state else 'gone'}) after {waited:.1f}s"
)
return CancelSessionResponse(cancelled=True)
logger.warning(
f"[CANCEL] Session ...{session_id[-8:]} not confirmed after {max_wait}s, force-completing"
)
await stream_registry.mark_session_completed(session_id, error_message="Cancelled")
return CancelSessionResponse(cancelled=True)
@router.post(
"/sessions/{session_id}/stream",
)
@@ -359,15 +291,16 @@ async def stream_chat_post(
- Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to a per-turn Redis stream for reconnection support. If the client
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
All chunks are written to Redis for reconnection support. If the client disconnects,
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
containing the task_id for reconnection.
"""
import asyncio
@@ -383,7 +316,7 @@ async def stream_chat_post(
f"user={user_id}, message_len={len(request.message)}",
extra={"json_fields": log_meta},
)
await _validate_and_get_session(session_id, user_id)
session = await _validate_and_get_session(session_id, user_id)
logger.info(
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
extra={
@@ -410,46 +343,152 @@ async def stream_chat_post(
message_length=len(request.message),
)
logger.info(f"[STREAM] Saving user message to session {session_id}")
await append_and_save_message(session_id, message)
session = await append_and_save_message(session_id, message)
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
log_meta["task_id"] = task_id
session_create_start = time.perf_counter()
await stream_registry.create_session(
task_create_start = time.perf_counter()
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream",
tool_call_id="chat_stream", # Not a tool call, but needed for the model
tool_name="chat",
turn_id=turn_id,
operation_id=operation_id,
)
logger.info(
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
}
},
)
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
subscribe_from_id = "0-0"
# Background task that runs the AI generation independently of SSE connection
async def run_ai_generation():
import time as time_module
await enqueue_copilot_turn(
session_id=session_id,
user_id=user_id,
message=request.message,
turn_id=turn_id,
is_user_message=request.is_user_message,
context=request.context,
)
gen_start_time = time_module.perf_counter()
logger.info(
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
extra={"json_fields": log_meta},
)
first_chunk_time, ttfc = None, None
chunk_count = 0
try:
# Emit a start event with task_id for reconnection
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
await stream_registry.publish_chunk(task_id, start_chunk)
logger.info(
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
* 1000,
}
},
)
# Choose service based on LaunchDarkly flag (falls back to config default)
use_sdk = await is_feature_enabled(
Flag.COPILOT_SDK,
user_id or "anonymous",
default=config.use_claude_agent_sdk,
)
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else chat_service.stream_chat_completion
)
logger.info(
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
extra={"json_fields": log_meta},
)
# Pass message=None since we already added it to the session above
async for chunk in stream_fn(
session_id,
None, # Message already in session
is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass session with message already added
context=request.context,
):
# Skip duplicate StreamStart — we already published one above
if isinstance(chunk, StreamStart):
continue
chunk_count += 1
if first_chunk_time is None:
first_chunk_time = time_module.perf_counter()
ttfc = first_chunk_time - gen_start_time
logger.info(
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
extra={
"json_fields": {
**log_meta,
"chunk_type": type(chunk).__name__,
"time_to_first_chunk_ms": ttfc * 1000,
}
},
)
# Write to Redis (subscribers will receive via XREAD)
await stream_registry.publish_chunk(task_id, chunk)
gen_end_time = time_module.perf_counter()
total_time = (gen_end_time - gen_start_time) * 1000
logger.info(
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
f"task={task_id}, session={session_id}, "
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
extra={
"json_fields": {
**log_meta,
"total_time_ms": total_time,
"time_to_first_chunk_ms": (
ttfc * 1000 if ttfc is not None else None
),
"n_chunks": chunk_count,
}
},
)
await stream_registry.mark_task_completed(task_id, "completed")
except Exception as e:
elapsed = time_module.perf_counter() - gen_start_time
logger.error(
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
extra={
"json_fields": {
**log_meta,
"elapsed_ms": elapsed * 1000,
"error": str(e),
}
},
)
# Publish a StreamError so the frontend can display an error message
try:
await stream_registry.publish_chunk(
task_id,
StreamError(
errorText="An error occurred. Please try again.",
code="stream_error",
),
)
except Exception:
pass # Best-effort; mark_task_completed will publish StreamFinish
await stream_registry.mark_task_completed(task_id, "failed")
# Start the AI generation in a background task
bg_task = asyncio.create_task(run_ai_generation())
await stream_registry.set_task_asyncio_task(task_id, bg_task)
setup_time = (time.perf_counter() - stream_start_time) * 1000
logger.info(
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
)
@@ -459,7 +498,7 @@ async def stream_chat_post(
event_gen_start = time_module.perf_counter()
logger.info(
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
)
@@ -467,12 +506,11 @@ async def stream_chat_post(
first_chunk_yielded = False
chunks_yielded = 0
try:
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=subscribe_from_id,
last_message_id="0-0", # Get all messages from the beginning
)
if subscriber_queue is None:
@@ -555,19 +593,19 @@ async def stream_chat_post(
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
await stream_registry.unsubscribe_from_task(
task_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from session {session_id}: {unsub_err}",
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
@@ -614,21 +652,17 @@ async def resume_session_stream(
"""
import asyncio
active_session, last_message_id = await stream_registry.get_active_session(
active_task, _last_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_session:
if not active_task:
return Response(status_code=204)
# Always replay from the beginning ("0-0") on resume.
# We can't use last_message_id because it's the latest ID in the backend
# stream, not the latest the frontend received — the gap causes lost
# messages. The frontend deduplicates replayed content.
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=active_task.task_id,
user_id=user_id,
last_message_id="0-0",
last_message_id="0-0", # Full replay so useChat rebuilds the message
)
if subscriber_queue is None:
@@ -664,12 +698,12 @@ async def resume_session_stream(
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
await stream_registry.unsubscribe_from_task(
active_task.task_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
exc_info=True,
)
logger.info(
@@ -720,6 +754,229 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
"""
# Check task existence and expiry before subscribing
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
if error_code == "TASK_EXPIRED":
raise HTTPException(
status_code=410,
detail={
"code": "TASK_EXPIRED",
"message": "This operation has expired. Please try again.",
},
)
if error_code == "TASK_NOT_FOUND":
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found.",
},
)
# Validate ownership if task has an owner
if task and task.user_id and user_id != task.user_id:
raise HTTPException(
status_code=403,
detail={
"code": "ACCESS_DENIED",
"message": "You do not have access to this task.",
},
)
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found or access denied.",
},
)
async def event_generator() -> AsyncGenerator[str, None]:
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
try:
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership - if task has an owner, requester must match
if task.user_id and user_id != task.user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
await process_operation_success(task, request.result)
else:
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}
# ========== Configuration ==========
@@ -794,12 +1051,14 @@ ToolResponseUnion = (
| AgentPreviewResponse
| AgentSavedResponse
| ClarificationNeededResponse
| SuggestedGoalResponse
| BlockListResponse
| BlockDetailsResponse
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
)

View File

@@ -0,0 +1,203 @@
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This module provides the adapter layer that converts streaming messages from
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
the frontend expects.
"""
import json
import logging
import uuid
from claude_agent_sdk import (
AssistantMessage,
Message,
ResultMessage,
SystemMessage,
TextBlock,
ToolResultBlock,
ToolUseBlock,
UserMessage,
)
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.api.features.chat.sdk.tool_adapter import (
MCP_TOOL_PREFIX,
pop_pending_tool_output,
)
logger = logging.getLogger(__name__)
class SDKResponseAdapter:
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
This class maintains state during a streaming session to properly track
text blocks, tool calls, and message lifecycle.
"""
def __init__(self, message_id: str | None = None):
self.message_id = message_id or str(uuid.uuid4())
self.text_block_id = str(uuid.uuid4())
self.has_started_text = False
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.task_id: str | None = None
self.step_open = False
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
self.task_id = task_id
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
"""Convert a single SDK message to Vercel AI SDK format."""
responses: list[StreamBaseResponse] = []
if isinstance(sdk_message, SystemMessage):
if sdk_message.subtype == "init":
responses.append(
StreamStart(messageId=self.message_id, taskId=self.task_id)
)
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())
self.step_open = True
elif isinstance(sdk_message, AssistantMessage):
# After tool results, the SDK sends a new AssistantMessage for the
# next LLM turn. Open a new step if the previous one was closed.
if not self.step_open:
responses.append(StreamStartStep())
self.step_open = True
for block in sdk_message.content:
if isinstance(block, TextBlock):
if block.text:
self._ensure_text_started(responses)
responses.append(
StreamTextDelta(id=self.text_block_id, delta=block.text)
)
elif isinstance(block, ToolUseBlock):
self._end_text_if_open(responses)
# Strip MCP prefix so frontend sees "find_block"
# instead of "mcp__copilot__find_block".
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
responses.append(
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
)
responses.append(
StreamToolInputAvailable(
toolCallId=block.id,
toolName=tool_name,
input=block.input,
)
)
self.current_tool_calls[block.id] = {"name": tool_name}
elif isinstance(sdk_message, UserMessage):
# UserMessage carries tool results back from tool execution.
content = sdk_message.content
blocks = content if isinstance(content, list) else []
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
tool_name = tool_info.get("name", "unknown")
# Prefer the stashed full output over the SDK's
# (potentially truncated) ToolResultBlock content.
# The SDK truncates large results, writing them to disk,
# which breaks frontend widget parsing.
output = pop_pending_tool_output(tool_name) or (
_extract_tool_output(block.content)
)
responses.append(
StreamToolOutputAvailable(
toolCallId=block.tool_use_id,
toolName=tool_name,
output=output,
success=not (block.is_error or False),
)
)
# Close the current step after tool results — the next
# AssistantMessage will open a new step for the continuation.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
elif isinstance(sdk_message, ResultMessage):
self._end_text_if_open(responses)
# Close the step before finishing.
if self.step_open:
responses.append(StreamFinishStep())
self.step_open = False
if sdk_message.subtype == "success":
responses.append(StreamFinish())
elif sdk_message.subtype in ("error", "error_during_execution"):
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
responses.append(
StreamError(errorText=str(error_msg), code="sdk_error")
)
responses.append(StreamFinish())
else:
logger.warning(
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
)
responses.append(StreamFinish())
else:
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
return responses
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
"""Start (or restart) a text block if needed."""
if not self.has_started_text or self.has_ended_text:
if self.has_ended_text:
self.text_block_id = str(uuid.uuid4())
self.has_ended_text = False
responses.append(StreamTextStart(id=self.text_block_id))
self.has_started_text = True
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
"""End the current text block if one is open."""
if self.has_started_text and not self.has_ended_text:
responses.append(StreamTextEnd(id=self.text_block_id))
self.has_ended_text = True
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
"""Extract a string output from a ToolResultBlock's content field."""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
if parts:
return "".join(parts)
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)
if content is None:
return ""
try:
return json.dumps(content)
except (TypeError, ValueError):
return str(content)

View File

@@ -1,8 +1,5 @@
"""Unit tests for the SDK response adapter."""
import asyncio
import pytest
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
@@ -13,7 +10,7 @@ from claude_agent_sdk import (
UserMessage,
)
from backend.copilot.response_model import (
from backend.api.features.chat.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
@@ -30,14 +27,12 @@ from backend.copilot.response_model import (
from .response_adapter import SDKResponseAdapter
from .tool_adapter import MCP_TOOL_PREFIX
from .tool_adapter import _pending_tool_outputs as _pto
from .tool_adapter import _stash_event
from .tool_adapter import stash_pending_tool_output as _stash
from .tool_adapter import wait_for_stash
def _adapter() -> SDKResponseAdapter:
return SDKResponseAdapter(message_id="msg-1", session_id="session-1")
a = SDKResponseAdapter(message_id="msg-1")
a.set_task_id("task-1")
return a
# -- SystemMessage -----------------------------------------------------------
@@ -49,7 +44,7 @@ def test_system_init_emits_start_and_step():
assert len(results) == 2
assert isinstance(results[0], StreamStart)
assert results[0].messageId == "msg-1"
assert results[0].sessionId == "session-1"
assert results[0].taskId == "task-1"
assert isinstance(results[1], StreamStartStep)
@@ -369,310 +364,3 @@ def test_full_conversation_flow():
"StreamFinishStep", # step 2 closed
"StreamFinish",
]
# -- Flush unresolved tool calls --------------------------------------------
def test_flush_unresolved_at_result_message():
"""Built-in tools (WebSearch) without UserMessage results get flushed at ResultMessage."""
adapter = _adapter()
all_responses: list[StreamBaseResponse] = []
# 1. Init
all_responses.extend(
adapter.convert_message(SystemMessage(subtype="init", data={}))
)
# 2. Tool use (built-in tool — no MCP prefix)
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
],
model="test",
)
)
)
# 3. No UserMessage for this tool — go straight to ResultMessage
all_responses.extend(
adapter.convert_message(
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
)
)
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamStartStep",
"StreamToolInputStart",
"StreamToolInputAvailable",
"StreamToolOutputAvailable", # flushed with empty output
"StreamFinishStep", # step closed by flush
"StreamFinish",
]
# The flushed output should be empty (no stash available)
output_event = [
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
][0]
assert output_event.toolCallId == "ws-1"
assert output_event.toolName == "WebSearch"
assert output_event.output == ""
def test_flush_unresolved_at_next_assistant_message():
"""Built-in tools get flushed when the next AssistantMessage arrives."""
adapter = _adapter()
all_responses: list[StreamBaseResponse] = []
# 1. Init
all_responses.extend(
adapter.convert_message(SystemMessage(subtype="init", data={}))
)
# 2. Tool use (built-in — no UserMessage will come)
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
],
model="test",
)
)
)
# 3. Next AssistantMessage triggers flush before processing its blocks
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[TextBlock(text="Here are the results")], model="test"
)
)
)
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamStartStep",
"StreamToolInputStart",
"StreamToolInputAvailable",
# Flush at next AssistantMessage:
"StreamToolOutputAvailable",
"StreamFinishStep", # step closed by flush
# New step for continuation text:
"StreamStartStep",
"StreamTextStart",
"StreamTextDelta",
]
def test_flush_with_stashed_output():
"""Stashed output from PostToolUse hook is used when flushing."""
adapter = _adapter()
# Simulate PostToolUse hook stashing output
_pto.set({})
_stash("WebSearch", "Search result: 5 items found")
all_responses: list[StreamBaseResponse] = []
# Tool use
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
],
model="test",
)
)
)
# ResultMessage triggers flush
all_responses.extend(
adapter.convert_message(
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
)
)
output_events = [
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
]
assert len(output_events) == 1
assert output_events[0].output == "Search result: 5 items found"
# Cleanup
_pto.set({}) # type: ignore[arg-type]
# -- wait_for_stash synchronisation tests --
@pytest.mark.asyncio
async def test_wait_for_stash_signaled():
"""wait_for_stash returns True when stash_pending_tool_output signals."""
_pto.set({})
event = asyncio.Event()
_stash_event.set(event)
# Simulate a PostToolUse hook that stashes output after a short delay
async def delayed_stash():
await asyncio.sleep(0.01)
_stash("WebSearch", "result data")
asyncio.create_task(delayed_stash())
result = await wait_for_stash(timeout=1.0)
assert result is True
assert _pto.get({}).get("WebSearch") == ["result data"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
@pytest.mark.asyncio
async def test_wait_for_stash_timeout():
"""wait_for_stash returns False on timeout when no stash occurs."""
_pto.set({})
event = asyncio.Event()
_stash_event.set(event)
result = await wait_for_stash(timeout=0.05)
assert result is False
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
@pytest.mark.asyncio
async def test_wait_for_stash_already_stashed():
"""wait_for_stash picks up a stash that happened just before the wait."""
_pto.set({})
event = asyncio.Event()
_stash_event.set(event)
# Stash before waiting — simulates hook completing before message arrives
_stash("Read", "file contents")
# Event is now set; wait_for_stash detects the fast path and returns
# immediately without timing out.
result = await wait_for_stash(timeout=0.05)
assert result is True
# But the stash itself is populated
assert _pto.get({}).get("Read") == ["file contents"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
# -- Parallel tool call tests --
def test_parallel_tool_calls_not_flushed_prematurely():
"""Parallel tool calls should NOT be flushed when the next AssistantMessage
only contains ToolUseBlocks (parallel continuation)."""
adapter = SDKResponseAdapter()
# Init
adapter.convert_message(SystemMessage(subtype="init", data={}))
# First AssistantMessage: tool call #1
msg1 = AssistantMessage(
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
model="test",
)
r1 = adapter.convert_message(msg1)
assert any(isinstance(r, StreamToolInputAvailable) for r in r1)
assert adapter.has_unresolved_tool_calls
# Second AssistantMessage: tool call #2 (parallel continuation)
msg2 = AssistantMessage(
content=[ToolUseBlock(id="t2", name="WebSearch", input={"q": "bar"})],
model="test",
)
r2 = adapter.convert_message(msg2)
# No flush should have happened — t1 should NOT have StreamToolOutputAvailable
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
assert len(output_events) == 0, (
f"Tool-only AssistantMessage should not flush prior tools, "
f"but got {len(output_events)} output events"
)
# Both t1 and t2 should still be unresolved
assert "t1" not in adapter.resolved_tool_calls
assert "t2" not in adapter.resolved_tool_calls
def test_text_assistant_message_flushes_prior_tools():
"""An AssistantMessage with text (new turn) should flush unresolved tools."""
adapter = SDKResponseAdapter()
# Init
adapter.convert_message(SystemMessage(subtype="init", data={}))
# Tool call
msg1 = AssistantMessage(
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
model="test",
)
adapter.convert_message(msg1)
assert adapter.has_unresolved_tool_calls
# Text AssistantMessage (new turn after tools completed)
msg2 = AssistantMessage(
content=[TextBlock(text="Here are the results")],
model="test",
)
r2 = adapter.convert_message(msg2)
# Flush SHOULD have happened — t1 gets empty output
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
assert len(output_events) == 1
assert output_events[0].toolCallId == "t1"
assert "t1" in adapter.resolved_tool_calls
def test_already_resolved_tool_skipped_in_user_message():
"""A tool result in UserMessage should be skipped if already resolved by flush."""
adapter = SDKResponseAdapter()
adapter.convert_message(SystemMessage(subtype="init", data={}))
# Tool call + flush via text message
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name="WebSearch", input={})],
model="test",
)
)
adapter.convert_message(
AssistantMessage(
content=[TextBlock(text="Done")],
model="test",
)
)
assert "t1" in adapter.resolved_tool_calls
# Now UserMessage arrives with the real result — should be skipped
user_msg = UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="real")])
r = adapter.convert_message(user_msg)
output_events = [r_ for r_ in r if isinstance(r_, StreamToolOutputAvailable)]
assert (
len(output_events) == 0
), "Already-resolved tool should not emit duplicate output"

View File

@@ -11,12 +11,11 @@ import re
from collections.abc import Callable
from typing import Any, cast
from .tool_adapter import (
from backend.api.features.chat.sdk.tool_adapter import (
BLOCKED_TOOLS,
DANGEROUS_PATTERNS,
MCP_TOOL_PREFIX,
WORKSPACE_SCOPED_TOOLS,
stash_pending_tool_output,
)
logger = logging.getLogger(__name__)
@@ -124,20 +123,20 @@ def _validate_user_isolation(
"""Validate that tool calls respect user isolation."""
# For workspace file tools, ensure path doesn't escape
if "workspace" in tool_name.lower():
# The "path" param is a cloud storage key (e.g. "/ASEAN/report.md")
# where a leading "/" is normal. Only check for ".." traversal.
# Filesystem paths (source_path, save_to_path) are validated inside
# the tool itself via _validate_ephemeral_path.
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path and ".." in path:
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": "Path traversal not allowed",
if path:
# Check for path traversal
if ".." in path or path.startswith("/"):
logger.warning(
f"Blocked path traversal attempt: {path} by user {user_id}"
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": "Path traversal not allowed",
}
}
}
return {}
@@ -160,7 +159,7 @@ def create_security_hooks(
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
max_subtasks: Maximum Task (sub-agent) spawns allowed per session
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
the SDK finishes processing used to read the JSONL transcript
before the CLI process exits.
@@ -172,9 +171,8 @@ def create_security_hooks(
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
# Per-session tracking for Task sub-agent concurrency.
# Set of tool_use_ids that consumed a slot — len() is the active count.
task_tool_use_ids: set[str] = set()
# Per-session counter for Task sub-agent spawns
task_spawn_count = 0
async def pre_tool_use_hook(
input_data: HookInput,
@@ -182,34 +180,23 @@ def create_security_hooks(
context: HookContext,
) -> SyncHookJSONOutput:
"""Combined pre-tool-use validation hook."""
nonlocal task_spawn_count
_ = context # unused but required by signature
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
# Rate-limit Task (sub-agent) spawns per session
if tool_name == "Task":
# Block background task execution first — denied calls
# should not consume a subtask slot.
if tool_input.get("run_in_background"):
logger.info(f"[SDK] Blocked background Task, user={user_id}")
return cast(
SyncHookJSONOutput,
_deny(
"Background task execution is not supported. "
"Run tasks in the foreground instead "
"(remove the run_in_background parameter)."
),
)
if len(task_tool_use_ids) >= max_subtasks:
task_spawn_count += 1
if task_spawn_count > max_subtasks:
logger.warning(
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
)
return cast(
SyncHookJSONOutput,
_deny(
f"Maximum {max_subtasks} concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
"or continue in the main conversation."
f"Maximum {max_subtasks} sub-tasks per session. "
"Please continue in the main conversation."
),
)
@@ -229,68 +216,18 @@ def create_security_hooks(
if result:
return cast(SyncHookJSONOutput, result)
# Reserve the Task slot only after all validations pass
if tool_name == "Task" and tool_use_id is not None:
task_tool_use_ids.add(tool_use_id)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
"""Release a Task concurrency slot if one was reserved."""
if tool_name == "Task" and tool_use_id in task_tool_use_ids:
task_tool_use_ids.discard(tool_use_id)
logger.info(
"[SDK] Task slot released, active=%d/%d, user=%s",
len(task_tool_use_ids),
max_subtasks,
user_id,
)
async def post_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
context: HookContext,
) -> SyncHookJSONOutput:
"""Log successful tool executions and stash SDK built-in tool outputs.
MCP tools stash their output in ``_execute_tool_sync`` before the
SDK can truncate it. SDK built-in tools (WebSearch, Read, etc.)
are executed by the CLI internally this hook captures their
output so the response adapter can forward it to the frontend.
"""
"""Log successful tool executions for observability."""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
_release_task_slot(tool_name, tool_use_id)
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
logger.info(
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
tool_name,
is_builtin,
(tool_use_id or "")[:12],
)
# Stash output for SDK built-in tools so the response adapter can
# emit StreamToolOutputAvailable even when the CLI doesn't surface
# a separate UserMessage with ToolResultBlock content.
if is_builtin:
tool_response = input_data.get("tool_response")
if tool_response is not None:
resp_preview = str(tool_response)[:100]
logger.info(
"[SDK] Stashing builtin output for %s (%d chars): %s...",
tool_name,
len(str(tool_response)),
resp_preview,
)
stash_pending_tool_output(tool_name, tool_response)
else:
logger.warning(
"[SDK] PostToolUse for builtin %s but tool_response is None",
tool_name,
)
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
return cast(SyncHookJSONOutput, {})
async def post_tool_failure_hook(
@@ -306,9 +243,6 @@ def create_security_hooks(
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
_release_task_slot(tool_name, tool_use_id)
return cast(SyncHookJSONOutput, {})
async def pre_compact_hook(

View File

@@ -0,0 +1,165 @@
"""Unit tests for SDK security hooks."""
import os
from .security_hooks import _validate_tool_access, _validate_user_isolation
SDK_CWD = "/tmp/copilot-abc123"
def _is_denied(result: dict) -> bool:
hook = result.get("hookSpecificOutput", {})
return hook.get("permissionDecision") == "deny"
# -- Blocked tools -----------------------------------------------------------
def test_blocked_tools_denied():
for tool in ("bash", "shell", "exec", "terminal", "command"):
result = _validate_tool_access(tool, {})
assert _is_denied(result), f"{tool} should be blocked"
def test_unknown_tool_allowed():
result = _validate_tool_access("SomeCustomTool", {})
assert result == {}
# -- Workspace-scoped tools --------------------------------------------------
def test_read_within_workspace_allowed():
result = _validate_tool_access(
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_write_within_workspace_allowed():
result = _validate_tool_access(
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_edit_within_workspace_allowed():
result = _validate_tool_access(
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
)
assert result == {}
def test_glob_within_workspace_allowed():
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_grep_within_workspace_allowed():
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_outside_workspace_denied():
result = _validate_tool_access(
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_write_outside_workspace_denied():
result = _validate_tool_access(
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
)
assert _is_denied(result)
def test_traversal_attack_denied():
result = _validate_tool_access(
"Read",
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
sdk_cwd=SDK_CWD,
)
assert _is_denied(result)
def test_no_path_allowed():
"""Glob/Grep without a path argument defaults to cwd — should pass."""
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_no_cwd_denies_absolute():
"""If no sdk_cwd is set, absolute paths are denied."""
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
assert _is_denied(result)
# -- Tool-results directory --------------------------------------------------
def test_read_tool_results_allowed():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_claude_projects_without_tool_results_denied():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
def test_bash_builtin_always_blocked():
"""SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead."""
result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Dangerous patterns ------------------------------------------------------
def test_dangerous_pattern_blocked():
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
assert _is_denied(result)
def test_subprocess_pattern_blocked():
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
assert _is_denied(result)
# -- User isolation ----------------------------------------------------------
def test_workspace_path_traversal_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_absolute_path_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
)
assert _is_denied(result)
def test_workspace_normal_path_allowed():
result = _validate_user_isolation(
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
)
assert result == {}
def test_non_workspace_tool_passes_isolation():
result = _validate_user_isolation(
"find_agent", {"query": "email"}, user_id="user-1"
)
assert result == {}

View File

@@ -0,0 +1,752 @@
"""Claude Agent SDK service layer for CoPilot chat completions."""
import asyncio
import json
import logging
import os
import uuid
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import Any
from backend.util.exceptions import NotFoundError
from .. import stream_registry
from ..config import ChatConfig
from ..model import (
ChatMessage,
ChatSession,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from ..response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamStart,
StreamTextDelta,
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from ..service import (
_build_system_prompt,
_execute_long_running_tool_with_streaming,
_generate_session_title,
)
from ..tools.models import OperationPendingResponse, OperationStartedResponse
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .tool_adapter import (
COPILOT_TOOL_NAMES,
SDK_DISALLOWED_TOOLS,
LongRunningCallback,
create_copilot_mcp_server,
set_execution_context,
)
from .transcript import (
download_transcript,
read_transcript_file,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
)
logger = logging.getLogger(__name__)
config = ChatConfig()
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
@dataclass
class CapturedTranscript:
"""Info captured by the SDK Stop hook for stateless --resume."""
path: str = ""
sdk_session_id: str = ""
@property
def available(self) -> bool:
return bool(self.path)
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
# Appended to the system prompt to inform the agent about available tools.
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
# which has kernel-level network isolation (unshare --net).
_SDK_TOOL_SUPPLEMENT = """
## Tool notes
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
for shell commands — it runs in a network-isolated sandbox.
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
same working directory. Files created by one are readable by the other.
These files are **ephemeral** — they exist only for the current session.
- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file`
for files that should persist across sessions (stored in cloud storage).
- Long-running tools (create_agent, edit_agent, etc.) are handled
asynchronously. You will receive an immediate response; the actual result
is delivered to the user via a background stream.
"""
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
existing background infrastructure: stream_registry (Redis Streams),
database persistence, and SSE reconnection. This means results survive
page refreshes / pod restarts, and the frontend shows the proper loading
widget with progress updates.
The returned callback matches the ``LongRunningCallback`` signature:
``(tool_name, args, session) -> MCP response dict``.
"""
async def _callback(
tool_name: str, args: dict[str, Any], session: ChatSession
) -> dict[str, Any]:
operation_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}"
session_id = session.session_id
# --- Build user-friendly messages (matches non-SDK service) ---
if tool_name == "create_agent":
desc = args.get("description", "")
desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc
pending_msg = (
f"Creating your agent: {desc_preview}"
if desc_preview
else "Creating agent... This may take a few minutes."
)
started_msg = (
"Agent creation started. You can close this tab - "
"check your library in a few minutes."
)
elif tool_name == "edit_agent":
changes = args.get("changes", "")
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
pending_msg = (
f"Editing agent: {changes_preview}"
if changes_preview
else "Editing agent... This may take a few minutes."
)
started_msg = (
"Agent edit started. You can close this tab - "
"check your library in a few minutes."
)
else:
pending_msg = f"Running {tool_name}... This may take a few minutes."
started_msg = (
f"{tool_name} started. You can close this tab - "
"check back in a few minutes."
)
# --- Register task in Redis for SSE reconnection ---
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id=tool_call_id,
tool_name=tool_name,
operation_id=operation_id,
)
# --- Save OperationPendingResponse to chat history ---
pending_message = ChatMessage(
role="tool",
content=OperationPendingResponse(
message=pending_msg,
operation_id=operation_id,
tool_name=tool_name,
).model_dump_json(),
tool_call_id=tool_call_id,
)
session.messages.append(pending_message)
await upsert_chat_session(session)
# --- Spawn background task (reuses non-SDK infrastructure) ---
bg_task = asyncio.create_task(
_execute_long_running_tool_with_streaming(
tool_name=tool_name,
parameters=args,
tool_call_id=tool_call_id,
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
user_id=user_id,
)
)
_background_tasks.add(bg_task)
bg_task.add_done_callback(_background_tasks.discard)
await stream_registry.set_task_asyncio_task(task_id, bg_task)
logger.info(
f"[SDK] Long-running tool {tool_name} delegated to background "
f"(operation_id={operation_id}, task_id={task_id})"
)
# --- Return OperationStartedResponse as MCP tool result ---
# This flows through SDK → response adapter → frontend, triggering
# the loading widget with SSE reconnection support.
started_json = OperationStartedResponse(
message=started_msg,
operation_id=operation_id,
tool_name=tool_name,
task_id=task_id,
).model_dump_json()
return {
"content": [{"type": "text", "text": started_json}],
"isError": False,
}
return _callback
def _resolve_sdk_model() -> str | None:
"""Resolve the model name for the Claude Agent SDK CLI.
Uses ``config.claude_agent_model`` if set, otherwise derives from
``config.model`` by stripping the OpenRouter provider prefix (e.g.,
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``).
"""
if config.claude_agent_model:
return config.claude_agent_model
model = config.model
if "/" in model:
return model.split("/", 1)[1]
return model
def _build_sdk_env() -> dict[str, str]:
"""Build env vars for the SDK CLI process.
Routes API calls through OpenRouter (or a custom base_url) using
the same ``config.api_key`` / ``config.base_url`` as the non-SDK path.
This gives per-call token and cost tracking on the OpenRouter dashboard.
Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth
token are both present — otherwise returns an empty dict so the SDK
falls back to its default credentials.
"""
env: dict[str, str] = {}
if config.api_key and config.base_url:
# Strip /v1 suffix — SDK expects the base URL without a version path
base = config.base_url.rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
if not base or not base.startswith("http"):
# Invalid base_url — don't override SDK defaults
return env
env["ANTHROPIC_BASE_URL"] = base
env["ANTHROPIC_AUTH_TOKEN"] = config.api_key
# Must be explicitly empty so the CLI uses AUTH_TOKEN instead
env["ANTHROPIC_API_KEY"] = ""
return env
def _make_sdk_cwd(session_id: str) -> str:
"""Create a safe, session-specific working directory path.
Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path`
(single source of truth for path sanitization) and adds a defence-in-depth
assertion.
"""
cwd = make_session_path(session_id)
# Defence-in-depth: normpath + startswith is a CodeQL-recognised sanitizer
cwd = os.path.normpath(cwd)
if not cwd.startswith(_SDK_CWD_PREFIX):
raise ValueError(f"SDK cwd escaped prefix: {cwd}")
return cwd
def _cleanup_sdk_tool_results(cwd: str) -> None:
"""Remove SDK tool-result files for a specific session working directory.
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
We clean only the specific cwd's results to avoid race conditions between
concurrent sessions.
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
"""
import shutil
# Validate cwd is under the expected prefix
normalized = os.path.normpath(cwd)
if not normalized.startswith(_SDK_CWD_PREFIX):
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
return
# SDK encodes the cwd path by replacing '/' with '-'
encoded_cwd = normalized.replace("/", "-")
# Construct the project directory path (known-safe home expansion)
claude_projects = os.path.expanduser("~/.claude/projects")
project_dir = os.path.join(claude_projects, encoded_cwd)
# Security check 3: Validate project_dir is under ~/.claude/projects
project_dir = os.path.normpath(project_dir)
if not project_dir.startswith(claude_projects):
logger.warning(
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
)
return
results_dir = os.path.join(project_dir, "tool-results")
if os.path.isdir(results_dir):
for filename in os.listdir(results_dir):
file_path = os.path.join(results_dir, filename)
try:
if os.path.isfile(file_path):
os.remove(file_path)
except OSError:
pass
# Also clean up the temp cwd directory itself
try:
shutil.rmtree(normalized, ignore_errors=True)
except OSError:
pass
async def _compress_conversation_history(
session: ChatSession,
) -> list[ChatMessage]:
"""Compress prior conversation messages if they exceed the token threshold.
Uses the shared compress_context() from prompt.py which supports:
- LLM summarization of old messages (keeps recent ones intact)
- Progressive content truncation as fallback
- Middle-out deletion as last resort
Returns the compressed prior messages (everything except the current message).
"""
prior = session.messages[:-1]
if len(prior) < 2:
return prior
from backend.util.prompt import compress_context
# Convert ChatMessages to dicts for compress_context
messages_dict = []
for msg in prior:
msg_dict: dict[str, Any] = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
if msg.tool_calls:
msg_dict["tool_calls"] = msg.tool_calls
if msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages_dict.append(msg_dict)
try:
import openai
async with openai.AsyncOpenAI(
api_key=config.api_key, base_url=config.base_url, timeout=30.0
) as client:
result = await compress_context(
messages=messages_dict,
model=config.model,
client=client,
)
except Exception as e:
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
# Fall back to truncation-only (no LLM summarization)
result = await compress_context(
messages=messages_dict,
model=config.model,
client=None,
)
if result.was_compacted:
logger.info(
f"[SDK] Context compacted: {result.original_token_count} -> "
f"{result.token_count} tokens "
f"({result.messages_summarized} summarized, "
f"{result.messages_dropped} dropped)"
)
# Convert compressed dicts back to ChatMessages
return [
ChatMessage(
role=m["role"],
content=m.get("content"),
tool_calls=m.get("tool_calls"),
tool_call_id=m.get("tool_call_id"),
)
for m in result.messages
]
return prior
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
"""Format conversation messages into a context prefix for the user message.
Returns a string like:
<conversation_history>
User: hello
You responded: Hi! How can I help?
</conversation_history>
Returns None if there are no messages to format.
"""
if not messages:
return None
lines: list[str] = []
for msg in messages:
if not msg.content:
continue
if msg.role == "user":
lines.append(f"User: {msg.content}")
elif msg.role == "assistant":
lines.append(f"You responded: {msg.content}")
# Skip tool messages — they're internal details
if not lines:
return None
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
async def stream_chat_completion_sdk(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None, # noqa: ARG001
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0, # noqa: ARG001
session: ChatSession | None = None,
context: dict[str, str] | None = None, # noqa: ARG001
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream chat completion using Claude Agent SDK.
Drop-in replacement for stream_chat_completion with improved reliability.
"""
if session is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(
f"Session {session_id} not found. Please create a new session first."
)
if message:
session.messages.append(
ChatMessage(
role="user" if is_user_message else "assistant", content=message
)
)
if is_user_message:
track_user_message(
user_id=user_id, session_id=session_id, message_length=len(message)
)
session = await upsert_chat_session(session)
# Generate title for new sessions (first user message)
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
if first_message:
task = asyncio.create_task(
_update_title_async(session_id, first_message, user_id)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# Build system prompt (reuses non-SDK path with Langfuse support)
has_history = len(session.messages) > 1
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=has_history
)
system_prompt += _SDK_TOOL_SUPPLEMENT
message_id = str(uuid.uuid4())
task_id = str(uuid.uuid4())
yield StreamStart(messageId=message_id, taskId=task_id)
stream_completed = False
# Initialise sdk_cwd before the try so the finally can reference it
# even if _make_sdk_cwd raises (in that case it stays as "").
sdk_cwd = ""
use_resume = False
try:
# Use a session-specific temp dir to avoid cleanup race conditions
# between concurrent sessions.
sdk_cwd = _make_sdk_cwd(session_id)
os.makedirs(sdk_cwd, exist_ok=True)
set_execution_context(
user_id,
session,
long_running_callback=_build_long_running_callback(user_id),
)
try:
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
# Fail fast when no API credentials are available at all
sdk_env = _build_sdk_env()
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
raise RuntimeError(
"No API key configured. Set OPEN_ROUTER_API_KEY "
"(or CHAT_API_KEY) for OpenRouter routing, "
"or ANTHROPIC_API_KEY for direct Anthropic access."
)
mcp_server = create_copilot_mcp_server()
sdk_model = _resolve_sdk_model()
# --- Transcript capture via Stop hook ---
captured_transcript = CapturedTranscript()
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
captured_transcript.path = transcript_path
captured_transcript.sdk_session_id = sdk_session_id
security_hooks = create_security_hooks(
user_id,
sdk_cwd=sdk_cwd,
max_subtasks=config.claude_agent_max_subtasks,
on_stop=_on_stop if config.claude_agent_use_resume else None,
)
# --- Resume strategy: download transcript from bucket ---
resume_file: str | None = None
use_resume = False
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
transcript_content = await download_transcript(user_id, session_id)
if transcript_content and validate_transcript(transcript_content):
resume_file = write_transcript_to_tempfile(
transcript_content, session_id, sdk_cwd
)
if resume_file:
use_resume = True
logger.info(
f"[SDK] Using --resume with transcript "
f"({len(transcript_content)} bytes)"
)
sdk_options_kwargs: dict[str, Any] = {
"system_prompt": system_prompt,
"mcp_servers": {"copilot": mcp_server},
"allowed_tools": COPILOT_TOOL_NAMES,
"disallowed_tools": SDK_DISALLOWED_TOOLS,
"hooks": security_hooks,
"cwd": sdk_cwd,
"max_buffer_size": config.claude_agent_max_buffer_size,
}
if sdk_env:
sdk_options_kwargs["model"] = sdk_model
sdk_options_kwargs["env"] = sdk_env
if use_resume and resume_file:
sdk_options_kwargs["resume"] = resume_file
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
adapter = SDKResponseAdapter(message_id=message_id)
adapter.set_task_id(task_id)
async with ClaudeSDKClient(options=options) as client:
current_message = message or ""
if not current_message and session.messages:
last_user = [m for m in session.messages if m.role == "user"]
if last_user:
current_message = last_user[-1].content or ""
if not current_message.strip():
yield StreamError(
errorText="Message cannot be empty.",
code="empty_prompt",
)
yield StreamFinish()
return
# Build query: with --resume the CLI already has full
# context, so we only send the new message. Without
# resume, compress history into a context prefix.
query_message = current_message
if not use_resume and len(session.messages) > 1:
logger.warning(
f"[SDK] Using compression fallback for session "
f"{session_id} ({len(session.messages)} messages) — "
f"no transcript available for --resume"
)
compressed = await _compress_conversation_history(session)
history_context = _format_conversation_context(compressed)
if history_context:
query_message = (
f"{history_context}\n\n"
f"Now, the user says:\n{current_message}"
)
logger.info(
f"[SDK] Sending query ({len(session.messages)} msgs in session)"
)
logger.debug(f"[SDK] Query preview: {current_message[:80]!r}")
await client.query(query_message, session_id=session_id)
assistant_response = ChatMessage(role="assistant", content="")
accumulated_tool_calls: list[dict[str, Any]] = []
has_appended_assistant = False
has_tool_results = False
async for sdk_msg in client.receive_messages():
logger.debug(
f"[SDK] Received: {type(sdk_msg).__name__} "
f"{getattr(sdk_msg, 'subtype', '')}"
)
for response in adapter.convert_message(sdk_msg):
if isinstance(response, StreamStart):
continue
yield response
if isinstance(response, StreamTextDelta):
delta = response.delta or ""
# After tool results, start a new assistant
# message for the post-tool text.
if has_tool_results and has_appended_assistant:
assistant_response = ChatMessage(
role="assistant", content=delta
)
accumulated_tool_calls = []
has_appended_assistant = False
has_tool_results = False
session.messages.append(assistant_response)
has_appended_assistant = True
else:
assistant_response.content = (
assistant_response.content or ""
) + delta
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolInputAvailable):
accumulated_tool_calls.append(
{
"id": response.toolCallId,
"type": "function",
"function": {
"name": response.toolName,
"arguments": json.dumps(response.input or {}),
},
}
)
assistant_response.tool_calls = accumulated_tool_calls
if not has_appended_assistant:
session.messages.append(assistant_response)
has_appended_assistant = True
elif isinstance(response, StreamToolOutputAvailable):
session.messages.append(
ChatMessage(
role="tool",
content=(
response.output
if isinstance(response.output, str)
else str(response.output)
),
tool_call_id=response.toolCallId,
)
)
has_tool_results = True
elif isinstance(response, StreamFinish):
stream_completed = True
if stream_completed:
break
if (
assistant_response.content or assistant_response.tool_calls
) and not has_appended_assistant:
session.messages.append(assistant_response)
# --- Capture transcript while CLI is still alive ---
# Must happen INSIDE async with: close() sends SIGTERM
# which kills the CLI before it can flush the JSONL.
if (
config.claude_agent_use_resume
and user_id
and captured_transcript.available
):
# Give CLI time to flush JSONL writes before we read
await asyncio.sleep(0.5)
raw_transcript = read_transcript_file(captured_transcript.path)
if raw_transcript:
task = asyncio.create_task(
_upload_transcript_bg(user_id, session_id, raw_transcript)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
else:
logger.debug("[SDK] Stop hook fired but transcript not usable")
except ImportError:
raise RuntimeError(
"claude-agent-sdk is not installed. "
"Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) "
"to use the OpenAI-compatible fallback."
)
await upsert_chat_session(session)
logger.debug(
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
)
if not stream_completed:
yield StreamFinish()
except Exception as e:
logger.error(f"[SDK] Error: {e}", exc_info=True)
try:
await upsert_chat_session(session)
except Exception as save_err:
logger.error(f"[SDK] Failed to save session on error: {save_err}")
yield StreamError(
errorText="An error occurred. Please try again.",
code="sdk_error",
)
yield StreamFinish()
finally:
if sdk_cwd:
_cleanup_sdk_tool_results(sdk_cwd)
async def _upload_transcript_bg(
user_id: str, session_id: str, raw_content: str
) -> None:
"""Background task to strip progress entries and upload transcript."""
try:
await upload_transcript(user_id, session_id, raw_content)
except Exception as e:
logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}")
async def _update_title_async(
session_id: str, message: str, user_id: str | None = None
) -> None:
"""Background task to update session title."""
try:
title = await _generate_session_title(
message, user_id=user_id, session_id=session_id
)
if title:
await update_session_title(session_id, title)
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
except Exception as e:
logger.warning(f"[SDK] Failed to update session title: {e}")

View File

@@ -2,20 +2,25 @@
This module provides the adapter layer that converts existing BaseTool implementations
into in-process MCP tools that can be used with the Claude Agent SDK.
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
via a callback provided by the service layer. This avoids wasteful SDK polling
and makes results survive page refreshes.
"""
import asyncio
import itertools
import json
import logging
import os
import uuid
from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from typing import Any
from backend.copilot.model import ChatSession
from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools.base import BaseTool
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools import TOOL_REGISTRY
from backend.api.features.chat.tools.base import BaseTool
logger = logging.getLogger(__name__)
@@ -36,23 +41,26 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
"pending_tool_outputs",
default=None, # type: ignore[arg-type]
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
"pending_tool_outputs", default=None # type: ignore[arg-type]
)
# Event signaled whenever stash_pending_tool_output() adds a new entry.
# Used by the streaming loop to wait for PostToolUse hooks to complete
# instead of sleeping an arbitrary duration. The SDK fires hooks via
# start_soon (fire-and-forget) so the next message can arrive before
# the hook stashes its output — this event bridges that gap.
_stash_event: ContextVar[asyncio.Event | None] = ContextVar(
"_stash_event", default=None
# Callback type for delegating long-running tools to the non-SDK infrastructure.
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
LongRunningCallback = Callable[
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
]
# ContextVar so the service layer can inject the callback per-request.
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
"long_running_callback", default=None
)
def set_execution_context(
user_id: str | None,
session: ChatSession,
long_running_callback: LongRunningCallback | None = None,
) -> None:
"""Set the execution context for tool calls.
@@ -62,11 +70,13 @@ def set_execution_context(
Args:
user_id: Current user's ID.
session: Current chat session.
long_running_callback: Optional callback to delegate long-running tools
to the non-SDK background infrastructure (stream_registry + Redis).
"""
_current_user_id.set(user_id)
_current_session.set(session)
_pending_tool_outputs.set({})
_stash_event.set(asyncio.Event())
_long_running_callback.set(long_running_callback)
def get_execution_context() -> tuple[str | None, ChatSession | None]:
@@ -78,89 +88,19 @@ def get_execution_context() -> tuple[str | None, ChatSession | None]:
def pop_pending_tool_output(tool_name: str) -> str | None:
"""Pop and return the oldest stashed output for *tool_name*.
"""Pop and return the stashed full output for *tool_name*.
The SDK CLI may truncate large tool results (writing them to disk and
replacing the content with a file reference). This stash keeps the
original MCP output so the response adapter can forward it to the
frontend for proper widget rendering.
Uses a FIFO queue per tool name so duplicate calls to the same tool
in one turn each get their own output.
Returns ``None`` if nothing was stashed for *tool_name*.
"""
pending = _pending_tool_outputs.get(None)
if pending is None:
return None
queue = pending.get(tool_name)
if not queue:
pending.pop(tool_name, None)
return None
value = queue.pop(0)
if not queue:
del pending[tool_name]
return value
def stash_pending_tool_output(tool_name: str, output: Any) -> None:
"""Stash tool output for later retrieval by the response adapter.
Used by the PostToolUse hook to capture SDK built-in tool outputs
(WebSearch, Read, etc.) that aren't available through the MCP stash
mechanism in ``_execute_tool_sync``.
Appends to a FIFO queue per tool name so multiple calls to the same
tool in one turn are all preserved.
"""
pending = _pending_tool_outputs.get(None)
if pending is None:
return
if isinstance(output, str):
text = output
else:
try:
text = json.dumps(output)
except (TypeError, ValueError):
text = str(output)
pending.setdefault(tool_name, []).append(text)
# Signal any waiters that new output is available.
event = _stash_event.get(None)
if event is not None:
event.set()
async def wait_for_stash(timeout: float = 0.5) -> bool:
"""Wait for a PostToolUse hook to stash tool output.
The SDK fires PostToolUse hooks asynchronously via ``start_soon()``
the next message (AssistantMessage/ResultMessage) can arrive before the
hook completes and stashes its output. This function bridges that gap
by waiting on the ``_stash_event``, which is signaled by
:func:`stash_pending_tool_output`.
After the event fires, callers should ``await asyncio.sleep(0)`` to
give any remaining concurrent hooks a chance to complete.
Returns ``True`` if a stash signal was received, ``False`` on timeout.
The timeout is a safety net normally the stash happens within
microseconds of yielding to the event loop.
"""
event = _stash_event.get(None)
if event is None:
return False
# Fast path: hook already completed before we got here.
if event.is_set():
event.clear()
return True
# Slow path: wait for the hook to signal.
try:
async with asyncio.timeout(timeout):
await event.wait()
event.clear()
return True
except TimeoutError:
return False
return pending.pop(tool_name, None)
async def _execute_tool_sync(
@@ -185,63 +125,14 @@ async def _execute_tool_sync(
# Stash the full output before the SDK potentially truncates it.
pending = _pending_tool_outputs.get(None)
if pending is not None:
pending.setdefault(base_tool.name, []).append(text)
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
# If the tool result contains inline image data, add an MCP image block
# so Claude can "see" the image (e.g. read_workspace_file on a small PNG).
image_block = _extract_image_block(text)
if image_block:
content_blocks.append(image_block)
pending[base_tool.name] = text
return {
"content": content_blocks,
"content": [{"type": "text", "text": text}],
"isError": not result.success,
}
# MIME types that Claude can process as image content blocks.
_SUPPORTED_IMAGE_TYPES = frozenset(
{"image/png", "image/jpeg", "image/gif", "image/webp"}
)
def _extract_image_block(text: str) -> dict[str, str] | None:
"""Extract an MCP image content block from a tool result JSON string.
Detects workspace file responses with ``content_base64`` and an image
MIME type, returning an MCP-format image block that allows Claude to
"see" the image. Returns ``None`` if the result is not an inline image.
"""
try:
data = json.loads(text)
except (json.JSONDecodeError, TypeError):
return None
if not isinstance(data, dict):
return None
mime_type = data.get("mime_type", "")
base64_content = data.get("content_base64", "")
# Only inline small images — large ones would exceed Claude's limits.
# 32 KB raw ≈ ~43 KB base64.
_MAX_IMAGE_BASE64_BYTES = 43_000
if (
mime_type in _SUPPORTED_IMAGE_TYPES
and base64_content
and len(base64_content) <= _MAX_IMAGE_BASE64_BYTES
):
return {
"type": "image",
"data": base64_content,
"mimeType": mime_type,
}
return None
def _mcp_error(message: str) -> dict[str, Any]:
return {
"content": [
@@ -256,6 +147,11 @@ def create_tool_handler(base_tool: BaseTool):
This wraps the existing BaseTool._execute method to be compatible
with the Claude Agent SDK MCP tool format.
Long-running tools (``is_long_running=True``) are delegated to the
non-SDK background infrastructure via a callback set in the execution
context. The callback persists the operation in Redis (stream_registry)
so results survive page refreshes and pod restarts.
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
@@ -265,6 +161,25 @@ def create_tool_handler(base_tool: BaseTool):
if session is None:
return _mcp_error("No session context available")
# --- Long-running: delegate to non-SDK background infrastructure ---
if base_tool.is_long_running:
callback = _long_running_callback.get(None)
if callback:
try:
return await callback(base_tool.name, args, session)
except Exception as e:
logger.error(
f"Long-running callback failed for {base_tool.name}: {e}",
exc_info=True,
)
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
# No callback — fall through to synchronous execution
logger.warning(
f"[SDK] No long-running callback for {base_tool.name}, "
f"executing synchronously (may block)"
)
# --- Normal (fast) tool: execute synchronously ---
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:
@@ -396,29 +311,14 @@ def create_copilot_mcp_server():
# which provides kernel-level network isolation via unshare --net.
# Task allows spawning sub-agents (rate-limited by security hooks).
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
# TodoWrite manages the task checklist shown in the UI — no security concern.
_SDK_BUILTIN_TOOLS = [
"Read",
"Write",
"Edit",
"Glob",
"Grep",
"Task",
"WebSearch",
"TodoWrite",
]
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task", "WebSearch"]
# SDK built-in tools that must be explicitly blocked.
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
# network isolation (unshare --net) instead.
# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.).
# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead.
# AskUserQuestion: interactive CLI tool — no terminal in copilot context.
SDK_DISALLOWED_TOOLS = [
"Bash",
"WebFetch",
"AskUserQuestion",
]
SDK_DISALLOWED_TOOLS = ["Bash", "WebFetch"]
# Tools that are blocked entirely in security hooks (defence-in-depth).
# Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms.

View File

@@ -14,8 +14,6 @@ import json
import logging
import os
import re
import time
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@@ -33,16 +31,6 @@ STRIPPABLE_TYPES = frozenset(
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
)
@dataclass
class TranscriptDownload:
"""Result of downloading a transcript with its metadata."""
content: str
message_count: int = 0 # session.messages length when uploaded
uploaded_at: float = 0.0 # epoch timestamp of upload
# Workspace storage constants — deterministic path from session_id.
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
@@ -131,21 +119,22 @@ def read_transcript_file(transcript_path: str) -> str | None:
content = f.read()
if not content.strip():
logger.debug("[Transcript] File is empty: %s", transcript_path)
logger.debug(f"[Transcript] Empty file: {transcript_path}")
return None
lines = content.strip().split("\n")
# Validate that the transcript has real conversation content
# (not just metadata like queue-operation entries).
if not validate_transcript(content):
if len(lines) < 3:
# Raw files with ≤2 lines are metadata-only
# (queue-operation + file-history-snapshot, no conversation).
logger.debug(
"[Transcript] No conversation content (%d lines) in %s",
len(lines),
transcript_path,
f"[Transcript] Too few lines ({len(lines)}): {transcript_path}"
)
return None
# Quick structural validation — parse first and last lines.
json.loads(lines[0])
json.loads(lines[-1])
logger.info(
f"[Transcript] Read {len(lines)} lines, "
f"{len(content)} bytes from {transcript_path}"
@@ -171,41 +160,6 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The CLI replaces all non-alphanumeric characters with ``-``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
"""Remove the CLI's project directory for a specific working directory.
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
safe to remove entirely after the transcript has been uploaded.
"""
import shutil
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
if not project_dir.startswith(projects_base + os.sep):
logger.warning(
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
)
return
if os.path.isdir(project_dir):
shutil.rmtree(project_dir, ignore_errors=True)
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
else:
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
def write_transcript_to_tempfile(
transcript_content: str,
session_id: str,
@@ -294,15 +248,6 @@ def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
)
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
return (
TRANSCRIPT_STORAGE_PREFIX,
_sanitize_id(user_id),
f"{_sanitize_id(session_id)}.meta.json",
)
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
"""Build the full storage path string that ``retrieve()`` expects.
@@ -323,30 +268,21 @@ def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
return f"local://{wid}/{fid}/{fname}"
async def upload_transcript(
user_id: str,
session_id: str,
content: str,
message_count: int = 0,
) -> None:
async def upload_transcript(user_id: str, session_id: str, content: str) -> None:
"""Strip progress entries and upload transcript to bucket storage.
Safety: only overwrites when the new (stripped) transcript is larger than
what is already stored. Since JSONL is append-only, the latest transcript
is always the longest. This prevents a slow/stale background task from
clobbering a newer upload from a concurrent turn.
Args:
message_count: ``len(session.messages)`` at upload time used by
the next turn to detect staleness and compress only the gap.
"""
from backend.util.workspace_storage import get_workspace_storage
stripped = strip_progress_entries(content)
if not validate_transcript(stripped):
logger.warning(
f"[Transcript] Skipping upload — stripped content not valid "
f"for session {session_id}"
f"[Transcript] Skipping upload — stripped content is not a valid "
f"transcript for session {session_id}"
)
return
@@ -361,8 +297,9 @@ async def upload_transcript(
existing = await storage.retrieve(path)
if len(existing) >= new_size:
logger.info(
f"[Transcript] Skipping upload — existing ({len(existing)}B) "
f">= new ({new_size}B) for session {session_id}"
f"[Transcript] Skipping upload — existing transcript "
f"({len(existing)}B) >= new ({new_size}B) for session "
f"{session_id}"
)
return
except (FileNotFoundError, Exception):
@@ -374,38 +311,16 @@ async def upload_transcript(
filename=fname,
content=encoded,
)
# Store metadata alongside the transcript so the next turn can detect
# staleness and only compress the gap instead of the full history.
# Wrapped in try/except so a metadata write failure doesn't orphan
# the already-uploaded transcript — the next turn will just fall back
# to full gap fill (msg_count=0).
try:
meta = {"message_count": message_count, "uploaded_at": time.time()}
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
await storage.store(
workspace_id=mwid,
file_id=mfid,
filename=mfname,
content=json.dumps(meta).encode("utf-8"),
)
except Exception as e:
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
logger.info(
f"[Transcript] Uploaded {new_size}B "
f"(stripped from {len(content)}B, msg_count={message_count}) "
f"for session {session_id}"
f"[Transcript] Uploaded {new_size} bytes "
f"(stripped from {len(content)}) for session {session_id}"
)
async def download_transcript(
user_id: str, session_id: str
) -> TranscriptDownload | None:
"""Download transcript and metadata from bucket storage.
async def download_transcript(user_id: str, session_id: str) -> str | None:
"""Download transcript from bucket storage.
Returns a ``TranscriptDownload`` with the JSONL content and the
``message_count`` watermark from the upload, or ``None`` if not found.
Returns the JSONL content string, or ``None`` if not found.
"""
from backend.util.workspace_storage import get_workspace_storage
@@ -415,6 +330,10 @@ async def download_transcript(
try:
data = await storage.retrieve(path)
content = data.decode("utf-8")
logger.info(
f"[Transcript] Downloaded {len(content)} bytes for session {session_id}"
)
return content
except FileNotFoundError:
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
return None
@@ -422,36 +341,6 @@ async def download_transcript(
logger.warning(f"[Transcript] Failed to download transcript: {e}")
return None
# Try to load metadata (best-effort — old transcripts won't have it)
message_count = 0
uploaded_at = 0.0
try:
from backend.util.workspace_storage import GCSWorkspaceStorage
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
if isinstance(storage, GCSWorkspaceStorage):
blob = f"workspaces/{mwid}/{mfid}/{mfname}"
meta_path = f"gcs://{storage.bucket_name}/{blob}"
else:
meta_path = f"local://{mwid}/{mfid}/{mfname}"
meta_data = await storage.retrieve(meta_path)
meta = json.loads(meta_data.decode("utf-8"))
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
except (FileNotFoundError, json.JSONDecodeError, Exception):
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
logger.info(
f"[Transcript] Downloaded {len(content)}B "
f"(msg_count={message_count}) for session {session_id}"
)
return TranscriptDownload(
content=content,
message_count=message_count,
uploaded_at=uploaded_at,
)
async def delete_transcript(user_id: str, session_id: str) -> None:
"""Delete transcript from bucket storage (e.g. after resume failure)."""

View File

@@ -6,7 +6,12 @@ import pytest
from . import service as chat_service
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import StreamError, StreamTextDelta, StreamToolOutputAvailable
from .response_model import (
StreamError,
StreamFinish,
StreamTextDelta,
StreamToolOutputAvailable,
)
from .sdk import service as sdk_service
from .sdk.transcript import download_transcript
@@ -25,6 +30,7 @@ async def test_stream_chat_completion(setup_test_user, test_user_id):
session = await create_chat_session(test_user_id)
has_errors = False
has_ended = False
assistant_message = ""
async for chunk in chat_service.stream_chat_completion(
session.session_id, "Hello, how are you?", user_id=session.user_id
@@ -34,9 +40,10 @@ async def test_stream_chat_completion(setup_test_user, test_user_id):
has_errors = True
if isinstance(chunk, StreamTextDelta):
assistant_message += chunk.delta
if isinstance(chunk, StreamFinish):
has_ended = True
# StreamFinish is published by mark_session_completed (processor layer),
# not by the service. The generator completing means the stream ended.
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert assistant_message, "Assistant message is empty"
@@ -54,6 +61,7 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
session = await upsert_chat_session(session)
has_errors = False
has_ended = False
had_tool_calls = False
async for chunk in chat_service.stream_chat_completion(
session.session_id,
@@ -63,9 +71,13 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamFinish):
has_ended = True
if isinstance(chunk, StreamToolOutputAvailable):
had_tool_calls = True
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert had_tool_calls, "Tool calls did not occur"
session = await get_chat_session(session.session_id)
@@ -102,6 +114,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
)
turn1_text = ""
turn1_errors: list[str] = []
turn1_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
@@ -112,28 +125,25 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
turn1_text += chunk.delta
elif isinstance(chunk, StreamError):
turn1_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn1_ended = True
assert turn1_ended, "Turn 1 did not finish"
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
assert turn1_text, "Turn 1 produced no text"
# Wait for background upload task to complete (retry up to 5s).
# The CLI may not produce a usable transcript for very short
# conversations (only metadata entries) — this is environment-dependent
# (CLI version, platform). When that happens, multi-turn still works
# via conversation compression (non-resume path), but we can't test
# the --resume round-trip.
# Wait for background upload task to complete (retry up to 5s)
transcript = None
for _ in range(10):
await asyncio.sleep(0.5)
transcript = await download_transcript(test_user_id, session.session_id)
if transcript:
break
if not transcript:
return pytest.skip(
"CLI did not produce a usable transcript — "
"cannot test --resume round-trip in this environment"
)
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
assert transcript, (
"Transcript was not uploaded to bucket after turn 1 — "
"Stop hook may not have fired or transcript was too small"
)
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes")
# Reload session for turn 2
session = await get_chat_session(session.session_id, test_user_id)
@@ -143,6 +153,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
turn2_msg = "What was the special keyword I asked you to remember?"
turn2_text = ""
turn2_errors: list[str] = []
turn2_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
@@ -154,7 +165,10 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
turn2_text += chunk.delta
elif isinstance(chunk, StreamError):
turn2_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn2_ended = True
assert turn2_ended, "Turn 2 did not finish"
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
assert turn2_text, "Turn 2 produced no text"
assert keyword in turn2_text, (

View File

@@ -3,14 +3,14 @@ from typing import TYPE_CHECKING, Any
from openai.types.chat import ChatCompletionToolParam
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_tool_called
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .browse_web import BrowseWebTool
from .check_operation_status import CheckOperationStatusTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
@@ -31,7 +31,7 @@ from .workspace_files import (
)
if TYPE_CHECKING:
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.api.features.chat.response_model import StreamToolOutputAvailable
logger = logging.getLogger(__name__)
@@ -47,12 +47,11 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"view_agent_output": AgentOutputTool(),
"check_operation_status": CheckOperationStatusTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Browser-based browsing for JS-rendered pages (Stagehand + Browserbase)
"browse_web": BrowseWebTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
# Persistent workspace tools (cloud storage, survives across sessions)

View File

@@ -3,15 +3,14 @@ from datetime import UTC, datetime
from os import getenv
import pytest
import pytest_asyncio
from prisma.types import ProfileCreateInput
from pydantic import SecretStr
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import AITextGeneratorBlock
from backend.copilot.model import ChatSession
from backend.data.db import prisma
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.model import APIKeyCredentials
@@ -32,16 +31,14 @@ def make_session(user_id: str):
)
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_test_data(server):
@pytest.fixture(scope="session")
async def setup_test_data():
"""
Set up test data for run_agent tests:
1. Create a test user
2. Create a test graph (agent input -> agent output)
3. Create a store listing and store listing version
4. Approve the store listing version
Depends on ``server`` to ensure Prisma is connected.
"""
# 1. Create a test user
user_data = {
@@ -153,16 +150,14 @@ async def setup_test_data(server):
}
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_llm_test_data(server):
@pytest.fixture(scope="session")
async def setup_llm_test_data():
"""
Set up test data for LLM agent tests:
1. Create a test user
2. Create test OpenAI credentials for the user
3. Create a test graph with input -> LLM block -> output
4. Create and approve a store listing
Depends on ``server`` to ensure Prisma is connected.
"""
key = getenv("OPENAI_API_KEY")
if not key:
@@ -320,15 +315,13 @@ async def setup_llm_test_data(server):
}
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_firecrawl_test_data(server):
@pytest.fixture(scope="session")
async def setup_firecrawl_test_data():
"""
Set up test data for Firecrawl agent tests (missing credentials scenario):
1. Create a test user (WITHOUT Firecrawl credentials)
2. Create a test graph with input -> Firecrawl block -> output
3. Create and approve a store listing
Depends on ``server`` to ensure Prisma is connected.
"""
# 1. Create a test user
user_data = {

View File

@@ -3,9 +3,11 @@
import logging
from typing import Any
from backend.copilot.model import ChatSession
from backend.data.db_accessors import understanding_db
from backend.data.understanding import BusinessUnderstandingInput
from backend.api.features.chat.model import ChatSession
from backend.data.understanding import (
BusinessUnderstandingInput,
upsert_business_understanding,
)
from .base import BaseTool
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
@@ -97,9 +99,7 @@ and automations for the user's specific needs."""
]
# Upsert with merge
understanding = await understanding_db().upsert_business_understanding(
user_id, input_data
)
understanding = await upsert_business_understanding(user_id, input_data)
# Build current understanding summary (filter out empty values)
current_understanding = {

View File

@@ -19,7 +19,6 @@ from .core import (
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
get_library_agent_by_id,
get_library_agents_by_ids,
get_library_agents_for_generation,
graph_to_json,
json_to_graph,
@@ -50,7 +49,6 @@ __all__ = [
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
"get_library_agent_by_id",
"get_library_agents_by_ids",
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",

View File

@@ -3,11 +3,11 @@
import logging
import re
import uuid
from collections.abc import Sequence
from typing import Any, NotRequired, TypedDict
from backend.data.db_accessors import graph_db, library_db, store_db
from backend.data.graph import Graph, Link, Node
from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
from backend.util.exceptions import DatabaseError, NotFoundError
from .service import (
@@ -79,7 +79,7 @@ AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
def _to_dict_list(
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
agents: list[AgentSummary] | list[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
@@ -145,9 +145,8 @@ async def get_library_agent_by_id(
Returns:
LibraryAgentSummary if found, None otherwise
"""
db = library_db()
try:
agent = await db.get_library_agent_by_graph_id(user_id, agent_id)
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return LibraryAgentSummary(
@@ -164,7 +163,7 @@ async def get_library_agent_by_id(
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
try:
agent = await db.get_library_agent(agent_id, user_id)
agent = await library_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return LibraryAgentSummary(
@@ -191,36 +190,6 @@ async def get_library_agent_by_id(
get_library_agent_by_graph_id = get_library_agent_by_id
async def get_library_agents_by_ids(
user_id: str,
agent_ids: list[str],
) -> list[LibraryAgentSummary]:
"""Fetch multiple library agents by their IDs.
Args:
user_id: The user ID
agent_ids: List of agent IDs (can be graph_ids or library agent IDs)
Returns:
List of LibraryAgentSummary for found agents (silently skips not found)
"""
agents: list[LibraryAgentSummary] = []
for agent_id in agent_ids:
try:
agent = await get_library_agent_by_id(user_id, agent_id)
if agent:
agents.append(agent)
logger.debug(f"Fetched library agent by ID: {agent['name']}")
else:
logger.warning(f"Library agent not found for ID: {agent_id}")
except Exception as e:
logger.warning(f"Failed to fetch library agent {agent_id}: {e}")
continue
logger.info(f"Fetched {len(agents)}/{len(agent_ids)} library agents by ID")
return agents
async def get_library_agents_for_generation(
user_id: str,
search_query: str | None = None,
@@ -245,17 +214,10 @@ async def get_library_agents_for_generation(
Returns:
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
"""
search_term = search_query.strip() if search_query else None
if search_term and len(search_term) > 100:
raise ValueError(
f"Search query is too long ({len(search_term)} chars, max 100). "
f"Please use a shorter, more specific search term."
)
try:
response = await library_db().list_library_agents(
response = await library_db.list_library_agents(
user_id=user_id,
search_term=search_term,
search_term=search_query,
page=1,
page_size=max_results,
include_executions=True,
@@ -309,16 +271,9 @@ async def search_marketplace_agents_for_generation(
Returns:
List of LibraryAgentSummary with full input/output schemas
"""
search_term = search_query.strip()
if len(search_term) > 100:
raise ValueError(
f"Search query is too long ({len(search_term)} chars, max 100). "
f"Please use a shorter, more specific search term."
)
try:
response = await store_db().get_store_agents(
search_query=search_term,
response = await store_db.get_store_agents(
search_query=search_query,
page=1,
page_size=max_results,
)
@@ -331,7 +286,7 @@ async def search_marketplace_agents_for_generation(
return []
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
graphs = await graph_db().get_store_listed_graphs(graph_ids)
graphs = await get_store_listed_graphs(*graph_ids)
results: list[LibraryAgentSummary] = []
for agent in agents_with_graphs:
@@ -469,7 +424,7 @@ def extract_search_terms_from_steps(
async def enrich_library_agents_from_steps(
user_id: str,
decomposition_result: DecompositionResult | dict[str, Any],
existing_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]],
existing_agents: list[AgentSummary] | list[dict[str, Any]],
exclude_graph_id: str | None = None,
include_marketplace: bool = True,
max_additional_results: int = 10,
@@ -493,7 +448,7 @@ async def enrich_library_agents_from_steps(
search_terms = extract_search_terms_from_steps(decomposition_result)
if not search_terms:
return list(existing_agents)
return existing_agents
existing_ids: set[str] = set()
existing_names: set[str] = set()
@@ -556,7 +511,7 @@ async def enrich_library_agents_from_steps(
async def decompose_goal(
description: str,
context: str = "",
library_agents: Sequence[AgentSummary] | None = None,
library_agents: list[AgentSummary] | None = None,
) -> DecompositionResult | None:
"""Break down a goal into steps or return clarifying questions.
@@ -584,16 +539,22 @@ async def decompose_goal(
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams
completion notification)
task_id: Task ID for async processing (enables Redis Streams persistence
and SSE delivery)
Returns:
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -601,9 +562,13 @@ async def generate_agent(
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents)
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
)
# Don't modify async response
if result and result.get("status") == "accepted":
return result
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
@@ -708,10 +673,9 @@ async def save_agent_to_library(
Tuple of (created Graph, LibraryAgent)
"""
graph = json_to_graph(agent_json)
db = library_db()
if is_update:
return await db.update_graph_in_library(graph, user_id)
return await db.create_graph_in_library(graph, user_id)
return await library_db.update_graph_in_library(graph, user_id)
return await library_db.create_graph_in_library(graph, user_id)
def graph_to_json(graph: Graph) -> dict[str, Any]:
@@ -771,14 +735,12 @@ async def get_agent_as_json(
Returns:
Agent as JSON dict or None if not found
"""
db = graph_db()
graph = await db.get_graph(agent_id, version=None, user_id=user_id)
graph = await get_graph(agent_id, version=None, user_id=user_id)
if not graph and user_id:
try:
library_agent = await library_db().get_library_agent(agent_id, user_id)
graph = await db.get_graph(
library_agent = await library_db.get_library_agent(agent_id, user_id)
graph = await get_graph(
library_agent.graph_id, version=None, user_id=user_id
)
except NotFoundError:
@@ -793,7 +755,9 @@ async def get_agent_as_json(
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: Sequence[AgentSummary] | None = None,
library_agents: list[AgentSummary] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
@@ -806,10 +770,12 @@ async def generate_agent_patch(
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on error
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -820,6 +786,8 @@ async def generate_agent_patch(
update_request,
current_agent,
_to_dict_list(library_agents),
operation_id,
task_id,
)

View File

@@ -102,15 +102,10 @@ async def generate_agent_dummy(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
)
"""Return dummy agent JSON after a simulated delay."""
logger.info("Using dummy agent generator for generate_agent (30s delay)")
await asyncio.sleep(30)
return _generate_dummy_agent_json()
@@ -120,16 +115,10 @@ async def generate_agent_patch_dummy(
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
)
await asyncio.sleep(30)
"""Return dummy patched agent (returns the current agent with updated description)."""
logger.info("Using dummy agent generator for generate_agent_patch")
patched = current_agent.copy()
patched["description"] = (
f"{current_agent.get('description', '')} (updated: {update_request})"

View File

@@ -242,18 +242,24 @@ async def decompose_goal_external(
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Agent JSON dict or error dict {"type": "error", ...} on error
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
"""
if _is_dummy_mode():
return await generate_agent_dummy(instructions, library_agents)
return await generate_agent_dummy(
instructions, library_agents, operation_id, task_id
)
client = _get_client()
@@ -261,9 +267,25 @@ async def generate_agent_external(
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
@@ -295,6 +317,8 @@ async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
@@ -303,14 +327,14 @@ async def generate_agent_patch_external(
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
"""
if _is_dummy_mode():
return await generate_agent_patch_dummy(
update_request, current_agent, library_agents
update_request, current_agent, library_agents, operation_id, task_id
)
client = _get_client()
@@ -322,9 +346,25 @@ async def generate_agent_patch_external(
}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
@@ -379,8 +419,6 @@ async def customize_template_external(
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error

View File

@@ -5,15 +5,15 @@ import re
from datetime import datetime, timedelta, timezone
from typing import Any
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, field_validator
from backend.api.features.chat.model import ChatSession
from backend.api.features.library import db as library_db
from backend.api.features.library.model import LibraryAgent
from backend.copilot.model import ChatSession
from backend.data.db_accessors import execution_db, library_db
from backend.data import execution as execution_db
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
from .base import BaseTool
from .execution_utils import TERMINAL_STATUSES, wait_for_execution
from .models import (
AgentOutputResponse,
ErrorResponse,
@@ -34,7 +34,6 @@ class AgentOutputInput(BaseModel):
store_slug: str = ""
execution_id: str = ""
run_time: str = "latest"
wait_if_running: int = Field(default=0, ge=0, le=300)
@field_validator(
"agent_name",
@@ -118,11 +117,6 @@ class AgentOutputTool(BaseTool):
Select which run to retrieve using:
- execution_id: Specific execution ID
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
Wait for completion (optional):
- wait_if_running: Max seconds to wait if execution is still running (0-300).
If the execution is running/queued, waits up to this many seconds for completion.
Returns current status on timeout. If already finished, returns immediately.
"""
@property
@@ -152,13 +146,6 @@ class AgentOutputTool(BaseTool):
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
),
},
"wait_if_running": {
"type": "integer",
"description": (
"Max seconds to wait if execution is still running (0-300). "
"If running, waits for completion. Returns current state on timeout."
),
},
},
"required": [],
}
@@ -178,12 +165,10 @@ class AgentOutputTool(BaseTool):
Resolve agent from provided identifiers.
Returns (library_agent, error_message).
"""
lib_db = library_db()
# Priority 1: Exact library agent ID
if library_agent_id:
try:
agent = await lib_db.get_library_agent(library_agent_id, user_id)
agent = await library_db.get_library_agent(library_agent_id, user_id)
return agent, None
except Exception as e:
logger.warning(f"Failed to get library agent by ID: {e}")
@@ -197,7 +182,7 @@ class AgentOutputTool(BaseTool):
return None, f"Agent '{store_slug}' not found in marketplace"
# Find in user's library by graph_id
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id)
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
if not agent:
return (
None,
@@ -209,7 +194,7 @@ class AgentOutputTool(BaseTool):
# Priority 3: Fuzzy name search in library
if agent_name:
try:
response = await lib_db.list_library_agents(
response = await library_db.list_library_agents(
user_id=user_id,
search_term=agent_name,
page_size=5,
@@ -238,20 +223,14 @@ class AgentOutputTool(BaseTool):
execution_id: str | None,
time_start: datetime | None,
time_end: datetime | None,
include_running: bool = False,
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
"""
Fetch execution(s) based on filters.
Returns (single_execution, available_executions_meta, error_message).
Args:
include_running: If True, also look for running/queued executions (for waiting)
"""
exec_db = execution_db()
# If specific execution_id provided, fetch it directly
if execution_id:
execution = await exec_db.get_graph_execution(
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
@@ -260,25 +239,11 @@ class AgentOutputTool(BaseTool):
return None, [], f"Execution '{execution_id}' not found"
return execution, [], None
# Determine which statuses to query
statuses = [ExecutionStatus.COMPLETED]
if include_running:
statuses.extend(
[
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
ExecutionStatus.REVIEW,
ExecutionStatus.FAILED,
ExecutionStatus.TERMINATED,
]
)
# Get executions with time filters
executions = await exec_db.get_graph_executions(
# Get completed executions with time filters
executions = await execution_db.get_graph_executions(
graph_id=graph_id,
user_id=user_id,
statuses=statuses,
statuses=[ExecutionStatus.COMPLETED],
created_time_gte=time_start,
created_time_lte=time_end,
limit=10,
@@ -289,7 +254,7 @@ class AgentOutputTool(BaseTool):
# If only one execution, fetch full details
if len(executions) == 1:
full_execution = await exec_db.get_graph_execution(
full_execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
@@ -297,7 +262,7 @@ class AgentOutputTool(BaseTool):
return full_execution, [], None
# Multiple executions - return latest with full details, plus list of available
full_execution = await exec_db.get_graph_execution(
full_execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=executions[0].id,
include_node_executions=False,
@@ -345,33 +310,10 @@ class AgentOutputTool(BaseTool):
for e in available_executions[:5]
]
# Build appropriate message based on execution status
if execution.status == ExecutionStatus.COMPLETED:
message = f"Found execution outputs for agent '{agent.name}'"
elif execution.status == ExecutionStatus.FAILED:
message = f"Execution for agent '{agent.name}' failed"
elif execution.status == ExecutionStatus.TERMINATED:
message = f"Execution for agent '{agent.name}' was terminated"
elif execution.status == ExecutionStatus.REVIEW:
message = (
f"Execution for agent '{agent.name}' is awaiting human review. "
"The user needs to approve it before it can continue."
)
elif execution.status in (
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
):
message = (
f"Execution for agent '{agent.name}' is still {execution.status.value}. "
"Results may be incomplete. Use wait_if_running to wait for completion."
)
else:
message = f"Found execution for agent '{agent.name}' (status: {execution.status.value})"
message = f"Found execution outputs for agent '{agent.name}'"
if len(available_executions) > 1:
message += (
f" Showing latest of {len(available_executions)} matching executions."
f". Showing latest of {len(available_executions)} matching executions."
)
return AgentOutputResponse(
@@ -438,7 +380,7 @@ class AgentOutputTool(BaseTool):
and not input_data.store_slug
):
# Fetch execution directly to get graph_id
execution = await execution_db().get_graph_execution(
execution = await execution_db.get_graph_execution(
user_id=user_id,
execution_id=input_data.execution_id,
include_node_executions=False,
@@ -450,7 +392,7 @@ class AgentOutputTool(BaseTool):
)
# Find library agent by graph_id
agent = await library_db().get_library_agent_by_graph_id(
agent = await library_db.get_library_agent_by_graph_id(
user_id, execution.graph_id
)
if not agent:
@@ -486,17 +428,13 @@ class AgentOutputTool(BaseTool):
# Parse time expression
time_start, time_end = parse_time_expression(input_data.run_time)
# Check if we should wait for running executions
wait_timeout = input_data.wait_if_running
# Fetch execution(s) - include running if we're going to wait
# Fetch execution(s)
execution, available_executions, exec_error = await self._get_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=input_data.execution_id or None,
time_start=time_start,
time_end=time_end,
include_running=wait_timeout > 0,
)
if exec_error:
@@ -505,17 +443,4 @@ class AgentOutputTool(BaseTool):
session_id=session_id,
)
# If we have an execution that's still running and we should wait
if execution and wait_timeout > 0 and execution.status not in TERMINAL_STATUSES:
logger.info(
f"Execution {execution.id} is {execution.status}, "
f"waiting up to {wait_timeout}s for completion"
)
execution = await wait_for_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=execution.id,
timeout_seconds=wait_timeout,
)
return self._build_response(agent, execution, available_executions, session_id)

View File

@@ -1,15 +1,11 @@
"""Shared agent search functionality for find_agent and find_library_agent tools."""
from __future__ import annotations
import logging
import re
from typing import TYPE_CHECKING, Literal
from typing import Literal
if TYPE_CHECKING:
from backend.api.features.library.model import LibraryAgent
from backend.data.db_accessors import library_db, store_db
from backend.api.features.library import db as library_db
from backend.api.features.store import db as store_db
from backend.util.exceptions import DatabaseError, NotFoundError
from .models import (
@@ -29,24 +25,92 @@ _UUID_PATTERN = re.compile(
re.IGNORECASE,
)
# Keywords that should be treated as "list all" rather than a literal search
_LIST_ALL_KEYWORDS = frozenset({"all", "*", "everything", "any", ""})
def _is_uuid(text: str) -> bool:
"""Check if text is a valid UUID v4."""
return bool(_UUID_PATTERN.match(text.strip()))
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
"""Fetch a library agent by ID (library agent ID or graph_id).
Tries multiple lookup strategies:
1. First by graph_id (AgentGraph primary key)
2. Then by library agent ID (LibraryAgent primary key)
Args:
user_id: The user ID
agent_id: The ID to look up (can be graph_id or library agent ID)
Returns:
AgentInfo if found, None otherwise
"""
try:
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by graph_id {agent_id}: {e}",
exc_info=True,
)
try:
agent = await library_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None
async def search_agents(
query: str,
source: SearchSource,
session_id: str | None = None,
session_id: str | None,
user_id: str | None = None,
) -> ToolResponseBase:
"""
Search for agents in marketplace or user library.
For library searches, keywords like "all", "*", "everything", or an empty
query will list all agents without filtering.
Args:
query: Search query string. Special keywords list all library agents.
query: Search query string
source: "marketplace" or "library"
session_id: Chat session ID
user_id: User ID (required for library search)
@@ -54,11 +118,7 @@ async def search_agents(
Returns:
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
"""
# Normalize list-all keywords to empty string for library searches
if source == "library" and query.lower().strip() in _LIST_ALL_KEYWORDS:
query = ""
if source == "marketplace" and not query:
if not query:
return ErrorResponse(
message="Please provide a search query", session_id=session_id
)
@@ -73,7 +133,7 @@ async def search_agents(
try:
if source == "marketplace":
logger.info(f"Searching marketplace for: {query}")
results = await store_db().get_store_agents(search_query=query, page_size=5)
results = await store_db.get_store_agents(search_query=query, page_size=5)
for agent in results.agents:
agents.append(
AgentInfo(
@@ -98,18 +158,28 @@ async def search_agents(
logger.info(f"Found agent by direct ID lookup: {agent.name}")
if not agents:
search_term = query or None
logger.info(
f"{'Listing all agents in' if not query else 'Searching'} "
f"user library{'' if not query else f' for: {query}'}"
)
results = await library_db().list_library_agents(
logger.info(f"Searching user library for: {query}")
results = await library_db.list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=search_term,
page_size=50 if not query else 10,
search_term=query,
page_size=10,
)
for agent in results.agents:
agents.append(_library_agent_to_info(agent))
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
)
logger.info(f"Found {len(agents)} agents in {source}")
except NotFoundError:
pass
@@ -122,62 +192,42 @@ async def search_agents(
)
if not agents:
if source == "marketplace":
suggestions = [
suggestions = (
[
"Try more general terms",
"Browse categories in the marketplace",
"Check spelling",
]
no_results_msg = (
f"No agents found matching '{query}'. Let the user know they can "
"try different keywords or browse the marketplace. Also let them "
"know you can create a custom agent for them based on their needs."
)
elif not query:
# User asked to list all but library is empty
suggestions = [
"Browse the marketplace to find and add agents",
"Use find_agent to search the marketplace",
]
no_results_msg = (
"Your library is empty. Let the user know they can browse the "
"marketplace to find agents, or you can create a custom agent "
"for them based on their needs."
)
else:
suggestions = [
if source == "marketplace"
else [
"Try different keywords",
"Use find_agent to search the marketplace",
"Check your library at /library",
]
no_results_msg = (
f"No agents matching '{query}' found in your library. Let the "
"user know you can create a custom agent for them based on "
"their needs."
)
)
no_results_msg = (
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
if source == "marketplace"
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
)
return NoResultsResponse(
message=no_results_msg, session_id=session_id, suggestions=suggestions
)
if source == "marketplace":
title = (
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
)
elif not query:
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library"
else:
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library for '{query}'"
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
title += (
f"for '{query}'"
if source == "marketplace"
else f"in your library for '{query}'"
)
message = (
"Now you have found some options for the user to choose from. "
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
"Please ask the user if they would like to use any of these agents. "
"Let the user know we can create a custom agent for them based on their needs."
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
if source == "marketplace"
else "Found agents in the user's library. You can provide a link to view "
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
"execution results, or run_agent to execute. Let the user know we can "
"create a custom agent for them based on their needs."
else "Found agents in the user's library. You can provide a link to view an agent at: "
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
)
return AgentsFoundResponse(
@@ -187,67 +237,3 @@ async def search_agents(
count=len(agents),
session_id=session_id,
)
def _is_uuid(text: str) -> bool:
"""Check if text is a valid UUID v4."""
return bool(_UUID_PATTERN.match(text.strip()))
def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
"""Convert a library agent model to an AgentInfo."""
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
"""Fetch a library agent by ID (library agent ID or graph_id).
Tries multiple lookup strategies:
1. First by graph_id (AgentGraph primary key)
2. Then by library agent ID (LibraryAgent primary key)
"""
lib_db = library_db()
try:
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return _library_agent_to_info(agent)
except NotFoundError:
logger.debug(f"Library agent not found by graph_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by graph_id {agent_id}: {e}",
exc_info=True,
)
try:
agent = await lib_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return _library_agent_to_info(agent)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None

View File

@@ -5,8 +5,8 @@ from typing import Any
from openai.types.chat import ChatCompletionToolParam
from backend.copilot.model import ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.response_model import StreamToolOutputAvailable
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
@@ -36,6 +36,16 @@ class BaseTool:
"""Whether this tool requires authentication."""
return False
@property
def is_long_running(self) -> bool:
"""Whether this tool is long-running and should execute in background.
Long-running tools (like agent generation) are executed via background
tasks to survive SSE disconnections. The result is persisted to chat
history and visible when the user refreshes.
"""
return False
def as_openai_tool(self) -> ChatCompletionToolParam:
"""Convert to OpenAI tool format."""
return ChatCompletionToolParam(

View File

@@ -11,11 +11,18 @@ available (e.g. macOS development).
import logging
from typing import Any
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
BashExecResponse,
ErrorResponse,
ToolResponseBase,
)
from backend.api.features.chat.tools.sandbox import (
get_workspace_dir,
has_full_sandbox,
run_sandboxed,
)
logger = logging.getLogger(__name__)

View File

@@ -0,0 +1,127 @@
"""CheckOperationStatusTool — query the status of a long-running operation."""
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
ErrorResponse,
ResponseType,
ToolResponseBase,
)
logger = logging.getLogger(__name__)
class OperationStatusResponse(ToolResponseBase):
"""Response for check_operation_status tool."""
type: ResponseType = ResponseType.OPERATION_STATUS
task_id: str
operation_id: str
status: str # "running", "completed", "failed"
tool_name: str | None = None
message: str = ""
class CheckOperationStatusTool(BaseTool):
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
The CoPilot uses this tool to report back to the user whether an
operation that was started earlier has completed, failed, or is still
running.
"""
@property
def name(self) -> str:
return "check_operation_status"
@property
def description(self) -> str:
return (
"Check the current status of a long-running operation such as "
"create_agent or edit_agent. Accepts either an operation_id or "
"task_id from a previous operation_started response. "
"Returns the current status: running, completed, or failed."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"operation_id": {
"type": "string",
"description": (
"The operation_id from an operation_started response."
),
},
"task_id": {
"type": "string",
"description": (
"The task_id from an operation_started response. "
"Used as fallback if operation_id is not provided."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
from backend.api.features.chat import stream_registry
operation_id = (kwargs.get("operation_id") or "").strip()
task_id = (kwargs.get("task_id") or "").strip()
if not operation_id and not task_id:
return ErrorResponse(
message="Please provide an operation_id or task_id.",
error="missing_parameter",
)
task = None
if operation_id:
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None and task_id:
task = await stream_registry.get_task(task_id)
if task is None:
# Task not in Redis — it may have already expired (TTL).
# Check conversation history for the result instead.
return ErrorResponse(
message=(
"Operation not found — it may have already completed and "
"expired from the status tracker. Check the conversation "
"history for the result."
),
error="not_found",
)
status_messages = {
"running": (
f"The {task.tool_name or 'operation'} is still running. "
"Please wait for it to complete."
),
"completed": (
f"The {task.tool_name or 'operation'} has completed successfully."
),
"failed": f"The {task.tool_name or 'operation'} has failed.",
}
return OperationStatusResponse(
task_id=task.task_id,
operation_id=task.operation_id,
status=task.status,
tool_name=task.tool_name,
message=status_messages.get(task.status, f"Status: {task.status}"),
)

View File

@@ -3,13 +3,14 @@
import logging
from typing import Any
from backend.copilot.model import ChatSession
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
@@ -17,10 +18,10 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
SuggestedGoalResponse,
ToolResponseBase,
)
@@ -38,16 +39,17 @@ class CreateAgentTool(BaseTool):
def description(self) -> str:
return (
"Create a new agent workflow from a natural language description. "
"First generates a preview, then saves to library if save=true. "
"\n\nIMPORTANT: Before calling this tool, search for relevant existing agents "
"using find_library_agent that could be used as building blocks. "
"Pass their IDs in the library_agent_ids parameter so the generator can compose them."
"First generates a preview, then saves to library if save=true."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -67,15 +69,6 @@ class CreateAgentTool(BaseTool):
"Include any preferences or constraints mentioned by the user."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks. "
"Search for relevant agents using find_library_agent first, "
"then pass their IDs here so they can be composed into the new agent."
),
},
"save": {
"type": "boolean",
"description": (
@@ -103,14 +96,12 @@ class CreateAgentTool(BaseTool):
"""
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
session_id = session.session_id if session else None
logger.info(
f"[AGENT_CREATE_DEBUG] START - description_len={len(description)}, "
f"library_agent_ids={library_agent_ids}, save={save}, user_id={user_id}, session_id={session_id}"
)
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not description:
return ErrorResponse(
@@ -119,34 +110,25 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id and library_agent_ids:
if user_id:
try:
from .agent_generator import get_library_agents_by_ids
library_agents = await get_library_agents_by_ids(
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
agent_ids=library_agent_ids,
search_query=description,
include_marketplace=True,
)
logger.debug(
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
logger.warning(f"Failed to fetch library agents: {e}")
try:
decomposition_result = await decompose_goal(
description, context, library_agents
)
logger.info(
f"[AGENT_CREATE_DEBUG] DECOMPOSE - type={decomposition_result.get('type') if decomposition_result else None}, "
f"session_id={session_id}"
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
@@ -204,28 +186,26 @@ class CreateAgentTool(BaseTool):
if decomposition_result.get("type") == "unachievable_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get("reason", "")
return SuggestedGoalResponse(
return ErrorResponse(
message=(
f"This goal cannot be accomplished with the available blocks. {reason}"
f"This goal cannot be accomplished with the available blocks. "
f"{reason} "
f"Suggestion: {suggested}"
),
suggested_goal=suggested,
reason=reason,
original_goal=description,
goal_type="unachievable",
error="unachievable_goal",
details={"suggested_goal": suggested, "reason": reason},
session_id=session_id,
)
if decomposition_result.get("type") == "vague_goal":
suggested = decomposition_result.get("suggested_goal", "")
reason = decomposition_result.get(
"reason", "The goal needs more specific details"
)
return SuggestedGoalResponse(
message="The goal is too vague to create a specific workflow.",
suggested_goal=suggested,
reason=reason,
original_goal=description,
goal_type="vague",
return ErrorResponse(
message=(
f"The goal is too vague to create a specific workflow. "
f"Suggestion: {suggested}"
),
error="vague_goal",
details={"suggested_goal": suggested},
session_id=session_id,
)
@@ -247,17 +227,10 @@ class CreateAgentTool(BaseTool):
agent_json = await generate_agent(
decomposition_result,
library_agents,
)
logger.info(
f"[AGENT_CREATE_DEBUG] GENERATE - "
f"success={agent_json is not None}, "
f"is_error={isinstance(agent_json, dict) and agent_json.get('type') == 'error'}, "
f"session_id={session_id}"
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured during generation, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
@@ -300,20 +273,25 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
logger.info(
f"[AGENT_CREATE_DEBUG] AGENT_JSON - name={agent_name}, "
f"nodes={node_count}, links={link_count}, save={save}, session_id={session_id}"
)
if not save:
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentPreviewResponse, session_id={session_id}"
)
return AgentPreviewResponse(
message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
@@ -339,13 +317,6 @@ class CreateAgentTool(BaseTool):
agent_json, user_id
)
logger.info(
f"[AGENT_CREATE_DEBUG] SAVED - graph_id={created_graph.id}, "
f"library_agent_id={library_agent.id}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentSavedResponse, session_id={session_id}"
)
return AgentSavedResponse(
message=f"Agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
@@ -356,12 +327,6 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
except Exception as e:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - save_failed: {str(e)}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - ErrorResponse (save_failed), session_id={session_id}"
)
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",

View File

@@ -3,9 +3,9 @@
import logging
from typing import Any
from backend.api.features.chat.model import ChatSession
from backend.api.features.store import db as store_db
from backend.api.features.store.exceptions import AgentNotFoundError
from backend.copilot.model import ChatSession
from backend.data.db_accessors import store_db as get_store_db
from .agent_generator import (
AgentGeneratorNotConfiguredError,
@@ -46,6 +46,10 @@ class CustomizeAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -133,8 +137,6 @@ class CustomizeAgentTool(BaseTool):
creator_username, agent_slug = parts
store_db = get_store_db()
# Fetch the marketplace agent details
try:
agent_details = await store_db.get_store_agent_details(

View File

@@ -3,12 +3,13 @@
import logging
from typing import Any
from backend.copilot.model import ChatSession
from backend.api.features.chat.model import ChatSession
from .agent_generator import (
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
@@ -16,6 +17,7 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -36,16 +38,17 @@ class EditAgentTool(BaseTool):
def description(self) -> str:
return (
"Edit an existing agent from the user's library using natural language. "
"Generates updates to the agent while preserving unchanged parts. "
"\n\nIMPORTANT: Before calling this tool, if the changes involve adding new "
"functionality, search for relevant existing agents using find_library_agent "
"that could be used as building blocks. Pass their IDs in library_agent_ids."
"Generates updates to the agent while preserving unchanged parts."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -71,15 +74,6 @@ class EditAgentTool(BaseTool):
"Additional context or answers to previous clarifying questions."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks for the changes. "
"If adding new functionality, search for relevant agents using "
"find_library_agent first, then pass their IDs here."
),
},
"save": {
"type": "boolean",
"description": (
@@ -108,10 +102,13 @@ class EditAgentTool(BaseTool):
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
@@ -135,25 +132,21 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id and library_agent_ids:
if user_id:
try:
from .agent_generator import get_library_agents_by_ids
graph_id = current_agent.get("id")
# Filter out the current agent being edited
filtered_ids = [id for id in library_agent_ids if id != graph_id]
library_agents = await get_library_agents_by_ids(
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
agent_ids=filtered_ids,
search_query=changes,
exclude_graph_id=graph_id,
include_marketplace=True,
)
logger.debug(
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
logger.warning(f"Failed to fetch library agents: {e}")
update_request = changes
if context:
@@ -164,6 +157,8 @@ class EditAgentTool(BaseTool):
update_request,
current_agent,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
@@ -183,6 +178,19 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")

View File

@@ -5,14 +5,9 @@ from typing import Any
from pydantic import SecretStr
from backend.blocks.linear._api import LinearClient
from backend.copilot.model import ChatSession
from backend.data.db_accessors import user_db
from backend.data.model import APIKeyCredentials
from backend.util.settings import Settings
from .base import BaseTool
from .models import (
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
ErrorResponse,
FeatureRequestCreatedResponse,
FeatureRequestInfo,
@@ -20,6 +15,10 @@ from .models import (
NoResultsResponse,
ToolResponseBase,
)
from backend.blocks.linear._api import LinearClient
from backend.data.model import APIKeyCredentials
from backend.data.user import get_user_email_by_id
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -33,6 +32,7 @@ query SearchFeatureRequests($term: String!, $filter: IssueFilter, $first: Int) {
id
identifier
title
description
}
}
}
@@ -104,8 +104,8 @@ def _get_linear_config() -> tuple[LinearClient, str, str]:
Raises RuntimeError if any required setting is missing.
"""
secrets = _get_settings().secrets
if not secrets.copilot_linear_api_key:
raise RuntimeError("COPILOT_LINEAR_API_KEY is not configured")
if not secrets.linear_api_key:
raise RuntimeError("LINEAR_API_KEY is not configured")
if not secrets.linear_feature_request_project_id:
raise RuntimeError("LINEAR_FEATURE_REQUEST_PROJECT_ID is not configured")
if not secrets.linear_feature_request_team_id:
@@ -114,7 +114,7 @@ def _get_linear_config() -> tuple[LinearClient, str, str]:
credentials = APIKeyCredentials(
id="system-linear",
provider="linear",
api_key=SecretStr(secrets.copilot_linear_api_key),
api_key=SecretStr(secrets.linear_api_key),
title="System Linear API Key",
)
client = LinearClient(credentials=credentials)
@@ -204,6 +204,7 @@ class SearchFeatureRequestsTool(BaseTool):
id=node["id"],
identifier=node["identifier"],
title=node["title"],
description=node.get("description"),
)
for node in nodes
]
@@ -237,11 +238,7 @@ class CreateFeatureRequestTool(BaseTool):
"Create a new feature request or add a customer need to an existing one. "
"Always search first with search_feature_requests to avoid duplicates. "
"If a matching request exists, pass its ID as existing_issue_id to add "
"the user's need to it instead of creating a duplicate. "
"IMPORTANT: Never include personally identifiable information (PII) in "
"the title or description — no names, emails, phone numbers, company "
"names, or other identifying details. Write titles and descriptions in "
"generic, feature-focused language."
"the user's need to it instead of creating a duplicate."
)
@property
@@ -251,20 +248,11 @@ class CreateFeatureRequestTool(BaseTool):
"properties": {
"title": {
"type": "string",
"description": (
"Title for the feature request. Must be generic and "
"feature-focused — do not include any user names, emails, "
"company names, or other PII."
),
"description": "Title for the feature request.",
},
"description": {
"type": "string",
"description": (
"Detailed description of what the user wants and why. "
"Must not contain any personally identifiable information "
"(PII) — describe the feature need generically without "
"referencing specific users, companies, or contact details."
),
"description": "Detailed description of what the user wants and why.",
},
"existing_issue_id": {
"type": "string",
@@ -344,9 +332,7 @@ class CreateFeatureRequestTool(BaseTool):
# Resolve a human-readable name (email) for the Linear customer record.
# Fall back to user_id if the lookup fails or returns None.
try:
customer_display_name = (
await user_db().get_user_email_by_id(user_id) or user_id
)
customer_display_name = await get_user_email_by_id(user_id) or user_id
except Exception:
customer_display_name = user_id

View File

@@ -1,18 +1,22 @@
"""Tests for SearchFeatureRequestsTool and CreateFeatureRequestTool."""
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, patch
import pytest
from ._test_data import make_session
from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool
from .models import (
from backend.api.features.chat.tools.feature_requests import (
CreateFeatureRequestTool,
SearchFeatureRequestsTool,
)
from backend.api.features.chat.tools.models import (
ErrorResponse,
FeatureRequestCreatedResponse,
FeatureRequestSearchResponse,
NoResultsResponse,
)
from ._test_data import make_session
_TEST_USER_ID = "test-user-feature-requests"
_TEST_USER_EMAIL = "testuser@example.com"
@@ -35,7 +39,7 @@ def _mock_linear_config(*, query_return=None, mutate_return=None):
client.mutate.return_value = mutate_return
return (
patch(
"backend.copilot.tools.feature_requests._get_linear_config",
"backend.api.features.chat.tools.feature_requests._get_linear_config",
return_value=(client, _FAKE_PROJECT_ID, _FAKE_TEAM_ID),
),
client,
@@ -117,11 +121,13 @@ class TestSearchFeatureRequestsTool:
"id": "id-1",
"identifier": "FR-1",
"title": "Dark mode",
"description": "Add dark mode support",
},
{
"id": "id-2",
"identifier": "FR-2",
"title": "Dark theme",
"description": None,
},
]
patcher, _ = _mock_linear_config(query_return=_search_response(nodes))
@@ -202,7 +208,7 @@ class TestSearchFeatureRequestsTool:
async def test_linear_client_init_failure(self):
session = make_session(user_id=_TEST_USER_ID)
with patch(
"backend.copilot.tools.feature_requests._get_linear_config",
"backend.api.features.chat.tools.feature_requests._get_linear_config",
side_effect=RuntimeError("No API key"),
):
tool = SearchFeatureRequestsTool()
@@ -225,11 +231,10 @@ class TestCreateFeatureRequestTool:
@pytest.fixture(autouse=True)
def _patch_email_lookup(self):
mock_user_db = MagicMock()
mock_user_db.get_user_email_by_id = AsyncMock(return_value=_TEST_USER_EMAIL)
with patch(
"backend.copilot.tools.feature_requests.user_db",
return_value=mock_user_db,
"backend.api.features.chat.tools.feature_requests.get_user_email_by_id",
new_callable=AsyncMock,
return_value=_TEST_USER_EMAIL,
):
yield
@@ -342,7 +347,7 @@ class TestCreateFeatureRequestTool:
async def test_linear_client_init_failure(self):
session = make_session(user_id=_TEST_USER_ID)
with patch(
"backend.copilot.tools.feature_requests._get_linear_config",
"backend.api.features.chat.tools.feature_requests._get_linear_config",
side_effect=RuntimeError("No API key"),
):
tool = CreateFeatureRequestTool()

View File

@@ -2,7 +2,7 @@
from typing import Any
from backend.copilot.model import ChatSession
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
from .base import BaseTool

View File

@@ -3,18 +3,17 @@ from typing import Any
from prisma.enums import ContentType
from backend.blocks import get_block
from backend.blocks._base import BlockType
from backend.copilot.model import ChatSession
from backend.data.db_accessors import search
from .base import BaseTool, ToolResponseBase
from .models import (
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
from backend.api.features.chat.tools.models import (
BlockInfoSummary,
BlockListResponse,
ErrorResponse,
NoResultsResponse,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.blocks import get_block
from backend.blocks._base import BlockType
logger = logging.getLogger(__name__)
@@ -108,7 +107,7 @@ class FindBlockTool(BaseTool):
try:
# Search for blocks using hybrid search
results, total = await search().unified_hybrid_search(
results, total = await unified_hybrid_search(
query=query,
content_types=[ContentType.BLOCK],
page=1,

View File

@@ -4,15 +4,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks._base import BlockType
from ._test_data import make_session
from .find_block import (
from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
FindBlockTool,
)
from .models import BlockListResponse
from backend.api.features.chat.tools.models import BlockListResponse
from backend.blocks._base import BlockType
from ._test_data import make_session
_TEST_USER_ID = "test-user-find-block"
@@ -84,17 +84,13 @@ class TestFindBlockFiltering:
"standard-block-id": standard_block,
}.get(block_id)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 2)
)
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
):
with patch(
"backend.copilot.tools.find_block.get_block",
"backend.api.features.chat.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
@@ -132,17 +128,13 @@ class TestFindBlockFiltering:
"normal-block-id": normal_block,
}.get(block_id)
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, 2)
)
with patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, 2),
):
with patch(
"backend.copilot.tools.find_block.get_block",
"backend.api.features.chat.tools.find_block.get_block",
side_effect=mock_get_block,
):
tool = FindBlockTool()
@@ -361,20 +353,13 @@ class TestFindBlockFiltering:
for d in block_defs
}
mock_search_db = MagicMock()
mock_search_db.unified_hybrid_search = AsyncMock(
return_value=(search_results, len(search_results))
)
with (
patch(
"backend.copilot.tools.find_block.search",
return_value=mock_search_db,
),
patch(
"backend.copilot.tools.find_block.get_block",
side_effect=lambda bid: mock_blocks.get(bid),
),
with patch(
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
new_callable=AsyncMock,
return_value=(search_results, len(search_results)),
), patch(
"backend.api.features.chat.tools.find_block.get_block",
side_effect=lambda bid: mock_blocks.get(bid),
):
tool = FindBlockTool()
response = await tool._execute(

View File

@@ -2,7 +2,7 @@
from typing import Any
from backend.copilot.model import ChatSession
from backend.api.features.chat.model import ChatSession
from .agent_search import search_agents
from .base import BaseTool
@@ -19,10 +19,9 @@ class FindLibraryAgentTool(BaseTool):
@property
def description(self) -> str:
return (
"Search for or list agents in the user's library. Use this to find "
"agents the user has already added to their library, including agents "
"they created or added from the marketplace. "
"Omit the query to list all agents."
"Search for agents in the user's library. Use this to find agents "
"the user has already added to their library, including agents they "
"created or added from the marketplace."
)
@property
@@ -32,13 +31,10 @@ class FindLibraryAgentTool(BaseTool):
"properties": {
"query": {
"type": "string",
"description": (
"Search query to find agents by name or description. "
"Omit to list all agents in the library."
),
"description": "Search query to find agents by name or description.",
},
},
"required": [],
"required": ["query"],
}
@property
@@ -49,7 +45,7 @@ class FindLibraryAgentTool(BaseTool):
self, user_id: str | None, session: ChatSession, **kwargs
) -> ToolResponseBase:
return await search_agents(
query=(kwargs.get("query") or "").strip(),
query=kwargs.get("query", "").strip(),
source="library",
session_id=session.session_id,
user_id=user_id,

View File

@@ -4,10 +4,13 @@ import logging
from pathlib import Path
from typing import Any
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import DocPageResponse, ErrorResponse, ToolResponseBase
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
DocPageResponse,
ErrorResponse,
ToolResponseBase,
)
logger = logging.getLogger(__name__)

View File

@@ -2,7 +2,7 @@
from datetime import datetime
from enum import Enum
from typing import Any, Literal
from typing import Any
from pydantic import BaseModel, Field
@@ -36,20 +36,20 @@ class ResponseType(str, Enum):
WORKSPACE_FILE_WRITTEN = "workspace_file_written"
WORKSPACE_FILE_DELETED = "workspace_file_deleted"
# Long-running operation types
OPERATION_STARTED = "operation_started"
OPERATION_PENDING = "operation_pending"
OPERATION_IN_PROGRESS = "operation_in_progress"
# Input validation
INPUT_VALIDATION_ERROR = "input_validation_error"
# Web fetch
WEB_FETCH = "web_fetch"
# Browser-based web browsing (JS-rendered pages)
BROWSE_WEB = "browse_web"
# Code execution
BASH_EXEC = "bash_exec"
# Operation status check
OPERATION_STATUS = "operation_status"
# Feature request types
FEATURE_REQUEST_SEARCH = "feature_request_search"
FEATURE_REQUEST_CREATED = "feature_request_created"
# Goal refinement
SUGGESTED_GOAL = "suggested_goal"
# Base response model
@@ -296,22 +296,6 @@ class ClarificationNeededResponse(ToolResponseBase):
questions: list[ClarifyingQuestion] = Field(default_factory=list)
class SuggestedGoalResponse(ToolResponseBase):
"""Response when the goal needs refinement with a suggested alternative."""
type: ResponseType = ResponseType.SUGGESTED_GOAL
suggested_goal: str = Field(description="The suggested alternative goal")
reason: str = Field(
default="", description="Why the original goal needs refinement"
)
original_goal: str = Field(
default="", description="The user's original goal for context"
)
goal_type: Literal["vague", "unachievable"] = Field(
default="vague", description="Type: 'vague' or 'unachievable'"
)
# Documentation search models
class DocSearchResult(BaseModel):
"""A single documentation search result."""
@@ -418,6 +402,34 @@ class BlockOutputResponse(ToolResponseBase):
# Long-running operation models
class OperationStartedResponse(ToolResponseBase):
"""Response when a long-running operation has been started in the background.
This is returned immediately to the client while the operation continues
to execute. The user can close the tab and check back later.
The task_id can be used to reconnect to the SSE stream via
GET /chat/tasks/{task_id}/stream?last_idx=0
"""
type: ResponseType = ResponseType.OPERATION_STARTED
operation_id: str
tool_name: str
task_id: str | None = None # For SSE reconnection
class OperationPendingResponse(ToolResponseBase):
"""Response stored in chat history while a long-running operation is executing.
This is persisted to the database so users see a pending state when they
refresh before the operation completes.
"""
type: ResponseType = ResponseType.OPERATION_PENDING
operation_id: str
tool_name: str
class OperationInProgressResponse(ToolResponseBase):
"""Response when an operation is already in progress.
@@ -429,6 +441,23 @@ class OperationInProgressResponse(ToolResponseBase):
tool_call_id: str
class AsyncProcessingResponse(ToolResponseBase):
"""Response when an operation has been delegated to async processing.
This is returned by tools when the external service accepts the request
for async processing (HTTP 202 Accepted). The Redis Streams completion
consumer will handle the result when the external service completes.
The status field is specifically "accepted" to allow the long-running tool
handler to detect this response and skip LLM continuation.
"""
type: ResponseType = ResponseType.OPERATION_STARTED
status: str = "accepted" # Must be "accepted" for detection
operation_id: str | None = None
task_id: str | None = None
class WebFetchResponse(ToolResponseBase):
"""Response for web_fetch tool."""
@@ -440,15 +469,6 @@ class WebFetchResponse(ToolResponseBase):
truncated: bool = False
class BrowseWebResponse(ToolResponseBase):
"""Response for browse_web tool."""
type: ResponseType = ResponseType.BROWSE_WEB
url: str
content: str
truncated: bool = False
class BashExecResponse(ToolResponseBase):
"""Response for bash_exec tool."""
@@ -466,6 +486,7 @@ class FeatureRequestInfo(BaseModel):
id: str
identifier: str
title: str
description: str | None = None
class FeatureRequestSearchResponse(ToolResponseBase):

View File

@@ -5,13 +5,16 @@ from typing import Any
from pydantic import BaseModel, Field, field_validator
from backend.copilot.config import ChatConfig
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
from backend.data.db_accessors import graph_db, library_db, user_db
from backend.data.execution import ExecutionStatus
from backend.api.features.chat.config import ChatConfig
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tracking import (
track_agent_run_success,
track_agent_scheduled,
)
from backend.api.features.library import db as library_db
from backend.data.graph import GraphModel
from backend.data.model import CredentialsMetaInput
from backend.data.user import get_user_by_id
from backend.executor import utils as execution_utils
from backend.util.clients import get_scheduler_client
from backend.util.exceptions import DatabaseError, NotFoundError
@@ -21,15 +24,12 @@ from backend.util.timezone_utils import (
)
from .base import BaseTool
from .execution_utils import get_execution_outputs, wait_for_execution
from .helpers import get_inputs_from_schema
from .models import (
AgentDetails,
AgentDetailsResponse,
AgentOutputResponse,
ErrorResponse,
ExecutionOptions,
ExecutionOutputInfo,
ExecutionStartedResponse,
InputValidationErrorResponse,
SetupInfo,
@@ -70,7 +70,6 @@ class RunAgentInput(BaseModel):
schedule_name: str = ""
cron: str = ""
timezone: str = "UTC"
wait_for_result: int = Field(default=0, ge=0, le=300)
@field_validator(
"username_agent_slug",
@@ -152,14 +151,6 @@ class RunAgentTool(BaseTool):
"type": "string",
"description": "IANA timezone for schedule (default: UTC)",
},
"wait_for_result": {
"type": "integer",
"description": (
"Max seconds to wait for execution to complete (0-300). "
"If >0, blocks until the execution finishes or times out. "
"Returns execution outputs when complete."
),
},
},
"required": [],
}
@@ -209,7 +200,7 @@ class RunAgentTool(BaseTool):
# Priority: library_agent_id if provided
if has_library_id:
library_agent = await library_db().get_library_agent(
library_agent = await library_db.get_library_agent(
params.library_agent_id, user_id
)
if not library_agent:
@@ -218,7 +209,9 @@ class RunAgentTool(BaseTool):
session_id=session_id,
)
# Get the graph from the library agent
graph = await graph_db().get_graph(
from backend.data.graph import get_graph
graph = await get_graph(
library_agent.graph_id,
library_agent.graph_version,
user_id=user_id,
@@ -354,7 +347,6 @@ class RunAgentTool(BaseTool):
graph=graph,
graph_credentials=graph_credentials,
inputs=params.inputs,
wait_for_result=params.wait_for_result,
)
except NotFoundError as e:
@@ -438,9 +430,8 @@ class RunAgentTool(BaseTool):
graph: GraphModel,
graph_credentials: dict[str, CredentialsMetaInput],
inputs: dict[str, Any],
wait_for_result: int = 0,
) -> ToolResponseBase:
"""Execute an agent immediately, optionally waiting for completion."""
"""Execute an agent immediately."""
session_id = session.session_id
# Check rate limits
@@ -477,91 +468,6 @@ class RunAgentTool(BaseTool):
)
library_agent_link = f"/library/agents/{library_agent.id}"
# If wait_for_result is requested, wait for execution to complete
if wait_for_result > 0:
logger.info(
f"Waiting up to {wait_for_result}s for execution {execution.id}"
)
completed = await wait_for_execution(
user_id=user_id,
graph_id=library_agent.graph_id,
execution_id=execution.id,
timeout_seconds=wait_for_result,
)
if completed and completed.status == ExecutionStatus.COMPLETED:
outputs = get_execution_outputs(completed)
return AgentOutputResponse(
message=(
f"Agent '{library_agent.name}' completed successfully. "
f"View at {library_agent_link}."
),
session_id=session_id,
agent_name=library_agent.name,
agent_id=library_agent.graph_id,
library_agent_id=library_agent.id,
library_agent_link=library_agent_link,
execution=ExecutionOutputInfo(
execution_id=execution.id,
status=completed.status.value,
started_at=completed.started_at,
ended_at=completed.ended_at,
outputs=outputs or {},
),
)
elif completed and completed.status == ExecutionStatus.FAILED:
error_detail = completed.stats.error if completed.stats else None
return ErrorResponse(
message=(
f"Agent '{library_agent.name}' execution failed. "
f"View details at {library_agent_link}."
),
session_id=session_id,
error=error_detail,
)
elif completed and completed.status == ExecutionStatus.TERMINATED:
error_detail = completed.stats.error if completed.stats else None
return ErrorResponse(
message=(
f"Agent '{library_agent.name}' execution was terminated. "
f"View details at {library_agent_link}."
),
session_id=session_id,
error=error_detail,
)
elif completed and completed.status == ExecutionStatus.REVIEW:
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' is awaiting human review. "
f"Check at {library_agent_link}."
),
session_id=session_id,
execution_id=execution.id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
library_agent_id=library_agent.id,
library_agent_link=library_agent_link,
status=ExecutionStatus.REVIEW.value,
)
else:
status = completed.status.value if completed else "unknown"
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' is still {status} after "
f"{wait_for_result}s. Check results later at "
f"{library_agent_link}. "
f"Use view_agent_output with wait_if_running to check again."
),
session_id=session_id,
execution_id=execution.id,
graph_id=library_agent.graph_id,
graph_name=library_agent.name,
library_agent_id=library_agent.id,
library_agent_link=library_agent_link,
status=status,
)
return ExecutionStartedResponse(
message=(
f"Agent '{library_agent.name}' execution started successfully. "
@@ -616,7 +522,7 @@ class RunAgentTool(BaseTool):
library_agent = await get_or_create_library_agent(graph, user_id)
# Get user timezone
user = await user_db().get_user_by_id(user_id)
user = await get_user_by_id(user_id)
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
# Create schedule

View File

@@ -7,17 +7,20 @@ from typing import Any
from pydantic_core import PydanticUndefined
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.find_block import (
COPILOT_EXCLUDED_BLOCK_IDS,
COPILOT_EXCLUDED_BLOCK_TYPES,
)
from backend.blocks import get_block
from backend.blocks._base import AnyBlockSchema
from backend.copilot.model import ChatSession
from backend.data.db_accessors import workspace_db
from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError
from .base import BaseTool
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
from .helpers import get_inputs_from_schema
from .models import (
BlockDetails,
@@ -160,10 +163,9 @@ class RunBlockTool(BaseTool):
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
creds_manager = IntegrationCredentialsManager()
(
matched_credentials,
missing_credentials,
) = await self._resolve_block_credentials(user_id, block, input_data)
matched_credentials, missing_credentials = (
await self._resolve_block_credentials(user_id, block, input_data)
)
# Get block schemas for details/validation
try:
@@ -274,7 +276,7 @@ class RunBlockTool(BaseTool):
try:
# Get or create user's workspace for CoPilot file operations
workspace = await workspace_db().get_or_create_workspace(user_id)
workspace = await get_or_create_workspace(user_id)
# Generate synthetic IDs for CoPilot context
# Each chat session is treated as its own agent with one continuous run

View File

@@ -4,16 +4,16 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks._base import BlockType
from ._test_data import make_session
from .models import (
from backend.api.features.chat.tools.models import (
BlockDetailsResponse,
BlockOutputResponse,
ErrorResponse,
InputValidationErrorResponse,
)
from .run_block import RunBlockTool
from backend.api.features.chat.tools.run_block import RunBlockTool
from backend.blocks._base import BlockType
from ._test_data import make_session
_TEST_USER_ID = "test-user-run-block"
@@ -77,7 +77,7 @@ class TestRunBlockFiltering:
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
with patch(
"backend.copilot.tools.run_block.get_block",
"backend.api.features.chat.tools.run_block.get_block",
return_value=input_block,
):
tool = RunBlockTool()
@@ -103,7 +103,7 @@ class TestRunBlockFiltering:
)
with patch(
"backend.copilot.tools.run_block.get_block",
"backend.api.features.chat.tools.run_block.get_block",
return_value=smart_block,
):
tool = RunBlockTool()
@@ -127,7 +127,7 @@ class TestRunBlockFiltering:
)
with patch(
"backend.copilot.tools.run_block.get_block",
"backend.api.features.chat.tools.run_block.get_block",
return_value=standard_block,
):
tool = RunBlockTool()
@@ -183,7 +183,7 @@ class TestRunBlockInputValidation:
)
with patch(
"backend.copilot.tools.run_block.get_block",
"backend.api.features.chat.tools.run_block.get_block",
return_value=mock_block,
):
tool = RunBlockTool()
@@ -222,7 +222,7 @@ class TestRunBlockInputValidation:
)
with patch(
"backend.copilot.tools.run_block.get_block",
"backend.api.features.chat.tools.run_block.get_block",
return_value=mock_block,
):
tool = RunBlockTool()
@@ -263,7 +263,7 @@ class TestRunBlockInputValidation:
)
with patch(
"backend.copilot.tools.run_block.get_block",
"backend.api.features.chat.tools.run_block.get_block",
return_value=mock_block,
):
tool = RunBlockTool()
@@ -302,19 +302,15 @@ class TestRunBlockInputValidation:
mock_block.execute = mock_execute
mock_workspace_db = MagicMock()
mock_workspace_db.get_or_create_workspace = AsyncMock(
return_value=MagicMock(id="test-workspace-id")
)
with (
patch(
"backend.copilot.tools.run_block.get_block",
"backend.api.features.chat.tools.run_block.get_block",
return_value=mock_block,
),
patch(
"backend.copilot.tools.run_block.workspace_db",
return_value=mock_workspace_db,
"backend.api.features.chat.tools.run_block.get_or_create_workspace",
new_callable=AsyncMock,
return_value=MagicMock(id="test-workspace-id"),
),
):
tool = RunBlockTool()
@@ -348,7 +344,7 @@ class TestRunBlockInputValidation:
)
with patch(
"backend.copilot.tools.run_block.get_block",
"backend.api.features.chat.tools.run_block.get_block",
return_value=mock_block,
):
tool = RunBlockTool()

View File

@@ -13,7 +13,6 @@ import logging
import os
import platform
import shutil
import signal
logger = logging.getLogger(__name__)
@@ -246,7 +245,6 @@ async def run_sandboxed(
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
env=safe_env,
start_new_session=True, # Own process group for clean kill
)
try:
@@ -257,18 +255,7 @@ async def run_sandboxed(
stderr = stderr_bytes.decode("utf-8", errors="replace")
return stdout, stderr, proc.returncode or 0, False
except asyncio.TimeoutError:
# Kill entire process group (bwrap + all children).
# proc.kill() alone only kills the bwrap parent, leaving
# children running until they finish naturally.
try:
os.killpg(proc.pid, signal.SIGKILL)
except ProcessLookupError:
pass # Already exited
except OSError as kill_err:
logger.warning(
"Failed to kill process group %d: %s", proc.pid, kill_err
)
# Always reap the subprocess regardless of killpg outcome.
proc.kill()
await proc.communicate()
return "", f"Execution timed out after {timeout}s", -1, True

View File

@@ -5,17 +5,16 @@ from typing import Any
from prisma.enums import ContentType
from backend.copilot.model import ChatSession
from backend.data.db_accessors import search
from .base import BaseTool
from .models import (
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
DocSearchResult,
DocSearchResultsResponse,
ErrorResponse,
NoResultsResponse,
ToolResponseBase,
)
from backend.api.features.store.hybrid_search import unified_hybrid_search
logger = logging.getLogger(__name__)
@@ -118,7 +117,7 @@ class SearchDocsTool(BaseTool):
try:
# Search using hybrid search for DOCUMENTATION content type only
results, total = await search().unified_hybrid_search(
results, total = await unified_hybrid_search(
query=query,
content_types=[ContentType.DOCUMENTATION],
page=1,

View File

@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.api.features.chat.tools.models import BlockDetailsResponse
from backend.api.features.chat.tools.run_block import RunBlockTool
from backend.blocks._base import BlockType
from backend.data.model import CredentialsMetaInput
from backend.integrations.providers import ProviderName
from ._test_data import make_session
from .models import BlockDetailsResponse
from .run_block import RunBlockTool
_TEST_USER_ID = "test-user-run-block-details"
@@ -61,7 +61,7 @@ async def test_run_block_returns_details_when_no_input_provided():
)
with patch(
"backend.copilot.tools.run_block.get_block",
"backend.api.features.chat.tools.run_block.get_block",
return_value=http_block,
):
# Mock credentials check to return no missing credentials
@@ -120,7 +120,7 @@ async def test_run_block_returns_details_when_only_credentials_provided():
}
with patch(
"backend.copilot.tools.run_block.get_block",
"backend.api.features.chat.tools.run_block.get_block",
return_value=mock,
):
with patch.object(

View File

@@ -3,8 +3,9 @@
import logging
from typing import Any
from backend.api.features.library import db as library_db
from backend.api.features.library import model as library_model
from backend.data.db_accessors import library_db, store_db
from backend.api.features.store import db as store_db
from backend.data.graph import GraphModel
from backend.data.model import (
Credentials,
@@ -38,14 +39,13 @@ async def fetch_graph_from_store_slug(
Raises:
DatabaseError: If there's a database error during lookup.
"""
sdb = store_db()
try:
store_agent = await sdb.get_store_agent_details(username, agent_name)
store_agent = await store_db.get_store_agent_details(username, agent_name)
except NotFoundError:
return None, None
# Get the graph from store listing version
graph = await sdb.get_available_graph(
graph = await store_db.get_available_graph(
store_agent.store_listing_version_id, hide_nodes=False
)
return graph, store_agent
@@ -119,7 +119,7 @@ def build_missing_credentials_from_graph(
preserving all supported credential types for each field.
"""
matched_keys = set(matched_credentials.keys()) if matched_credentials else set()
aggregated_fields = graph.aggregate_credentials_inputs()
aggregated_fields = graph.regular_credentials_inputs
return {
field_key: _serialize_missing_credential(field_key, field_info)
@@ -210,13 +210,13 @@ async def get_or_create_library_agent(
Returns:
LibraryAgent instance
"""
existing = await library_db().get_library_agent_by_graph_id(
existing = await library_db.get_library_agent_by_graph_id(
graph_id=graph.id, user_id=user_id
)
if existing:
return existing
library_agents = await library_db().create_library_agent(
library_agents = await library_db.create_library_agent(
graph=graph,
user_id=user_id,
create_library_agents_for_sub_graphs=False,
@@ -339,7 +339,7 @@ async def match_user_credentials_to_graph(
missing_creds: list[str] = []
# Get aggregated credentials requirements from the graph
aggregated_creds = graph.aggregate_credentials_inputs()
aggregated_creds = graph.regular_credentials_inputs
logger.debug(
f"Matching credentials for graph {graph.id}: {len(aggregated_creds)} required"
)

View File

@@ -0,0 +1,78 @@
"""Tests for chat tools utility functions."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.data.model import CredentialsFieldInfo
def _make_regular_field() -> CredentialsFieldInfo:
return CredentialsFieldInfo.model_validate(
{
"credentials_provider": ["github"],
"credentials_types": ["api_key"],
"is_auto_credential": False,
},
by_alias=True,
)
def test_build_missing_credentials_excludes_auto_creds():
"""
build_missing_credentials_from_graph() should use regular_credentials_inputs
and thus exclude auto_credentials from the "missing" set.
"""
from backend.api.features.chat.tools.utils import (
build_missing_credentials_from_graph,
)
regular_field = _make_regular_field()
mock_graph = MagicMock()
# regular_credentials_inputs should only return the non-auto field
mock_graph.regular_credentials_inputs = {
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
}
result = build_missing_credentials_from_graph(mock_graph, matched_credentials=None)
# Should include the regular credential
assert "github_api_key" in result
# Should NOT include the auto_credential (not in regular_credentials_inputs)
assert "google_oauth2" not in result
@pytest.mark.asyncio
async def test_match_user_credentials_excludes_auto_creds():
"""
match_user_credentials_to_graph() should use regular_credentials_inputs
and thus exclude auto_credentials from matching.
"""
from backend.api.features.chat.tools.utils import match_user_credentials_to_graph
regular_field = _make_regular_field()
mock_graph = MagicMock()
mock_graph.id = "test-graph"
# regular_credentials_inputs returns only non-auto fields
mock_graph.regular_credentials_inputs = {
"github_api_key": (regular_field, {("node-1", "credentials")}, True),
}
# Mock the credentials manager to return no credentials
with patch(
"backend.api.features.chat.tools.utils.IntegrationCredentialsManager"
) as MockCredsMgr:
mock_store = AsyncMock()
mock_store.get_all_creds.return_value = []
MockCredsMgr.return_value.store = mock_store
matched, missing = await match_user_credentials_to_graph(
user_id="test-user", graph=mock_graph
)
# No credentials available, so github should be missing
assert len(matched) == 0
assert len(missing) == 1
assert "github_api_key" in missing[0]

View File

@@ -6,12 +6,15 @@ from typing import Any
import aiohttp
import html2text
from backend.copilot.model import ChatSession
from backend.api.features.chat.model import ChatSession
from backend.api.features.chat.tools.base import BaseTool
from backend.api.features.chat.tools.models import (
ErrorResponse,
ToolResponseBase,
WebFetchResponse,
)
from backend.util.request import Requests
from .base import BaseTool
from .models import ErrorResponse, ToolResponseBase, WebFetchResponse
logger = logging.getLogger(__name__)
# Limits

View File

@@ -0,0 +1,626 @@
"""CoPilot tools for workspace file operations."""
import base64
import logging
from typing import Any, Optional
from pydantic import BaseModel
from backend.api.features.chat.model import ChatSession
from backend.data.workspace import get_or_create_workspace
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
class WorkspaceFileInfoData(BaseModel):
"""Data model for workspace file information (not a response itself)."""
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
class WorkspaceFileListResponse(ToolResponseBase):
"""Response containing list of workspace files."""
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
files: list[WorkspaceFileInfoData]
total_count: int
class WorkspaceFileContentResponse(ToolResponseBase):
"""Response containing workspace file content (legacy, for small text files)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
file_id: str
name: str
path: str
mime_type: str
content_base64: str
class WorkspaceFileMetadataResponse(ToolResponseBase):
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
download_url: str
preview: str | None = None # First 500 chars for text files
class WorkspaceWriteResponse(ToolResponseBase):
"""Response after writing a file to workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
file_id: str
name: str
path: str
size_bytes: int
class WorkspaceDeleteResponse(ToolResponseBase):
"""Response after deleting a file from workspace."""
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
file_id: str
success: bool
class ListWorkspaceFilesTool(BaseTool):
"""Tool for listing files in user's workspace."""
@property
def name(self) -> str:
return "list_workspace_files"
@property
def description(self) -> str:
return (
"List files in the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read/Glob tools instead. "
"Returns file names, paths, sizes, and metadata. "
"Optionally filter by path prefix."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"path_prefix": {
"type": "string",
"description": (
"Optional path prefix to filter files "
"(e.g., '/documents/' to list only files in documents folder). "
"By default, only files from the current session are listed."
),
},
"limit": {
"type": "integer",
"description": "Maximum number of files to return (default 50, max 100)",
"minimum": 1,
"maximum": 100,
},
"include_all_sessions": {
"type": "boolean",
"description": (
"If true, list files from all sessions. "
"Default is false (only current session's files)."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
path_prefix: Optional[str] = kwargs.get("path_prefix")
limit = min(kwargs.get("limit", 50), 100)
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
files = await manager.list_files(
path=path_prefix,
limit=limit,
include_all_sessions=include_all_sessions,
)
total = await manager.get_file_count(
path=path_prefix,
include_all_sessions=include_all_sessions,
)
file_infos = [
WorkspaceFileInfoData(
file_id=f.id,
name=f.name,
path=f.path,
mime_type=f.mimeType,
size_bytes=f.sizeBytes,
)
for f in files
]
scope_msg = "all sessions" if include_all_sessions else "current session"
return WorkspaceFileListResponse(
files=file_infos,
total_count=total,
message=f"Found {len(files)} files in workspace ({scope_msg})",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error listing workspace files: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to list workspace files: {str(e)}",
error=str(e),
session_id=session_id,
)
class ReadWorkspaceFileTool(BaseTool):
"""Tool for reading file content from workspace."""
# Size threshold for returning full content vs metadata+URL
# Files larger than this return metadata with download URL to prevent context bloat
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
# Preview size for text files
PREVIEW_SIZE = 500
@property
def name(self) -> str:
return "read_workspace_file"
@property
def description(self) -> str:
return (
"Read a file from the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Read tool instead. "
"Specify either file_id or path to identify the file. "
"For small text files, returns content directly. "
"For large or binary files, returns metadata and a download URL. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
"force_download_url": {
"type": "boolean",
"description": (
"If true, always return metadata+URL instead of inline content. "
"Default is false (auto-selects based on file size/type)."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
def _is_text_mime_type(self, mime_type: str) -> bool:
"""Check if the MIME type is a text-based type."""
text_types = [
"text/",
"application/json",
"application/xml",
"application/javascript",
"application/x-python",
"application/x-sh",
]
return any(mime_type.startswith(t) for t in text_types)
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
force_download_url: bool = kwargs.get("force_download_url", False)
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Get file info
if file_id:
file_info = await manager.get_file_info(file_id)
if file_info is None:
return ErrorResponse(
message=f"File not found: {file_id}",
session_id=session_id,
)
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
# Decide whether to return inline content or metadata+URL
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
is_text_file = self._is_text_mime_type(file_info.mimeType)
# Return inline content for small text files (unless force_download_url)
if is_small_file and is_text_file and not force_download_url:
content = await manager.read_file_by_id(target_file_id)
content_b64 = base64.b64encode(content).decode("utf-8")
return WorkspaceFileContentResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
content_base64=content_b64,
message=f"Successfully read file: {file_info.name}",
session_id=session_id,
)
# Return metadata + workspace:// reference for large or binary files
# This prevents context bloat (100KB file = ~133KB as base64)
# Use workspace:// format so frontend urlTransform can add proxy prefix
download_url = f"workspace://{target_file_id}"
# Generate preview for text files
preview: str | None = None
if is_text_file:
try:
content = await manager.read_file_by_id(target_file_id)
preview_text = content[: self.PREVIEW_SIZE].decode(
"utf-8", errors="replace"
)
if len(content) > self.PREVIEW_SIZE:
preview_text += "..."
preview = preview_text
except Exception:
pass # Preview is optional
return WorkspaceFileMetadataResponse(
file_id=file_info.id,
name=file_info.name,
path=file_info.path,
mime_type=file_info.mimeType,
size_bytes=file_info.sizeBytes,
download_url=download_url,
preview=preview,
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
session_id=session_id,
)
except FileNotFoundError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error reading workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to read workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class WriteWorkspaceFileTool(BaseTool):
"""Tool for writing files to workspace."""
@property
def name(self) -> str:
return "write_workspace_file"
@property
def description(self) -> str:
return (
"Write or create a file in the user's persistent workspace (cloud storage). "
"These files survive across sessions. "
"For ephemeral session files, use the SDK Write tool instead. "
"Provide the content as a base64-encoded string. "
f"Maximum file size is {Config().max_file_size_mb}MB. "
"Files are saved to the current session's folder by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Name for the file (e.g., 'report.pdf')",
},
"content_base64": {
"type": "string",
"description": "Base64-encoded file content",
},
"path": {
"type": "string",
"description": (
"Optional virtual path where to save the file "
"(e.g., '/documents/report.pdf'). "
"Defaults to '/{filename}'. Scoped to current session."
),
},
"mime_type": {
"type": "string",
"description": (
"Optional MIME type of the file. "
"Auto-detected from filename if not provided."
),
},
"overwrite": {
"type": "boolean",
"description": "Whether to overwrite if file exists at path (default: false)",
},
},
"required": ["filename", "content_base64"],
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
filename: str = kwargs.get("filename", "")
content_b64: str = kwargs.get("content_base64", "")
path: Optional[str] = kwargs.get("path")
mime_type: Optional[str] = kwargs.get("mime_type")
overwrite: bool = kwargs.get("overwrite", False)
if not filename:
return ErrorResponse(
message="Please provide a filename",
session_id=session_id,
)
if not content_b64:
return ErrorResponse(
message="Please provide content_base64",
session_id=session_id,
)
# Decode content
try:
content = base64.b64decode(content_b64)
except Exception:
return ErrorResponse(
message="Invalid base64-encoded content",
session_id=session_id,
)
# Check size
max_file_size = Config().max_file_size_mb * 1024 * 1024
if len(content) > max_file_size:
return ErrorResponse(
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
session_id=session_id,
)
try:
# Virus scan
await scan_content_safe(content, filename=filename)
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
file_record = await manager.write_file(
content=content,
filename=filename,
path=path,
mime_type=mime_type,
overwrite=overwrite,
)
return WorkspaceWriteResponse(
file_id=file_record.id,
name=file_record.name,
path=file_record.path,
size_bytes=file_record.sizeBytes,
message=f"Successfully wrote file: {file_record.name}",
session_id=session_id,
)
except ValueError as e:
return ErrorResponse(
message=str(e),
session_id=session_id,
)
except Exception as e:
logger.error(f"Error writing workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to write workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)
class DeleteWorkspaceFileTool(BaseTool):
"""Tool for deleting files from workspace."""
@property
def name(self) -> str:
return "delete_workspace_file"
@property
def description(self) -> str:
return (
"Delete a file from the user's persistent workspace (cloud storage). "
"Specify either file_id or path to identify the file. "
"Paths are scoped to the current session by default. "
"Use /sessions/<session_id>/... for cross-session access."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"file_id": {
"type": "string",
"description": "The file's unique ID (from list_workspace_files)",
},
"path": {
"type": "string",
"description": (
"The virtual file path (e.g., '/documents/report.pdf'). "
"Scoped to current session by default."
),
},
},
"required": [], # At least one must be provided
}
@property
def requires_auth(self) -> bool:
return True
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
session_id = session.session_id
if not user_id:
return ErrorResponse(
message="Authentication required",
session_id=session_id,
)
file_id: Optional[str] = kwargs.get("file_id")
path: Optional[str] = kwargs.get("path")
if not file_id and not path:
return ErrorResponse(
message="Please provide either file_id or path",
session_id=session_id,
)
try:
workspace = await get_or_create_workspace(user_id)
# Pass session_id for session-scoped file access
manager = WorkspaceManager(user_id, workspace.id, session_id)
# Determine the file_id to delete
target_file_id: str
if file_id:
target_file_id = file_id
else:
# path is guaranteed to be non-None here due to the check above
assert path is not None
file_info = await manager.get_file_info_by_path(path)
if file_info is None:
return ErrorResponse(
message=f"File not found at path: {path}",
session_id=session_id,
)
target_file_id = file_info.id
success = await manager.delete_file(target_file_id)
if not success:
return ErrorResponse(
message=f"File not found: {target_file_id}",
session_id=session_id,
)
return WorkspaceDeleteResponse(
file_id=target_file_id,
success=True,
message="File deleted successfully",
session_id=session_id,
)
except Exception as e:
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
return ErrorResponse(
message=f"Failed to delete workspace file: {str(e)}",
error=str(e),
session_id=session_id,
)

File diff suppressed because it is too large Load Diff

View File

@@ -144,7 +144,6 @@ async def test_add_agent_to_library(mocker):
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
@@ -179,6 +178,7 @@ async def test_add_agent_to_library(mocker):
"agentGraphVersion": 1,
}
},
include={"AgentGraph": True},
)
# Check that create was called with the expected data including settings
create_call_args = mock_library_agent.return_value.create.call_args

View File

@@ -1,10 +0,0 @@
class FolderValidationError(Exception):
"""Raised when folder operations fail validation."""
pass
class FolderAlreadyExistsError(FolderValidationError):
"""Raised when a folder with the same name already exists in the location."""
pass

View File

@@ -26,95 +26,6 @@ class LibraryAgentStatus(str, Enum):
ERROR = "ERROR"
# === Folder Models ===
class LibraryFolder(pydantic.BaseModel):
"""Represents a folder for organizing library agents."""
id: str
user_id: str
name: str
icon: str | None = None
color: str | None = None
parent_id: str | None = None
created_at: datetime.datetime
updated_at: datetime.datetime
agent_count: int = 0 # Direct agents in folder
subfolder_count: int = 0 # Direct child folders
@staticmethod
def from_db(
folder: prisma.models.LibraryFolder,
agent_count: int = 0,
subfolder_count: int = 0,
) -> "LibraryFolder":
"""Factory method that constructs a LibraryFolder from a Prisma model."""
return LibraryFolder(
id=folder.id,
user_id=folder.userId,
name=folder.name,
icon=folder.icon,
color=folder.color,
parent_id=folder.parentId,
created_at=folder.createdAt,
updated_at=folder.updatedAt,
agent_count=agent_count,
subfolder_count=subfolder_count,
)
class LibraryFolderTree(LibraryFolder):
"""Folder with nested children for tree view."""
children: list["LibraryFolderTree"] = []
class FolderCreateRequest(pydantic.BaseModel):
"""Request model for creating a folder."""
name: str = pydantic.Field(..., min_length=1, max_length=100)
icon: str | None = None
color: str | None = pydantic.Field(
None, pattern=r"^#[0-9A-Fa-f]{6}$", description="Hex color code (#RRGGBB)"
)
parent_id: str | None = None
class FolderUpdateRequest(pydantic.BaseModel):
"""Request model for updating a folder."""
name: str | None = pydantic.Field(None, min_length=1, max_length=100)
icon: str | None = None
color: str | None = None
class FolderMoveRequest(pydantic.BaseModel):
"""Request model for moving a folder to a new parent."""
target_parent_id: str | None = None # None = move to root
class BulkMoveAgentsRequest(pydantic.BaseModel):
"""Request model for moving multiple agents to a folder."""
agent_ids: list[str]
folder_id: str | None = None # None = move to root
class FolderListResponse(pydantic.BaseModel):
"""Response schema for a list of folders."""
folders: list[LibraryFolder]
pagination: Pagination
class FolderTreeResponse(pydantic.BaseModel):
"""Response schema for folder tree structure."""
tree: list[LibraryFolderTree]
class MarketplaceListingCreator(pydantic.BaseModel):
"""Creator information for a marketplace listing."""
@@ -209,9 +120,6 @@ class LibraryAgent(pydantic.BaseModel):
can_access_graph: bool
is_latest_version: bool
is_favorite: bool
folder_id: str | None = None
folder_name: str | None = None # Denormalized for display
recommended_schedule_cron: str | None = None
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
marketplace_listing: Optional["MarketplaceListing"] = None
@@ -351,8 +259,6 @@ class LibraryAgent(pydantic.BaseModel):
can_access_graph=can_access_graph,
is_latest_version=is_latest_version,
is_favorite=agent.isFavorite,
folder_id=agent.folderId,
folder_name=agent.Folder.name if agent.Folder else None,
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
settings=_parse_settings(agent.settings),
marketplace_listing=marketplace_listing_data,
@@ -564,7 +470,3 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
settings: Optional[GraphSettings] = pydantic.Field(
default=None, description="User-specific settings for this library agent"
)
folder_id: Optional[str] = pydantic.Field(
default=None,
description="Folder ID to move agent to (None to move to root)",
)

View File

@@ -1,11 +1,9 @@
import fastapi
from .agents import router as agents_router
from .folders import router as folders_router
from .presets import router as presets_router
router = fastapi.APIRouter()
router.include_router(presets_router)
router.include_router(folders_router)
router.include_router(agents_router)

View File

@@ -41,14 +41,6 @@ async def list_library_agents(
ge=1,
description="Number of agents per page (must be >= 1)",
),
folder_id: Optional[str] = Query(
None,
description="Filter by folder ID",
),
include_root_only: bool = Query(
False,
description="Only return agents without a folder (root-level agents)",
),
) -> library_model.LibraryAgentResponse:
"""
Get all agents in the user's library (both created and saved).
@@ -59,8 +51,6 @@ async def list_library_agents(
sort_by=sort_by,
page=page,
page_size=page_size,
folder_id=folder_id,
include_root_only=include_root_only,
)
@@ -178,7 +168,6 @@ async def update_library_agent(
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
settings=payload.settings,
folder_id=payload.folder_id,
)

View File

@@ -1,287 +0,0 @@
from typing import Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Query, Security, status
from fastapi.responses import Response
from .. import db as library_db
from .. import model as library_model
router = APIRouter(
prefix="/folders",
tags=["library", "folders", "private"],
dependencies=[Security(autogpt_auth_lib.requires_user)],
)
@router.get(
"",
summary="List Library Folders",
response_model=library_model.FolderListResponse,
responses={
200: {"description": "List of folders"},
500: {"description": "Server error"},
},
)
async def list_folders(
user_id: str = Security(autogpt_auth_lib.get_user_id),
parent_id: Optional[str] = Query(
None,
description="Filter by parent folder ID. If not provided, returns root-level folders.",
),
include_relations: bool = Query(
True,
description="Include agent and subfolder relations (for counts)",
),
) -> library_model.FolderListResponse:
"""
List folders for the authenticated user.
Args:
user_id: ID of the authenticated user.
parent_id: Optional parent folder ID to filter by.
include_relations: Whether to include agent and subfolder relations for counts.
Returns:
A FolderListResponse containing folders.
"""
folders = await library_db.list_folders(
user_id=user_id,
parent_id=parent_id,
include_relations=include_relations,
)
return library_model.FolderListResponse(
folders=folders,
pagination=library_model.Pagination(
total_items=len(folders),
total_pages=1,
current_page=1,
page_size=len(folders),
),
)
@router.get(
"/tree",
summary="Get Folder Tree",
response_model=library_model.FolderTreeResponse,
responses={
200: {"description": "Folder tree structure"},
500: {"description": "Server error"},
},
)
async def get_folder_tree(
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.FolderTreeResponse:
"""
Get the full folder tree for the authenticated user.
Args:
user_id: ID of the authenticated user.
Returns:
A FolderTreeResponse containing the nested folder structure.
"""
tree = await library_db.get_folder_tree(user_id=user_id)
return library_model.FolderTreeResponse(tree=tree)
@router.get(
"/{folder_id}",
summary="Get Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder details"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def get_folder(
folder_id: str,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Get a specific folder.
Args:
folder_id: ID of the folder to retrieve.
user_id: ID of the authenticated user.
Returns:
The requested LibraryFolder.
"""
return await library_db.get_folder(folder_id=folder_id, user_id=user_id)
@router.post(
"",
summary="Create Folder",
status_code=status.HTTP_201_CREATED,
response_model=library_model.LibraryFolder,
responses={
201: {"description": "Folder created successfully"},
400: {"description": "Validation error"},
404: {"description": "Parent folder not found"},
409: {"description": "Folder name conflict"},
500: {"description": "Server error"},
},
)
async def create_folder(
payload: library_model.FolderCreateRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Create a new folder.
Args:
payload: The folder creation request.
user_id: ID of the authenticated user.
Returns:
The created LibraryFolder.
"""
return await library_db.create_folder(
user_id=user_id,
name=payload.name,
parent_id=payload.parent_id,
icon=payload.icon,
color=payload.color,
)
@router.patch(
"/{folder_id}",
summary="Update Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder updated successfully"},
400: {"description": "Validation error"},
404: {"description": "Folder not found"},
409: {"description": "Folder name conflict"},
500: {"description": "Server error"},
},
)
async def update_folder(
folder_id: str,
payload: library_model.FolderUpdateRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Update a folder's properties.
Args:
folder_id: ID of the folder to update.
payload: The folder update request.
user_id: ID of the authenticated user.
Returns:
The updated LibraryFolder.
"""
return await library_db.update_folder(
folder_id=folder_id,
user_id=user_id,
name=payload.name,
icon=payload.icon,
color=payload.color,
)
@router.post(
"/{folder_id}/move",
summary="Move Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder moved successfully"},
400: {"description": "Validation error (circular reference)"},
404: {"description": "Folder or target parent not found"},
409: {"description": "Folder name conflict in target location"},
500: {"description": "Server error"},
},
)
async def move_folder(
folder_id: str,
payload: library_model.FolderMoveRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Move a folder to a new parent.
Args:
folder_id: ID of the folder to move.
payload: The move request with target parent.
user_id: ID of the authenticated user.
Returns:
The moved LibraryFolder.
"""
return await library_db.move_folder(
folder_id=folder_id,
user_id=user_id,
target_parent_id=payload.target_parent_id,
)
@router.delete(
"/{folder_id}",
summary="Delete Folder",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "Folder deleted successfully"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def delete_folder(
folder_id: str,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> Response:
"""
Soft-delete a folder and all its contents.
Args:
folder_id: ID of the folder to delete.
user_id: ID of the authenticated user.
Returns:
204 No Content if successful.
"""
await library_db.delete_folder(
folder_id=folder_id,
user_id=user_id,
soft_delete=True,
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
# === Bulk Agent Operations ===
@router.post(
"/agents/bulk-move",
summary="Bulk Move Agents",
response_model=list[library_model.LibraryAgent],
responses={
200: {"description": "Agents moved successfully"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def bulk_move_agents(
payload: library_model.BulkMoveAgentsRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> list[library_model.LibraryAgent]:
"""
Move multiple agents to a folder.
Args:
payload: The bulk move request with agent IDs and target folder.
user_id: ID of the authenticated user.
Returns:
The updated LibraryAgents.
"""
return await library_db.bulk_move_agents_to_folder(
agent_ids=payload.agent_ids,
folder_id=payload.folder_id,
user_id=user_id,
)

View File

@@ -115,8 +115,6 @@ async def test_get_library_agents_success(
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
page=1,
page_size=15,
folder_id=None,
include_root_only=False,
)

View File

@@ -9,26 +9,15 @@ import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, get_args, get_origin
from typing import Any
from prisma.enums import ContentType
from backend.blocks.llm import LlmModel
from backend.data.db import query_raw_with_schema
logger = logging.getLogger(__name__)
def _contains_type(annotation: Any, target: type) -> bool:
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
if annotation is target:
return True
origin = get_origin(annotation)
if origin is None:
return False
return any(_contains_type(arg, target) for arg in get_args(annotation))
@dataclass
class ContentItem:
"""Represents a piece of content to be embedded."""
@@ -199,51 +188,45 @@ class BlockHandler(ContentHandler):
try:
block_instance = block_cls()
# Skip disabled blocks - they shouldn't be indexed
if block_instance.disabled:
continue
# Build searchable text from block metadata
parts = []
if block_instance.name:
if hasattr(block_instance, "name") and block_instance.name:
parts.append(block_instance.name)
if block_instance.description:
if (
hasattr(block_instance, "description")
and block_instance.description
):
parts.append(block_instance.description)
if block_instance.categories:
if hasattr(block_instance, "categories") and block_instance.categories:
# Convert BlockCategory enum to strings
parts.append(
" ".join(str(cat.value) for cat in block_instance.categories)
)
# Add input schema field descriptions
block_input_fields = block_instance.input_schema.model_fields
parts += [
f"{field_name}: {field_info.description}"
for field_name, field_info in block_input_fields.items()
if field_info.description
]
# Add input/output schema info
if hasattr(block_instance, "input_schema"):
schema = block_instance.input_schema
if hasattr(schema, "model_json_schema"):
schema_dict = schema.model_json_schema()
if "properties" in schema_dict:
for prop_name, prop_info in schema_dict[
"properties"
].items():
if "description" in prop_info:
parts.append(
f"{prop_name}: {prop_info['description']}"
)
searchable_text = " ".join(parts)
# Convert categories set of enums to list of strings for JSON serialization
categories = getattr(block_instance, "categories", set())
categories_list = (
[cat.value for cat in block_instance.categories]
if block_instance.categories
else []
)
# Extract provider names from credentials fields
credentials_info = (
block_instance.input_schema.get_credentials_fields_info()
)
is_integration = len(credentials_info) > 0
provider_names = [
provider.value.lower()
for info in credentials_info.values()
for provider in info.provider
]
# Check if block has LlmModel field in input schema
has_llm_model_field = any(
_contains_type(field.annotation, LlmModel)
for field in block_instance.input_schema.model_fields.values()
[cat.value for cat in categories] if categories else []
)
items.append(
@@ -252,11 +235,8 @@ class BlockHandler(ContentHandler):
content_type=ContentType.BLOCK,
searchable_text=searchable_text,
metadata={
"name": block_instance.name,
"name": getattr(block_instance, "name", ""),
"categories": categories_list,
"providers": provider_names,
"has_llm_model_field": has_llm_model_field,
"is_integration": is_integration,
},
user_id=None, # Blocks are public
)

View File

@@ -82,10 +82,9 @@ async def test_block_handler_get_missing_items(mocker):
mock_block_instance.description = "Performs calculations"
mock_block_instance.categories = [MagicMock(value="MATH")]
mock_block_instance.disabled = False
mock_field = MagicMock()
mock_field.description = "Math expression to evaluate"
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
mock_block_instance.input_schema.model_json_schema.return_value = {
"properties": {"expression": {"description": "Math expression to evaluate"}}
}
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-uuid-1": mock_block_class}
@@ -310,19 +309,19 @@ async def test_content_handlers_registry():
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_handles_empty_attributes():
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
async def test_block_handler_handles_missing_attributes():
"""Test BlockHandler gracefully handles blocks with missing attributes."""
handler = BlockHandler()
# Mock block with empty values (all attributes exist but are falsy)
# Mock block with minimal attributes
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Minimal Block"
mock_block_instance.disabled = False
mock_block_instance.description = ""
mock_block_instance.categories = set()
mock_block_instance.input_schema.model_fields = {}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
# No description, categories, or schema
del mock_block_instance.description
del mock_block_instance.categories
del mock_block_instance.input_schema
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-minimal": mock_block_class}
@@ -353,8 +352,6 @@ async def test_block_handler_skips_failed_blocks():
good_instance.description = "Works fine"
good_instance.categories = []
good_instance.disabled = False
good_instance.input_schema.model_fields = {}
good_instance.input_schema.get_credentials_fields_info.return_value = {}
good_block.return_value = good_instance
bad_block = MagicMock()

View File

@@ -126,9 +126,6 @@ v1_router = APIRouter()
########################################################
_tally_background_tasks: set[asyncio.Task] = set()
@v1_router.post(
"/auth/user",
summary="Get or create user",
@@ -137,24 +134,6 @@ _tally_background_tasks: set[asyncio.Task] = set()
)
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
user = await get_or_create_user(user_data)
# Fire-and-forget: populate business understanding from Tally form.
# We use created_at proximity instead of an is_new flag because
# get_or_create_user is cached — a separate is_new return value would be
# unreliable on repeated calls within the cache TTL.
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
if age_seconds < 30:
try:
from backend.data.tally import populate_understanding_from_tally
task = asyncio.create_task(
populate_understanding_from_tally(user.id, user.email)
)
_tally_background_tasks.add(task)
task.add_done_callback(_tally_background_tasks.discard)
except Exception:
logger.debug("Failed to start Tally population task", exc_info=True)
return user.model_dump()

View File

@@ -1,5 +1,5 @@
import json
from datetime import datetime, timezone
from datetime import datetime
from io import BytesIO
from unittest.mock import AsyncMock, Mock, patch
@@ -43,7 +43,6 @@ def test_get_or_create_user_route(
) -> None:
"""Test get or create user endpoint"""
mock_user = Mock()
mock_user.created_at = datetime.now(timezone.utc)
mock_user.model_dump.return_value = {
"id": test_user_id,
"email": "test@example.com",

View File

@@ -11,7 +11,7 @@ import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi.responses import Response
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
from backend.data.workspace import get_workspace, get_workspace_file
from backend.util.workspace_storage import get_workspace_storage
@@ -44,11 +44,11 @@ router = fastapi.APIRouter(
)
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
def _create_streaming_response(content: bytes, file) -> Response:
"""Create a streaming response for file content."""
return Response(
content=content,
media_type=file.mime_type,
media_type=file.mimeType,
headers={
"Content-Disposition": _sanitize_filename_for_header(file.name),
"Content-Length": str(len(content)),
@@ -56,7 +56,7 @@ def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
)
async def _create_file_download_response(file: WorkspaceFile) -> Response:
async def _create_file_download_response(file) -> Response:
"""
Create a download response for a workspace file.
@@ -66,33 +66,33 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
storage = await get_workspace_storage()
# For local storage, stream the file directly
if file.storage_path.startswith("local://"):
content = await storage.retrieve(file.storage_path)
if file.storagePath.startswith("local://"):
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
# For GCS, try to redirect to signed URL, fall back to streaming
try:
url = await storage.get_download_url(file.storage_path, expires_in=300)
url = await storage.get_download_url(file.storagePath, expires_in=300)
# If we got back an API path (fallback), stream directly instead
if url.startswith("/api/"):
content = await storage.retrieve(file.storage_path)
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
return fastapi.responses.RedirectResponse(url=url, status_code=302)
except Exception as e:
# Log the signed URL failure with context
logger.error(
f"Failed to get signed URL for file {file.id} "
f"(storagePath={file.storage_path}): {e}",
f"(storagePath={file.storagePath}): {e}",
exc_info=True,
)
# Fall back to streaming directly from GCS
try:
content = await storage.retrieve(file.storage_path)
content = await storage.retrieve(file.storagePath)
return _create_streaming_response(content, file)
except Exception as fallback_error:
logger.error(
f"Fallback streaming also failed for file {file.id} "
f"(storagePath={file.storage_path}): {fallback_error}",
f"(storagePath={file.storagePath}): {fallback_error}",
exc_info=True,
)
raise

View File

@@ -41,9 +41,9 @@ import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.api.features.library.exceptions import (
FolderAlreadyExistsError,
FolderValidationError,
from backend.api.features.chat.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.data.model import Credentials
@@ -123,9 +123,21 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for Redis Streams notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e:
@@ -265,10 +277,6 @@ async def validation_error_handler(
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
app.add_exception_handler(
FolderAlreadyExistsError, handle_internal_http_error(409, False)
)
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
app.add_exception_handler(RequestValidationError, validation_error_handler)

View File

@@ -24,7 +24,7 @@ def run_processes(*processes: "AppProcess", **kwargs):
# Run the last process in the foreground.
processes[-1].start(background=False, **kwargs)
finally:
for process in reversed(processes):
for process in processes:
try:
process.stop()
except Exception as e:
@@ -38,9 +38,7 @@ def main(**kwargs):
from backend.api.rest_api import AgentServer
from backend.api.ws_api import WebsocketServer
from backend.copilot.executor.manager import CoPilotExecutor
from backend.data.db_manager import DatabaseManager
from backend.executor import ExecutionManager, Scheduler
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
from backend.notifications import NotificationManager
run_processes(
@@ -50,7 +48,6 @@ def main(**kwargs):
WebsocketServer(),
AgentServer(),
ExecutionManager(),
CoPilotExecutor(),
**kwargs,
)

View File

@@ -310,6 +310,8 @@ class BlockSchema(BaseModel):
"credentials_provider": [config.get("provider", "google")],
"credentials_types": [config.get("type", "oauth2")],
"credentials_scopes": config.get("scopes"),
"is_auto_credential": True,
"input_field_name": info["field_name"],
}
result[kwarg_name] = CredentialsFieldInfo.model_validate(
auto_schema, by_alias=True

View File

@@ -1,182 +0,0 @@
"""
Telegram Bot API helper functions.
Provides utilities for making authenticated requests to the Telegram Bot API.
"""
import logging
from io import BytesIO
from typing import Any, Optional
from pydantic import BaseModel
from backend.data.model import APIKeyCredentials
from backend.util.request import Requests
logger = logging.getLogger(__name__)
TELEGRAM_API_BASE = "https://api.telegram.org"
class TelegramMessageResult(BaseModel, extra="allow"):
"""Result from Telegram send/edit message API calls."""
message_id: int = 0
chat: dict[str, Any] = {}
date: int = 0
text: str = ""
class TelegramFileResult(BaseModel, extra="allow"):
"""Result from Telegram getFile API call."""
file_id: str = ""
file_unique_id: str = ""
file_size: int = 0
file_path: str = ""
class TelegramAPIException(ValueError):
"""Exception raised for Telegram API errors."""
def __init__(self, message: str, error_code: int = 0):
super().__init__(message)
self.error_code = error_code
def get_bot_api_url(bot_token: str, method: str) -> str:
"""Construct Telegram Bot API URL for a method."""
return f"{TELEGRAM_API_BASE}/bot{bot_token}/{method}"
def get_file_url(bot_token: str, file_path: str) -> str:
"""Construct Telegram file download URL."""
return f"{TELEGRAM_API_BASE}/file/bot{bot_token}/{file_path}"
async def call_telegram_api(
credentials: APIKeyCredentials,
method: str,
data: Optional[dict[str, Any]] = None,
) -> TelegramMessageResult:
"""
Make a request to the Telegram Bot API.
Args:
credentials: Bot token credentials
method: API method name (e.g., "sendMessage", "getFile")
data: Request parameters
Returns:
API response result
Raises:
TelegramAPIException: If the API returns an error
"""
token = credentials.api_key.get_secret_value()
url = get_bot_api_url(token, method)
response = await Requests().post(url, json=data or {})
result = response.json()
if not result.get("ok"):
error_code = result.get("error_code", 0)
description = result.get("description", "Unknown error")
raise TelegramAPIException(description, error_code)
return TelegramMessageResult(**result.get("result", {}))
async def call_telegram_api_with_file(
credentials: APIKeyCredentials,
method: str,
file_field: str,
file_data: bytes,
filename: str,
content_type: str,
data: Optional[dict[str, Any]] = None,
) -> TelegramMessageResult:
"""
Make a multipart/form-data request to the Telegram Bot API with a file upload.
Args:
credentials: Bot token credentials
method: API method name (e.g., "sendPhoto", "sendVoice")
file_field: Form field name for the file (e.g., "photo", "voice")
file_data: Raw file bytes
filename: Filename for the upload
content_type: MIME type of the file
data: Additional form parameters
Returns:
API response result
Raises:
TelegramAPIException: If the API returns an error
"""
token = credentials.api_key.get_secret_value()
url = get_bot_api_url(token, method)
files = [(file_field, (filename, BytesIO(file_data), content_type))]
response = await Requests().post(url, files=files, data=data or {})
result = response.json()
if not result.get("ok"):
error_code = result.get("error_code", 0)
description = result.get("description", "Unknown error")
raise TelegramAPIException(description, error_code)
return TelegramMessageResult(**result.get("result", {}))
async def get_file_info(
credentials: APIKeyCredentials, file_id: str
) -> TelegramFileResult:
"""
Get file information from Telegram.
Args:
credentials: Bot token credentials
file_id: Telegram file_id from message
Returns:
File info dict containing file_id, file_unique_id, file_size, file_path
"""
result = await call_telegram_api(credentials, "getFile", {"file_id": file_id})
return TelegramFileResult(**result.model_dump())
async def get_file_download_url(credentials: APIKeyCredentials, file_id: str) -> str:
"""
Get the download URL for a Telegram file.
Args:
credentials: Bot token credentials
file_id: Telegram file_id from message
Returns:
Full download URL
"""
token = credentials.api_key.get_secret_value()
result = await get_file_info(credentials, file_id)
file_path = result.file_path
if not file_path:
raise TelegramAPIException("No file_path returned from getFile")
return get_file_url(token, file_path)
async def download_telegram_file(credentials: APIKeyCredentials, file_id: str) -> bytes:
"""
Download a file from Telegram servers.
Args:
credentials: Bot token credentials
file_id: Telegram file_id
Returns:
File content as bytes
"""
url = await get_file_download_url(credentials, file_id)
response = await Requests().get(url)
return response.content

View File

@@ -1,43 +0,0 @@
"""
Telegram Bot credentials handling.
Telegram bots use an API key (bot token) obtained from @BotFather.
"""
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
# Bot token credentials (API key style)
TelegramCredentials = APIKeyCredentials
TelegramCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.TELEGRAM], Literal["api_key"]
]
def TelegramCredentialsField() -> TelegramCredentialsInput:
"""Creates a Telegram bot token credentials field."""
return CredentialsField(
description="Telegram Bot API token from @BotFather. "
"Create a bot at https://t.me/BotFather to get your token."
)
# Test credentials for unit tests
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="telegram",
api_key=SecretStr("test_telegram_bot_token"),
title="Mock Telegram Bot Token",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,377 +0,0 @@
"""
Telegram trigger blocks for receiving messages via webhooks.
"""
import logging
from pydantic import BaseModel
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
BlockWebhookConfig,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.telegram import TelegramWebhookType
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TelegramCredentialsField,
TelegramCredentialsInput,
)
logger = logging.getLogger(__name__)
# Example payload for testing
EXAMPLE_MESSAGE_PAYLOAD = {
"update_id": 123456789,
"message": {
"message_id": 1,
"from": {
"id": 12345678,
"is_bot": False,
"first_name": "John",
"last_name": "Doe",
"username": "johndoe",
"language_code": "en",
},
"chat": {
"id": 12345678,
"first_name": "John",
"last_name": "Doe",
"username": "johndoe",
"type": "private",
},
"date": 1234567890,
"text": "Hello, bot!",
},
}
class TelegramTriggerBase:
"""Base class for Telegram trigger blocks."""
class Input(BlockSchemaInput):
credentials: TelegramCredentialsInput = TelegramCredentialsField()
payload: dict = SchemaField(hidden=True, default_factory=dict)
class TelegramMessageTriggerBlock(TelegramTriggerBase, Block):
"""
Triggers when a message is received or edited in your Telegram bot.
Supports text, photos, voice messages, audio files, documents, and videos.
Connect the outputs to other blocks to process messages and send responses.
"""
class Input(TelegramTriggerBase.Input):
class EventsFilter(BaseModel):
"""Filter for message types to receive."""
text: bool = True
photo: bool = False
voice: bool = False
audio: bool = False
document: bool = False
video: bool = False
edited_message: bool = False
events: EventsFilter = SchemaField(
title="Message Types", description="Types of messages to receive"
)
class Output(BlockSchemaOutput):
payload: dict = SchemaField(
description="The complete webhook payload from Telegram"
)
chat_id: int = SchemaField(
description="The chat ID where the message was received. "
"Use this to send replies."
)
message_id: int = SchemaField(description="The unique message ID")
user_id: int = SchemaField(description="The user ID who sent the message")
username: str = SchemaField(description="Username of the sender (may be empty)")
first_name: str = SchemaField(description="First name of the sender")
event: str = SchemaField(
description="The message type (text, photo, voice, audio, etc.)"
)
text: str = SchemaField(
description="Text content of the message (for text messages)"
)
photo_file_id: str = SchemaField(
description="File ID of the photo (for photo messages). "
"Use GetTelegramFileBlock to download."
)
voice_file_id: str = SchemaField(
description="File ID of the voice message (for voice messages). "
"Use GetTelegramFileBlock to download."
)
audio_file_id: str = SchemaField(
description="File ID of the audio file (for audio messages). "
"Use GetTelegramFileBlock to download."
)
file_id: str = SchemaField(
description="File ID for document/video messages. "
"Use GetTelegramFileBlock to download."
)
file_name: str = SchemaField(
description="Original filename (for document/audio messages)"
)
caption: str = SchemaField(description="Caption for media messages")
is_edited: bool = SchemaField(
description="Whether this is an edit of a previously sent message"
)
def __init__(self):
super().__init__(
id="4435e4e0-df6e-4301-8f35-ad70b12fc9ec",
description="Triggers when a message is received or edited in your Telegram bot. "
"Supports text, photos, voice messages, audio files, documents, and videos.",
categories={BlockCategory.SOCIAL},
input_schema=TelegramMessageTriggerBlock.Input,
output_schema=TelegramMessageTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.TELEGRAM,
webhook_type=TelegramWebhookType.BOT,
resource_format="bot",
event_filter_input="events",
event_format="message.{event}",
),
test_input={
"events": {"text": True, "photo": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": EXAMPLE_MESSAGE_PAYLOAD,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", EXAMPLE_MESSAGE_PAYLOAD),
("chat_id", 12345678),
("message_id", 1),
("user_id", 12345678),
("username", "johndoe"),
("first_name", "John"),
("is_edited", False),
("event", "text"),
("text", "Hello, bot!"),
("photo_file_id", ""),
("voice_file_id", ""),
("audio_file_id", ""),
("file_id", ""),
("file_name", ""),
("caption", ""),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
payload = input_data.payload
is_edited = "edited_message" in payload
message = payload.get("message") or payload.get("edited_message", {})
# Extract common fields
chat = message.get("chat", {})
sender = message.get("from", {})
yield "payload", payload
yield "chat_id", chat.get("id", 0)
yield "message_id", message.get("message_id", 0)
yield "user_id", sender.get("id", 0)
yield "username", sender.get("username", "")
yield "first_name", sender.get("first_name", "")
yield "is_edited", is_edited
# For edited messages, yield event as "edited_message" and extract
# all content fields from the edited message body
if is_edited:
yield "event", "edited_message"
yield "text", message.get("text", "")
photos = message.get("photo", [])
yield "photo_file_id", photos[-1].get("file_id", "") if photos else ""
voice = message.get("voice", {})
yield "voice_file_id", voice.get("file_id", "")
audio = message.get("audio", {})
yield "audio_file_id", audio.get("file_id", "")
document = message.get("document", {})
video = message.get("video", {})
yield "file_id", (document.get("file_id", "") or video.get("file_id", ""))
yield "file_name", (
document.get("file_name", "") or audio.get("file_name", "")
)
yield "caption", message.get("caption", "")
# Determine message type and extract content
elif "text" in message:
yield "event", "text"
yield "text", message.get("text", "")
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", ""
elif "photo" in message:
# Get the largest photo (last in array)
photos = message.get("photo", [])
photo_fid = photos[-1].get("file_id", "") if photos else ""
yield "event", "photo"
yield "text", ""
yield "photo_file_id", photo_fid
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", message.get("caption", "")
elif "voice" in message:
voice = message.get("voice", {})
yield "event", "voice"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", voice.get("file_id", "")
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", message.get("caption", "")
elif "audio" in message:
audio = message.get("audio", {})
yield "event", "audio"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", audio.get("file_id", "")
yield "file_id", ""
yield "file_name", audio.get("file_name", "")
yield "caption", message.get("caption", "")
elif "document" in message:
document = message.get("document", {})
yield "event", "document"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", document.get("file_id", "")
yield "file_name", document.get("file_name", "")
yield "caption", message.get("caption", "")
elif "video" in message:
video = message.get("video", {})
yield "event", "video"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", video.get("file_id", "")
yield "file_name", video.get("file_name", "")
yield "caption", message.get("caption", "")
else:
yield "event", "other"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", ""
# Example payload for reaction trigger testing
EXAMPLE_REACTION_PAYLOAD = {
"update_id": 123456790,
"message_reaction": {
"chat": {
"id": 12345678,
"first_name": "John",
"last_name": "Doe",
"username": "johndoe",
"type": "private",
},
"message_id": 42,
"user": {
"id": 12345678,
"is_bot": False,
"first_name": "John",
"username": "johndoe",
},
"date": 1234567890,
"new_reaction": [{"type": "emoji", "emoji": "👍"}],
"old_reaction": [],
},
}
class TelegramMessageReactionTriggerBlock(TelegramTriggerBase, Block):
"""
Triggers when a reaction to a message is changed.
Works automatically in private chats. In group chats, the bot must be
an administrator to receive reaction updates.
"""
class Input(TelegramTriggerBase.Input):
pass
class Output(BlockSchemaOutput):
payload: dict = SchemaField(
description="The complete webhook payload from Telegram"
)
chat_id: int = SchemaField(
description="The chat ID where the reaction occurred"
)
message_id: int = SchemaField(description="The message ID that was reacted to")
user_id: int = SchemaField(description="The user ID who changed the reaction")
username: str = SchemaField(description="Username of the user (may be empty)")
new_reactions: list = SchemaField(
description="List of new reactions on the message"
)
old_reactions: list = SchemaField(
description="List of previous reactions on the message"
)
def __init__(self):
super().__init__(
id="82525328-9368-4966-8f0c-cd78e80181fd",
description="Triggers when a reaction to a message is changed. "
"Works in private chats automatically. "
"In groups, the bot must be an administrator.",
categories={BlockCategory.SOCIAL},
input_schema=TelegramMessageReactionTriggerBlock.Input,
output_schema=TelegramMessageReactionTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.TELEGRAM,
webhook_type=TelegramWebhookType.BOT,
resource_format="bot",
event_filter_input="",
event_format="message_reaction",
),
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"payload": EXAMPLE_REACTION_PAYLOAD,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", EXAMPLE_REACTION_PAYLOAD),
("chat_id", 12345678),
("message_id", 42),
("user_id", 12345678),
("username", "johndoe"),
("new_reactions", [{"type": "emoji", "emoji": "👍"}]),
("old_reactions", []),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
payload = input_data.payload
reaction = payload.get("message_reaction", {})
chat = reaction.get("chat", {})
user = reaction.get("user", {})
yield "payload", payload
yield "chat_id", chat.get("id", 0)
yield "message_id", reaction.get("message_id", 0)
yield "user_id", user.get("id", 0)
yield "username", user.get("username", "")
yield "new_reactions", reaction.get("new_reaction", [])
yield "old_reactions", reaction.get("old_reaction", [])

View File

@@ -34,12 +34,10 @@ def main(output: Path, pretty: bool):
"""Generate and output the OpenAPI JSON specification."""
openapi_schema = get_openapi_schema()
json_output = json.dumps(
openapi_schema, indent=2 if pretty else None, ensure_ascii=False
)
json_output = json.dumps(openapi_schema, indent=2 if pretty else None)
if output:
output.write_text(json_output, encoding="utf-8")
output.write_text(json_output)
click.echo(f"✅ OpenAPI specification written to {output}\n\nPreview:")
click.echo(f"\n{json_output[:500]} ...")
else:

View File

@@ -1,7 +1,6 @@
import logging
import os
import pytest
import pytest_asyncio
from dotenv import load_dotenv
@@ -28,54 +27,6 @@ async def server():
yield server
@pytest.fixture
def test_user_id() -> str:
"""Test user ID fixture."""
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture
def admin_user_id() -> str:
"""Admin user ID fixture."""
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
@pytest.fixture
def target_user_id() -> str:
"""Target user ID fixture."""
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
@pytest.fixture
async def setup_test_user(test_user_id):
"""Create test user in database before tests."""
from backend.data.user import get_or_create_user
# Create the test user in the database using JWT token format
user_data = {
"sub": test_user_id,
"email": "test@example.com",
"user_metadata": {"name": "Test User"},
}
await get_or_create_user(user_data)
return test_user_id
@pytest.fixture
async def setup_admin_user(admin_user_id):
"""Create admin user in database before tests."""
from backend.data.user import get_or_create_user
# Create the admin user in the database using JWT token format
user_data = {
"sub": admin_user_id,
"email": "test-admin@example.com",
"user_metadata": {"name": "Test Admin"},
}
await get_or_create_user(user_data)
return admin_user_id
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
async def graph_cleanup(server):
created_graph_ids = []

View File

@@ -1,18 +0,0 @@
"""Entry point for running the CoPilot Executor service.
Usage:
python -m backend.copilot.executor
"""
from backend.app import run_processes
from .manager import CoPilotExecutor
def main():
"""Run the CoPilot Executor service."""
run_processes(CoPilotExecutor())
if __name__ == "__main__":
main()

View File

@@ -1,526 +0,0 @@
"""CoPilot Executor Manager - main service for CoPilot task execution.
This module contains the CoPilotExecutor class that consumes chat tasks from
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
"""
import asyncio
import logging
import os
import threading
import time
import uuid
from concurrent.futures import Future, ThreadPoolExecutor
from pika.adapters.blocking_connection import BlockingChannel
from pika.exceptions import AMQPChannelError, AMQPConnectionError
from pika.spec import Basic, BasicProperties
from prometheus_client import Gauge, start_http_server
from backend.data import redis_client as redis
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.cluster_lock import ClusterLock
from backend.util.decorator import error_logged
from backend.util.logging import TruncatedLogger
from backend.util.process import AppProcess
from backend.util.retry import continuous_retry
from backend.util.settings import Settings
from .processor import execute_copilot_turn, init_worker
from .utils import (
COPILOT_CANCEL_QUEUE_NAME,
COPILOT_EXECUTION_QUEUE_NAME,
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
CancelCoPilotEvent,
CoPilotExecutionEntry,
create_copilot_queue_config,
)
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
settings = Settings()
# Prometheus metrics
active_tasks_gauge = Gauge(
"copilot_executor_active_tasks",
"Number of active CoPilot tasks",
)
pool_size_gauge = Gauge(
"copilot_executor_pool_size",
"Maximum number of CoPilot executor workers",
)
utilization_gauge = Gauge(
"copilot_executor_utilization_ratio",
"Ratio of active tasks to pool size",
)
class CoPilotExecutor(AppProcess):
"""CoPilot Executor service for processing chat generation tasks.
This service consumes tasks from RabbitMQ, processes them using a thread pool,
and publishes results to Redis Streams. It follows the graph executor pattern
for reliable message handling and graceful shutdown.
Key features:
- RabbitMQ-based task distribution with manual acknowledgment
- Thread pool executor for concurrent task processing
- Cluster lock for duplicate prevention across pods
- Graceful shutdown with timeout for in-flight tasks
- FANOUT exchange for cancellation broadcast
"""
def __init__(self):
super().__init__()
self.pool_size = settings.config.num_copilot_workers
self.active_tasks: dict[str, tuple[Future, threading.Event]] = {}
self.executor_id = str(uuid.uuid4())
self._executor = None
self._stop_consuming = None
self._cancel_thread = None
self._cancel_client = None
self._run_thread = None
self._run_client = None
self._task_locks: dict[str, ClusterLock] = {}
self._active_tasks_lock = threading.Lock()
# ============ Main Entry Points (AppProcess interface) ============ #
def run(self):
"""Main service loop - consume from RabbitMQ."""
logger.info(f"Pod assigned executor_id: {self.executor_id}")
logger.info(f"Spawn max-{self.pool_size} workers...")
pool_size_gauge.set(self.pool_size)
self._update_metrics()
start_http_server(settings.config.copilot_executor_port)
self.cancel_thread.start()
self.run_thread.start()
while True:
time.sleep(1e5)
def cleanup(self):
"""Graceful shutdown with active execution waiting."""
pid = os.getpid()
logger.info(f"[cleanup {pid}] Starting graceful shutdown...")
# Signal the consumer thread to stop
try:
self.stop_consuming.set()
run_channel = self.run_client.get_channel()
run_channel.connection.add_callback_threadsafe(
lambda: run_channel.stop_consuming()
)
logger.info(f"[cleanup {pid}] Consumer has been signaled to stop")
except Exception as e:
logger.error(f"[cleanup {pid}] Error stopping consumer: {e}")
# Wait for active executions to complete
if self.active_tasks:
logger.info(
f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
)
start_time = time.monotonic()
last_refresh = start_time
lock_refresh_interval = settings.config.cluster_lock_timeout / 10
while (
self.active_tasks
and (time.monotonic() - start_time) < GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
):
self._cleanup_completed_tasks()
if not self.active_tasks:
break
# Refresh cluster locks periodically
current_time = time.monotonic()
if current_time - last_refresh >= lock_refresh_interval:
for lock in list(self._task_locks.values()):
try:
lock.refresh()
except Exception as e:
logger.warning(
f"[cleanup {pid}] Failed to refresh lock: {e}"
)
last_refresh = current_time
logger.info(
f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..."
)
time.sleep(10.0)
# Stop message consumers
if self._run_thread:
self._stop_message_consumers(
self._run_thread, self.run_client, "[cleanup][run]"
)
if self._cancel_thread:
self._stop_message_consumers(
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
)
# Clean up worker threads (closes per-loop workspace storage sessions)
if self._executor:
from .processor import cleanup_worker
logger.info(f"[cleanup {pid}] Cleaning up workers...")
futures = []
for _ in range(self._executor._max_workers):
futures.append(self._executor.submit(cleanup_worker))
for f in futures:
try:
f.result(timeout=10)
except Exception as e:
logger.warning(f"[cleanup {pid}] Worker cleanup error: {e}")
logger.info(f"[cleanup {pid}] Shutting down executor...")
self._executor.shutdown(wait=False)
# Release any remaining locks
for session_id, lock in list(self._task_locks.items()):
try:
lock.release()
logger.info(f"[cleanup {pid}] Released lock for {session_id}")
except Exception as e:
logger.error(
f"[cleanup {pid}] Failed to release lock for {session_id}: {e}"
)
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
# ============ RabbitMQ Consumer Methods ============ #
@continuous_retry()
def _consume_cancel(self):
"""Consume cancellation messages from FANOUT exchange."""
if self.stop_consuming.is_set() and not self.active_tasks:
logger.info("Stop reconnecting cancel consumer - service cleaned up")
return
if not self.cancel_client.is_ready:
self.cancel_client.disconnect()
self.cancel_client.connect()
# Check again after connect - shutdown may have been requested
if self.stop_consuming.is_set() and not self.active_tasks:
logger.info("Stop consuming requested during reconnect - disconnecting")
self.cancel_client.disconnect()
return
cancel_channel = self.cancel_client.get_channel()
cancel_channel.basic_consume(
queue=COPILOT_CANCEL_QUEUE_NAME,
on_message_callback=self._handle_cancel_message,
auto_ack=True,
)
logger.info("Starting to consume cancel messages...")
cancel_channel.start_consuming()
if not self.stop_consuming.is_set() or self.active_tasks:
raise RuntimeError("Cancel message consumer stopped unexpectedly")
logger.info("Cancel message consumer stopped gracefully")
@continuous_retry()
def _consume_run(self):
"""Consume run messages from DIRECT exchange."""
if self.stop_consuming.is_set():
logger.info("Stop reconnecting run consumer - service cleaned up")
return
if not self.run_client.is_ready:
self.run_client.disconnect()
self.run_client.connect()
# Check again after connect - shutdown may have been requested
if self.stop_consuming.is_set():
logger.info("Stop consuming requested during reconnect - disconnecting")
self.run_client.disconnect()
return
run_channel = self.run_client.get_channel()
run_channel.basic_qos(prefetch_count=self.pool_size)
run_channel.basic_consume(
queue=COPILOT_EXECUTION_QUEUE_NAME,
on_message_callback=self._handle_run_message,
auto_ack=False,
consumer_tag="copilot_execution_consumer",
)
logger.info("Starting to consume run messages...")
run_channel.start_consuming()
if not self.stop_consuming.is_set():
raise RuntimeError("Run message consumer stopped unexpectedly")
logger.info("Run message consumer stopped gracefully")
# ============ Message Handlers ============ #
@error_logged(swallow=True)
def _handle_cancel_message(
self,
_channel: BlockingChannel,
_method: Basic.Deliver,
_properties: BasicProperties,
body: bytes,
):
"""Handle cancel message from FANOUT exchange."""
request = CancelCoPilotEvent.model_validate_json(body)
session_id = request.session_id
if not session_id:
logger.warning("Cancel message missing 'session_id'")
return
if session_id not in self.active_tasks:
logger.debug(f"Cancel received for {session_id} but not active")
return
_, cancel_event = self.active_tasks[session_id]
logger.info(f"Received cancel for {session_id}")
if not cancel_event.is_set():
cancel_event.set()
else:
logger.debug(f"Cancel already set for {session_id}")
def _handle_run_message(
self,
_channel: BlockingChannel,
method: Basic.Deliver,
_properties: BasicProperties,
body: bytes,
):
"""Handle run message from DIRECT exchange."""
delivery_tag = method.delivery_tag
# Capture the channel used at message delivery time to ensure we ack
# on the correct channel. Delivery tags are channel-scoped and become
# invalid if the channel is recreated after reconnection.
delivery_channel = _channel
def ack_message(reject: bool, requeue: bool):
"""Acknowledge or reject the message.
Uses the channel from the original message delivery. If the channel
is no longer open (e.g., after reconnection), logs a warning and
skips the ack - RabbitMQ will redeliver the message automatically.
"""
try:
if not delivery_channel.is_open:
logger.warning(
f"Channel closed, cannot ack delivery_tag={delivery_tag}. "
"Message will be redelivered by RabbitMQ."
)
return
if reject:
delivery_channel.connection.add_callback_threadsafe(
lambda: delivery_channel.basic_nack(
delivery_tag, requeue=requeue
)
)
else:
delivery_channel.connection.add_callback_threadsafe(
lambda: delivery_channel.basic_ack(delivery_tag)
)
except (AMQPChannelError, AMQPConnectionError) as e:
# Channel/connection errors indicate stale delivery tag - don't retry
logger.warning(
f"Cannot ack delivery_tag={delivery_tag} due to channel/connection "
f"error: {e}. Message will be redelivered by RabbitMQ."
)
except Exception as e:
# Other errors might be transient, but log and skip to avoid blocking
logger.error(
f"Unexpected error acking delivery_tag={delivery_tag}: {e}"
)
# Check if we're shutting down
if self.stop_consuming.is_set():
logger.info("Rejecting new task during shutdown")
ack_message(reject=True, requeue=True)
return
# Check if we can accept more tasks
self._cleanup_completed_tasks()
if len(self.active_tasks) >= self.pool_size:
ack_message(reject=True, requeue=True)
return
try:
entry = CoPilotExecutionEntry.model_validate_json(body)
except Exception as e:
logger.error(f"Could not parse run message: {e}, body={body}")
ack_message(reject=True, requeue=False)
return
session_id = entry.session_id
# Check for local duplicate - session is already running on this executor
if session_id in self.active_tasks:
logger.warning(
f"Session {session_id} already running locally, rejecting duplicate"
)
ack_message(reject=True, requeue=False)
return
# Try to acquire cluster-wide lock
cluster_lock = ClusterLock(
redis=redis.get_redis(),
key=f"copilot:session:{session_id}:lock",
owner_id=self.executor_id,
timeout=settings.config.cluster_lock_timeout,
)
current_owner = cluster_lock.try_acquire()
if current_owner != self.executor_id:
if current_owner is not None:
logger.warning(
f"Session {session_id} already running on pod {current_owner}"
)
ack_message(reject=True, requeue=False)
else:
logger.warning(
f"Could not acquire lock for {session_id} - Redis unavailable"
)
ack_message(reject=True, requeue=True)
return
# Execute the task
try:
self._task_locks[session_id] = cluster_lock
logger.info(
f"Acquired cluster lock for {session_id}, "
f"executor_id={self.executor_id}"
)
cancel_event = threading.Event()
future = self.executor.submit(
execute_copilot_turn, entry, cancel_event, cluster_lock
)
self.active_tasks[session_id] = (future, cancel_event)
except Exception as e:
logger.warning(f"Failed to setup execution for {session_id}: {e}")
cluster_lock.release()
if session_id in self._task_locks:
del self._task_locks[session_id]
ack_message(reject=True, requeue=True)
return
self._update_metrics()
def on_run_done(f: Future):
logger.info(f"Run completed for {session_id}")
error_msg = None
try:
if exec_error := f.exception():
error_msg = str(exec_error) or type(exec_error).__name__
logger.error(f"Execution for {session_id} failed: {error_msg}")
ack_message(reject=True, requeue=False)
else:
ack_message(reject=False, requeue=False)
except asyncio.CancelledError:
logger.info(f"Run completion callback cancelled for {session_id}")
except BaseException as e:
error_msg = str(e) or type(e).__name__
logger.exception(f"Error in run completion callback: {error_msg}")
finally:
# Release the cluster lock
if session_id in self._task_locks:
logger.info(f"Releasing cluster lock for {session_id}")
self._task_locks[session_id].release()
del self._task_locks[session_id]
self._cleanup_completed_tasks()
future.add_done_callback(on_run_done)
# ============ Helper Methods ============ #
def _cleanup_completed_tasks(self) -> list[str]:
"""Remove completed futures from active_tasks and update metrics."""
completed_tasks = []
with self._active_tasks_lock:
for session_id, (future, _) in list(self.active_tasks.items()):
if future.done():
completed_tasks.append(session_id)
self.active_tasks.pop(session_id, None)
logger.info(f"Cleaned up completed session {session_id}")
self._update_metrics()
return completed_tasks
def _update_metrics(self):
"""Update Prometheus metrics."""
active_count = len(self.active_tasks)
active_tasks_gauge.set(active_count)
if self.stop_consuming.is_set():
utilization_gauge.set(1.0)
else:
utilization_gauge.set(
active_count / self.pool_size if self.pool_size > 0 else 0
)
def _stop_message_consumers(
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
):
"""Stop a message consumer thread."""
try:
channel = client.get_channel()
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
thread.join(timeout=300)
if thread.is_alive():
logger.error(
f"{prefix} Thread did not finish in time, forcing disconnect"
)
client.disconnect()
logger.info(f"{prefix} Client disconnected")
except Exception as e:
logger.error(f"{prefix} Error disconnecting client: {e}")
# ============ Lazy-initialized Properties ============ #
@property
def cancel_thread(self) -> threading.Thread:
if self._cancel_thread is None:
self._cancel_thread = threading.Thread(
target=lambda: self._consume_cancel(),
daemon=True,
)
return self._cancel_thread
@property
def run_thread(self) -> threading.Thread:
if self._run_thread is None:
self._run_thread = threading.Thread(
target=lambda: self._consume_run(),
daemon=True,
)
return self._run_thread
@property
def stop_consuming(self) -> threading.Event:
if self._stop_consuming is None:
self._stop_consuming = threading.Event()
return self._stop_consuming
@property
def executor(self) -> ThreadPoolExecutor:
if self._executor is None:
self._executor = ThreadPoolExecutor(
max_workers=self.pool_size,
initializer=init_worker,
)
return self._executor
@property
def cancel_client(self) -> SyncRabbitMQ:
if self._cancel_client is None:
self._cancel_client = SyncRabbitMQ(create_copilot_queue_config())
return self._cancel_client
@property
def run_client(self) -> SyncRabbitMQ:
if self._run_client is None:
self._run_client = SyncRabbitMQ(create_copilot_queue_config())
return self._run_client

View File

@@ -1,276 +0,0 @@
"""CoPilot execution processor - per-worker execution logic.
This module contains the processor class that handles CoPilot session execution
in a thread-local context, following the graph executor pattern.
"""
import asyncio
import logging
import threading
import time
from backend.copilot import service as copilot_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig
from backend.copilot.response_model import StreamFinish
from backend.copilot.sdk import service as sdk_service
from backend.executor.cluster_lock import ClusterLock
from backend.util.decorator import error_logged
from backend.util.feature_flag import Flag, is_feature_enabled
from backend.util.logging import TruncatedLogger, configure_logging
from backend.util.process import set_service_name
from backend.util.retry import func_retry
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
# ============ Module Entry Points ============ #
# Thread-local storage for processor instances
_tls = threading.local()
def execute_copilot_turn(
entry: CoPilotExecutionEntry,
cancel: threading.Event,
cluster_lock: ClusterLock,
):
"""Execute a single CoPilot turn (user message → AI response).
This function is the entry point called by the thread pool executor.
Args:
entry: The turn payload
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock for this execution
"""
processor: CoPilotProcessor = _tls.processor
return processor.execute(entry, cancel, cluster_lock)
def init_worker():
"""Initialize the processor for the current worker thread.
This function is called by the thread pool executor when a new worker
thread is created. It ensures each worker has its own processor instance.
"""
_tls.processor = CoPilotProcessor()
_tls.processor.on_executor_start()
def cleanup_worker():
"""Clean up the processor for the current worker thread.
Should be called before the worker thread's event loop is destroyed so
that event-loop-bound resources (e.g. ``aiohttp.ClientSession``) are
closed on the correct loop.
"""
processor: CoPilotProcessor | None = getattr(_tls, "processor", None)
if processor is not None:
processor.cleanup()
# ============ Processor Class ============ #
class CoPilotProcessor:
"""Per-worker execution logic for CoPilot sessions.
This class is instantiated once per worker thread and handles the execution
of CoPilot chat generation sessions. It maintains an async event loop for
running the async service code.
The execution flow:
1. Session entry is picked from RabbitMQ queue
2. Manager submits to thread pool
3. Processor executes in its event loop
4. Results are published to Redis Streams
"""
@func_retry
def on_executor_start(self):
"""Initialize the processor when the worker thread starts.
This method is called once per worker thread to set up the async event
loop and initialize any required resources.
Database is accessed only through DatabaseManager, so we don't need to connect
to Prisma directly.
"""
configure_logging()
set_service_name("CoPilotExecutor")
self.tid = threading.get_ident()
self.execution_loop = asyncio.new_event_loop()
self.execution_thread = threading.Thread(
target=self.execution_loop.run_forever, daemon=True
)
self.execution_thread.start()
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
def cleanup(self):
"""Clean up event-loop-bound resources before the loop is destroyed.
Shuts down the workspace storage instance that belongs to this
worker's event loop, ensuring ``aiohttp.ClientSession.close()``
runs on the same loop that created the session.
"""
from backend.util.workspace_storage import shutdown_workspace_storage
try:
future = asyncio.run_coroutine_threadsafe(
shutdown_workspace_storage(), self.execution_loop
)
future.result(timeout=5)
except Exception as e:
error_msg = str(e) or type(e).__name__
logger.warning(
f"[CoPilotExecutor] Worker {self.tid} cleanup error: {error_msg}"
)
# Stop the event loop
self.execution_loop.call_soon_threadsafe(self.execution_loop.stop)
self.execution_thread.join(timeout=5)
logger.info(f"[CoPilotExecutor] Worker {self.tid} cleaned up")
@error_logged(swallow=False)
def execute(
self,
entry: CoPilotExecutionEntry,
cancel: threading.Event,
cluster_lock: ClusterLock,
):
"""Execute a CoPilot turn.
Runs the async logic in the worker's event loop and handles errors.
Args:
entry: The turn payload containing session and message info
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock to prevent duplicate execution
"""
log = CoPilotLogMetadata(
logging.getLogger(__name__),
session_id=entry.session_id,
user_id=entry.user_id,
)
log.info("Starting execution")
start_time = time.monotonic()
# Run the async execution in our event loop
future = asyncio.run_coroutine_threadsafe(
self._execute_async(entry, cancel, cluster_lock, log),
self.execution_loop,
)
# Wait for completion, checking cancel periodically
while not future.done():
try:
future.result(timeout=1.0)
except asyncio.TimeoutError:
if cancel.is_set():
log.info("Cancellation requested")
future.cancel()
break
# Refresh cluster lock to maintain ownership
cluster_lock.refresh()
if not future.cancelled():
# Get result to propagate any exceptions
future.result()
elapsed = time.monotonic() - start_time
log.info(f"Execution completed in {elapsed:.2f}s")
async def _execute_async(
self,
entry: CoPilotExecutionEntry,
cancel: threading.Event,
cluster_lock: ClusterLock,
log: CoPilotLogMetadata,
):
"""Async execution logic for a CoPilot turn.
Calls the stream_chat_completion service function and publishes
results to the stream registry.
Args:
entry: The turn payload
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock for refresh
log: Structured logger
"""
last_refresh = time.monotonic()
refresh_interval = 30.0 # Refresh lock every 30 seconds
error_msg = None
try:
# Choose service based on LaunchDarkly flag
config = ChatConfig()
use_sdk = await is_feature_enabled(
Flag.COPILOT_SDK,
entry.user_id or "anonymous",
default=config.use_claude_agent_sdk,
)
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else copilot_service.stream_chat_completion
)
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
# Stream chat completion and publish chunks to Redis.
async for chunk in stream_fn(
session_id=entry.session_id,
message=entry.message if entry.message else None,
is_user_message=entry.is_user_message,
user_id=entry.user_id,
context=entry.context,
):
if cancel.is_set():
log.info("Cancel requested, breaking stream")
break
current_time = time.monotonic()
if current_time - last_refresh >= refresh_interval:
cluster_lock.refresh()
last_refresh = current_time
# Skip StreamFinish — mark_session_completed publishes it.
if isinstance(chunk, StreamFinish):
continue
try:
await stream_registry.publish_chunk(entry.turn_id, chunk)
except Exception as e:
log.error(
f"Error publishing chunk {type(chunk).__name__}: {e}",
exc_info=True,
)
# Stream loop completed
if cancel.is_set():
log.info("Stream cancelled by user")
except BaseException as e:
# Handle all exceptions (including CancelledError) with appropriate logging
if isinstance(e, asyncio.CancelledError):
log.info("Turn cancelled")
error_msg = "Operation cancelled"
else:
error_msg = str(e) or type(e).__name__
log.error(f"Turn failed: {error_msg}")
raise
finally:
# If no exception but user cancelled, still mark as cancelled
if not error_msg and cancel.is_set():
error_msg = "Operation cancelled"
try:
await stream_registry.mark_session_completed(
entry.session_id, error_message=error_msg
)
except Exception as mark_err:
log.error(f"Failed to mark session completed: {mark_err}")

View File

@@ -1,218 +0,0 @@
"""RabbitMQ queue configuration for CoPilot executor.
Defines two exchanges and queues following the graph executor pattern:
- 'copilot_execution' (DIRECT) for chat generation tasks
- 'copilot_cancel' (FANOUT) for cancellation requests
"""
import logging
from pydantic import BaseModel
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
logger = logging.getLogger(__name__)
# ============ Logging Helper ============ #
class CoPilotLogMetadata(TruncatedLogger):
"""Structured logging helper for CoPilot executor.
In cloud environments (structured logging enabled), uses a simple prefix
and passes metadata via json_fields. In local environments, uses a detailed
prefix with all metadata key-value pairs for easier debugging.
Args:
logger: The underlying logger instance
max_length: Maximum log message length before truncation
**kwargs: Metadata key-value pairs (e.g., session_id="xyz", turn_id="abc")
These are added to json_fields in cloud mode, or to the prefix in local mode.
"""
def __init__(
self,
logger: logging.Logger,
max_length: int = 1000,
**kwargs: str | None,
):
# Filter out None values
metadata = {k: v for k, v in kwargs.items() if v is not None}
metadata["component"] = "CoPilotExecutor"
if is_structured_logging_enabled():
prefix = "[CoPilotExecutor]"
else:
# Build prefix from metadata key-value pairs
meta_parts = "|".join(
f"{k}:{v}" for k, v in metadata.items() if k != "component"
)
prefix = (
f"[CoPilotExecutor|{meta_parts}]" if meta_parts else "[CoPilotExecutor]"
)
super().__init__(
logger,
max_length=max_length,
prefix=prefix,
metadata=metadata,
)
# ============ Exchange and Queue Configuration ============ #
COPILOT_EXECUTION_EXCHANGE = Exchange(
name="copilot_execution",
type=ExchangeType.DIRECT,
durable=True,
auto_delete=False,
)
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue"
COPILOT_EXECUTION_ROUTING_KEY = "copilot.run"
COPILOT_CANCEL_EXCHANGE = Exchange(
name="copilot_cancel",
type=ExchangeType.FANOUT,
durable=True,
auto_delete=False,
)
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
# CoPilot operations can include extended thinking and agent generation
# which may take 30+ minutes to complete
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
# Graceful shutdown timeout - allow in-flight operations to complete
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
def create_copilot_queue_config() -> RabbitMQConfig:
"""Create RabbitMQ configuration for CoPilot executor.
Defines two exchanges and queues:
- 'copilot_execution' (DIRECT) for chat generation tasks
- 'copilot_cancel' (FANOUT) for cancellation requests
Returns:
RabbitMQConfig with exchanges and queues defined
"""
run_queue = Queue(
name=COPILOT_EXECUTION_QUEUE_NAME,
exchange=COPILOT_EXECUTION_EXCHANGE,
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
durable=True,
auto_delete=False,
arguments={
# Extended consumer timeout for long-running LLM operations
# Default 30-minute timeout is insufficient for extended thinking
# and agent generation which can take 30+ minutes
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
* 1000,
},
)
cancel_queue = Queue(
name=COPILOT_CANCEL_QUEUE_NAME,
exchange=COPILOT_CANCEL_EXCHANGE,
routing_key="", # not used for FANOUT
durable=True,
auto_delete=False,
)
return RabbitMQConfig(
vhost="/",
exchanges=[COPILOT_EXECUTION_EXCHANGE, COPILOT_CANCEL_EXCHANGE],
queues=[run_queue, cancel_queue],
)
# ============ Message Models ============ #
class CoPilotExecutionEntry(BaseModel):
"""Task payload for CoPilot AI generation.
This model represents a chat generation task to be processed by the executor.
"""
session_id: str
"""Chat session ID (also used for dedup/locking)"""
turn_id: str = ""
"""Per-turn UUID for Redis stream isolation"""
user_id: str | None
"""User ID (may be None for anonymous users)"""
message: str
"""User's message to process"""
is_user_message: bool = True
"""Whether the message is from the user (vs system/assistant)"""
context: dict[str, str] | None = None
"""Optional context for the message (e.g., {url: str, content: str})"""
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
session_id: str
"""Session ID to cancel"""
# ============ Queue Publishing Helpers ============ #
async def enqueue_copilot_turn(
session_id: str,
user_id: str | None,
message: str,
turn_id: str,
is_user_message: bool = True,
context: dict[str, str] | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
Args:
session_id: Chat session ID (also used for dedup/locking)
user_id: User ID (may be None for anonymous users)
message: User's message to process
turn_id: Per-turn UUID for Redis stream isolation
is_user_message: Whether the message is from the user (vs system/assistant)
context: Optional context for the message (e.g., {url: str, content: str})
"""
from backend.util.clients import get_async_copilot_queue
entry = CoPilotExecutionEntry(
session_id=session_id,
turn_id=turn_id,
user_id=user_id,
message=message,
is_user_message=is_user_message,
context=context,
)
queue_client = await get_async_copilot_queue()
await queue_client.publish_message(
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
message=entry.model_dump_json(),
exchange=COPILOT_EXECUTION_EXCHANGE,
)
async def enqueue_cancel_task(session_id: str) -> None:
"""Publish a cancel request for a running CoPilot session.
Sends a ``CancelCoPilotEvent`` to the FANOUT exchange so all executor
pods receive the cancellation signal.
"""
from backend.util.clients import get_async_copilot_queue
event = CancelCoPilotEvent(session_id=session_id)
queue_client = await get_async_copilot_queue()
await queue_client.publish_message(
routing_key="", # FANOUT ignores routing key
message=event.model_dump_json(),
exchange=COPILOT_CANCEL_EXCHANGE,
)

View File

@@ -1,269 +0,0 @@
"""Tests for parallel tool call execution in CoPilot.
These tests mock _yield_tool_call to avoid importing the full copilot stack
which requires Prisma, DB connections, etc.
"""
import asyncio
import time
from typing import Any, cast
import pytest
@pytest.mark.asyncio
async def test_parallel_tool_calls_run_concurrently():
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
n_tools = 3
delay_per_tool = 0.2
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"tool_{i}", "arguments": "{}"},
}
for i in range(n_tools)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
original_yield = None
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"],
toolName=tc_list[idx]["function"]["name"],
input={},
)
await asyncio.sleep(delay_per_tool)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"],
toolName=tc_list[idx]["function"]["name"],
output="{}",
)
import backend.copilot.service as svc
original_yield = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
start = time.monotonic()
events = []
async for event in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
events.append(event)
elapsed = time.monotonic() - start
finally:
svc._yield_tool_call = original_yield
assert len(events) == n_tools * 2
# Parallel: should take ~delay, not ~n*delay
assert elapsed < delay_per_tool * (
n_tools - 0.5
), f"Took {elapsed:.2f}s, expected parallel (~{delay_per_tool}s)"
@pytest.mark.asyncio
async def test_single_tool_call_works():
"""Single tool call should work identically."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": "call_0",
"type": "function",
"function": {"name": "t", "arguments": "{}"},
}
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
events = [
e
async for e in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
)
]
finally:
svc._yield_tool_call = orig
assert len(events) == 2
@pytest.mark.asyncio
async def test_retryable_error_propagates():
"""Retryable errors should be raised after all tools finish."""
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(2)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess):
if idx == 1:
raise KeyError("bad")
from backend.copilot.response_model import StreamToolInputAvailable
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName="t_0", input={}
)
await asyncio.sleep(0.05)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"], toolName="t_0", output="{}"
)
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
events = []
with pytest.raises(KeyError):
async for event in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
events.append(event)
# First tool's events should still be yielded
assert any(isinstance(e, StreamToolOutputAvailable) for e in events)
finally:
svc._yield_tool_call = orig
@pytest.mark.asyncio
async def test_session_shared_across_parallel_tools():
"""All parallel tools should receive the same session instance."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(3)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
observed_sessions = []
async def fake_yield(tc_list, idx, sess):
observed_sessions.append(sess)
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", output="{}"
)
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
async for _ in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
pass
finally:
svc._yield_tool_call = orig
assert len(observed_sessions) == 3
assert observed_sessions[0] is observed_sessions[1] is observed_sessions[2]
@pytest.mark.asyncio
async def test_cancellation_cleans_up():
"""Generator close should cancel in-flight tasks."""
from backend.copilot.response_model import StreamToolInputAvailable
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(2)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
started = asyncio.Event()
async def fake_yield(tc_list, idx, sess):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)
started.set()
await asyncio.sleep(10) # simulate long-running
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
gen = _execute_tool_calls_parallel(tool_calls, cast(Any, FakeSession()))
await gen.__anext__() # get first event
await started.wait()
await gen.aclose() # close generator
finally:
svc._yield_tool_call = orig
# If we get here without hanging, cleanup worked

View File

@@ -1,57 +0,0 @@
"""Dummy SDK service for testing copilot streaming.
Returns mock streaming responses without calling Claude Agent SDK.
Enable via COPILOT_TEST_MODE=true environment variable.
WARNING: This is for testing only. Do not use in production.
"""
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from ..model import ChatSession
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
logger = logging.getLogger(__name__)
async def stream_chat_completion_dummy(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None,
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0,
session: ChatSession | None = None,
context: dict[str, str] | None = None,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream dummy chat completion for testing.
Returns a simple streaming response with text deltas to test:
- Streaming infrastructure works
- No timeout occurs
- Text arrives in chunks
- StreamFinish is sent by mark_session_completed
"""
logger.warning(
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
)
message_id = str(uuid.uuid4())
text_block_id = str(uuid.uuid4())
# Start the stream
yield StreamStart(messageId=message_id, sessionId=session_id)
# Simulate streaming text response with delays
dummy_response = "I counted: 1... 2... 3. All done!"
words = dummy_response.split()
for i, word in enumerate(words):
# Add space except for last word
text = word if i == len(words) - 1 else f"{word} "
yield StreamTextDelta(id=text_block_id, delta=text)
# Small delay to simulate real streaming
await asyncio.sleep(0.1)

View File

@@ -1,221 +0,0 @@
"""Tests for _format_conversation_context and _build_query_message."""
from datetime import UTC, datetime
import pytest
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.sdk.service import (
_build_query_message,
_format_conversation_context,
)
# ---------------------------------------------------------------------------
# _format_conversation_context
# ---------------------------------------------------------------------------
def test_format_empty_list():
assert _format_conversation_context([]) is None
def test_format_none_content_messages():
msgs = [ChatMessage(role="user", content=None)]
assert _format_conversation_context(msgs) is None
def test_format_user_message():
msgs = [ChatMessage(role="user", content="hello")]
result = _format_conversation_context(msgs)
assert result is not None
assert "User: hello" in result
assert result.startswith("<conversation_history>")
assert result.endswith("</conversation_history>")
def test_format_assistant_text():
msgs = [ChatMessage(role="assistant", content="hi there")]
result = _format_conversation_context(msgs)
assert result is not None
assert "You responded: hi there" in result
def test_format_assistant_tool_calls():
msgs = [
ChatMessage(
role="assistant",
content=None,
tool_calls=[{"function": {"name": "search", "arguments": '{"q": "test"}'}}],
)
]
result = _format_conversation_context(msgs)
assert result is not None
assert 'You called tool: search({"q": "test"})' in result
def test_format_tool_result():
msgs = [ChatMessage(role="tool", content='{"result": "ok"}')]
result = _format_conversation_context(msgs)
assert result is not None
assert 'Tool result: {"result": "ok"}' in result
def test_format_tool_result_none_content():
msgs = [ChatMessage(role="tool", content=None)]
result = _format_conversation_context(msgs)
assert result is not None
assert "Tool result: " in result
def test_format_full_conversation():
msgs = [
ChatMessage(role="user", content="find agents"),
ChatMessage(
role="assistant",
content="I'll search for agents.",
tool_calls=[
{"function": {"name": "find_agents", "arguments": '{"q": "test"}'}}
],
),
ChatMessage(role="tool", content='[{"id": "1", "name": "Agent1"}]'),
ChatMessage(role="assistant", content="Found Agent1."),
]
result = _format_conversation_context(msgs)
assert result is not None
assert "User: find agents" in result
assert "You responded: I'll search for agents." in result
assert "You called tool: find_agents" in result
assert "Tool result:" in result
assert "You responded: Found Agent1." in result
# ---------------------------------------------------------------------------
# _build_query_message
# ---------------------------------------------------------------------------
def _make_session(messages: list[ChatMessage]) -> ChatSession:
"""Build a minimal ChatSession with the given messages."""
now = datetime.now(UTC)
return ChatSession(
session_id="test-session",
user_id="user-1",
messages=messages,
title="test",
usage=[],
started_at=now,
updated_at=now,
)
@pytest.mark.asyncio
async def test_build_query_resume_up_to_date():
"""With --resume and transcript covers all messages, return raw message."""
session = _make_session(
[
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi"),
ChatMessage(role="user", content="what's new?"),
]
)
result = await _build_query_message(
"what's new?",
session,
use_resume=True,
transcript_msg_count=2,
session_id="test-session",
)
# transcript_msg_count == msg_count - 1, so no gap
assert result == "what's new?"
@pytest.mark.asyncio
async def test_build_query_resume_stale_transcript():
"""With --resume and stale transcript, gap context is prepended."""
session = _make_session(
[
ChatMessage(role="user", content="turn 1"),
ChatMessage(role="assistant", content="reply 1"),
ChatMessage(role="user", content="turn 2"),
ChatMessage(role="assistant", content="reply 2"),
ChatMessage(role="user", content="turn 3"),
]
)
result = await _build_query_message(
"turn 3",
session,
use_resume=True,
transcript_msg_count=2,
session_id="test-session",
)
assert "<conversation_history>" in result
assert "turn 2" in result
assert "reply 2" in result
assert "Now, the user says:\nturn 3" in result
@pytest.mark.asyncio
async def test_build_query_resume_zero_msg_count():
"""With --resume but transcript_msg_count=0, return raw message."""
session = _make_session(
[
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi"),
ChatMessage(role="user", content="new msg"),
]
)
result = await _build_query_message(
"new msg",
session,
use_resume=True,
transcript_msg_count=0,
session_id="test-session",
)
assert result == "new msg"
@pytest.mark.asyncio
async def test_build_query_no_resume_single_message():
"""Without --resume and only 1 message, return raw message."""
session = _make_session([ChatMessage(role="user", content="first")])
result = await _build_query_message(
"first",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
)
assert result == "first"
@pytest.mark.asyncio
async def test_build_query_no_resume_multi_message(monkeypatch):
"""Without --resume and multiple messages, compress and prepend."""
session = _make_session(
[
ChatMessage(role="user", content="older question"),
ChatMessage(role="assistant", content="older answer"),
ChatMessage(role="user", content="new question"),
]
)
# Mock _compress_conversation_history to return the messages as-is
async def _mock_compress(sess):
return sess.messages[:-1]
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_conversation_history",
_mock_compress,
)
result = await _build_query_message(
"new question",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
)
assert "<conversation_history>" in result
assert "older question" in result
assert "older answer" in result
assert "Now, the user says:\nnew question" in result

Some files were not shown because too many files have changed in this diff Show More