Compare commits

..

9 Commits

Author SHA1 Message Date
Zamil Majdy
11e6fca8c3 fix(copilot): resolve dangling tool spinners when stream finishes
When the backend doesn't emit StreamToolOutputAvailable for all tool
calls before StreamFinish (e.g. SDK built-in tools like WebSearch),
the frontend spinners would spin forever.

Add a useEffect that watches for the streaming→ready transition and
marks any remaining input-available/input-streaming tool parts as
output-available. Extract shared resolveInProgressTools helper used
by both the stop handler (cancelled) and stream-end (completed).
2026-02-20 03:48:20 +07:00
Zamil Majdy
6e737e0b74 style: fix Black formatting on cancel endpoint 2026-02-20 02:46:03 +07:00
Zamil Majdy
5ce002803d fix(copilot): toast when cancel confirmation times out
Check the reason field in the cancel response — if
"cancel_published_not_confirmed", show a non-destructive toast so the
user knows the stop was sent but not yet confirmed by the executor.
2026-02-20 02:31:18 +07:00
Zamil Majdy
f8ad8484ee refactor(copilot): convert stop to plain function declaration
Remove useCallback wrapper per project guidelines — stopRef.current
captures the latest closure on every render regardless.
2026-02-20 02:25:43 +07:00
Zamil Majdy
b6064d0155 fix(copilot): address round-2 PR review and fix tool loading on stop
Backend:
- Add _validate_and_get_session() call to cancel endpoint (404 for
  invalid sessions, consistent with other endpoints)
- Reduce polling max_wait from 10s to 5s (stay below reverse-proxy
  read timeouts)
- Return cancelled=True with reason="cancel_published_not_confirmed"
  on timeout (cancel event IS published, just not yet confirmed)

Frontend:
- Mark in-progress tool parts as output-error on stop so spinners
  clear immediately instead of spinning forever
- Toast on cancel API failure (network error / 5xx)
2026-02-20 02:21:28 +07:00
Zamil Majdy
76e0c96aa9 feat: fix openapi.json 2026-02-20 02:14:07 +07:00
Zamil Majdy
3364a8e415 refactor(copilot): use generated client for cancel API call
Replace raw fetch() with generated postV2CancelSessionTask() and
remove the now-unnecessary dedicated cancel proxy route — the general
/api/proxy handles auth and forwarding.

Toast on cancel failure so the user knows the backend may still be running.
2026-02-20 02:10:10 +07:00
Zamil Majdy
9f4f2749a4 fix(copilot): address PR review comments for cancel endpoint
- Add CancelTaskResponse Pydantic model with typed return annotation
- Handle non-JSON backend responses in cancel proxy route
- Check for "no-token-found" token before forwarding auth header
- Truncate IDs in log messages for consistency
- Add cancel endpoint to openapi.json for frontend codegen
2026-02-20 02:02:14 +07:00
Zamil Majdy
2b0f457985 feat(copilot): wire up stop button to cancel executor tasks
The stop button was completely disconnected — clicking it only aborted the
client-side SSE fetch while the executor kept running indefinitely.

- Add `enqueue_cancel_task()` to publish `CancelCoPilotEvent` to the
  existing RabbitMQ FANOUT exchange that the executor already consumes
- Add `POST /sessions/{session_id}/cancel` endpoint that finds the active
  task, publishes the cancel event, and polls Redis until the task status
  confirms stopped (up to 10s)
- Add Next.js API proxy route for the cancel endpoint
- Wrap the AI SDK's `stop()` to also call the cancel API so the executor
  actually terminates
2026-02-20 01:20:19 +07:00
342 changed files with 21295 additions and 29536 deletions

View File

@@ -149,7 +149,7 @@ jobs:
driver-opts: network=host
- name: Set up Platform - Expose GHA cache to docker buildx CLI
uses: crazy-max/ghaction-github-runtime@v4
uses: crazy-max/ghaction-github-runtime@v3
- name: Set up Platform - Build Docker images (with cache)
working-directory: autogpt_platform

4
.gitignore vendored
View File

@@ -180,6 +180,4 @@ autogpt_platform/backend/settings.py
.claude/settings.local.json
CLAUDE.local.md
/autogpt_platform/backend/logs
.next
# Implementation plans (generated by AI agents)
plans/
.next

1
.nvmrc
View File

@@ -1 +0,0 @@
22

View File

@@ -1,10 +1,3 @@
default_install_hook_types:
- pre-commit
- pre-push
- post-checkout
default_stages: [pre-commit]
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
@@ -24,7 +17,6 @@ repos:
name: Detect secrets
description: Detects high entropy strings that are likely to be passwords.
files: ^autogpt_platform/
exclude: pnpm-lock\.yaml$
stages: [pre-push]
- repo: local
@@ -34,106 +26,49 @@ repos:
- id: poetry-install
name: Check & Install dependencies - AutoGPT Platform - Backend
alias: poetry-install-platform-backend
entry: poetry -C autogpt_platform/backend install
# include autogpt_libs source (since it's a path dependency)
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$" || exit 0;
poetry -C autogpt_platform/backend install
'
always_run: true
files: ^autogpt_platform/(backend|autogpt_libs)/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - AutoGPT Platform - Libs
alias: poetry-install-platform-libs
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/autogpt_libs/poetry\.lock$" || exit 0;
poetry -C autogpt_platform/autogpt_libs install
'
always_run: true
entry: poetry -C autogpt_platform/autogpt_libs install
files: ^autogpt_platform/autogpt_libs/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: pnpm-install
name: Check & Install dependencies - AutoGPT Platform - Frontend
alias: pnpm-install-platform-frontend
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/frontend/pnpm-lock\.yaml$" || exit 0;
pnpm --prefix autogpt_platform/frontend install
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - AutoGPT
alias: poetry-install-classic-autogpt
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/(original_autogpt|forge)/poetry\.lock$" || exit 0;
poetry -C classic/original_autogpt install
'
entry: poetry -C classic/original_autogpt install
# include forge source (since it's a path dependency)
always_run: true
files: ^classic/(original_autogpt|forge)/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Forge
alias: poetry-install-classic-forge
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/forge/poetry\.lock$" || exit 0;
poetry -C classic/forge install
'
always_run: true
entry: poetry -C classic/forge install
files: ^classic/forge/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: poetry-install
name: Check & Install dependencies - Classic - Benchmark
alias: poetry-install-classic-benchmark
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^classic/benchmark/poetry\.lock$" || exit 0;
poetry -C classic/benchmark install
'
always_run: true
entry: poetry -C classic/benchmark install
files: ^classic/benchmark/poetry\.lock$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- repo: local
# For proper type checking, Prisma client must be up-to-date.
@@ -141,54 +76,12 @@ repos:
- id: prisma-generate
name: Prisma Generate - AutoGPT Platform - Backend
alias: prisma-generate-platform-backend
entry: >
bash -c '
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --name-only "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF"
else
git diff --cached --name-only
fi | grep -qE "^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema\.prisma)$" || exit 0;
cd autogpt_platform/backend
&& poetry run prisma generate
&& poetry run gen-prisma-stub
'
entry: bash -c 'cd autogpt_platform/backend && poetry run prisma generate'
# include everything that triggers poetry install + the prisma schema
always_run: true
files: ^autogpt_platform/((backend|autogpt_libs)/poetry\.lock|backend/schema.prisma)$
types: [file]
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- id: export-api-schema
name: Export API schema - AutoGPT Platform - Backend -> Frontend
alias: export-api-schema-platform
entry: >
bash -c '
cd autogpt_platform/backend
&& poetry run export-api-schema --output ../frontend/src/app/api/openapi.json
&& cd ../frontend
&& pnpm prettier --write ./src/app/api/openapi.json
'
files: ^autogpt_platform/backend/
language: system
pass_filenames: false
- id: generate-api-client
name: Generate API client - AutoGPT Platform - Frontend
alias: generate-api-client-platform-frontend
entry: >
bash -c '
SCHEMA=autogpt_platform/frontend/src/app/api/openapi.json;
if [ -n "$PRE_COMMIT_FROM_REF" ]; then
git diff --quiet "$PRE_COMMIT_FROM_REF" "$PRE_COMMIT_TO_REF" -- "$SCHEMA" && exit 0
else
git diff --quiet HEAD -- "$SCHEMA" && exit 0
fi;
cd autogpt_platform/frontend && pnpm generate:api
'
always_run: true
language: system
pass_filenames: false
stages: [pre-commit, post-checkout]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2

View File

@@ -1,3 +1,2 @@
*.ignore.*
*.ign.*
.application.logs
*.ign.*

View File

@@ -190,8 +190,5 @@ ZEROBOUNCE_API_KEY=
POSTHOG_API_KEY=
POSTHOG_HOST=https://eu.i.posthog.com
# Tally Form Integration (pre-populate business understanding on signup)
TALLY_API_KEY=
# Other Services
AUTOMOD_API_KEY=

View File

@@ -95,7 +95,7 @@ ENV DEBIAN_FRONTEND=noninteractive
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
# for the bash_exec MCP tool (fallback when E2B is not configured).
# for the bash_exec MCP tool.
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.13 \
@@ -111,29 +111,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
# Copy Node.js installation for Prisma and agent-browser.
# npm/npx are symlinks in the builder (-> ../lib/node_modules/npm/bin/*-cli.js);
# COPY resolves them to regular files, breaking require() paths. Recreate as
# proper symlinks so npm/npx can find their modules.
# Copy Node.js installation for Prisma
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
COPY --from=builder /usr/bin/npm /usr/bin/npm
COPY --from=builder /usr/bin/npx /usr/bin/npx
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
# Install agent-browser (Copilot browser tool) + Chromium runtime dependencies.
# These are the runtime libraries Chromium/Playwright needs on Debian 13 (trixie).
RUN apt-get update && apt-get install -y --no-install-recommends \
libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 libcups2 libdrm2 \
libdbus-1-3 libxkbcommon0 libatspi2.0-0t64 libxcomposite1 libxdamage1 \
libxfixes3 libxrandr2 libgbm1 libasound2t64 libpango-1.0-0 libcairo2 \
libx11-6 libx11-xcb1 libxcb1 libxext6 libglib2.0-0t64 \
fonts-liberation libfontconfig1 \
&& rm -rf /var/lib/apt/lists/* \
&& npm install -g agent-browser \
&& agent-browser install \
&& rm -rf /tmp/* /root/.npm
WORKDIR /app/autogpt_platform/backend
# Copy only the .venv from builder (not the entire /app directory)

View File

@@ -88,23 +88,20 @@ async def require_auth(
)
def require_permission(*permissions: APIKeyPermission):
def require_permission(permission: APIKeyPermission):
"""
Dependency function for checking required permissions.
All listed permissions must be present.
Dependency function for checking specific permissions
(works with API keys and OAuth tokens)
"""
async def check_permissions(
async def check_permission(
auth: APIAuthorizationInfo = Security(require_auth),
) -> APIAuthorizationInfo:
missing = [p for p in permissions if p not in auth.scopes]
if missing:
if permission not in auth.scopes:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Missing required permission(s): "
f"{', '.join(p.value for p in missing)}",
detail=f"Missing required permission: {permission.value}",
)
return auth
return check_permissions
return check_permission

View File

@@ -18,7 +18,6 @@ from backend.data import user as user_db
from backend.data.auth.base import APIAuthorizationInfo
from backend.data.block import BlockInput, CompletedBlockOutput
from backend.executor.utils import add_graph_execution
from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate
from backend.util.settings import Settings
from .integrations import integrations_router
@@ -96,43 +95,6 @@ async def execute_graph_block(
return output
@v1_router.post(
path="/graphs",
tags=["graphs"],
status_code=201,
dependencies=[
Security(
require_permission(
APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY
)
)
],
)
async def create_graph(
graph: graph_db.Graph,
auth: APIAuthorizationInfo = Security(
require_permission(APIKeyPermission.WRITE_GRAPH, APIKeyPermission.WRITE_LIBRARY)
),
) -> graph_db.GraphModel:
"""
Create a new agent graph.
The graph will be validated and assigned a new ID.
It is automatically added to the user's library.
"""
from backend.api.features.library import db as library_db
graph_model = graph_db.make_graph_model(graph, auth.user_id)
graph_model.reassign_ids(user_id=auth.user_id, reassign_graph_id=True)
graph_model.validate_graph(for_run=False)
await graph_db.create_graph(graph_model, user_id=auth.user_id)
await library_db.create_library_agent(graph_model, auth.user_id)
activated_graph = await on_graph_activate(graph_model, user_id=auth.user_id)
return activated_graph
@v1_router.post(
path="/graphs/{graph_id}/execute/{graph_version}",
tags=["graphs"],

View File

@@ -1,17 +1,15 @@
import logging
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from difflib import SequenceMatcher
from typing import Any, Sequence, get_args, get_origin
from typing import Sequence
import prisma
from prisma.enums import ContentType
from prisma.models import mv_suggested_blocks
import backend.api.features.library.db as library_db
import backend.api.features.library.model as library_model
import backend.api.features.store.db as store_db
import backend.api.features.store.model as store_model
from backend.api.features.store.hybrid_search import unified_hybrid_search
from backend.blocks import load_all_blocks
from backend.blocks._base import (
AnyBlockSchema,
@@ -21,6 +19,7 @@ from backend.blocks._base import (
BlockType,
)
from backend.blocks.llm import LlmModel
from backend.data.db import query_raw_with_schema
from backend.integrations.providers import ProviderName
from backend.util.cache import cached
from backend.util.models import Pagination
@@ -43,16 +42,6 @@ MAX_LIBRARY_AGENT_RESULTS = 100
MAX_MARKETPLACE_AGENT_RESULTS = 100
MIN_SCORE_FOR_FILTERED_RESULTS = 10.0
# Boost blocks over marketplace agents in search results
BLOCK_SCORE_BOOST = 50.0
# Block IDs to exclude from search results
EXCLUDED_BLOCK_IDS = frozenset(
{
"e189baac-8c20-45a1-94a7-55177ea42565", # AgentExecutorBlock
}
)
SearchResultItem = BlockInfo | library_model.LibraryAgent | store_model.StoreAgent
@@ -75,8 +64,8 @@ def get_block_categories(category_blocks: int = 3) -> list[BlockCategoryResponse
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
# Skip disabled and excluded blocks
if block.disabled or block.id in EXCLUDED_BLOCK_IDS:
# Skip disabled blocks
if block.disabled:
continue
# Skip blocks that don't have categories (all should have at least one)
if not block.categories:
@@ -127,9 +116,6 @@ def get_blocks(
# Skip disabled blocks
if block.disabled:
continue
# Skip excluded blocks
if block.id in EXCLUDED_BLOCK_IDS:
continue
# Skip blocks that don't match the category
if category and category not in {c.name.lower() for c in block.categories}:
continue
@@ -269,25 +255,14 @@ async def _build_cached_search_results(
"my_agents": 0,
}
# Use hybrid search when query is present, otherwise list all blocks
if (include_blocks or include_integrations) and normalized_query:
block_results, block_total, integration_total = await _hybrid_search_blocks(
query=search_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
elif include_blocks or include_integrations:
# No query - list all blocks using in-memory approach
block_results, block_total, integration_total = _collect_block_results(
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
block_results, block_total, integration_total = _collect_block_results(
normalized_query=normalized_query,
include_blocks=include_blocks,
include_integrations=include_integrations,
)
scored_items.extend(block_results)
total_items["blocks"] = block_total
total_items["integrations"] = integration_total
if include_library_agents:
library_response = await library_db.list_library_agents(
@@ -332,14 +307,10 @@ async def _build_cached_search_results(
def _collect_block_results(
*,
normalized_query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
"""
Collect all blocks for listing (no search query).
All blocks get BLOCK_SCORE_BOOST to prioritize them over marketplace agents.
"""
results: list[_ScoredItem] = []
block_count = 0
integration_count = 0
@@ -352,10 +323,6 @@ def _collect_block_results(
if block.disabled:
continue
# Skip excluded blocks
if block.id in EXCLUDED_BLOCK_IDS:
continue
block_info = block.get_info()
credentials = list(block.input_schema.get_credentials_fields().values())
is_integration = len(credentials) > 0
@@ -365,6 +332,10 @@ def _collect_block_results(
if not is_integration and not include_blocks:
continue
score = _score_block(block, block_info, normalized_query)
if not _should_include_item(score, normalized_query):
continue
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
integration_count += 1
@@ -375,122 +346,8 @@ def _collect_block_results(
_ScoredItem(
item=block_info,
filter_type=filter_type,
score=BLOCK_SCORE_BOOST,
sort_key=block_info.name.lower(),
)
)
return results, block_count, integration_count
async def _hybrid_search_blocks(
*,
query: str,
include_blocks: bool,
include_integrations: bool,
) -> tuple[list[_ScoredItem], int, int]:
"""
Search blocks using hybrid search with builder-specific filtering.
Uses unified_hybrid_search for semantic + lexical search, then applies
post-filtering for block/integration types and scoring adjustments.
Scoring:
- Base: hybrid relevance score (0-1) scaled to 0-100, plus BLOCK_SCORE_BOOST
to prioritize blocks over marketplace agents in combined results
- +30 for exact name match, +15 for prefix name match
- +20 if the block has an LlmModel field and the query matches an LLM model name
Args:
query: The search query string
include_blocks: Whether to include regular blocks
include_integrations: Whether to include integration blocks
Returns:
Tuple of (scored_items, block_count, integration_count)
"""
results: list[_ScoredItem] = []
block_count = 0
integration_count = 0
if not include_blocks and not include_integrations:
return results, block_count, integration_count
normalized_query = query.strip().lower()
# Fetch more results to account for post-filtering
search_results, _ = await unified_hybrid_search(
query=query,
content_types=[ContentType.BLOCK],
page=1,
page_size=150,
min_score=0.10,
)
# Load all blocks for getting BlockInfo
all_blocks = load_all_blocks()
for result in search_results:
block_id = result["content_id"]
# Skip excluded blocks
if block_id in EXCLUDED_BLOCK_IDS:
continue
metadata = result.get("metadata", {})
hybrid_score = result.get("relevance", 0.0)
# Get the actual block class
if block_id not in all_blocks:
continue
block_cls = all_blocks[block_id]
block: AnyBlockSchema = block_cls()
if block.disabled:
continue
# Check block/integration filter using metadata
is_integration = metadata.get("is_integration", False)
if is_integration and not include_integrations:
continue
if not is_integration and not include_blocks:
continue
# Get block info
block_info = block.get_info()
# Calculate final score: scale hybrid score and add builder-specific bonuses
# Hybrid scores are 0-1, builder scores were 0-200+
# Add BLOCK_SCORE_BOOST to prioritize blocks over marketplace agents
final_score = hybrid_score * 100 + BLOCK_SCORE_BOOST
# Add LLM model match bonus
has_llm_field = metadata.get("has_llm_model_field", False)
if has_llm_field and _matches_llm_model(block.input_schema, normalized_query):
final_score += 20
# Add exact/prefix match bonus for deterministic tie-breaking
name = block_info.name.lower()
if name == normalized_query:
final_score += 30
elif name.startswith(normalized_query):
final_score += 15
# Track counts
filter_type: FilterType = "integrations" if is_integration else "blocks"
if is_integration:
integration_count += 1
else:
block_count += 1
results.append(
_ScoredItem(
item=block_info,
filter_type=filter_type,
score=final_score,
sort_key=name,
score=score,
sort_key=_get_item_name(block_info),
)
)
@@ -615,8 +472,6 @@ async def _get_static_counts():
block: AnyBlockSchema = block_type()
if block.disabled:
continue
if block.id in EXCLUDED_BLOCK_IDS:
continue
all_blocks += 1
@@ -643,25 +498,47 @@ async def _get_static_counts():
}
def _contains_type(annotation: Any, target: type) -> bool:
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
if annotation is target:
return True
origin = get_origin(annotation)
if origin is None:
return False
return any(_contains_type(arg, target) for arg in get_args(annotation))
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
for field in schema_cls.model_fields.values():
if _contains_type(field.annotation, LlmModel):
if field.annotation == LlmModel:
# Check if query matches any value in llm_models
if any(query in name for name in llm_models):
return True
return False
def _score_block(
block: AnyBlockSchema,
block_info: BlockInfo,
normalized_query: str,
) -> float:
if not normalized_query:
return 0.0
name = block_info.name.lower()
description = block_info.description.lower()
score = _score_primary_fields(name, description, normalized_query)
category_text = " ".join(
category.get("category", "").lower() for category in block_info.categories
)
score += _score_additional_field(category_text, normalized_query, 12, 6)
credentials_info = block.input_schema.get_credentials_fields_info().values()
provider_names = [
provider.value.lower()
for info in credentials_info
for provider in info.provider
]
provider_text = " ".join(provider_names)
score += _score_additional_field(provider_text, normalized_query, 15, 6)
if _matches_llm_model(block.input_schema, normalized_query):
score += 20
return score
def _score_library_agent(
agent: library_model.LibraryAgent,
normalized_query: str,
@@ -768,20 +645,31 @@ def _get_all_providers() -> dict[ProviderName, Provider]:
return providers
@cached(ttl_seconds=3600, shared_cache=True)
@cached(ttl_seconds=3600)
async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
"""Return the most-executed blocks from the last 14 days.
suggested_blocks = []
# Sum the number of executions for each block type
# Prisma cannot group by nested relations, so we do a raw query
# Calculate the cutoff timestamp
timestamp_threshold = datetime.now(timezone.utc) - timedelta(days=30)
Queries the mv_suggested_blocks materialized view (refreshed hourly via pg_cron)
and returns the top `count` blocks sorted by execution count, excluding
Input/Output/Agent block types and blocks in EXCLUDED_BLOCK_IDS.
"""
results = await mv_suggested_blocks.prisma().find_many()
results = await query_raw_with_schema(
"""
SELECT
agent_node."agentBlockId" AS block_id,
COUNT(execution.id) AS execution_count
FROM {schema_prefix}"AgentNodeExecution" execution
JOIN {schema_prefix}"AgentNode" agent_node ON execution."agentNodeId" = agent_node.id
WHERE execution."endedTime" >= $1::timestamp
GROUP BY agent_node."agentBlockId"
ORDER BY execution_count DESC;
""",
timestamp_threshold,
)
# Get the top blocks based on execution count
# But ignore Input, Output, Agent, and excluded blocks
# But ignore Input and Output blocks
blocks: list[tuple[BlockInfo, int]] = []
execution_counts = {row.block_id: row.execution_count for row in results}
for block_type in load_all_blocks().values():
block: AnyBlockSchema = block_type()
@@ -791,9 +679,11 @@ async def get_suggested_blocks(count: int = 5) -> list[BlockInfo]:
BlockType.AGENT,
):
continue
if block.id in EXCLUDED_BLOCK_IDS:
continue
execution_count = execution_counts.get(block.id, 0)
# Find the execution count for this block
execution_count = next(
(row["execution_count"] for row in results if row["block_id"] == block.id),
0,
)
blocks.append((block.get_info(), execution_count))
# Sort blocks by execution count
blocks.sort(key=lambda x: x[1], reverse=True)

View File

@@ -27,6 +27,7 @@ class SearchEntry(BaseModel):
# Suggestions
class SuggestionsResponse(BaseModel):
otto_suggestions: list[str]
recent_searches: list[SearchEntry]
providers: list[ProviderName]
top_blocks: list[BlockInfo]

View File

@@ -1,5 +1,5 @@
import logging
from typing import Annotated, Sequence, cast, get_args
from typing import Annotated, Sequence
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
@@ -10,8 +10,6 @@ from backend.util.models import Pagination
from . import db as builder_db
from . import model as builder_model
VALID_FILTER_VALUES = get_args(builder_model.FilterType)
logger = logging.getLogger(__name__)
router = fastapi.APIRouter(
@@ -51,6 +49,11 @@ async def get_suggestions(
Get all suggestions for the Blocks Menu.
"""
return builder_model.SuggestionsResponse(
otto_suggestions=[
"What blocks do I need to get started?",
"Help me create a list",
"Help me feed my data to Google Maps",
],
recent_searches=await builder_db.get_recent_searches(user_id),
providers=[
ProviderName.TWITTER,
@@ -148,7 +151,7 @@ async def get_providers(
async def search(
user_id: Annotated[str, fastapi.Security(get_user_id)],
search_query: Annotated[str | None, fastapi.Query()] = None,
filter: Annotated[str | None, fastapi.Query()] = None,
filter: Annotated[list[builder_model.FilterType] | None, fastapi.Query()] = None,
search_id: Annotated[str | None, fastapi.Query()] = None,
by_creator: Annotated[list[str] | None, fastapi.Query()] = None,
page: Annotated[int, fastapi.Query()] = 1,
@@ -157,20 +160,9 @@ async def search(
"""
Search for blocks (including integrations), marketplace agents, and user library agents.
"""
# Parse and validate filter parameter
filters: list[builder_model.FilterType]
if filter:
filter_values = [f.strip() for f in filter.split(",")]
invalid_filters = [f for f in filter_values if f not in VALID_FILTER_VALUES]
if invalid_filters:
raise fastapi.HTTPException(
status_code=400,
detail=f"Invalid filter value(s): {', '.join(invalid_filters)}. "
f"Valid values are: {', '.join(VALID_FILTER_VALUES)}",
)
filters = cast(list[builder_model.FilterType], filter_values)
else:
filters = [
# If no filters are provided, then we will return all types
if not filter:
filter = [
"blocks",
"integrations",
"marketplace_agents",
@@ -182,7 +174,7 @@ async def search(
cached_results = await builder_db.get_sorted_search_results(
user_id=user_id,
search_query=search_query,
filters=filters,
filters=filter,
by_creator=by_creator,
)
@@ -204,7 +196,7 @@ async def search(
user_id,
builder_model.SearchEntry(
search_query=search_query,
filter=filters,
filter=filter,
by_creator=by_creator,
search_id=search_id,
),

View File

@@ -2,21 +2,23 @@
import asyncio
import logging
import re
import uuid as uuid_module
from collections.abc import AsyncGenerator
from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
from fastapi import APIRouter, Depends, HTTPException, Query, Response, Security
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, Security
from fastapi.responses import StreamingResponse
from prisma.models import UserWorkspaceFile
from pydantic import BaseModel, Field
from pydantic import BaseModel
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.completion_handler import (
process_operation_failure,
process_operation_success,
)
from backend.copilot.config import ChatConfig
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_task
from backend.copilot.model import (
ChatMessage,
ChatSession,
@@ -42,23 +44,20 @@ from backend.copilot.tools.models import (
ErrorResponse,
ExecutionStartedResponse,
InputValidationErrorResponse,
MCPToolOutputResponse,
MCPToolsDiscoveredResponse,
NeedLoginResponse,
NoResultsResponse,
OperationInProgressResponse,
OperationPendingResponse,
OperationStartedResponse,
SetupRequirementsResponse,
SuggestedGoalResponse,
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
from backend.data.workspace import get_or_create_workspace
from backend.util.exceptions import NotFoundError
config = ChatConfig()
_UUID_RE = re.compile(
r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", re.I
)
logger = logging.getLogger(__name__)
@@ -87,9 +86,6 @@ class StreamChatRequest(BaseModel):
message: str
is_user_message: bool = True
context: dict[str, str] | None = None # {url: str, content: str}
file_ids: list[str] | None = Field(
default=None, max_length=20
) # Workspace file IDs attached to this message
class CreateSessionResponse(BaseModel):
@@ -103,8 +99,10 @@ class CreateSessionResponse(BaseModel):
class ActiveStreamInfo(BaseModel):
"""Information about an active stream for reconnection."""
turn_id: str
task_id: str
last_message_id: str # Redis Stream message ID for resumption
operation_id: str # Operation ID for completion tracking
tool_name: str # Name of the tool being executed
class SessionDetailResponse(BaseModel):
@@ -134,13 +132,22 @@ class ListSessionsResponse(BaseModel):
total: int
class CancelSessionResponse(BaseModel):
"""Response model for the cancel session endpoint."""
class CancelTaskResponse(BaseModel):
"""Response model for the cancel task endpoint."""
cancelled: bool
task_id: str | None = None
reason: str | None = None
class OperationCompleteRequest(BaseModel):
"""Request model for external completion webhook."""
success: bool
result: dict | str | None = None
error: str | None = None
# ========== Routes ==========
@@ -249,18 +256,6 @@ async def delete_session(
detail=f"Session {session_id} not found or access denied",
)
# Best-effort cleanup of the E2B sandbox (if any).
config = ChatConfig()
if config.use_e2b_sandbox and config.e2b_api_key:
from backend.copilot.tools.e2b_sandbox import kill_sandbox
try:
await kill_sandbox(session_id, config.e2b_api_key)
except Exception:
logger.warning(
"[E2B] Failed to kill sandbox for session %s", session_id[:12]
)
return Response(status_code=204)
@@ -275,7 +270,7 @@ async def get_session(
Retrieve the details of a specific chat session.
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
If there's an active stream for this session, returns active_stream info for reconnection.
If there's an active stream for this session, returns the task_id for reconnection.
Args:
session_id: The unique identifier for the desired chat session.
@@ -293,21 +288,28 @@ async def get_session(
# Check if there's an active stream for this session
active_stream_info = None
active_session, last_message_id = await stream_registry.get_active_session(
active_task, last_message_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
logger.info(
f"[GET_SESSION] session={session_id}, active_session={active_session is not None}, "
f"[GET_SESSION] session={session_id}, active_task={active_task is not None}, "
f"msg_count={len(messages)}, last_role={messages[-1].get('role') if messages else 'none'}"
)
if active_session:
# Keep the assistant message (including tool_calls) so the frontend can
# render the correct tool UI (e.g. CreateAgent with mini game).
# convertChatSessionToUiMessages handles isComplete=false by setting
# tool parts without output to state "input-available".
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
# stream replay instead, preventing duplicate content.
if messages and messages[-1].get("role") == "assistant":
messages = messages[:-1]
# Use "0-0" as last_message_id to replay the stream from the beginning.
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
active_stream_info = ActiveStreamInfo(
turn_id=active_session.turn_id,
last_message_id=last_message_id,
task_id=active_task.task_id,
last_message_id="0-0",
operation_id=active_task.operation_id,
tool_name=active_task.tool_name,
)
return SessionDetailResponse(
@@ -327,7 +329,7 @@ async def get_session(
async def cancel_session_task(
session_id: str,
user_id: Annotated[str | None, Depends(auth.get_user_id)],
) -> CancelSessionResponse:
) -> CancelTaskResponse:
"""Cancel the active streaming task for a session.
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
@@ -336,33 +338,39 @@ async def cancel_session_task(
"""
await _validate_and_get_session(session_id, user_id)
active_session, _ = await stream_registry.get_active_session(session_id, user_id)
if not active_session:
return CancelSessionResponse(cancelled=True, reason="no_active_session")
active_task, _ = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_task:
return CancelTaskResponse(cancelled=False, reason="no_active_task")
await enqueue_cancel_task(session_id)
logger.info(f"[CANCEL] Published cancel for session ...{session_id[-8:]}")
task_id = active_task.task_id
await enqueue_cancel_task(task_id)
logger.info(
f"[CANCEL] Published cancel for task ...{task_id[-8:]} "
f"session ...{session_id[-8:]}"
)
# Poll until the executor confirms the task is no longer running.
# Keep max_wait below typical reverse-proxy read timeouts.
poll_interval = 0.5
max_wait = 5.0
waited = 0.0
while waited < max_wait:
await asyncio.sleep(poll_interval)
waited += poll_interval
session_state = await stream_registry.get_session(session_id)
if session_state is None or session_state.status != "running":
task = await stream_registry.get_task(task_id)
if task is None or task.status != "running":
logger.info(
f"[CANCEL] Session ...{session_id[-8:]} confirmed stopped "
f"(status={session_state.status if session_state else 'gone'}) after {waited:.1f}s"
f"[CANCEL] Task ...{task_id[-8:]} confirmed stopped "
f"(status={task.status if task else 'gone'}) after {waited:.1f}s"
)
return CancelSessionResponse(cancelled=True)
return CancelTaskResponse(cancelled=True, task_id=task_id)
logger.warning(
f"[CANCEL] Session ...{session_id[-8:]} not confirmed after {max_wait}s, force-completing"
logger.warning(f"[CANCEL] Task ...{task_id[-8:]} not confirmed after {max_wait}s")
return CancelTaskResponse(
cancelled=True, task_id=task_id, reason="cancel_published_not_confirmed"
)
await stream_registry.mark_session_completed(session_id, error_message="Cancelled")
return CancelSessionResponse(cancelled=True)
@router.post(
@@ -382,15 +390,16 @@ async def stream_chat_post(
- Tool execution results
The AI generation runs in a background task that continues even if the client disconnects.
All chunks are written to a per-turn Redis stream for reconnection support. If the client
disconnects, they can reconnect using GET /sessions/{session_id}/stream to resume.
All chunks are written to Redis for reconnection support. If the client disconnects,
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
Args:
session_id: The chat session identifier to associate with the streamed messages.
request: Request body containing message, is_user_message, and optional context.
user_id: Optional authenticated user ID.
Returns:
StreamingResponse: SSE-formatted response chunks.
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
containing the task_id for reconnection.
"""
import asyncio
@@ -417,38 +426,6 @@ async def stream_chat_post(
},
)
# Enrich message with file metadata if file_ids are provided.
# Also sanitise file_ids so only validated, workspace-scoped IDs are
# forwarded downstream (e.g. to the executor via enqueue_copilot_turn).
sanitized_file_ids: list[str] | None = None
if request.file_ids and user_id:
# Filter to valid UUIDs only to prevent DB abuse
valid_ids = [fid for fid in request.file_ids if _UUID_RE.match(fid)]
if valid_ids:
workspace = await get_or_create_workspace(user_id)
# Batch query instead of N+1
files = await UserWorkspaceFile.prisma().find_many(
where={
"id": {"in": valid_ids},
"workspaceId": workspace.id,
"isDeleted": False,
}
)
# Only keep IDs that actually exist in the user's workspace
sanitized_file_ids = [wf.id for wf in files] or None
file_lines: list[str] = [
f"- {wf.name} ({wf.mimeType}, {round(wf.sizeBytes / 1024, 1)} KB), file_id={wf.id}"
for wf in files
]
if file_lines:
files_block = (
"\n\n[Attached files]\n"
+ "\n".join(file_lines)
+ "\nUse read_workspace_file with the file_id to access file contents."
)
request.message += files_block
# Atomically append user message to session BEFORE creating task to avoid
# race condition where GET_SESSION sees task as "running" but message isn't
# saved yet. append_and_save_message re-fetches inside a lock to prevent
@@ -469,38 +446,37 @@ async def stream_chat_post(
logger.info(f"[STREAM] User message saved for session {session_id}")
# Create a task in the stream registry for reconnection support
turn_id = str(uuid4())
log_meta["turn_id"] = turn_id
task_id = str(uuid_module.uuid4())
operation_id = str(uuid_module.uuid4())
log_meta["task_id"] = task_id
session_create_start = time.perf_counter()
await stream_registry.create_session(
task_create_start = time.perf_counter()
await stream_registry.create_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
tool_call_id="chat_stream",
tool_call_id="chat_stream", # Not a tool call, but needed for the model
tool_name="chat",
turn_id=turn_id,
operation_id=operation_id,
)
logger.info(
f"[TIMING] create_session completed in {(time.perf_counter() - session_create_start) * 1000:.1f}ms",
f"[TIMING] create_task completed in {(time.perf_counter() - task_create_start) * 1000:.1f}ms",
extra={
"json_fields": {
**log_meta,
"duration_ms": (time.perf_counter() - session_create_start) * 1000,
"duration_ms": (time.perf_counter() - task_create_start) * 1000,
}
},
)
# Per-turn stream is always fresh (unique turn_id), subscribe from beginning
subscribe_from_id = "0-0"
await enqueue_copilot_turn(
await enqueue_copilot_task(
task_id=task_id,
session_id=session_id,
user_id=user_id,
operation_id=operation_id,
message=request.message,
turn_id=turn_id,
is_user_message=request.is_user_message,
context=request.context,
file_ids=sanitized_file_ids,
)
setup_time = (time.perf_counter() - stream_start_time) * 1000
@@ -515,7 +491,7 @@ async def stream_chat_post(
event_gen_start = time_module.perf_counter()
logger.info(
f"[TIMING] event_generator STARTED, turn={turn_id}, session={session_id}, "
f"[TIMING] event_generator STARTED, task={task_id}, session={session_id}, "
f"user={user_id}",
extra={"json_fields": log_meta},
)
@@ -523,12 +499,11 @@ async def stream_chat_post(
first_chunk_yielded = False
chunks_yielded = 0
try:
# Subscribe from the position we captured before enqueuing
# This avoids replaying old messages while catching all new ones
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
# Subscribe to the task stream (this replays existing messages + live updates)
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=subscribe_from_id,
last_message_id="0-0", # Get all messages from the beginning
)
if subscriber_queue is None:
@@ -543,7 +518,7 @@ async def stream_chat_post(
)
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
chunks_yielded += 1
if not first_chunk_yielded:
@@ -611,19 +586,19 @@ async def stream_chat_post(
# Unsubscribe when client disconnects or stream ends
if subscriber_queue is not None:
try:
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
await stream_registry.unsubscribe_from_task(
task_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from session {session_id}: {unsub_err}",
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
total_time = time_module.perf_counter() - event_gen_start
logger.info(
f"[TIMING] event_generator FINISHED in {total_time:.2f}s; "
f"turn={turn_id}, session={session_id}, n_chunks={chunks_yielded}",
f"task={task_id}, session={session_id}, n_chunks={chunks_yielded}",
extra={
"json_fields": {
**log_meta,
@@ -670,21 +645,17 @@ async def resume_session_stream(
"""
import asyncio
active_session, last_message_id = await stream_registry.get_active_session(
active_task, _last_id = await stream_registry.get_active_task_for_session(
session_id, user_id
)
if not active_session:
if not active_task:
return Response(status_code=204)
# Always replay from the beginning ("0-0") on resume.
# We can't use last_message_id because it's the latest ID in the backend
# stream, not the latest the frontend received — the gap causes lost
# messages. The frontend deduplicates replayed content.
subscriber_queue = await stream_registry.subscribe_to_session(
session_id=session_id,
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=active_task.task_id,
user_id=user_id,
last_message_id="0-0",
last_message_id="0-0", # Full replay so useChat rebuilds the message
)
if subscriber_queue is None:
@@ -696,7 +667,7 @@ async def resume_session_stream(
try:
while True:
try:
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=10.0)
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
if chunk_count < 3:
logger.info(
"Resume stream chunk",
@@ -720,12 +691,12 @@ async def resume_session_stream(
logger.error(f"Error in resume stream for session {session_id}: {e}")
finally:
try:
await stream_registry.unsubscribe_from_session(
session_id, subscriber_queue
await stream_registry.unsubscribe_from_task(
active_task.task_id, subscriber_queue
)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from session {active_session.session_id}: {unsub_err}",
f"Error unsubscribing from task {active_task.task_id}: {unsub_err}",
exc_info=True,
)
logger.info(
@@ -776,6 +747,229 @@ async def session_assign_user(
return {"status": "ok"}
# ========== Task Streaming (SSE Reconnection) ==========
@router.get(
"/tasks/{task_id}/stream",
)
async def stream_task(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
last_message_id: str = Query(
default="0-0",
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
),
):
"""
Reconnect to a long-running task's SSE stream.
When a long-running operation (like agent generation) starts, the client
receives a task_id. If the connection drops, the client can reconnect
using this endpoint to resume receiving updates.
Args:
task_id: The task ID from the operation_started response.
user_id: Authenticated user ID for ownership validation.
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
Returns:
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
Raises:
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
"""
# Check task existence and expiry before subscribing
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
if error_code == "TASK_EXPIRED":
raise HTTPException(
status_code=410,
detail={
"code": "TASK_EXPIRED",
"message": "This operation has expired. Please try again.",
},
)
if error_code == "TASK_NOT_FOUND":
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found.",
},
)
# Validate ownership if task has an owner
if task and task.user_id and user_id != task.user_id:
raise HTTPException(
status_code=403,
detail={
"code": "ACCESS_DENIED",
"message": "You do not have access to this task.",
},
)
# Get subscriber queue from stream registry
subscriber_queue = await stream_registry.subscribe_to_task(
task_id=task_id,
user_id=user_id,
last_message_id=last_message_id,
)
if subscriber_queue is None:
raise HTTPException(
status_code=404,
detail={
"code": "TASK_NOT_FOUND",
"message": f"Task {task_id} not found or access denied.",
},
)
async def event_generator() -> AsyncGenerator[str, None]:
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
try:
while True:
try:
# Wait for next chunk with timeout for heartbeats
chunk = await asyncio.wait_for(
subscriber_queue.get(), timeout=heartbeat_interval
)
yield chunk.to_sse()
# Check for finish signal
if isinstance(chunk, StreamFinish):
break
except asyncio.TimeoutError:
# Send heartbeat to keep connection alive
yield StreamHeartbeat().to_sse()
except Exception as e:
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
finally:
# Unsubscribe when client disconnects or stream ends
try:
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
except Exception as unsub_err:
logger.error(
f"Error unsubscribing from task {task_id}: {unsub_err}",
exc_info=True,
)
# AI SDK protocol termination - always yield even if unsubscribe fails
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
},
)
@router.get(
"/tasks/{task_id}",
)
async def get_task_status(
task_id: str,
user_id: str | None = Depends(auth.get_user_id),
) -> dict:
"""
Get the status of a long-running task.
Args:
task_id: The task ID to check.
user_id: Authenticated user ID for ownership validation.
Returns:
dict: Task status including task_id, status, tool_name, and operation_id.
Raises:
NotFoundError: If task_id is not found or user doesn't have access.
"""
task = await stream_registry.get_task(task_id)
if task is None:
raise NotFoundError(f"Task {task_id} not found.")
# Validate ownership - if task has an owner, requester must match
if task.user_id and user_id != task.user_id:
raise NotFoundError(f"Task {task_id} not found.")
return {
"task_id": task.task_id,
"session_id": task.session_id,
"status": task.status,
"tool_name": task.tool_name,
"operation_id": task.operation_id,
"created_at": task.created_at.isoformat(),
}
# ========== External Completion Webhook ==========
@router.post(
"/operations/{operation_id}/complete",
status_code=200,
)
async def complete_operation(
operation_id: str,
request: OperationCompleteRequest,
x_api_key: str | None = Header(default=None),
) -> dict:
"""
External completion webhook for long-running operations.
Called by Agent Generator (or other services) when an operation completes.
This triggers the stream registry to publish completion and continue LLM generation.
Args:
operation_id: The operation ID to complete.
request: Completion payload with success status and result/error.
x_api_key: Internal API key for authentication.
Returns:
dict: Status of the completion.
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None:
raise HTTPException(
status_code=404,
detail=f"Operation {operation_id} not found",
)
logger.info(
f"Received completion webhook for operation {operation_id} "
f"(task_id={task.task_id}, success={request.success})"
)
if request.success:
await process_operation_success(task, request.result)
else:
await process_operation_failure(task, request.error)
return {"status": "ok", "task_id": task.task_id}
# ========== Configuration ==========
@@ -856,8 +1050,9 @@ ToolResponseUnion = (
| BlockOutputResponse
| DocSearchResultsResponse
| DocPageResponse
| MCPToolsDiscoveredResponse
| MCPToolOutputResponse
| OperationStartedResponse
| OperationPendingResponse
| OperationInProgressResponse
)

View File

@@ -1,160 +0,0 @@
"""Tests for chat route file_ids validation and enrichment."""
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from backend.api.features.chat import routes as chat_routes
app = fastapi.FastAPI()
app.include_router(chat_routes.router)
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
# ---- file_ids Pydantic validation (B1) ----
def test_stream_chat_rejects_too_many_file_ids():
"""More than 20 file_ids should be rejected by Pydantic validation (422)."""
response = client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(21)],
},
)
assert response.status_code == 422
def _mock_stream_internals(mocker: pytest_mock.MockFixture):
"""Mock the async internals of stream_chat_post so tests can exercise
validation and enrichment logic without needing Redis/RabbitMQ."""
mocker.patch(
"backend.api.features.chat.routes._validate_and_get_session",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.append_and_save_message",
return_value=None,
)
mock_registry = mocker.MagicMock()
mock_registry.create_session = mocker.AsyncMock(return_value=None)
mocker.patch(
"backend.api.features.chat.routes.stream_registry",
mock_registry,
)
mocker.patch(
"backend.api.features.chat.routes.enqueue_copilot_turn",
return_value=None,
)
mocker.patch(
"backend.api.features.chat.routes.track_user_message",
return_value=None,
)
def test_stream_chat_accepts_20_file_ids(mocker: pytest_mock.MockFixture):
"""Exactly 20 file_ids should be accepted (not rejected by validation)."""
_mock_stream_internals(mocker)
# Patch workspace lookup as imported by the routes module
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
response = client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [f"00000000-0000-0000-0000-{i:012d}" for i in range(20)],
},
)
# Should get past validation — 200 streaming response expected
assert response.status_code == 200
# ---- UUID format filtering ----
def test_file_ids_filters_invalid_uuids(mocker: pytest_mock.MockFixture):
"""Non-UUID strings in file_ids should be silently filtered out
and NOT passed to the database query."""
_mock_stream_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "ws-1"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
valid_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
client.post(
"/sessions/sess-1/stream",
json={
"message": "hello",
"file_ids": [
valid_id,
"not-a-uuid",
"../../../etc/passwd",
"",
],
},
)
# The find_many call should only receive the one valid UUID
mock_prisma.find_many.assert_called_once()
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["id"]["in"] == [valid_id]
# ---- Cross-workspace file_ids ----
def test_file_ids_scoped_to_workspace(mocker: pytest_mock.MockFixture):
"""The batch query should scope to the user's workspace."""
_mock_stream_internals(mocker)
mocker.patch(
"backend.api.features.chat.routes.get_or_create_workspace",
return_value=type("W", (), {"id": "my-workspace-id"})(),
)
mock_prisma = mocker.MagicMock()
mock_prisma.find_many = mocker.AsyncMock(return_value=[])
mocker.patch(
"prisma.models.UserWorkspaceFile.prisma",
return_value=mock_prisma,
)
fid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
client.post(
"/sessions/sess-1/stream",
json={"message": "hi", "file_ids": [fid]},
)
call_kwargs = mock_prisma.find_many.call_args[1]
assert call_kwargs["where"]["workspaceId"] == "my-workspace-id"
assert call_kwargs["where"]["isDeleted"] is False

View File

@@ -22,7 +22,6 @@ from backend.data.human_review import (
)
from backend.data.model import USER_TIMEZONE_NOT_SET
from backend.data.user import get_user_by_id
from backend.data.workspace import get_or_create_workspace
from backend.executor.utils import add_graph_execution
from .model import PendingHumanReviewModel, ReviewRequest, ReviewResponse
@@ -322,13 +321,10 @@ async def process_review_action(
user.timezone if user.timezone != USER_TIMEZONE_NOT_SET else "UTC"
)
workspace = await get_or_create_workspace(user_id)
execution_context = ExecutionContext(
human_in_the_loop_safe_mode=settings.human_in_the_loop_safe_mode,
sensitive_action_safe_mode=settings.sensitive_action_safe_mode,
user_timezone=user_timezone,
workspace_id=workspace.id,
)
await add_graph_execution(

File diff suppressed because it is too large Load Diff

View File

@@ -144,7 +144,6 @@ async def test_add_agent_to_library(mocker):
)
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
mock_library_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.find_unique = mocker.AsyncMock(return_value=None)
mock_library_agent.return_value.create = mocker.AsyncMock(
return_value=mock_library_agent_data
@@ -179,6 +178,7 @@ async def test_add_agent_to_library(mocker):
"agentGraphVersion": 1,
}
},
include={"AgentGraph": True},
)
# Check that create was called with the expected data including settings
create_call_args = mock_library_agent.return_value.create.call_args

View File

@@ -1,10 +0,0 @@
class FolderValidationError(Exception):
"""Raised when folder operations fail validation."""
pass
class FolderAlreadyExistsError(FolderValidationError):
"""Raised when a folder with the same name already exists in the location."""
pass

View File

@@ -26,95 +26,6 @@ class LibraryAgentStatus(str, Enum):
ERROR = "ERROR"
# === Folder Models ===
class LibraryFolder(pydantic.BaseModel):
"""Represents a folder for organizing library agents."""
id: str
user_id: str
name: str
icon: str | None = None
color: str | None = None
parent_id: str | None = None
created_at: datetime.datetime
updated_at: datetime.datetime
agent_count: int = 0 # Direct agents in folder
subfolder_count: int = 0 # Direct child folders
@staticmethod
def from_db(
folder: prisma.models.LibraryFolder,
agent_count: int = 0,
subfolder_count: int = 0,
) -> "LibraryFolder":
"""Factory method that constructs a LibraryFolder from a Prisma model."""
return LibraryFolder(
id=folder.id,
user_id=folder.userId,
name=folder.name,
icon=folder.icon,
color=folder.color,
parent_id=folder.parentId,
created_at=folder.createdAt,
updated_at=folder.updatedAt,
agent_count=agent_count,
subfolder_count=subfolder_count,
)
class LibraryFolderTree(LibraryFolder):
"""Folder with nested children for tree view."""
children: list["LibraryFolderTree"] = []
class FolderCreateRequest(pydantic.BaseModel):
"""Request model for creating a folder."""
name: str = pydantic.Field(..., min_length=1, max_length=100)
icon: str | None = None
color: str | None = pydantic.Field(
None, pattern=r"^#[0-9A-Fa-f]{6}$", description="Hex color code (#RRGGBB)"
)
parent_id: str | None = None
class FolderUpdateRequest(pydantic.BaseModel):
"""Request model for updating a folder."""
name: str | None = pydantic.Field(None, min_length=1, max_length=100)
icon: str | None = None
color: str | None = None
class FolderMoveRequest(pydantic.BaseModel):
"""Request model for moving a folder to a new parent."""
target_parent_id: str | None = None # None = move to root
class BulkMoveAgentsRequest(pydantic.BaseModel):
"""Request model for moving multiple agents to a folder."""
agent_ids: list[str]
folder_id: str | None = None # None = move to root
class FolderListResponse(pydantic.BaseModel):
"""Response schema for a list of folders."""
folders: list[LibraryFolder]
pagination: Pagination
class FolderTreeResponse(pydantic.BaseModel):
"""Response schema for folder tree structure."""
tree: list[LibraryFolderTree]
class MarketplaceListingCreator(pydantic.BaseModel):
"""Creator information for a marketplace listing."""
@@ -209,9 +120,6 @@ class LibraryAgent(pydantic.BaseModel):
can_access_graph: bool
is_latest_version: bool
is_favorite: bool
folder_id: str | None = None
folder_name: str | None = None # Denormalized for display
recommended_schedule_cron: str | None = None
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
marketplace_listing: Optional["MarketplaceListing"] = None
@@ -351,8 +259,6 @@ class LibraryAgent(pydantic.BaseModel):
can_access_graph=can_access_graph,
is_latest_version=is_latest_version,
is_favorite=agent.isFavorite,
folder_id=agent.folderId,
folder_name=agent.Folder.name if agent.Folder else None,
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
settings=_parse_settings(agent.settings),
marketplace_listing=marketplace_listing_data,
@@ -564,7 +470,3 @@ class LibraryAgentUpdateRequest(pydantic.BaseModel):
settings: Optional[GraphSettings] = pydantic.Field(
default=None, description="User-specific settings for this library agent"
)
folder_id: Optional[str] = pydantic.Field(
default=None,
description="Folder ID to move agent to (None to move to root)",
)

View File

@@ -1,11 +1,9 @@
import fastapi
from .agents import router as agents_router
from .folders import router as folders_router
from .presets import router as presets_router
router = fastapi.APIRouter()
router.include_router(presets_router)
router.include_router(folders_router)
router.include_router(agents_router)

View File

@@ -41,14 +41,6 @@ async def list_library_agents(
ge=1,
description="Number of agents per page (must be >= 1)",
),
folder_id: Optional[str] = Query(
None,
description="Filter by folder ID",
),
include_root_only: bool = Query(
False,
description="Only return agents without a folder (root-level agents)",
),
) -> library_model.LibraryAgentResponse:
"""
Get all agents in the user's library (both created and saved).
@@ -59,8 +51,6 @@ async def list_library_agents(
sort_by=sort_by,
page=page,
page_size=page_size,
folder_id=folder_id,
include_root_only=include_root_only,
)
@@ -178,7 +168,6 @@ async def update_library_agent(
is_favorite=payload.is_favorite,
is_archived=payload.is_archived,
settings=payload.settings,
folder_id=payload.folder_id,
)

View File

@@ -1,287 +0,0 @@
from typing import Optional
import autogpt_libs.auth as autogpt_auth_lib
from fastapi import APIRouter, Query, Security, status
from fastapi.responses import Response
from .. import db as library_db
from .. import model as library_model
router = APIRouter(
prefix="/folders",
tags=["library", "folders", "private"],
dependencies=[Security(autogpt_auth_lib.requires_user)],
)
@router.get(
"",
summary="List Library Folders",
response_model=library_model.FolderListResponse,
responses={
200: {"description": "List of folders"},
500: {"description": "Server error"},
},
)
async def list_folders(
user_id: str = Security(autogpt_auth_lib.get_user_id),
parent_id: Optional[str] = Query(
None,
description="Filter by parent folder ID. If not provided, returns root-level folders.",
),
include_relations: bool = Query(
True,
description="Include agent and subfolder relations (for counts)",
),
) -> library_model.FolderListResponse:
"""
List folders for the authenticated user.
Args:
user_id: ID of the authenticated user.
parent_id: Optional parent folder ID to filter by.
include_relations: Whether to include agent and subfolder relations for counts.
Returns:
A FolderListResponse containing folders.
"""
folders = await library_db.list_folders(
user_id=user_id,
parent_id=parent_id,
include_relations=include_relations,
)
return library_model.FolderListResponse(
folders=folders,
pagination=library_model.Pagination(
total_items=len(folders),
total_pages=1,
current_page=1,
page_size=len(folders),
),
)
@router.get(
"/tree",
summary="Get Folder Tree",
response_model=library_model.FolderTreeResponse,
responses={
200: {"description": "Folder tree structure"},
500: {"description": "Server error"},
},
)
async def get_folder_tree(
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.FolderTreeResponse:
"""
Get the full folder tree for the authenticated user.
Args:
user_id: ID of the authenticated user.
Returns:
A FolderTreeResponse containing the nested folder structure.
"""
tree = await library_db.get_folder_tree(user_id=user_id)
return library_model.FolderTreeResponse(tree=tree)
@router.get(
"/{folder_id}",
summary="Get Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder details"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def get_folder(
folder_id: str,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Get a specific folder.
Args:
folder_id: ID of the folder to retrieve.
user_id: ID of the authenticated user.
Returns:
The requested LibraryFolder.
"""
return await library_db.get_folder(folder_id=folder_id, user_id=user_id)
@router.post(
"",
summary="Create Folder",
status_code=status.HTTP_201_CREATED,
response_model=library_model.LibraryFolder,
responses={
201: {"description": "Folder created successfully"},
400: {"description": "Validation error"},
404: {"description": "Parent folder not found"},
409: {"description": "Folder name conflict"},
500: {"description": "Server error"},
},
)
async def create_folder(
payload: library_model.FolderCreateRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Create a new folder.
Args:
payload: The folder creation request.
user_id: ID of the authenticated user.
Returns:
The created LibraryFolder.
"""
return await library_db.create_folder(
user_id=user_id,
name=payload.name,
parent_id=payload.parent_id,
icon=payload.icon,
color=payload.color,
)
@router.patch(
"/{folder_id}",
summary="Update Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder updated successfully"},
400: {"description": "Validation error"},
404: {"description": "Folder not found"},
409: {"description": "Folder name conflict"},
500: {"description": "Server error"},
},
)
async def update_folder(
folder_id: str,
payload: library_model.FolderUpdateRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Update a folder's properties.
Args:
folder_id: ID of the folder to update.
payload: The folder update request.
user_id: ID of the authenticated user.
Returns:
The updated LibraryFolder.
"""
return await library_db.update_folder(
folder_id=folder_id,
user_id=user_id,
name=payload.name,
icon=payload.icon,
color=payload.color,
)
@router.post(
"/{folder_id}/move",
summary="Move Folder",
response_model=library_model.LibraryFolder,
responses={
200: {"description": "Folder moved successfully"},
400: {"description": "Validation error (circular reference)"},
404: {"description": "Folder or target parent not found"},
409: {"description": "Folder name conflict in target location"},
500: {"description": "Server error"},
},
)
async def move_folder(
folder_id: str,
payload: library_model.FolderMoveRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> library_model.LibraryFolder:
"""
Move a folder to a new parent.
Args:
folder_id: ID of the folder to move.
payload: The move request with target parent.
user_id: ID of the authenticated user.
Returns:
The moved LibraryFolder.
"""
return await library_db.move_folder(
folder_id=folder_id,
user_id=user_id,
target_parent_id=payload.target_parent_id,
)
@router.delete(
"/{folder_id}",
summary="Delete Folder",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "Folder deleted successfully"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def delete_folder(
folder_id: str,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> Response:
"""
Soft-delete a folder and all its contents.
Args:
folder_id: ID of the folder to delete.
user_id: ID of the authenticated user.
Returns:
204 No Content if successful.
"""
await library_db.delete_folder(
folder_id=folder_id,
user_id=user_id,
soft_delete=True,
)
return Response(status_code=status.HTTP_204_NO_CONTENT)
# === Bulk Agent Operations ===
@router.post(
"/agents/bulk-move",
summary="Bulk Move Agents",
response_model=list[library_model.LibraryAgent],
responses={
200: {"description": "Agents moved successfully"},
404: {"description": "Folder not found"},
500: {"description": "Server error"},
},
)
async def bulk_move_agents(
payload: library_model.BulkMoveAgentsRequest,
user_id: str = Security(autogpt_auth_lib.get_user_id),
) -> list[library_model.LibraryAgent]:
"""
Move multiple agents to a folder.
Args:
payload: The bulk move request with agent IDs and target folder.
user_id: ID of the authenticated user.
Returns:
The updated LibraryAgents.
"""
return await library_db.bulk_move_agents_to_folder(
agent_ids=payload.agent_ids,
folder_id=payload.folder_id,
user_id=user_id,
)

View File

@@ -115,8 +115,6 @@ async def test_get_library_agents_success(
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
page=1,
page_size=15,
folder_id=None,
include_root_only=False,
)

View File

@@ -7,24 +7,20 @@ frontend can list available tools on an MCP server before placing a block.
import logging
from typing import Annotated, Any
from urllib.parse import urlparse
import fastapi
from autogpt_libs.auth import get_user_id
from fastapi import Security
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field
from backend.api.features.integrations.router import CredentialsMetaResponse
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.helpers import (
auto_lookup_mcp_credential,
normalize_mcp_url,
server_host,
)
from backend.blocks.mcp.oauth import MCPOAuthHandler
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
from backend.util.request import HTTPClientError, Requests, validate_url
from backend.util.request import HTTPClientError, Requests
from backend.util.settings import Settings
logger = logging.getLogger(__name__)
@@ -78,20 +74,32 @@ async def discover_tools(
If the user has a stored MCP credential for this server URL, it will be
used automatically — no need to pass an explicit auth token.
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
auth_token = request.auth_token
# Auto-use stored MCP credential when no explicit token is provided.
if not auth_token:
best_cred = await auto_lookup_mcp_credential(
user_id, normalize_mcp_url(request.server_url)
mcp_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
# Find the freshest credential for this server URL
best_cred: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == request.server_url
):
if best_cred is None or (
(cred.access_token_expires_at or 0)
> (best_cred.access_token_expires_at or 0)
):
best_cred = cred
if best_cred:
# Refresh the token if expired before using it
best_cred = await creds_manager.refresh_if_needed(user_id, best_cred)
logger.info(
f"Using MCP credential {best_cred.id} for {request.server_url}, "
f"expires_at={best_cred.access_token_expires_at}"
)
auth_token = best_cred.access_token.get_secret_value()
client = MCPClient(request.server_url, auth_token=auth_token)
@@ -126,7 +134,7 @@ async def discover_tools(
],
server_name=(
init_result.get("serverInfo", {}).get("name")
or server_host(request.server_url)
or urlparse(request.server_url).hostname
or "MCP"
),
protocol_version=init_result.get("protocolVersion"),
@@ -165,16 +173,7 @@ async def mcp_oauth_login(
3. Performs Dynamic Client Registration (RFC 7591) if available
4. Returns the authorization URL for the frontend to open in a popup
"""
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
# Normalize the URL so that credentials stored here are matched consistently
# by auto_lookup_mcp_credential (which also uses normalized URLs).
server_url = normalize_mcp_url(request.server_url)
client = MCPClient(server_url)
client = MCPClient(request.server_url)
# Step 1: Discover protected-resource metadata (RFC 9728)
protected_resource = await client.discover_auth()
@@ -183,16 +182,7 @@ async def mcp_oauth_login(
if protected_resource and protected_resource.get("authorization_servers"):
auth_server_url = protected_resource["authorization_servers"][0]
resource_url = protected_resource.get("resource", server_url)
# Validate the auth server URL from metadata to prevent SSRF.
try:
await validate_url(auth_server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(
status_code=400,
detail=f"Invalid authorization server URL in metadata: {e}",
)
resource_url = protected_resource.get("resource", request.server_url)
# Step 2a: Discover auth-server metadata (RFC 8414)
metadata = await client.discover_auth_server_metadata(auth_server_url)
@@ -202,7 +192,7 @@ async def mcp_oauth_login(
# Don't assume a resource_url — omitting it lets the auth server choose
# the correct audience for the token (RFC 8707 resource is optional).
resource_url = None
metadata = await client.discover_auth_server_metadata(server_url)
metadata = await client.discover_auth_server_metadata(request.server_url)
if (
not metadata
@@ -232,18 +222,12 @@ async def mcp_oauth_login(
client_id = ""
client_secret = ""
if registration_endpoint:
# Validate the registration endpoint to prevent SSRF via metadata.
try:
await validate_url(registration_endpoint, trusted_origins=[])
except ValueError:
pass # Skip registration, fall back to default client_id
else:
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
reg_result = await _register_mcp_client(
registration_endpoint, redirect_uri, request.server_url
)
if reg_result:
client_id = reg_result.get("client_id", "")
client_secret = reg_result.get("client_secret", "")
if not client_id:
client_id = "autogpt-platform"
@@ -261,7 +245,7 @@ async def mcp_oauth_login(
"token_url": token_url,
"revoke_url": revoke_url,
"resource_url": resource_url,
"server_url": server_url,
"server_url": request.server_url,
"client_id": client_id,
"client_secret": client_secret,
},
@@ -358,7 +342,7 @@ async def mcp_oauth_callback(
credentials.metadata["mcp_token_url"] = meta["token_url"]
credentials.metadata["mcp_resource_url"] = meta.get("resource_url", "")
hostname = server_host(meta["server_url"])
hostname = urlparse(meta["server_url"]).hostname or meta["server_url"]
credentials.title = f"MCP: {hostname}"
# Remove old MCP credentials for the same server to prevent stale token buildup.
@@ -373,9 +357,7 @@ async def mcp_oauth_callback(
):
await creds_manager.store.delete_creds_by_id(user_id, old.id)
logger.info(
"Removed old MCP credential %s for %s",
old.id,
server_host(meta["server_url"]),
f"Removed old MCP credential {old.id} for {meta['server_url']}"
)
except Exception:
logger.debug("Could not clean up old MCP credentials", exc_info=True)
@@ -393,93 +375,6 @@ async def mcp_oauth_callback(
)
# ======================== Bearer Token ======================== #
class MCPStoreTokenRequest(BaseModel):
"""Request to store a bearer token for an MCP server that doesn't support OAuth."""
server_url: str = Field(
description="MCP server URL the token authenticates against"
)
token: SecretStr = Field(
min_length=1, description="Bearer token / API key for the MCP server"
)
@router.post(
"/token",
summary="Store a bearer token for an MCP server",
)
async def mcp_store_token(
request: MCPStoreTokenRequest,
user_id: Annotated[str, Security(get_user_id)],
) -> CredentialsMetaResponse:
"""
Store a manually provided bearer token as an MCP credential.
Used by the Copilot MCPSetupCard when the server doesn't support the MCP
OAuth discovery flow (returns 400 from /oauth/login). Subsequent
``run_mcp_tool`` calls will automatically pick up the token via
``_auto_lookup_credential``.
"""
token = request.token.get_secret_value().strip()
if not token:
raise fastapi.HTTPException(status_code=422, detail="Token must not be blank.")
# Validate URL to prevent SSRF — blocks loopback and private IP ranges.
try:
await validate_url(request.server_url, trusted_origins=[])
except ValueError as e:
raise fastapi.HTTPException(status_code=400, detail=f"Invalid server URL: {e}")
# Normalize URL so trailing-slash variants match existing credentials.
server_url = normalize_mcp_url(request.server_url)
hostname = server_host(server_url)
# Collect IDs of old credentials to clean up after successful create.
old_cred_ids: list[str] = []
try:
old_creds = await creds_manager.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
old_cred_ids = [
old.id
for old in old_creds
if isinstance(old, OAuth2Credentials)
and normalize_mcp_url((old.metadata or {}).get("mcp_server_url", ""))
== server_url
]
except Exception:
logger.debug("Could not query old MCP token credentials", exc_info=True)
credentials = OAuth2Credentials(
provider=ProviderName.MCP.value,
title=f"MCP: {hostname}",
access_token=SecretStr(token),
scopes=[],
metadata={"mcp_server_url": server_url},
)
await creds_manager.create(user_id, credentials)
# Only delete old credentials after the new one is safely stored.
for old_id in old_cred_ids:
try:
await creds_manager.store.delete_creds_by_id(user_id, old_id)
except Exception:
logger.debug("Could not clean up old MCP token credential", exc_info=True)
return CredentialsMetaResponse(
id=credentials.id,
provider=credentials.provider,
type=credentials.type,
title=credentials.title,
scopes=credentials.scopes,
username=credentials.username,
host=hostname,
)
# ======================== Helpers ======================== #
@@ -505,7 +400,5 @@ async def _register_mcp_client(
return data
return None
except Exception as e:
logger.warning(
"Dynamic client registration failed for %s: %s", server_host(server_url), e
)
logger.warning(f"Dynamic client registration failed for {server_url}: {e}")
return None

View File

@@ -11,11 +11,9 @@ import httpx
import pytest
import pytest_asyncio
from autogpt_libs.auth import get_user_id
from pydantic import SecretStr
from backend.api.features.mcp.routes import router
from backend.blocks.mcp.client import MCPClientError, MCPTool
from backend.data.model import OAuth2Credentials
from backend.util.request import HTTPClientError
app = fastapi.FastAPI()
@@ -30,16 +28,6 @@ async def client():
yield c
@pytest.fixture(autouse=True)
def _bypass_ssrf_validation():
"""Bypass validate_url in all route tests (test URLs don't resolve)."""
with patch(
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
):
yield
class TestDiscoverTools:
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_success(self, client):
@@ -68,12 +56,9 @@ class TestDiscoverTools:
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={
@@ -122,6 +107,10 @@ class TestDiscoverTools:
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_auto_uses_stored_credential(self, client):
"""When no explicit token is given, stored MCP credentials are used."""
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
stored_cred = OAuth2Credentials(
provider="mcp",
title="MCP: example.com",
@@ -135,12 +124,10 @@ class TestDiscoverTools:
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=stored_cred,
),
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[stored_cred])
mock_cm.refresh_if_needed = AsyncMock(return_value=stored_cred)
instance = MockClient.return_value
instance.initialize = AsyncMock(
return_value={"serverInfo": {}, "protocolVersion": "2025-03-26"}
@@ -162,12 +149,9 @@ class TestDiscoverTools:
async def test_discover_tools_mcp_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=MCPClientError("Connection refused")
@@ -185,12 +169,9 @@ class TestDiscoverTools:
async def test_discover_tools_generic_error(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(side_effect=Exception("Network timeout"))
@@ -206,12 +187,9 @@ class TestDiscoverTools:
async def test_discover_tools_auth_required(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 401 Error: Unauthorized", 401)
@@ -229,12 +207,9 @@ class TestDiscoverTools:
async def test_discover_tools_forbidden(self, client):
with (
patch("backend.api.features.mcp.routes.MCPClient") as MockClient,
patch(
"backend.api.features.mcp.routes.auto_lookup_mcp_credential",
new_callable=AsyncMock,
return_value=None,
),
patch("backend.api.features.mcp.routes.creds_manager") as mock_cm,
):
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
instance = MockClient.return_value
instance.initialize = AsyncMock(
side_effect=HTTPClientError("HTTP 403 Error: Forbidden", 403)
@@ -356,6 +331,10 @@ class TestOAuthLogin:
class TestOAuthCallback:
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_callback_success(self, client):
from pydantic import SecretStr
from backend.data.model import OAuth2Credentials
mock_creds = OAuth2Credentials(
provider="mcp",
title=None,
@@ -455,118 +434,3 @@ class TestOAuthCallback:
assert response.status_code == 400
assert "token exchange failed" in response.json()["detail"].lower()
class TestStoreToken:
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_success(self, client):
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[])
mock_cm.create = AsyncMock()
response = await client.post(
"/token",
json={
"server_url": "https://mcp.example.com/mcp",
"token": "my-api-key-123",
},
)
assert response.status_code == 200
data = response.json()
assert data["provider"] == "mcp"
assert data["type"] == "oauth2"
assert data["host"] == "mcp.example.com"
mock_cm.create.assert_called_once()
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_blank_rejected(self, client):
"""Blank token string (after stripping) should return 422."""
response = await client.post(
"/token",
json={
"server_url": "https://mcp.example.com/mcp",
"token": " ",
},
)
# Pydantic min_length=1 catches the whitespace-only token
assert response.status_code == 422
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_replaces_old_credential(self, client):
old_cred = OAuth2Credentials(
provider="mcp",
title="MCP: mcp.example.com",
access_token=SecretStr("old-token"),
scopes=[],
metadata={"mcp_server_url": "https://mcp.example.com/mcp"},
)
with patch("backend.api.features.mcp.routes.creds_manager") as mock_cm:
mock_cm.store.get_creds_by_provider = AsyncMock(return_value=[old_cred])
mock_cm.create = AsyncMock()
mock_cm.store.delete_creds_by_id = AsyncMock()
response = await client.post(
"/token",
json={
"server_url": "https://mcp.example.com/mcp",
"token": "new-token",
},
)
assert response.status_code == 200
mock_cm.store.delete_creds_by_id.assert_called_once_with(
"test-user-id", old_cred.id
)
class TestSSRFValidation:
"""Verify that validate_url is enforced on all endpoints."""
@pytest.mark.asyncio(loop_scope="session")
async def test_discover_tools_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
response = await client.post(
"/discover-tools",
json={"server_url": "http://localhost/mcp"},
)
assert response.status_code == 400
assert "blocked loopback" in response.json()["detail"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_oauth_login_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked private IP"),
):
response = await client.post(
"/oauth/login",
json={"server_url": "http://10.0.0.1/mcp"},
)
assert response.status_code == 400
assert "blocked private ip" in response.json()["detail"].lower()
@pytest.mark.asyncio(loop_scope="session")
async def test_store_token_ssrf_blocked(self, client):
with patch(
"backend.api.features.mcp.routes.validate_url",
new_callable=AsyncMock,
side_effect=ValueError("blocked loopback"),
):
response = await client.post(
"/token",
json={
"server_url": "http://127.0.0.1/mcp",
"token": "some-token",
},
)
assert response.status_code == 400
assert "blocked loopback" in response.json()["detail"].lower()

View File

@@ -9,26 +9,15 @@ import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, get_args, get_origin
from typing import Any
from prisma.enums import ContentType
from backend.blocks.llm import LlmModel
from backend.data.db import query_raw_with_schema
logger = logging.getLogger(__name__)
def _contains_type(annotation: Any, target: type) -> bool:
"""Check if an annotation is or contains the target type (handles Optional/Union/Annotated)."""
if annotation is target:
return True
origin = get_origin(annotation)
if origin is None:
return False
return any(_contains_type(arg, target) for arg in get_args(annotation))
@dataclass
class ContentItem:
"""Represents a piece of content to be embedded."""
@@ -199,51 +188,45 @@ class BlockHandler(ContentHandler):
try:
block_instance = block_cls()
# Skip disabled blocks - they shouldn't be indexed
if block_instance.disabled:
continue
# Build searchable text from block metadata
parts = []
if block_instance.name:
if hasattr(block_instance, "name") and block_instance.name:
parts.append(block_instance.name)
if block_instance.description:
if (
hasattr(block_instance, "description")
and block_instance.description
):
parts.append(block_instance.description)
if block_instance.categories:
if hasattr(block_instance, "categories") and block_instance.categories:
# Convert BlockCategory enum to strings
parts.append(
" ".join(str(cat.value) for cat in block_instance.categories)
)
# Add input schema field descriptions
block_input_fields = block_instance.input_schema.model_fields
parts += [
f"{field_name}: {field_info.description}"
for field_name, field_info in block_input_fields.items()
if field_info.description
]
# Add input/output schema info
if hasattr(block_instance, "input_schema"):
schema = block_instance.input_schema
if hasattr(schema, "model_json_schema"):
schema_dict = schema.model_json_schema()
if "properties" in schema_dict:
for prop_name, prop_info in schema_dict[
"properties"
].items():
if "description" in prop_info:
parts.append(
f"{prop_name}: {prop_info['description']}"
)
searchable_text = " ".join(parts)
# Convert categories set of enums to list of strings for JSON serialization
categories = getattr(block_instance, "categories", set())
categories_list = (
[cat.value for cat in block_instance.categories]
if block_instance.categories
else []
)
# Extract provider names from credentials fields
credentials_info = (
block_instance.input_schema.get_credentials_fields_info()
)
is_integration = len(credentials_info) > 0
provider_names = [
provider.value.lower()
for info in credentials_info.values()
for provider in info.provider
]
# Check if block has LlmModel field in input schema
has_llm_model_field = any(
_contains_type(field.annotation, LlmModel)
for field in block_instance.input_schema.model_fields.values()
[cat.value for cat in categories] if categories else []
)
items.append(
@@ -252,11 +235,8 @@ class BlockHandler(ContentHandler):
content_type=ContentType.BLOCK,
searchable_text=searchable_text,
metadata={
"name": block_instance.name,
"name": getattr(block_instance, "name", ""),
"categories": categories_list,
"providers": provider_names,
"has_llm_model_field": has_llm_model_field,
"is_integration": is_integration,
},
user_id=None, # Blocks are public
)

View File

@@ -82,10 +82,9 @@ async def test_block_handler_get_missing_items(mocker):
mock_block_instance.description = "Performs calculations"
mock_block_instance.categories = [MagicMock(value="MATH")]
mock_block_instance.disabled = False
mock_field = MagicMock()
mock_field.description = "Math expression to evaluate"
mock_block_instance.input_schema.model_fields = {"expression": mock_field}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
mock_block_instance.input_schema.model_json_schema.return_value = {
"properties": {"expression": {"description": "Math expression to evaluate"}}
}
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-uuid-1": mock_block_class}
@@ -310,19 +309,19 @@ async def test_content_handlers_registry():
@pytest.mark.asyncio(loop_scope="session")
async def test_block_handler_handles_empty_attributes():
"""Test BlockHandler handles blocks with empty/falsy attribute values."""
async def test_block_handler_handles_missing_attributes():
"""Test BlockHandler gracefully handles blocks with missing attributes."""
handler = BlockHandler()
# Mock block with empty values (all attributes exist but are falsy)
# Mock block with minimal attributes
mock_block_class = MagicMock()
mock_block_instance = MagicMock()
mock_block_instance.name = "Minimal Block"
mock_block_instance.disabled = False
mock_block_instance.description = ""
mock_block_instance.categories = set()
mock_block_instance.input_schema.model_fields = {}
mock_block_instance.input_schema.get_credentials_fields_info.return_value = {}
# No description, categories, or schema
del mock_block_instance.description
del mock_block_instance.categories
del mock_block_instance.input_schema
mock_block_class.return_value = mock_block_instance
mock_blocks = {"block-minimal": mock_block_class}
@@ -353,8 +352,6 @@ async def test_block_handler_skips_failed_blocks():
good_instance.description = "Works fine"
good_instance.categories = []
good_instance.disabled = False
good_instance.input_schema.model_fields = {}
good_instance.input_schema.get_credentials_fields_info.return_value = {}
good_block.return_value = good_instance
bad_block = MagicMock()

View File

@@ -126,9 +126,6 @@ v1_router = APIRouter()
########################################################
_tally_background_tasks: set[asyncio.Task] = set()
@v1_router.post(
"/auth/user",
summary="Get or create user",
@@ -137,24 +134,6 @@ _tally_background_tasks: set[asyncio.Task] = set()
)
async def get_or_create_user_route(user_data: dict = Security(get_jwt_payload)):
user = await get_or_create_user(user_data)
# Fire-and-forget: populate business understanding from Tally form.
# We use created_at proximity instead of an is_new flag because
# get_or_create_user is cached — a separate is_new return value would be
# unreliable on repeated calls within the cache TTL.
age_seconds = (datetime.now(timezone.utc) - user.created_at).total_seconds()
if age_seconds < 30:
try:
from backend.data.tally import populate_understanding_from_tally
task = asyncio.create_task(
populate_understanding_from_tally(user.id, user.email)
)
_tally_background_tasks.add(task)
task.add_done_callback(_tally_background_tasks.discard)
except Exception:
logger.debug("Failed to start Tally population task", exc_info=True)
return user.model_dump()

View File

@@ -1,5 +1,5 @@
import json
from datetime import datetime, timezone
from datetime import datetime
from io import BytesIO
from unittest.mock import AsyncMock, Mock, patch
@@ -43,7 +43,6 @@ def test_get_or_create_user_route(
) -> None:
"""Test get or create user endpoint"""
mock_user = Mock()
mock_user.created_at = datetime.now(timezone.utc)
mock_user.model_dump.return_value = {
"id": test_user_id,
"email": "test@example.com",

View File

@@ -3,29 +3,15 @@ Workspace API routes for managing user file storage.
"""
import logging
import os
import re
from typing import Annotated
from urllib.parse import quote
import fastapi
from autogpt_libs.auth.dependencies import get_user_id, requires_user
from fastapi import Query, UploadFile
from fastapi.responses import Response
from pydantic import BaseModel
from backend.data.workspace import (
WorkspaceFile,
count_workspace_files,
get_or_create_workspace,
get_workspace,
get_workspace_file,
get_workspace_total_size,
soft_delete_workspace_file,
)
from backend.util.settings import Config
from backend.util.virus_scanner import scan_content_safe
from backend.util.workspace import WorkspaceManager
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
from backend.util.workspace_storage import get_workspace_storage
@@ -112,25 +98,6 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
raise
class UploadFileResponse(BaseModel):
file_id: str
name: str
path: str
mime_type: str
size_bytes: int
class DeleteFileResponse(BaseModel):
deleted: bool
class StorageUsageResponse(BaseModel):
used_bytes: int
limit_bytes: int
used_percent: float
file_count: int
@router.get(
"/files/{file_id}/download",
summary="Download file by ID",
@@ -153,148 +120,3 @@ async def download_file(
raise fastapi.HTTPException(status_code=404, detail="File not found")
return await _create_file_download_response(file)
@router.delete(
"/files/{file_id}",
summary="Delete a workspace file",
)
async def delete_workspace_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file_id: str,
) -> DeleteFileResponse:
"""
Soft-delete a workspace file and attempt to remove it from storage.
Used when a user clears a file input in the builder.
"""
workspace = await get_workspace(user_id)
if workspace is None:
raise fastapi.HTTPException(status_code=404, detail="Workspace not found")
manager = WorkspaceManager(user_id, workspace.id)
deleted = await manager.delete_file(file_id)
if not deleted:
raise fastapi.HTTPException(status_code=404, detail="File not found")
return DeleteFileResponse(deleted=True)
@router.post(
"/files/upload",
summary="Upload file to workspace",
)
async def upload_file(
user_id: Annotated[str, fastapi.Security(get_user_id)],
file: UploadFile,
session_id: str | None = Query(default=None),
) -> UploadFileResponse:
"""
Upload a file to the user's workspace.
Files are stored in session-scoped paths when session_id is provided,
so the agent's session-scoped tools can discover them automatically.
"""
config = Config()
# Sanitize filename — strip any directory components
filename = os.path.basename(file.filename or "upload") or "upload"
# Read file content with early abort on size limit
max_file_bytes = config.max_file_size_mb * 1024 * 1024
chunks: list[bytes] = []
total_size = 0
while chunk := await file.read(64 * 1024): # 64KB chunks
total_size += len(chunk)
if total_size > max_file_bytes:
raise fastapi.HTTPException(
status_code=413,
detail=f"File exceeds maximum size of {config.max_file_size_mb} MB",
)
chunks.append(chunk)
content = b"".join(chunks)
# Get or create workspace
workspace = await get_or_create_workspace(user_id)
# Pre-write storage cap check (soft check — final enforcement is post-write)
storage_limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
current_usage = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and current_usage + len(content) > storage_limit_bytes:
used_percent = (current_usage / storage_limit_bytes) * 100
raise fastapi.HTTPException(
status_code=413,
detail={
"message": "Storage limit exceeded",
"used_bytes": current_usage,
"limit_bytes": storage_limit_bytes,
"used_percent": round(used_percent, 1),
},
)
# Warn at 80% usage
if (
storage_limit_bytes
and (usage_ratio := (current_usage + len(content)) / storage_limit_bytes) >= 0.8
):
logger.warning(
f"User {user_id} workspace storage at {usage_ratio * 100:.1f}% "
f"({current_usage + len(content)} / {storage_limit_bytes} bytes)"
)
# Virus scan
await scan_content_safe(content, filename=filename)
# Write file via WorkspaceManager
manager = WorkspaceManager(user_id, workspace.id, session_id)
try:
workspace_file = await manager.write_file(content, filename)
except ValueError as e:
raise fastapi.HTTPException(status_code=409, detail=str(e)) from e
# Post-write storage check — eliminates TOCTOU race on the quota.
# If a concurrent upload pushed us over the limit, undo this write.
new_total = await get_workspace_total_size(workspace.id)
if storage_limit_bytes and new_total > storage_limit_bytes:
await soft_delete_workspace_file(workspace_file.id, workspace.id)
raise fastapi.HTTPException(
status_code=413,
detail={
"message": "Storage limit exceeded (concurrent upload)",
"used_bytes": new_total,
"limit_bytes": storage_limit_bytes,
},
)
return UploadFileResponse(
file_id=workspace_file.id,
name=workspace_file.name,
path=workspace_file.path,
mime_type=workspace_file.mime_type,
size_bytes=workspace_file.size_bytes,
)
@router.get(
"/storage/usage",
summary="Get workspace storage usage",
)
async def get_storage_usage(
user_id: Annotated[str, fastapi.Security(get_user_id)],
) -> StorageUsageResponse:
"""
Get storage usage information for the user's workspace.
"""
config = Config()
workspace = await get_or_create_workspace(user_id)
used_bytes = await get_workspace_total_size(workspace.id)
file_count = await count_workspace_files(workspace.id)
limit_bytes = config.max_workspace_storage_mb * 1024 * 1024
return StorageUsageResponse(
used_bytes=used_bytes,
limit_bytes=limit_bytes,
used_percent=round((used_bytes / limit_bytes) * 100, 1) if limit_bytes else 0,
file_count=file_count,
)

View File

@@ -1,359 +0,0 @@
"""Tests for workspace file upload and download routes."""
import io
from datetime import datetime, timezone
import fastapi
import fastapi.testclient
import pytest
import pytest_mock
from backend.api.features.workspace import routes as workspace_routes
from backend.data.workspace import WorkspaceFile
app = fastapi.FastAPI()
app.include_router(workspace_routes.router)
@app.exception_handler(ValueError)
async def _value_error_handler(
request: fastapi.Request, exc: ValueError
) -> fastapi.responses.JSONResponse:
"""Mirror the production ValueError → 400 mapping from rest_api.py."""
return fastapi.responses.JSONResponse(status_code=400, content={"detail": str(exc)})
client = fastapi.testclient.TestClient(app)
TEST_USER_ID = "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
MOCK_WORKSPACE = type("W", (), {"id": "ws-1"})()
_NOW = datetime(2023, 1, 1, tzinfo=timezone.utc)
MOCK_FILE = WorkspaceFile(
id="file-aaa-bbb",
workspace_id="ws-1",
created_at=_NOW,
updated_at=_NOW,
name="hello.txt",
path="/session/hello.txt",
mime_type="text/plain",
size_bytes=13,
storage_path="local://hello.txt",
)
@pytest.fixture(autouse=True)
def setup_app_auth(mock_jwt_user):
from autogpt_libs.auth.jwt_utils import get_jwt_payload
app.dependency_overrides[get_jwt_payload] = mock_jwt_user["get_jwt_payload"]
yield
app.dependency_overrides.clear()
def _upload(
filename: str = "hello.txt",
content: bytes = b"Hello, world!",
content_type: str = "text/plain",
):
"""Helper to POST a file upload."""
return client.post(
"/files/upload?session_id=sess-1",
files={"file": (filename, io.BytesIO(content), content_type)},
)
# ---- Happy path ----
def test_upload_happy_path(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload()
assert response.status_code == 200
data = response.json()
assert data["file_id"] == "file-aaa-bbb"
assert data["name"] == "hello.txt"
assert data["size_bytes"] == 13
# ---- Per-file size limit ----
def test_upload_exceeds_max_file_size(mocker: pytest_mock.MockFixture):
"""Files larger than max_file_size_mb should be rejected with 413."""
cfg = mocker.patch("backend.api.features.workspace.routes.Config")
cfg.return_value.max_file_size_mb = 0 # 0 MB → any content is too big
cfg.return_value.max_workspace_storage_mb = 500
response = _upload(content=b"x" * 1024)
assert response.status_code == 413
# ---- Storage quota exceeded ----
def test_upload_storage_quota_exceeded(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
# Current usage already at limit
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=500 * 1024 * 1024,
)
response = _upload()
assert response.status_code == 413
assert "Storage limit exceeded" in response.text
# ---- Post-write quota race (B2) ----
def test_upload_post_write_quota_race(mocker: pytest_mock.MockFixture):
"""If a concurrent upload tips the total over the limit after write,
the file should be soft-deleted and 413 returned."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
# Pre-write check passes (under limit), but post-write check fails
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
side_effect=[0, 600 * 1024 * 1024], # first call OK, second over limit
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
mock_delete = mocker.patch(
"backend.api.features.workspace.routes.soft_delete_workspace_file",
return_value=None,
)
response = _upload()
assert response.status_code == 413
mock_delete.assert_called_once_with("file-aaa-bbb", "ws-1")
# ---- Any extension accepted (no allowlist) ----
def test_upload_any_extension(mocker: pytest_mock.MockFixture):
"""Any file extension should be accepted — ClamAV is the security layer."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload(filename="data.xyz", content=b"arbitrary")
assert response.status_code == 200
# ---- Virus scan rejection ----
def test_upload_blocked_by_virus_scan(mocker: pytest_mock.MockFixture):
"""Files flagged by ClamAV should be rejected and never written to storage."""
from backend.api.features.store.exceptions import VirusDetectedError
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
side_effect=VirusDetectedError("Eicar-Test-Signature"),
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload(filename="evil.exe", content=b"X5O!P%@AP...")
assert response.status_code == 400
assert "Virus detected" in response.text
mock_manager.write_file.assert_not_called()
# ---- No file extension ----
def test_upload_file_without_extension(mocker: pytest_mock.MockFixture):
"""Files without an extension should be accepted and stored as-is."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = _upload(
filename="Makefile",
content=b"all:\n\techo hello",
content_type="application/octet-stream",
)
assert response.status_code == 200
mock_manager.write_file.assert_called_once()
assert mock_manager.write_file.call_args[0][1] == "Makefile"
# ---- Filename sanitization (SF5) ----
def test_upload_strips_path_components(mocker: pytest_mock.MockFixture):
"""Path-traversal filenames should be reduced to their basename."""
mocker.patch(
"backend.api.features.workspace.routes.get_or_create_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_total_size",
return_value=0,
)
mocker.patch(
"backend.api.features.workspace.routes.scan_content_safe",
return_value=None,
)
mock_manager = mocker.MagicMock()
mock_manager.write_file = mocker.AsyncMock(return_value=MOCK_FILE)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
# Filename with traversal
_upload(filename="../../etc/passwd.txt")
# write_file should have been called with just the basename
mock_manager.write_file.assert_called_once()
call_args = mock_manager.write_file.call_args
assert call_args[0][1] == "passwd.txt"
# ---- Download ----
def test_download_file_not_found(mocker: pytest_mock.MockFixture):
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
)
mocker.patch(
"backend.api.features.workspace.routes.get_workspace_file",
return_value=None,
)
response = client.get("/files/some-file-id/download")
assert response.status_code == 404
# ---- Delete ----
def test_delete_file_success(mocker: pytest_mock.MockFixture):
"""Deleting an existing file should return {"deleted": true}."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=True)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = client.delete("/files/file-aaa-bbb")
assert response.status_code == 200
assert response.json() == {"deleted": True}
mock_manager.delete_file.assert_called_once_with("file-aaa-bbb")
def test_delete_file_not_found(mocker: pytest_mock.MockFixture):
"""Deleting a non-existent file should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=MOCK_WORKSPACE,
)
mock_manager = mocker.MagicMock()
mock_manager.delete_file = mocker.AsyncMock(return_value=False)
mocker.patch(
"backend.api.features.workspace.routes.WorkspaceManager",
return_value=mock_manager,
)
response = client.delete("/files/nonexistent-id")
assert response.status_code == 404
assert "File not found" in response.text
def test_delete_file_no_workspace(mocker: pytest_mock.MockFixture):
"""Deleting when user has no workspace should return 404."""
mocker.patch(
"backend.api.features.workspace.routes.get_workspace",
return_value=None,
)
response = client.delete("/files/file-aaa-bbb")
assert response.status_code == 404
assert "Workspace not found" in response.text

View File

@@ -41,11 +41,11 @@ import backend.data.user
import backend.integrations.webhooks.utils
import backend.util.service
import backend.util.settings
from backend.api.features.library.exceptions import (
FolderAlreadyExistsError,
FolderValidationError,
)
from backend.blocks.llm import DEFAULT_LLM_MODEL
from backend.copilot.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
from backend.data.model import Credentials
from backend.integrations.providers import ProviderName
from backend.monitoring.instrumentation import instrument_fastapi
@@ -123,9 +123,21 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for Redis Streams notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e:
@@ -265,10 +277,6 @@ async def validation_error_handler(
app.add_exception_handler(PrismaError, handle_internal_http_error(500))
app.add_exception_handler(
FolderAlreadyExistsError, handle_internal_http_error(409, False)
)
app.add_exception_handler(FolderValidationError, handle_internal_http_error(400, False))
app.add_exception_handler(NotFoundError, handle_internal_http_error(404, False))
app.add_exception_handler(NotAuthorizedError, handle_internal_http_error(403, False))
app.add_exception_handler(RequestValidationError, validation_error_handler)

View File

@@ -24,7 +24,7 @@ def run_processes(*processes: "AppProcess", **kwargs):
# Run the last process in the foreground.
processes[-1].start(background=False, **kwargs)
finally:
for process in reversed(processes):
for process in processes:
try:
process.stop()
except Exception as e:

View File

@@ -6,6 +6,7 @@ and execute them. Works like AgentExecutorBlock — the user selects a tool from
dropdown and the input/output schema adapts dynamically.
"""
import json
import logging
from typing import Any, Literal
@@ -19,11 +20,6 @@ from backend.blocks._base import (
BlockType,
)
from backend.blocks.mcp.client import MCPClient, MCPClientError
from backend.blocks.mcp.helpers import (
auto_lookup_mcp_credential,
normalize_mcp_url,
parse_mcp_content,
)
from backend.data.block import BlockInput, BlockOutput
from backend.data.model import (
CredentialsField,
@@ -183,7 +179,31 @@ class MCPToolBlock(Block):
f"{error_text or 'Unknown error'}"
)
return parse_mcp_content(result.content)
# Extract text content from the result
output_parts = []
for item in result.content:
if item.get("type") == "text":
text = item.get("text", "")
# Try to parse as JSON for structured output
try:
output_parts.append(json.loads(text))
except (json.JSONDecodeError, ValueError):
output_parts.append(text)
elif item.get("type") == "image":
output_parts.append(
{
"type": "image",
"data": item.get("data"),
"mimeType": item.get("mimeType"),
}
)
elif item.get("type") == "resource":
output_parts.append(item.get("resource", {}))
# If single result, unwrap
if len(output_parts) == 1:
return output_parts[0]
return output_parts if output_parts else None
@staticmethod
async def _auto_lookup_credential(
@@ -191,10 +211,37 @@ class MCPToolBlock(Block):
) -> "OAuth2Credentials | None":
"""Auto-lookup stored MCP credential for a server URL.
Delegates to :func:`~backend.blocks.mcp.helpers.auto_lookup_mcp_credential`.
The caller should pass a normalized URL.
This is a fallback for nodes that don't have ``credentials`` explicitly
set (e.g. nodes created before the credential field was wired up).
"""
return await auto_lookup_mcp_credential(user_id, server_url)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
try:
mgr = IntegrationCredentialsManager()
mcp_creds = await mgr.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
best: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == server_url
):
if best is None or (
(cred.access_token_expires_at or 0)
> (best.access_token_expires_at or 0)
):
best = cred
if best:
best = await mgr.refresh_if_needed(user_id, best)
logger.info(
"Auto-resolved MCP credential %s for %s", best.id, server_url
)
return best
except Exception:
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
return None
async def run(
self,
@@ -231,7 +278,7 @@ class MCPToolBlock(Block):
# the stored MCP credential for this server URL.
if credentials is None:
credentials = await self._auto_lookup_credential(
user_id, normalize_mcp_url(input_data.server_url)
user_id, input_data.server_url
)
auth_token = (

View File

@@ -55,9 +55,7 @@ class MCPClient:
server_url: str,
auth_token: str | None = None,
):
from backend.blocks.mcp.helpers import normalize_mcp_url
self.server_url = normalize_mcp_url(server_url)
self.server_url = server_url.rstrip("/")
self.auth_token = auth_token
self._request_id = 0
self._session_id: str | None = None

View File

@@ -1,117 +0,0 @@
"""Shared MCP helpers used by blocks, copilot tools, and API routes."""
from __future__ import annotations
import json
import logging
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
if TYPE_CHECKING:
from backend.data.model import OAuth2Credentials
logger = logging.getLogger(__name__)
def normalize_mcp_url(url: str) -> str:
"""Normalize an MCP server URL for consistent credential matching.
Strips leading/trailing whitespace and a single trailing slash so that
``https://mcp.example.com/`` and ``https://mcp.example.com`` resolve to
the same stored credential.
"""
return url.strip().rstrip("/")
def server_host(server_url: str) -> str:
"""Extract the hostname from a server URL for display purposes.
Uses ``parsed.hostname`` (never ``netloc``) to strip any embedded
username/password before surfacing the value in UI messages.
"""
try:
parsed = urlparse(server_url)
return parsed.hostname or server_url
except Exception:
return server_url
def parse_mcp_content(content: list[dict[str, Any]]) -> Any:
"""Parse MCP tool response content into a plain Python value.
- text items: parsed as JSON when possible, kept as str otherwise
- image items: kept as ``{type, data, mimeType}`` dict for frontend rendering
- resource items: unwrapped to their resource payload dict
Single-item responses are unwrapped from the list; multiple items are
returned as a list; empty content returns ``None``.
"""
output_parts: list[Any] = []
for item in content:
item_type = item.get("type")
if item_type == "text":
text = item.get("text", "")
try:
output_parts.append(json.loads(text))
except (json.JSONDecodeError, ValueError):
output_parts.append(text)
elif item_type == "image":
output_parts.append(
{
"type": "image",
"data": item.get("data"),
"mimeType": item.get("mimeType"),
}
)
elif item_type == "resource":
output_parts.append(item.get("resource", {}))
if len(output_parts) == 1:
return output_parts[0]
return output_parts or None
async def auto_lookup_mcp_credential(
user_id: str, server_url: str
) -> OAuth2Credentials | None:
"""Look up the best stored MCP credential for *server_url*.
The caller should pass a **normalized** URL (via :func:`normalize_mcp_url`)
so the comparison with ``mcp_server_url`` in credential metadata matches.
Returns the credential with the latest ``access_token_expires_at``, refreshed
if needed, or ``None`` when no match is found.
"""
from backend.data.model import OAuth2Credentials
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.integrations.providers import ProviderName
try:
mgr = IntegrationCredentialsManager()
mcp_creds = await mgr.store.get_creds_by_provider(
user_id, ProviderName.MCP.value
)
# Collect all matching credentials and pick the best one.
# Primary sort: latest access_token_expires_at (tokens with expiry
# are preferred over non-expiring ones). Secondary sort: last in
# iteration order, which corresponds to the most recently created
# row — this acts as a tiebreaker when multiple bearer tokens have
# no expiry (e.g. after a failed old-credential cleanup).
best: OAuth2Credentials | None = None
for cred in mcp_creds:
if (
isinstance(cred, OAuth2Credentials)
and (cred.metadata or {}).get("mcp_server_url") == server_url
):
if best is None or (
(cred.access_token_expires_at or 0)
>= (best.access_token_expires_at or 0)
):
best = cred
if best:
best = await mgr.refresh_if_needed(user_id, best)
logger.info("Auto-resolved MCP credential %s for %s", best.id, server_url)
return best
except Exception:
logger.warning("Auto-lookup MCP credential failed", exc_info=True)
return None

View File

@@ -1,98 +0,0 @@
"""Unit tests for the shared MCP helpers."""
from backend.blocks.mcp.helpers import normalize_mcp_url, parse_mcp_content, server_host
# ---------------------------------------------------------------------------
# normalize_mcp_url
# ---------------------------------------------------------------------------
def test_normalize_trailing_slash():
assert normalize_mcp_url("https://mcp.example.com/") == "https://mcp.example.com"
def test_normalize_whitespace():
assert normalize_mcp_url(" https://mcp.example.com ") == "https://mcp.example.com"
def test_normalize_both():
assert (
normalize_mcp_url(" https://mcp.example.com/ ") == "https://mcp.example.com"
)
def test_normalize_noop():
assert normalize_mcp_url("https://mcp.example.com") == "https://mcp.example.com"
def test_normalize_path_with_trailing_slash():
assert (
normalize_mcp_url("https://mcp.example.com/path/")
== "https://mcp.example.com/path"
)
# ---------------------------------------------------------------------------
# server_host
# ---------------------------------------------------------------------------
def test_server_host_standard_url():
assert server_host("https://mcp.example.com/mcp") == "mcp.example.com"
def test_server_host_strips_credentials():
"""hostname must not expose user:pass."""
assert server_host("https://user:secret@mcp.example.com/mcp") == "mcp.example.com"
def test_server_host_with_port():
"""Port should not appear in hostname (hostname strips it)."""
assert server_host("https://mcp.example.com:8080/mcp") == "mcp.example.com"
def test_server_host_fallback():
"""Falls back to the raw string for un-parseable URLs."""
assert server_host("not-a-url") == "not-a-url"
# ---------------------------------------------------------------------------
# parse_mcp_content
# ---------------------------------------------------------------------------
def test_parse_text_plain():
assert parse_mcp_content([{"type": "text", "text": "hello world"}]) == "hello world"
def test_parse_text_json():
content = [{"type": "text", "text": '{"status": "ok", "count": 42}'}]
assert parse_mcp_content(content) == {"status": "ok", "count": 42}
def test_parse_image():
content = [{"type": "image", "data": "abc123==", "mimeType": "image/png"}]
assert parse_mcp_content(content) == {
"type": "image",
"data": "abc123==",
"mimeType": "image/png",
}
def test_parse_resource():
content = [
{"type": "resource", "resource": {"uri": "file:///tmp/out.txt", "text": "hi"}}
]
assert parse_mcp_content(content) == {"uri": "file:///tmp/out.txt", "text": "hi"}
def test_parse_multi_item():
content = [
{"type": "text", "text": "first"},
{"type": "text", "text": "second"},
]
assert parse_mcp_content(content) == ["first", "second"]
def test_parse_empty():
assert parse_mcp_content([]) is None

View File

@@ -1,182 +0,0 @@
"""
Telegram Bot API helper functions.
Provides utilities for making authenticated requests to the Telegram Bot API.
"""
import logging
from io import BytesIO
from typing import Any, Optional
from pydantic import BaseModel
from backend.data.model import APIKeyCredentials
from backend.util.request import Requests
logger = logging.getLogger(__name__)
TELEGRAM_API_BASE = "https://api.telegram.org"
class TelegramMessageResult(BaseModel, extra="allow"):
"""Result from Telegram send/edit message API calls."""
message_id: int = 0
chat: dict[str, Any] = {}
date: int = 0
text: str = ""
class TelegramFileResult(BaseModel, extra="allow"):
"""Result from Telegram getFile API call."""
file_id: str = ""
file_unique_id: str = ""
file_size: int = 0
file_path: str = ""
class TelegramAPIException(ValueError):
"""Exception raised for Telegram API errors."""
def __init__(self, message: str, error_code: int = 0):
super().__init__(message)
self.error_code = error_code
def get_bot_api_url(bot_token: str, method: str) -> str:
"""Construct Telegram Bot API URL for a method."""
return f"{TELEGRAM_API_BASE}/bot{bot_token}/{method}"
def get_file_url(bot_token: str, file_path: str) -> str:
"""Construct Telegram file download URL."""
return f"{TELEGRAM_API_BASE}/file/bot{bot_token}/{file_path}"
async def call_telegram_api(
credentials: APIKeyCredentials,
method: str,
data: Optional[dict[str, Any]] = None,
) -> TelegramMessageResult:
"""
Make a request to the Telegram Bot API.
Args:
credentials: Bot token credentials
method: API method name (e.g., "sendMessage", "getFile")
data: Request parameters
Returns:
API response result
Raises:
TelegramAPIException: If the API returns an error
"""
token = credentials.api_key.get_secret_value()
url = get_bot_api_url(token, method)
response = await Requests().post(url, json=data or {})
result = response.json()
if not result.get("ok"):
error_code = result.get("error_code", 0)
description = result.get("description", "Unknown error")
raise TelegramAPIException(description, error_code)
return TelegramMessageResult(**result.get("result", {}))
async def call_telegram_api_with_file(
credentials: APIKeyCredentials,
method: str,
file_field: str,
file_data: bytes,
filename: str,
content_type: str,
data: Optional[dict[str, Any]] = None,
) -> TelegramMessageResult:
"""
Make a multipart/form-data request to the Telegram Bot API with a file upload.
Args:
credentials: Bot token credentials
method: API method name (e.g., "sendPhoto", "sendVoice")
file_field: Form field name for the file (e.g., "photo", "voice")
file_data: Raw file bytes
filename: Filename for the upload
content_type: MIME type of the file
data: Additional form parameters
Returns:
API response result
Raises:
TelegramAPIException: If the API returns an error
"""
token = credentials.api_key.get_secret_value()
url = get_bot_api_url(token, method)
files = [(file_field, (filename, BytesIO(file_data), content_type))]
response = await Requests().post(url, files=files, data=data or {})
result = response.json()
if not result.get("ok"):
error_code = result.get("error_code", 0)
description = result.get("description", "Unknown error")
raise TelegramAPIException(description, error_code)
return TelegramMessageResult(**result.get("result", {}))
async def get_file_info(
credentials: APIKeyCredentials, file_id: str
) -> TelegramFileResult:
"""
Get file information from Telegram.
Args:
credentials: Bot token credentials
file_id: Telegram file_id from message
Returns:
File info dict containing file_id, file_unique_id, file_size, file_path
"""
result = await call_telegram_api(credentials, "getFile", {"file_id": file_id})
return TelegramFileResult(**result.model_dump())
async def get_file_download_url(credentials: APIKeyCredentials, file_id: str) -> str:
"""
Get the download URL for a Telegram file.
Args:
credentials: Bot token credentials
file_id: Telegram file_id from message
Returns:
Full download URL
"""
token = credentials.api_key.get_secret_value()
result = await get_file_info(credentials, file_id)
file_path = result.file_path
if not file_path:
raise TelegramAPIException("No file_path returned from getFile")
return get_file_url(token, file_path)
async def download_telegram_file(credentials: APIKeyCredentials, file_id: str) -> bytes:
"""
Download a file from Telegram servers.
Args:
credentials: Bot token credentials
file_id: Telegram file_id
Returns:
File content as bytes
"""
url = await get_file_download_url(credentials, file_id)
response = await Requests().get(url)
return response.content

View File

@@ -1,43 +0,0 @@
"""
Telegram Bot credentials handling.
Telegram bots use an API key (bot token) obtained from @BotFather.
"""
from typing import Literal
from pydantic import SecretStr
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
from backend.integrations.providers import ProviderName
# Bot token credentials (API key style)
TelegramCredentials = APIKeyCredentials
TelegramCredentialsInput = CredentialsMetaInput[
Literal[ProviderName.TELEGRAM], Literal["api_key"]
]
def TelegramCredentialsField() -> TelegramCredentialsInput:
"""Creates a Telegram bot token credentials field."""
return CredentialsField(
description="Telegram Bot API token from @BotFather. "
"Create a bot at https://t.me/BotFather to get your token."
)
# Test credentials for unit tests
TEST_CREDENTIALS = APIKeyCredentials(
id="01234567-89ab-cdef-0123-456789abcdef",
provider="telegram",
api_key=SecretStr("test_telegram_bot_token"),
title="Mock Telegram Bot Token",
expires_at=None,
)
TEST_CREDENTIALS_INPUT = {
"provider": TEST_CREDENTIALS.provider,
"id": TEST_CREDENTIALS.id,
"type": TEST_CREDENTIALS.type,
"title": TEST_CREDENTIALS.title,
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,377 +0,0 @@
"""
Telegram trigger blocks for receiving messages via webhooks.
"""
import logging
from pydantic import BaseModel
from backend.blocks._base import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
BlockWebhookConfig,
)
from backend.data.model import SchemaField
from backend.integrations.providers import ProviderName
from backend.integrations.webhooks.telegram import TelegramWebhookType
from ._auth import (
TEST_CREDENTIALS,
TEST_CREDENTIALS_INPUT,
TelegramCredentialsField,
TelegramCredentialsInput,
)
logger = logging.getLogger(__name__)
# Example payload for testing
EXAMPLE_MESSAGE_PAYLOAD = {
"update_id": 123456789,
"message": {
"message_id": 1,
"from": {
"id": 12345678,
"is_bot": False,
"first_name": "John",
"last_name": "Doe",
"username": "johndoe",
"language_code": "en",
},
"chat": {
"id": 12345678,
"first_name": "John",
"last_name": "Doe",
"username": "johndoe",
"type": "private",
},
"date": 1234567890,
"text": "Hello, bot!",
},
}
class TelegramTriggerBase:
"""Base class for Telegram trigger blocks."""
class Input(BlockSchemaInput):
credentials: TelegramCredentialsInput = TelegramCredentialsField()
payload: dict = SchemaField(hidden=True, default_factory=dict)
class TelegramMessageTriggerBlock(TelegramTriggerBase, Block):
"""
Triggers when a message is received or edited in your Telegram bot.
Supports text, photos, voice messages, audio files, documents, and videos.
Connect the outputs to other blocks to process messages and send responses.
"""
class Input(TelegramTriggerBase.Input):
class EventsFilter(BaseModel):
"""Filter for message types to receive."""
text: bool = True
photo: bool = False
voice: bool = False
audio: bool = False
document: bool = False
video: bool = False
edited_message: bool = False
events: EventsFilter = SchemaField(
title="Message Types", description="Types of messages to receive"
)
class Output(BlockSchemaOutput):
payload: dict = SchemaField(
description="The complete webhook payload from Telegram"
)
chat_id: int = SchemaField(
description="The chat ID where the message was received. "
"Use this to send replies."
)
message_id: int = SchemaField(description="The unique message ID")
user_id: int = SchemaField(description="The user ID who sent the message")
username: str = SchemaField(description="Username of the sender (may be empty)")
first_name: str = SchemaField(description="First name of the sender")
event: str = SchemaField(
description="The message type (text, photo, voice, audio, etc.)"
)
text: str = SchemaField(
description="Text content of the message (for text messages)"
)
photo_file_id: str = SchemaField(
description="File ID of the photo (for photo messages). "
"Use GetTelegramFileBlock to download."
)
voice_file_id: str = SchemaField(
description="File ID of the voice message (for voice messages). "
"Use GetTelegramFileBlock to download."
)
audio_file_id: str = SchemaField(
description="File ID of the audio file (for audio messages). "
"Use GetTelegramFileBlock to download."
)
file_id: str = SchemaField(
description="File ID for document/video messages. "
"Use GetTelegramFileBlock to download."
)
file_name: str = SchemaField(
description="Original filename (for document/audio messages)"
)
caption: str = SchemaField(description="Caption for media messages")
is_edited: bool = SchemaField(
description="Whether this is an edit of a previously sent message"
)
def __init__(self):
super().__init__(
id="4435e4e0-df6e-4301-8f35-ad70b12fc9ec",
description="Triggers when a message is received or edited in your Telegram bot. "
"Supports text, photos, voice messages, audio files, documents, and videos.",
categories={BlockCategory.SOCIAL},
input_schema=TelegramMessageTriggerBlock.Input,
output_schema=TelegramMessageTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.TELEGRAM,
webhook_type=TelegramWebhookType.BOT,
resource_format="bot",
event_filter_input="events",
event_format="message.{event}",
),
test_input={
"events": {"text": True, "photo": True},
"credentials": TEST_CREDENTIALS_INPUT,
"payload": EXAMPLE_MESSAGE_PAYLOAD,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", EXAMPLE_MESSAGE_PAYLOAD),
("chat_id", 12345678),
("message_id", 1),
("user_id", 12345678),
("username", "johndoe"),
("first_name", "John"),
("is_edited", False),
("event", "text"),
("text", "Hello, bot!"),
("photo_file_id", ""),
("voice_file_id", ""),
("audio_file_id", ""),
("file_id", ""),
("file_name", ""),
("caption", ""),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
payload = input_data.payload
is_edited = "edited_message" in payload
message = payload.get("message") or payload.get("edited_message", {})
# Extract common fields
chat = message.get("chat", {})
sender = message.get("from", {})
yield "payload", payload
yield "chat_id", chat.get("id", 0)
yield "message_id", message.get("message_id", 0)
yield "user_id", sender.get("id", 0)
yield "username", sender.get("username", "")
yield "first_name", sender.get("first_name", "")
yield "is_edited", is_edited
# For edited messages, yield event as "edited_message" and extract
# all content fields from the edited message body
if is_edited:
yield "event", "edited_message"
yield "text", message.get("text", "")
photos = message.get("photo", [])
yield "photo_file_id", photos[-1].get("file_id", "") if photos else ""
voice = message.get("voice", {})
yield "voice_file_id", voice.get("file_id", "")
audio = message.get("audio", {})
yield "audio_file_id", audio.get("file_id", "")
document = message.get("document", {})
video = message.get("video", {})
yield "file_id", (document.get("file_id", "") or video.get("file_id", ""))
yield "file_name", (
document.get("file_name", "") or audio.get("file_name", "")
)
yield "caption", message.get("caption", "")
# Determine message type and extract content
elif "text" in message:
yield "event", "text"
yield "text", message.get("text", "")
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", ""
elif "photo" in message:
# Get the largest photo (last in array)
photos = message.get("photo", [])
photo_fid = photos[-1].get("file_id", "") if photos else ""
yield "event", "photo"
yield "text", ""
yield "photo_file_id", photo_fid
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", message.get("caption", "")
elif "voice" in message:
voice = message.get("voice", {})
yield "event", "voice"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", voice.get("file_id", "")
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", message.get("caption", "")
elif "audio" in message:
audio = message.get("audio", {})
yield "event", "audio"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", audio.get("file_id", "")
yield "file_id", ""
yield "file_name", audio.get("file_name", "")
yield "caption", message.get("caption", "")
elif "document" in message:
document = message.get("document", {})
yield "event", "document"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", document.get("file_id", "")
yield "file_name", document.get("file_name", "")
yield "caption", message.get("caption", "")
elif "video" in message:
video = message.get("video", {})
yield "event", "video"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", video.get("file_id", "")
yield "file_name", video.get("file_name", "")
yield "caption", message.get("caption", "")
else:
yield "event", "other"
yield "text", ""
yield "photo_file_id", ""
yield "voice_file_id", ""
yield "audio_file_id", ""
yield "file_id", ""
yield "file_name", ""
yield "caption", ""
# Example payload for reaction trigger testing
EXAMPLE_REACTION_PAYLOAD = {
"update_id": 123456790,
"message_reaction": {
"chat": {
"id": 12345678,
"first_name": "John",
"last_name": "Doe",
"username": "johndoe",
"type": "private",
},
"message_id": 42,
"user": {
"id": 12345678,
"is_bot": False,
"first_name": "John",
"username": "johndoe",
},
"date": 1234567890,
"new_reaction": [{"type": "emoji", "emoji": "👍"}],
"old_reaction": [],
},
}
class TelegramMessageReactionTriggerBlock(TelegramTriggerBase, Block):
"""
Triggers when a reaction to a message is changed.
Works automatically in private chats. In group chats, the bot must be
an administrator to receive reaction updates.
"""
class Input(TelegramTriggerBase.Input):
pass
class Output(BlockSchemaOutput):
payload: dict = SchemaField(
description="The complete webhook payload from Telegram"
)
chat_id: int = SchemaField(
description="The chat ID where the reaction occurred"
)
message_id: int = SchemaField(description="The message ID that was reacted to")
user_id: int = SchemaField(description="The user ID who changed the reaction")
username: str = SchemaField(description="Username of the user (may be empty)")
new_reactions: list = SchemaField(
description="List of new reactions on the message"
)
old_reactions: list = SchemaField(
description="List of previous reactions on the message"
)
def __init__(self):
super().__init__(
id="82525328-9368-4966-8f0c-cd78e80181fd",
description="Triggers when a reaction to a message is changed. "
"Works in private chats automatically. "
"In groups, the bot must be an administrator.",
categories={BlockCategory.SOCIAL},
input_schema=TelegramMessageReactionTriggerBlock.Input,
output_schema=TelegramMessageReactionTriggerBlock.Output,
webhook_config=BlockWebhookConfig(
provider=ProviderName.TELEGRAM,
webhook_type=TelegramWebhookType.BOT,
resource_format="bot",
event_filter_input="",
event_format="message_reaction",
),
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"payload": EXAMPLE_REACTION_PAYLOAD,
},
test_credentials=TEST_CREDENTIALS,
test_output=[
("payload", EXAMPLE_REACTION_PAYLOAD),
("chat_id", 12345678),
("message_id", 42),
("user_id", 12345678),
("username", "johndoe"),
("new_reactions", [{"type": "emoji", "emoji": "👍"}]),
("old_reactions", []),
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
payload = input_data.payload
reaction = payload.get("message_reaction", {})
chat = reaction.get("chat", {})
user = reaction.get("user", {})
yield "payload", payload
yield "chat_id", chat.get("id", 0)
yield "message_id", reaction.get("message_id", 0)
yield "user_id", user.get("id", 0)
yield "username", user.get("username", "")
yield "new_reactions", reaction.get("new_reaction", [])
yield "old_reactions", reaction.get("old_reaction", [])

View File

@@ -34,12 +34,10 @@ def main(output: Path, pretty: bool):
"""Generate and output the OpenAPI JSON specification."""
openapi_schema = get_openapi_schema()
json_output = json.dumps(
openapi_schema, indent=2 if pretty else None, ensure_ascii=False
)
json_output = json.dumps(openapi_schema, indent=2 if pretty else None)
if output:
output.write_text(json_output, encoding="utf-8")
output.write_text(json_output)
click.echo(f"✅ OpenAPI specification written to {output}\n\nPreview:")
click.echo(f"\n{json_output[:500]} ...")
else:

View File

@@ -1,3 +0,0 @@
from .service import stream_chat_completion_baseline
__all__ = ["stream_chat_completion_baseline"]

View File

@@ -1,420 +0,0 @@
"""Baseline LLM fallback — OpenAI-compatible streaming with tool calling.
Used when ``CHAT_USE_CLAUDE_AGENT_SDK=false``, e.g. as a fallback when the
Claude Agent SDK / Anthropic API is unavailable. Routes through any
OpenAI-compatible provider (OpenRouter by default) and reuses the same
shared tool registry as the SDK path.
"""
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any
import orjson
from langfuse import propagate_attributes
from backend.copilot.model import (
ChatMessage,
ChatSession,
get_chat_session,
update_session_title,
upsert_chat_session,
)
from backend.copilot.response_model import (
StreamBaseResponse,
StreamError,
StreamFinish,
StreamFinishStep,
StreamStart,
StreamStartStep,
StreamTextDelta,
StreamTextEnd,
StreamTextStart,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.copilot.service import (
_build_system_prompt,
_generate_session_title,
client,
config,
)
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.util.exceptions import NotFoundError
from backend.util.prompt import compress_context
logger = logging.getLogger(__name__)
# Set to hold background tasks to prevent garbage collection
_background_tasks: set[asyncio.Task[Any]] = set()
# Maximum number of tool-call rounds before forcing a text response.
_MAX_TOOL_ROUNDS = 30
async def _update_title_async(
session_id: str, message: str, user_id: str | None
) -> None:
"""Generate and persist a session title in the background."""
try:
title = await _generate_session_title(message, user_id, session_id)
if title:
await update_session_title(session_id, title)
except Exception as e:
logger.warning("[Baseline] Failed to update session title: %s", e)
async def _compress_session_messages(
messages: list[ChatMessage],
) -> list[ChatMessage]:
"""Compress session messages if they exceed the model's token limit.
Uses the shared compress_context() utility which supports LLM-based
summarization of older messages while keeping recent ones intact,
with progressive truncation and middle-out deletion as fallbacks.
"""
messages_dict = []
for msg in messages:
msg_dict: dict[str, Any] = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
messages_dict.append(msg_dict)
try:
result = await compress_context(
messages=messages_dict,
model=config.model,
client=client,
)
except Exception as e:
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
result = await compress_context(
messages=messages_dict,
model=config.model,
client=None,
)
if result.was_compacted:
logger.info(
"[Baseline] Context compacted: %d -> %d tokens "
"(%d summarized, %d dropped)",
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
return [
ChatMessage(role=m["role"], content=m.get("content"))
for m in result.messages
]
return messages
async def stream_chat_completion_baseline(
session_id: str,
message: str | None = None,
is_user_message: bool = True,
user_id: str | None = None,
session: ChatSession | None = None,
**_kwargs: Any,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Baseline LLM with tool calling via OpenAI-compatible API.
Designed as a fallback when the Claude Agent SDK is unavailable.
Uses the same tool registry as the SDK path but routes through any
OpenAI-compatible provider (e.g. OpenRouter).
Flow: stream response -> if tool_calls, execute them -> feed results back -> repeat.
"""
if session is None:
session = await get_chat_session(session_id, user_id)
if not session:
raise NotFoundError(
f"Session {session_id} not found. Please create a new session first."
)
# Append user message
new_role = "user" if is_user_message else "assistant"
if message and (
len(session.messages) == 0
or not (
session.messages[-1].role == new_role
and session.messages[-1].content == message
)
):
session.messages.append(ChatMessage(role=new_role, content=message))
if is_user_message:
track_user_message(
user_id=user_id,
session_id=session_id,
message_length=len(message),
)
session = await upsert_chat_session(session)
# Generate title for new sessions
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
if len(user_messages) == 1:
first_message = user_messages[0].content or message or ""
if first_message:
task = asyncio.create_task(
_update_title_async(session_id, first_message, user_id)
)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
message_id = str(uuid.uuid4())
# Build system prompt only on the first turn to avoid mid-conversation
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
if is_first_turn:
system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=False
)
else:
system_prompt, _ = await _build_system_prompt(
user_id=None, has_conversation_history=True
)
# Compress context if approaching the model's token limit
messages_for_context = await _compress_session_messages(session.messages)
# Build OpenAI message list from session history
openai_messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt}
]
for msg in messages_for_context:
if msg.role in ("user", "assistant") and msg.content:
openai_messages.append({"role": msg.role, "content": msg.content})
tools = get_available_tools()
yield StreamStart(messageId=message_id, sessionId=session_id)
# Propagate user/session context to Langfuse so all LLM calls within
# this request are grouped under a single trace with proper attribution.
_trace_ctx: Any = None
try:
_trace_ctx = propagate_attributes(
user_id=user_id,
session_id=session_id,
trace_name="copilot-baseline",
tags=["baseline"],
)
_trace_ctx.__enter__()
except Exception:
logger.warning("[Baseline] Langfuse trace context setup failed")
assistant_text = ""
text_block_id = str(uuid.uuid4())
text_started = False
step_open = False
try:
for _round in range(_MAX_TOOL_ROUNDS):
# Open a new step for each LLM round
yield StreamStartStep()
step_open = True
# Stream a response from the model
create_kwargs: dict[str, Any] = dict(
model=config.model,
messages=openai_messages,
stream=True,
)
if tools:
create_kwargs["tools"] = tools
response = await client.chat.completions.create(**create_kwargs) # type: ignore[arg-type] # dynamic kwargs
# Accumulate streamed response (text + tool calls)
round_text = ""
tool_calls_by_index: dict[int, dict[str, str]] = {}
async for chunk in response:
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
# Text content
if delta.content:
if not text_started:
yield StreamTextStart(id=text_block_id)
text_started = True
round_text += delta.content
yield StreamTextDelta(id=text_block_id, delta=delta.content)
# Tool call fragments (streamed incrementally)
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
tool_calls_by_index[idx] = {
"id": "",
"name": "",
"arguments": "",
}
entry = tool_calls_by_index[idx]
if tc.id:
entry["id"] = tc.id
if tc.function and tc.function.name:
entry["name"] = tc.function.name
if tc.function and tc.function.arguments:
entry["arguments"] += tc.function.arguments
# Close text block if we had one this round
if text_started:
yield StreamTextEnd(id=text_block_id)
text_started = False
text_block_id = str(uuid.uuid4())
# Accumulate text for session persistence
assistant_text += round_text
# No tool calls -> model is done
if not tool_calls_by_index:
yield StreamFinishStep()
step_open = False
break
# Close step before tool execution
yield StreamFinishStep()
step_open = False
# Append the assistant message with tool_calls to context.
assistant_msg: dict[str, Any] = {"role": "assistant"}
if round_text:
assistant_msg["content"] = round_text
assistant_msg["tool_calls"] = [
{
"id": tc["id"],
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["arguments"] or "{}",
},
}
for tc in tool_calls_by_index.values()
]
openai_messages.append(assistant_msg)
# Execute each tool call and stream events
for tc in tool_calls_by_index.values():
tool_call_id = tc["id"]
tool_name = tc["name"]
raw_args = tc["arguments"] or "{}"
try:
tool_args = orjson.loads(raw_args)
except orjson.JSONDecodeError as parse_err:
parse_error = (
f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
)
logger.warning("[Baseline] %s", parse_error)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": parse_error,
}
)
continue
yield StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
yield StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
# Execute via shared tool registry
try:
result: StreamToolOutputAvailable = await execute_tool(
tool_name=tool_name,
parameters=tool_args,
user_id=user_id,
session=session,
tool_call_id=tool_call_id,
)
yield result
tool_output = (
result.output
if isinstance(result.output, str)
else str(result.output)
)
except Exception as e:
error_output = f"Tool execution error: {e}"
logger.error(
"[Baseline] Tool %s failed: %s",
tool_name,
error_output,
exc_info=True,
)
yield StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
tool_output = error_output
# Append tool result to context for next round
openai_messages.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": tool_output,
}
)
else:
# for-loop exhausted without break -> tool-round limit hit
limit_msg = (
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
"without a final response."
)
logger.error("[Baseline] %s", limit_msg)
yield StreamError(
errorText=limit_msg,
code="baseline_tool_round_limit",
)
except Exception as e:
error_msg = str(e) or type(e).__name__
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
# Close any open text/step before emitting error
if text_started:
yield StreamTextEnd(id=text_block_id)
if step_open:
yield StreamFinishStep()
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
# Close Langfuse trace context
if _trace_ctx is not None:
try:
_trace_ctx.__exit__(None, None, None)
except Exception:
logger.warning("[Baseline] Langfuse trace context teardown failed")
# Persist assistant response
if assistant_text:
session.messages.append(
ChatMessage(role="assistant", content=assistant_text)
)
try:
await upsert_chat_session(session)
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
yield StreamFinish()

View File

@@ -1,99 +0,0 @@
import logging
from os import getenv
import pytest
from backend.copilot.baseline import stream_chat_completion_baseline
from backend.copilot.model import (
create_chat_session,
get_chat_session,
upsert_chat_session,
)
from backend.copilot.response_model import (
StreamError,
StreamFinish,
StreamStart,
StreamTextDelta,
)
logger = logging.getLogger(__name__)
@pytest.mark.asyncio(loop_scope="session")
async def test_baseline_multi_turn(setup_test_user, test_user_id):
"""Test that the baseline LLM path streams responses and maintains history.
Turn 1: Send a message with a unique keyword.
Turn 2: Ask the model to recall the keyword — proving conversation history
is correctly passed to the single-call LLM.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
# --- Turn 1: send a message with a unique keyword ---
keyword = "QUASAR99"
turn1_msg = (
f"Please remember this special keyword: {keyword}. "
"Just confirm you've noted it, keep your response brief."
)
turn1_text = ""
turn1_errors: list[str] = []
got_start = False
got_finish = False
async for chunk in stream_chat_completion_baseline(
session.session_id,
turn1_msg,
user_id=test_user_id,
):
if isinstance(chunk, StreamStart):
got_start = True
elif isinstance(chunk, StreamTextDelta):
turn1_text += chunk.delta
elif isinstance(chunk, StreamError):
turn1_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
got_finish = True
assert got_start, "Turn 1 did not yield StreamStart"
assert got_finish, "Turn 1 did not yield StreamFinish"
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
assert turn1_text, "Turn 1 produced no text"
logger.info(f"Turn 1 response: {turn1_text[:100]}")
# Reload session for turn 2
session = await get_chat_session(session.session_id, test_user_id)
assert session, "Session not found after turn 1"
# Verify messages were persisted (user + assistant)
assert (
len(session.messages) >= 2
), f"Expected at least 2 messages after turn 1, got {len(session.messages)}"
# --- Turn 2: ask model to recall the keyword ---
turn2_msg = "What was the special keyword I asked you to remember?"
turn2_text = ""
turn2_errors: list[str] = []
async for chunk in stream_chat_completion_baseline(
session.session_id,
turn2_msg,
user_id=test_user_id,
session=session,
):
if isinstance(chunk, StreamTextDelta):
turn2_text += chunk.delta
elif isinstance(chunk, StreamError):
turn2_errors.append(chunk.errorText)
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
assert turn2_text, "Turn 2 produced no text"
assert keyword in turn2_text, (
f"Model did not recall keyword '{keyword}' in turn 2. "
f"Response: {turn2_text[:200]}"
)
logger.info(f"Turn 2 recalled keyword successfully: {turn2_text[:100]}")

View File

@@ -0,0 +1,349 @@
"""Redis Streams consumer for operation completion messages.
This module provides a consumer (ChatCompletionConsumer) that listens for
completion notifications (OperationCompleteMessage) from external services
(like Agent Generator) and triggers the appropriate stream registry and
chat service updates via process_operation_success/process_operation_failure.
Why Redis Streams instead of RabbitMQ?
--------------------------------------
While the project typically uses RabbitMQ for async task queues (e.g., execution
queue), Redis Streams was chosen for chat completion notifications because:
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
Streams (via stream_registry) for message persistence and replay. Using Redis
Streams for completion notifications keeps all chat streaming infrastructure
in one system, simplifying operations and reducing cross-system coordination.
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
allowing consumers to replay missed messages after reconnection. This aligns
with the SSE reconnection pattern where clients can resume from last_message_id.
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
recovering from dead consumers - ideal for the completion callback pattern.
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
stream_registry) provides lower latency than an additional RabbitMQ hop.
5. **Atomicity with Task State**: Completion processing often needs to update
task metadata stored in Redis. Keeping both in Redis enables simpler
transactional semantics without distributed coordination.
The consumer uses Redis Streams with consumer groups for reliable message
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
stale pending messages from dead consumers.
"""
import asyncio
import logging
import uuid
from typing import Any
import orjson
from pydantic import BaseModel
from redis.exceptions import ResponseError
from backend.data.redis_client import get_redis_async
from . import stream_registry
from .completion_handler import process_operation_failure, process_operation_success
from .config import ChatConfig
logger = logging.getLogger(__name__)
config = ChatConfig()
class OperationCompleteMessage(BaseModel):
"""Message format for operation completion notifications."""
operation_id: str
task_id: str
success: bool
result: dict | str | None = None
error: str | None = None
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from Redis Streams.
Database operations are handled through the chat_db() accessor, which
routes through DatabaseManager RPC when Prisma is not directly connected.
Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure.
"""
def __init__(self):
self._consumer_task: asyncio.Task | None = None
self._running = False
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
"""Start the completion consumer."""
if self._running:
logger.warning("Completion consumer already running")
return
# Create consumer group if it doesn't exist
try:
redis = await get_redis_async()
await redis.xgroup_create(
config.stream_completion_name,
config.stream_consumer_group,
id="0",
mkstream=True,
)
logger.info(
f"Created consumer group '{config.stream_consumer_group}' "
f"on stream '{config.stream_completion_name}'"
)
except ResponseError as e:
if "BUSYGROUP" in str(e):
logger.debug(
f"Consumer group '{config.stream_consumer_group}' already exists"
)
else:
raise
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info(
f"Chat completion consumer started (consumer: {self._consumer_name})"
)
async def stop(self) -> None:
"""Stop the completion consumer."""
self._running = False
if self._consumer_task:
self._consumer_task.cancel()
try:
await self._consumer_task
except asyncio.CancelledError:
pass
self._consumer_task = None
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
block_timeout = 5000 # milliseconds
while self._running and retry_count < max_retries:
try:
redis = await get_redis_async()
# Reset retry count on successful connection
retry_count = 0
while self._running:
# First, claim any stale pending messages from dead consumers
# Redis does NOT auto-redeliver pending messages; we must explicitly
# claim them using XAUTOCLAIM
try:
claimed_result = await redis.xautoclaim(
name=config.stream_completion_name,
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
min_idle_time=config.stream_claim_min_idle_ms,
start_id="0-0",
count=10,
)
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
if claimed_result and len(claimed_result) >= 2:
claimed_entries = claimed_result[1]
if claimed_entries:
logger.info(
f"Claimed {len(claimed_entries)} stale pending messages"
)
for entry_id, data in claimed_entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except Exception as e:
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
# Read new messages from the stream
messages = await redis.xreadgroup(
groupname=config.stream_consumer_group,
consumername=self._consumer_name,
streams={config.stream_completion_name: ">"},
block=block_timeout,
count=10,
)
if not messages:
continue
for stream_name, entries in messages:
for entry_id, data in entries:
if not self._running:
return
await self._process_entry(redis, entry_id, data)
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _process_entry(
self, redis: Any, entry_id: str, data: dict[str, Any]
) -> None:
"""Process a single stream entry and acknowledge it on success.
Args:
redis: Redis client connection
entry_id: The stream entry ID
data: The entry data dict
"""
try:
# Handle the message
message_data = data.get("data")
if message_data:
await self._handle_message(
message_data.encode()
if isinstance(message_data, str)
else message_data
)
# Acknowledge the message after successful processing
await redis.xack(
config.stream_completion_name,
config.stream_consumer_group,
entry_id,
)
except Exception as e:
logger.error(
f"Error processing completion message {entry_id}: {e}",
exc_info=True,
)
# Message remains in pending state and will be claimed by
# XAUTOCLAIM after min_idle_time expires
async def _handle_message(self, body: bytes) -> None:
"""Handle a completion message."""
try:
data = orjson.loads(body)
message = OperationCompleteMessage(**data)
except Exception as e:
logger.error(f"Failed to parse completion message: {e}")
return
logger.info(
f"[COMPLETION] Received completion for operation {message.operation_id} "
f"(task_id={message.task_id}, success={message.success})"
)
# Find task in registry
task = await stream_registry.find_task_by_operation_id(message.operation_id)
if task is None:
task = await stream_registry.get_task(message.task_id)
if task is None:
logger.warning(
f"[COMPLETION] Task not found for operation {message.operation_id} "
f"(task_id={message.task_id})"
)
return
logger.info(
f"[COMPLETION] Found task: task_id={task.task_id}, "
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
)
# Guard against empty task fields
if not task.task_id or not task.session_id or not task.tool_call_id:
logger.error(
f"[COMPLETION] Task has empty critical fields! "
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
f"tool_call_id={task.tool_call_id!r}"
)
return
if message.success:
await self._handle_success(task, message)
else:
await self._handle_failure(task, message)
async def _handle_success(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle successful operation completion."""
await process_operation_success(task, message.result)
async def _handle_failure(
self,
task: stream_registry.ActiveTask,
message: OperationCompleteMessage,
) -> None:
"""Handle failed operation completion."""
await process_operation_failure(task, message.error)
# Module-level consumer instance
_consumer: ChatCompletionConsumer | None = None
async def start_completion_consumer() -> None:
"""Start the global completion consumer."""
global _consumer
if _consumer is None:
_consumer = ChatCompletionConsumer()
await _consumer.start()
async def stop_completion_consumer() -> None:
"""Stop the global completion consumer."""
global _consumer
if _consumer:
await _consumer.stop()
_consumer = None
async def publish_operation_complete(
operation_id: str,
task_id: str,
success: bool,
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message to Redis Streams.
Args:
operation_id: The operation ID that completed.
task_id: The task ID associated with the operation.
success: Whether the operation succeeded.
result: The result data (for success).
error: The error message (for failure).
"""
message = OperationCompleteMessage(
operation_id=operation_id,
task_id=task_id,
success=success,
result=result,
error=error,
)
redis = await get_redis_async()
await redis.xadd(
config.stream_completion_name,
{"data": message.model_dump_json()},
maxlen=config.stream_max_length,
)
logger.info(f"Published completion for operation {operation_id}")

View File

@@ -0,0 +1,329 @@
"""Shared completion handling for operation success and failure.
This module provides common logic for handling operation completion from both:
- The Redis Streams consumer (completion_consumer.py)
- The HTTP webhook endpoint (routes.py)
"""
import logging
from typing import Any
import orjson
from backend.data.db_accessors import chat_db
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Tools that produce agent_json that needs to be saved to library
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
# Keys that should be stripped from agent_json when returning in error responses
SENSITIVE_KEYS = frozenset(
{
"api_key",
"apikey",
"api_secret",
"password",
"secret",
"credentials",
"credential",
"token",
"access_token",
"refresh_token",
"private_key",
"privatekey",
"auth",
"authorization",
}
)
def _sanitize_agent_json(obj: Any) -> Any:
"""Recursively sanitize agent_json by removing sensitive keys.
Args:
obj: The object to sanitize (dict, list, or primitive)
Returns:
Sanitized copy with sensitive keys removed/redacted
"""
if isinstance(obj, dict):
return {
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
for k, v in obj.items()
}
elif isinstance(obj, list):
return [_sanitize_agent_json(item) for item in obj]
else:
return obj
class ToolMessageUpdateError(Exception):
"""Raised when updating a tool message in the database fails."""
pass
async def _update_tool_message(
session_id: str,
tool_call_id: str,
content: str,
) -> None:
"""Update tool message in database using the chat_db accessor.
Routes through DatabaseManager RPC when Prisma is not directly
connected (e.g. in the CoPilot Executor microservice).
Args:
session_id: The session ID
tool_call_id: The tool call ID to update
content: The new content for the message
Raises:
ToolMessageUpdateError: If the database update fails.
"""
try:
updated = await chat_db().update_tool_message_content(
session_id=session_id,
tool_call_id=tool_call_id,
new_content=content,
)
if not updated:
raise ToolMessageUpdateError(
f"No message found with tool_call_id="
f"{tool_call_id} in session {session_id}"
)
except ToolMessageUpdateError:
raise
except Exception as e:
logger.error(
f"[COMPLETION] Failed to update tool message: {e}",
exc_info=True,
)
raise ToolMessageUpdateError(
f"Failed to update tool message for tool call #{tool_call_id}: {e}"
) from e
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
"""Serialize result to JSON string with sensible defaults.
Args:
result: The result to serialize. Can be a dict, list, string,
number, boolean, or None.
Returns:
JSON string representation of the result. Returns '{"status": "completed"}'
only when result is explicitly None.
"""
if isinstance(result, str):
return result
if result is None:
return '{"status": "completed"}'
return orjson.dumps(result).decode("utf-8")
async def _save_agent_from_result(
result: dict[str, Any],
user_id: str | None,
tool_name: str,
) -> dict[str, Any]:
"""Save agent to library if result contains agent_json.
Args:
result: The result dict that may contain agent_json
user_id: The user ID to save the agent for
tool_name: The tool name (create_agent or edit_agent)
Returns:
Updated result dict with saved agent details, or original result if no agent_json
"""
if not user_id:
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
return result
agent_json = result.get("agent_json")
if not agent_json:
logger.warning(
f"[COMPLETION] {tool_name} completed but no agent_json in result"
)
return result
try:
from .tools.agent_generator import save_agent_to_library
is_update = tool_name == "edit_agent"
created_graph, library_agent = await save_agent_to_library(
agent_json, user_id, is_update=is_update
)
logger.info(
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
)
# Return a response similar to AgentSavedResponse
return {
"type": "agent_saved",
"message": f"Agent '{created_graph.name}' has been saved to your library!",
"agent_id": created_graph.id,
"agent_name": created_graph.name,
"library_agent_id": library_agent.id,
"library_agent_link": f"/library/agents/{library_agent.id}",
"agent_page_link": f"/build?flowID={created_graph.id}",
}
except Exception as e:
logger.error(
f"[COMPLETION] Failed to save agent to library: {e}",
exc_info=True,
)
# Return error but don't fail the whole operation
# Sanitize agent_json to remove sensitive keys before returning
return {
"type": "error",
"message": f"Agent was generated but failed to save: {str(e)}",
"error": str(e),
"agent_json": _sanitize_agent_json(agent_json),
}
async def process_operation_success(
task: stream_registry.ActiveTask,
result: dict | str | None,
) -> None:
"""Handle successful operation completion.
Publishes the result to the stream registry, updates the database,
generates LLM continuation, and marks the task as completed.
Args:
task: The active task that completed
result: The result data from the operation
Raises:
ToolMessageUpdateError: If the database update fails. The task
will be marked as failed instead of completed.
"""
# For agent generation tools, save the agent to library
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
# Serialize result for output (only substitute default when result is exactly None)
result_output = result if result is not None else {"status": "completed"}
output_str = (
result_output
if isinstance(result_output, str)
else orjson.dumps(result_output).decode("utf-8")
)
# Publish result to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamToolOutputAvailable(
toolCallId=task.tool_call_id,
toolName=task.tool_name,
output=output_str,
success=True,
),
)
# Update pending operation in database
# If this fails, we must not continue to mark the task as completed
result_str = serialize_result(result)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=result_str,
)
except ToolMessageUpdateError:
# DB update failed - mark task as failed to avoid inconsistent state
logger.error(
f"[COMPLETION] DB update failed for task {task.task_id}, "
"marking as failed instead of completed"
)
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText="Failed to save operation result to database"),
)
await stream_registry.mark_task_completed(task.task_id, status="failed")
raise
# Generate LLM continuation with streaming
try:
await chat_service._generate_llm_continuation_with_streaming(
session_id=task.session_id,
user_id=task.user_id,
task_id=task.task_id,
)
except Exception as e:
logger.error(
f"[COMPLETION] Failed to generate LLM continuation: {e}",
exc_info=True,
)
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
)
async def process_operation_failure(
task: stream_registry.ActiveTask,
error: str | None,
) -> None:
"""Handle failed operation completion.
Publishes the error to the stream registry, updates the database
with the error response, and marks the task as failed.
Args:
task: The active task that failed
error: The error message from the operation
"""
error_msg = error or "Operation failed"
# Publish error to stream registry
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
# Update pending operation with error
# If this fails, we still continue to mark the task as failed
error_response = ErrorResponse(
message=error_msg,
error=error,
)
try:
await _update_tool_message(
session_id=task.session_id,
tool_call_id=task.tool_call_id,
content=error_response.model_dump_json(),
)
except ToolMessageUpdateError:
# DB update failed - log but continue with cleanup
logger.error(
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
"continuing with cleanup"
)
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
try:
await chat_service._mark_operation_completed(task.tool_call_id)
except Exception as e:
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")

View File

@@ -26,34 +26,63 @@ class ChatConfig(BaseSettings):
# Session TTL Configuration - 12 hours
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
# Streaming Configuration
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
max_retries: int = Field(
default=3,
description="Max retries for fallback path (SDK handles retries internally)",
)
max_agent_runs: int = Field(default=30, description="Maximum number of agent runs")
max_agent_schedules: int = Field(
default=30, description="Maximum number of agent schedules"
)
# Long-running operation configuration
long_running_operation_ttl: int = Field(
default=600,
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
)
# Stream registry configuration for SSE reconnection
stream_ttl: int = Field(
default=3600,
description="TTL in seconds for stream data in Redis (1 hour)",
)
stream_lock_ttl: int = Field(
default=120,
description="TTL in seconds for stream lock (2 minutes). Short timeout allows "
"reconnection after refresh/crash without long waits.",
)
stream_max_length: int = Field(
default=10000,
description="Maximum number of messages to store per stream",
)
# Redis key prefixes for stream registry
session_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for session metadata hash keys",
# Redis Streams configuration for completion consumer
stream_completion_name: str = Field(
default="chat:completions",
description="Redis Stream name for operation completions",
)
turn_stream_prefix: str = Field(
stream_consumer_group: str = Field(
default="chat_consumers",
description="Consumer group name for completion stream",
)
stream_claim_min_idle_ms: int = Field(
default=60000,
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
)
# Redis key prefixes for stream registry
task_meta_prefix: str = Field(
default="chat:task:meta:",
description="Prefix for task metadata hash keys",
)
task_stream_prefix: str = Field(
default="chat:stream:",
description="Prefix for turn message stream keys",
description="Prefix for task message stream keys",
)
task_op_prefix: str = Field(
default="chat:task:op:",
description="Prefix for operation ID to task ID mapping keys",
)
internal_api_key: str | None = Field(
default=None,
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
)
# Langfuse Prompt Management Configuration
@@ -62,15 +91,11 @@ class ChatConfig(BaseSettings):
default="CoPilot Prompt",
description="Name of the prompt in Langfuse to fetch",
)
langfuse_prompt_cache_ttl: int = Field(
default=300,
description="Cache TTL in seconds for Langfuse prompt (0 to disable caching)",
)
# Claude Agent SDK Configuration
use_claude_agent_sdk: bool = Field(
default=True,
description="Use Claude Agent SDK (True) or OpenAI-compatible LLM baseline (False)",
description="Use Claude Agent SDK for chat completions",
)
claude_agent_model: str | None = Field(
default=None,
@@ -84,60 +109,25 @@ class ChatConfig(BaseSettings):
)
claude_agent_max_subtasks: int = Field(
default=10,
description="Max number of concurrent sub-agent Tasks the SDK can run per session.",
description="Max number of sub-agent Tasks the SDK can spawn per session.",
)
claude_agent_use_resume: bool = Field(
default=True,
description="Use --resume for multi-turn conversations instead of "
"history compression. Falls back to compression when unavailable.",
)
use_claude_code_subscription: bool = Field(
default=False,
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
)
# E2B Sandbox Configuration
use_e2b_sandbox: bool = Field(
# Extended thinking configuration for Claude models
thinking_enabled: bool = Field(
default=True,
description="Use E2B cloud sandboxes for persistent bash/python execution. "
"When enabled, bash_exec routes commands to E2B and SDK file tools "
"operate directly on the sandbox via E2B's filesystem API.",
description="Enable adaptive thinking for Claude models via OpenRouter",
)
e2b_api_key: str | None = Field(
default=None,
description="E2B API key. Falls back to E2B_API_KEY environment variable.",
)
e2b_sandbox_template: str = Field(
default="base",
description="E2B sandbox template to use for copilot sessions.",
)
e2b_sandbox_timeout: int = Field(
default=43200, # 12 hours — same as session_ttl
description="E2B sandbox keepalive timeout in seconds.",
)
@field_validator("use_e2b_sandbox", mode="before")
@classmethod
def get_use_e2b_sandbox(cls, v):
"""Get use_e2b_sandbox from environment if not provided."""
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
return True if v is None else v
@field_validator("e2b_api_key", mode="before")
@classmethod
def get_e2b_api_key(cls, v):
"""Get E2B API key from environment if not provided."""
if not v:
v = os.getenv("CHAT_E2B_API_KEY") or os.getenv("E2B_API_KEY")
return v
@field_validator("api_key", mode="before")
@classmethod
def get_api_key(cls, v):
"""Get API key from environment if not provided."""
if not v:
if v is None:
# Try to get from environment variables
# First check for CHAT_API_KEY (Pydantic prefix)
v = os.getenv("CHAT_API_KEY")
@@ -147,16 +137,13 @@ class ChatConfig(BaseSettings):
if not v:
# Fall back to OPENAI_API_KEY
v = os.getenv("OPENAI_API_KEY")
# Note: ANTHROPIC_API_KEY is intentionally NOT included here.
# The SDK CLI picks it up from the env directly. Including it
# would pair it with the OpenRouter base_url, causing auth failures.
return v
@field_validator("base_url", mode="before")
@classmethod
def get_base_url(cls, v):
"""Get base URL from environment if not provided."""
if not v:
if v is None:
# Check for OpenRouter or custom base URL
v = os.getenv("CHAT_BASE_URL")
if not v:
@@ -167,6 +154,14 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
@field_validator("use_claude_agent_sdk", mode="before")
@classmethod
def get_use_claude_agent_sdk(cls, v):
@@ -178,15 +173,6 @@ class ChatConfig(BaseSettings):
# Default to True (SDK enabled by default)
return True if v is None else v
@field_validator("use_claude_code_subscription", mode="before")
@classmethod
def get_use_claude_code_subscription(cls, v):
"""Get use_claude_code_subscription from environment if not provided."""
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
if env_val:
return env_val in ("true", "1", "yes", "on")
return False if v is None else v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -1,11 +0,0 @@
"""Shared constants for the CoPilot module."""
# Special message prefixes for text-based markers (parsed by frontend).
# The hex suffix makes accidental LLM generation of these strings virtually
# impossible, avoiding false-positive marker detection in normal conversation.
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
# Compaction notice messages shown to users.
COMPACTION_DONE_MSG = "Earlier messages were summarized to fit within context limits."
COMPACTION_TOOL_NAME = "context_compaction"

View File

@@ -3,9 +3,8 @@
import asyncio
import logging
from datetime import UTC, datetime
from typing import Any
from typing import Any, cast
from prisma.errors import UniqueViolationError
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import (
@@ -16,7 +15,7 @@ from prisma.types import (
)
from backend.data import db
from backend.util.json import SafeJson, sanitize_string
from backend.util.json import SafeJson
from .model import ChatMessage, ChatSession, ChatSessionInfo
@@ -93,24 +92,24 @@ async def add_chat_message(
function_call: dict[str, Any] | None = None,
) -> ChatMessage:
"""Add a message to a chat session."""
# Build ChatMessageCreateInput with only non-None values
# (Prisma TypedDict rejects optional fields set to None)
data: ChatMessageCreateInput = {
# Build input dict dynamically rather than using ChatMessageCreateInput directly
# because Prisma's TypedDict validation rejects optional fields set to None.
# We only include fields that have values, then cast at the end.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": role,
"sequence": sequence,
}
# Add optional string fields — sanitize to strip PostgreSQL-incompatible
# control characters (null bytes etc.) that may appear in tool outputs.
# Add optional string fields
if content is not None:
data["content"] = sanitize_string(content)
data["content"] = content
if name is not None:
data["name"] = name
if tool_call_id is not None:
data["toolCallId"] = tool_call_id
if refusal is not None:
data["refusal"] = sanitize_string(refusal)
data["refusal"] = refusal
# Add optional JSON fields only when they have values
if tool_calls is not None:
@@ -124,7 +123,7 @@ async def add_chat_message(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
),
PrismaChatMessage.prisma().create(data=data),
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
)
return ChatMessage.from_db(message)
@@ -133,94 +132,58 @@ async def add_chat_messages_batch(
session_id: str,
messages: list[dict[str, Any]],
start_sequence: int,
) -> int:
) -> list[ChatMessage]:
"""Add multiple messages to a chat session in a batch.
Uses collision detection with retry: tries to create messages starting
at start_sequence. If a unique constraint violation occurs (e.g., the
streaming loop and long-running callback race), queries the latest
sequence and retries with the correct offset. This avoids unnecessary
upserts and DB queries in the common case (no collision).
Returns:
Next sequence number for the next message to be inserted. This equals
start_sequence + len(messages) and allows callers to update their
counters even when collision detection adjusts start_sequence.
Uses a transaction for atomicity - if any message creation fails,
the entire batch is rolled back.
"""
if not messages:
# No messages to add - return current count
return start_sequence
return []
max_retries = 5
for attempt in range(max_retries):
try:
# Single timestamp for all messages and session update
now = datetime.now(UTC)
created_messages = []
async with db.transaction() as tx:
# Build all message data
messages_data = []
for i, msg in enumerate(messages):
# Build ChatMessageCreateInput with only non-None values
# (Prisma TypedDict rejects optional fields set to None)
# Note: create_many doesn't support nested creates, use sessionId directly
data: ChatMessageCreateInput = {
"sessionId": session_id,
"role": msg["role"],
"sequence": start_sequence + i,
"createdAt": now,
}
async with db.transaction() as tx:
for i, msg in enumerate(messages):
# Build input dict dynamically rather than using ChatMessageCreateInput
# directly because Prisma's TypedDict validation rejects optional fields
# set to None. We only include fields that have values, then cast.
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
# Add optional string fields — sanitize to strip
# PostgreSQL-incompatible control characters.
if msg.get("content") is not None:
data["content"] = sanitize_string(msg["content"])
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = sanitize_string(msg["refusal"])
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
messages_data.append(data)
created = await PrismaChatMessage.prisma(tx).create(
data=cast(ChatMessageCreateInput, data)
)
created_messages.append(created)
# Run create_many and session update in parallel within transaction
# Both use the same timestamp for consistency
await asyncio.gather(
PrismaChatMessage.prisma(tx).create_many(data=messages_data),
PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": now},
),
)
# Update session's updatedAt timestamp within the same transaction.
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
# separately via update_chat_session() after streaming completes.
await PrismaChatSession.prisma(tx).update(
where={"id": session_id},
data={"updatedAt": datetime.now(UTC)},
)
# Return next sequence number for counter sync
return start_sequence + len(messages)
except UniqueViolationError:
if attempt < max_retries - 1:
# Collision detected - query MAX(sequence)+1 and retry with correct offset
logger.info(
f"Collision detected for session {session_id} at sequence "
f"{start_sequence}, querying DB for latest sequence"
)
start_sequence = await get_next_sequence(session_id)
logger.info(
f"Retrying batch insert with start_sequence={start_sequence}"
)
continue
else:
# Max retries exceeded - propagate error
raise
# Should never reach here due to raise in exception handler
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
return [ChatMessage.from_db(m) for m in created_messages]
async def get_user_chat_sessions(
@@ -274,20 +237,10 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
return False
async def get_next_sequence(session_id: str) -> int:
"""Get the next sequence number for a new message in this session.
Uses MAX(sequence) + 1 for robustness. Returns 0 if no messages exist.
More robust than COUNT(*) because it's immune to deleted messages.
Optimized to select only the sequence column using raw SQL.
The unique index on (sessionId, sequence) makes this query fast.
"""
results = await db.query_raw_with_schema(
'SELECT "sequence" FROM {schema_prefix}"ChatMessage" WHERE "sessionId" = $1 ORDER BY "sequence" DESC LIMIT 1',
session_id,
)
return 0 if not results else results[0]["sequence"] + 1
async def get_chat_session_message_count(session_id: str) -> int:
"""Get the number of messages in a chat session."""
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
return count
async def update_tool_message_content(
@@ -314,7 +267,7 @@ async def update_tool_message_content(
"toolCallId": tool_call_id,
},
data={
"content": sanitize_string(new_content),
"content": new_content,
},
)
if result == 0:

View File

@@ -4,7 +4,6 @@ This module contains the CoPilotExecutor class that consumes chat tasks from
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
"""
import asyncio
import logging
import os
import threading
@@ -26,7 +25,7 @@ from backend.util.process import AppProcess
from backend.util.retry import continuous_retry
from backend.util.settings import Settings
from .processor import execute_copilot_turn, init_worker
from .processor import execute_copilot_task, init_worker
from .utils import (
COPILOT_CANCEL_QUEUE_NAME,
COPILOT_EXECUTION_QUEUE_NAME,
@@ -182,13 +181,13 @@ class CoPilotExecutor(AppProcess):
self._executor.shutdown(wait=False)
# Release any remaining locks
for session_id, lock in list(self._task_locks.items()):
for task_id, lock in list(self._task_locks.items()):
try:
lock.release()
logger.info(f"[cleanup {pid}] Released lock for {session_id}")
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
except Exception as e:
logger.error(
f"[cleanup {pid}] Failed to release lock for {session_id}: {e}"
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
)
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
@@ -268,20 +267,20 @@ class CoPilotExecutor(AppProcess):
):
"""Handle cancel message from FANOUT exchange."""
request = CancelCoPilotEvent.model_validate_json(body)
session_id = request.session_id
if not session_id:
logger.warning("Cancel message missing 'session_id'")
task_id = request.task_id
if not task_id:
logger.warning("Cancel message missing 'task_id'")
return
if session_id not in self.active_tasks:
logger.debug(f"Cancel received for {session_id} but not active")
if task_id not in self.active_tasks:
logger.debug(f"Cancel received for {task_id} but not active")
return
_, cancel_event = self.active_tasks[session_id]
logger.info(f"Received cancel for {session_id}")
_, cancel_event = self.active_tasks[task_id]
logger.info(f"Received cancel for {task_id}")
if not cancel_event.is_set():
cancel_event.set()
else:
logger.debug(f"Cancel already set for {session_id}")
logger.debug(f"Cancel already set for {task_id}")
def _handle_run_message(
self,
@@ -353,12 +352,12 @@ class CoPilotExecutor(AppProcess):
ack_message(reject=True, requeue=False)
return
session_id = entry.session_id
task_id = entry.task_id
# Check for local duplicate - session is already running on this executor
if session_id in self.active_tasks:
# Check for local duplicate - task is already running on this executor
if task_id in self.active_tasks:
logger.warning(
f"Session {session_id} already running locally, rejecting duplicate"
f"Task {task_id} already running locally, rejecting duplicate"
)
ack_message(reject=True, requeue=False)
return
@@ -366,69 +365,64 @@ class CoPilotExecutor(AppProcess):
# Try to acquire cluster-wide lock
cluster_lock = ClusterLock(
redis=redis.get_redis(),
key=f"copilot:session:{session_id}:lock",
key=f"copilot:task:{task_id}:lock",
owner_id=self.executor_id,
timeout=settings.config.cluster_lock_timeout,
)
current_owner = cluster_lock.try_acquire()
if current_owner != self.executor_id:
if current_owner is not None:
logger.warning(
f"Session {session_id} already running on pod {current_owner}"
)
logger.warning(f"Task {task_id} already running on pod {current_owner}")
ack_message(reject=True, requeue=False)
else:
logger.warning(
f"Could not acquire lock for {session_id} - Redis unavailable"
f"Could not acquire lock for {task_id} - Redis unavailable"
)
ack_message(reject=True, requeue=True)
return
# Execute the task
try:
self._task_locks[session_id] = cluster_lock
self._task_locks[task_id] = cluster_lock
logger.info(
f"Acquired cluster lock for {session_id}, "
f"executor_id={self.executor_id}"
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
)
cancel_event = threading.Event()
future = self.executor.submit(
execute_copilot_turn, entry, cancel_event, cluster_lock
execute_copilot_task, entry, cancel_event, cluster_lock
)
self.active_tasks[session_id] = (future, cancel_event)
self.active_tasks[task_id] = (future, cancel_event)
except Exception as e:
logger.warning(f"Failed to setup execution for {session_id}: {e}")
logger.warning(f"Failed to setup execution for {task_id}: {e}")
cluster_lock.release()
if session_id in self._task_locks:
del self._task_locks[session_id]
if task_id in self._task_locks:
del self._task_locks[task_id]
ack_message(reject=True, requeue=True)
return
self._update_metrics()
def on_run_done(f: Future):
logger.info(f"Run completed for {session_id}")
error_msg = None
logger.info(f"Run completed for {task_id}")
try:
if exec_error := f.exception():
error_msg = str(exec_error) or type(exec_error).__name__
logger.error(f"Execution for {session_id} failed: {error_msg}")
logger.error(f"Execution for {task_id} failed: {exec_error}")
# Don't requeue failed tasks - they've been marked as failed
# in the stream registry. Requeuing would cause infinite retries
# for deterministic failures.
ack_message(reject=True, requeue=False)
else:
ack_message(reject=False, requeue=False)
except asyncio.CancelledError:
logger.info(f"Run completion callback cancelled for {session_id}")
except BaseException as e:
error_msg = str(e) or type(e).__name__
logger.exception(f"Error in run completion callback: {error_msg}")
logger.exception(f"Error in run completion callback: {e}")
finally:
# Release the cluster lock
if session_id in self._task_locks:
logger.info(f"Releasing cluster lock for {session_id}")
self._task_locks[session_id].release()
del self._task_locks[session_id]
if task_id in self._task_locks:
logger.info(f"Releasing cluster lock for {task_id}")
self._task_locks[task_id].release()
del self._task_locks[task_id]
self._cleanup_completed_tasks()
future.add_done_callback(on_run_done)
@@ -439,11 +433,11 @@ class CoPilotExecutor(AppProcess):
"""Remove completed futures from active_tasks and update metrics."""
completed_tasks = []
with self._active_tasks_lock:
for session_id, (future, _) in list(self.active_tasks.items()):
for task_id, (future, _) in list(self.active_tasks.items()):
if future.done():
completed_tasks.append(session_id)
self.active_tasks.pop(session_id, None)
logger.info(f"Cleaned up completed session {session_id}")
completed_tasks.append(task_id)
self.active_tasks.pop(task_id, None)
logger.info(f"Cleaned up completed task {task_id}")
self._update_metrics()
return completed_tasks

View File

@@ -1,20 +1,18 @@
"""CoPilot execution processor - per-worker execution logic.
This module contains the processor class that handles CoPilot session execution
This module contains the processor class that handles CoPilot task execution
in a thread-local context, following the graph executor pattern.
"""
import asyncio
import logging
import os
import subprocess
import threading
import time
from backend.copilot import service as copilot_service
from backend.copilot import stream_registry
from backend.copilot.baseline import stream_chat_completion_baseline
from backend.copilot.config import ChatConfig
from backend.copilot.response_model import StreamFinish
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
from backend.copilot.sdk import service as sdk_service
from backend.executor.cluster_lock import ClusterLock
from backend.util.decorator import error_logged
@@ -34,17 +32,17 @@ logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]"
_tls = threading.local()
def execute_copilot_turn(
def execute_copilot_task(
entry: CoPilotExecutionEntry,
cancel: threading.Event,
cluster_lock: ClusterLock,
):
"""Execute a single CoPilot turn (user message → AI response).
"""Execute a CoPilot task using the thread-local processor.
This function is the entry point called by the thread pool executor.
Args:
entry: The turn payload
entry: The task payload
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock for this execution
"""
@@ -78,16 +76,16 @@ def cleanup_worker():
class CoPilotProcessor:
"""Per-worker execution logic for CoPilot sessions.
"""Per-worker execution logic for CoPilot tasks.
This class is instantiated once per worker thread and handles the execution
of CoPilot chat generation sessions. It maintains an async event loop for
of CoPilot chat generation tasks. It maintains an async event loop for
running the async service code.
The execution flow:
1. Session entry is picked from RabbitMQ queue
2. Manager submits to thread pool
3. Processor executes in its event loop
1. CoPilot task is picked from RabbitMQ queue
2. Manager submits task to thread pool
3. Processor executes the task in its event loop
4. Results are published to Redis Streams
"""
@@ -110,41 +108,8 @@ class CoPilotProcessor:
)
self.execution_thread.start()
# Skip the SDK's per-request CLI version check — the bundled CLI is
# already version-matched to the SDK package.
os.environ.setdefault("CLAUDE_AGENT_SDK_SKIP_VERSION_CHECK", "1")
# Pre-warm the bundled CLI binary so the OS page-caches the ~185 MB
# executable. First spawn pays ~1.2 s; subsequent spawns ~0.65 s.
self._prewarm_cli()
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
def _prewarm_cli(self) -> None:
"""Run the bundled CLI binary once to warm OS page caches."""
try:
from claude_agent_sdk._internal.transport.subprocess_cli import (
SubprocessCLITransport,
)
cli_path = SubprocessCLITransport._find_bundled_cli(None) # type: ignore[arg-type]
if cli_path:
result = subprocess.run(
[cli_path, "-v"],
capture_output=True,
timeout=10,
)
if result.returncode == 0:
logger.info(f"[CoPilotExecutor] CLI pre-warm done: {cli_path}")
else:
logger.warning(
"[CoPilotExecutor] CLI pre-warm failed (rc=%d): %s",
result.returncode, # type: ignore[reportCallIssue]
cli_path,
)
except Exception as e:
logger.debug(f"[CoPilotExecutor] CLI pre-warm skipped: {e}")
def cleanup(self):
"""Clean up event-loop-bound resources before the loop is destroyed.
@@ -154,16 +119,13 @@ class CoPilotProcessor:
"""
from backend.util.workspace_storage import shutdown_workspace_storage
coro = shutdown_workspace_storage()
try:
future = asyncio.run_coroutine_threadsafe(coro, self.execution_loop)
future = asyncio.run_coroutine_threadsafe(
shutdown_workspace_storage(), self.execution_loop
)
future.result(timeout=5)
except Exception as e:
coro.close() # Prevent "coroutine was never awaited" warning
error_msg = str(e) or type(e).__name__
logger.warning(
f"[CoPilotExecutor] Worker {self.tid} cleanup error: {error_msg}"
)
logger.warning(f"[CoPilotExecutor] Worker {self.tid} cleanup error: {e}")
# Stop the event loop
self.execution_loop.call_soon_threadsafe(self.execution_loop.stop)
@@ -177,17 +139,19 @@ class CoPilotProcessor:
cancel: threading.Event,
cluster_lock: ClusterLock,
):
"""Execute a CoPilot turn.
"""Execute a CoPilot task.
Runs the async logic in the worker's event loop and handles errors.
This is the main entry point for task execution. It runs the async
execution logic in the worker's event loop and handles errors.
Args:
entry: The turn payload containing session and message info
entry: The task payload containing session and message info
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock to prevent duplicate execution
"""
log = CoPilotLogMetadata(
logging.getLogger(__name__),
task_id=entry.task_id,
session_id=entry.session_id,
user_id=entry.user_id,
)
@@ -195,30 +159,38 @@ class CoPilotProcessor:
start_time = time.monotonic()
# Run the async execution in our event loop
future = asyncio.run_coroutine_threadsafe(
self._execute_async(entry, cancel, cluster_lock, log),
self.execution_loop,
)
try:
# Run the async execution in our event loop
future = asyncio.run_coroutine_threadsafe(
self._execute_async(entry, cancel, cluster_lock, log),
self.execution_loop,
)
# Wait for completion, checking cancel periodically
while not future.done():
try:
future.result(timeout=1.0)
except asyncio.TimeoutError:
if cancel.is_set():
log.info("Cancellation requested")
future.cancel()
break
# Refresh cluster lock to maintain ownership
cluster_lock.refresh()
# Wait for completion, checking cancel periodically
while not future.done():
try:
future.result(timeout=1.0)
except asyncio.TimeoutError:
if cancel.is_set():
log.info("Cancellation requested")
future.cancel()
break
# Refresh cluster lock to maintain ownership
cluster_lock.refresh()
if not future.cancelled():
# Get result to propagate any exceptions
future.result()
if not future.cancelled():
# Get result to propagate any exceptions
future.result()
elapsed = time.monotonic() - start_time
log.info(f"Execution completed in {elapsed:.2f}s")
elapsed = time.monotonic() - start_time
log.info(f"Execution completed in {elapsed:.2f}s")
except Exception as e:
elapsed = time.monotonic() - start_time
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
# Note: _execute_async already marks the task as failed before re-raising,
# so we don't call _mark_task_failed here to avoid duplicate error events.
raise
async def _execute_async(
self,
@@ -227,26 +199,24 @@ class CoPilotProcessor:
cluster_lock: ClusterLock,
log: CoPilotLogMetadata,
):
"""Async execution logic for a CoPilot turn.
"""Async execution logic for CoPilot task.
Calls the chat completion service (SDK or baseline) and publishes
results to the stream registry.
This method calls the existing stream_chat_completion service function
and publishes results to the stream registry.
Args:
entry: The turn payload
entry: The task payload
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock for refresh
log: Structured logger
log: Structured logger for this task
"""
last_refresh = time.monotonic()
refresh_interval = 30.0 # Refresh lock every 30 seconds
error_msg = None
try:
# Choose service based on LaunchDarkly flag.
# Claude Code subscription forces SDK mode (CLI subprocess auth).
# Choose service based on LaunchDarkly flag
config = ChatConfig()
use_sdk = config.use_claude_code_subscription or await is_feature_enabled(
use_sdk = await is_feature_enabled(
Flag.COPILOT_SDK,
entry.user_id or "anonymous",
default=config.use_claude_agent_sdk,
@@ -254,60 +224,64 @@ class CoPilotProcessor:
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else stream_chat_completion_baseline
else copilot_service.stream_chat_completion
)
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
# Stream chat completion and publish chunks to Redis.
# Stream chat completion and publish chunks to Redis
async for chunk in stream_fn(
session_id=entry.session_id,
message=entry.message if entry.message else None,
is_user_message=entry.is_user_message,
user_id=entry.user_id,
context=entry.context,
file_ids=entry.file_ids,
):
# Check for cancellation
if cancel.is_set():
log.info("Cancel requested, breaking stream")
break
log.info("Cancelled during streaming")
await stream_registry.publish_chunk(
entry.task_id, StreamError(errorText="Operation cancelled")
)
await stream_registry.publish_chunk(
entry.task_id, StreamFinishStep()
)
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
await stream_registry.mark_task_completed(
entry.task_id, status="failed"
)
return
# Refresh cluster lock periodically
current_time = time.monotonic()
if current_time - last_refresh >= refresh_interval:
cluster_lock.refresh()
last_refresh = current_time
# Skip StreamFinish — mark_session_completed publishes it.
if isinstance(chunk, StreamFinish):
continue
# Publish chunk to stream registry
await stream_registry.publish_chunk(entry.task_id, chunk)
try:
await stream_registry.publish_chunk(entry.turn_id, chunk)
except Exception as e:
log.error(
f"Error publishing chunk {type(chunk).__name__}: {e}",
exc_info=True,
)
# Mark task as completed
await stream_registry.mark_task_completed(entry.task_id, status="completed")
log.info("Task completed successfully")
# Stream loop completed
if cancel.is_set():
log.info("Stream cancelled by user")
except BaseException as e:
# Handle all exceptions (including CancelledError) with appropriate logging
if isinstance(e, asyncio.CancelledError):
log.info("Turn cancelled")
error_msg = "Operation cancelled"
else:
error_msg = str(e) or type(e).__name__
log.error(f"Turn failed: {error_msg}")
except asyncio.CancelledError:
log.info("Task cancelled")
await stream_registry.mark_task_completed(entry.task_id, status="failed")
raise
finally:
# If no exception but user cancelled, still mark as cancelled
if not error_msg and cancel.is_set():
error_msg = "Operation cancelled"
try:
await stream_registry.mark_session_completed(
entry.session_id, error_message=error_msg
)
except Exception as mark_err:
log.error(f"Failed to mark session completed: {mark_err}")
except Exception as e:
log.error(f"Task failed: {e}")
await self._mark_task_failed(entry.task_id, str(e))
raise
async def _mark_task_failed(self, task_id: str, error_message: str):
"""Mark a task as failed and publish error to stream registry."""
try:
await stream_registry.publish_chunk(
task_id, StreamError(errorText=error_message)
)
await stream_registry.publish_chunk(task_id, StreamFinishStep())
await stream_registry.publish_chunk(task_id, StreamFinish())
await stream_registry.mark_task_completed(task_id, status="failed")
except Exception as e:
logger.error(f"Failed to mark task {task_id} as failed: {e}")

View File

@@ -28,7 +28,7 @@ class CoPilotLogMetadata(TruncatedLogger):
Args:
logger: The underlying logger instance
max_length: Maximum log message length before truncation
**kwargs: Metadata key-value pairs (e.g., session_id="xyz", turn_id="abc")
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
These are added to json_fields in cloud mode, or to the prefix in local mode.
"""
@@ -135,15 +135,18 @@ class CoPilotExecutionEntry(BaseModel):
This model represents a chat generation task to be processed by the executor.
"""
session_id: str
"""Chat session ID (also used for dedup/locking)"""
task_id: str
"""Unique identifier for this task (used for stream registry)"""
turn_id: str = ""
"""Per-turn UUID for Redis stream isolation"""
session_id: str
"""Chat session ID"""
user_id: str | None
"""User ID (may be None for anonymous users)"""
operation_id: str
"""Operation ID for webhook callbacks and completion tracking"""
message: str
"""User's message to process"""
@@ -153,50 +156,47 @@ class CoPilotExecutionEntry(BaseModel):
context: dict[str, str] | None = None
"""Optional context for the message (e.g., {url: str, content: str})"""
file_ids: list[str] | None = None
"""Workspace file IDs attached to the user's message"""
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
session_id: str
"""Session ID to cancel"""
task_id: str
"""Task ID to cancel"""
# ============ Queue Publishing Helpers ============ #
async def enqueue_copilot_turn(
async def enqueue_copilot_task(
task_id: str,
session_id: str,
user_id: str | None,
operation_id: str,
message: str,
turn_id: str,
is_user_message: bool = True,
context: dict[str, str] | None = None,
file_ids: list[str] | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
Args:
session_id: Chat session ID (also used for dedup/locking)
task_id: Unique identifier for this task (used for stream registry)
session_id: Chat session ID
user_id: User ID (may be None for anonymous users)
operation_id: Operation ID for webhook callbacks and completion tracking
message: User's message to process
turn_id: Per-turn UUID for Redis stream isolation
is_user_message: Whether the message is from the user (vs system/assistant)
context: Optional context for the message (e.g., {url: str, content: str})
file_ids: Optional workspace file IDs attached to the user's message
"""
from backend.util.clients import get_async_copilot_queue
entry = CoPilotExecutionEntry(
task_id=task_id,
session_id=session_id,
turn_id=turn_id,
user_id=user_id,
operation_id=operation_id,
message=message,
is_user_message=is_user_message,
context=context,
file_ids=file_ids,
)
queue_client = await get_async_copilot_queue()
@@ -207,15 +207,15 @@ async def enqueue_copilot_turn(
)
async def enqueue_cancel_task(session_id: str) -> None:
"""Publish a cancel request for a running CoPilot session.
async def enqueue_cancel_task(task_id: str) -> None:
"""Publish a cancel request for a running CoPilot task.
Sends a ``CancelCoPilotEvent`` to the FANOUT exchange so all executor
pods receive the cancellation signal.
"""
from backend.util.clients import get_async_copilot_queue
event = CancelCoPilotEvent(session_id=session_id)
event = CancelCoPilotEvent(task_id=task_id)
queue_client = await get_async_copilot_queue()
await queue_client.publish_message(
routing_key="", # FANOUT ignores routing key

View File

@@ -432,9 +432,7 @@ async def _get_session_from_db(session_id: str) -> ChatSession | None:
return session
async def upsert_chat_session(
session: ChatSession,
) -> ChatSession:
async def upsert_chat_session(session: ChatSession) -> ChatSession:
"""Update a chat session in both cache and database.
Uses session-level locking to prevent race conditions when concurrent
@@ -451,18 +449,16 @@ async def upsert_chat_session(
lock = await _get_session_lock(session.session_id)
async with lock:
# Always query DB for existing message count to ensure consistency
existing_message_count = await chat_db().get_next_sequence(session.session_id)
# Get existing message count from DB for incremental saves
existing_message_count = await chat_db().get_chat_session_message_count(
session.session_id
)
db_error: Exception | None = None
# Save to database (primary storage)
try:
await _save_session_to_db(
session,
existing_message_count,
skip_existence_check=existing_message_count > 0,
)
await _save_session_to_db(session, existing_message_count)
except Exception as e:
logger.error(
f"Failed to save session {session.session_id} to database: {e}"
@@ -493,31 +489,21 @@ async def upsert_chat_session(
async def _save_session_to_db(
session: ChatSession,
existing_message_count: int,
*,
skip_existence_check: bool = False,
session: ChatSession, existing_message_count: int
) -> None:
"""Save or update a chat session in the database.
Args:
skip_existence_check: When True, skip the ``get_chat_session`` query
and assume the session row already exists. Saves one DB round trip
for incremental saves during streaming.
"""
"""Save or update a chat session in the database."""
db = chat_db()
if not skip_existence_check:
# Check if session exists in DB
existing = await db.get_chat_session(session.session_id)
# Check if session exists in DB
existing = await db.get_chat_session(session.session_id)
if not existing:
# Create new session
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
if not existing:
# Create new session
await db.create_chat_session(
session_id=session.session_id,
user_id=session.user_id,
)
existing_message_count = 0
# Calculate total tokens from usage
total_prompt = sum(u.prompt_tokens for u in session.usage)
@@ -576,7 +562,9 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
raise ValueError(f"Session {session_id} not found")
session.messages.append(message)
existing_message_count = await chat_db().get_next_sequence(session_id)
existing_message_count = await chat_db().get_chat_session_message_count(
session_id
)
try:
await _save_session_to_db(session, existing_message_count)
@@ -672,16 +660,6 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
async with _session_locks_mutex:
_session_locks.pop(session_id, None)
# Shut down any local browser daemon for this session (best-effort).
# Inline import required: all tool modules import ChatSession from this
# module, so any top-level import from tools.* would create a cycle.
try:
from .tools.agent_browser import close_browser_session
await close_browser_session(session_id, user_id=user_id)
except Exception as e:
logger.debug(f"Browser cleanup for session {session_id}: {e}")
return True
@@ -705,10 +683,19 @@ async def update_session_title(session_id: str, title: str) -> bool:
logger.warning(f"Session {session_id} not found for title update")
return False
# Invalidate the cache so the next access reloads from DB with the
# updated title. This avoids a read-modify-write on the full session
# blob, which could overwrite concurrent message updates.
await invalidate_session_cache(session_id)
# Update title in cache if it exists (instead of invalidating).
# This prevents race conditions where cache invalidation causes
# the frontend to see stale DB data while streaming is still in progress.
try:
cached = await _get_session_from_cache(session_id)
if cached:
cached.title = title
await cache_chat_session(cached)
except Exception as e:
# Not critical - title will be correct on next full cache refresh
logger.warning(
f"Failed to update title in cache for session {session_id}: {e}"
)
return True
except Exception as e:

View File

@@ -331,96 +331,3 @@ def test_to_openai_messages_merges_split_assistants():
tc_list = merged.get("tool_calls")
assert tc_list is not None and len(list(tc_list)) == 1
assert list(tc_list)[0]["id"] == "tc1"
# --------------------------------------------------------------------------- #
# Concurrent save collision detection #
# --------------------------------------------------------------------------- #
@pytest.mark.asyncio(loop_scope="session")
async def test_concurrent_saves_collision_detection(setup_test_user, test_user_id):
"""Test that concurrent saves from streaming loop and callback handle collisions correctly.
Simulates the race condition where:
1. Streaming loop starts with saved_msg_count=5
2. Long-running callback appends message #5 and saves
3. Streaming loop tries to save with stale count=5
The collision detection should handle this gracefully.
"""
import asyncio
# Create a session with initial messages
session = ChatSession.new(user_id=test_user_id)
for i in range(3):
session.messages.append(
ChatMessage(
role="user" if i % 2 == 0 else "assistant", content=f"Message {i}"
)
)
# Save initial messages
session = await upsert_chat_session(session)
# Simulate streaming loop and callback saving concurrently
async def streaming_loop_save():
"""Simulates streaming loop saving messages."""
# Add 2 messages
session.messages.append(ChatMessage(role="user", content="Streaming message 1"))
session.messages.append(
ChatMessage(role="assistant", content="Streaming message 2")
)
# Wait a bit to let callback potentially save first
await asyncio.sleep(0.01)
# Save (will query DB for existing count)
return await upsert_chat_session(session)
async def callback_save():
"""Simulates long-running callback saving a message."""
# Add 1 message
session.messages.append(
ChatMessage(role="tool", content="Callback result", tool_call_id="tc1")
)
# Save immediately (will query DB for existing count)
return await upsert_chat_session(session)
# Run both saves concurrently - one will hit collision detection
results = await asyncio.gather(streaming_loop_save(), callback_save())
# Both should succeed
assert all(r is not None for r in results)
# Reload session from DB to verify
from backend.data.redis_client import get_redis_async
redis_key = f"chat:session:{session.session_id}"
async_redis = await get_redis_async()
await async_redis.delete(redis_key) # Clear cache to force DB load
loaded_session = await get_chat_session(session.session_id, test_user_id)
assert loaded_session is not None
# Should have all 6 messages (3 initial + 2 streaming + 1 callback)
assert len(loaded_session.messages) == 6
# Verify no duplicate sequences
sequences = []
for i, msg in enumerate(loaded_session.messages):
# Messages should have sequential sequence numbers starting from 0
sequences.append(i)
# All sequences should be unique and sequential
assert sequences == list(range(6))
# Verify message content is preserved
contents = [m.content for m in loaded_session.messages]
assert "Message 0" in contents
assert "Message 1" in contents
assert "Message 2" in contents
assert "Streaming message 1" in contents
assert "Streaming message 2" in contents
assert "Callback result" in contents

View File

@@ -0,0 +1,272 @@
"""Tests for parallel tool call execution in CoPilot.
These tests mock _yield_tool_call to avoid importing the full copilot stack
which requires Prisma, DB connections, etc.
"""
import asyncio
import time
from typing import Any, cast
import pytest
@pytest.mark.asyncio
async def test_parallel_tool_calls_run_concurrently():
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
# Import here to allow module-level mocking if needed
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
n_tools = 3
delay_per_tool = 0.2
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"tool_{i}", "arguments": "{}"},
}
for i in range(n_tools)
]
# Minimal session mock
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
original_yield = None
async def fake_yield(tc_list, idx, sess, lock=None):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"],
toolName=tc_list[idx]["function"]["name"],
input={},
)
await asyncio.sleep(delay_per_tool)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"],
toolName=tc_list[idx]["function"]["name"],
output="{}",
)
import backend.copilot.service as svc
original_yield = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
start = time.monotonic()
events = []
async for event in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
events.append(event)
elapsed = time.monotonic() - start
finally:
svc._yield_tool_call = original_yield
assert len(events) == n_tools * 2
# Parallel: should take ~delay, not ~n*delay
assert elapsed < delay_per_tool * (
n_tools - 0.5
), f"Took {elapsed:.2f}s, expected parallel (~{delay_per_tool}s)"
@pytest.mark.asyncio
async def test_single_tool_call_works():
"""Single tool call should work identically."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": "call_0",
"type": "function",
"function": {"name": "t", "arguments": "{}"},
}
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess, lock=None):
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
events = [
e
async for e in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
)
]
finally:
svc._yield_tool_call = orig
assert len(events) == 2
@pytest.mark.asyncio
async def test_retryable_error_propagates():
"""Retryable errors should be raised after all tools finish."""
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(2)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
async def fake_yield(tc_list, idx, sess, lock=None):
if idx == 1:
raise KeyError("bad")
from backend.copilot.response_model import StreamToolInputAvailable
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName="t_0", input={}
)
await asyncio.sleep(0.05)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"], toolName="t_0", output="{}"
)
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
events = []
with pytest.raises(KeyError):
async for event in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
events.append(event)
# First tool's events should still be yielded
assert any(isinstance(e, StreamToolOutputAvailable) for e in events)
finally:
svc._yield_tool_call = orig
@pytest.mark.asyncio
async def test_session_lock_shared():
"""All parallel tools should receive the same lock instance."""
from backend.copilot.response_model import (
StreamToolInputAvailable,
StreamToolOutputAvailable,
)
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(3)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
observed_locks = []
async def fake_yield(tc_list, idx, sess, lock=None):
observed_locks.append(lock)
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)
yield StreamToolOutputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", output="{}"
)
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
async for _ in _execute_tool_calls_parallel(
tool_calls, cast(Any, FakeSession())
):
pass
finally:
svc._yield_tool_call = orig
assert len(observed_locks) == 3
assert observed_locks[0] is observed_locks[1] is observed_locks[2]
assert isinstance(observed_locks[0], asyncio.Lock)
@pytest.mark.asyncio
async def test_cancellation_cleans_up():
"""Generator close should cancel in-flight tasks."""
from backend.copilot.response_model import StreamToolInputAvailable
from backend.copilot.service import _execute_tool_calls_parallel
tool_calls = [
{
"id": f"call_{i}",
"type": "function",
"function": {"name": f"t_{i}", "arguments": "{}"},
}
for i in range(2)
]
class FakeSession:
session_id = "test"
user_id = "test"
def __init__(self):
self.messages = []
started = asyncio.Event()
async def fake_yield(tc_list, idx, sess, lock=None):
yield StreamToolInputAvailable(
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
)
started.set()
await asyncio.sleep(10) # simulate long-running
import backend.copilot.service as svc
orig = svc._yield_tool_call
svc._yield_tool_call = fake_yield
try:
gen = _execute_tool_calls_parallel(tool_calls, cast(Any, FakeSession()))
await gen.__anext__() # get first event
await started.wait()
await gen.aclose() # close generator
finally:
svc._yield_tool_call = orig
# If we get here without hanging, cleanup worked

View File

@@ -5,17 +5,12 @@ This module implements the AI SDK UI Stream Protocol (v1) for streaming chat res
See: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol
"""
import json
import logging
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
from backend.util.json import dumps as json_dumps
from backend.util.truncate import truncate
logger = logging.getLogger(__name__)
class ResponseType(str, Enum):
@@ -52,8 +47,7 @@ class StreamBaseResponse(BaseModel):
def to_sse(self) -> str:
"""Convert to SSE format."""
json_str = self.model_dump_json(exclude_none=True)
return f"data: {json_str}\n\n"
return f"data: {self.model_dump_json()}\n\n"
# ========== Message Lifecycle ==========
@@ -64,13 +58,15 @@ class StreamStart(StreamBaseResponse):
type: ResponseType = ResponseType.START
messageId: str = Field(..., description="Unique message ID")
sessionId: str | None = Field(
taskId: str | None = Field(
default=None,
description="Session ID for SSE reconnection.",
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-protocol fields like sessionId."""
"""Convert to SSE format, excluding non-protocol fields like taskId."""
import json
data: dict[str, Any] = {
"type": self.type.value,
"messageId": self.messageId,
@@ -151,9 +147,6 @@ class StreamToolInputAvailable(StreamBaseResponse):
)
_MAX_TOOL_OUTPUT_SIZE = 100_000 # ~100 KB; truncate to avoid bloating SSE/DB
class StreamToolOutputAvailable(StreamBaseResponse):
"""Tool execution result."""
@@ -168,12 +161,10 @@ class StreamToolOutputAvailable(StreamBaseResponse):
default=True, description="Whether the tool execution succeeded"
)
def model_post_init(self, __context: Any) -> None:
"""Truncate oversized outputs after construction."""
self.output = truncate(self.output, _MAX_TOOL_OUTPUT_SIZE)
def to_sse(self) -> str:
"""Convert to SSE format, excluding non-spec fields."""
import json
data = {
"type": self.type.value,
"toolCallId": self.toolCallId,

View File

@@ -1,239 +0,0 @@
"""Compaction tracking for SDK-based chat sessions.
Encapsulates the state machine and event emission for context compaction,
both pre-query (history compressed before SDK query) and SDK-internal
(PreCompact hook fires mid-stream).
All compaction-related helpers live here: event builders, message filtering,
persistence, and the ``CompactionTracker`` state machine.
"""
import asyncio
import logging
import uuid
from collections.abc import Callable
from ..constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
from ..model import ChatMessage, ChatSession
from ..response_model import (
StreamBaseResponse,
StreamFinishStep,
StreamStartStep,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Event builders (private — use CompactionTracker or compaction_events)
# ---------------------------------------------------------------------------
def _start_events(tool_call_id: str) -> list[StreamBaseResponse]:
"""Build the opening events for a compaction tool call."""
return [
StreamStartStep(),
StreamToolInputStart(toolCallId=tool_call_id, toolName=COMPACTION_TOOL_NAME),
StreamToolInputAvailable(
toolCallId=tool_call_id, toolName=COMPACTION_TOOL_NAME, input={}
),
]
def _end_events(tool_call_id: str, message: str) -> list[StreamBaseResponse]:
"""Build the closing events for a compaction tool call."""
return [
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=COMPACTION_TOOL_NAME,
output=message,
),
StreamFinishStep(),
]
def _new_tool_call_id() -> str:
return f"compaction-{uuid.uuid4().hex[:12]}"
# ---------------------------------------------------------------------------
# Public event builder
# ---------------------------------------------------------------------------
def emit_compaction(session: ChatSession) -> list[StreamBaseResponse]:
"""Create, persist, and return a self-contained compaction tool call.
Convenience for callers that don't use ``CompactionTracker`` (e.g. the
legacy non-SDK streaming path in ``service.py``).
"""
tc_id = _new_tool_call_id()
evts = compaction_events(COMPACTION_DONE_MSG, tool_call_id=tc_id)
_persist(session, tc_id, COMPACTION_DONE_MSG)
return evts
def compaction_events(
message: str, tool_call_id: str | None = None
) -> list[StreamBaseResponse]:
"""Emit a self-contained compaction tool call (already completed).
When *tool_call_id* is provided it is reused (e.g. for persistence that
must match an already-streamed start event). Otherwise a new ID is
generated.
"""
tc_id = tool_call_id or _new_tool_call_id()
return _start_events(tc_id) + _end_events(tc_id, message)
# ---------------------------------------------------------------------------
# Message filtering
# ---------------------------------------------------------------------------
def filter_compaction_messages(
messages: list[ChatMessage],
) -> list[ChatMessage]:
"""Remove synthetic compaction tool-call messages (UI-only artifacts).
Strips assistant messages whose only tool calls are compaction calls,
and their corresponding tool-result messages.
"""
compaction_ids: set[str] = set()
filtered: list[ChatMessage] = []
for msg in messages:
if msg.role == "assistant" and msg.tool_calls:
for tc in msg.tool_calls:
if tc.get("function", {}).get("name") == COMPACTION_TOOL_NAME:
compaction_ids.add(tc.get("id", ""))
real_calls = [
tc
for tc in msg.tool_calls
if tc.get("function", {}).get("name") != COMPACTION_TOOL_NAME
]
if not real_calls and not msg.content:
continue
if msg.role == "tool" and msg.tool_call_id in compaction_ids:
continue
filtered.append(msg)
return filtered
# ---------------------------------------------------------------------------
# Persistence
# ---------------------------------------------------------------------------
def _persist(session: ChatSession, tool_call_id: str, message: str) -> None:
"""Append compaction tool-call + result to session messages.
Compaction events are synthetic so they bypass the normal adapter
accumulation. This explicitly records them so they survive a page refresh.
"""
session.messages.append(
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": tool_call_id,
"type": "function",
"function": {
"name": COMPACTION_TOOL_NAME,
"arguments": "{}",
},
}
],
)
)
session.messages.append(
ChatMessage(role="tool", content=message, tool_call_id=tool_call_id)
)
# ---------------------------------------------------------------------------
# CompactionTracker — state machine for streaming sessions
# ---------------------------------------------------------------------------
class CompactionTracker:
"""Tracks compaction state and yields UI events.
Two compaction paths:
1. **Pre-query** — history compressed before the SDK query starts.
Call :meth:`emit_pre_query` to yield a self-contained tool call.
2. **SDK-internal** — ``PreCompact`` hook fires mid-stream.
Call :meth:`emit_start_if_ready` on heartbeat ticks and
:meth:`emit_end_if_ready` when a message arrives.
"""
def __init__(self) -> None:
self._compact_start = asyncio.Event()
self._start_emitted = False
self._done = False
self._tool_call_id = ""
@property
def on_compact(self) -> Callable[[], None]:
"""Callback for the PreCompact hook."""
return self._compact_start.set
# ------------------------------------------------------------------
# Pre-query compaction
# ------------------------------------------------------------------
def emit_pre_query(self, session: ChatSession) -> list[StreamBaseResponse]:
"""Emit + persist a self-contained compaction tool call."""
self._done = True
return emit_compaction(session)
# ------------------------------------------------------------------
# SDK-internal compaction
# ------------------------------------------------------------------
def reset_for_query(self) -> None:
"""Reset per-query state before a new SDK query."""
self._done = False
self._start_emitted = False
self._tool_call_id = ""
def emit_start_if_ready(self) -> list[StreamBaseResponse]:
"""If the PreCompact hook fired, emit start events (spinning tool)."""
if self._compact_start.is_set() and not self._start_emitted and not self._done:
self._compact_start.clear()
self._start_emitted = True
self._tool_call_id = _new_tool_call_id()
return _start_events(self._tool_call_id)
return []
async def emit_end_if_ready(self, session: ChatSession) -> list[StreamBaseResponse]:
"""If compaction is in progress, emit end events and persist."""
# Yield so pending hook tasks can set compact_start
await asyncio.sleep(0)
if self._done:
return []
if not self._start_emitted and not self._compact_start.is_set():
return []
if self._start_emitted:
# Close the open spinner
done_events = _end_events(self._tool_call_id, COMPACTION_DONE_MSG)
persist_id = self._tool_call_id
else:
# PreCompact fired but start never emitted — self-contained
persist_id = _new_tool_call_id()
done_events = compaction_events(
COMPACTION_DONE_MSG, tool_call_id=persist_id
)
self._compact_start.clear()
self._start_emitted = False
self._done = True
_persist(session, persist_id, COMPACTION_DONE_MSG)
return done_events

View File

@@ -1,291 +0,0 @@
"""Tests for sdk/compaction.py — event builders, filtering, persistence, and
CompactionTracker state machine."""
import pytest
from backend.copilot.constants import COMPACTION_DONE_MSG, COMPACTION_TOOL_NAME
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.response_model import (
StreamFinishStep,
StreamStartStep,
StreamToolInputAvailable,
StreamToolInputStart,
StreamToolOutputAvailable,
)
from backend.copilot.sdk.compaction import (
CompactionTracker,
compaction_events,
emit_compaction,
filter_compaction_messages,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session() -> ChatSession:
return ChatSession.new(user_id="test-user")
# ---------------------------------------------------------------------------
# compaction_events
# ---------------------------------------------------------------------------
class TestCompactionEvents:
def test_returns_start_and_end_events(self):
evts = compaction_events("done")
assert len(evts) == 5
assert isinstance(evts[0], StreamStartStep)
assert isinstance(evts[1], StreamToolInputStart)
assert isinstance(evts[2], StreamToolInputAvailable)
assert isinstance(evts[3], StreamToolOutputAvailable)
assert isinstance(evts[4], StreamFinishStep)
def test_uses_provided_tool_call_id(self):
evts = compaction_events("msg", tool_call_id="my-id")
tool_start = evts[1]
assert isinstance(tool_start, StreamToolInputStart)
assert tool_start.toolCallId == "my-id"
def test_generates_id_when_not_provided(self):
evts = compaction_events("msg")
tool_start = evts[1]
assert isinstance(tool_start, StreamToolInputStart)
assert tool_start.toolCallId.startswith("compaction-")
def test_tool_name_is_context_compaction(self):
evts = compaction_events("msg")
tool_start = evts[1]
assert isinstance(tool_start, StreamToolInputStart)
assert tool_start.toolName == COMPACTION_TOOL_NAME
# ---------------------------------------------------------------------------
# emit_compaction
# ---------------------------------------------------------------------------
class TestEmitCompaction:
def test_persists_to_session(self):
session = _make_session()
assert len(session.messages) == 0
evts = emit_compaction(session)
assert len(evts) == 5
# Should have appended 2 messages (assistant tool call + tool result)
assert len(session.messages) == 2
assert session.messages[0].role == "assistant"
assert session.messages[0].tool_calls is not None
assert (
session.messages[0].tool_calls[0]["function"]["name"]
== COMPACTION_TOOL_NAME
)
assert session.messages[1].role == "tool"
assert session.messages[1].content == COMPACTION_DONE_MSG
# ---------------------------------------------------------------------------
# filter_compaction_messages
# ---------------------------------------------------------------------------
class TestFilterCompactionMessages:
def test_removes_compaction_tool_calls(self):
msgs = [
ChatMessage(role="user", content="hello"),
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": "comp-1",
"type": "function",
"function": {"name": COMPACTION_TOOL_NAME, "arguments": "{}"},
}
],
),
ChatMessage(
role="tool", content=COMPACTION_DONE_MSG, tool_call_id="comp-1"
),
ChatMessage(role="assistant", content="world"),
]
filtered = filter_compaction_messages(msgs)
assert len(filtered) == 2
assert filtered[0].content == "hello"
assert filtered[1].content == "world"
def test_keeps_non_compaction_tool_calls(self):
msgs = [
ChatMessage(
role="assistant",
content="",
tool_calls=[
{
"id": "real-1",
"type": "function",
"function": {"name": "search", "arguments": "{}"},
}
],
),
ChatMessage(role="tool", content="result", tool_call_id="real-1"),
]
filtered = filter_compaction_messages(msgs)
assert len(filtered) == 2
def test_keeps_assistant_with_content_and_compaction_call(self):
"""If assistant message has both content and a compaction tool call,
the message is kept (has real content)."""
msgs = [
ChatMessage(
role="assistant",
content="I have content",
tool_calls=[
{
"id": "comp-1",
"type": "function",
"function": {"name": COMPACTION_TOOL_NAME, "arguments": "{}"},
}
],
),
]
filtered = filter_compaction_messages(msgs)
assert len(filtered) == 1
def test_empty_list(self):
assert filter_compaction_messages([]) == []
# ---------------------------------------------------------------------------
# CompactionTracker
# ---------------------------------------------------------------------------
class TestCompactionTracker:
def test_on_compact_sets_event(self):
tracker = CompactionTracker()
tracker.on_compact()
assert tracker._compact_start.is_set()
def test_emit_start_if_ready_no_event(self):
tracker = CompactionTracker()
assert tracker.emit_start_if_ready() == []
def test_emit_start_if_ready_with_event(self):
tracker = CompactionTracker()
tracker.on_compact()
evts = tracker.emit_start_if_ready()
assert len(evts) == 3
assert isinstance(evts[0], StreamStartStep)
assert isinstance(evts[1], StreamToolInputStart)
assert isinstance(evts[2], StreamToolInputAvailable)
def test_emit_start_only_once(self):
tracker = CompactionTracker()
tracker.on_compact()
evts1 = tracker.emit_start_if_ready()
assert len(evts1) == 3
# Second call should return empty
evts2 = tracker.emit_start_if_ready()
assert evts2 == []
@pytest.mark.asyncio
async def test_emit_end_after_start(self):
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
tracker.emit_start_if_ready()
evts = await tracker.emit_end_if_ready(session)
assert len(evts) == 2
assert isinstance(evts[0], StreamToolOutputAvailable)
assert isinstance(evts[1], StreamFinishStep)
# Should persist
assert len(session.messages) == 2
@pytest.mark.asyncio
async def test_emit_end_without_start_self_contained(self):
"""If PreCompact fired but start was never emitted, emit_end
produces a self-contained compaction event."""
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
# Don't call emit_start_if_ready
evts = await tracker.emit_end_if_ready(session)
assert len(evts) == 5 # Full self-contained event
assert isinstance(evts[0], StreamStartStep)
assert len(session.messages) == 2
@pytest.mark.asyncio
async def test_emit_end_no_op_when_done(self):
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
tracker.emit_start_if_ready()
await tracker.emit_end_if_ready(session)
# Second call should be no-op
evts = await tracker.emit_end_if_ready(session)
assert evts == []
@pytest.mark.asyncio
async def test_emit_end_no_op_when_nothing_happened(self):
tracker = CompactionTracker()
session = _make_session()
evts = await tracker.emit_end_if_ready(session)
assert evts == []
def test_emit_pre_query(self):
tracker = CompactionTracker()
session = _make_session()
evts = tracker.emit_pre_query(session)
assert len(evts) == 5
assert len(session.messages) == 2
assert tracker._done is True
def test_reset_for_query(self):
tracker = CompactionTracker()
tracker._done = True
tracker._start_emitted = True
tracker._tool_call_id = "old"
tracker.reset_for_query()
assert tracker._done is False
assert tracker._start_emitted is False
assert tracker._tool_call_id == ""
@pytest.mark.asyncio
async def test_pre_query_blocks_sdk_compaction(self):
"""After pre-query compaction, SDK compaction events are suppressed."""
tracker = CompactionTracker()
session = _make_session()
tracker.emit_pre_query(session)
tracker.on_compact()
evts = tracker.emit_start_if_ready()
assert evts == [] # _done blocks it
@pytest.mark.asyncio
async def test_reset_allows_new_compaction(self):
"""After reset_for_query, compaction can fire again."""
tracker = CompactionTracker()
session = _make_session()
tracker.emit_pre_query(session)
tracker.reset_for_query()
tracker.on_compact()
evts = tracker.emit_start_if_ready()
assert len(evts) == 3 # Start events emitted
@pytest.mark.asyncio
async def test_tool_call_id_consistency(self):
"""Start and end events use the same tool_call_id."""
tracker = CompactionTracker()
session = _make_session()
tracker.on_compact()
start_evts = tracker.emit_start_if_ready()
end_evts = await tracker.emit_end_if_ready(session)
start_evt = start_evts[1]
end_evt = end_evts[0]
assert isinstance(start_evt, StreamToolInputStart)
assert isinstance(end_evt, StreamToolOutputAvailable)
assert start_evt.toolCallId == end_evt.toolCallId
# Persisted ID should also match
tool_calls = session.messages[0].tool_calls
assert tool_calls is not None
assert tool_calls[0]["id"] == start_evt.toolCallId

View File

@@ -1,59 +0,0 @@
"""Dummy SDK service for testing copilot streaming.
Returns mock streaming responses without calling Claude Agent SDK.
Enable via COPILOT_TEST_MODE=true environment variable.
WARNING: This is for testing only. Do not use in production.
"""
import asyncio
import logging
import uuid
from collections.abc import AsyncGenerator
from typing import Any
from ..model import ChatSession
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
logger = logging.getLogger(__name__)
async def stream_chat_completion_dummy(
session_id: str,
message: str | None = None,
tool_call_response: str | None = None,
is_user_message: bool = True,
user_id: str | None = None,
retry_count: int = 0,
session: ChatSession | None = None,
context: dict[str, str] | None = None,
**_kwargs: Any,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Stream dummy chat completion for testing.
Returns a simple streaming response with text deltas to test:
- Streaming infrastructure works
- No timeout occurs
- Text arrives in chunks
- StreamFinish is sent by mark_session_completed
"""
logger.warning(
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
)
message_id = str(uuid.uuid4())
text_block_id = str(uuid.uuid4())
# Start the stream
yield StreamStart(messageId=message_id, sessionId=session_id)
# Simulate streaming text response with delays
dummy_response = "I counted: 1... 2... 3. All done!"
words = dummy_response.split()
for i, word in enumerate(words):
# Add space except for last word
text = word if i == len(words) - 1 else f"{word} "
yield StreamTextDelta(id=text_block_id, delta=text)
# Small delay to simulate real streaming
await asyncio.sleep(0.1)

View File

@@ -1,362 +0,0 @@
"""MCP file-tool handlers that route to the E2B cloud sandbox.
When E2B is active, these tools replace the SDK built-in Read/Write/Edit/
Glob/Grep so that all file operations share the same ``/home/user``
filesystem as ``bash_exec``.
SDK-internal paths (``~/.claude/projects/…/tool-results/``) are handled
by the separate ``Read`` MCP tool registered in ``tool_adapter.py``.
"""
from __future__ import annotations
import itertools
import json
import logging
import os
import shlex
from typing import Any, Callable
from backend.copilot.tools.e2b_sandbox import E2B_WORKDIR
logger = logging.getLogger(__name__)
# Lazy imports to break circular dependency with tool_adapter.
def _get_sandbox(): # type: ignore[return]
from .tool_adapter import get_current_sandbox # noqa: E402
return get_current_sandbox()
def _is_allowed_local(path: str) -> bool:
from .tool_adapter import is_allowed_local_path # noqa: E402
return is_allowed_local_path(path)
def _resolve_remote(path: str) -> str:
"""Normalise *path* to an absolute sandbox path under ``/home/user``.
Raises :class:`ValueError` if the resolved path escapes the sandbox.
"""
candidate = path if os.path.isabs(path) else os.path.join(E2B_WORKDIR, path)
normalized = os.path.normpath(candidate)
if normalized != E2B_WORKDIR and not normalized.startswith(E2B_WORKDIR + "/"):
raise ValueError(f"Path must be within {E2B_WORKDIR}: {path}")
return normalized
def _mcp(text: str, *, error: bool = False) -> dict[str, Any]:
if error:
text = json.dumps({"error": text, "type": "error"})
return {"content": [{"type": "text", "text": text}], "isError": error}
def _get_sandbox_and_path(
file_path: str,
) -> tuple[Any, str] | dict[str, Any]:
"""Common preamble: get sandbox + resolve path, or return MCP error."""
sandbox = _get_sandbox()
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
remote = _resolve_remote(file_path)
except ValueError as exc:
return _mcp(str(exc), error=True)
return sandbox, remote
# Tool handlers
async def _handle_read_file(args: dict[str, Any]) -> dict[str, Any]:
file_path: str = args.get("file_path", "")
offset: int = max(0, int(args.get("offset", 0)))
limit: int = max(1, int(args.get("limit", 2000)))
if not file_path:
return _mcp("file_path is required", error=True)
# SDK-internal paths (tool-results, ephemeral working dir) stay on the host.
if _is_allowed_local(file_path):
return _read_local(file_path, offset, limit)
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
return result
sandbox, remote = result
try:
raw: bytes = await sandbox.files.read(remote, format="bytes")
content = raw.decode("utf-8", errors="replace")
except Exception as exc:
return _mcp(f"Failed to read {remote}: {exc}", error=True)
lines = content.splitlines(keepends=True)
selected = list(itertools.islice(lines, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp(numbered)
async def _handle_write_file(args: dict[str, Any]) -> dict[str, Any]:
file_path: str = args.get("file_path", "")
content: str = args.get("content", "")
if not file_path:
return _mcp("file_path is required", error=True)
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
return result
sandbox, remote = result
try:
parent = os.path.dirname(remote)
if parent and parent != E2B_WORKDIR:
await sandbox.files.make_dir(parent)
await sandbox.files.write(remote, content)
except Exception as exc:
return _mcp(f"Failed to write {remote}: {exc}", error=True)
return _mcp(f"Successfully wrote to {remote}")
async def _handle_edit_file(args: dict[str, Any]) -> dict[str, Any]:
file_path: str = args.get("file_path", "")
old_string: str = args.get("old_string", "")
new_string: str = args.get("new_string", "")
replace_all: bool = args.get("replace_all", False)
if not file_path:
return _mcp("file_path is required", error=True)
if not old_string:
return _mcp("old_string is required", error=True)
result = _get_sandbox_and_path(file_path)
if isinstance(result, dict):
return result
sandbox, remote = result
try:
raw: bytes = await sandbox.files.read(remote, format="bytes")
content = raw.decode("utf-8", errors="replace")
except Exception as exc:
return _mcp(f"Failed to read {remote}: {exc}", error=True)
count = content.count(old_string)
if count == 0:
return _mcp(f"old_string not found in {file_path}", error=True)
if count > 1 and not replace_all:
return _mcp(
f"old_string appears {count} times in {file_path}. "
"Use replace_all=true or provide a more unique string.",
error=True,
)
updated = (
content.replace(old_string, new_string)
if replace_all
else content.replace(old_string, new_string, 1)
)
try:
await sandbox.files.write(remote, updated)
except Exception as exc:
return _mcp(f"Failed to write {remote}: {exc}", error=True)
return _mcp(f"Edited {remote} ({count} replacement{'s' if count > 1 else ''})")
async def _handle_glob(args: dict[str, Any]) -> dict[str, Any]:
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
if not pattern:
return _mcp("pattern is required", error=True)
sandbox = _get_sandbox()
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
cmd = f"find {shlex.quote(search_dir)} -name {shlex.quote(pattern)} -type f 2>/dev/null | head -500"
try:
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=10)
except Exception as exc:
return _mcp(f"Glob failed: {exc}", error=True)
files = [line for line in (result.stdout or "").strip().splitlines() if line]
return _mcp(json.dumps(files, indent=2))
async def _handle_grep(args: dict[str, Any]) -> dict[str, Any]:
pattern: str = args.get("pattern", "")
path: str = args.get("path", "")
include: str = args.get("include", "")
if not pattern:
return _mcp("pattern is required", error=True)
sandbox = _get_sandbox()
if sandbox is None:
return _mcp("No E2B sandbox available", error=True)
try:
search_dir = _resolve_remote(path) if path else E2B_WORKDIR
except ValueError as exc:
return _mcp(str(exc), error=True)
parts = ["grep", "-rn", "--color=never"]
if include:
parts.extend(["--include", include])
parts.extend([pattern, search_dir])
cmd = " ".join(shlex.quote(p) for p in parts) + " 2>/dev/null | head -200"
try:
result = await sandbox.commands.run(cmd, cwd=E2B_WORKDIR, timeout=15)
except Exception as exc:
return _mcp(f"Grep failed: {exc}", error=True)
output = (result.stdout or "").strip()
return _mcp(output if output else "No matches found.")
# Local read (for SDK-internal paths)
def _read_local(file_path: str, offset: int, limit: int) -> dict[str, Any]:
"""Read from the host filesystem (defence-in-depth path check)."""
if not _is_allowed_local(file_path):
return _mcp(f"Path not allowed: {file_path}", error=True)
expanded = os.path.realpath(os.path.expanduser(file_path))
try:
with open(expanded) as fh:
selected = list(itertools.islice(fh, offset, offset + limit))
numbered = "".join(
f"{i + offset + 1:>6}\t{line}" for i, line in enumerate(selected)
)
return _mcp(numbered)
except FileNotFoundError:
return _mcp(f"File not found: {file_path}", error=True)
except Exception as exc:
return _mcp(f"Error reading {file_path}: {exc}", error=True)
# Tool descriptors (name, description, schema, handler)
E2B_FILE_TOOLS: list[tuple[str, str, dict[str, Any], Callable[..., Any]]] = [
(
"read_file",
"Read a file from the cloud sandbox (/home/user). "
"Use offset and limit for large files.",
{
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
},
"offset": {
"type": "integer",
"description": "Line to start reading from (0-indexed). Default: 0.",
},
"limit": {
"type": "integer",
"description": "Number of lines to read. Default: 2000.",
},
},
"required": ["file_path"],
},
_handle_read_file,
),
(
"write_file",
"Write or create a file in the cloud sandbox (/home/user). "
"Parent directories are created automatically. "
"To copy a workspace file into the sandbox, use "
"read_workspace_file with save_to_path instead.",
{
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
},
"content": {"type": "string", "description": "Content to write."},
},
"required": ["file_path", "content"],
},
_handle_write_file,
),
(
"edit_file",
"Targeted text replacement in a sandbox file. "
"old_string must appear in the file and is replaced with new_string.",
{
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path (relative to /home/user, or absolute).",
},
"old_string": {"type": "string", "description": "Text to find."},
"new_string": {"type": "string", "description": "Replacement text."},
"replace_all": {
"type": "boolean",
"description": "Replace all occurrences (default: false).",
},
},
"required": ["file_path", "old_string", "new_string"],
},
_handle_edit_file,
),
(
"glob",
"Search for files by name pattern in the cloud sandbox.",
{
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "Glob pattern (e.g. *.py).",
},
"path": {
"type": "string",
"description": "Directory to search. Default: /home/user.",
},
},
"required": ["pattern"],
},
_handle_glob,
),
(
"grep",
"Search file contents by regex in the cloud sandbox.",
{
"type": "object",
"properties": {
"pattern": {"type": "string", "description": "Regex pattern."},
"path": {
"type": "string",
"description": "File or directory. Default: /home/user.",
},
"include": {
"type": "string",
"description": "Glob to filter files (e.g. *.py).",
},
},
"required": ["pattern"],
},
_handle_grep,
),
]
E2B_FILE_TOOL_NAMES: list[str] = [name for name, *_ in E2B_FILE_TOOLS]

View File

@@ -1,153 +0,0 @@
"""Tests for E2B file-tool path validation and local read safety.
Pure unit tests with no external dependencies (no E2B, no sandbox).
"""
import os
import pytest
from .e2b_file_tools import _read_local, _resolve_remote
from .tool_adapter import _current_project_dir
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# ---------------------------------------------------------------------------
# _resolve_remote — sandbox path normalisation & boundary enforcement
# ---------------------------------------------------------------------------
class TestResolveRemote:
def test_relative_path_resolved(self):
assert _resolve_remote("src/main.py") == "/home/user/src/main.py"
def test_absolute_within_sandbox(self):
assert _resolve_remote("/home/user/file.txt") == "/home/user/file.txt"
def test_workdir_itself(self):
assert _resolve_remote("/home/user") == "/home/user"
def test_relative_dotslash(self):
assert _resolve_remote("./README.md") == "/home/user/README.md"
def test_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("../../etc/passwd")
def test_absolute_traversal_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/user/../../etc/passwd")
def test_absolute_outside_sandbox_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/etc/passwd")
def test_root_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/")
def test_home_other_user_blocked(self):
with pytest.raises(ValueError, match="must be within /home/user"):
_resolve_remote("/home/other/file.txt")
def test_deep_nested_allowed(self):
assert _resolve_remote("a/b/c/d/e.txt") == "/home/user/a/b/c/d/e.txt"
def test_trailing_slash_normalised(self):
assert _resolve_remote("src/") == "/home/user/src"
def test_double_dots_within_sandbox_ok(self):
"""Path that resolves back within /home/user is allowed."""
assert _resolve_remote("a/b/../c.txt") == "/home/user/a/c.txt"
# ---------------------------------------------------------------------------
# _read_local — host filesystem reads with allowlist enforcement
#
# In E2B mode, _read_local only allows tool-results paths (via
# is_allowed_local_path without sdk_cwd). Regular files live on the
# sandbox, not the host.
# ---------------------------------------------------------------------------
class TestReadLocal:
def _make_tool_results_file(self, encoded: str, filename: str, content: str) -> str:
"""Create a tool-results file and return its path."""
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, filename)
with open(filepath, "w") as f:
f.write(content)
return filepath
def test_read_tool_results_file(self):
"""Reading a tool-results file should succeed."""
encoded = "-tmp-copilot-e2b-test-read"
filepath = self._make_tool_results_file(
encoded, "result.txt", "line 1\nline 2\nline 3\n"
)
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=0, limit=2000)
assert result["isError"] is False
assert "line 1" in result["content"][0]["text"]
assert "line 2" in result["content"][0]["text"]
finally:
_current_project_dir.reset(token)
os.unlink(filepath)
def test_read_disallowed_path_blocked(self):
"""Reading /etc/passwd should be blocked by the allowlist."""
result = _read_local("/etc/passwd", offset=0, limit=10)
assert result["isError"] is True
assert "not allowed" in result["content"][0]["text"].lower()
def test_read_nonexistent_tool_results(self):
"""A tool-results path that doesn't exist returns FileNotFoundError."""
encoded = "-tmp-copilot-e2b-test-nofile"
tool_results_dir = os.path.join(_SDK_PROJECTS_DIR, encoded, "tool-results")
os.makedirs(tool_results_dir, exist_ok=True)
filepath = os.path.join(tool_results_dir, "nonexistent.txt")
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=0, limit=10)
assert result["isError"] is True
assert "not found" in result["content"][0]["text"].lower()
finally:
_current_project_dir.reset(token)
os.rmdir(tool_results_dir)
def test_read_traversal_path_blocked(self):
"""A traversal attempt that escapes allowed directories is blocked."""
result = _read_local("/tmp/copilot-abc/../../etc/shadow", offset=0, limit=10)
assert result["isError"] is True
assert "not allowed" in result["content"][0]["text"].lower()
def test_read_arbitrary_host_path_blocked(self):
"""Arbitrary host paths are blocked even if they exist."""
result = _read_local("/proc/self/environ", offset=0, limit=10)
assert result["isError"] is True
def test_read_with_offset_and_limit(self):
"""Offset and limit should control which lines are returned."""
encoded = "-tmp-copilot-e2b-test-offset"
content = "".join(f"line {i}\n" for i in range(10))
filepath = self._make_tool_results_file(encoded, "lines.txt", content)
token = _current_project_dir.set(encoded)
try:
result = _read_local(filepath, offset=3, limit=2)
assert result["isError"] is False
text = result["content"][0]["text"]
assert "line 3" in text
assert "line 4" in text
assert "line 2" not in text
assert "line 5" not in text
finally:
_current_project_dir.reset(token)
os.unlink(filepath)
def test_read_without_project_dir_blocks_all(self):
"""Without _current_project_dir set, all paths are blocked."""
result = _read_local("/tmp/anything.txt", offset=0, limit=10)
assert result["isError"] is True

View File

@@ -1,172 +0,0 @@
"""Tests for OTEL tracing setup in the SDK copilot path."""
import os
from unittest.mock import MagicMock, patch
class TestSetupLangfuseOtel:
"""Tests for _setup_langfuse_otel()."""
def test_noop_when_langfuse_not_configured(self):
"""No env vars should be set when Langfuse credentials are missing."""
with patch(
"backend.copilot.sdk.service._is_langfuse_configured", return_value=False
):
from backend.copilot.sdk.service import _setup_langfuse_otel
# Clear any previously set env vars
env_keys = [
"LANGSMITH_OTEL_ENABLED",
"LANGSMITH_OTEL_ONLY",
"LANGSMITH_TRACING",
"OTEL_EXPORTER_OTLP_ENDPOINT",
"OTEL_EXPORTER_OTLP_HEADERS",
]
saved = {k: os.environ.pop(k, None) for k in env_keys}
try:
_setup_langfuse_otel()
for key in env_keys:
assert key not in os.environ, f"{key} should not be set"
finally:
for k, v in saved.items():
if v is not None:
os.environ[k] = v
def test_sets_env_vars_when_langfuse_configured(self):
"""OTEL env vars should be set when Langfuse credentials exist."""
mock_settings = MagicMock()
mock_settings.secrets.langfuse_public_key = "pk-test-123"
mock_settings.secrets.langfuse_secret_key = "sk-test-456"
mock_settings.secrets.langfuse_host = "https://langfuse.example.com"
mock_settings.secrets.langfuse_tracing_environment = "test"
with (
patch(
"backend.copilot.sdk.service._is_langfuse_configured",
return_value=True,
),
patch("backend.copilot.sdk.service.Settings", return_value=mock_settings),
patch(
"backend.copilot.sdk.service.configure_claude_agent_sdk",
return_value=True,
) as mock_configure,
):
from backend.copilot.sdk.service import _setup_langfuse_otel
# Clear env vars so setdefault works
env_keys = [
"LANGSMITH_OTEL_ENABLED",
"LANGSMITH_OTEL_ONLY",
"LANGSMITH_TRACING",
"OTEL_EXPORTER_OTLP_ENDPOINT",
"OTEL_EXPORTER_OTLP_HEADERS",
"OTEL_RESOURCE_ATTRIBUTES",
]
saved = {k: os.environ.pop(k, None) for k in env_keys}
try:
_setup_langfuse_otel()
assert os.environ["LANGSMITH_OTEL_ENABLED"] == "true"
assert os.environ["LANGSMITH_OTEL_ONLY"] == "true"
assert os.environ["LANGSMITH_TRACING"] == "true"
assert (
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
== "https://langfuse.example.com/api/public/otel"
)
assert "Authorization=Basic" in os.environ["OTEL_EXPORTER_OTLP_HEADERS"]
assert (
os.environ["OTEL_RESOURCE_ATTRIBUTES"]
== "langfuse.environment=test"
)
mock_configure.assert_called_once_with(tags=["sdk"])
finally:
for k, v in saved.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]
def test_existing_env_vars_not_overwritten(self):
"""Explicit env-var overrides should not be clobbered."""
mock_settings = MagicMock()
mock_settings.secrets.langfuse_public_key = "pk-test"
mock_settings.secrets.langfuse_secret_key = "sk-test"
mock_settings.secrets.langfuse_host = "https://langfuse.example.com"
with (
patch(
"backend.copilot.sdk.service._is_langfuse_configured",
return_value=True,
),
patch("backend.copilot.sdk.service.Settings", return_value=mock_settings),
patch(
"backend.copilot.sdk.service.configure_claude_agent_sdk",
return_value=True,
),
):
from backend.copilot.sdk.service import _setup_langfuse_otel
saved = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
try:
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = "https://custom.endpoint/v1"
_setup_langfuse_otel()
assert (
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
== "https://custom.endpoint/v1"
)
finally:
if saved is not None:
os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = saved
elif "OTEL_EXPORTER_OTLP_ENDPOINT" in os.environ:
del os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"]
def test_graceful_failure_on_exception(self):
"""Setup should not raise even if internal code fails."""
with (
patch(
"backend.copilot.sdk.service._is_langfuse_configured",
return_value=True,
),
patch(
"backend.copilot.sdk.service.Settings",
side_effect=RuntimeError("settings unavailable"),
),
):
from backend.copilot.sdk.service import _setup_langfuse_otel
# Should not raise — just logs and returns
_setup_langfuse_otel()
class TestPropagateAttributesImport:
"""Verify langfuse.propagate_attributes is available."""
def test_propagate_attributes_is_importable(self):
from langfuse import propagate_attributes
assert callable(propagate_attributes)
def test_propagate_attributes_returns_context_manager(self):
from langfuse import propagate_attributes
ctx = propagate_attributes(user_id="u1", session_id="s1", tags=["test"])
assert hasattr(ctx, "__enter__")
assert hasattr(ctx, "__exit__")
class TestReceiveResponseCompat:
"""Verify ClaudeSDKClient.receive_response() exists (langsmith patches it)."""
def test_receive_response_exists(self):
from claude_agent_sdk import ClaudeSDKClient
assert hasattr(ClaudeSDKClient, "receive_response")
def test_receive_response_is_async_generator(self):
import inspect
from claude_agent_sdk import ClaudeSDKClient
method = getattr(ClaudeSDKClient, "receive_response")
assert inspect.isfunction(method) or inspect.ismethod(method)

View File

@@ -1,255 +0,0 @@
"""Tests for _format_conversation_context and _build_query_message."""
from datetime import UTC, datetime
import pytest
from backend.copilot.model import ChatMessage, ChatSession
from backend.copilot.sdk.service import (
_build_query_message,
_format_conversation_context,
)
# ---------------------------------------------------------------------------
# _format_conversation_context
# ---------------------------------------------------------------------------
def test_format_empty_list():
assert _format_conversation_context([]) is None
def test_format_none_content_messages():
msgs = [ChatMessage(role="user", content=None)]
assert _format_conversation_context(msgs) is None
def test_format_user_message():
msgs = [ChatMessage(role="user", content="hello")]
result = _format_conversation_context(msgs)
assert result is not None
assert "User: hello" in result
assert result.startswith("<conversation_history>")
assert result.endswith("</conversation_history>")
def test_format_assistant_text():
msgs = [ChatMessage(role="assistant", content="hi there")]
result = _format_conversation_context(msgs)
assert result is not None
assert "You responded: hi there" in result
def test_format_assistant_tool_calls():
msgs = [
ChatMessage(
role="assistant",
content=None,
tool_calls=[{"function": {"name": "search", "arguments": '{"q": "test"}'}}],
)
]
result = _format_conversation_context(msgs)
assert result is not None
assert 'You called tool: search({"q": "test"})' in result
def test_format_tool_result():
msgs = [ChatMessage(role="tool", content='{"result": "ok"}')]
result = _format_conversation_context(msgs)
assert result is not None
assert 'Tool result: {"result": "ok"}' in result
def test_format_tool_result_none_content():
msgs = [ChatMessage(role="tool", content=None)]
result = _format_conversation_context(msgs)
assert result is not None
assert "Tool result: " in result
def test_format_full_conversation():
msgs = [
ChatMessage(role="user", content="find agents"),
ChatMessage(
role="assistant",
content="I'll search for agents.",
tool_calls=[
{"function": {"name": "find_agents", "arguments": '{"q": "test"}'}}
],
),
ChatMessage(role="tool", content='[{"id": "1", "name": "Agent1"}]'),
ChatMessage(role="assistant", content="Found Agent1."),
]
result = _format_conversation_context(msgs)
assert result is not None
assert "User: find agents" in result
assert "You responded: I'll search for agents." in result
assert "You called tool: find_agents" in result
assert "Tool result:" in result
assert "You responded: Found Agent1." in result
# ---------------------------------------------------------------------------
# _build_query_message
# ---------------------------------------------------------------------------
def _make_session(messages: list[ChatMessage]) -> ChatSession:
"""Build a minimal ChatSession with the given messages."""
now = datetime.now(UTC)
return ChatSession(
session_id="test-session",
user_id="user-1",
messages=messages,
title="test",
usage=[],
started_at=now,
updated_at=now,
)
@pytest.mark.asyncio
async def test_build_query_resume_up_to_date():
"""With --resume and transcript covers all messages, return raw message."""
session = _make_session(
[
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi"),
ChatMessage(role="user", content="what's new?"),
]
)
result, was_compacted = await _build_query_message(
"what's new?",
session,
use_resume=True,
transcript_msg_count=2,
session_id="test-session",
)
# transcript_msg_count == msg_count - 1, so no gap
assert result == "what's new?"
assert was_compacted is False
@pytest.mark.asyncio
async def test_build_query_resume_stale_transcript():
"""With --resume and stale transcript, gap context is prepended."""
session = _make_session(
[
ChatMessage(role="user", content="turn 1"),
ChatMessage(role="assistant", content="reply 1"),
ChatMessage(role="user", content="turn 2"),
ChatMessage(role="assistant", content="reply 2"),
ChatMessage(role="user", content="turn 3"),
]
)
result, was_compacted = await _build_query_message(
"turn 3",
session,
use_resume=True,
transcript_msg_count=2,
session_id="test-session",
)
assert "<conversation_history>" in result
assert "turn 2" in result
assert "reply 2" in result
assert "Now, the user says:\nturn 3" in result
assert was_compacted is False # gap context does not compact
@pytest.mark.asyncio
async def test_build_query_resume_zero_msg_count():
"""With --resume but transcript_msg_count=0, return raw message."""
session = _make_session(
[
ChatMessage(role="user", content="hello"),
ChatMessage(role="assistant", content="hi"),
ChatMessage(role="user", content="new msg"),
]
)
result, was_compacted = await _build_query_message(
"new msg",
session,
use_resume=True,
transcript_msg_count=0,
session_id="test-session",
)
assert result == "new msg"
assert was_compacted is False
@pytest.mark.asyncio
async def test_build_query_no_resume_single_message():
"""Without --resume and only 1 message, return raw message."""
session = _make_session([ChatMessage(role="user", content="first")])
result, was_compacted = await _build_query_message(
"first",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
)
assert result == "first"
assert was_compacted is False
@pytest.mark.asyncio
async def test_build_query_no_resume_multi_message(monkeypatch):
"""Without --resume and multiple messages, compress and prepend."""
session = _make_session(
[
ChatMessage(role="user", content="older question"),
ChatMessage(role="assistant", content="older answer"),
ChatMessage(role="user", content="new question"),
]
)
# Mock _compress_messages to return the messages as-is
async def _mock_compress(msgs):
return msgs, False
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages",
_mock_compress,
)
result, was_compacted = await _build_query_message(
"new question",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
)
assert "<conversation_history>" in result
assert "older question" in result
assert "older answer" in result
assert "Now, the user says:\nnew question" in result
assert was_compacted is False # mock returns False
@pytest.mark.asyncio
async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
"""When compression actually compacts, was_compacted should be True."""
session = _make_session(
[
ChatMessage(role="user", content="old"),
ChatMessage(role="assistant", content="reply"),
ChatMessage(role="user", content="new"),
]
)
async def _mock_compress(msgs):
return msgs, True # Simulate actual compaction
monkeypatch.setattr(
"backend.copilot.sdk.service._compress_messages",
_mock_compress,
)
result, was_compacted = await _build_query_message(
"new",
session,
use_resume=False,
transcript_msg_count=0,
session_id="test-session",
)
assert was_compacted is True

View File

@@ -47,20 +47,19 @@ class SDKResponseAdapter:
text blocks, tool calls, and message lifecycle.
"""
def __init__(self, message_id: str | None = None, session_id: str | None = None):
def __init__(self, message_id: str | None = None):
self.message_id = message_id or str(uuid.uuid4())
self.session_id = session_id
self.text_block_id = str(uuid.uuid4())
self.has_started_text = False
self.has_ended_text = False
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.resolved_tool_calls: set[str] = set()
self.task_id: str | None = None
self.step_open = False
@property
def has_unresolved_tool_calls(self) -> bool:
"""True when there are tool calls that haven't received output yet."""
return bool(self.current_tool_calls.keys() - self.resolved_tool_calls)
def set_task_id(self, task_id: str) -> None:
"""Set the task ID for reconnection support."""
self.task_id = task_id
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
"""Convert a single SDK message to Vercel AI SDK format."""
@@ -69,7 +68,7 @@ class SDKResponseAdapter:
if isinstance(sdk_message, SystemMessage):
if sdk_message.subtype == "init":
responses.append(
StreamStart(messageId=self.message_id, sessionId=self.session_id)
StreamStart(messageId=self.message_id, taskId=self.task_id)
)
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
responses.append(StreamStartStep())
@@ -78,12 +77,7 @@ class SDKResponseAdapter:
elif isinstance(sdk_message, AssistantMessage):
# Flush any SDK built-in tool calls that didn't get a UserMessage
# result (e.g. WebSearch, Read handled internally by the CLI).
# BUT skip flush when this AssistantMessage is a parallel tool
# continuation (contains only ToolUseBlocks) — the prior tools
# are still executing concurrently and haven't finished yet.
is_tool_only = all(isinstance(b, ToolUseBlock) for b in sdk_message.content)
if not is_tool_only:
self._flush_unresolved_tool_calls(responses)
self._flush_unresolved_tool_calls(responses)
# After tool results, the SDK sends a new AssistantMessage for the
# next LLM turn. Open a new step if the previous one was closed.
@@ -124,24 +118,8 @@ class SDKResponseAdapter:
blocks = content if isinstance(content, list) else []
resolved_in_blocks: set[str] = set()
sid = (self.session_id or "?")[:12]
parent_id_preview = getattr(sdk_message, "parent_tool_use_id", None)
logger.info(
"[SDK] [%s] UserMessage: %d blocks, content_type=%s, "
"parent_tool_use_id=%s",
sid,
len(blocks),
type(content).__name__,
parent_id_preview[:12] if parent_id_preview else "None",
)
for block in blocks:
if isinstance(block, ToolResultBlock) and block.tool_use_id:
# Skip if already resolved (e.g. by flush) — the real
# result supersedes the empty flush, but re-emitting
# would confuse the frontend's state machine.
if block.tool_use_id in self.resolved_tool_calls:
continue
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
tool_name = tool_info.get("name", "unknown")
@@ -166,11 +144,7 @@ class SDKResponseAdapter:
# Handle SDK built-in tool results carried via parent_tool_use_id
# instead of (or in addition to) ToolResultBlock content.
parent_id = sdk_message.parent_tool_use_id
if (
parent_id
and parent_id not in resolved_in_blocks
and parent_id not in self.resolved_tool_calls
):
if parent_id and parent_id not in resolved_in_blocks:
tool_info = self.current_tool_calls.get(parent_id, {})
tool_name = tool_info.get("name", "unknown")
@@ -254,28 +228,11 @@ class SDKResponseAdapter:
output, which we pop and emit here before the next ``AssistantMessage``
starts.
"""
unresolved = [
(tid, info.get("name", "unknown"))
for tid, info in self.current_tool_calls.items()
if tid not in self.resolved_tool_calls
]
sid = (self.session_id or "?")[:12]
if not unresolved:
logger.info(
"[SDK] [%s] Flush called but all %d tool(s) already resolved",
sid,
len(self.current_tool_calls),
)
return
logger.info(
"[SDK] [%s] Flushing %d unresolved tool call(s): %s",
sid,
len(unresolved),
", ".join(f"{name}({tid[:12]})" for tid, name in unresolved),
)
flushed = False
for tool_id, tool_name in unresolved:
for tool_id, tool_info in self.current_tool_calls.items():
if tool_id in self.resolved_tool_calls:
continue
tool_name = tool_info.get("name", "unknown")
output = pop_pending_tool_output(tool_name)
if output is not None:
responses.append(
@@ -288,12 +245,9 @@ class SDKResponseAdapter:
)
self.resolved_tool_calls.add(tool_id)
flushed = True
logger.info(
"[SDK] [%s] Flushed stashed output for %s " "(call %s, %d chars)",
sid,
tool_name,
tool_id[:12],
len(output),
logger.debug(
f"Flushed pending output for built-in tool {tool_name} "
f"(call {tool_id})"
)
else:
# No output available — emit an empty output so the frontend
@@ -309,14 +263,9 @@ class SDKResponseAdapter:
)
self.resolved_tool_calls.add(tool_id)
flushed = True
logger.warning(
"[SDK] [%s] Flushed EMPTY output for unresolved tool %s "
"(call %s) — stash was empty (likely SDK hook race "
"condition: PostToolUse hook hadn't completed before "
"flush was triggered)",
sid,
tool_name,
tool_id[:12],
logger.debug(
f"Flushed empty output for unresolved tool {tool_name} "
f"(call {tool_id})"
)
if flushed and self.step_open:

View File

@@ -1,8 +1,5 @@
"""Unit tests for the SDK response adapter."""
import asyncio
import pytest
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
@@ -30,14 +27,12 @@ from backend.copilot.response_model import (
from .response_adapter import SDKResponseAdapter
from .tool_adapter import MCP_TOOL_PREFIX
from .tool_adapter import _pending_tool_outputs as _pto
from .tool_adapter import _stash_event
from .tool_adapter import stash_pending_tool_output as _stash
from .tool_adapter import wait_for_stash
def _adapter() -> SDKResponseAdapter:
return SDKResponseAdapter(message_id="msg-1", session_id="session-1")
a = SDKResponseAdapter(message_id="msg-1")
a.set_task_id("task-1")
return a
# -- SystemMessage -----------------------------------------------------------
@@ -49,7 +44,7 @@ def test_system_init_emits_start_and_step():
assert len(results) == 2
assert isinstance(results[0], StreamStart)
assert results[0].messageId == "msg-1"
assert results[0].sessionId == "session-1"
assert results[0].taskId == "task-1"
assert isinstance(results[1], StreamStartStep)
@@ -369,310 +364,3 @@ def test_full_conversation_flow():
"StreamFinishStep", # step 2 closed
"StreamFinish",
]
# -- Flush unresolved tool calls --------------------------------------------
def test_flush_unresolved_at_result_message():
"""Built-in tools (WebSearch) without UserMessage results get flushed at ResultMessage."""
adapter = _adapter()
all_responses: list[StreamBaseResponse] = []
# 1. Init
all_responses.extend(
adapter.convert_message(SystemMessage(subtype="init", data={}))
)
# 2. Tool use (built-in tool — no MCP prefix)
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
],
model="test",
)
)
)
# 3. No UserMessage for this tool — go straight to ResultMessage
all_responses.extend(
adapter.convert_message(
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
)
)
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamStartStep",
"StreamToolInputStart",
"StreamToolInputAvailable",
"StreamToolOutputAvailable", # flushed with empty output
"StreamFinishStep", # step closed by flush
"StreamFinish",
]
# The flushed output should be empty (no stash available)
output_event = [
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
][0]
assert output_event.toolCallId == "ws-1"
assert output_event.toolName == "WebSearch"
assert output_event.output == ""
def test_flush_unresolved_at_next_assistant_message():
"""Built-in tools get flushed when the next AssistantMessage arrives."""
adapter = _adapter()
all_responses: list[StreamBaseResponse] = []
# 1. Init
all_responses.extend(
adapter.convert_message(SystemMessage(subtype="init", data={}))
)
# 2. Tool use (built-in — no UserMessage will come)
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
],
model="test",
)
)
)
# 3. Next AssistantMessage triggers flush before processing its blocks
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[TextBlock(text="Here are the results")], model="test"
)
)
)
types = [type(r).__name__ for r in all_responses]
assert types == [
"StreamStart",
"StreamStartStep",
"StreamToolInputStart",
"StreamToolInputAvailable",
# Flush at next AssistantMessage:
"StreamToolOutputAvailable",
"StreamFinishStep", # step closed by flush
# New step for continuation text:
"StreamStartStep",
"StreamTextStart",
"StreamTextDelta",
]
def test_flush_with_stashed_output():
"""Stashed output from PostToolUse hook is used when flushing."""
adapter = _adapter()
# Simulate PostToolUse hook stashing output
_pto.set({})
_stash("WebSearch", "Search result: 5 items found")
all_responses: list[StreamBaseResponse] = []
# Tool use
all_responses.extend(
adapter.convert_message(
AssistantMessage(
content=[
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
],
model="test",
)
)
)
# ResultMessage triggers flush
all_responses.extend(
adapter.convert_message(
ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
)
)
output_events = [
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
]
assert len(output_events) == 1
assert output_events[0].output == "Search result: 5 items found"
# Cleanup
_pto.set({}) # type: ignore[arg-type]
# -- wait_for_stash synchronisation tests --
@pytest.mark.asyncio
async def test_wait_for_stash_signaled():
"""wait_for_stash returns True when stash_pending_tool_output signals."""
_pto.set({})
event = asyncio.Event()
_stash_event.set(event)
# Simulate a PostToolUse hook that stashes output after a short delay
async def delayed_stash():
await asyncio.sleep(0.01)
_stash("WebSearch", "result data")
asyncio.create_task(delayed_stash())
result = await wait_for_stash(timeout=1.0)
assert result is True
assert _pto.get({}).get("WebSearch") == ["result data"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
@pytest.mark.asyncio
async def test_wait_for_stash_timeout():
"""wait_for_stash returns False on timeout when no stash occurs."""
_pto.set({})
event = asyncio.Event()
_stash_event.set(event)
result = await wait_for_stash(timeout=0.05)
assert result is False
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
@pytest.mark.asyncio
async def test_wait_for_stash_already_stashed():
"""wait_for_stash picks up a stash that happened just before the wait."""
_pto.set({})
event = asyncio.Event()
_stash_event.set(event)
# Stash before waiting — simulates hook completing before message arrives
_stash("Read", "file contents")
# Event is now set; wait_for_stash detects the fast path and returns
# immediately without timing out.
result = await wait_for_stash(timeout=0.05)
assert result is True
# But the stash itself is populated
assert _pto.get({}).get("Read") == ["file contents"]
# Cleanup
_pto.set({}) # type: ignore[arg-type]
_stash_event.set(None)
# -- Parallel tool call tests --
def test_parallel_tool_calls_not_flushed_prematurely():
"""Parallel tool calls should NOT be flushed when the next AssistantMessage
only contains ToolUseBlocks (parallel continuation)."""
adapter = SDKResponseAdapter()
# Init
adapter.convert_message(SystemMessage(subtype="init", data={}))
# First AssistantMessage: tool call #1
msg1 = AssistantMessage(
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
model="test",
)
r1 = adapter.convert_message(msg1)
assert any(isinstance(r, StreamToolInputAvailable) for r in r1)
assert adapter.has_unresolved_tool_calls
# Second AssistantMessage: tool call #2 (parallel continuation)
msg2 = AssistantMessage(
content=[ToolUseBlock(id="t2", name="WebSearch", input={"q": "bar"})],
model="test",
)
r2 = adapter.convert_message(msg2)
# No flush should have happened — t1 should NOT have StreamToolOutputAvailable
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
assert len(output_events) == 0, (
f"Tool-only AssistantMessage should not flush prior tools, "
f"but got {len(output_events)} output events"
)
# Both t1 and t2 should still be unresolved
assert "t1" not in adapter.resolved_tool_calls
assert "t2" not in adapter.resolved_tool_calls
def test_text_assistant_message_flushes_prior_tools():
"""An AssistantMessage with text (new turn) should flush unresolved tools."""
adapter = SDKResponseAdapter()
# Init
adapter.convert_message(SystemMessage(subtype="init", data={}))
# Tool call
msg1 = AssistantMessage(
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
model="test",
)
adapter.convert_message(msg1)
assert adapter.has_unresolved_tool_calls
# Text AssistantMessage (new turn after tools completed)
msg2 = AssistantMessage(
content=[TextBlock(text="Here are the results")],
model="test",
)
r2 = adapter.convert_message(msg2)
# Flush SHOULD have happened — t1 gets empty output
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
assert len(output_events) == 1
assert output_events[0].toolCallId == "t1"
assert "t1" in adapter.resolved_tool_calls
def test_already_resolved_tool_skipped_in_user_message():
"""A tool result in UserMessage should be skipped if already resolved by flush."""
adapter = SDKResponseAdapter()
adapter.convert_message(SystemMessage(subtype="init", data={}))
# Tool call + flush via text message
adapter.convert_message(
AssistantMessage(
content=[ToolUseBlock(id="t1", name="WebSearch", input={})],
model="test",
)
)
adapter.convert_message(
AssistantMessage(
content=[TextBlock(text="Done")],
model="test",
)
)
assert "t1" in adapter.resolved_tool_calls
# Now UserMessage arrives with the real result — should be skipped
user_msg = UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="real")])
r = adapter.convert_message(user_msg)
output_events = [r_ for r_ in r if isinstance(r_, StreamToolOutputAvailable)]
assert (
len(output_events) == 0
), "Already-resolved tool should not emit duplicate output"

View File

@@ -1,194 +0,0 @@
"""SDK compatibility tests — verify the claude-agent-sdk public API surface we depend on.
Instead of pinning to a narrow version range, these tests verify that the
installed SDK exposes every class, function, attribute, and method the copilot
integration relies on. If an SDK upgrade removes or renames something these
tests will catch it immediately.
"""
import inspect
import pytest
# ---------------------------------------------------------------------------
# Public types & factories
# ---------------------------------------------------------------------------
def test_sdk_exports_client_and_options():
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
assert inspect.isclass(ClaudeSDKClient)
assert inspect.isclass(ClaudeAgentOptions)
def test_sdk_exports_message_types():
from claude_agent_sdk import (
AssistantMessage,
Message,
ResultMessage,
SystemMessage,
UserMessage,
)
for cls in (AssistantMessage, ResultMessage, SystemMessage, UserMessage):
assert inspect.isclass(cls), f"{cls.__name__} is not a class"
# Message is a Union type alias, just verify it's importable
assert Message is not None
def test_sdk_exports_content_block_types():
from claude_agent_sdk import TextBlock, ToolResultBlock, ToolUseBlock
for cls in (TextBlock, ToolResultBlock, ToolUseBlock):
assert inspect.isclass(cls), f"{cls.__name__} is not a class"
def test_sdk_exports_mcp_helpers():
from claude_agent_sdk import create_sdk_mcp_server, tool
assert callable(create_sdk_mcp_server)
assert callable(tool)
# ---------------------------------------------------------------------------
# ClaudeSDKClient interface
# ---------------------------------------------------------------------------
def test_client_has_required_methods():
from claude_agent_sdk import ClaudeSDKClient
required = ["connect", "disconnect", "query", "receive_messages"]
for name in required:
attr = getattr(ClaudeSDKClient, name, None)
assert attr is not None, f"ClaudeSDKClient.{name} missing"
assert callable(attr), f"ClaudeSDKClient.{name} is not callable"
def test_client_supports_async_context_manager():
from claude_agent_sdk import ClaudeSDKClient
assert hasattr(ClaudeSDKClient, "__aenter__")
assert hasattr(ClaudeSDKClient, "__aexit__")
# ---------------------------------------------------------------------------
# ClaudeAgentOptions fields
# ---------------------------------------------------------------------------
def test_agent_options_accepts_required_fields():
"""Verify ClaudeAgentOptions accepts all kwargs our code passes."""
from claude_agent_sdk import ClaudeAgentOptions
opts = ClaudeAgentOptions(
system_prompt="test",
cwd="/tmp",
)
assert opts.system_prompt == "test"
assert opts.cwd == "/tmp"
def test_agent_options_accepts_all_our_fields():
"""Comprehensive check of every field we use in service.py."""
from claude_agent_sdk import ClaudeAgentOptions
fields_we_use = [
"system_prompt",
"mcp_servers",
"allowed_tools",
"disallowed_tools",
"hooks",
"cwd",
"model",
"env",
"resume",
"max_buffer_size",
]
sig = inspect.signature(ClaudeAgentOptions)
for field in fields_we_use:
assert field in sig.parameters, (
f"ClaudeAgentOptions no longer accepts '{field}'"
f"available params: {list(sig.parameters.keys())}"
)
# ---------------------------------------------------------------------------
# Message attributes
# ---------------------------------------------------------------------------
def test_assistant_message_has_content_and_model():
from claude_agent_sdk import AssistantMessage, TextBlock
msg = AssistantMessage(content=[TextBlock(text="hi")], model="test")
assert hasattr(msg, "content")
assert hasattr(msg, "model")
def test_result_message_has_required_attrs():
from claude_agent_sdk import ResultMessage
msg = ResultMessage(
subtype="success",
duration_ms=100,
duration_api_ms=50,
is_error=False,
num_turns=1,
session_id="s1",
)
assert msg.subtype == "success"
assert hasattr(msg, "result")
def test_system_message_has_subtype_and_data():
from claude_agent_sdk import SystemMessage
msg = SystemMessage(subtype="init", data={})
assert msg.subtype == "init"
assert msg.data == {}
def test_user_message_has_parent_tool_use_id():
from claude_agent_sdk import UserMessage
msg = UserMessage(content="test")
assert hasattr(msg, "parent_tool_use_id")
assert hasattr(msg, "tool_use_result")
def test_tool_use_block_has_id_name_input():
from claude_agent_sdk import ToolUseBlock
block = ToolUseBlock(id="t1", name="test", input={"key": "val"})
assert block.id == "t1"
assert block.name == "test"
assert block.input == {"key": "val"}
def test_tool_result_block_has_required_attrs():
from claude_agent_sdk import ToolResultBlock
block = ToolResultBlock(tool_use_id="t1", content="result")
assert block.tool_use_id == "t1"
assert block.content == "result"
assert hasattr(block, "is_error")
# ---------------------------------------------------------------------------
# Hook types
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"hook_event",
["PreToolUse", "PostToolUse", "Stop"],
)
def test_sdk_exports_hook_event_type(hook_event: str):
"""Verify HookEvent literal includes the events our security_hooks use."""
from claude_agent_sdk.types import HookEvent
# HookEvent is a Literal type — check that our events are valid values.
# We can't easily inspect Literal at runtime, so just verify the type exists.
assert HookEvent is not None

View File

@@ -6,6 +6,7 @@ ensuring multi-user isolation and preventing unauthorized operations.
import json
import logging
import os
import re
from collections.abc import Callable
from typing import Any, cast
@@ -15,7 +16,6 @@ from .tool_adapter import (
DANGEROUS_PATTERNS,
MCP_TOOL_PREFIX,
WORKSPACE_SCOPED_TOOLS,
is_allowed_local_path,
stash_pending_tool_output,
)
@@ -38,20 +38,40 @@ def _validate_workspace_path(
) -> dict[str, Any]:
"""Validate that a workspace-scoped tool only accesses allowed paths.
Delegates to :func:`is_allowed_local_path` which permits:
Allowed directories:
- The SDK working directory (``/tmp/copilot-<session>/``)
- The current session's tool-results directory
(``~/.claude/projects/<encoded-cwd>/tool-results/``)
- The SDK tool-results directory (``~/.claude/projects/…/tool-results/``)
"""
path = tool_input.get("file_path") or tool_input.get("path") or ""
if not path:
# Glob/Grep without a path default to cwd which is already sandboxed
return {}
if is_allowed_local_path(path, sdk_cwd):
# Resolve relative paths against sdk_cwd (the SDK sets cwd so the LLM
# naturally uses relative paths like "test.txt" instead of absolute ones).
# Tilde paths (~/) are home-dir references, not relative — expand first.
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved.startswith(norm_cwd + os.sep) or resolved == norm_cwd:
return {}
# Allow access to ~/.claude/projects/*/tool-results/ (big tool results)
claude_dir = os.path.realpath(os.path.expanduser("~/.claude/projects"))
tool_results_seg = os.sep + "tool-results" + os.sep
if resolved.startswith(claude_dir + os.sep) and tool_results_seg in resolved:
return {}
logger.warning(f"Blocked {tool_name} outside workspace: {path}")
logger.warning(
f"Blocked {tool_name} outside workspace: {path} (resolved={resolved})"
)
workspace_hint = f" Allowed workspace: {sdk_cwd}" if sdk_cwd else ""
return _deny(
f"[SECURITY] Tool '{tool_name}' can only access files within the workspace "
@@ -104,20 +124,20 @@ def _validate_user_isolation(
"""Validate that tool calls respect user isolation."""
# For workspace file tools, ensure path doesn't escape
if "workspace" in tool_name.lower():
# The "path" param is a cloud storage key (e.g. "/ASEAN/report.md")
# where a leading "/" is normal. Only check for ".." traversal.
# Filesystem paths (source_path, save_to_path) are validated inside
# the tool itself via _validate_ephemeral_path.
path = tool_input.get("path", "") or tool_input.get("file_path", "")
if path and ".." in path:
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": "Path traversal not allowed",
if path:
# Check for path traversal
if ".." in path or path.startswith("/"):
logger.warning(
f"Blocked path traversal attempt: {path} by user {user_id}"
)
return {
"hookSpecificOutput": {
"hookEventName": "PreToolUse",
"permissionDecision": "deny",
"permissionDecisionReason": "Path traversal not allowed",
}
}
}
return {}
@@ -126,7 +146,6 @@ def create_security_hooks(
user_id: str | None,
sdk_cwd: str | None = None,
max_subtasks: int = 3,
on_compact: Callable[[], None] | None = None,
on_stop: Callable[[str, str], None] | None = None,
) -> dict[str, Any]:
"""Create the security hooks configuration for Claude Agent SDK.
@@ -141,7 +160,7 @@ def create_security_hooks(
Args:
user_id: Current user ID for isolation validation
sdk_cwd: SDK working directory for workspace-scoped tool validation
max_subtasks: Maximum concurrent Task (sub-agent) spawns allowed per session
max_subtasks: Maximum Task (sub-agent) spawns allowed per session
on_stop: Callback ``(transcript_path, sdk_session_id)`` invoked when
the SDK finishes processing — used to read the JSONL transcript
before the CLI process exits.
@@ -153,9 +172,8 @@ def create_security_hooks(
from claude_agent_sdk import HookMatcher
from claude_agent_sdk.types import HookContext, HookInput, SyncHookJSONOutput
# Per-session tracking for Task sub-agent concurrency.
# Set of tool_use_ids that consumed a slot — len() is the active count.
task_tool_use_ids: set[str] = set()
# Per-session counter for Task sub-agent spawns
task_spawn_count = 0
async def pre_tool_use_hook(
input_data: HookInput,
@@ -163,34 +181,23 @@ def create_security_hooks(
context: HookContext,
) -> SyncHookJSONOutput:
"""Combined pre-tool-use validation hook."""
nonlocal task_spawn_count
_ = context # unused but required by signature
tool_name = cast(str, input_data.get("tool_name", ""))
tool_input = cast(dict[str, Any], input_data.get("tool_input", {}))
# Rate-limit Task (sub-agent) spawns per session
if tool_name == "Task":
# Block background task execution first — denied calls
# should not consume a subtask slot.
if tool_input.get("run_in_background"):
logger.info(f"[SDK] Blocked background Task, user={user_id}")
return cast(
SyncHookJSONOutput,
_deny(
"Background task execution is not supported. "
"Run tasks in the foreground instead "
"(remove the run_in_background parameter)."
),
)
if len(task_tool_use_ids) >= max_subtasks:
task_spawn_count += 1
if task_spawn_count > max_subtasks:
logger.warning(
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
)
return cast(
SyncHookJSONOutput,
_deny(
f"Maximum {max_subtasks} concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
"or continue in the main conversation."
f"Maximum {max_subtasks} sub-tasks per session. "
"Please continue in the main conversation."
),
)
@@ -210,24 +217,9 @@ def create_security_hooks(
if result:
return cast(SyncHookJSONOutput, result)
# Reserve the Task slot only after all validations pass
if tool_name == "Task" and tool_use_id is not None:
task_tool_use_ids.add(tool_use_id)
logger.debug(f"[SDK] Tool start: {tool_name}, user={user_id}")
return cast(SyncHookJSONOutput, {})
def _release_task_slot(tool_name: str, tool_use_id: str | None) -> None:
"""Release a Task concurrency slot if one was reserved."""
if tool_name == "Task" and tool_use_id in task_tool_use_ids:
task_tool_use_ids.discard(tool_use_id)
logger.info(
"[SDK] Task slot released, active=%d/%d, user=%s",
len(task_tool_use_ids),
max_subtasks,
user_id,
)
async def post_tool_use_hook(
input_data: HookInput,
tool_use_id: str | None,
@@ -242,35 +234,15 @@ def create_security_hooks(
"""
_ = context
tool_name = cast(str, input_data.get("tool_name", ""))
_release_task_slot(tool_name, tool_use_id)
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
logger.info(
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
tool_name,
is_builtin,
(tool_use_id or "")[:12],
)
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
# Stash output for SDK built-in tools so the response adapter can
# emit StreamToolOutputAvailable even when the CLI doesn't surface
# a separate UserMessage with ToolResultBlock content.
if is_builtin:
if not tool_name.startswith(MCP_TOOL_PREFIX):
tool_response = input_data.get("tool_response")
if tool_response is not None:
resp_preview = str(tool_response)[:100]
logger.info(
"[SDK] Stashing builtin output for %s (%d chars): %s...",
tool_name,
len(str(tool_response)),
resp_preview,
)
stash_pending_tool_output(tool_name, tool_response)
else:
logger.warning(
"[SDK] PostToolUse for builtin %s but tool_response is None",
tool_name,
)
return cast(SyncHookJSONOutput, {})
@@ -287,9 +259,6 @@ def create_security_hooks(
f"[SDK] Tool failed: {tool_name}, error={error}, "
f"user={user_id}, tool_use_id={tool_use_id}"
)
_release_task_slot(tool_name, tool_use_id)
return cast(SyncHookJSONOutput, {})
async def pre_compact_hook(
@@ -307,8 +276,6 @@ def create_security_hooks(
logger.info(
f"[SDK] Context compaction triggered: {trigger}, user={user_id}"
)
if on_compact is not None:
on_compact()
return cast(SyncHookJSONOutput, {})
# --- Stop hook: capture transcript path for stateless resume ---

View File

@@ -7,23 +7,11 @@ tool access, and dangerous input patterns.
import os
import pytest
from .security_hooks import _validate_tool_access, _validate_user_isolation
from .service import _is_tool_error_or_denial
SDK_CWD = "/tmp/copilot-abc123"
def _sdk_available() -> bool:
try:
import claude_agent_sdk # noqa: F401
return True
except ImportError:
return False
def _is_denied(result: dict) -> bool:
hook = result.get("hookSpecificOutput", {})
return hook.get("permissionDecision") == "deny"
@@ -120,31 +108,17 @@ def test_read_no_cwd_denies_absolute():
def test_read_tool_results_allowed():
from .tool_adapter import _current_project_dir
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
# is_allowed_local_path requires the session's encoded cwd to be set
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
finally:
_current_project_dir.reset(token)
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert result == {}
def test_read_claude_projects_session_dir_allowed():
"""Files within the current session's project dir are allowed."""
from .tool_adapter import _current_project_dir
def test_read_claude_projects_without_tool_results_denied():
home = os.path.expanduser("~")
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
token = _current_project_dir.set("-tmp-copilot-abc123")
try:
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert not _is_denied(result)
finally:
_current_project_dir.reset(token)
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
assert _is_denied(result)
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
@@ -179,12 +153,11 @@ def test_workspace_path_traversal_blocked():
assert _is_denied(result)
def test_workspace_absolute_path_allowed():
"""Workspace 'path' is a cloud storage key — leading '/' is normal."""
def test_workspace_absolute_path_blocked():
result = _validate_user_isolation(
"workspace_read", {"path": "/ASEAN/report.md"}, user_id="user-1"
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
)
assert result == {}
assert _is_denied(result)
def test_workspace_normal_path_allowed():
@@ -215,218 +188,3 @@ def test_bash_builtin_blocked_message_clarity():
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
assert "[SECURITY]" in reason
assert "cannot be bypassed" in reason
# -- Task sub-agent hooks (require SDK) --------------------------------------
@pytest.fixture()
def _hooks():
"""Create security hooks and return (pre, post, post_failure) handlers."""
from .security_hooks import create_security_hooks
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
pre = hooks["PreToolUse"][0].hooks[0]
post = hooks["PostToolUse"][0].hooks[0]
post_failure = hooks["PostToolUseFailure"][0].hooks[0]
return pre, post, post_failure
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_task_background_blocked(_hooks):
"""Task with run_in_background=true must be denied."""
pre, _, _ = _hooks
result = await pre(
{"tool_name": "Task", "tool_input": {"run_in_background": True, "prompt": "x"}},
tool_use_id=None,
context={},
)
assert _is_denied(result)
assert "foreground" in _reason(result).lower()
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_task_foreground_allowed(_hooks):
"""Task without run_in_background should be allowed."""
pre, _, _ = _hooks
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "do stuff"}},
tool_use_id="tu-1",
context={},
)
assert not _is_denied(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_task_limit_enforced(_hooks):
"""Task spawns beyond max_subtasks should be denied."""
pre, _, _ = _hooks
# First two should pass
for i in range(2):
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
tool_use_id=f"tu-limit-{i}",
context={},
)
assert not _is_denied(result)
# Third should be denied (limit=2)
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "over limit"}},
tool_use_id="tu-limit-2",
context={},
)
assert _is_denied(result)
assert "Maximum" in _reason(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_task_slot_released_on_completion(_hooks):
"""Completing a Task should free a slot so new Tasks can be spawned."""
pre, post, _ = _hooks
# Fill both slots
for i in range(2):
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
tool_use_id=f"tu-comp-{i}",
context={},
)
assert not _is_denied(result)
# Third should be denied — at capacity
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "over"}},
tool_use_id="tu-comp-2",
context={},
)
assert _is_denied(result)
# Complete first task — frees a slot
await post(
{"tool_name": "Task", "tool_input": {}},
tool_use_id="tu-comp-0",
context={},
)
# Now a new Task should be allowed
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "after release"}},
tool_use_id="tu-comp-3",
context={},
)
assert not _is_denied(result)
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
@pytest.mark.asyncio
async def test_task_slot_released_on_failure(_hooks):
"""A failed Task should also free its concurrency slot."""
pre, _, post_failure = _hooks
# Fill both slots
for i in range(2):
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
tool_use_id=f"tu-fail-{i}",
context={},
)
assert not _is_denied(result)
# At capacity
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "over"}},
tool_use_id="tu-fail-2",
context={},
)
assert _is_denied(result)
# Fail first task — should free a slot
await post_failure(
{"tool_name": "Task", "tool_input": {}, "error": "something broke"},
tool_use_id="tu-fail-0",
context={},
)
# New Task should be allowed
result = await pre(
{"tool_name": "Task", "tool_input": {"prompt": "after failure"}},
tool_use_id="tu-fail-3",
context={},
)
assert not _is_denied(result)
# -- _is_tool_error_or_denial ------------------------------------------------
class TestIsToolErrorOrDenial:
def test_none_content(self):
assert _is_tool_error_or_denial(None) is False
def test_empty_content(self):
assert _is_tool_error_or_denial("") is False
def test_benign_output(self):
assert _is_tool_error_or_denial("All good, no issues.") is False
def test_security_marker(self):
assert _is_tool_error_or_denial("[SECURITY] Tool access blocked") is True
def test_cannot_be_bypassed(self):
assert _is_tool_error_or_denial("This restriction cannot be bypassed.") is True
def test_not_allowed(self):
assert _is_tool_error_or_denial("Operation not allowed in sandbox") is True
def test_background_task_denial(self):
assert (
_is_tool_error_or_denial(
"Background task execution is not supported. "
"Run tasks in the foreground instead."
)
is True
)
def test_subtask_limit_denial(self):
assert (
_is_tool_error_or_denial(
"Maximum 2 concurrent sub-tasks. "
"Wait for running sub-tasks to finish, "
"or continue in the main conversation."
)
is True
)
def test_denied_marker(self):
assert (
_is_tool_error_or_denial("Access denied: insufficient privileges") is True
)
def test_blocked_marker(self):
assert _is_tool_error_or_denial("Request blocked by security policy") is True
def test_failed_marker(self):
assert _is_tool_error_or_denial("Failed to execute tool: timeout") is True
def test_mcp_iserror(self):
assert _is_tool_error_or_denial('{"isError": true, "content": []}') is True
def test_benign_error_in_value(self):
"""Content like '0 errors found' should not trigger — 'error' was removed."""
assert _is_tool_error_or_denial("0 errors found") is False
def test_benign_permission_field(self):
"""Schema descriptions mentioning 'permission' should not trigger."""
assert (
_is_tool_error_or_denial(
'{"fields": [{"name": "permission_level", "type": "int"}]}'
)
is False
)
def test_benign_not_found_in_listing(self):
"""File listing containing 'not found' in filenames should not trigger."""
assert _is_tool_error_or_denial("readme.md\nfile-not-found-handler.py") is False

File diff suppressed because it is too large Load Diff

View File

@@ -1,147 +0,0 @@
"""Tests for SDK service helpers."""
import base64
import os
from dataclasses import dataclass
from unittest.mock import AsyncMock, patch
import pytest
from .service import _prepare_file_attachments
@dataclass
class _FakeFileInfo:
id: str
name: str
path: str
mime_type: str
size_bytes: int
_PATCH_TARGET = "backend.copilot.sdk.service.get_manager"
class TestPrepareFileAttachments:
@pytest.mark.asyncio
async def test_empty_list_returns_empty(self, tmp_path):
result = await _prepare_file_attachments([], "u", "s", str(tmp_path))
assert result.hint == ""
assert result.image_blocks == []
@pytest.mark.asyncio
async def test_image_embedded_as_vision_block(self, tmp_path):
"""JPEG images should become vision content blocks, not files on disk."""
raw = b"\xff\xd8\xff\xe0fake-jpeg"
info = _FakeFileInfo(
id="abc",
name="photo.jpg",
path="/photo.jpg",
mime_type="image/jpeg",
size_bytes=len(raw),
)
mgr = AsyncMock()
mgr.get_file_info.return_value = info
mgr.read_file_by_id.return_value = raw
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(
["abc"], "user1", "sess1", str(tmp_path)
)
assert "1 file" in result.hint
assert "photo.jpg" in result.hint
assert "embedded as image" in result.hint
assert len(result.image_blocks) == 1
block = result.image_blocks[0]
assert block["type"] == "image"
assert block["source"]["media_type"] == "image/jpeg"
assert block["source"]["data"] == base64.b64encode(raw).decode("ascii")
# Image should NOT be written to disk (embedded instead)
assert not os.path.exists(os.path.join(tmp_path, "photo.jpg"))
@pytest.mark.asyncio
async def test_pdf_saved_to_disk(self, tmp_path):
"""PDFs should be saved to disk for Read tool access, not embedded."""
info = _FakeFileInfo("f1", "doc.pdf", "/doc.pdf", "application/pdf", 50)
mgr = AsyncMock()
mgr.get_file_info.return_value = info
mgr.read_file_by_id.return_value = b"%PDF-1.4 fake"
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(["f1"], "u", "s", str(tmp_path))
assert result.image_blocks == []
saved = tmp_path / "doc.pdf"
assert saved.exists()
assert saved.read_bytes() == b"%PDF-1.4 fake"
assert str(saved) in result.hint
@pytest.mark.asyncio
async def test_mixed_images_and_files(self, tmp_path):
"""Images become blocks, non-images go to disk."""
infos = {
"id1": _FakeFileInfo("id1", "a.png", "/a.png", "image/png", 4),
"id2": _FakeFileInfo("id2", "b.pdf", "/b.pdf", "application/pdf", 4),
"id3": _FakeFileInfo("id3", "c.txt", "/c.txt", "text/plain", 4),
}
mgr = AsyncMock()
mgr.get_file_info.side_effect = lambda fid: infos[fid]
mgr.read_file_by_id.return_value = b"data"
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(
["id1", "id2", "id3"], "u", "s", str(tmp_path)
)
assert "3 files" in result.hint
assert "a.png" in result.hint
assert "b.pdf" in result.hint
assert "c.txt" in result.hint
# Only the image should be a vision block
assert len(result.image_blocks) == 1
assert result.image_blocks[0]["source"]["media_type"] == "image/png"
# Non-image files should be on disk
assert (tmp_path / "b.pdf").exists()
assert (tmp_path / "c.txt").exists()
# Read tool hint should appear (has non-image files)
assert "Read tool" in result.hint
@pytest.mark.asyncio
async def test_singular_noun(self, tmp_path):
info = _FakeFileInfo("x", "only.txt", "/only.txt", "text/plain", 2)
mgr = AsyncMock()
mgr.get_file_info.return_value = info
mgr.read_file_by_id.return_value = b"hi"
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(["x"], "u", "s", str(tmp_path))
assert "1 file." in result.hint
@pytest.mark.asyncio
async def test_missing_file_skipped(self, tmp_path):
mgr = AsyncMock()
mgr.get_file_info.return_value = None
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(
["missing-id"], "u", "s", str(tmp_path)
)
assert result.hint == ""
assert result.image_blocks == []
@pytest.mark.asyncio
async def test_image_only_no_read_hint(self, tmp_path):
"""When all files are images, no Read tool hint should appear."""
info = _FakeFileInfo("i1", "cat.png", "/cat.png", "image/png", 4)
mgr = AsyncMock()
mgr.get_file_info.return_value = info
mgr.read_file_by_id.return_value = b"data"
with patch(_PATCH_TARGET, new_callable=AsyncMock, return_value=mgr):
result = await _prepare_file_attachments(["i1"], "u", "s", str(tmp_path))
assert "Read tool" not in result.hint
assert len(result.image_blocks) == 1

View File

@@ -2,91 +2,32 @@
This module provides the adapter layer that converts existing BaseTool implementations
into in-process MCP tools that can be used with the Claude Agent SDK.
Long-running tools (``is_long_running=True``) are delegated to the non-SDK
background infrastructure (stream_registry, Redis persistence, SSE reconnection)
via a callback provided by the service layer. This avoids wasteful SDK polling
and makes results survive page refreshes.
"""
import asyncio
import itertools
import json
import logging
import os
import re
import uuid
from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
from claude_agent_sdk import create_sdk_mcp_server, tool
from typing import Any
from backend.copilot.model import ChatSession
from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools.base import BaseTool
from backend.util.truncate import truncate
from .e2b_file_tools import E2B_FILE_TOOL_NAMES, E2B_FILE_TOOLS
if TYPE_CHECKING:
from e2b import AsyncSandbox
logger = logging.getLogger(__name__)
# Allowed base directory for the Read tool (SDK saves oversized tool results here).
# Restricted to ~/.claude/projects/ and further validated to require "tool-results"
# in the path — prevents reading settings, credentials, or other sensitive files.
_SDK_PROJECTS_DIR = os.path.realpath(os.path.expanduser("~/.claude/projects"))
# Max MCP response size in chars — keeps tool output under the SDK's 10 MB JSON buffer.
_MCP_MAX_CHARS = 500_000
# Context variable holding the encoded project directory name for the current
# session (e.g. "-private-tmp-copilot-<uuid>"). Set by set_execution_context()
# so that path validation can scope tool-results reads to the current session.
_current_project_dir: ContextVar[str] = ContextVar("_current_project_dir", default="")
def _encode_cwd_for_cli(cwd: str) -> str:
"""Encode a working directory path the same way the Claude CLI does.
The CLI replaces all non-alphanumeric characters with ``-``.
"""
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
def is_allowed_local_path(path: str, sdk_cwd: str | None = None) -> bool:
"""Check whether *path* is an allowed host-filesystem path.
Allowed:
- Files under *sdk_cwd* (``/tmp/copilot-<session>/``)
- Files under ``~/.claude/projects/<encoded-cwd>/`` — the SDK's
project directory for this session (tool-results, transcripts, etc.)
Both checks are scoped to the **current session** so sessions cannot
read each other's data.
"""
if not path:
return False
if path.startswith("~"):
resolved = os.path.realpath(os.path.expanduser(path))
elif not os.path.isabs(path) and sdk_cwd:
resolved = os.path.realpath(os.path.join(sdk_cwd, path))
else:
resolved = os.path.realpath(path)
# Allow access within the SDK working directory
if sdk_cwd:
norm_cwd = os.path.realpath(sdk_cwd)
if resolved == norm_cwd or resolved.startswith(norm_cwd + os.sep):
return True
# Allow access within the current session's CLI project directory
# (~/.claude/projects/<encoded-cwd>/).
encoded = _current_project_dir.get("")
if encoded:
session_project = os.path.join(_SDK_PROJECTS_DIR, encoded)
if resolved == session_project or resolved.startswith(session_project + os.sep):
return True
return False
_SDK_PROJECTS_DIR = os.path.expanduser("~/.claude/projects/")
# MCP server naming - the SDK prefixes tool names as "mcp__{server_name}__{tool}"
MCP_SERVER_NAME = "copilot"
@@ -97,66 +38,45 @@ _current_user_id: ContextVar[str | None] = ContextVar("current_user_id", default
_current_session: ContextVar[ChatSession | None] = ContextVar(
"current_session", default=None
)
# E2B cloud sandbox for the current turn (None when E2B is not configured).
# Passed to bash_exec so commands run on E2B instead of the local bwrap sandbox.
_current_sandbox: ContextVar["AsyncSandbox | None"] = ContextVar(
"_current_sandbox", default=None
)
# Raw SDK working directory path (e.g. /tmp/copilot-<session_id>).
# Used by workspace tools to save binary files for the CLI's built-in Read.
_current_sdk_cwd: ContextVar[str] = ContextVar("_current_sdk_cwd", default="")
# Stash for MCP tool outputs before the SDK potentially truncates them.
# Keyed by tool_name → full output string. Consumed (popped) by the
# response adapter when it builds StreamToolOutputAvailable.
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
"pending_tool_outputs",
default=None, # type: ignore[arg-type]
"pending_tool_outputs", default=None # type: ignore[arg-type]
)
# Event signaled whenever stash_pending_tool_output() adds a new entry.
# Used by the streaming loop to wait for PostToolUse hooks to complete
# instead of sleeping an arbitrary duration. The SDK fires hooks via
# start_soon (fire-and-forget) so the next message can arrive before
# the hook stashes its output — this event bridges that gap.
_stash_event: ContextVar[asyncio.Event | None] = ContextVar(
"_stash_event", default=None
# Callback type for delegating long-running tools to the non-SDK infrastructure.
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
LongRunningCallback = Callable[
[str, dict[str, Any], ChatSession], Awaitable[dict[str, Any]]
]
# ContextVar so the service layer can inject the callback per-request.
_long_running_callback: ContextVar[LongRunningCallback | None] = ContextVar(
"long_running_callback", default=None
)
def set_execution_context(
user_id: str | None,
session: ChatSession,
sandbox: "AsyncSandbox | None" = None,
sdk_cwd: str | None = None,
long_running_callback: LongRunningCallback | None = None,
) -> None:
"""Set the execution context for tool calls.
This must be called before streaming begins to ensure tools have access
to user_id, session, and (optionally) an E2B sandbox for bash execution.
to user_id and session information.
Args:
user_id: Current user's ID.
session: Current chat session.
sandbox: Optional E2B sandbox; when set, bash_exec routes commands there.
sdk_cwd: SDK working directory; used to scope tool-results reads.
long_running_callback: Optional callback to delegate long-running tools
to the non-SDK background infrastructure (stream_registry + Redis).
"""
_current_user_id.set(user_id)
_current_session.set(session)
_current_sandbox.set(sandbox)
_current_sdk_cwd.set(sdk_cwd or "")
_current_project_dir.set(_encode_cwd_for_cli(sdk_cwd) if sdk_cwd else "")
_pending_tool_outputs.set({})
_stash_event.set(asyncio.Event())
def get_current_sandbox() -> "AsyncSandbox | None":
"""Return the E2B sandbox for the current turn, or None."""
return _current_sandbox.get()
def get_sdk_cwd() -> str:
"""Return the SDK ephemeral working directory for the current turn."""
return _current_sdk_cwd.get()
_long_running_callback.set(long_running_callback)
def get_execution_context() -> tuple[str | None, ChatSession | None]:
@@ -214,43 +134,6 @@ def stash_pending_tool_output(tool_name: str, output: Any) -> None:
except (TypeError, ValueError):
text = str(output)
pending.setdefault(tool_name, []).append(text)
# Signal any waiters that new output is available.
event = _stash_event.get(None)
if event is not None:
event.set()
async def wait_for_stash(timeout: float = 0.5) -> bool:
"""Wait for a PostToolUse hook to stash tool output.
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
the next message (AssistantMessage/ResultMessage) can arrive before the
hook completes and stashes its output. This function bridges that gap
by waiting on the ``_stash_event``, which is signaled by
:func:`stash_pending_tool_output`.
After the event fires, callers should ``await asyncio.sleep(0)`` to
give any remaining concurrent hooks a chance to complete.
Returns ``True`` if a stash signal was received, ``False`` on timeout.
The timeout is a safety net — normally the stash happens within
microseconds of yielding to the event loop.
"""
event = _stash_event.get(None)
if event is None:
return False
# Fast path: hook already completed before we got here.
if event.is_set():
event.clear()
return True
# Slow path: wait for the hook to signal.
try:
async with asyncio.timeout(timeout):
await event.wait()
event.clear()
return True
except TimeoutError:
return False
async def _execute_tool_sync(
@@ -272,12 +155,66 @@ async def _execute_tool_sync(
result.output if isinstance(result.output, str) else json.dumps(result.output)
)
# Stash the full output before the SDK potentially truncates it.
pending = _pending_tool_outputs.get(None)
if pending is not None:
pending.setdefault(base_tool.name, []).append(text)
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
# If the tool result contains inline image data, add an MCP image block
# so Claude can "see" the image (e.g. read_workspace_file on a small PNG).
image_block = _extract_image_block(text)
if image_block:
content_blocks.append(image_block)
return {
"content": [{"type": "text", "text": text}],
"content": content_blocks,
"isError": not result.success,
}
# MIME types that Claude can process as image content blocks.
_SUPPORTED_IMAGE_TYPES = frozenset(
{"image/png", "image/jpeg", "image/gif", "image/webp"}
)
def _extract_image_block(text: str) -> dict[str, str] | None:
"""Extract an MCP image content block from a tool result JSON string.
Detects workspace file responses with ``content_base64`` and an image
MIME type, returning an MCP-format image block that allows Claude to
"see" the image. Returns ``None`` if the result is not an inline image.
"""
try:
data = json.loads(text)
except (json.JSONDecodeError, TypeError):
return None
if not isinstance(data, dict):
return None
mime_type = data.get("mime_type", "")
base64_content = data.get("content_base64", "")
# Only inline small images — large ones would exceed Claude's limits.
# 32 KB raw ≈ ~43 KB base64.
_MAX_IMAGE_BASE64_BYTES = 43_000
if (
mime_type in _SUPPORTED_IMAGE_TYPES
and base64_content
and len(base64_content) <= _MAX_IMAGE_BASE64_BYTES
):
return {
"type": "image",
"data": base64_content,
"mimeType": mime_type,
}
return None
def _mcp_error(message: str) -> dict[str, Any]:
return {
"content": [
@@ -292,6 +229,11 @@ def create_tool_handler(base_tool: BaseTool):
This wraps the existing BaseTool._execute method to be compatible
with the Claude Agent SDK MCP tool format.
Long-running tools (``is_long_running=True``) are delegated to the
non-SDK background infrastructure via a callback set in the execution
context. The callback persists the operation in Redis (stream_registry)
so results survive page refreshes and pod restarts.
"""
async def tool_handler(args: dict[str, Any]) -> dict[str, Any]:
@@ -301,6 +243,25 @@ def create_tool_handler(base_tool: BaseTool):
if session is None:
return _mcp_error("No session context available")
# --- Long-running: delegate to non-SDK background infrastructure ---
if base_tool.is_long_running:
callback = _long_running_callback.get(None)
if callback:
try:
return await callback(base_tool.name, args, session)
except Exception as e:
logger.error(
f"Long-running callback failed for {base_tool.name}: {e}",
exc_info=True,
)
return _mcp_error(f"Failed to start {base_tool.name}: {e}")
# No callback — fall through to synchronous execution
logger.warning(
f"[SDK] No long-running callback for {base_tool.name}, "
f"executing synchronously (may block)"
)
# --- Normal (fast) tool: execute synchronously ---
try:
return await _execute_tool_sync(base_tool, user_id, session, args)
except Exception as e:
@@ -320,32 +281,29 @@ def _build_input_schema(base_tool: BaseTool) -> dict[str, Any]:
async def _read_file_handler(args: dict[str, Any]) -> dict[str, Any]:
"""Read a local file with optional offset/limit.
"""Read a file with optional offset/limit. Restricted to SDK working directory.
Only allows paths that pass :func:`is_allowed_local_path` — the current
session's tool-results directory and ephemeral working directory.
After reading, the file is deleted to prevent accumulation in long-running pods.
"""
file_path = args.get("file_path", "")
offset = args.get("offset", 0)
limit = args.get("limit", 2000)
if not is_allowed_local_path(file_path):
# Security: only allow reads under ~/.claude/projects/**/tool-results/
real_path = os.path.realpath(file_path)
if not real_path.startswith(_SDK_PROJECTS_DIR) or "tool-results" not in real_path:
return {
"content": [{"type": "text", "text": f"Access denied: {file_path}"}],
"isError": True,
}
resolved = os.path.realpath(os.path.expanduser(file_path))
try:
with open(resolved) as f:
with open(real_path) as f:
selected = list(itertools.islice(f, offset, offset + limit))
content = "".join(selected)
# Cleanup happens in _cleanup_sdk_tool_results after session ends;
# don't delete here — the SDK may read in multiple chunks.
return {
"content": [{"type": "text", "text": content}],
"isError": False,
}
return {"content": [{"type": "text", "text": content}], "isError": False}
except FileNotFoundError:
return {
"content": [{"type": "text", "text": f"File not found: {file_path}"}],
@@ -383,86 +341,50 @@ _READ_TOOL_SCHEMA = {
}
# ---------------------------------------------------------------------------
# MCP result helpers
# ---------------------------------------------------------------------------
def _text_from_mcp_result(result: dict[str, Any]) -> str:
"""Extract concatenated text from an MCP response's content blocks."""
content = result.get("content", [])
if not isinstance(content, list):
return ""
return "".join(
b.get("text", "")
for b in content
if isinstance(b, dict) and b.get("type") == "text"
)
def create_copilot_mcp_server(*, use_e2b: bool = False):
# Create the MCP server configuration
def create_copilot_mcp_server():
"""Create an in-process MCP server configuration for CoPilot tools.
When *use_e2b* is True, five additional MCP file tools are registered
that route directly to the E2B sandbox filesystem, and the caller should
disable the corresponding SDK built-in tools via
:func:`get_sdk_disallowed_tools`.
This can be passed to ClaudeAgentOptions.mcp_servers.
Note: The actual SDK MCP server creation depends on the claude-agent-sdk
package being available. This function returns the configuration that
can be used with the SDK.
"""
try:
from claude_agent_sdk import create_sdk_mcp_server, tool
def _truncating(fn, tool_name: str):
"""Wrap a tool handler so its response is truncated to stay under the
SDK's 10 MB JSON buffer, and stash the (truncated) output for the
response adapter before the SDK can apply its own head-truncation.
# Create decorated tool functions
sdk_tools = []
Applied once to every registered tool."""
async def wrapper(args: dict[str, Any]) -> dict[str, Any]:
result = await fn(args)
truncated = truncate(result, _MCP_MAX_CHARS)
# Stash the text so the response adapter can forward our
# middle-out truncated version to the frontend instead of the
# SDK's head-truncated version (for outputs >~100 KB the SDK
# persists to tool-results/ with a 2 KB head-only preview).
if not truncated.get("isError"):
text = _text_from_mcp_result(truncated)
if text:
stash_pending_tool_output(tool_name, text)
return truncated
return wrapper
sdk_tools = []
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(_truncating(handler, tool_name))
sdk_tools.append(decorated)
# E2B file tools replace SDK built-in Read/Write/Edit/Glob/Grep.
if use_e2b:
for name, desc, schema, handler in E2B_FILE_TOOLS:
decorated = tool(name, desc, schema)(_truncating(handler, name))
for tool_name, base_tool in TOOL_REGISTRY.items():
handler = create_tool_handler(base_tool)
decorated = tool(
tool_name,
base_tool.description,
_build_input_schema(base_tool),
)(handler)
sdk_tools.append(decorated)
# Read tool for SDK-truncated tool results (always needed).
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_truncating(_read_file_handler, _READ_TOOL_NAME))
sdk_tools.append(read_tool)
# Add the Read tool so the SDK can read back oversized tool results
read_tool = tool(
_READ_TOOL_NAME,
_READ_TOOL_DESCRIPTION,
_READ_TOOL_SCHEMA,
)(_read_file_handler)
sdk_tools.append(read_tool)
return create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
server = create_sdk_mcp_server(
name=MCP_SERVER_NAME,
version="1.0.0",
tools=sdk_tools,
)
return server
except ImportError:
# Let ImportError propagate so service.py handles the fallback
raise
# SDK built-in tools allowed within the workspace directory.
@@ -472,11 +394,16 @@ def create_copilot_mcp_server(*, use_e2b: bool = False):
# Task allows spawning sub-agents (rate-limited by security hooks).
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
# TodoWrite manages the task checklist shown in the UI — no security concern.
# In E2B mode, all five are disabled — MCP equivalents provide direct sandbox
# access. read_file also handles local tool-results and ephemeral reads.
_SDK_BUILTIN_FILE_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep"]
_SDK_BUILTIN_ALWAYS = ["Task", "WebSearch", "TodoWrite"]
_SDK_BUILTIN_TOOLS = [*_SDK_BUILTIN_FILE_TOOLS, *_SDK_BUILTIN_ALWAYS]
_SDK_BUILTIN_TOOLS = [
"Read",
"Write",
"Edit",
"Glob",
"Grep",
"Task",
"WebSearch",
"TodoWrite",
]
# SDK built-in tools that must be explicitly blocked.
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
@@ -523,37 +450,11 @@ DANGEROUS_PATTERNS = [
r"subprocess",
]
# Static tool name list for the non-E2B case (backward compatibility).
# List of tool names for allowed_tools configuration
# Include MCP tools, the MCP Read tool for oversized results,
# and SDK built-in file tools for workspace operations.
COPILOT_TOOL_NAMES = [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*_SDK_BUILTIN_TOOLS,
]
def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]:
"""Build the ``allowed_tools`` list for :class:`ClaudeAgentOptions`.
When *use_e2b* is True the SDK built-in file tools are replaced by MCP
equivalents that route to the E2B sandbox.
"""
if not use_e2b:
return list(COPILOT_TOOL_NAMES)
return [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES],
*_SDK_BUILTIN_ALWAYS,
]
def get_sdk_disallowed_tools(*, use_e2b: bool = False) -> list[str]:
"""Build the ``disallowed_tools`` list for :class:`ClaudeAgentOptions`.
When *use_e2b* is True the SDK built-in file tools are also disabled
because MCP equivalents provide direct sandbox access.
"""
if not use_e2b:
return list(SDK_DISALLOWED_TOOLS)
return [*SDK_DISALLOWED_TOOLS, *_SDK_BUILTIN_FILE_TOOLS]

View File

@@ -1,170 +0,0 @@
"""Tests for tool_adapter helpers: truncation, stash, context vars."""
import pytest
from backend.util.truncate import truncate
from .tool_adapter import (
_MCP_MAX_CHARS,
_text_from_mcp_result,
get_sdk_cwd,
pop_pending_tool_output,
set_execution_context,
stash_pending_tool_output,
)
# ---------------------------------------------------------------------------
# _text_from_mcp_result
# ---------------------------------------------------------------------------
class TestTextFromMcpResult:
def test_single_text_block(self):
result = {"content": [{"type": "text", "text": "hello"}]}
assert _text_from_mcp_result(result) == "hello"
def test_multiple_text_blocks_concatenated(self):
result = {
"content": [
{"type": "text", "text": "one"},
{"type": "text", "text": "two"},
]
}
assert _text_from_mcp_result(result) == "onetwo"
def test_non_text_blocks_ignored(self):
result = {
"content": [
{"type": "image", "data": "..."},
{"type": "text", "text": "only this"},
]
}
assert _text_from_mcp_result(result) == "only this"
def test_empty_content_list(self):
assert _text_from_mcp_result({"content": []}) == ""
def test_missing_content_key(self):
assert _text_from_mcp_result({}) == ""
def test_non_list_content(self):
assert _text_from_mcp_result({"content": "raw string"}) == ""
def test_missing_text_field(self):
result = {"content": [{"type": "text"}]}
assert _text_from_mcp_result(result) == ""
# ---------------------------------------------------------------------------
# get_sdk_cwd
# ---------------------------------------------------------------------------
class TestGetSdkCwd:
def test_returns_empty_string_by_default(self):
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
)
assert get_sdk_cwd() == ""
def test_returns_set_value(self):
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
sdk_cwd="/tmp/copilot-test-123",
)
assert get_sdk_cwd() == "/tmp/copilot-test-123"
# ---------------------------------------------------------------------------
# stash / pop round-trip (the mechanism _truncating relies on)
# ---------------------------------------------------------------------------
class TestToolOutputStash:
@pytest.fixture(autouse=True)
def _init_context(self):
"""Initialise the context vars that stash_pending_tool_output needs."""
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
sdk_cwd="/tmp/test",
)
def test_stash_and_pop(self):
stash_pending_tool_output("my_tool", "output1")
assert pop_pending_tool_output("my_tool") == "output1"
def test_pop_empty_returns_none(self):
assert pop_pending_tool_output("nonexistent") is None
def test_fifo_order(self):
stash_pending_tool_output("t", "first")
stash_pending_tool_output("t", "second")
assert pop_pending_tool_output("t") == "first"
assert pop_pending_tool_output("t") == "second"
assert pop_pending_tool_output("t") is None
def test_dict_serialised_to_json(self):
stash_pending_tool_output("t", {"key": "value"})
assert pop_pending_tool_output("t") == '{"key": "value"}'
def test_separate_tool_names(self):
stash_pending_tool_output("a", "alpha")
stash_pending_tool_output("b", "beta")
assert pop_pending_tool_output("b") == "beta"
assert pop_pending_tool_output("a") == "alpha"
# ---------------------------------------------------------------------------
# _truncating wrapper (integration via create_copilot_mcp_server)
# ---------------------------------------------------------------------------
class TestTruncationAndStashIntegration:
"""Test truncation + stash behavior that _truncating relies on."""
@pytest.fixture(autouse=True)
def _init_context(self):
set_execution_context(
user_id="test",
session=None, # type: ignore[arg-type]
sandbox=None,
sdk_cwd="/tmp/test",
)
def test_small_output_stashed(self):
"""Non-error output is stashed for the response adapter."""
result = {
"content": [{"type": "text", "text": "small output"}],
"isError": False,
}
truncated = truncate(result, _MCP_MAX_CHARS)
text = _text_from_mcp_result(truncated)
assert text == "small output"
stash_pending_tool_output("test_tool", text)
assert pop_pending_tool_output("test_tool") == "small output"
def test_error_result_not_stashed(self):
"""Error results should not be stashed."""
result = {
"content": [{"type": "text", "text": "error msg"}],
"isError": True,
}
# _truncating only stashes when not result.get("isError")
if not result.get("isError"):
stash_pending_tool_output("err_tool", "should not happen")
assert pop_pending_tool_output("err_tool") is None
def test_large_output_truncated(self):
"""Output exceeding _MCP_MAX_CHARS is truncated before stashing."""
big_text = "x" * (_MCP_MAX_CHARS + 100_000)
result = {"content": [{"type": "text", "text": big_text}]}
truncated = truncate(result, _MCP_MAX_CHARS)
text = _text_from_mcp_result(truncated)
assert len(text) < len(big_text)
assert len(str(truncated)) <= _MCP_MAX_CHARS

View File

@@ -131,21 +131,18 @@ def read_transcript_file(transcript_path: str) -> str | None:
content = f.read()
if not content.strip():
logger.debug("[Transcript] File is empty: %s", transcript_path)
return None
lines = content.strip().split("\n")
# Validate that the transcript has real conversation content
# (not just metadata like queue-operation entries).
if not validate_transcript(content):
logger.debug(
"[Transcript] No conversation content (%d lines) in %s",
len(lines),
transcript_path,
)
if len(lines) < 3:
# Raw files with ≤2 lines are metadata-only
# (queue-operation + file-history-snapshot, no conversation).
return None
# Quick structural validation — parse first and last lines.
json.loads(lines[0])
json.loads(lines[-1])
logger.info(
f"[Transcript] Read {len(lines)} lines, "
f"{len(content)} bytes from {transcript_path}"
@@ -331,10 +328,10 @@ async def upload_transcript(
) -> None:
"""Strip progress entries and upload transcript to bucket storage.
The executor holds a cluster lock per session, so concurrent uploads for
the same session cannot happen. We always overwrite — with ``--resume``
the CLI may compact old tool results, so neither byte size nor line count
is a reliable proxy for "newer".
Safety: only overwrites when the new (stripped) transcript is larger than
what is already stored. Since JSONL is append-only, the latest transcript
is always the longest. This prevents a slow/stale background task from
clobbering a newer upload from a concurrent turn.
Args:
message_count: ``len(session.messages)`` at upload time — used by
@@ -353,6 +350,20 @@ async def upload_transcript(
storage = await get_workspace_storage()
wid, fid, fname = _storage_path_parts(user_id, session_id)
encoded = stripped.encode("utf-8")
new_size = len(encoded)
# Check existing transcript size to avoid overwriting newer with older
path = _build_storage_path(user_id, session_id, storage)
try:
existing = await storage.retrieve(path)
if len(existing) >= new_size:
logger.info(
f"[Transcript] Skipping upload — existing ({len(existing)}B) "
f">= new ({new_size}B) for session {session_id}"
)
return
except (FileNotFoundError, Exception):
pass # No existing transcript or retrieval error — proceed with upload
await storage.store(
workspace_id=wid,
@@ -361,8 +372,11 @@ async def upload_transcript(
content=encoded,
)
# Update metadata so message_count stays current. The gap-fill logic
# in _build_query_message relies on it to avoid re-compressing messages.
# Store metadata alongside the transcript so the next turn can detect
# staleness and only compress the gap instead of the full history.
# Wrapped in try/except so a metadata write failure doesn't orphan
# the already-uploaded transcript — the next turn will just fall back
# to full gap fill (msg_count=0).
try:
meta = {"message_count": message_count, "uploaded_at": time.time()}
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
@@ -376,7 +390,7 @@ async def upload_transcript(
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
logger.info(
f"[Transcript] Uploaded {len(encoded)}B "
f"[Transcript] Uploaded {new_size}B "
f"(stripped from {len(content)}B, msg_count={message_count}) "
f"for session {session_id}"
)

File diff suppressed because it is too large Load Diff

View File

@@ -4,14 +4,87 @@ from os import getenv
import pytest
from . import service as chat_service
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import StreamError, StreamTextDelta
from .response_model import (
StreamError,
StreamFinish,
StreamTextDelta,
StreamToolOutputAvailable,
)
from .sdk import service as sdk_service
from .sdk.transcript import download_transcript
logger = logging.getLogger(__name__)
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
has_errors = False
has_ended = False
assistant_message = ""
async for chunk in chat_service.stream_chat_completion(
session.session_id, "Hello, how are you?", user_id=session.user_id
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamTextDelta):
assistant_message += chunk.delta
if isinstance(chunk, StreamFinish):
has_ended = True
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert assistant_message, "Assistant message is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user_id):
"""
Test the stream_chat_completion function.
"""
api_key: str | None = getenv("OPEN_ROUTER_API_KEY")
if not api_key:
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
session = await create_chat_session(test_user_id)
session = await upsert_chat_session(session)
has_errors = False
has_ended = False
had_tool_calls = False
async for chunk in chat_service.stream_chat_completion(
session.session_id,
"Please find me an agent that can help me with my business. Use the query 'moneny printing agent'",
user_id=session.user_id,
):
logger.info(chunk)
if isinstance(chunk, StreamError):
has_errors = True
if isinstance(chunk, StreamFinish):
has_ended = True
if isinstance(chunk, StreamToolOutputAvailable):
had_tool_calls = True
assert has_ended, "Chat completion did not end"
assert not has_errors, "Error occurred while streaming chat completion"
assert had_tool_calls, "Tool calls did not occur"
session = await get_chat_session(session.session_id)
assert session, "Session not found"
assert session.usage, "Usage is empty"
@pytest.mark.asyncio(loop_scope="session")
async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
"""Test that the SDK --resume path captures and uses transcripts across turns.
@@ -41,6 +114,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
)
turn1_text = ""
turn1_errors: list[str] = []
turn1_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
@@ -51,27 +125,24 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
turn1_text += chunk.delta
elif isinstance(chunk, StreamError):
turn1_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn1_ended = True
assert turn1_ended, "Turn 1 did not finish"
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
assert turn1_text, "Turn 1 produced no text"
# Wait for background upload task to complete (retry up to 5s).
# The CLI may not produce a usable transcript for very short
# conversations (only metadata entries) — this is environment-dependent
# (CLI version, platform). When that happens, multi-turn still works
# via conversation compression (non-resume path), but we can't test
# the --resume round-trip.
# Wait for background upload task to complete (retry up to 5s)
transcript = None
for _ in range(10):
await asyncio.sleep(0.5)
transcript = await download_transcript(test_user_id, session.session_id)
if transcript:
break
if not transcript:
return pytest.skip(
"CLI did not produce a usable transcript — "
"cannot test --resume round-trip in this environment"
)
assert transcript, (
"Transcript was not uploaded to bucket after turn 1 — "
"Stop hook may not have fired or transcript was too small"
)
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
# Reload session for turn 2
@@ -82,6 +153,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
turn2_msg = "What was the special keyword I asked you to remember?"
turn2_text = ""
turn2_errors: list[str] = []
turn2_ended = False
async for chunk in sdk_service.stream_chat_completion_sdk(
session.session_id,
@@ -93,7 +165,10 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
turn2_text += chunk.delta
elif isinstance(chunk, StreamError):
turn2_errors.append(chunk.errorText)
elif isinstance(chunk, StreamFinish):
turn2_ended = True
assert turn2_ended, "Turn 2 did not finish"
assert not turn2_errors, f"Turn 2 errors: {turn2_errors}"
assert turn2_text, "Turn 2 produced no text"
assert keyword in turn2_text, (

File diff suppressed because it is too large Load Diff

View File

@@ -1,401 +0,0 @@
"""End-to-end tests for Copilot streaming with dummy implementations.
These tests verify the complete copilot flow using dummy implementations
for agent generator and SDK service, allowing automated testing without
external LLM calls.
Enable test mode with COPILOT_TEST_MODE=true environment variable.
Note: StreamFinish is NOT emitted by the dummy service — it is published
by mark_session_completed in the processor layer. These tests only cover
the service-level streaming output (StreamStart + StreamTextDelta).
"""
import asyncio
import os
from uuid import uuid4
import pytest
from backend.copilot.model import ChatMessage, ChatSession, upsert_chat_session
from backend.copilot.response_model import (
StreamError,
StreamHeartbeat,
StreamStart,
StreamTextDelta,
)
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
@pytest.fixture(autouse=True)
def enable_test_mode():
"""Enable test mode for all tests in this module."""
os.environ["COPILOT_TEST_MODE"] = "true"
yield
os.environ.pop("COPILOT_TEST_MODE", None)
@pytest.mark.asyncio
async def test_dummy_streaming_basic_flow():
"""Test that dummy streaming produces correct event sequence."""
events = []
async for event in stream_chat_completion_dummy(
session_id="test-session-basic",
message="Hello",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Verify we got events
assert len(events) > 0, "Should receive events"
# Verify StreamStart
start_events = [e for e in events if isinstance(e, StreamStart)]
assert len(start_events) == 1
assert start_events[0].messageId
assert start_events[0].sessionId
# Verify StreamTextDelta events
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
assert len(text_events) > 0
full_text = "".join(e.delta for e in text_events)
assert len(full_text) > 0
# Verify order: start before text
start_idx = events.index(start_events[0])
first_text_idx = events.index(text_events[0]) if text_events else -1
if first_text_idx >= 0:
assert start_idx < first_text_idx
print(f"✅ Basic flow: {len(events)} events, {len(text_events)} text deltas")
@pytest.mark.asyncio
async def test_streaming_no_timeout():
"""Test that streaming completes within reasonable time without timeout."""
import time
start_time = time.monotonic()
event_count = 0
async for _event in stream_chat_completion_dummy(
session_id="test-session-timeout",
message="count to 10",
is_user_message=True,
user_id="test-user",
):
event_count += 1
elapsed = time.monotonic() - start_time
# Should complete in < 5 seconds (dummy has 0.1s delays between words)
assert elapsed < 5.0, f"Streaming took {elapsed:.1f}s, expected < 5s"
assert event_count > 0, "Should receive events"
print(f"✅ No timeout: completed in {elapsed:.2f}s with {event_count} events")
@pytest.mark.asyncio
async def test_streaming_event_types():
"""Test that all expected event types are present."""
event_types = set()
async for event in stream_chat_completion_dummy(
session_id="test-session-types",
message="test",
is_user_message=True,
user_id="test-user",
):
event_types.add(type(event).__name__)
# Required event types (StreamFinish is published by processor, not service)
assert "StreamStart" in event_types, "Missing StreamStart"
assert "StreamTextDelta" in event_types, "Missing StreamTextDelta"
print(f"✅ Event types: {sorted(event_types)}")
@pytest.mark.asyncio
async def test_streaming_text_content():
"""Test that streamed text is coherent and complete."""
text_events = []
async for event in stream_chat_completion_dummy(
session_id="test-session-content",
message="count to 3",
is_user_message=True,
user_id="test-user",
):
if isinstance(event, StreamTextDelta):
text_events.append(event)
# Verify text deltas
assert len(text_events) > 0, "Should have text deltas"
# Reconstruct full text
full_text = "".join(e.delta for e in text_events)
assert len(full_text) > 0, "Text should not be empty"
assert (
"1" in full_text or "counted" in full_text.lower()
), "Text should contain count"
# Verify all deltas have IDs
for text_event in text_events:
assert text_event.id, "Text delta must have ID"
assert text_event.delta, "Text delta must have content"
print(f"✅ Text content: '{full_text}' ({len(text_events)} deltas)")
@pytest.mark.asyncio
async def test_streaming_heartbeat_timing():
"""Test that heartbeats are sent at correct interval during long operations."""
# This test would need a dummy that takes longer
# For now, just verify heartbeat structure if we receive one
heartbeats = []
async for event in stream_chat_completion_dummy(
session_id="test-session-heartbeat",
message="test",
is_user_message=True,
user_id="test-user",
):
if isinstance(event, StreamHeartbeat):
heartbeats.append(event)
# Dummy is fast, so we might not get heartbeats
# But if we do, verify they're valid
if heartbeats:
print(f"✅ Heartbeat structure verified ({len(heartbeats)} received)")
else:
print("✅ No heartbeats (dummy executes quickly)")
@pytest.mark.asyncio
async def test_error_handling():
"""Test that errors are properly formatted and sent."""
# This would require a dummy that can trigger errors
# For now, just verify error event structure
error = StreamError(errorText="Test error", code="test_error")
assert error.errorText == "Test error"
assert error.code == "test_error"
assert str(error.type.value) in ["error", "error"]
print("✅ Error structure verified")
@pytest.mark.asyncio
async def test_concurrent_sessions():
"""Test that multiple sessions can stream concurrently."""
async def stream_session(session_id: str) -> int:
count = 0
async for _event in stream_chat_completion_dummy(
session_id=session_id,
message="test",
is_user_message=True,
user_id="test-user",
):
count += 1
return count
# Run 3 concurrent sessions
results = await asyncio.gather(
stream_session("session-1"),
stream_session("session-2"),
stream_session("session-3"),
)
# All should complete successfully
assert all(count > 0 for count in results), "All sessions should produce events"
print(f"✅ Concurrent sessions: {results} events each")
@pytest.mark.asyncio
@pytest.mark.xfail(
reason="Event loop isolation issue with DB operations in tests - needs fixture refactoring"
)
async def test_session_state_persistence():
"""Test that session state is maintained across multiple messages."""
from datetime import datetime, timezone
session_id = f"test-session-{uuid4()}"
user_id = "test-user"
# Create session with first message
session = ChatSession(
session_id=session_id,
user_id=user_id,
messages=[
ChatMessage(role="user", content="Hello"),
ChatMessage(role="assistant", content="Hi there!"),
],
usage=[],
started_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
await upsert_chat_session(session)
# Stream second message
events = []
async for event in stream_chat_completion_dummy(
session_id=session_id,
message="How are you?",
is_user_message=True,
user_id=user_id,
session=session, # Pass existing session
):
events.append(event)
# Verify events were produced
assert len(events) > 0, "Should produce events for second message"
print(f"✅ Session persistence: {len(events)} events for second message")
@pytest.mark.asyncio
async def test_message_deduplication():
"""Test that duplicate messages are filtered out."""
# Simulate receiving duplicate events (e.g., from reconnection)
events = []
# First stream
async for event in stream_chat_completion_dummy(
session_id="test-dedup-1",
message="Hello",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Count unique message IDs in StreamStart events
start_events = [e for e in events if isinstance(e, StreamStart)]
message_ids = [e.messageId for e in start_events]
# Verify all IDs are present
assert len(message_ids) == len(set(message_ids)), "Message IDs should be unique"
print(f"✅ Deduplication: {len(events)} events, all unique")
@pytest.mark.asyncio
async def test_event_ordering():
"""Test that events arrive in correct order."""
events = []
async for event in stream_chat_completion_dummy(
session_id="test-ordering",
message="Test",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Find event indices
start_idx = next(
(i for i, e in enumerate(events) if isinstance(e, StreamStart)), None
)
text_indices = [i for i, e in enumerate(events) if isinstance(e, StreamTextDelta)]
# Verify ordering
assert start_idx is not None, "Should have StreamStart"
assert start_idx == 0, "StreamStart should be first"
if text_indices:
assert all(
start_idx < i for i in text_indices
), "Text deltas should be after start"
print(f"✅ Event ordering: start({start_idx}) < text deltas")
@pytest.mark.asyncio
async def test_stream_completeness():
"""Test that stream includes all required event types."""
events = []
async for event in stream_chat_completion_dummy(
session_id="test-completeness",
message="Complete stream test",
is_user_message=True,
user_id="test-user",
):
events.append(event)
# Check for required events (StreamFinish is published by processor)
has_start = any(isinstance(e, StreamStart) for e in events)
has_text = any(isinstance(e, StreamTextDelta) for e in events)
assert has_start, "Stream must include StreamStart"
assert has_text, "Stream must include text deltas"
# Verify exactly one start
start_count = sum(1 for e in events if isinstance(e, StreamStart))
assert start_count == 1, f"Should have exactly 1 StreamStart, got {start_count}"
print(
f"✅ Completeness: 1 start, {sum(1 for e in events if isinstance(e, StreamTextDelta))} text deltas"
)
@pytest.mark.asyncio
async def test_text_delta_consistency():
"""Test that text deltas have consistent IDs and build coherent text."""
text_events = []
async for event in stream_chat_completion_dummy(
session_id="test-consistency",
message="Test consistency",
is_user_message=True,
user_id="test-user",
):
if isinstance(event, StreamTextDelta):
text_events.append(event)
# Verify all text deltas have IDs
assert all(e.id for e in text_events), "All text deltas must have IDs"
# Verify all deltas have the same ID (same text block)
if text_events:
first_id = text_events[0].id
assert all(
e.id == first_id for e in text_events
), "All text deltas should share the same block ID"
# Verify deltas build coherent text
full_text = "".join(e.delta for e in text_events)
assert len(full_text) > 0, "Deltas should build non-empty text"
assert (
full_text == full_text.strip()
), "Text should not have leading/trailing whitespace artifacts"
print(
f"✅ Consistency: {len(text_events)} deltas with ID '{text_events[0].id if text_events else 'N/A'}', text: '{full_text}'"
)
if __name__ == "__main__":
# Run tests directly
print("Running Copilot E2E tests with dummy implementations...")
print("=" * 60)
asyncio.run(test_dummy_streaming_basic_flow())
asyncio.run(test_streaming_no_timeout())
asyncio.run(test_streaming_event_types())
asyncio.run(test_streaming_text_content())
asyncio.run(test_streaming_heartbeat_timing())
asyncio.run(test_error_handling())
asyncio.run(test_concurrent_sessions())
asyncio.run(test_session_state_persistence())
asyncio.run(test_message_deduplication())
asyncio.run(test_event_ordering())
asyncio.run(test_stream_completeness())
asyncio.run(test_text_delta_consistency())
print("=" * 60)
print("✅ All E2E tests passed!")

View File

@@ -1,17 +1,16 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from openai.types.chat import ChatCompletionToolParam
from backend.copilot.model import ChatSession
from backend.copilot.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
from .agent_output import AgentOutputTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .check_operation_status import CheckOperationStatusTool
from .create_agent import CreateAgentTool
from .customize_agent import CustomizeAgentTool
from .edit_agent import EditAgentTool
@@ -22,7 +21,6 @@ from .find_library_agent import FindLibraryAgentTool
from .get_doc_page import GetDocPageTool
from .run_agent import RunAgentTool
from .run_block import RunBlockTool
from .run_mcp_tool import RunMCPToolTool
from .search_docs import SearchDocsTool
from .web_fetch import WebFetchTool
from .workspace_files import (
@@ -33,7 +31,6 @@ from .workspace_files import (
)
if TYPE_CHECKING:
from backend.copilot.model import ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable
logger = logging.getLogger(__name__)
@@ -49,16 +46,12 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"find_library_agent": FindLibraryAgentTool(),
"run_agent": RunAgentTool(),
"run_block": RunBlockTool(),
"run_mcp_tool": RunMCPToolTool(),
"view_agent_output": AgentOutputTool(),
"check_operation_status": CheckOperationStatusTool(),
"search_docs": SearchDocsTool(),
"get_doc_page": GetDocPageTool(),
# Web fetch for safe URL retrieval
"web_fetch": WebFetchTool(),
# Agent-browser multi-step automation (navigate, act, screenshot)
"browser_navigate": BrowserNavigateTool(),
"browser_act": BrowserActTool(),
"browser_screenshot": BrowserScreenshotTool(),
# Sandboxed code execution (bubblewrap)
"bash_exec": BashExecTool(),
# Persistent workspace tools (cloud storage, survives across sessions)
@@ -76,17 +69,10 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
find_agent_tool = TOOL_REGISTRY["find_agent"]
run_agent_tool = TOOL_REGISTRY["run_agent"]
def get_available_tools() -> list[ChatCompletionToolParam]:
"""Return OpenAI tool schemas for tools available in the current environment.
Called per-request so that env-var or binary availability is evaluated
fresh each time (e.g. browser_* tools are excluded when agent-browser
CLI is not installed).
"""
return [
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
]
# Generated from registry for OpenAI API
tools: list[ChatCompletionToolParam] = [
tool.as_openai_tool() for tool in TOOL_REGISTRY.values()
]
def get_tool(tool_name: str) -> BaseTool | None:

View File

@@ -1,10 +1,8 @@
import logging
import uuid
from datetime import UTC, datetime
from os import getenv
import pytest
import pytest_asyncio
from prisma.types import ProfileCreateInput
from pydantic import SecretStr
@@ -13,34 +11,12 @@ from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.llm import AITextGeneratorBlock
from backend.copilot.model import ChatSession
from backend.data import db as db_module
from backend.data.db import prisma
from backend.data.graph import Graph, Link, Node, create_graph
from backend.data.model import APIKeyCredentials
from backend.data.user import get_or_create_user
from backend.integrations.credentials_store import IntegrationCredentialsStore
_logger = logging.getLogger(__name__)
async def _ensure_db_connected() -> None:
"""Ensure the Prisma connection is alive on the current event loop.
On Python 3.11, the httpx transport inside Prisma can reference a stale
(closed) event loop when session-scoped async fixtures are evaluated long
after the initial ``server`` fixture connected Prisma. A cheap health-check
followed by a reconnect fixes this without affecting other fixtures.
"""
try:
await prisma.query_raw("SELECT 1")
except Exception:
_logger.info("Prisma connection stale reconnecting")
try:
await db_module.disconnect()
except Exception:
pass
await db_module.connect()
def make_session(user_id: str):
return ChatSession(
@@ -55,19 +31,15 @@ def make_session(user_id: str):
)
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_test_data(server):
@pytest.fixture(scope="session")
async def setup_test_data():
"""
Set up test data for run_agent tests:
1. Create a test user
2. Create a test graph (agent input -> agent output)
3. Create a store listing and store listing version
4. Approve the store listing version
Depends on ``server`` to ensure Prisma is connected.
"""
await _ensure_db_connected()
# 1. Create a test user
user_data = {
"sub": f"test-user-{uuid.uuid4()}",
@@ -178,19 +150,15 @@ async def setup_test_data(server):
}
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_llm_test_data(server):
@pytest.fixture(scope="session")
async def setup_llm_test_data():
"""
Set up test data for LLM agent tests:
1. Create a test user
2. Create test OpenAI credentials for the user
3. Create a test graph with input -> LLM block -> output
4. Create and approve a store listing
Depends on ``server`` to ensure Prisma is connected.
"""
await _ensure_db_connected()
key = getenv("OPENAI_API_KEY")
if not key:
return pytest.skip("OPENAI_API_KEY is not set")
@@ -347,18 +315,14 @@ async def setup_llm_test_data(server):
}
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def setup_firecrawl_test_data(server):
@pytest.fixture(scope="session")
async def setup_firecrawl_test_data():
"""
Set up test data for Firecrawl agent tests (missing credentials scenario):
1. Create a test user (WITHOUT Firecrawl credentials)
2. Create a test graph with input -> Firecrawl block -> output
3. Create and approve a store listing
Depends on ``server`` to ensure Prisma is connected.
"""
await _ensure_db_connected()
# 1. Create a test user
user_data = {
"sub": f"test-user-{uuid.uuid4()}",

View File

@@ -1,876 +0,0 @@
"""Agent-browser tools — multi-step browser automation for the Copilot.
Uses the agent-browser CLI (https://github.com/vercel-labs/agent-browser)
which runs a local Chromium instance managed by a persistent daemon.
- Runs locally — no cloud account required
- Full interaction support: click, fill, scroll, login flows, multi-step
- Session persistence via --session-name: cookies/auth carry across tool calls
within the same Copilot session, enabling login → navigate → extract workflows
- Screenshot with --annotate overlays @ref labels, saved to workspace for user
- The Claude Agent SDK's multi-turn loop handles orchestration — each tool call
is one browser action; the LLM chains them naturally
SSRF protection:
Uses the shared validate_url() from backend.util.request, which is the same
guard used by HTTP blocks and web_fetch. It resolves ALL DNS answers (not just
the first), blocks RFC 1918, loopback, link-local, 0.0.0.0/8, multicast,
and all relevant IPv6 ranges, and applies IDNA encoding to prevent Unicode
domain attacks.
Requires:
npm install -g agent-browser
agent-browser install (downloads Chromium, one-time per machine)
"""
import asyncio
import base64
import json
import logging
import os
import shutil
import tempfile
from typing import Any
from backend.copilot.model import ChatSession
from backend.util.request import validate_url
from .base import BaseTool
from .models import (
BrowserActResponse,
BrowserNavigateResponse,
BrowserScreenshotResponse,
ErrorResponse,
ToolResponseBase,
)
from .workspace_files import get_manager
logger = logging.getLogger(__name__)
# Per-command timeout (seconds). Navigation + networkidle wait can be slow.
_CMD_TIMEOUT = 45
# Accessibility tree can be very large; cap it to keep LLM context manageable.
_MAX_SNAPSHOT_CHARS = 20_000
# ---------------------------------------------------------------------------
# Subprocess helper
# ---------------------------------------------------------------------------
async def _run(
session_name: str,
*args: str,
timeout: int = _CMD_TIMEOUT,
) -> tuple[int, str, str]:
"""Run agent-browser for the given session and return (rc, stdout, stderr).
Uses both:
--session <name> → isolated Chromium context (no shared history/cookies
with other Copilot sessions — prevents cross-session
browser state leakage)
--session-name <name> → persist cookies/localStorage across tool calls within
the same session (enables login → navigate flows)
"""
cmd = [
"agent-browser",
"--session",
session_name,
"--session-name",
session_name,
*args,
]
proc = None
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout)
return proc.returncode or 0, stdout.decode(), stderr.decode()
except asyncio.TimeoutError:
# Kill the orphaned subprocess so it does not linger in the process table.
if proc is not None and proc.returncode is None:
proc.kill()
try:
await proc.communicate()
except Exception:
pass # Best-effort reap; ignore errors during cleanup.
return 1, "", f"Command timed out after {timeout}s."
except FileNotFoundError:
return (
1,
"",
"agent-browser is not installed (run: npm install -g agent-browser && agent-browser install).",
)
async def _snapshot(session_name: str) -> str:
"""Return the current page's interactive accessibility tree, truncated."""
rc, stdout, stderr = await _run(session_name, "snapshot", "-i", "-c")
if rc != 0:
return f"[snapshot failed: {stderr[:300]}]"
text = stdout.strip()
if len(text) > _MAX_SNAPSHOT_CHARS:
suffix = "\n\n[Snapshot truncated — use browser_act to navigate further]"
keep = max(0, _MAX_SNAPSHOT_CHARS - len(suffix))
text = text[:keep] + suffix
return text
# ---------------------------------------------------------------------------
# Stateless session helpers — persist / restore browser state across pods
# ---------------------------------------------------------------------------
# Module-level cache of sessions known to be alive on this pod.
# Avoids the subprocess probe on every tool call within the same pod.
_alive_sessions: set[str] = set()
# Per-session locks to prevent concurrent _ensure_session calls from
# triggering duplicate _restore_browser_state for the same session.
# Protected by _session_locks_mutex to ensure setdefault/pop are not
# interleaved across await boundaries.
_session_locks: dict[str, asyncio.Lock] = {}
_session_locks_mutex = asyncio.Lock()
# Workspace filename for persisted browser state (auto-scoped to session).
# Dot-prefixed so it is hidden from user workspace listings.
_STATE_FILENAME = "._browser_state.json"
# Maximum concurrent subprocesses during cookie/storage restore.
_RESTORE_CONCURRENCY = 10
# Maximum cookies to restore per session. Pathological sites can accumulate
# thousands of cookies; restoring them all would be slow and is rarely useful.
_MAX_RESTORE_COOKIES = 100
# Background tasks for fire-and-forget state persistence.
# Prevents GC from collecting tasks before they complete.
_background_tasks: set[asyncio.Task] = set()
def _fire_and_forget_save(
session_name: str, user_id: str, session: ChatSession
) -> None:
"""Schedule state persistence as a background task (non-blocking).
State save is already best-effort (errors are swallowed), so running it
in the background avoids adding latency to tool responses.
"""
task = asyncio.create_task(_save_browser_state(session_name, user_id, session))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
async def _has_local_session(session_name: str) -> bool:
"""Check if the local agent-browser daemon for this session is running."""
rc, _, _ = await _run(session_name, "get", "url", timeout=5)
return rc == 0
async def _save_browser_state(
session_name: str, user_id: str, session: ChatSession
) -> None:
"""Persist browser state (cookies, localStorage, URL) to workspace.
Best-effort: errors are logged but never propagate to the tool response.
"""
try:
# Gather state in parallel
(rc_url, url_out, _), (rc_ck, ck_out, _), (rc_ls, ls_out, _) = (
await asyncio.gather(
_run(session_name, "get", "url", timeout=10),
_run(session_name, "cookies", "get", "--json", timeout=10),
_run(session_name, "storage", "local", "--json", timeout=10),
)
)
state = {
"url": url_out.strip() if rc_url == 0 else "",
"cookies": (json.loads(ck_out) if rc_ck == 0 and ck_out.strip() else []),
"local_storage": (
json.loads(ls_out) if rc_ls == 0 and ls_out.strip() else {}
),
}
manager = await get_manager(user_id, session.session_id)
await manager.write_file(
content=json.dumps(state).encode("utf-8"),
filename=_STATE_FILENAME,
mime_type="application/json",
overwrite=True,
)
except Exception:
logger.warning(
"[browser] Failed to save browser state for session %s",
session_name,
exc_info=True,
)
async def _restore_browser_state(
session_name: str, user_id: str, session: ChatSession
) -> bool:
"""Restore browser state from workspace storage into a fresh daemon.
Best-effort: errors are logged but never propagate to the tool response.
Returns True on success (or no state to restore), False on failure.
"""
try:
manager = await get_manager(user_id, session.session_id)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is None:
return True # No saved state — first call or never saved
state_bytes = await manager.read_file(_STATE_FILENAME)
state = json.loads(state_bytes.decode("utf-8"))
url = state.get("url", "")
cookies = state.get("cookies", [])
local_storage = state.get("local_storage", {})
# Navigate first — starts daemon + sets the correct origin for cookies
if url:
# Validate the saved URL to prevent SSRF via stored redirect targets.
try:
await validate_url(url, trusted_origins=[])
except ValueError:
logger.warning(
"[browser] State restore: blocked SSRF URL %s", url[:200]
)
return False
rc, _, stderr = await _run(session_name, "open", url)
if rc != 0:
logger.warning(
"[browser] State restore: failed to open %s: %s",
url,
stderr[:200],
)
return False
await _run(session_name, "wait", "--load", "load", timeout=15)
# Restore cookies and localStorage in parallel via asyncio.gather.
# Semaphore caps concurrent subprocess spawns so we don't overwhelm the
# system when a session has hundreds of cookies.
sem = asyncio.Semaphore(_RESTORE_CONCURRENCY)
# Guard against pathological sites with thousands of cookies.
if len(cookies) > _MAX_RESTORE_COOKIES:
logger.debug(
"[browser] State restore: capping cookies from %d to %d",
len(cookies),
_MAX_RESTORE_COOKIES,
)
cookies = cookies[:_MAX_RESTORE_COOKIES]
async def _set_cookie(c: dict[str, Any]) -> None:
name = c.get("name", "")
value = c.get("value", "")
domain = c.get("domain", "")
path = c.get("path", "/")
if not (name and domain):
return
async with sem:
rc, _, stderr = await _run(
session_name,
"cookies",
"set",
name,
value,
"--domain",
domain,
"--path",
path,
timeout=5,
)
if rc != 0:
logger.debug(
"[browser] State restore: cookie set failed for %s: %s",
name,
stderr[:100],
)
async def _set_storage(key: str, val: object) -> None:
async with sem:
rc, _, stderr = await _run(
session_name,
"storage",
"local",
"set",
key,
str(val),
timeout=5,
)
if rc != 0:
logger.debug(
"[browser] State restore: localStorage set failed for %s: %s",
key,
stderr[:100],
)
await asyncio.gather(
*[_set_cookie(c) for c in cookies],
*[_set_storage(k, v) for k, v in local_storage.items()],
)
return True
except Exception:
logger.warning(
"[browser] Failed to restore browser state for session %s",
session_name,
exc_info=True,
)
return False
async def _ensure_session(
session_name: str, user_id: str, session: ChatSession
) -> None:
"""Ensure the local browser daemon has state. Restore from cloud if needed."""
if session_name in _alive_sessions:
return
async with _session_locks_mutex:
lock = _session_locks.setdefault(session_name, asyncio.Lock())
async with lock:
# Double-check after acquiring lock — another coroutine may have restored.
if session_name in _alive_sessions:
return
if await _has_local_session(session_name):
_alive_sessions.add(session_name)
return
if await _restore_browser_state(session_name, user_id, session):
_alive_sessions.add(session_name)
async def close_browser_session(session_name: str, user_id: str | None = None) -> None:
"""Shut down the local agent-browser daemon and clean up stored state.
Deletes ``._browser_state.json`` from workspace storage so cookies and
other credentials do not linger after the session is deleted.
Best-effort: errors are logged but never raised.
"""
_alive_sessions.discard(session_name)
async with _session_locks_mutex:
_session_locks.pop(session_name, None)
# Delete persisted browser state (cookies, localStorage) from workspace.
if user_id:
try:
manager = await get_manager(user_id, session_name)
file_info = await manager.get_file_info_by_path(_STATE_FILENAME)
if file_info is not None:
await manager.delete_file(file_info.id)
except Exception:
logger.debug(
"[browser] Failed to delete state file for session %s",
session_name,
exc_info=True,
)
try:
rc, _, stderr = await _run(session_name, "close", timeout=10)
if rc != 0:
logger.debug(
"[browser] close failed for session %s: %s",
session_name,
stderr[:200],
)
except Exception:
logger.debug(
"[browser] Exception closing browser session %s",
session_name,
exc_info=True,
)
# ---------------------------------------------------------------------------
# Tool: browser_navigate
# ---------------------------------------------------------------------------
class BrowserNavigateTool(BaseTool):
"""Navigate to a URL and return the page's interactive elements.
The browser session persists across tool calls within this Copilot session
(keyed to session_id), so cookies and auth state carry over. This enables
full login flows: navigate to login page → browser_act to fill credentials
→ browser_act to submit → browser_navigate to the target page.
"""
@property
def name(self) -> str:
return "browser_navigate"
@property
def description(self) -> str:
return (
"Navigate to a URL using a real browser. Returns an accessibility "
"tree snapshot listing the page's interactive elements with @ref IDs "
"(e.g. @e3) that can be used with browser_act. "
"Session persists — cookies and login state carry over between calls. "
"Use this (with browser_act) for multi-step interaction: login flows, "
"form filling, button clicks, or anything requiring page interaction. "
"For plain static pages, prefer web_fetch — no browser overhead. "
"For authenticated pages: navigate to the login page first, use browser_act "
"to fill credentials and submit, then navigate to the target page. "
"Note: for slow SPAs, the returned snapshot may reflect a partially-loaded "
"state. If elements seem missing, use browser_act with action='wait' and a "
"CSS selector or millisecond delay, then take a browser_screenshot to verify."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The HTTP/HTTPS URL to navigate to.",
},
"wait_for": {
"type": "string",
"enum": ["networkidle", "load", "domcontentloaded"],
"default": "networkidle",
"description": "When to consider navigation complete. Use 'networkidle' for SPAs (default).",
},
},
"required": ["url"],
}
@property
def requires_auth(self) -> bool:
return True
@property
def is_available(self) -> bool:
return shutil.which("agent-browser") is not None
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
"""Navigate to *url*, wait for the page to settle, and return a snapshot.
The snapshot is an accessibility-tree listing of interactive elements.
Note: for slow SPAs that never fully idle, the snapshot may reflect a
partially-loaded state (the wait is best-effort).
"""
url: str = (kwargs.get("url") or "").strip()
wait_for: str = kwargs.get("wait_for") or "networkidle"
session_name = session.session_id
if not url:
return ErrorResponse(
message="Please provide a URL to navigate to.",
error="missing_url",
session_id=session_name,
)
try:
await validate_url(url, trusted_origins=[])
except ValueError as e:
return ErrorResponse(
message=str(e),
error="blocked_url",
session_id=session_name,
)
# Restore browser state from cloud if this is a different pod
if user_id:
await _ensure_session(session_name, user_id, session)
# Navigate
rc, _, stderr = await _run(session_name, "open", url)
if rc != 0:
logger.warning(
"[browser_navigate] open failed for %s: %s", url, stderr[:300]
)
return ErrorResponse(
message="Failed to navigate to URL.",
error="navigation_failed",
session_id=session_name,
)
# Wait for page to settle (best-effort: some SPAs never reach networkidle)
wait_rc, _, wait_err = await _run(session_name, "wait", "--load", wait_for)
if wait_rc != 0:
logger.warning(
"[browser_navigate] wait(%s) failed: %s", wait_for, wait_err[:300]
)
# Get current title and URL in parallel
(_, title_out, _), (_, url_out, _) = await asyncio.gather(
_run(session_name, "get", "title"),
_run(session_name, "get", "url"),
)
snapshot = await _snapshot(session_name)
result = BrowserNavigateResponse(
message=f"Navigated to {url}",
url=url_out.strip() or url,
title=title_out.strip(),
snapshot=snapshot,
session_id=session_name,
)
# Persist browser state to cloud for cross-pod continuity
if user_id:
_fire_and_forget_save(session_name, user_id, session)
return result
# ---------------------------------------------------------------------------
# Tool: browser_act
# ---------------------------------------------------------------------------
_NO_TARGET_ACTIONS = frozenset({"back", "forward", "reload"})
_SCROLL_ACTIONS = frozenset({"scroll"})
_TARGET_ONLY_ACTIONS = frozenset({"click", "dblclick", "hover", "check", "uncheck"})
_TARGET_VALUE_ACTIONS = frozenset({"fill", "type", "select"})
# wait <selector|ms>: waits for a DOM element or a fixed delay (e.g. "1000" for 1 s)
_WAIT_ACTIONS = frozenset({"wait"})
class BrowserActTool(BaseTool):
"""Perform an action on the current browser page and return the updated snapshot.
Use @ref IDs from the snapshot returned by browser_navigate (e.g. '@e3').
The LLM orchestrates multi-step flows by chaining browser_navigate and
browser_act calls across turns of the Claude Agent SDK conversation.
"""
@property
def name(self) -> str:
return "browser_act"
@property
def description(self) -> str:
return (
"Interact with the current browser page. Use @ref IDs from the "
"snapshot (e.g. '@e3') to target elements. Returns an updated snapshot. "
"Supported actions: click, dblclick, fill, type, scroll, hover, press, "
"check, uncheck, select, wait, back, forward, reload. "
"fill clears the field before typing; type appends without clearing. "
"wait accepts a CSS selector (waits for element) or milliseconds string (e.g. '1000'). "
"Example login flow: fill @e1 with email → fill @e2 with password → "
"click @e3 (submit) → browser_navigate to the target page."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"click",
"dblclick",
"fill",
"type",
"scroll",
"hover",
"press",
"check",
"uncheck",
"select",
"wait",
"back",
"forward",
"reload",
],
"description": "The action to perform.",
},
"target": {
"type": "string",
"description": (
"Element to target. Use @ref from snapshot (e.g. '@e3'), "
"a CSS selector, or a text description. "
"Required for: click, dblclick, fill, type, hover, check, uncheck, select. "
"For wait: a CSS selector to wait for, or milliseconds as a string (e.g. '1000')."
),
},
"value": {
"type": "string",
"description": (
"For fill/type: the text to enter. "
"For press: key name (e.g. 'Enter', 'Tab', 'Control+a'). "
"For select: the option value to select."
),
},
"direction": {
"type": "string",
"enum": ["up", "down", "left", "right"],
"default": "down",
"description": "For scroll: direction to scroll.",
},
},
"required": ["action"],
}
@property
def requires_auth(self) -> bool:
return True
@property
def is_available(self) -> bool:
return shutil.which("agent-browser") is not None
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
"""Perform a browser action and return an updated page snapshot.
Validates the *action*/*target*/*value* combination, delegates to
``agent-browser``, waits for the page to settle, and returns the
accessibility-tree snapshot so the LLM can plan the next step.
"""
action: str = (kwargs.get("action") or "").strip()
target: str = (kwargs.get("target") or "").strip()
value: str = (kwargs.get("value") or "").strip()
direction: str = (kwargs.get("direction") or "down").strip()
session_name = session.session_id
if not action:
return ErrorResponse(
message="Please specify an action.",
error="missing_action",
session_id=session_name,
)
# Build the agent-browser command args
if action in _NO_TARGET_ACTIONS:
cmd_args = [action]
elif action in _SCROLL_ACTIONS:
cmd_args = ["scroll", direction]
elif action == "press":
if not value:
return ErrorResponse(
message="'press' requires a 'value' (key name, e.g. 'Enter').",
error="missing_value",
session_id=session_name,
)
cmd_args = ["press", value]
elif action in _TARGET_ONLY_ACTIONS:
if not target:
return ErrorResponse(
message=f"'{action}' requires a 'target' element.",
error="missing_target",
session_id=session_name,
)
cmd_args = [action, target]
elif action in _TARGET_VALUE_ACTIONS:
if not target or not value:
return ErrorResponse(
message=f"'{action}' requires both 'target' and 'value'.",
error="missing_params",
session_id=session_name,
)
cmd_args = [action, target, value]
elif action in _WAIT_ACTIONS:
if not target:
return ErrorResponse(
message=(
"'wait' requires a 'target': a CSS selector to wait for, "
"or milliseconds as a string (e.g. '1000')."
),
error="missing_target",
session_id=session_name,
)
cmd_args = ["wait", target]
else:
return ErrorResponse(
message=f"Unsupported action: {action}",
error="invalid_action",
session_id=session_name,
)
# Restore browser state from cloud if this is a different pod
if user_id:
await _ensure_session(session_name, user_id, session)
rc, _, stderr = await _run(session_name, *cmd_args)
if rc != 0:
logger.warning("[browser_act] %s failed: %s", action, stderr[:300])
return ErrorResponse(
message=f"Action '{action}' failed.",
error="action_failed",
session_id=session_name,
)
# Allow the page to settle after interaction (best-effort: SPAs may not idle)
settle_rc, _, settle_err = await _run(
session_name, "wait", "--load", "networkidle"
)
if settle_rc != 0:
logger.warning(
"[browser_act] post-action wait failed: %s", settle_err[:300]
)
snapshot = await _snapshot(session_name)
_, url_out, _ = await _run(session_name, "get", "url")
result = BrowserActResponse(
message=f"Performed '{action}'" + (f" on '{target}'" if target else ""),
action=action,
current_url=url_out.strip(),
snapshot=snapshot,
session_id=session_name,
)
# Persist browser state to cloud for cross-pod continuity
if user_id:
_fire_and_forget_save(session_name, user_id, session)
return result
# ---------------------------------------------------------------------------
# Tool: browser_screenshot
# ---------------------------------------------------------------------------
class BrowserScreenshotTool(BaseTool):
"""Capture a screenshot of the current browser page and save it to the workspace."""
@property
def name(self) -> str:
return "browser_screenshot"
@property
def description(self) -> str:
return (
"Take a screenshot of the current browser page and save it to the workspace. "
"IMPORTANT: After calling this tool, immediately call read_workspace_file "
"with the returned file_id to display the image inline to the user — "
"the screenshot is not visible until you do this. "
"With annotate=true (default), @ref labels are overlaid on interactive "
"elements, making it easy to see which @ref ID maps to which element on screen."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"annotate": {
"type": "boolean",
"default": True,
"description": "Overlay @ref labels on interactive elements (default: true).",
},
"filename": {
"type": "string",
"default": "screenshot.png",
"description": "Filename to save in the workspace.",
},
},
}
@property
def requires_auth(self) -> bool:
return True
@property
def is_available(self) -> bool:
return shutil.which("agent-browser") is not None
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
"""Capture a PNG screenshot and upload it to the workspace.
Handles string-to-bool coercion for *annotate* (OpenAI function-call
payloads sometimes deliver ``"true"``/``"false"`` as strings).
Returns a :class:`BrowserScreenshotResponse` with the workspace
``file_id`` the LLM should pass to ``read_workspace_file``.
"""
raw_annotate = kwargs.get("annotate", True)
if isinstance(raw_annotate, str):
annotate = raw_annotate.strip().lower() in {"1", "true", "yes", "on"}
else:
annotate = bool(raw_annotate)
filename: str = (kwargs.get("filename") or "screenshot.png").strip()
session_name = session.session_id
# Restore browser state from cloud if this is a different pod
if user_id:
await _ensure_session(session_name, user_id, session)
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".png")
os.close(tmp_fd)
try:
cmd_args = ["screenshot"]
if annotate:
cmd_args.append("--annotate")
cmd_args.append(tmp_path)
rc, _, stderr = await _run(session_name, *cmd_args)
if rc != 0:
logger.warning("[browser_screenshot] failed: %s", stderr[:300])
return ErrorResponse(
message="Failed to take screenshot.",
error="screenshot_failed",
session_id=session_name,
)
with open(tmp_path, "rb") as f:
png_bytes = f.read()
finally:
try:
os.unlink(tmp_path)
except OSError:
pass # Best-effort temp file cleanup; not critical if it fails.
# Upload to workspace so the user can view it
png_b64 = base64.b64encode(png_bytes).decode()
# Import here to avoid circular deps — workspace_files imports from .models
from .workspace_files import WorkspaceWriteResponse, WriteWorkspaceFileTool
write_resp = await WriteWorkspaceFileTool()._execute(
user_id=user_id,
session=session,
filename=filename,
content_base64=png_b64,
)
if not isinstance(write_resp, WorkspaceWriteResponse):
return ErrorResponse(
message="Screenshot taken but failed to save to workspace.",
error="workspace_write_failed",
session_id=session_name,
)
result = BrowserScreenshotResponse(
message=f"Screenshot saved to workspace as '{filename}'. Use read_workspace_file with file_id='{write_resp.file_id}' to retrieve it.",
file_id=write_resp.file_id,
filename=filename,
session_id=session_name,
)
# Persist browser state to cloud for cross-pod continuity
if user_id:
_fire_and_forget_save(session_name, user_id, session)
return result

View File

@@ -19,7 +19,6 @@ from .core import (
get_all_relevant_agents_for_generation,
get_library_agent_by_graph_id,
get_library_agent_by_id,
get_library_agents_by_ids,
get_library_agents_for_generation,
graph_to_json,
json_to_graph,
@@ -50,7 +49,6 @@ __all__ = [
"get_all_relevant_agents_for_generation",
"get_library_agent_by_graph_id",
"get_library_agent_by_id",
"get_library_agents_by_ids",
"get_library_agents_for_generation",
"get_user_message_for_error",
"graph_to_json",

View File

@@ -3,7 +3,6 @@
import logging
import re
import uuid
from collections.abc import Sequence
from typing import Any, NotRequired, TypedDict
from backend.data.db_accessors import graph_db, library_db, store_db
@@ -79,7 +78,7 @@ AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
def _to_dict_list(
agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None,
agents: list[AgentSummary] | list[dict[str, Any]] | None,
) -> list[dict[str, Any]] | None:
"""Convert typed agent summaries to plain dicts for external service calls."""
if agents is None:
@@ -191,36 +190,6 @@ async def get_library_agent_by_id(
get_library_agent_by_graph_id = get_library_agent_by_id
async def get_library_agents_by_ids(
user_id: str,
agent_ids: list[str],
) -> list[LibraryAgentSummary]:
"""Fetch multiple library agents by their IDs.
Args:
user_id: The user ID
agent_ids: List of agent IDs (can be graph_ids or library agent IDs)
Returns:
List of LibraryAgentSummary for found agents (silently skips not found)
"""
agents: list[LibraryAgentSummary] = []
for agent_id in agent_ids:
try:
agent = await get_library_agent_by_id(user_id, agent_id)
if agent:
agents.append(agent)
logger.debug(f"Fetched library agent by ID: {agent['name']}")
else:
logger.warning(f"Library agent not found for ID: {agent_id}")
except Exception as e:
logger.warning(f"Failed to fetch library agent {agent_id}: {e}")
continue
logger.info(f"Fetched {len(agents)}/{len(agent_ids)} library agents by ID")
return agents
async def get_library_agents_for_generation(
user_id: str,
search_query: str | None = None,
@@ -245,17 +214,10 @@ async def get_library_agents_for_generation(
Returns:
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
"""
search_term = search_query.strip() if search_query else None
if search_term and len(search_term) > 100:
raise ValueError(
f"Search query is too long ({len(search_term)} chars, max 100). "
f"Please use a shorter, more specific search term."
)
try:
response = await library_db().list_library_agents(
user_id=user_id,
search_term=search_term,
search_term=search_query,
page=1,
page_size=max_results,
include_executions=True,
@@ -309,16 +271,9 @@ async def search_marketplace_agents_for_generation(
Returns:
List of LibraryAgentSummary with full input/output schemas
"""
search_term = search_query.strip()
if len(search_term) > 100:
raise ValueError(
f"Search query is too long ({len(search_term)} chars, max 100). "
f"Please use a shorter, more specific search term."
)
try:
response = await store_db().get_store_agents(
search_query=search_term,
search_query=search_query,
page=1,
page_size=max_results,
)
@@ -469,7 +424,7 @@ def extract_search_terms_from_steps(
async def enrich_library_agents_from_steps(
user_id: str,
decomposition_result: DecompositionResult | dict[str, Any],
existing_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]],
existing_agents: list[AgentSummary] | list[dict[str, Any]],
exclude_graph_id: str | None = None,
include_marketplace: bool = True,
max_additional_results: int = 10,
@@ -493,7 +448,7 @@ async def enrich_library_agents_from_steps(
search_terms = extract_search_terms_from_steps(decomposition_result)
if not search_terms:
return list(existing_agents)
return existing_agents
existing_ids: set[str] = set()
existing_names: set[str] = set()
@@ -556,7 +511,7 @@ async def enrich_library_agents_from_steps(
async def decompose_goal(
description: str,
context: str = "",
library_agents: Sequence[AgentSummary] | None = None,
library_agents: list[AgentSummary] | None = None,
) -> DecompositionResult | None:
"""Break down a goal into steps or return clarifying questions.
@@ -584,16 +539,22 @@ async def decompose_goal(
async def generate_agent(
instructions: DecompositionResult | dict[str, Any],
library_agents: Sequence[AgentSummary] | Sequence[dict[str, Any]] | None = None,
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Generate agent JSON from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams
completion notification)
task_id: Task ID for async processing (enables Redis Streams persistence
and SSE delivery)
Returns:
Agent JSON dict, error dict {"type": "error", ...}, or None on error
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -601,9 +562,13 @@ async def generate_agent(
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent")
result = await generate_agent_external(
dict(instructions), _to_dict_list(library_agents)
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
)
# Don't modify async response
if result and result.get("status") == "accepted":
return result
if result:
if isinstance(result, dict) and result.get("type") == "error":
return result
@@ -793,7 +758,9 @@ async def get_agent_as_json(
async def generate_agent_patch(
update_request: str,
current_agent: dict[str, Any],
library_agents: Sequence[AgentSummary] | None = None,
library_agents: list[AgentSummary] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Update an existing agent using natural language.
@@ -806,10 +773,12 @@ async def generate_agent_patch(
update_request: Natural language description of changes
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
error dict {"type": "error", ...}, or None on error
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
Raises:
AgentGeneratorNotConfiguredError: If the external service is not configured.
@@ -820,6 +789,8 @@ async def generate_agent_patch(
update_request,
current_agent,
_to_dict_list(library_agents),
operation_id,
task_id,
)

View File

@@ -102,15 +102,10 @@ async def generate_agent_dummy(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy agent synchronously (blocks for 30s, returns agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator (sync mode): returning agent JSON after 30s"
)
"""Return dummy agent JSON after a simulated delay."""
logger.info("Using dummy agent generator for generate_agent (30s delay)")
await asyncio.sleep(30)
return _generate_dummy_agent_json()
@@ -120,16 +115,10 @@ async def generate_agent_patch_dummy(
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
session_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any]:
"""Return dummy patched agent synchronously (blocks for 30s, returns patched agent JSON).
Note: operation_id and session_id parameters are ignored - we always use synchronous mode.
"""
logger.info(
"Using dummy agent generator patch (sync mode): returning patched agent after 30s"
)
await asyncio.sleep(30)
"""Return dummy patched agent (returns the current agent with updated description)."""
logger.info("Using dummy agent generator for generate_agent_patch")
patched = current_agent.copy()
patched["description"] = (
f"{current_agent.get('description', '')} (updated: {update_request})"

View File

@@ -242,18 +242,24 @@ async def decompose_goal_external(
async def generate_agent_external(
instructions: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate an agent from instructions.
Args:
instructions: Structured instructions from decompose_goal
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Agent JSON dict or error dict {"type": "error", ...} on error
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
"""
if _is_dummy_mode():
return await generate_agent_dummy(instructions, library_agents)
return await generate_agent_dummy(
instructions, library_agents, operation_id, task_id
)
client = _get_client()
@@ -261,9 +267,25 @@ async def generate_agent_external(
payload: dict[str, Any] = {"instructions": instructions}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/generate-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
@@ -295,6 +317,8 @@ async def generate_agent_patch_external(
update_request: str,
current_agent: dict[str, Any],
library_agents: list[dict[str, Any]] | None = None,
operation_id: str | None = None,
task_id: str | None = None,
) -> dict[str, Any] | None:
"""Call the external service to generate a patch for an existing agent.
@@ -303,14 +327,14 @@ async def generate_agent_patch_external(
current_agent: Current agent JSON
library_agents: User's library agents available for sub-agent composition
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
task_id: Task ID for async processing (enables Redis Streams callback)
Returns:
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
"""
if _is_dummy_mode():
return await generate_agent_patch_dummy(
update_request, current_agent, library_agents
update_request, current_agent, library_agents, operation_id, task_id
)
client = _get_client()
@@ -322,9 +346,25 @@ async def generate_agent_patch_external(
}
if library_agents:
payload["library_agents"] = library_agents
if operation_id and task_id:
payload["operation_id"] = operation_id
payload["task_id"] = task_id
try:
response = await client.post("/api/update-agent", json=payload)
# Handle 202 Accepted for async processing
if response.status_code == 202:
logger.info(
f"Agent Generator accepted async update request "
f"(operation_id={operation_id}, task_id={task_id})"
)
return {
"status": "accepted",
"operation_id": operation_id,
"task_id": task_id,
}
response.raise_for_status()
data = response.json()
@@ -379,8 +419,6 @@ async def customize_template_external(
template_agent: The template agent JSON to customize
modification_request: Natural language description of customizations
context: Additional context (e.g., answers to previous questions)
operation_id: Operation ID for async processing (enables Redis Streams callback)
session_id: Session ID for async processing (enables Redis Streams callback)
Returns:
Customized agent JSON, clarifying questions dict, or error dict on error

View File

@@ -5,7 +5,7 @@ import re
from datetime import datetime, timedelta, timezone
from typing import Any
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, field_validator
from backend.api.features.library.model import LibraryAgent
from backend.copilot.model import ChatSession
@@ -13,7 +13,6 @@ from backend.data.db_accessors import execution_db, library_db
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
from .base import BaseTool
from .execution_utils import TERMINAL_STATUSES, wait_for_execution
from .models import (
AgentOutputResponse,
ErrorResponse,
@@ -34,7 +33,6 @@ class AgentOutputInput(BaseModel):
store_slug: str = ""
execution_id: str = ""
run_time: str = "latest"
wait_if_running: int = Field(default=0, ge=0, le=300)
@field_validator(
"agent_name",
@@ -118,11 +116,6 @@ class AgentOutputTool(BaseTool):
Select which run to retrieve using:
- execution_id: Specific execution ID
- run_time: 'latest' (default), 'yesterday', 'last week', or ISO date 'YYYY-MM-DD'
Wait for completion (optional):
- wait_if_running: Max seconds to wait if execution is still running (0-300).
If the execution is running/queued, waits up to this many seconds for completion.
Returns current status on timeout. If already finished, returns immediately.
"""
@property
@@ -152,13 +145,6 @@ class AgentOutputTool(BaseTool):
"Time filter: 'latest', 'yesterday', 'last week', or 'YYYY-MM-DD'"
),
},
"wait_if_running": {
"type": "integer",
"description": (
"Max seconds to wait if execution is still running (0-300). "
"If running, waits for completion. Returns current state on timeout."
),
},
},
"required": [],
}
@@ -238,14 +224,10 @@ class AgentOutputTool(BaseTool):
execution_id: str | None,
time_start: datetime | None,
time_end: datetime | None,
include_running: bool = False,
) -> tuple[GraphExecution | None, list[GraphExecutionMeta], str | None]:
"""
Fetch execution(s) based on filters.
Returns (single_execution, available_executions_meta, error_message).
Args:
include_running: If True, also look for running/queued executions (for waiting)
"""
exec_db = execution_db()
@@ -260,25 +242,11 @@ class AgentOutputTool(BaseTool):
return None, [], f"Execution '{execution_id}' not found"
return execution, [], None
# Determine which statuses to query
statuses = [ExecutionStatus.COMPLETED]
if include_running:
statuses.extend(
[
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
ExecutionStatus.REVIEW,
ExecutionStatus.FAILED,
ExecutionStatus.TERMINATED,
]
)
# Get executions with time filters
# Get completed executions with time filters
executions = await exec_db.get_graph_executions(
graph_id=graph_id,
user_id=user_id,
statuses=statuses,
statuses=[ExecutionStatus.COMPLETED],
created_time_gte=time_start,
created_time_lte=time_end,
limit=10,
@@ -345,33 +313,10 @@ class AgentOutputTool(BaseTool):
for e in available_executions[:5]
]
# Build appropriate message based on execution status
if execution.status == ExecutionStatus.COMPLETED:
message = f"Found execution outputs for agent '{agent.name}'"
elif execution.status == ExecutionStatus.FAILED:
message = f"Execution for agent '{agent.name}' failed"
elif execution.status == ExecutionStatus.TERMINATED:
message = f"Execution for agent '{agent.name}' was terminated"
elif execution.status == ExecutionStatus.REVIEW:
message = (
f"Execution for agent '{agent.name}' is awaiting human review. "
"The user needs to approve it before it can continue."
)
elif execution.status in (
ExecutionStatus.RUNNING,
ExecutionStatus.QUEUED,
ExecutionStatus.INCOMPLETE,
):
message = (
f"Execution for agent '{agent.name}' is still {execution.status.value}. "
"Results may be incomplete. Use wait_if_running to wait for completion."
)
else:
message = f"Found execution for agent '{agent.name}' (status: {execution.status.value})"
message = f"Found execution outputs for agent '{agent.name}'"
if len(available_executions) > 1:
message += (
f" Showing latest of {len(available_executions)} matching executions."
f". Showing latest of {len(available_executions)} matching executions."
)
return AgentOutputResponse(
@@ -486,17 +431,13 @@ class AgentOutputTool(BaseTool):
# Parse time expression
time_start, time_end = parse_time_expression(input_data.run_time)
# Check if we should wait for running executions
wait_timeout = input_data.wait_if_running
# Fetch execution(s) - include running if we're going to wait
# Fetch execution(s)
execution, available_executions, exec_error = await self._get_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=input_data.execution_id or None,
time_start=time_start,
time_end=time_end,
include_running=wait_timeout > 0,
)
if exec_error:
@@ -505,17 +446,4 @@ class AgentOutputTool(BaseTool):
session_id=session_id,
)
# If we have an execution that's still running and we should wait
if execution and wait_timeout > 0 and execution.status not in TERMINAL_STATUSES:
logger.info(
f"Execution {execution.id} is {execution.status}, "
f"waiting up to {wait_timeout}s for completion"
)
execution = await wait_for_execution(
user_id=user_id,
graph_id=agent.graph_id,
execution_id=execution.id,
timeout_seconds=wait_timeout,
)
return self._build_response(agent, execution, available_executions, session_id)

View File

@@ -1,13 +1,8 @@
"""Shared agent search functionality for find_agent and find_library_agent tools."""
from __future__ import annotations
import logging
import re
from typing import TYPE_CHECKING, Literal
if TYPE_CHECKING:
from backend.api.features.library.model import LibraryAgent
from typing import Literal
from backend.data.db_accessors import library_db, store_db
from backend.util.exceptions import DatabaseError, NotFoundError
@@ -29,24 +24,94 @@ _UUID_PATTERN = re.compile(
re.IGNORECASE,
)
# Keywords that should be treated as "list all" rather than a literal search
_LIST_ALL_KEYWORDS = frozenset({"all", "*", "everything", "any", ""})
def _is_uuid(text: str) -> bool:
"""Check if text is a valid UUID v4."""
return bool(_UUID_PATTERN.match(text.strip()))
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
"""Fetch a library agent by ID (library agent ID or graph_id).
Tries multiple lookup strategies:
1. First by graph_id (AgentGraph primary key)
2. Then by library agent ID (LibraryAgent primary key)
Args:
user_id: The user ID
agent_id: The ID to look up (can be graph_id or library agent ID)
Returns:
AgentInfo if found, None otherwise
"""
lib_db = library_db()
try:
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by graph_id {agent_id}: {e}",
exc_info=True,
)
try:
agent = await lib_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None
async def search_agents(
query: str,
source: SearchSource,
session_id: str | None = None,
session_id: str | None,
user_id: str | None = None,
) -> ToolResponseBase:
"""
Search for agents in marketplace or user library.
For library searches, keywords like "all", "*", "everything", or an empty
query will list all agents without filtering.
Args:
query: Search query string. Special keywords list all library agents.
query: Search query string
source: "marketplace" or "library"
session_id: Chat session ID
user_id: User ID (required for library search)
@@ -54,11 +119,7 @@ async def search_agents(
Returns:
AgentsFoundResponse, NoResultsResponse, or ErrorResponse
"""
# Normalize list-all keywords to empty string for library searches
if source == "library" and query.lower().strip() in _LIST_ALL_KEYWORDS:
query = ""
if source == "marketplace" and not query:
if not query:
return ErrorResponse(
message="Please provide a search query", session_id=session_id
)
@@ -98,18 +159,28 @@ async def search_agents(
logger.info(f"Found agent by direct ID lookup: {agent.name}")
if not agents:
search_term = query or None
logger.info(
f"{'Listing all agents in' if not query else 'Searching'} "
f"user library{'' if not query else f' for: {query}'}"
)
logger.info(f"Searching user library for: {query}")
results = await library_db().list_library_agents(
user_id=user_id, # type: ignore[arg-type]
search_term=search_term,
page_size=50 if not query else 10,
search_term=query,
page_size=10,
)
for agent in results.agents:
agents.append(_library_agent_to_info(agent))
agents.append(
AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
)
logger.info(f"Found {len(agents)} agents in {source}")
except NotFoundError:
pass
@@ -122,62 +193,42 @@ async def search_agents(
)
if not agents:
if source == "marketplace":
suggestions = [
suggestions = (
[
"Try more general terms",
"Browse categories in the marketplace",
"Check spelling",
]
no_results_msg = (
f"No agents found matching '{query}'. Let the user know they can "
"try different keywords or browse the marketplace. Also let them "
"know you can create a custom agent for them based on their needs."
)
elif not query:
# User asked to list all but library is empty
suggestions = [
"Browse the marketplace to find and add agents",
"Use find_agent to search the marketplace",
]
no_results_msg = (
"Your library is empty. Let the user know they can browse the "
"marketplace to find agents, or you can create a custom agent "
"for them based on their needs."
)
else:
suggestions = [
if source == "marketplace"
else [
"Try different keywords",
"Use find_agent to search the marketplace",
"Check your library at /library",
]
no_results_msg = (
f"No agents matching '{query}' found in your library. Let the "
"user know you can create a custom agent for them based on "
"their needs."
)
)
no_results_msg = (
f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs."
if source == "marketplace"
else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs."
)
return NoResultsResponse(
message=no_results_msg, session_id=session_id, suggestions=suggestions
)
if source == "marketplace":
title = (
f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} for '{query}'"
)
elif not query:
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library"
else:
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} in your library for '{query}'"
title = f"Found {len(agents)} agent{'s' if len(agents) != 1 else ''} "
title += (
f"for '{query}'"
if source == "marketplace"
else f"in your library for '{query}'"
)
message = (
"Now you have found some options for the user to choose from. "
"You can add a link to a recommended agent at: /marketplace/agent/agent_id "
"Please ask the user if they would like to use any of these agents. "
"Let the user know we can create a custom agent for them based on their needs."
"Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs."
if source == "marketplace"
else "Found agents in the user's library. You can provide a link to view "
"an agent at: /library/agents/{agent_id}. Use agent_output to get "
"execution results, or run_agent to execute. Let the user know we can "
"create a custom agent for them based on their needs."
else "Found agents in the user's library. You can provide a link to view an agent at: "
"/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs."
)
return AgentsFoundResponse(
@@ -187,67 +238,3 @@ async def search_agents(
count=len(agents),
session_id=session_id,
)
def _is_uuid(text: str) -> bool:
"""Check if text is a valid UUID v4."""
return bool(_UUID_PATTERN.match(text.strip()))
def _library_agent_to_info(agent: LibraryAgent) -> AgentInfo:
"""Convert a library agent model to an AgentInfo."""
return AgentInfo(
id=agent.id,
name=agent.name,
description=agent.description or "",
source="library",
in_library=True,
creator=agent.creator_name,
status=agent.status.value,
can_access_graph=agent.can_access_graph,
has_external_trigger=agent.has_external_trigger,
new_output=agent.new_output,
graph_id=agent.graph_id,
)
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
"""Fetch a library agent by ID (library agent ID or graph_id).
Tries multiple lookup strategies:
1. First by graph_id (AgentGraph primary key)
2. Then by library agent ID (LibraryAgent primary key)
"""
lib_db = library_db()
try:
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
if agent:
logger.debug(f"Found library agent by graph_id: {agent.name}")
return _library_agent_to_info(agent)
except NotFoundError:
logger.debug(f"Library agent not found by graph_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by graph_id {agent_id}: {e}",
exc_info=True,
)
try:
agent = await lib_db.get_library_agent(agent_id, user_id)
if agent:
logger.debug(f"Found library agent by library_id: {agent.name}")
return _library_agent_to_info(agent)
except NotFoundError:
logger.debug(f"Library agent not found by library_id: {agent_id}")
except DatabaseError:
raise
except Exception as e:
logger.warning(
f"Could not fetch library agent by library_id {agent_id}: {e}",
exc_info=True,
)
return None

View File

@@ -1,6 +1,5 @@
"""Base classes and shared utilities for chat tools."""
import json
import logging
from typing import Any
@@ -8,98 +7,11 @@ from openai.types.chat import ChatCompletionToolParam
from backend.copilot.model import ChatSession
from backend.copilot.response_model import StreamToolOutputAvailable
from backend.data.db_accessors import workspace_db
from backend.util.truncate import truncate
from backend.util.workspace import WorkspaceManager
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
logger = logging.getLogger(__name__)
# Persist full tool output to workspace when it exceeds this threshold.
# Must be below _MAX_TOOL_OUTPUT_SIZE (100K) in response_model.py so we
# capture the data before model_post_init middle-out truncation discards it.
_LARGE_OUTPUT_THRESHOLD = 80_000
# Character budget for the middle-out preview. The total preview + wrapper
# must stay below BOTH:
# - _MAX_TOOL_OUTPUT_SIZE (100K) in response_model.py (our own truncation)
# - Claude SDK's ~100 KB tool-result spill-to-disk threshold
# to avoid double truncation/spilling. 95K + ~300 wrapper = ~95.3K, under both.
_PREVIEW_CHARS = 95_000
# Fields whose values are binary/base64 data — truncating them produces
# garbage, so we replace them with a human-readable size summary instead.
_BINARY_FIELD_NAMES = {"content_base64"}
def _summarize_binary_fields(raw_json: str) -> str:
"""Replace known binary fields with a size summary so truncate() doesn't
produce garbled base64 in the middle-out preview."""
try:
data = json.loads(raw_json)
except (json.JSONDecodeError, TypeError):
return raw_json
if not isinstance(data, dict):
return raw_json
changed = False
for key in _BINARY_FIELD_NAMES:
if key in data and isinstance(data[key], str) and len(data[key]) > 1_000:
byte_size = len(data[key]) * 3 // 4 # approximate decoded size
data[key] = f"<binary, ~{byte_size:,} bytes>"
changed = True
return json.dumps(data, ensure_ascii=False) if changed else raw_json
async def _persist_and_summarize(
raw_output: str,
user_id: str,
session_id: str,
tool_call_id: str,
) -> str:
"""Persist full output to workspace and return a middle-out preview with retrieval instructions.
On failure, returns the original ``raw_output`` unchanged so that the
existing ``model_post_init`` middle-out truncation handles it as before.
"""
file_path = f"tool-outputs/{tool_call_id}.json"
try:
workspace = await workspace_db().get_or_create_workspace(user_id)
manager = WorkspaceManager(user_id, workspace.id, session_id)
await manager.write_file(
content=raw_output.encode("utf-8"),
filename=f"{tool_call_id}.json",
path=file_path,
mime_type="application/json",
overwrite=True,
)
except Exception:
logger.warning(
"Failed to persist large tool output for %s",
tool_call_id,
exc_info=True,
)
return raw_output # fall back to normal truncation
total = len(raw_output)
preview = truncate(_summarize_binary_fields(raw_output), _PREVIEW_CHARS)
retrieval = (
f"\nFull output ({total:,} chars) saved to workspace. "
f"Use read_workspace_file("
f'path="{file_path}", offset=<char_offset>, length=50000) '
f"to read any section."
)
return (
f'<tool-output-truncated total_chars={total} path="{file_path}">\n'
f"{preview}\n"
f"{retrieval}\n"
f"</tool-output-truncated>"
)
class BaseTool:
"""Base class for all chat tools."""
@@ -125,14 +37,14 @@ class BaseTool:
return False
@property
def is_available(self) -> bool:
"""Whether this tool is available in the current environment.
def is_long_running(self) -> bool:
"""Whether this tool is long-running and should execute in background.
Override to check required env vars, binaries, or other dependencies.
Unavailable tools are excluded from the LLM tool list so the model is
never offered an option that will immediately fail.
Long-running tools (like agent generation) are executed via background
tasks to survive SSE disconnections. The result is persisted to chat
history and visible when the user refreshes.
"""
return True
return False
def as_openai_tool(self) -> ChatCompletionToolParam:
"""Convert to OpenAI tool format."""
@@ -155,7 +67,7 @@ class BaseTool:
"""Execute the tool with authentication check.
Args:
user_id: User ID (None for anonymous users)
user_id: User ID (may be anonymous like "anon_123")
session_id: Chat session ID
**kwargs: Tool-specific parameters
@@ -179,21 +91,10 @@ class BaseTool:
try:
result = await self._execute(user_id, session, **kwargs)
raw_output = result.model_dump_json()
if (
len(raw_output) > _LARGE_OUTPUT_THRESHOLD
and user_id
and session.session_id
):
raw_output = await _persist_and_summarize(
raw_output, user_id, session.session_id, tool_call_id
)
return StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=self.name,
output=raw_output,
output=result.model_dump_json(),
)
except Exception as e:
logger.error(f"Error in {self.name}: {e}", exc_info=True)

View File

@@ -1,194 +0,0 @@
"""Tests for BaseTool large-output persistence in execute()."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.copilot.tools.base import (
_LARGE_OUTPUT_THRESHOLD,
BaseTool,
_persist_and_summarize,
_summarize_binary_fields,
)
from backend.copilot.tools.models import ResponseType, ToolResponseBase
class _HugeOutputTool(BaseTool):
"""Fake tool that returns an arbitrarily large output."""
def __init__(self, output_size: int) -> None:
self._output_size = output_size
@property
def name(self) -> str:
return "huge_output_tool"
@property
def description(self) -> str:
return "Returns a huge output"
@property
def parameters(self) -> dict:
return {"type": "object", "properties": {}}
async def _execute(self, user_id, session, **kwargs) -> ToolResponseBase:
return ToolResponseBase(
type=ResponseType.ERROR,
message="x" * self._output_size,
)
# ---------------------------------------------------------------------------
# _persist_and_summarize
# ---------------------------------------------------------------------------
class TestPersistAndSummarize:
@pytest.mark.asyncio
async def test_returns_middle_out_preview_with_retrieval_instructions(self):
raw = "A" * 200_000
mock_workspace = MagicMock()
mock_workspace.id = "ws-1"
mock_db = AsyncMock()
mock_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
mock_manager = AsyncMock()
with (
patch("backend.copilot.tools.base.workspace_db", return_value=mock_db),
patch(
"backend.copilot.tools.base.WorkspaceManager",
return_value=mock_manager,
),
):
result = await _persist_and_summarize(raw, "user-1", "session-1", "tc-123")
assert "<tool-output-truncated" in result
assert "</tool-output-truncated>" in result
assert "total_chars=200000" in result
assert 'path="tool-outputs/tc-123.json"' in result
assert "read_workspace_file" in result
# Middle-out sentinel from truncate()
assert "omitted" in result
# Total result is much shorter than the raw output
assert len(result) < len(raw)
# Verify write_file was called with full content
mock_manager.write_file.assert_awaited_once()
call_kwargs = mock_manager.write_file.call_args
assert call_kwargs.kwargs["content"] == raw.encode("utf-8")
assert call_kwargs.kwargs["path"] == "tool-outputs/tc-123.json"
@pytest.mark.asyncio
async def test_fallback_on_workspace_error(self):
"""If workspace write fails, return raw output for normal truncation."""
raw = "B" * 200_000
mock_db = AsyncMock()
mock_db.get_or_create_workspace = AsyncMock(side_effect=RuntimeError("boom"))
with patch("backend.copilot.tools.base.workspace_db", return_value=mock_db):
result = await _persist_and_summarize(raw, "user-1", "session-1", "tc-fail")
assert result == raw # unchanged — fallback to normal truncation
# ---------------------------------------------------------------------------
# BaseTool.execute — integration with persistence
# ---------------------------------------------------------------------------
class TestBaseToolExecuteLargeOutput:
@pytest.mark.asyncio
async def test_small_output_not_persisted(self):
"""Outputs under the threshold go through without persistence."""
tool = _HugeOutputTool(output_size=100)
session = MagicMock()
session.session_id = "s-1"
with patch(
"backend.copilot.tools.base._persist_and_summarize",
new_callable=AsyncMock,
) as persist_mock:
result = await tool.execute("user-1", session, "tc-small")
persist_mock.assert_not_awaited()
assert "<tool-output-truncated" not in str(result.output)
@pytest.mark.asyncio
async def test_large_output_persisted(self):
"""Outputs over the threshold trigger persistence + preview."""
tool = _HugeOutputTool(output_size=_LARGE_OUTPUT_THRESHOLD + 10_000)
session = MagicMock()
session.session_id = "s-1"
mock_workspace = MagicMock()
mock_workspace.id = "ws-1"
mock_db = AsyncMock()
mock_db.get_or_create_workspace = AsyncMock(return_value=mock_workspace)
mock_manager = AsyncMock()
with (
patch("backend.copilot.tools.base.workspace_db", return_value=mock_db),
patch(
"backend.copilot.tools.base.WorkspaceManager",
return_value=mock_manager,
),
):
result = await tool.execute("user-1", session, "tc-big")
assert "<tool-output-truncated" in str(result.output)
assert "read_workspace_file" in str(result.output)
mock_manager.write_file.assert_awaited_once()
@pytest.mark.asyncio
async def test_no_persistence_without_user_id(self):
"""Anonymous users skip persistence (no workspace)."""
tool = _HugeOutputTool(output_size=_LARGE_OUTPUT_THRESHOLD + 10_000)
session = MagicMock()
session.session_id = "s-1"
# user_id=None → should not attempt persistence
with patch(
"backend.copilot.tools.base._persist_and_summarize",
new_callable=AsyncMock,
) as persist_mock:
result = await tool.execute(None, session, "tc-anon")
persist_mock.assert_not_awaited()
# Output is set but not wrapped in <tool-output-truncated> tags
# (it will be middle-out truncated by model_post_init instead)
assert "<tool-output-truncated" not in str(result.output)
# ---------------------------------------------------------------------------
# _summarize_binary_fields
# ---------------------------------------------------------------------------
class TestSummarizeBinaryFields:
def test_replaces_large_content_base64(self):
import json
data = {"content_base64": "A" * 10_000, "name": "file.png"}
result = json.loads(_summarize_binary_fields(json.dumps(data)))
assert result["name"] == "file.png"
assert "<binary" in result["content_base64"]
assert "bytes>" in result["content_base64"]
def test_preserves_small_content_base64(self):
import json
data = {"content_base64": "AQID", "name": "tiny.bin"}
result_str = _summarize_binary_fields(json.dumps(data))
result = json.loads(result_str)
assert result["content_base64"] == "AQID" # unchanged
def test_non_json_passthrough(self):
raw = "not json at all"
assert _summarize_binary_fields(raw) == raw
def test_no_binary_fields_unchanged(self):
import json
data = {"message": "hello", "type": "info"}
raw = json.dumps(data)
assert _summarize_binary_fields(raw) == raw

View File

@@ -1,30 +1,19 @@
"""Bash execution tool — run shell commands on E2B or in a bubblewrap sandbox.
"""Bash execution tool — run shell commands in a bubblewrap sandbox.
When an E2B sandbox is available in the current execution context the command
runs directly on the remote E2B cloud environment. This means:
Full Bash scripting is allowed (loops, conditionals, pipes, functions, etc.).
Safety comes from OS-level isolation (bubblewrap): only system dirs visible
read-only, writable workspace only, clean env, no network.
- **Persistent filesystem**: files survive across turns via HTTP-based sync
with the sandbox's ``/home/user`` directory (E2B files API), shared with
SDK Read/Write/Edit tools.
- **Full internet access**: E2B sandboxes have unrestricted outbound network.
- **Execution isolation**: E2B provides a fresh, containerised Linux environment.
When E2B is *not* configured the tool falls back to **bubblewrap** (bwrap):
OS-level isolation with a whitelist-only filesystem, no network, and resource
limits. Requires bubblewrap to be installed (Linux only).
Requires bubblewrap (``bwrap``) — the tool is disabled when bwrap is not
available (e.g. macOS development).
"""
import logging
import shlex
from typing import Any
from e2b import AsyncSandbox
from e2b.exceptions import TimeoutException
from backend.copilot.model import ChatSession
from .base import BaseTool
from .e2b_sandbox import E2B_WORKDIR
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
@@ -32,7 +21,7 @@ logger = logging.getLogger(__name__)
class BashExecTool(BaseTool):
"""Execute Bash commands on E2B or in a bubblewrap sandbox."""
"""Execute Bash commands in a bubblewrap sandbox."""
@property
def name(self) -> str:
@@ -40,16 +29,28 @@ class BashExecTool(BaseTool):
@property
def description(self) -> str:
if not has_full_sandbox():
return (
"Bash execution is DISABLED — bubblewrap sandbox is not "
"available on this platform. Do not call this tool."
)
return (
"Execute a Bash command or script. "
"Execute a Bash command or script in a bubblewrap sandbox. "
"Full Bash scripting is supported (loops, conditionals, pipes, "
"functions, etc.). "
"The working directory is shared with the SDK Read/Write/Edit/Glob/Grep "
"tools — files created by either are immediately visible to both. "
"The sandbox shares the same working directory as the SDK Read/Write "
"tools — files created by either are accessible to both. "
"SECURITY: Only system directories (/usr, /bin, /lib, /etc) are "
"visible read-only, the per-session workspace is the only writable "
"path, environment variables are wiped (no secrets), all network "
"access is blocked at the kernel level, and resource limits are "
"enforced (max 64 processes, 512MB memory, 50MB file size). "
"Application code, configs, and other directories are NOT accessible. "
"To fetch web content, use the web_fetch tool instead. "
"Execution is killed after the timeout (default 30s, max 120s). "
"Returns stdout and stderr. "
"Useful for file manipulation, data processing, running scripts, "
"and installing packages."
"Useful for file manipulation, data processing with Unix tools "
"(grep, awk, sed, jq, etc.), and running shell scripts."
)
@property
@@ -84,8 +85,15 @@ class BashExecTool(BaseTool):
) -> ToolResponseBase:
session_id = session.session_id if session else None
if not has_full_sandbox():
return ErrorResponse(
message="bash_exec requires bubblewrap sandbox (Linux only).",
error="sandbox_unavailable",
session_id=session_id,
)
command: str = (kwargs.get("command") or "").strip()
timeout: int = int(kwargs.get("timeout", 30))
timeout: int = kwargs.get("timeout", 30)
if not command:
return ErrorResponse(
@@ -94,21 +102,6 @@ class BashExecTool(BaseTool):
session_id=session_id,
)
# E2B path: run on remote cloud sandbox when available.
from backend.copilot.sdk.tool_adapter import get_current_sandbox
sandbox = get_current_sandbox()
if sandbox is not None:
return await self._execute_on_e2b(sandbox, command, timeout, session_id)
# Bubblewrap fallback: local isolated execution.
if not has_full_sandbox():
return ErrorResponse(
message="bash_exec requires bubblewrap sandbox (Linux only).",
error="sandbox_unavailable",
session_id=session_id,
)
workspace = get_workspace_dir(session_id or "default")
stdout, stderr, exit_code, timed_out = await run_sandboxed(
@@ -129,43 +122,3 @@ class BashExecTool(BaseTool):
timed_out=timed_out,
session_id=session_id,
)
async def _execute_on_e2b(
self,
sandbox: AsyncSandbox,
command: str,
timeout: int,
session_id: str | None,
) -> ToolResponseBase:
"""Execute *command* on the E2B sandbox via commands.run()."""
try:
result = await sandbox.commands.run(
f"bash -c {shlex.quote(command)}",
cwd=E2B_WORKDIR,
timeout=timeout,
envs={"PATH": "/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin"},
)
return BashExecResponse(
message=f"Command executed on E2B (exit {result.exit_code})",
stdout=result.stdout or "",
stderr=result.stderr or "",
exit_code=result.exit_code,
timed_out=False,
session_id=session_id,
)
except Exception as exc:
if isinstance(exc, TimeoutException):
return BashExecResponse(
message="Execution timed out",
stdout="",
stderr=f"Timed out after {timeout}s",
exit_code=-1,
timed_out=True,
session_id=session_id,
)
logger.error("[E2B] bash_exec failed: %s", exc, exc_info=True)
return ErrorResponse(
message=f"E2B execution failed: {exc}",
error="e2b_execution_error",
session_id=session_id,
)

View File

@@ -0,0 +1,124 @@
"""CheckOperationStatusTool — query the status of a long-running operation."""
import logging
from typing import Any
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import ErrorResponse, ResponseType, ToolResponseBase
logger = logging.getLogger(__name__)
class OperationStatusResponse(ToolResponseBase):
"""Response for check_operation_status tool."""
type: ResponseType = ResponseType.OPERATION_STATUS
task_id: str
operation_id: str
status: str # "running", "completed", "failed"
tool_name: str | None = None
message: str = ""
class CheckOperationStatusTool(BaseTool):
"""Check the status of a long-running operation (create_agent, edit_agent, etc.).
The CoPilot uses this tool to report back to the user whether an
operation that was started earlier has completed, failed, or is still
running.
"""
@property
def name(self) -> str:
return "check_operation_status"
@property
def description(self) -> str:
return (
"Check the current status of a long-running operation such as "
"create_agent or edit_agent. Accepts either an operation_id or "
"task_id from a previous operation_started response. "
"Returns the current status: running, completed, or failed."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"operation_id": {
"type": "string",
"description": (
"The operation_id from an operation_started response."
),
},
"task_id": {
"type": "string",
"description": (
"The task_id from an operation_started response. "
"Used as fallback if operation_id is not provided."
),
},
},
"required": [],
}
@property
def requires_auth(self) -> bool:
return False
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs,
) -> ToolResponseBase:
from backend.copilot import stream_registry
operation_id = (kwargs.get("operation_id") or "").strip()
task_id = (kwargs.get("task_id") or "").strip()
if not operation_id and not task_id:
return ErrorResponse(
message="Please provide an operation_id or task_id.",
error="missing_parameter",
)
task = None
if operation_id:
task = await stream_registry.find_task_by_operation_id(operation_id)
if task is None and task_id:
task = await stream_registry.get_task(task_id)
if task is None:
# Task not in Redis — it may have already expired (TTL).
# Check conversation history for the result instead.
return ErrorResponse(
message=(
"Operation not found — it may have already completed and "
"expired from the status tracker. Check the conversation "
"history for the result."
),
error="not_found",
)
status_messages = {
"running": (
f"The {task.tool_name or 'operation'} is still running. "
"Please wait for it to complete."
),
"completed": (
f"The {task.tool_name or 'operation'} has completed successfully."
),
"failed": f"The {task.tool_name or 'operation'} has failed.",
}
return OperationStatusResponse(
task_id=task.task_id,
operation_id=task.operation_id,
status=task.status,
tool_name=task.tool_name,
message=status_messages.get(task.status, f"Status: {task.status}"),
)

View File

@@ -10,6 +10,7 @@ from .agent_generator import (
decompose_goal,
enrich_library_agents_from_steps,
generate_agent,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
@@ -17,6 +18,7 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -38,16 +40,17 @@ class CreateAgentTool(BaseTool):
def description(self) -> str:
return (
"Create a new agent workflow from a natural language description. "
"First generates a preview, then saves to library if save=true. "
"\n\nIMPORTANT: Before calling this tool, search for relevant existing agents "
"using find_library_agent that could be used as building blocks. "
"Pass their IDs in the library_agent_ids parameter so the generator can compose them."
"First generates a preview, then saves to library if save=true."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -67,15 +70,6 @@ class CreateAgentTool(BaseTool):
"Include any preferences or constraints mentioned by the user."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks. "
"Search for relevant agents using find_library_agent first, "
"then pass their IDs here so they can be composed into the new agent."
),
},
"save": {
"type": "boolean",
"description": (
@@ -103,14 +97,12 @@ class CreateAgentTool(BaseTool):
"""
description = kwargs.get("description", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
session_id = session.session_id if session else None
logger.info(
f"[AGENT_CREATE_DEBUG] START - description_len={len(description)}, "
f"library_agent_ids={library_agent_ids}, save={save}, user_id={user_id}, session_id={session_id}"
)
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not description:
return ErrorResponse(
@@ -119,34 +111,25 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id and library_agent_ids:
if user_id:
try:
from .agent_generator import get_library_agents_by_ids
library_agents = await get_library_agents_by_ids(
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
agent_ids=library_agent_ids,
search_query=description,
include_marketplace=True,
)
logger.debug(
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
logger.warning(f"Failed to fetch library agents: {e}")
try:
decomposition_result = await decompose_goal(
description, context, library_agents
)
logger.info(
f"[AGENT_CREATE_DEBUG] DECOMPOSE - type={decomposition_result.get('type') if decomposition_result else None}, "
f"session_id={session_id}"
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
@@ -247,17 +230,10 @@ class CreateAgentTool(BaseTool):
agent_json = await generate_agent(
decomposition_result,
library_agents,
)
logger.info(
f"[AGENT_CREATE_DEBUG] GENERATE - "
f"success={agent_json is not None}, "
f"is_error={isinstance(agent_json, dict) and agent_json.get('type') == 'error'}, "
f"session_id={session_id}"
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - AgentGeneratorNotConfigured during generation, session_id={session_id}"
)
return ErrorResponse(
message=(
"Agent generation is not available. "
@@ -300,20 +276,25 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if agent_json.get("status") == "accepted":
logger.info(
f"Agent generation delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent generation started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
agent_name = agent_json.get("name", "Generated Agent")
agent_description = agent_json.get("description", "")
node_count = len(agent_json.get("nodes", []))
link_count = len(agent_json.get("links", []))
logger.info(
f"[AGENT_CREATE_DEBUG] AGENT_JSON - name={agent_name}, "
f"nodes={node_count}, links={link_count}, save={save}, session_id={session_id}"
)
if not save:
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentPreviewResponse, session_id={session_id}"
)
return AgentPreviewResponse(
message=(
f"I've generated an agent called '{agent_name}' with {node_count} blocks. "
@@ -339,13 +320,6 @@ class CreateAgentTool(BaseTool):
agent_json, user_id
)
logger.info(
f"[AGENT_CREATE_DEBUG] SAVED - graph_id={created_graph.id}, "
f"library_agent_id={library_agent.id}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - AgentSavedResponse, session_id={session_id}"
)
return AgentSavedResponse(
message=f"Agent '{created_graph.name}' has been saved to your library!",
agent_id=created_graph.id,
@@ -356,12 +330,6 @@ class CreateAgentTool(BaseTool):
session_id=session_id,
)
except Exception as e:
logger.error(
f"[AGENT_CREATE_DEBUG] ERROR - save_failed: {str(e)}, session_id={session_id}"
)
logger.info(
f"[AGENT_CREATE_DEBUG] RETURN - ErrorResponse (save_failed), session_id={session_id}"
)
return ErrorResponse(
message=f"Failed to save the agent: {str(e)}",
error="save_failed",

View File

@@ -43,6 +43,11 @@ async def test_vague_goal_returns_suggested_goal_response(tool, session):
}
with (
patch(
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=[],
),
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
@@ -73,6 +78,11 @@ async def test_unachievable_goal_returns_suggested_goal_response(tool, session):
}
with (
patch(
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=[],
),
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,
@@ -110,6 +120,11 @@ async def test_clarifying_questions_returns_clarification_needed_response(
}
with (
patch(
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
new_callable=AsyncMock,
return_value=[],
),
patch(
"backend.copilot.tools.create_agent.decompose_goal",
new_callable=AsyncMock,

View File

@@ -46,6 +46,10 @@ class CustomizeAgentTool(BaseTool):
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {

View File

@@ -1,170 +0,0 @@
"""E2B sandbox lifecycle for CoPilot: persistent cloud execution.
Each session gets a long-lived E2B cloud sandbox. ``bash_exec`` runs commands
directly on the sandbox via ``sandbox.commands.run()``. SDK file tools
(read_file/write_file/edit_file/glob/grep) route to the sandbox's
``/home/user`` directory via E2B's HTTP-based filesystem API — all tools
share a single coherent filesystem with no local sync required.
Lifecycle
---------
1. **Turn start** connect to the existing sandbox (sandbox_id in Redis) or
create a new one via ``get_or_create_sandbox()``.
2. **Execution** ``bash_exec`` and MCP file tools operate directly on the
sandbox's ``/home/user`` filesystem.
3. **Session expiry** E2B sandbox is killed by its own timeout (session_ttl).
"""
import asyncio
import logging
from e2b import AsyncSandbox
from backend.data.redis_client import get_redis_async
logger = logging.getLogger(__name__)
_SANDBOX_REDIS_PREFIX = "copilot:e2b:sandbox:"
E2B_WORKDIR = "/home/user"
_CREATING = "__creating__"
_CREATION_LOCK_TTL = 60
_MAX_WAIT_ATTEMPTS = 20 # 20 * 0.5s = 10s max wait
async def _try_reconnect(
sandbox_id: str, api_key: str, redis_key: str, timeout: int
) -> "AsyncSandbox | None":
"""Try to reconnect to an existing sandbox. Returns None on failure."""
try:
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
if await sandbox.is_running():
redis = await get_redis_async()
await redis.expire(redis_key, timeout)
return sandbox
except Exception as exc:
logger.warning("[E2B] Reconnect to %.12s failed: %s", sandbox_id, exc)
# Stale — clear Redis so a new sandbox can be created.
redis = await get_redis_async()
await redis.delete(redis_key)
return None
async def get_or_create_sandbox(
session_id: str,
api_key: str,
template: str = "base",
timeout: int = 43200,
) -> AsyncSandbox:
"""Return the existing E2B sandbox for *session_id* or create a new one.
The sandbox_id is persisted in Redis so the same sandbox is reused
across turns. Concurrent calls for the same session are serialised
via a Redis ``SET NX`` creation lock.
"""
redis = await get_redis_async()
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
# 1. Try reconnecting to an existing sandbox.
raw = await redis.get(redis_key)
if raw:
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
if sandbox:
logger.info(
"[E2B] Reconnected to %.12s for session %.12s",
sandbox_id,
session_id,
)
return sandbox
# 2. Claim creation lock. If another request holds it, wait for the result.
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
if not claimed:
for _ in range(_MAX_WAIT_ATTEMPTS):
await asyncio.sleep(0.5)
raw = await redis.get(redis_key)
if not raw:
break # Lock expired — fall through to retry creation
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(sandbox_id, api_key, redis_key, timeout)
if sandbox:
return sandbox
break # Stale sandbox cleared — fall through to create
# Try to claim creation lock again after waiting.
claimed = await redis.set(redis_key, _CREATING, nx=True, ex=_CREATION_LOCK_TTL)
if not claimed:
# Another process may have created a sandbox — try to use it.
raw = await redis.get(redis_key)
if raw:
sandbox_id = raw if isinstance(raw, str) else raw.decode()
if sandbox_id != _CREATING:
sandbox = await _try_reconnect(
sandbox_id, api_key, redis_key, timeout
)
if sandbox:
return sandbox
raise RuntimeError(
f"Could not acquire E2B creation lock for session {session_id[:12]}"
)
# 3. Create a new sandbox.
try:
sandbox = await AsyncSandbox.create(
template=template, api_key=api_key, timeout=timeout
)
except Exception:
await redis.delete(redis_key)
raise
await redis.setex(redis_key, timeout, sandbox.sandbox_id)
logger.info(
"[E2B] Created sandbox %.12s for session %.12s",
sandbox.sandbox_id,
session_id,
)
return sandbox
async def kill_sandbox(session_id: str, api_key: str) -> bool:
"""Kill the E2B sandbox for *session_id* and clean up its Redis entry.
Returns ``True`` if a sandbox was found and killed, ``False`` otherwise.
Safe to call even when no sandbox exists for the session.
"""
redis = await get_redis_async()
redis_key = f"{_SANDBOX_REDIS_PREFIX}{session_id}"
raw = await redis.get(redis_key)
if not raw:
return False
sandbox_id = raw if isinstance(raw, str) else raw.decode()
await redis.delete(redis_key)
if sandbox_id == _CREATING:
return False
try:
async def _connect_and_kill():
sandbox = await AsyncSandbox.connect(sandbox_id, api_key=api_key)
await sandbox.kill()
await asyncio.wait_for(_connect_and_kill(), timeout=10)
logger.info(
"[E2B] Killed sandbox %.12s for session %.12s",
sandbox_id,
session_id,
)
return True
except Exception as exc:
logger.warning(
"[E2B] Failed to kill sandbox %.12s for session %.12s: %s",
sandbox_id,
session_id,
exc,
)
return False

View File

@@ -1,272 +0,0 @@
"""Tests for e2b_sandbox: get_or_create_sandbox, _try_reconnect, kill_sandbox.
Uses mock Redis and mock AsyncSandbox — no external dependencies.
Tests are synchronous (using asyncio.run) to avoid conflicts with the
session-scoped event loop in conftest.py.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .e2b_sandbox import (
_CREATING,
_SANDBOX_REDIS_PREFIX,
_try_reconnect,
get_or_create_sandbox,
kill_sandbox,
)
_KEY = f"{_SANDBOX_REDIS_PREFIX}sess-123"
_API_KEY = "test-api-key"
_TIMEOUT = 300
def _mock_sandbox(sandbox_id: str = "sb-abc", running: bool = True) -> MagicMock:
sb = MagicMock()
sb.sandbox_id = sandbox_id
sb.is_running = AsyncMock(return_value=running)
return sb
def _mock_redis(get_val: str | bytes | None = None, set_nx_result: bool = True):
r = AsyncMock()
r.get = AsyncMock(return_value=get_val)
r.set = AsyncMock(return_value=set_nx_result)
r.setex = AsyncMock()
r.delete = AsyncMock()
r.expire = AsyncMock()
return r
def _patch_redis(redis):
return patch(
"backend.copilot.tools.e2b_sandbox.get_redis_async",
new_callable=AsyncMock,
return_value=redis,
)
# ---------------------------------------------------------------------------
# _try_reconnect
# ---------------------------------------------------------------------------
class TestTryReconnect:
def test_reconnect_success(self):
sb = _mock_sandbox()
redis = _mock_redis()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
assert result is sb
redis.expire.assert_awaited_once_with(_KEY, _TIMEOUT)
redis.delete.assert_not_awaited()
def test_reconnect_not_running_clears_key(self):
sb = _mock_sandbox(running=False)
redis = _mock_redis()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
assert result is None
redis.delete.assert_awaited_once_with(_KEY)
def test_reconnect_exception_clears_key(self):
redis = _mock_redis()
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(_try_reconnect("sb-abc", _API_KEY, _KEY, _TIMEOUT))
assert result is None
redis.delete.assert_awaited_once_with(_KEY)
# ---------------------------------------------------------------------------
# get_or_create_sandbox
# ---------------------------------------------------------------------------
class TestGetOrCreateSandbox:
def test_reconnect_existing(self):
"""When Redis has a valid sandbox_id, reconnect to it."""
sb = _mock_sandbox()
redis = _mock_redis(get_val="sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
mock_cls.create.assert_not_called()
def test_create_new_when_no_key(self):
"""When Redis is empty, claim lock and create a new sandbox."""
sb = _mock_sandbox("sb-new")
redis = _mock_redis(get_val=None, set_nx_result=True)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
redis.setex.assert_awaited_once_with(_KEY, _TIMEOUT, "sb-new")
def test_create_failure_clears_lock(self):
"""If sandbox creation fails, the Redis lock is deleted."""
redis = _mock_redis(get_val=None, set_nx_result=True)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.create = AsyncMock(side_effect=RuntimeError("quota"))
with pytest.raises(RuntimeError, match="quota"):
asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
redis.delete.assert_awaited_once_with(_KEY)
def test_wait_for_lock_then_reconnect(self):
"""When another process holds the lock, wait and reconnect."""
sb = _mock_sandbox("sb-other")
redis = _mock_redis()
redis.get = AsyncMock(side_effect=[_CREATING, "sb-other"])
redis.set = AsyncMock(return_value=False)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
patch(
"backend.copilot.tools.e2b_sandbox.asyncio.sleep",
new_callable=AsyncMock,
),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
assert result is sb
def test_stale_reconnect_clears_and_creates(self):
"""When stored sandbox is stale, clear key and create a new one."""
stale_sb = _mock_sandbox("sb-stale", running=False)
new_sb = _mock_sandbox("sb-fresh")
redis = _mock_redis(get_val="sb-stale", set_nx_result=True)
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=stale_sb)
mock_cls.create = AsyncMock(return_value=new_sb)
result = asyncio.run(
get_or_create_sandbox("sess-123", _API_KEY, timeout=_TIMEOUT)
)
assert result is new_sb
redis.delete.assert_awaited()
# ---------------------------------------------------------------------------
# kill_sandbox
# ---------------------------------------------------------------------------
class TestKillSandbox:
def test_kill_existing_sandbox(self):
"""Kill a running sandbox and clean up Redis."""
sb = _mock_sandbox()
sb.kill = AsyncMock()
redis = _mock_redis(get_val="sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is True
redis.delete.assert_awaited_once_with(_KEY)
sb.kill.assert_awaited_once()
def test_kill_no_sandbox(self):
"""No-op when no sandbox exists in Redis."""
redis = _mock_redis(get_val=None)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is False
redis.delete.assert_not_awaited()
def test_kill_creating_state(self):
"""Clears Redis key but returns False when sandbox is still being created."""
redis = _mock_redis(get_val=_CREATING)
with _patch_redis(redis):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)
def test_kill_connect_failure(self):
"""Returns False and cleans Redis if connect/kill fails."""
redis = _mock_redis(get_val="sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(side_effect=ConnectionError("gone"))
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)
def test_kill_with_bytes_redis_value(self):
"""Redis may return bytes — kill_sandbox should decode correctly."""
sb = _mock_sandbox()
sb.kill = AsyncMock()
redis = _mock_redis(get_val=b"sb-abc")
with (
patch("backend.copilot.tools.e2b_sandbox.AsyncSandbox") as mock_cls,
_patch_redis(redis),
):
mock_cls.connect = AsyncMock(return_value=sb)
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is True
sb.kill.assert_awaited_once()
def test_kill_timeout_returns_false(self):
"""Returns False when E2B API calls exceed the 10s timeout."""
redis = _mock_redis(get_val="sb-abc")
with (
_patch_redis(redis),
patch(
"backend.copilot.tools.e2b_sandbox.asyncio.wait_for",
new_callable=AsyncMock,
side_effect=asyncio.TimeoutError,
),
):
result = asyncio.run(kill_sandbox("sess-123", _API_KEY))
assert result is False
redis.delete.assert_awaited_once_with(_KEY)

View File

@@ -9,6 +9,7 @@ from .agent_generator import (
AgentGeneratorNotConfiguredError,
generate_agent_patch,
get_agent_as_json,
get_all_relevant_agents_for_generation,
get_user_message_for_error,
save_agent_to_library,
)
@@ -16,6 +17,7 @@ from .base import BaseTool
from .models import (
AgentPreviewResponse,
AgentSavedResponse,
AsyncProcessingResponse,
ClarificationNeededResponse,
ClarifyingQuestion,
ErrorResponse,
@@ -36,16 +38,17 @@ class EditAgentTool(BaseTool):
def description(self) -> str:
return (
"Edit an existing agent from the user's library using natural language. "
"Generates updates to the agent while preserving unchanged parts. "
"\n\nIMPORTANT: Before calling this tool, if the changes involve adding new "
"functionality, search for relevant existing agents using find_library_agent "
"that could be used as building blocks. Pass their IDs in library_agent_ids."
"Generates updates to the agent while preserving unchanged parts."
)
@property
def requires_auth(self) -> bool:
return True
@property
def is_long_running(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@@ -71,15 +74,6 @@ class EditAgentTool(BaseTool):
"Additional context or answers to previous clarifying questions."
),
},
"library_agent_ids": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of library agent IDs to use as building blocks for the changes. "
"If adding new functionality, search for relevant agents using "
"find_library_agent first, then pass their IDs here."
),
},
"save": {
"type": "boolean",
"description": (
@@ -108,10 +102,13 @@ class EditAgentTool(BaseTool):
agent_id = kwargs.get("agent_id", "").strip()
changes = kwargs.get("changes", "").strip()
context = kwargs.get("context", "")
library_agent_ids = kwargs.get("library_agent_ids", [])
save = kwargs.get("save", True)
session_id = session.session_id if session else None
# Extract async processing params (passed by long-running tool handler)
operation_id = kwargs.get("_operation_id")
task_id = kwargs.get("_task_id")
if not agent_id:
return ErrorResponse(
message="Please provide the agent ID to edit.",
@@ -135,25 +132,21 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Fetch library agents by IDs if provided
library_agents = None
if user_id and library_agent_ids:
if user_id:
try:
from .agent_generator import get_library_agents_by_ids
graph_id = current_agent.get("id")
# Filter out the current agent being edited
filtered_ids = [id for id in library_agent_ids if id != graph_id]
library_agents = await get_library_agents_by_ids(
library_agents = await get_all_relevant_agents_for_generation(
user_id=user_id,
agent_ids=filtered_ids,
search_query=changes,
exclude_graph_id=graph_id,
include_marketplace=True,
)
logger.debug(
f"Fetched {len(library_agents)} library agents by ID for sub-agent composition"
f"Found {len(library_agents)} relevant agents for sub-agent composition"
)
except Exception as e:
logger.warning(f"Failed to fetch library agents by IDs: {e}")
logger.warning(f"Failed to fetch library agents: {e}")
update_request = changes
if context:
@@ -164,6 +157,8 @@ class EditAgentTool(BaseTool):
update_request,
current_agent,
library_agents,
operation_id=operation_id,
task_id=task_id,
)
except AgentGeneratorNotConfiguredError:
return ErrorResponse(
@@ -183,6 +178,19 @@ class EditAgentTool(BaseTool):
session_id=session_id,
)
# Check if Agent Generator accepted for async processing
if result.get("status") == "accepted":
logger.info(
f"Agent edit delegated to async processing "
f"(operation_id={operation_id}, task_id={task_id})"
)
return AsyncProcessingResponse(
message="Agent edit started. You'll be notified when it's complete.",
operation_id=operation_id,
task_id=task_id,
session_id=session_id,
)
# Check if the result is an error from the external service
if isinstance(result, dict) and result.get("type") == "error":
error_msg = result.get("error", "Unknown error")

View File

@@ -1,186 +0,0 @@
"""Shared utilities for execution waiting and status handling."""
import asyncio
import logging
from typing import Any
from backend.data.db_accessors import execution_db
from backend.data.execution import (
AsyncRedisExecutionEventBus,
ExecutionStatus,
GraphExecution,
GraphExecutionEvent,
)
logger = logging.getLogger(__name__)
# Terminal statuses that indicate execution is complete
TERMINAL_STATUSES = frozenset(
{
ExecutionStatus.COMPLETED,
ExecutionStatus.FAILED,
ExecutionStatus.TERMINATED,
}
)
# Statuses where execution is paused but not finished (e.g. human-in-the-loop)
PAUSED_STATUSES = frozenset(
{
ExecutionStatus.REVIEW,
}
)
# Statuses that mean "stop waiting" (terminal or paused)
STOP_WAITING_STATUSES = TERMINAL_STATUSES | PAUSED_STATUSES
_POST_SUBSCRIBE_RECHECK_DELAY = 0.1 # seconds to wait for subscription to establish
async def wait_for_execution(
user_id: str,
graph_id: str,
execution_id: str,
timeout_seconds: int,
) -> GraphExecution | None:
"""
Wait for an execution to reach a terminal or paused status using Redis pubsub.
Handles the race condition between checking status and subscribing by
re-checking the DB after the subscription is established.
Args:
user_id: User ID
graph_id: Graph ID
execution_id: Execution ID to wait for
timeout_seconds: Max seconds to wait
Returns:
The execution with current status, or None if not found
"""
exec_db = execution_db()
# Quick check — maybe it's already done
execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if not execution:
return None
if execution.status in STOP_WAITING_STATUSES:
logger.debug(
f"Execution {execution_id} already in stop-waiting state: "
f"{execution.status}"
)
return execution
logger.info(
f"Waiting up to {timeout_seconds}s for execution {execution_id} "
f"(current status: {execution.status})"
)
event_bus = AsyncRedisExecutionEventBus()
channel_key = f"{user_id}/{graph_id}/{execution_id}"
# Mutable container so _subscribe_and_wait can surface the task even if
# asyncio.wait_for cancels the coroutine before it returns.
task_holder: list[asyncio.Task] = []
try:
result = await asyncio.wait_for(
_subscribe_and_wait(
event_bus, channel_key, user_id, execution_id, exec_db, task_holder
),
timeout=timeout_seconds,
)
return result
except asyncio.TimeoutError:
logger.info(f"Timeout waiting for execution {execution_id}")
except Exception as e:
logger.error(f"Error waiting for execution: {e}", exc_info=True)
finally:
for task in task_holder:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
await event_bus.close()
# Return current state on timeout/error
return await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
async def _subscribe_and_wait(
event_bus: AsyncRedisExecutionEventBus,
channel_key: str,
user_id: str,
execution_id: str,
exec_db: Any,
task_holder: list[asyncio.Task],
) -> GraphExecution | None:
"""
Subscribe to execution events and wait for a terminal/paused status.
Appends the consumer task to ``task_holder`` so the caller can clean it up
even if this coroutine is cancelled by ``asyncio.wait_for``.
To avoid the race condition where the execution completes between the
initial DB check and the Redis subscription, we:
1. Start listening (which subscribes internally)
2. Re-check the DB after subscription is active
3. If still running, wait for pubsub events
"""
listen_iter = event_bus.listen_events(channel_key).__aiter__()
done = asyncio.Event()
result_execution: GraphExecution | None = None
async def _consume() -> None:
nonlocal result_execution
try:
async for event in listen_iter:
if isinstance(event, GraphExecutionEvent):
logger.debug(f"Received execution update: {event.status}")
if event.status in STOP_WAITING_STATUSES:
result_execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
done.set()
return
except Exception as e:
logger.error(f"Error in execution consumer: {e}", exc_info=True)
done.set()
consume_task = asyncio.create_task(_consume())
task_holder.append(consume_task)
# Give the subscription a moment to establish, then re-check DB
await asyncio.sleep(_POST_SUBSCRIBE_RECHECK_DELAY)
execution = await exec_db.get_graph_execution(
user_id=user_id,
execution_id=execution_id,
include_node_executions=False,
)
if execution and execution.status in STOP_WAITING_STATUSES:
return execution
# Wait for the pubsub consumer to find a terminal event
await done.wait()
return result_execution
def get_execution_outputs(execution: GraphExecution | None) -> dict[str, Any] | None:
"""Extract outputs from an execution, or return None."""
if execution is None:
return None
return execution.outputs

View File

@@ -32,7 +32,6 @@ COPILOT_EXCLUDED_BLOCK_TYPES = {
BlockType.NOTE, # Visual annotation only - no runtime behavior
BlockType.HUMAN_IN_THE_LOOP, # Pauses for human approval - CoPilot IS human-in-the-loop
BlockType.AGENT, # AgentExecutorBlock requires execution_context - use run_agent tool
BlockType.MCP_TOOL, # Has dedicated run_mcp_tool tool with proper discovery + auth flow
}
# Specific block IDs excluded from CoPilot (STANDARD type but still require graph context)

Some files were not shown because too many files have changed in this diff Show More