mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-24 03:00:28 -05:00
Compare commits
84 Commits
pr-12177
...
add-llm-ma
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b54022bded | ||
|
|
987712dac1 | ||
|
|
e01526cf52 | ||
|
|
1704812f50 | ||
|
|
29f95e5b61 | ||
|
|
266526f08c | ||
|
|
26490e32d8 | ||
|
|
d6bf54281b | ||
|
|
a7835056c9 | ||
|
|
cf3390d192 | ||
|
|
d8007f74e9 | ||
|
|
4d341c55c5 | ||
|
|
01ef7e1925 | ||
|
|
5baf1a0f60 | ||
|
|
9fc5d465da | ||
|
|
c797f4e1f2 | ||
|
|
05033610bb | ||
|
|
76f3a89be8 | ||
|
|
df7bb57c83 | ||
|
|
b11d46d246 | ||
|
|
8e6bc5eb48 | ||
|
|
8b2b0c853a | ||
|
|
ffb86cced4 | ||
|
|
fea46a6d28 | ||
|
|
f2f779e54f | ||
|
|
dda9a9b010 | ||
|
|
c1d3604682 | ||
|
|
dfbfbdf696 | ||
|
|
994ebc2cf8 | ||
|
|
2245d115d3 | ||
|
|
5238b1b71c | ||
|
|
4fb86b2738 | ||
|
|
e10128e9f0 | ||
|
|
b205d5863e | ||
|
|
6da2dee62f | ||
|
|
324ebc1e06 | ||
|
|
ce2ebee838 | ||
|
|
0597573b6c | ||
|
|
9496b33a1c | ||
|
|
8e3aabd558 | ||
|
|
fbef81c0c9 | ||
|
|
226d2ef4a0 | ||
|
|
42f8a26ee1 | ||
|
|
8d021fe76c | ||
|
|
cb10907bf6 | ||
|
|
54084fe597 | ||
|
|
8f5d851908 | ||
|
|
358a21c6fc | ||
|
|
336fc43b24 | ||
|
|
cfb1613877 | ||
|
|
386eea741c | ||
|
|
e5c6809d9c | ||
|
|
963b8090cc | ||
|
|
eab93aba2b | ||
|
|
47a70cdbd0 | ||
|
|
69c9136060 | ||
|
|
6ed8bb4f14 | ||
|
|
6cf28e58d3 | ||
|
|
632ef24408 | ||
|
|
6dc767aafa | ||
|
|
23e37fd163 | ||
|
|
63869fe710 | ||
|
|
90ae75d475 | ||
|
|
9b6dc3be12 | ||
|
|
9b8b6252c5 | ||
|
|
0d321323f5 | ||
|
|
3ee3ea8f02 | ||
|
|
7a842d35ae | ||
|
|
07e8568f57 | ||
|
|
13a0caa5d8 | ||
|
|
664523a721 | ||
|
|
33b103d09b | ||
|
|
2e3fc99caa | ||
|
|
52c7b223df | ||
|
|
24d86fde30 | ||
|
|
df7be39724 | ||
|
|
8c7b1af409 | ||
|
|
b6e2f05b63 | ||
|
|
7435739053 | ||
|
|
a97fdba554 | ||
|
|
ec705bbbcf | ||
|
|
7fe6b576ae | ||
|
|
dfc42003a1 | ||
|
|
6bbeb22943 |
9
.github/workflows/platform-backend-ci.yml
vendored
9
.github/workflows/platform-backend-ci.yml
vendored
@@ -41,18 +41,13 @@ jobs:
|
||||
ports:
|
||||
- 6379:6379
|
||||
rabbitmq:
|
||||
image: rabbitmq:4.1.4
|
||||
image: rabbitmq:3.12-management
|
||||
ports:
|
||||
- 5672:5672
|
||||
- 15672:15672
|
||||
env:
|
||||
RABBITMQ_DEFAULT_USER: ${{ env.RABBITMQ_DEFAULT_USER }}
|
||||
RABBITMQ_DEFAULT_PASS: ${{ env.RABBITMQ_DEFAULT_PASS }}
|
||||
options: >-
|
||||
--health-cmd "rabbitmq-diagnostics -q ping"
|
||||
--health-interval 30s
|
||||
--health-timeout 10s
|
||||
--health-retries 5
|
||||
--health-start-period 10s
|
||||
clamav:
|
||||
image: clamav/clamav-debian:latest
|
||||
ports:
|
||||
|
||||
6
.github/workflows/platform-frontend-ci.yml
vendored
6
.github/workflows/platform-frontend-ci.yml
vendored
@@ -6,16 +6,10 @@ on:
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
- "autogpt_platform/backend/Dockerfile"
|
||||
- "autogpt_platform/docker-compose.yml"
|
||||
- "autogpt_platform/docker-compose.platform.yml"
|
||||
pull_request:
|
||||
paths:
|
||||
- ".github/workflows/platform-frontend-ci.yml"
|
||||
- "autogpt_platform/frontend/**"
|
||||
- "autogpt_platform/backend/Dockerfile"
|
||||
- "autogpt_platform/docker-compose.yml"
|
||||
- "autogpt_platform/docker-compose.platform.yml"
|
||||
merge_group:
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
@@ -53,6 +53,63 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||
|
||||
# ============================== BACKEND SERVER ============================== #
|
||||
|
||||
FROM debian:13-slim AS server
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=true \
|
||||
POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||
# for the bash_exec MCP tool.
|
||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
imagemagick \
|
||||
jq \
|
||||
ripgrep \
|
||||
tree \
|
||||
bubblewrap \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
# The .venv includes the generated Prisma client
|
||||
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
# Copy dependency files + autogpt_libs (path dependency)
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
|
||||
|
||||
# Copy backend code + docs (for Copilot docs search)
|
||||
COPY autogpt_platform/backend ./
|
||||
COPY docs /app/docs
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
CMD ["poetry", "run", "rest"]
|
||||
|
||||
# =============================== DB MIGRATOR =============================== #
|
||||
|
||||
# Lightweight migrate stage - only needs Prisma CLI, not full Python environment
|
||||
@@ -84,59 +141,3 @@ COPY autogpt_platform/backend/schema.prisma ./
|
||||
COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/partial_types.py
|
||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||
COPY autogpt_platform/backend/migrations ./migrations
|
||||
|
||||
# ============================== BACKEND SERVER ============================== #
|
||||
|
||||
FROM debian:13-slim AS server
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python, FFmpeg, ImageMagick, and CLI tools for agent use.
|
||||
# bubblewrap provides OS-level sandbox (whitelist-only FS + no network)
|
||||
# for the bash_exec MCP tool.
|
||||
# Using --no-install-recommends saves ~650MB by skipping unnecessary deps like llvm, mesa, etc.
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
imagemagick \
|
||||
jq \
|
||||
ripgrep \
|
||||
tree \
|
||||
bubblewrap \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy poetry (build-time only, for `poetry install --only-root` to create entry points)
|
||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||
# Copy Node.js installation for Prisma
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
COPY --from=builder /usr/bin/npm /usr/bin/npm
|
||||
COPY --from=builder /usr/bin/npx /usr/bin/npx
|
||||
COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-python/binaries
|
||||
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
|
||||
# Copy only the .venv from builder (not the entire /app directory)
|
||||
# The .venv includes the generated Prisma client
|
||||
COPY --from=builder /app/autogpt_platform/backend/.venv ./.venv
|
||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||
|
||||
# Copy dependency files + autogpt_libs (path dependency)
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml ./
|
||||
|
||||
# Copy backend code + docs (for Copilot docs search)
|
||||
COPY autogpt_platform/backend ./
|
||||
COPY docs /app/docs
|
||||
# Install the project package to create entry point scripts in .venv/bin/
|
||||
# (e.g., rest, executor, ws, db, scheduler, notification - see [tool.poetry.scripts])
|
||||
RUN POETRY_VIRTUALENVS_CREATE=true POETRY_VIRTUALENVS_IN_PROJECT=true \
|
||||
poetry install --no-ansi --only-root
|
||||
|
||||
ENV PORT=8000
|
||||
|
||||
CMD ["rest"]
|
||||
|
||||
@@ -1,9 +1,4 @@
|
||||
"""Common test fixtures for server tests.
|
||||
|
||||
Note: Common fixtures like test_user_id, admin_user_id, target_user_id,
|
||||
setup_test_user, and setup_admin_user are defined in the parent conftest.py
|
||||
(backend/conftest.py) and are available here automatically.
|
||||
"""
|
||||
"""Common test fixtures for server tests."""
|
||||
|
||||
import pytest
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
@@ -16,6 +11,54 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot:
|
||||
return snapshot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_user(test_user_id):
|
||||
"""Provide mock JWT payload for regular user testing."""
|
||||
|
||||
@@ -122,6 +122,24 @@ class ConnectionManager:
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def broadcast_to_all(self, *, method: WSMethod, data: dict) -> int:
|
||||
"""Broadcast a message to all active websocket connections."""
|
||||
message = WSMessage(
|
||||
method=method,
|
||||
data=data,
|
||||
).model_dump_json()
|
||||
|
||||
connections = tuple(self.active_connections)
|
||||
if not connections:
|
||||
return 0
|
||||
|
||||
await asyncio.gather(
|
||||
*(connection.send_text(message) for connection in connections),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
||||
if channel_key not in self.subscriptions:
|
||||
self.subscriptions[channel_key] = set()
|
||||
|
||||
@@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.api.external.middleware import require_permission
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools import find_agent_tool, run_agent_tool
|
||||
from backend.copilot.tools.models import ToolResponseBase
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import find_agent_tool, run_agent_tool
|
||||
from backend.api.features.chat.tools.models import ToolResponseBase
|
||||
from backend.data.auth.base import APIAuthorizationInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -176,30 +176,64 @@ async def get_execution_analytics_config(
|
||||
# Return with provider prefix for clarity
|
||||
return f"{provider_name}: {model_name}"
|
||||
|
||||
# Include all LlmModel values (no more filtering by hardcoded list)
|
||||
recommended_model = LlmModel.GPT4O_MINI.value
|
||||
for model in LlmModel:
|
||||
# Get all models from the registry (dynamic, not hardcoded enum)
|
||||
from backend.data import llm_registry
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
|
||||
# Get the recommended model from the database (configurable via admin UI)
|
||||
recommended_model_slug = await llm_db.get_recommended_model_slug()
|
||||
|
||||
# Build the available models list
|
||||
first_enabled_slug = None
|
||||
for registry_model in llm_registry.iter_dynamic_models():
|
||||
# Only include enabled models in the list
|
||||
if not registry_model.is_enabled:
|
||||
continue
|
||||
|
||||
# Track first enabled model as fallback
|
||||
if first_enabled_slug is None:
|
||||
first_enabled_slug = registry_model.slug
|
||||
|
||||
model = LlmModel(registry_model.slug)
|
||||
label = generate_model_label(model)
|
||||
# Add "(Recommended)" suffix to the recommended model
|
||||
if model.value == recommended_model:
|
||||
if registry_model.slug == recommended_model_slug:
|
||||
label += " (Recommended)"
|
||||
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
value=model.value,
|
||||
value=registry_model.slug,
|
||||
label=label,
|
||||
provider=model.provider,
|
||||
provider=registry_model.metadata.provider,
|
||||
)
|
||||
)
|
||||
|
||||
# Sort models by provider and name for better UX
|
||||
available_models.sort(key=lambda x: (x.provider, x.label))
|
||||
|
||||
# Handle case where no models are available
|
||||
if not available_models:
|
||||
logger.warning(
|
||||
"No enabled LLM models found in registry. "
|
||||
"Ensure models are configured and enabled in the LLM Registry."
|
||||
)
|
||||
# Provide a placeholder entry so admins see meaningful feedback
|
||||
available_models.append(
|
||||
ModelInfo(
|
||||
value="",
|
||||
label="No models available - configure in LLM Registry",
|
||||
provider="none",
|
||||
)
|
||||
)
|
||||
|
||||
# Use the DB recommended model, or fallback to first enabled model
|
||||
final_recommended = recommended_model_slug or first_enabled_slug or ""
|
||||
|
||||
return ExecutionAnalyticsConfig(
|
||||
available_models=available_models,
|
||||
default_system_prompt=DEFAULT_SYSTEM_PROMPT,
|
||||
default_user_prompt=DEFAULT_USER_PROMPT,
|
||||
recommended_model=recommended_model,
|
||||
recommended_model=final_recommended,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,593 @@
|
||||
import logging
|
||||
|
||||
import autogpt_libs.auth
|
||||
import fastapi
|
||||
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.server.v2.llm import db as llm_db
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter(
|
||||
tags=["llm", "admin"],
|
||||
dependencies=[fastapi.Security(autogpt_libs.auth.requires_admin_user)],
|
||||
)
|
||||
|
||||
|
||||
async def _refresh_runtime_state() -> None:
|
||||
"""Refresh the LLM registry and clear all related caches to ensure real-time updates."""
|
||||
logger.info("Refreshing LLM registry runtime state...")
|
||||
try:
|
||||
# Refresh registry from database
|
||||
await llm_registry.refresh_llm_registry()
|
||||
await refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they're regenerated with updated model options
|
||||
from backend.blocks._base import BlockSchema
|
||||
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
logger.info("Cleared all block schema caches")
|
||||
|
||||
# Clear the /blocks endpoint cache so frontend gets updated schemas
|
||||
try:
|
||||
from backend.api.features.v1 import _get_cached_blocks
|
||||
|
||||
_get_cached_blocks.cache_clear()
|
||||
logger.info("Cleared /blocks endpoint cache")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to clear /blocks cache: %s", e)
|
||||
|
||||
# Clear the v2 builder caches
|
||||
try:
|
||||
from backend.api.features.builder import db as builder_db
|
||||
|
||||
builder_db._get_all_providers.cache_clear()
|
||||
logger.info("Cleared v2 builder providers cache")
|
||||
builder_db._build_cached_search_results.cache_clear()
|
||||
logger.info("Cleared v2 builder search results cache")
|
||||
except Exception as e:
|
||||
logger.debug("Could not clear v2 builder cache: %s", e)
|
||||
|
||||
# Notify all executor services to refresh their registry cache
|
||||
from backend.data.llm_registry import publish_registry_refresh_notification
|
||||
|
||||
await publish_registry_refresh_notification()
|
||||
logger.info("Published registry refresh notification")
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"LLM runtime state refresh failed; caches may be stale: %s", exc
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/providers",
|
||||
summary="List LLM providers",
|
||||
response_model=llm_model.LlmProvidersResponse,
|
||||
)
|
||||
async def list_llm_providers(include_models: bool = True):
|
||||
providers = await llm_db.list_providers(include_models=include_models)
|
||||
return llm_model.LlmProvidersResponse(providers=providers)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/providers",
|
||||
summary="Create LLM provider",
|
||||
response_model=llm_model.LlmProvider,
|
||||
)
|
||||
async def create_llm_provider(request: llm_model.UpsertLlmProviderRequest):
|
||||
provider = await llm_db.upsert_provider(request=request)
|
||||
await _refresh_runtime_state()
|
||||
return provider
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/providers/{provider_id}",
|
||||
summary="Update LLM provider",
|
||||
response_model=llm_model.LlmProvider,
|
||||
)
|
||||
async def update_llm_provider(
|
||||
provider_id: str,
|
||||
request: llm_model.UpsertLlmProviderRequest,
|
||||
):
|
||||
provider = await llm_db.upsert_provider(request=request, provider_id=provider_id)
|
||||
await _refresh_runtime_state()
|
||||
return provider
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/providers/{provider_id}",
|
||||
summary="Delete LLM provider",
|
||||
response_model=dict,
|
||||
)
|
||||
async def delete_llm_provider(provider_id: str):
|
||||
"""
|
||||
Delete an LLM provider.
|
||||
|
||||
A provider can only be deleted if it has no associated models.
|
||||
Delete all models from the provider first before deleting the provider.
|
||||
"""
|
||||
try:
|
||||
await llm_db.delete_provider(provider_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Deleted LLM provider '%s'", provider_id)
|
||||
return {"success": True, "message": "Provider deleted successfully"}
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to delete provider '%s': %s", provider_id, e)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to delete provider '%s': %s", provider_id, e)
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models",
|
||||
summary="List LLM models",
|
||||
response_model=llm_model.LlmModelsResponse,
|
||||
)
|
||||
async def list_llm_models(
|
||||
provider_id: str | None = fastapi.Query(default=None),
|
||||
page: int = fastapi.Query(default=1, ge=1, description="Page number (1-indexed)"),
|
||||
page_size: int = fastapi.Query(
|
||||
default=50, ge=1, le=100, description="Number of models per page"
|
||||
),
|
||||
):
|
||||
return await llm_db.list_models(
|
||||
provider_id=provider_id, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/models",
|
||||
summary="Create LLM model",
|
||||
response_model=llm_model.LlmModel,
|
||||
)
|
||||
async def create_llm_model(request: llm_model.CreateLlmModelRequest):
|
||||
model = await llm_db.create_model(request=request)
|
||||
await _refresh_runtime_state()
|
||||
return model
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/models/{model_id}",
|
||||
summary="Update LLM model",
|
||||
response_model=llm_model.LlmModel,
|
||||
)
|
||||
async def update_llm_model(
|
||||
model_id: str,
|
||||
request: llm_model.UpdateLlmModelRequest,
|
||||
):
|
||||
model = await llm_db.update_model(model_id=model_id, request=request)
|
||||
await _refresh_runtime_state()
|
||||
return model
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/models/{model_id}/toggle",
|
||||
summary="Toggle LLM model availability",
|
||||
response_model=llm_model.ToggleLlmModelResponse,
|
||||
)
|
||||
async def toggle_llm_model(
|
||||
model_id: str,
|
||||
request: llm_model.ToggleLlmModelRequest,
|
||||
):
|
||||
"""
|
||||
Toggle a model's enabled status, optionally migrating workflows when disabling.
|
||||
|
||||
If disabling a model and `migrate_to_slug` is provided, all workflows using
|
||||
this model will be migrated to the specified replacement model before disabling.
|
||||
A migration record is created which can be reverted later using the revert endpoint.
|
||||
|
||||
Optional fields:
|
||||
- `migration_reason`: Reason for the migration (e.g., "Provider outage")
|
||||
- `custom_credit_cost`: Custom pricing override for billing during migration
|
||||
"""
|
||||
try:
|
||||
result = await llm_db.toggle_model(
|
||||
model_id=model_id,
|
||||
is_enabled=request.is_enabled,
|
||||
migrate_to_slug=request.migrate_to_slug,
|
||||
migration_reason=request.migration_reason,
|
||||
custom_credit_cost=request.custom_credit_cost,
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
if result.nodes_migrated > 0:
|
||||
logger.info(
|
||||
"Toggled model '%s' to %s and migrated %d nodes to '%s' (migration_id=%s)",
|
||||
result.model.slug,
|
||||
"enabled" if request.is_enabled else "disabled",
|
||||
result.nodes_migrated,
|
||||
result.migrated_to_slug,
|
||||
result.migration_id,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
logger.warning("Model toggle validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to toggle LLM model %s: %s", model_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to toggle model availability",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/models/{model_id}/usage",
|
||||
summary="Get model usage count",
|
||||
response_model=llm_model.LlmModelUsageResponse,
|
||||
)
|
||||
async def get_llm_model_usage(model_id: str):
|
||||
"""Get the number of workflow nodes using this model."""
|
||||
try:
|
||||
return await llm_db.get_model_usage(model_id=model_id)
|
||||
except ValueError as exc:
|
||||
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get model usage %s: %s", model_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get model usage",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/models/{model_id}",
|
||||
summary="Delete LLM model and migrate workflows",
|
||||
response_model=llm_model.DeleteLlmModelResponse,
|
||||
)
|
||||
async def delete_llm_model(
|
||||
model_id: str,
|
||||
replacement_model_slug: str | None = fastapi.Query(
|
||||
default=None,
|
||||
description="Slug of the model to migrate existing workflows to (required only if workflows use this model)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Delete a model and optionally migrate workflows using it to a replacement model.
|
||||
|
||||
If no workflows are using this model, it can be deleted without providing a
|
||||
replacement. If workflows exist, replacement_model_slug is required.
|
||||
|
||||
This endpoint:
|
||||
1. Counts how many workflow nodes use the model being deleted
|
||||
2. If nodes exist, validates the replacement model and migrates them
|
||||
3. Deletes the model record
|
||||
4. Refreshes all caches and notifies executors
|
||||
|
||||
Example: DELETE /api/llm/admin/models/{id}?replacement_model_slug=gpt-4o
|
||||
Example (no usage): DELETE /api/llm/admin/models/{id}
|
||||
"""
|
||||
try:
|
||||
result = await llm_db.delete_model(
|
||||
model_id=model_id, replacement_model_slug=replacement_model_slug
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
logger.info(
|
||||
"Deleted model '%s' and migrated %d nodes to '%s'",
|
||||
result.deleted_model_slug,
|
||||
result.nodes_migrated,
|
||||
result.replacement_model_slug,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
# Validation errors (model not found, replacement invalid, etc.)
|
||||
logger.warning("Model deletion validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete LLM model %s: %s", model_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to delete model and migrate workflows",
|
||||
) from exc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Migration Management Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/migrations",
|
||||
summary="List model migrations",
|
||||
response_model=llm_model.LlmMigrationsResponse,
|
||||
)
|
||||
async def list_llm_migrations(
|
||||
include_reverted: bool = fastapi.Query(
|
||||
default=False, description="Include reverted migrations in the list"
|
||||
),
|
||||
):
|
||||
"""
|
||||
List all model migrations.
|
||||
|
||||
Migrations are created when disabling a model with the migrate_to_slug option.
|
||||
They can be reverted to restore the original model configuration.
|
||||
"""
|
||||
try:
|
||||
migrations = await llm_db.list_migrations(include_reverted=include_reverted)
|
||||
return llm_model.LlmMigrationsResponse(migrations=migrations)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to list migrations: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to list migrations",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/migrations/{migration_id}",
|
||||
summary="Get migration details",
|
||||
response_model=llm_model.LlmModelMigration,
|
||||
)
|
||||
async def get_llm_migration(migration_id: str):
|
||||
"""Get details of a specific migration."""
|
||||
try:
|
||||
migration = await llm_db.get_migration(migration_id)
|
||||
if not migration:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Migration '{migration_id}' not found"
|
||||
)
|
||||
return migration
|
||||
except fastapi.HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get migration %s: %s", migration_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get migration",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/migrations/{migration_id}/revert",
|
||||
summary="Revert a model migration",
|
||||
response_model=llm_model.RevertMigrationResponse,
|
||||
)
|
||||
async def revert_llm_migration(
|
||||
migration_id: str,
|
||||
request: llm_model.RevertMigrationRequest | None = None,
|
||||
):
|
||||
"""
|
||||
Revert a model migration, restoring affected workflows to their original model.
|
||||
|
||||
This only reverts the specific nodes that were part of the migration.
|
||||
The source model must exist for the revert to succeed.
|
||||
|
||||
Options:
|
||||
- `re_enable_source_model`: Whether to re-enable the source model if disabled (default: True)
|
||||
|
||||
Response includes:
|
||||
- `nodes_reverted`: Number of nodes successfully reverted
|
||||
- `nodes_already_changed`: Number of nodes that were modified since migration (not reverted)
|
||||
- `source_model_re_enabled`: Whether the source model was re-enabled
|
||||
|
||||
Requirements:
|
||||
- Migration must not already be reverted
|
||||
- Source model must exist
|
||||
"""
|
||||
try:
|
||||
re_enable = request.re_enable_source_model if request else True
|
||||
result = await llm_db.revert_migration(
|
||||
migration_id,
|
||||
re_enable_source_model=re_enable,
|
||||
)
|
||||
await _refresh_runtime_state()
|
||||
logger.info(
|
||||
"Reverted migration '%s': %d nodes restored from '%s' to '%s' "
|
||||
"(%d already changed, source re-enabled=%s)",
|
||||
migration_id,
|
||||
result.nodes_reverted,
|
||||
result.target_model_slug,
|
||||
result.source_model_slug,
|
||||
result.nodes_already_changed,
|
||||
result.source_model_re_enabled,
|
||||
)
|
||||
return result
|
||||
except ValueError as exc:
|
||||
logger.warning("Migration revert validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to revert migration %s: %s", migration_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to revert migration",
|
||||
) from exc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Creator Management Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creators",
|
||||
summary="List model creators",
|
||||
response_model=llm_model.LlmCreatorsResponse,
|
||||
)
|
||||
async def list_llm_creators():
|
||||
"""
|
||||
List all model creators.
|
||||
|
||||
Creators are organizations that create/train models (e.g., OpenAI, Meta, Anthropic).
|
||||
This is distinct from providers who host/serve the models (e.g., OpenRouter).
|
||||
"""
|
||||
try:
|
||||
creators = await llm_db.list_creators()
|
||||
return llm_model.LlmCreatorsResponse(creators=creators)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to list creators: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to list creators",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creators/{creator_id}",
|
||||
summary="Get creator details",
|
||||
response_model=llm_model.LlmModelCreator,
|
||||
)
|
||||
async def get_llm_creator(creator_id: str):
|
||||
"""Get details of a specific model creator."""
|
||||
try:
|
||||
creator = await llm_db.get_creator(creator_id)
|
||||
if not creator:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Creator '{creator_id}' not found"
|
||||
)
|
||||
return creator
|
||||
except fastapi.HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get creator %s: %s", creator_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get creator",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/creators",
|
||||
summary="Create model creator",
|
||||
response_model=llm_model.LlmModelCreator,
|
||||
)
|
||||
async def create_llm_creator(request: llm_model.UpsertLlmCreatorRequest):
|
||||
"""
|
||||
Create a new model creator.
|
||||
|
||||
A creator represents an organization that creates/trains AI models,
|
||||
such as OpenAI, Anthropic, Meta, or Google.
|
||||
"""
|
||||
try:
|
||||
creator = await llm_db.upsert_creator(request=request)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Created model creator '%s' (%s)", creator.display_name, creator.id)
|
||||
return creator
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to create creator: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to create creator",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/creators/{creator_id}",
|
||||
summary="Update model creator",
|
||||
response_model=llm_model.LlmModelCreator,
|
||||
)
|
||||
async def update_llm_creator(
|
||||
creator_id: str,
|
||||
request: llm_model.UpsertLlmCreatorRequest,
|
||||
):
|
||||
"""Update an existing model creator."""
|
||||
try:
|
||||
creator = await llm_db.upsert_creator(request=request, creator_id=creator_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Updated model creator '%s' (%s)", creator.display_name, creator_id)
|
||||
return creator
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to update creator %s: %s", creator_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to update creator",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/creators/{creator_id}",
|
||||
summary="Delete model creator",
|
||||
response_model=dict,
|
||||
)
|
||||
async def delete_llm_creator(creator_id: str):
|
||||
"""
|
||||
Delete a model creator.
|
||||
|
||||
This will remove the creator association from all models that reference it
|
||||
(sets creatorId to NULL), but will not delete the models themselves.
|
||||
"""
|
||||
try:
|
||||
await llm_db.delete_creator(creator_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info("Deleted model creator '%s'", creator_id)
|
||||
return {"success": True, "message": f"Creator '{creator_id}' deleted"}
|
||||
except ValueError as exc:
|
||||
logger.warning("Creator deletion validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=404, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete creator %s: %s", creator_id, exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to delete creator",
|
||||
) from exc
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Recommended Model Endpoints
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/recommended-model",
|
||||
summary="Get recommended model",
|
||||
response_model=llm_model.RecommendedModelResponse,
|
||||
)
|
||||
async def get_recommended_model():
|
||||
"""
|
||||
Get the currently recommended LLM model.
|
||||
|
||||
The recommended model is shown to users as the default/suggested option
|
||||
in model selection dropdowns.
|
||||
"""
|
||||
try:
|
||||
model = await llm_db.get_recommended_model()
|
||||
return llm_model.RecommendedModelResponse(
|
||||
model=model,
|
||||
slug=model.slug if model else None,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to get recommended model: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to get recommended model",
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/recommended-model",
|
||||
summary="Set recommended model",
|
||||
response_model=llm_model.SetRecommendedModelResponse,
|
||||
)
|
||||
async def set_recommended_model(request: llm_model.SetRecommendedModelRequest):
|
||||
"""
|
||||
Set a model as the recommended model.
|
||||
|
||||
This clears the recommended flag from any other model and sets it on
|
||||
the specified model. The model must be enabled to be set as recommended.
|
||||
|
||||
The recommended model is displayed to users as the default/suggested
|
||||
option in model selection dropdowns throughout the platform.
|
||||
"""
|
||||
try:
|
||||
model, previous_slug = await llm_db.set_recommended_model(request.model_id)
|
||||
await _refresh_runtime_state()
|
||||
logger.info(
|
||||
"Set recommended model to '%s' (previous: %s)",
|
||||
model.slug,
|
||||
previous_slug or "none",
|
||||
)
|
||||
return llm_model.SetRecommendedModelResponse(
|
||||
model=model,
|
||||
previous_recommended_slug=previous_slug,
|
||||
message=f"Model '{model.display_name}' is now the recommended model",
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.warning("Set recommended model validation failed: %s", exc)
|
||||
raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to set recommended model: %s", exc)
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to set recommended model",
|
||||
) from exc
|
||||
@@ -0,0 +1,491 @@
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
from autogpt_libs.auth.jwt_utils import get_jwt_payload
|
||||
from pytest_snapshot.plugin import Snapshot
|
||||
|
||||
import backend.api.features.admin.llm_routes as llm_routes
|
||||
from backend.server.v2.llm import model as llm_model
|
||||
from backend.util.models import Pagination
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(llm_routes.router, prefix="/admin/llm")
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_app_admin_auth(mock_jwt_admin):
|
||||
"""Setup admin auth overrides for all tests in this module"""
|
||||
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
|
||||
yield
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
def test_list_llm_providers_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful listing of LLM providers"""
|
||||
# Mock the database function
|
||||
mock_providers = [
|
||||
{
|
||||
"id": "provider-1",
|
||||
"name": "openai",
|
||||
"display_name": "OpenAI",
|
||||
"description": "OpenAI LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": True,
|
||||
"metadata": {},
|
||||
"models": [],
|
||||
},
|
||||
{
|
||||
"id": "provider-2",
|
||||
"name": "anthropic",
|
||||
"display_name": "Anthropic",
|
||||
"description": "Anthropic LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": True,
|
||||
"metadata": {},
|
||||
"models": [],
|
||||
},
|
||||
]
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.list_providers",
|
||||
new=AsyncMock(return_value=mock_providers),
|
||||
)
|
||||
|
||||
response = client.get("/admin/llm/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert len(response_data["providers"]) == 2
|
||||
assert response_data["providers"][0]["name"] == "openai"
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"list_llm_providers_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_list_llm_models_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful listing of LLM models with pagination"""
|
||||
# Mock the database function - now returns LlmModelsResponse
|
||||
mock_model = llm_model.LlmModel(
|
||||
id="model-1",
|
||||
slug="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
description="GPT-4 Optimized",
|
||||
provider_id="provider-1",
|
||||
context_window=128000,
|
||||
max_output_tokens=16384,
|
||||
is_enabled=True,
|
||||
capabilities={},
|
||||
metadata={},
|
||||
costs=[
|
||||
llm_model.LlmModelCost(
|
||||
id="cost-1",
|
||||
credit_cost=10,
|
||||
credential_provider="openai",
|
||||
metadata={},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
mock_response = llm_model.LlmModelsResponse(
|
||||
models=[mock_model],
|
||||
pagination=Pagination(
|
||||
total_items=1,
|
||||
total_pages=1,
|
||||
current_page=1,
|
||||
page_size=50,
|
||||
),
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.list_models",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
response = client.get("/admin/llm/models")
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert len(response_data["models"]) == 1
|
||||
assert response_data["models"][0]["slug"] == "gpt-4o"
|
||||
assert response_data["pagination"]["total_items"] == 1
|
||||
assert response_data["pagination"]["page_size"] == 50
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"list_llm_models_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_create_llm_provider_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful creation of LLM provider"""
|
||||
mock_provider = {
|
||||
"id": "new-provider-id",
|
||||
"name": "groq",
|
||||
"display_name": "Groq",
|
||||
"description": "Groq LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": False,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.upsert_provider",
|
||||
new=AsyncMock(return_value=mock_provider),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"name": "groq",
|
||||
"display_name": "Groq",
|
||||
"description": "Groq LLM provider",
|
||||
"supports_tools": True,
|
||||
"supports_json_output": True,
|
||||
"supports_reasoning": False,
|
||||
"supports_parallel_tool": False,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
response = client.post("/admin/llm/providers", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["name"] == "groq"
|
||||
assert response_data["display_name"] == "Groq"
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"create_llm_provider_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_create_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful creation of LLM model"""
|
||||
mock_model = {
|
||||
"id": "new-model-id",
|
||||
"slug": "gpt-4.1-mini",
|
||||
"display_name": "GPT-4.1 Mini",
|
||||
"description": "Latest GPT-4.1 Mini model",
|
||||
"provider_id": "provider-1",
|
||||
"context_window": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"is_enabled": True,
|
||||
"capabilities": {},
|
||||
"metadata": {},
|
||||
"costs": [
|
||||
{
|
||||
"id": "cost-id",
|
||||
"credit_cost": 5,
|
||||
"credential_provider": "openai",
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.create_model",
|
||||
new=AsyncMock(return_value=mock_model),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"slug": "gpt-4.1-mini",
|
||||
"display_name": "GPT-4.1 Mini",
|
||||
"description": "Latest GPT-4.1 Mini model",
|
||||
"provider_id": "provider-1",
|
||||
"context_window": 128000,
|
||||
"max_output_tokens": 16384,
|
||||
"is_enabled": True,
|
||||
"capabilities": {},
|
||||
"metadata": {},
|
||||
"costs": [
|
||||
{
|
||||
"credit_cost": 5,
|
||||
"credential_provider": "openai",
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
response = client.post("/admin/llm/models", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["slug"] == "gpt-4.1-mini"
|
||||
assert response_data["is_enabled"] is True
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"create_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_update_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful update of LLM model"""
|
||||
mock_model = {
|
||||
"id": "model-1",
|
||||
"slug": "gpt-4o",
|
||||
"display_name": "GPT-4o Updated",
|
||||
"description": "Updated description",
|
||||
"provider_id": "provider-1",
|
||||
"context_window": 256000,
|
||||
"max_output_tokens": 32768,
|
||||
"is_enabled": True,
|
||||
"capabilities": {},
|
||||
"metadata": {},
|
||||
"costs": [
|
||||
{
|
||||
"id": "cost-1",
|
||||
"credit_cost": 15,
|
||||
"credential_provider": "openai",
|
||||
"metadata": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.update_model",
|
||||
new=AsyncMock(return_value=mock_model),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"display_name": "GPT-4o Updated",
|
||||
"description": "Updated description",
|
||||
"context_window": 256000,
|
||||
"max_output_tokens": 32768,
|
||||
}
|
||||
|
||||
response = client.patch("/admin/llm/models/model-1", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["display_name"] == "GPT-4o Updated"
|
||||
assert response_data["context_window"] == 256000
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"update_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_toggle_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful toggling of LLM model enabled status"""
|
||||
# Create a proper mock model object
|
||||
mock_model = llm_model.LlmModel(
|
||||
id="model-1",
|
||||
slug="gpt-4o",
|
||||
display_name="GPT-4o",
|
||||
description="GPT-4 Optimized",
|
||||
provider_id="provider-1",
|
||||
context_window=128000,
|
||||
max_output_tokens=16384,
|
||||
is_enabled=False,
|
||||
capabilities={},
|
||||
metadata={},
|
||||
costs=[],
|
||||
)
|
||||
|
||||
# Create a proper ToggleLlmModelResponse
|
||||
mock_response = llm_model.ToggleLlmModelResponse(
|
||||
model=mock_model,
|
||||
nodes_migrated=0,
|
||||
migrated_to_slug=None,
|
||||
migration_id=None,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.toggle_model",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
request_data = {"is_enabled": False}
|
||||
|
||||
response = client.patch("/admin/llm/models/model-1/toggle", json=request_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["model"]["is_enabled"] is False
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"toggle_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_llm_model_success(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
configured_snapshot: Snapshot,
|
||||
) -> None:
|
||||
"""Test successful deletion of LLM model with migration"""
|
||||
# Create a proper DeleteLlmModelResponse
|
||||
mock_response = llm_model.DeleteLlmModelResponse(
|
||||
deleted_model_slug="gpt-3.5-turbo",
|
||||
deleted_model_display_name="GPT-3.5 Turbo",
|
||||
replacement_model_slug="gpt-4o-mini",
|
||||
nodes_migrated=42,
|
||||
message="Successfully deleted model 'GPT-3.5 Turbo' (gpt-3.5-turbo) "
|
||||
"and migrated 42 workflow node(s) to 'gpt-4o-mini'.",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.delete(
|
||||
"/admin/llm/models/model-1?replacement_model_slug=gpt-4o-mini"
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["deleted_model_slug"] == "gpt-3.5-turbo"
|
||||
assert response_data["nodes_migrated"] == 42
|
||||
assert response_data["replacement_model_slug"] == "gpt-4o-mini"
|
||||
|
||||
# Verify refresh was called
|
||||
mock_refresh.assert_called_once()
|
||||
|
||||
# Snapshot test the response (must be string)
|
||||
configured_snapshot.assert_match(
|
||||
json.dumps(response_data, indent=2, sort_keys=True),
|
||||
"delete_llm_model_success.json",
|
||||
)
|
||||
|
||||
|
||||
def test_delete_llm_model_validation_error(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test deletion fails with proper error when validation fails"""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(side_effect=ValueError("Replacement model 'invalid' not found")),
|
||||
)
|
||||
|
||||
response = client.delete("/admin/llm/models/model-1?replacement_model_slug=invalid")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Replacement model 'invalid' not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_llm_model_no_replacement_with_usage(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test deletion fails when nodes exist but no replacement is provided"""
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(
|
||||
side_effect=ValueError(
|
||||
"Cannot delete model 'test-model': 5 workflow node(s) are using it. "
|
||||
"Please provide a replacement_model_slug to migrate them."
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
response = client.delete("/admin/llm/models/model-1")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "workflow node(s) are using it" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_delete_llm_model_no_replacement_no_usage(
|
||||
mocker: pytest_mock.MockFixture,
|
||||
) -> None:
|
||||
"""Test deletion succeeds when no nodes use the model and no replacement is provided"""
|
||||
mock_response = llm_model.DeleteLlmModelResponse(
|
||||
deleted_model_slug="unused-model",
|
||||
deleted_model_display_name="Unused Model",
|
||||
replacement_model_slug=None,
|
||||
nodes_migrated=0,
|
||||
message="Successfully deleted model 'Unused Model' (unused-model). No workflows were using this model.",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"backend.api.features.admin.llm_routes.llm_db.delete_model",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
)
|
||||
|
||||
mock_refresh = mocker.patch(
|
||||
"backend.api.features.admin.llm_routes._refresh_runtime_state",
|
||||
new=AsyncMock(),
|
||||
)
|
||||
|
||||
response = client.delete("/admin/llm/models/model-1")
|
||||
|
||||
assert response.status_code == 200
|
||||
response_data = response.json()
|
||||
assert response_data["deleted_model_slug"] == "unused-model"
|
||||
assert response_data["nodes_migrated"] == 0
|
||||
assert response_data["replacement_model_slug"] is None
|
||||
mock_refresh.assert_called_once()
|
||||
@@ -20,6 +20,7 @@ from backend.blocks._base import (
|
||||
)
|
||||
from backend.blocks.llm import LlmModel
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.data.llm_registry import get_all_model_slugs_for_validation
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.cache import cached
|
||||
from backend.util.models import Pagination
|
||||
@@ -36,7 +37,14 @@ from .model import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
llm_models = [name.name.lower().replace("_", " ") for name in LlmModel]
|
||||
|
||||
|
||||
def _get_llm_models() -> list[str]:
|
||||
"""Get LLM model names for search matching from the registry."""
|
||||
return [
|
||||
slug.lower().replace("-", " ") for slug in get_all_model_slugs_for_validation()
|
||||
]
|
||||
|
||||
|
||||
MAX_LIBRARY_AGENT_RESULTS = 100
|
||||
MAX_MARKETPLACE_AGENT_RESULTS = 100
|
||||
@@ -501,8 +509,10 @@ async def _get_static_counts():
|
||||
def _matches_llm_model(schema_cls: type[BlockSchema], query: str) -> bool:
|
||||
for field in schema_cls.model_fields.values():
|
||||
if field.annotation == LlmModel:
|
||||
# Check if query matches any value in llm_models
|
||||
if any(query in name for name in llm_models):
|
||||
# Normalize query same as model slugs (lowercase, hyphens to spaces)
|
||||
normalized_model_query = query.lower().replace("-", " ")
|
||||
# Check if query matches any value in llm_models from registry
|
||||
if any(normalized_model_query in name for name in _get_llm_models()):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@@ -37,10 +37,12 @@ stale pending messages from dead consumers.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
from pydantic import BaseModel
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
@@ -67,8 +69,8 @@ class OperationCompleteMessage(BaseModel):
|
||||
class ChatCompletionConsumer:
|
||||
"""Consumer for chat operation completion messages from Redis Streams.
|
||||
|
||||
Database operations are handled through the chat_db() accessor, which
|
||||
routes through DatabaseManager RPC when Prisma is not directly connected.
|
||||
This consumer initializes its own Prisma client in start() to ensure
|
||||
database operations work correctly within this async context.
|
||||
|
||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||
messages reliably with automatic redelivery on failure.
|
||||
@@ -77,6 +79,7 @@ class ChatCompletionConsumer:
|
||||
def __init__(self):
|
||||
self._consumer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self._prisma: Prisma | None = None
|
||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -112,6 +115,15 @@ class ChatCompletionConsumer:
|
||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||
)
|
||||
|
||||
async def _ensure_prisma(self) -> Prisma:
|
||||
"""Lazily initialize Prisma client on first use."""
|
||||
if self._prisma is None:
|
||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
self._prisma = Prisma(datasource={"url": database_url})
|
||||
await self._prisma.connect()
|
||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||
return self._prisma
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the completion consumer."""
|
||||
self._running = False
|
||||
@@ -124,6 +136,11 @@ class ChatCompletionConsumer:
|
||||
pass
|
||||
self._consumer_task = None
|
||||
|
||||
if self._prisma:
|
||||
await self._prisma.disconnect()
|
||||
self._prisma = None
|
||||
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
||||
|
||||
logger.info("Chat completion consumer stopped")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
@@ -235,7 +252,7 @@ class ChatCompletionConsumer:
|
||||
# XAUTOCLAIM after min_idle_time expires
|
||||
|
||||
async def _handle_message(self, body: bytes) -> None:
|
||||
"""Handle a completion message."""
|
||||
"""Handle a completion message using our own Prisma client."""
|
||||
try:
|
||||
data = orjson.loads(body)
|
||||
message = OperationCompleteMessage(**data)
|
||||
@@ -285,7 +302,8 @@ class ChatCompletionConsumer:
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle successful operation completion."""
|
||||
await process_operation_success(task, message.result)
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_success(task, message.result, prisma)
|
||||
|
||||
async def _handle_failure(
|
||||
self,
|
||||
@@ -293,7 +311,8 @@ class ChatCompletionConsumer:
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle failed operation completion."""
|
||||
await process_operation_failure(task, message.error)
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_failure(task, message.error, prisma)
|
||||
|
||||
|
||||
# Module-level consumer instance
|
||||
@@ -9,8 +9,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from prisma import Prisma
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
@@ -73,40 +72,48 @@ async def _update_tool_message(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
prisma_client: Prisma | None,
|
||||
) -> None:
|
||||
"""Update tool message in database using the chat_db accessor.
|
||||
|
||||
Routes through DatabaseManager RPC when Prisma is not directly
|
||||
connected (e.g. in the CoPilot Executor microservice).
|
||||
"""Update tool message in database.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
tool_call_id: The tool call ID to update
|
||||
content: The new content for the message
|
||||
prisma_client: Optional Prisma client. If None, uses chat_service.
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails.
|
||||
ToolMessageUpdateError: If the database update fails. The caller should
|
||||
handle this to avoid marking the task as completed with inconsistent state.
|
||||
"""
|
||||
try:
|
||||
updated = await chat_db().update_tool_message_content(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
new_content=content,
|
||||
)
|
||||
if not updated:
|
||||
raise ToolMessageUpdateError(
|
||||
f"No message found with tool_call_id="
|
||||
f"{tool_call_id} in session {session_id}"
|
||||
if prisma_client:
|
||||
# Use provided Prisma client (for consumer with its own connection)
|
||||
updated_count = await prisma_client.chatmessage.update_many(
|
||||
where={
|
||||
"sessionId": session_id,
|
||||
"toolCallId": tool_call_id,
|
||||
},
|
||||
data={"content": content},
|
||||
)
|
||||
# Check if any rows were updated - 0 means message not found
|
||||
if updated_count == 0:
|
||||
raise ToolMessageUpdateError(
|
||||
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
||||
)
|
||||
else:
|
||||
# Use service function (for webhook endpoint)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=content,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to update tool message: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
||||
raise ToolMessageUpdateError(
|
||||
f"Failed to update tool message for tool call #{tool_call_id}: {e}"
|
||||
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
@@ -195,6 +202,7 @@ async def _save_agent_from_result(
|
||||
async def process_operation_success(
|
||||
task: stream_registry.ActiveTask,
|
||||
result: dict | str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle successful operation completion.
|
||||
|
||||
@@ -204,10 +212,12 @@ async def process_operation_success(
|
||||
Args:
|
||||
task: The active task that completed
|
||||
result: The result data from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails. The task
|
||||
will be marked as failed instead of completed.
|
||||
ToolMessageUpdateError: If the database update fails. The task will be
|
||||
marked as failed instead of completed to avoid inconsistent state.
|
||||
"""
|
||||
# For agent generation tools, save the agent to library
|
||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||
@@ -240,6 +250,7 @@ async def process_operation_success(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=result_str,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - mark task as failed to avoid inconsistent state
|
||||
@@ -282,15 +293,18 @@ async def process_operation_success(
|
||||
async def process_operation_failure(
|
||||
task: stream_registry.ActiveTask,
|
||||
error: str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle failed operation completion.
|
||||
|
||||
Publishes the error to the stream registry, updates the database
|
||||
with the error response, and marks the task as failed.
|
||||
Publishes the error to the stream registry, updates the database with
|
||||
the error response, and marks the task as failed.
|
||||
|
||||
Args:
|
||||
task: The active task that failed
|
||||
error: The error message from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
"""
|
||||
error_msg = error or "Operation failed"
|
||||
|
||||
@@ -311,6 +325,7 @@ async def process_operation_failure(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=error_response.model_dump_json(),
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - log but continue with cleanup
|
||||
@@ -27,6 +27,7 @@ class ChatConfig(BaseSettings):
|
||||
session_ttl: int = Field(default=43200, description="Session TTL in seconds")
|
||||
|
||||
# Streaming Configuration
|
||||
stream_timeout: int = Field(default=300, description="Stream timeout in seconds")
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Max retries for fallback path (SDK handles retries internally)",
|
||||
@@ -38,10 +39,8 @@ class ChatConfig(BaseSettings):
|
||||
|
||||
# Long-running operation configuration
|
||||
long_running_operation_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL in seconds for long-running operation deduplication lock "
|
||||
"(1 hour, matches stream_ttl). Prevents duplicate operations if pod dies. "
|
||||
"For longer operations, the stream_registry heartbeat keeps them alive.",
|
||||
default=600,
|
||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||
)
|
||||
|
||||
# Stream registry configuration for SSE reconnection
|
||||
@@ -14,27 +14,29 @@ from prisma.types import (
|
||||
ChatSessionWhereInput,
|
||||
)
|
||||
|
||||
from backend.data import db
|
||||
from backend.data.db import transaction
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
from .model import ChatMessage, ChatSession, ChatSessionInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_chat_session(session_id: str) -> ChatSession | None:
|
||||
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||
"""Get a chat session by ID from the database."""
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
||||
include={"Messages": True},
|
||||
)
|
||||
return ChatSession.from_db(session) if session else None
|
||||
if session and session.Messages:
|
||||
# Sort messages by sequence in Python - Prisma Python client doesn't support
|
||||
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
) -> ChatSessionInfo:
|
||||
) -> PrismaChatSession:
|
||||
"""Create a new chat session in the database."""
|
||||
data = ChatSessionCreateInput(
|
||||
id=session_id,
|
||||
@@ -43,8 +45,7 @@ async def create_chat_session(
|
||||
successfulAgentRuns=SafeJson({}),
|
||||
successfulAgentSchedules=SafeJson({}),
|
||||
)
|
||||
prisma_session = await PrismaChatSession.prisma().create(data=data)
|
||||
return ChatSessionInfo.from_db(prisma_session)
|
||||
return await PrismaChatSession.prisma().create(data=data)
|
||||
|
||||
|
||||
async def update_chat_session(
|
||||
@@ -55,7 +56,7 @@ async def update_chat_session(
|
||||
total_prompt_tokens: int | None = None,
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
) -> ChatSession | None:
|
||||
) -> PrismaChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
|
||||
@@ -75,9 +76,12 @@ async def update_chat_session(
|
||||
session = await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data=data,
|
||||
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
||||
include={"Messages": True},
|
||||
)
|
||||
return ChatSession.from_db(session) if session else None
|
||||
if session and session.Messages:
|
||||
# Sort in Python - Prisma Python doesn't support order_by in include clauses
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
@@ -90,7 +94,7 @@ async def add_chat_message(
|
||||
refusal: str | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> ChatMessage:
|
||||
) -> PrismaChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
||||
# because Prisma's TypedDict validation rejects optional fields set to None.
|
||||
@@ -125,119 +129,79 @@ async def add_chat_message(
|
||||
),
|
||||
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
||||
)
|
||||
return ChatMessage.from_db(message)
|
||||
return message
|
||||
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> tuple[list[ChatMessage], int]:
|
||||
) -> list[PrismaChatMessage]:
|
||||
"""Add multiple messages to a chat session in a batch.
|
||||
|
||||
Uses collision detection with retry: tries to create messages starting
|
||||
at start_sequence. If a unique constraint violation occurs (e.g., the
|
||||
streaming loop and long-running callback race), queries MAX(sequence)
|
||||
and retries with the correct next sequence number. This avoids
|
||||
unnecessary upserts and DB queries in the common case (no collision).
|
||||
|
||||
Returns:
|
||||
Tuple of (messages, final_message_count) where final_message_count
|
||||
is the total number of messages in the session after insertion.
|
||||
This allows callers to update their counters even when collision
|
||||
detection adjusts start_sequence.
|
||||
Uses a transaction for atomicity - if any message creation fails,
|
||||
the entire batch is rolled back.
|
||||
"""
|
||||
if not messages:
|
||||
# No messages to add - return current count
|
||||
return [], start_sequence
|
||||
return []
|
||||
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
created_messages = []
|
||||
async with db.transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
created_messages = []
|
||||
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
async with transaction() as tx:
|
||||
for i, msg in enumerate(messages):
|
||||
# Build input dict dynamically rather than using ChatMessageCreateInput
|
||||
# directly because Prisma's TypedDict validation rejects optional fields
|
||||
# set to None. We only include fields that have values, then cast.
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
# Return messages and final message count (for shared counter sync)
|
||||
final_count = start_sequence + len(messages)
|
||||
return [ChatMessage.from_db(m) for m in created_messages], final_count
|
||||
|
||||
except Exception as e:
|
||||
# Check if it's a unique constraint violation
|
||||
error_msg = str(e).lower()
|
||||
is_unique_constraint = (
|
||||
"unique constraint" in error_msg or "duplicate key" in error_msg
|
||||
created = await PrismaChatMessage.prisma(tx).create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
|
||||
if is_unique_constraint and attempt < max_retries - 1:
|
||||
# Collision detected - query MAX(sequence)+1 and retry with correct offset
|
||||
logger.info(
|
||||
f"Collision detected for session {session_id} at sequence "
|
||||
f"{start_sequence}, querying DB for latest sequence"
|
||||
)
|
||||
start_sequence = await get_next_sequence(session_id)
|
||||
logger.info(
|
||||
f"Retrying batch insert with start_sequence={start_sequence}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Not a collision or max retries exceeded - propagate error
|
||||
raise
|
||||
# Update session's updatedAt timestamp within the same transaction.
|
||||
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
||||
# separately via update_chat_session() after streaming completes.
|
||||
await PrismaChatSession.prisma(tx).update(
|
||||
where={"id": session_id},
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
# Should never reach here due to raise in exception handler
|
||||
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
|
||||
return created_messages
|
||||
|
||||
|
||||
async def get_user_chat_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[ChatSessionInfo]:
|
||||
) -> list[PrismaChatSession]:
|
||||
"""Get chat sessions for a user, ordered by most recent."""
|
||||
prisma_sessions = await PrismaChatSession.prisma().find_many(
|
||||
return await PrismaChatSession.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
return [ChatSessionInfo.from_db(s) for s in prisma_sessions]
|
||||
|
||||
|
||||
async def get_user_session_count(user_id: str) -> int:
|
||||
@@ -276,23 +240,10 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
return False
|
||||
|
||||
|
||||
async def get_next_sequence(session_id: str) -> int:
|
||||
"""Get the next sequence number for a new message in this session.
|
||||
|
||||
Uses MAX(sequence) + 1 for robustness. Returns 0 if no messages exist.
|
||||
More robust than COUNT(*) because it's immune to deleted messages.
|
||||
"""
|
||||
result = await db.prisma.query_raw(
|
||||
"""
|
||||
SELECT COALESCE(MAX(sequence) + 1, 0) as next_seq
|
||||
FROM "ChatMessage"
|
||||
WHERE "sessionId" = $1
|
||||
""",
|
||||
session_id,
|
||||
)
|
||||
if not result or len(result) == 0:
|
||||
return 0
|
||||
return int(result[0]["next_seq"])
|
||||
async def get_chat_session_message_count(session_id: str) -> int:
|
||||
"""Get the number of messages in a chat session."""
|
||||
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
||||
return count
|
||||
|
||||
|
||||
async def update_tool_message_content(
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Self, cast
|
||||
from typing import Any, cast
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from openai.types.chat import (
|
||||
@@ -23,17 +23,26 @@ from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db_accessors import chat_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.util import json
|
||||
from backend.util.exceptions import DatabaseError, RedisError
|
||||
|
||||
from . import db as chat_db
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
# Redis cache key prefix for chat sessions
|
||||
CHAT_SESSION_CACHE_PREFIX = "chat:session:"
|
||||
|
||||
@@ -43,7 +52,28 @@ def _get_session_cache_key(session_id: str) -> str:
|
||||
return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}"
|
||||
|
||||
|
||||
# ===================== Chat data models ===================== #
|
||||
# Session-level locks to prevent race conditions during concurrent upserts.
|
||||
# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced,
|
||||
# preventing unbounded memory growth while maintaining lock semantics for active sessions.
|
||||
# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after
|
||||
# async with lock: completes). Explicit cleanup also occurs in delete_chat_session().
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -55,19 +85,6 @@ class ChatMessage(BaseModel):
|
||||
tool_calls: list[dict] | None = None
|
||||
function_call: dict | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage":
|
||||
"""Convert a Prisma ChatMessage to a Pydantic ChatMessage."""
|
||||
return ChatMessage(
|
||||
role=prisma_message.role,
|
||||
content=prisma_message.content,
|
||||
name=prisma_message.name,
|
||||
tool_call_id=prisma_message.toolCallId,
|
||||
refusal=prisma_message.refusal,
|
||||
tool_calls=_parse_json_field(prisma_message.toolCalls),
|
||||
function_call=_parse_json_field(prisma_message.functionCall),
|
||||
)
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
@@ -75,10 +92,11 @@ class Usage(BaseModel):
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatSessionInfo(BaseModel):
|
||||
class ChatSession(BaseModel):
|
||||
session_id: str
|
||||
user_id: str
|
||||
title: str | None = None
|
||||
messages: list[ChatMessage]
|
||||
usage: list[Usage]
|
||||
credentials: dict[str, dict] = {} # Map of provider -> credential metadata
|
||||
started_at: datetime
|
||||
@@ -86,9 +104,60 @@ class ChatSessionInfo(BaseModel):
|
||||
successful_agent_runs: dict[str, int] = {}
|
||||
successful_agent_schedules: dict[str, int] = {}
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
||||
"""Convert Prisma ChatSession to Pydantic ChatSession."""
|
||||
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||
"""Attach a tool_call to the current turn's assistant message.
|
||||
|
||||
Searches backwards for the most recent assistant message (stopping at
|
||||
any user message boundary). If found, appends the tool_call to it.
|
||||
Otherwise creates a new assistant message with the tool_call.
|
||||
"""
|
||||
for msg in reversed(self.messages):
|
||||
if msg.role == "user":
|
||||
break
|
||||
if msg.role == "assistant":
|
||||
if not msg.tool_calls:
|
||||
msg.tool_calls = []
|
||||
msg.tool_calls.append(tool_call)
|
||||
return
|
||||
|
||||
self.messages.append(
|
||||
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def new(user_id: str) -> "ChatSession":
|
||||
return ChatSession(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_db(
|
||||
prisma_session: PrismaChatSession,
|
||||
prisma_messages: list[PrismaChatMessage] | None = None,
|
||||
) -> "ChatSession":
|
||||
"""Convert Prisma models to Pydantic ChatSession."""
|
||||
messages = []
|
||||
if prisma_messages:
|
||||
for msg in prisma_messages:
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
role=msg.role,
|
||||
content=msg.content,
|
||||
name=msg.name,
|
||||
tool_call_id=msg.toolCallId,
|
||||
refusal=msg.refusal,
|
||||
tool_calls=_parse_json_field(msg.toolCalls),
|
||||
function_call=_parse_json_field(msg.functionCall),
|
||||
)
|
||||
)
|
||||
|
||||
# Parse JSON fields from Prisma
|
||||
credentials = _parse_json_field(prisma_session.credentials, default={})
|
||||
successful_agent_runs = _parse_json_field(
|
||||
@@ -110,10 +179,11 @@ class ChatSessionInfo(BaseModel):
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
return ChatSession(
|
||||
session_id=prisma_session.id,
|
||||
user_id=prisma_session.userId,
|
||||
title=prisma_session.title,
|
||||
messages=messages,
|
||||
usage=usage,
|
||||
credentials=credentials,
|
||||
started_at=prisma_session.createdAt,
|
||||
@@ -122,55 +192,46 @@ class ChatSessionInfo(BaseModel):
|
||||
successful_agent_schedules=successful_agent_schedules,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _merge_consecutive_assistant_messages(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
"""Merge consecutive assistant messages into single messages.
|
||||
|
||||
class ChatSession(ChatSessionInfo):
|
||||
messages: list[ChatMessage]
|
||||
|
||||
@classmethod
|
||||
def new(cls, user_id: str) -> Self:
|
||||
return cls(
|
||||
session_id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
title=None,
|
||||
messages=[],
|
||||
usage=[],
|
||||
credentials={},
|
||||
started_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db(cls, prisma_session: PrismaChatSession) -> Self:
|
||||
"""Convert Prisma ChatSession to Pydantic ChatSession."""
|
||||
if prisma_session.Messages is None:
|
||||
raise ValueError(
|
||||
f"Prisma session {prisma_session.id} is missing Messages relation"
|
||||
)
|
||||
|
||||
return cls(
|
||||
**ChatSessionInfo.from_db(prisma_session).model_dump(),
|
||||
messages=[ChatMessage.from_db(m) for m in prisma_session.Messages],
|
||||
)
|
||||
|
||||
def add_tool_call_to_current_turn(self, tool_call: dict) -> None:
|
||||
"""Attach a tool_call to the current turn's assistant message.
|
||||
|
||||
Searches backwards for the most recent assistant message (stopping at
|
||||
any user message boundary). If found, appends the tool_call to it.
|
||||
Otherwise creates a new assistant message with the tool_call.
|
||||
Long-running tool flows can create split assistant messages: one with
|
||||
text content and another with tool_calls. Anthropic's API requires
|
||||
tool_result blocks to reference a tool_use in the immediately preceding
|
||||
assistant message, so these splits cause 400 errors via OpenRouter.
|
||||
"""
|
||||
for msg in reversed(self.messages):
|
||||
if msg.role == "user":
|
||||
break
|
||||
if msg.role == "assistant":
|
||||
if not msg.tool_calls:
|
||||
msg.tool_calls = []
|
||||
msg.tool_calls.append(tool_call)
|
||||
return
|
||||
if len(messages) < 2:
|
||||
return messages
|
||||
|
||||
self.messages.append(
|
||||
ChatMessage(role="assistant", content="", tool_calls=[tool_call])
|
||||
)
|
||||
result: list[ChatCompletionMessageParam] = [messages[0]]
|
||||
for msg in messages[1:]:
|
||||
prev = result[-1]
|
||||
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
||||
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
||||
|
||||
curr_content = curr.get("content") or ""
|
||||
if curr_content:
|
||||
prev_content = prev.get("content") or ""
|
||||
prev["content"] = (
|
||||
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
||||
)
|
||||
|
||||
curr_tool_calls = curr.get("tool_calls")
|
||||
if curr_tool_calls:
|
||||
prev_tool_calls = prev.get("tool_calls")
|
||||
prev["tool_calls"] = (
|
||||
list(prev_tool_calls) + list(curr_tool_calls)
|
||||
if prev_tool_calls
|
||||
else list(curr_tool_calls)
|
||||
)
|
||||
return result
|
||||
|
||||
def to_openai_messages(self) -> list[ChatCompletionMessageParam]:
|
||||
messages = []
|
||||
@@ -260,70 +321,40 @@ class ChatSession(ChatSessionInfo):
|
||||
)
|
||||
return self._merge_consecutive_assistant_messages(messages)
|
||||
|
||||
@staticmethod
|
||||
def _merge_consecutive_assistant_messages(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
) -> list[ChatCompletionMessageParam]:
|
||||
"""Merge consecutive assistant messages into single messages.
|
||||
|
||||
Long-running tool flows can create split assistant messages: one with
|
||||
text content and another with tool_calls. Anthropic's API requires
|
||||
tool_result blocks to reference a tool_use in the immediately preceding
|
||||
assistant message, so these splits cause 400 errors via OpenRouter.
|
||||
"""
|
||||
if len(messages) < 2:
|
||||
return messages
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
result: list[ChatCompletionMessageParam] = [messages[0]]
|
||||
for msg in messages[1:]:
|
||||
prev = result[-1]
|
||||
if prev.get("role") != "assistant" or msg.get("role") != "assistant":
|
||||
result.append(msg)
|
||||
continue
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
prev = cast(ChatCompletionAssistantMessageParam, prev)
|
||||
curr = cast(ChatCompletionAssistantMessageParam, msg)
|
||||
|
||||
curr_content = curr.get("content") or ""
|
||||
if curr_content:
|
||||
prev_content = prev.get("content") or ""
|
||||
prev["content"] = (
|
||||
f"{prev_content}\n{curr_content}" if prev_content else curr_content
|
||||
)
|
||||
|
||||
curr_tool_calls = curr.get("tool_calls")
|
||||
if curr_tool_calls:
|
||||
prev_tool_calls = prev.get("tool_calls")
|
||||
prev["tool_calls"] = (
|
||||
list(prev_tool_calls) + list(curr_tool_calls)
|
||||
if prev_tool_calls
|
||||
else list(curr_tool_calls)
|
||||
)
|
||||
return result
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, "
|
||||
f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any:
|
||||
"""Parse a JSON field that may be stored as string or already parsed."""
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, str):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
|
||||
# ================ Chat cache + DB operations ================ #
|
||||
|
||||
# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not
|
||||
# connected directly.
|
||||
|
||||
|
||||
async def cache_chat_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis (without persisting to the database)."""
|
||||
async def _cache_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session in Redis."""
|
||||
redis_key = _get_session_cache_key(session.session_id)
|
||||
async_redis = await get_redis_async()
|
||||
await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json())
|
||||
|
||||
|
||||
async def cache_chat_session(session: ChatSession) -> None:
|
||||
"""Cache a chat session without persisting to the database."""
|
||||
await _cache_session(session)
|
||||
|
||||
|
||||
async def invalidate_session_cache(session_id: str) -> None:
|
||||
"""Invalidate a chat session from Redis cache.
|
||||
|
||||
@@ -339,6 +370,77 @@ async def invalidate_session_cache(session_id: str) -> None:
|
||||
logger.warning(f"Failed to invalidate session cache for {session_id}: {e}")
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
prisma_session = await chat_db.get_chat_session(session_id)
|
||||
if not prisma_session:
|
||||
return None
|
||||
|
||||
messages = prisma_session.Messages
|
||||
logger.debug(
|
||||
f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, "
|
||||
f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles
|
||||
)
|
||||
|
||||
return ChatSession.from_db(prisma_session, messages)
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession, existing_message_count: int
|
||||
) -> None:
|
||||
"""Save or update a chat session in the database."""
|
||||
# Check if session exists in DB
|
||||
existing = await chat_db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await chat_db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, "
|
||||
f"roles={[m['role'] for m in messages_data]}"
|
||||
)
|
||||
await chat_db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
|
||||
async def get_chat_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
@@ -386,73 +488,22 @@ async def get_chat_session(
|
||||
|
||||
# Cache the session from DB
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
logger.info(f"Cached session {session_id} from database")
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache session {session_id}: {e}")
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def _get_session_from_cache(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from Redis cache."""
|
||||
redis_key = _get_session_cache_key(session_id)
|
||||
async_redis = await get_redis_async()
|
||||
raw_session: bytes | None = await async_redis.get(redis_key)
|
||||
|
||||
if raw_session is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
session = ChatSession.model_validate_json(raw_session)
|
||||
logger.info(
|
||||
f"Loading session {session_id} from cache: "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True)
|
||||
raise RedisError(f"Corrupted session data for {session_id}") from e
|
||||
|
||||
|
||||
async def _get_session_from_db(session_id: str) -> ChatSession | None:
|
||||
"""Get a chat session from the database."""
|
||||
session = await chat_db().get_chat_session(session_id)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"Loaded session {session_id} from DB: "
|
||||
f"has_messages={bool(session.messages)}, "
|
||||
f"message_count={len(session.messages)}, "
|
||||
f"roles={[m.role for m in session.messages]}"
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
|
||||
async def upsert_chat_session(
|
||||
session: ChatSession,
|
||||
*,
|
||||
existing_message_count: int | None = None,
|
||||
) -> tuple[ChatSession, int]:
|
||||
) -> ChatSession:
|
||||
"""Update a chat session in both cache and database.
|
||||
|
||||
Uses session-level locking to prevent race conditions when concurrent
|
||||
operations (e.g., background title update and main stream handler)
|
||||
attempt to upsert the same session simultaneously.
|
||||
|
||||
Args:
|
||||
existing_message_count: If provided, skip the DB query to count
|
||||
existing messages. The caller is responsible for tracking this
|
||||
accurately. Useful for incremental saves in a streaming loop
|
||||
where the caller already knows how many messages are persisted.
|
||||
|
||||
Returns:
|
||||
Tuple of (session, final_message_count) where final_message_count is
|
||||
the actual persisted message count after collision detection adjustments.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the database write fails. The cache is still updated
|
||||
as a best-effort optimization, but the error is propagated to ensure
|
||||
@@ -464,21 +515,15 @@ async def upsert_chat_session(
|
||||
|
||||
async with lock:
|
||||
# Get existing message count from DB for incremental saves
|
||||
if existing_message_count is None:
|
||||
existing_message_count = await chat_db().get_next_sequence(
|
||||
session.session_id
|
||||
)
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session.session_id
|
||||
)
|
||||
|
||||
db_error: Exception | None = None
|
||||
final_count = existing_message_count
|
||||
|
||||
# Save to database (primary storage)
|
||||
try:
|
||||
final_count = await _save_session_to_db(
|
||||
session,
|
||||
existing_message_count,
|
||||
skip_existence_check=existing_message_count > 0,
|
||||
)
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save session {session.session_id} to database: {e}"
|
||||
@@ -487,7 +532,7 @@ async def upsert_chat_session(
|
||||
|
||||
# Save to cache (best-effort, even if DB failed)
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
# If DB succeeded but cache failed, raise cache error
|
||||
if db_error is None:
|
||||
@@ -505,82 +550,7 @@ async def upsert_chat_session(
|
||||
f"Failed to persist chat session {session.session_id} to database"
|
||||
) from db_error
|
||||
|
||||
return session, final_count
|
||||
|
||||
|
||||
async def _save_session_to_db(
|
||||
session: ChatSession,
|
||||
existing_message_count: int,
|
||||
*,
|
||||
skip_existence_check: bool = False,
|
||||
) -> int:
|
||||
"""Save or update a chat session in the database.
|
||||
|
||||
Args:
|
||||
skip_existence_check: When True, skip the ``get_chat_session`` query
|
||||
and assume the session row already exists. Saves one DB round trip
|
||||
for incremental saves during streaming.
|
||||
|
||||
Returns:
|
||||
Final message count after save (accounting for collision detection).
|
||||
"""
|
||||
db = chat_db()
|
||||
|
||||
if not skip_existence_check:
|
||||
# Check if session exists in DB
|
||||
existing = await db.get_chat_session(session.session_id)
|
||||
|
||||
if not existing:
|
||||
# Create new session
|
||||
await db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=session.user_id,
|
||||
)
|
||||
existing_message_count = 0
|
||||
|
||||
# Calculate total tokens from usage
|
||||
total_prompt = sum(u.prompt_tokens for u in session.usage)
|
||||
total_completion = sum(u.completion_tokens for u in session.usage)
|
||||
|
||||
# Update session metadata
|
||||
await db.update_chat_session(
|
||||
session_id=session.session_id,
|
||||
credentials=session.credentials,
|
||||
successful_agent_runs=session.successful_agent_runs,
|
||||
successful_agent_schedules=session.successful_agent_schedules,
|
||||
total_prompt_tokens=total_prompt,
|
||||
total_completion_tokens=total_completion,
|
||||
)
|
||||
|
||||
# Add new messages (only those after existing count)
|
||||
new_messages = session.messages[existing_message_count:]
|
||||
final_count = existing_message_count
|
||||
if new_messages:
|
||||
messages_data = []
|
||||
for msg in new_messages:
|
||||
messages_data.append(
|
||||
{
|
||||
"role": msg.role,
|
||||
"content": msg.content,
|
||||
"name": msg.name,
|
||||
"tool_call_id": msg.tool_call_id,
|
||||
"refusal": msg.refusal,
|
||||
"tool_calls": msg.tool_calls,
|
||||
"function_call": msg.function_call,
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: "
|
||||
f"roles={[m['role'] for m in messages_data]}, "
|
||||
f"start_sequence={existing_message_count}"
|
||||
)
|
||||
_, final_count = await db.add_chat_messages_batch(
|
||||
session_id=session.session_id,
|
||||
messages=messages_data,
|
||||
start_sequence=existing_message_count,
|
||||
)
|
||||
|
||||
return final_count
|
||||
return session
|
||||
|
||||
|
||||
async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession:
|
||||
@@ -598,7 +568,9 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
|
||||
session.messages.append(message)
|
||||
existing_message_count = await chat_db().get_next_sequence(session_id)
|
||||
existing_message_count = await chat_db.get_chat_session_message_count(
|
||||
session_id
|
||||
)
|
||||
|
||||
try:
|
||||
await _save_session_to_db(session, existing_message_count)
|
||||
@@ -608,7 +580,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat
|
||||
) from e
|
||||
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache write failed for session {session_id}: {e}")
|
||||
|
||||
@@ -627,7 +599,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Create in database first - fail fast if this fails
|
||||
try:
|
||||
await chat_db().create_chat_session(
|
||||
await chat_db.create_chat_session(
|
||||
session_id=session.session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -639,7 +611,7 @@ async def create_chat_session(user_id: str) -> ChatSession:
|
||||
|
||||
# Cache the session (best-effort optimization, DB is source of truth)
|
||||
try:
|
||||
await cache_chat_session(session)
|
||||
await _cache_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cache new session {session.session_id}: {e}")
|
||||
|
||||
@@ -650,16 +622,20 @@ async def get_user_sessions(
|
||||
user_id: str,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[ChatSessionInfo], int]:
|
||||
) -> tuple[list[ChatSession], int]:
|
||||
"""Get chat sessions for a user from the database with total count.
|
||||
|
||||
Returns:
|
||||
A tuple of (sessions, total_count) where total_count is the overall
|
||||
number of sessions for the user (not just the current page).
|
||||
"""
|
||||
db = chat_db()
|
||||
sessions = await db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await db.get_user_session_count(user_id)
|
||||
prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset)
|
||||
total_count = await chat_db.get_user_session_count(user_id)
|
||||
|
||||
sessions = []
|
||||
for prisma_session in prisma_sessions:
|
||||
# Convert without messages for listing (lighter weight)
|
||||
sessions.append(ChatSession.from_db(prisma_session, None))
|
||||
|
||||
return sessions, total_count
|
||||
|
||||
@@ -677,7 +653,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo
|
||||
"""
|
||||
# Delete from database first (with optional user_id validation)
|
||||
# This confirms ownership before invalidating cache
|
||||
deleted = await chat_db().delete_chat_session(session_id, user_id)
|
||||
deleted = await chat_db.delete_chat_session(session_id, user_id)
|
||||
|
||||
if not deleted:
|
||||
return False
|
||||
@@ -712,7 +688,7 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
True if updated successfully, False otherwise.
|
||||
"""
|
||||
try:
|
||||
result = await chat_db().update_chat_session(session_id=session_id, title=title)
|
||||
result = await chat_db.update_chat_session(session_id=session_id, title=title)
|
||||
if result is None:
|
||||
logger.warning(f"Session {session_id} not found for title update")
|
||||
return False
|
||||
@@ -724,7 +700,7 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
cached = await _get_session_from_cache(session_id)
|
||||
if cached:
|
||||
cached.title = title
|
||||
await cache_chat_session(cached)
|
||||
await _cache_session(cached)
|
||||
except Exception as e:
|
||||
# Not critical - title will be correct on next full cache refresh
|
||||
logger.warning(
|
||||
@@ -735,29 +711,3 @@ async def update_session_title(session_id: str, title: str) -> bool:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update title for session {session_id}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ==================== Chat session locks ==================== #
|
||||
|
||||
_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary()
|
||||
_session_locks_mutex = asyncio.Lock()
|
||||
|
||||
|
||||
async def _get_session_lock(session_id: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a specific session to prevent concurrent upserts.
|
||||
|
||||
This was originally added to solve the specific problem of race conditions between
|
||||
the session title thread and the conversation thread, which always occurs on the
|
||||
same instance as we prevent rapid request sends on the frontend.
|
||||
|
||||
Uses WeakValueDictionary for automatic cleanup: locks are garbage collected
|
||||
when no coroutine holds a reference to them, preventing memory leaks from
|
||||
unbounded growth of session locks. Explicit cleanup also occurs
|
||||
in `delete_chat_session()`.
|
||||
"""
|
||||
async with _session_locks_mutex:
|
||||
lock = _session_locks.get(session_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_session_locks[session_id] = lock
|
||||
return lock
|
||||
@@ -60,7 +60,7 @@ async def test_chatsession_redis_storage(setup_test_user, test_user_id):
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
|
||||
s, _ = await upsert_chat_session(s)
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(
|
||||
session_id=s.session_id,
|
||||
@@ -77,7 +77,7 @@ async def test_chatsession_redis_storage_user_id_mismatch(
|
||||
|
||||
s = ChatSession.new(user_id=test_user_id)
|
||||
s.messages = messages
|
||||
s, _ = await upsert_chat_session(s)
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
s2 = await get_chat_session(s.session_id, "different_user_id")
|
||||
|
||||
@@ -94,7 +94,7 @@ async def test_chatsession_db_storage(setup_test_user, test_user_id):
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
assert s.session_id is not None, "Session id is not set"
|
||||
# Upsert to save to both cache and DB
|
||||
s, _ = await upsert_chat_session(s)
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
# Clear the Redis cache to force DB load
|
||||
redis_key = f"chat:session:{s.session_id}"
|
||||
@@ -11,25 +11,24 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response,
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.completion_handler import (
|
||||
process_operation_failure,
|
||||
process_operation_success,
|
||||
)
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_task
|
||||
from backend.copilot.model import (
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
append_and_save_message,
|
||||
create_chat_session,
|
||||
delete_chat_session,
|
||||
get_chat_session,
|
||||
get_user_sessions,
|
||||
)
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat
|
||||
from backend.copilot.tools.models import (
|
||||
from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart
|
||||
from .sdk import service as sdk_service
|
||||
from .tools.models import (
|
||||
AgentDetailsResponse,
|
||||
AgentOutputResponse,
|
||||
AgentPreviewResponse,
|
||||
@@ -50,11 +49,9 @@ from backend.copilot.tools.models import (
|
||||
OperationPendingResponse,
|
||||
OperationStartedResponse,
|
||||
SetupRequirementsResponse,
|
||||
SuggestedGoalResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from .tracking import track_user_message
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -132,14 +129,6 @@ class ListSessionsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class CancelTaskResponse(BaseModel):
|
||||
"""Response model for the cancel task endpoint."""
|
||||
|
||||
cancelled: bool
|
||||
task_id: str | None = None
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
|
||||
@@ -222,43 +211,6 @@ async def create_session(
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/sessions/{session_id}",
|
||||
dependencies=[Security(auth.requires_user)],
|
||||
status_code=204,
|
||||
responses={404: {"description": "Session not found or access denied"}},
|
||||
)
|
||||
async def delete_session(
|
||||
session_id: str,
|
||||
user_id: Annotated[str, Security(auth.get_user_id)],
|
||||
) -> Response:
|
||||
"""
|
||||
Delete a chat session.
|
||||
|
||||
Permanently removes a chat session and all its messages.
|
||||
Only the owner can delete their sessions.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to delete.
|
||||
user_id: The authenticated user's ID.
|
||||
|
||||
Returns:
|
||||
204 No Content on success.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if session not found or not owned by user.
|
||||
"""
|
||||
deleted = await delete_chat_session(session_id, user_id)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Session {session_id} not found or access denied",
|
||||
)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sessions/{session_id}",
|
||||
)
|
||||
@@ -322,57 +274,6 @@ async def get_session(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/cancel",
|
||||
status_code=200,
|
||||
)
|
||||
async def cancel_session_task(
|
||||
session_id: str,
|
||||
user_id: Annotated[str | None, Depends(auth.get_user_id)],
|
||||
) -> CancelTaskResponse:
|
||||
"""Cancel the active streaming task for a session.
|
||||
|
||||
Publishes a cancel event to the executor via RabbitMQ FANOUT, then
|
||||
polls Redis until the task status flips from ``running`` or a timeout
|
||||
(5 s) is reached. Returns only after the cancellation is confirmed.
|
||||
"""
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
active_task, _ = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
)
|
||||
if not active_task:
|
||||
return CancelTaskResponse(cancelled=False, reason="no_active_task")
|
||||
|
||||
task_id = active_task.task_id
|
||||
await enqueue_cancel_task(task_id)
|
||||
logger.info(
|
||||
f"[CANCEL] Published cancel for task ...{task_id[-8:]} "
|
||||
f"session ...{session_id[-8:]}"
|
||||
)
|
||||
|
||||
# Poll until the executor confirms the task is no longer running.
|
||||
# Keep max_wait below typical reverse-proxy read timeouts.
|
||||
poll_interval = 0.5
|
||||
max_wait = 5.0
|
||||
waited = 0.0
|
||||
while waited < max_wait:
|
||||
await asyncio.sleep(poll_interval)
|
||||
waited += poll_interval
|
||||
task = await stream_registry.get_task(task_id)
|
||||
if task is None or task.status != "running":
|
||||
logger.info(
|
||||
f"[CANCEL] Task ...{task_id[-8:]} confirmed stopped "
|
||||
f"(status={task.status if task else 'gone'}) after {waited:.1f}s"
|
||||
)
|
||||
return CancelTaskResponse(cancelled=True, task_id=task_id)
|
||||
|
||||
logger.warning(f"[CANCEL] Task ...{task_id[-8:]} not confirmed after {max_wait}s")
|
||||
return CancelTaskResponse(
|
||||
cancelled=True, task_id=task_id, reason="cancel_published_not_confirmed"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{session_id}/stream",
|
||||
)
|
||||
@@ -415,7 +316,7 @@ async def stream_chat_post(
|
||||
f"user={user_id}, message_len={len(request.message)}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
await _validate_and_get_session(session_id, user_id)
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
logger.info(
|
||||
f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms",
|
||||
extra={
|
||||
@@ -442,7 +343,7 @@ async def stream_chat_post(
|
||||
message_length=len(request.message),
|
||||
)
|
||||
logger.info(f"[STREAM] Saving user message to session {session_id}")
|
||||
await append_and_save_message(session_id, message)
|
||||
session = await append_and_save_message(session_id, message)
|
||||
logger.info(f"[STREAM] User message saved for session {session_id}")
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
@@ -469,19 +370,125 @@ async def stream_chat_post(
|
||||
},
|
||||
)
|
||||
|
||||
await enqueue_copilot_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
context=request.context,
|
||||
)
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
import time as time_module
|
||||
|
||||
gen_start_time = time_module.perf_counter()
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
first_chunk_time, ttfc = None, None
|
||||
chunk_count = 0
|
||||
try:
|
||||
# Emit a start event with task_id for reconnection
|
||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||
logger.info(
|
||||
f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": (time_module.perf_counter() - gen_start_time)
|
||||
* 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Choose service based on LaunchDarkly flag (falls back to config default)
|
||||
use_sdk = await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else chat_service.stream_chat_completion
|
||||
)
|
||||
logger.info(
|
||||
f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion",
|
||||
extra={"json_fields": log_meta},
|
||||
)
|
||||
# Pass message=None since we already added it to the session above
|
||||
async for chunk in stream_fn(
|
||||
session_id,
|
||||
None, # Message already in session
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass session with message already added
|
||||
context=request.context,
|
||||
):
|
||||
# Skip duplicate StreamStart — we already published one above
|
||||
if isinstance(chunk, StreamStart):
|
||||
continue
|
||||
chunk_count += 1
|
||||
if first_chunk_time is None:
|
||||
first_chunk_time = time_module.perf_counter()
|
||||
ttfc = first_chunk_time - gen_start_time
|
||||
logger.info(
|
||||
f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"chunk_type": type(chunk).__name__,
|
||||
"time_to_first_chunk_ms": ttfc * 1000,
|
||||
}
|
||||
},
|
||||
)
|
||||
# Write to Redis (subscribers will receive via XREAD)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
gen_end_time = time_module.perf_counter()
|
||||
total_time = (gen_end_time - gen_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; "
|
||||
f"task={task_id}, session={session_id}, "
|
||||
f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"total_time_ms": total_time,
|
||||
"time_to_first_chunk_ms": (
|
||||
ttfc * 1000 if ttfc is not None else None
|
||||
),
|
||||
"n_chunks": chunk_count,
|
||||
}
|
||||
},
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
except Exception as e:
|
||||
elapsed = time_module.perf_counter() - gen_start_time
|
||||
logger.error(
|
||||
f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
**log_meta,
|
||||
"elapsed_ms": elapsed * 1000,
|
||||
"error": str(e),
|
||||
}
|
||||
},
|
||||
)
|
||||
# Publish a StreamError so the frontend can display an error message
|
||||
try:
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="stream_error",
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass # Best-effort; mark_task_completed will publish StreamFinish
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
bg_task = asyncio.create_task(run_ai_generation())
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
setup_time = (time.perf_counter() - stream_start_time) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms",
|
||||
f"[TIMING] Background task started, setup={setup_time:.1f}ms",
|
||||
extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}},
|
||||
)
|
||||
|
||||
@@ -1044,7 +1051,6 @@ ToolResponseUnion = (
|
||||
| AgentPreviewResponse
|
||||
| AgentSavedResponse
|
||||
| ClarificationNeededResponse
|
||||
| SuggestedGoalResponse
|
||||
| BlockListResponse
|
||||
| BlockDetailsResponse
|
||||
| BlockOutputResponse
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This module provides the adapter layer that converts streaming messages from
|
||||
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
||||
the frontend expects.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
Message,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.api.features.chat.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.api.features.chat.sdk.tool_adapter import (
|
||||
MCP_TOOL_PREFIX,
|
||||
pop_pending_tool_output,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SDKResponseAdapter:
|
||||
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This class maintains state during a streaming session to properly track
|
||||
text blocks, tool calls, and message lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, message_id: str | None = None):
|
||||
self.message_id = message_id or str(uuid.uuid4())
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_started_text = False
|
||||
self.has_ended_text = False
|
||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||
self.task_id: str | None = None
|
||||
self.step_open = False
|
||||
|
||||
def set_task_id(self, task_id: str) -> None:
|
||||
"""Set the task ID for reconnection support."""
|
||||
self.task_id = task_id
|
||||
|
||||
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
|
||||
"""Convert a single SDK message to Vercel AI SDK format."""
|
||||
responses: list[StreamBaseResponse] = []
|
||||
|
||||
if isinstance(sdk_message, SystemMessage):
|
||||
if sdk_message.subtype == "init":
|
||||
responses.append(
|
||||
StreamStart(messageId=self.message_id, taskId=self.task_id)
|
||||
)
|
||||
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
# After tool results, the SDK sends a new AssistantMessage for the
|
||||
# next LLM turn. Open a new step if the previous one was closed.
|
||||
if not self.step_open:
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
|
||||
for block in sdk_message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
if block.text:
|
||||
self._ensure_text_started(responses)
|
||||
responses.append(
|
||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
||||
)
|
||||
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
self._end_text_if_open(responses)
|
||||
|
||||
# Strip MCP prefix so frontend sees "find_block"
|
||||
# instead of "mcp__copilot__find_block".
|
||||
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
responses.append(
|
||||
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
|
||||
)
|
||||
responses.append(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=block.id,
|
||||
toolName=tool_name,
|
||||
input=block.input,
|
||||
)
|
||||
)
|
||||
self.current_tool_calls[block.id] = {"name": tool_name}
|
||||
|
||||
elif isinstance(sdk_message, UserMessage):
|
||||
# UserMessage carries tool results back from tool execution.
|
||||
content = sdk_message.content
|
||||
blocks = content if isinstance(content, list) else []
|
||||
for block in blocks:
|
||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
# Prefer the stashed full output over the SDK's
|
||||
# (potentially truncated) ToolResultBlock content.
|
||||
# The SDK truncates large results, writing them to disk,
|
||||
# which breaks frontend widget parsing.
|
||||
output = pop_pending_tool_output(tool_name) or (
|
||||
_extract_tool_output(block.content)
|
||||
)
|
||||
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=block.tool_use_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=not (block.is_error or False),
|
||||
)
|
||||
)
|
||||
|
||||
# Close the current step after tool results — the next
|
||||
# AssistantMessage will open a new step for the continuation.
|
||||
if self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
elif isinstance(sdk_message, ResultMessage):
|
||||
self._end_text_if_open(responses)
|
||||
# Close the step before finishing.
|
||||
if self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
if sdk_message.subtype == "success":
|
||||
responses.append(StreamFinish())
|
||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||
error_msg = getattr(sdk_message, "result", None) or "Unknown error"
|
||||
responses.append(
|
||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||
|
||||
return responses
|
||||
|
||||
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Start (or restart) a text block if needed."""
|
||||
if not self.has_started_text or self.has_ended_text:
|
||||
if self.has_ended_text:
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_ended_text = False
|
||||
responses.append(StreamTextStart(id=self.text_block_id))
|
||||
self.has_started_text = True
|
||||
|
||||
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""End the current text block if one is open."""
|
||||
if self.has_started_text and not self.has_ended_text:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
|
||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||
"""Extract a string output from a ToolResultBlock's content field."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||||
if parts:
|
||||
return "".join(parts)
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
if content is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
@@ -1,8 +1,5 @@
|
||||
"""Unit tests for the SDK response adapter."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ResultMessage,
|
||||
@@ -13,7 +10,7 @@ from claude_agent_sdk import (
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.copilot.response_model import (
|
||||
from backend.api.features.chat.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
@@ -30,10 +27,6 @@ from backend.copilot.response_model import (
|
||||
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .tool_adapter import MCP_TOOL_PREFIX
|
||||
from .tool_adapter import _pending_tool_outputs as _pto
|
||||
from .tool_adapter import _stash_event
|
||||
from .tool_adapter import stash_pending_tool_output as _stash
|
||||
from .tool_adapter import wait_for_stash
|
||||
|
||||
|
||||
def _adapter() -> SDKResponseAdapter:
|
||||
@@ -371,310 +364,3 @@ def test_full_conversation_flow():
|
||||
"StreamFinishStep", # step 2 closed
|
||||
"StreamFinish",
|
||||
]
|
||||
|
||||
|
||||
# -- Flush unresolved tool calls --------------------------------------------
|
||||
|
||||
|
||||
def test_flush_unresolved_at_result_message():
|
||||
"""Built-in tools (WebSearch) without UserMessage results get flushed at ResultMessage."""
|
||||
adapter = _adapter()
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# 1. Init
|
||||
all_responses.extend(
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
)
|
||||
# 2. Tool use (built-in tool — no MCP prefix)
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# 3. No UserMessage for this tool — go straight to ResultMessage
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
types = [type(r).__name__ for r in all_responses]
|
||||
assert types == [
|
||||
"StreamStart",
|
||||
"StreamStartStep",
|
||||
"StreamToolInputStart",
|
||||
"StreamToolInputAvailable",
|
||||
"StreamToolOutputAvailable", # flushed with empty output
|
||||
"StreamFinishStep", # step closed by flush
|
||||
"StreamFinish",
|
||||
]
|
||||
# The flushed output should be empty (no stash available)
|
||||
output_event = [
|
||||
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
|
||||
][0]
|
||||
assert output_event.toolCallId == "ws-1"
|
||||
assert output_event.toolName == "WebSearch"
|
||||
assert output_event.output == ""
|
||||
|
||||
|
||||
def test_flush_unresolved_at_next_assistant_message():
|
||||
"""Built-in tools get flushed when the next AssistantMessage arrives."""
|
||||
adapter = _adapter()
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# 1. Init
|
||||
all_responses.extend(
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
)
|
||||
# 2. Tool use (built-in — no UserMessage will come)
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# 3. Next AssistantMessage triggers flush before processing its blocks
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[TextBlock(text="Here are the results")], model="test"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
types = [type(r).__name__ for r in all_responses]
|
||||
assert types == [
|
||||
"StreamStart",
|
||||
"StreamStartStep",
|
||||
"StreamToolInputStart",
|
||||
"StreamToolInputAvailable",
|
||||
# Flush at next AssistantMessage:
|
||||
"StreamToolOutputAvailable",
|
||||
"StreamFinishStep", # step closed by flush
|
||||
# New step for continuation text:
|
||||
"StreamStartStep",
|
||||
"StreamTextStart",
|
||||
"StreamTextDelta",
|
||||
]
|
||||
|
||||
|
||||
def test_flush_with_stashed_output():
|
||||
"""Stashed output from PostToolUse hook is used when flushing."""
|
||||
adapter = _adapter()
|
||||
|
||||
# Simulate PostToolUse hook stashing output
|
||||
_pto.set({})
|
||||
_stash("WebSearch", "Search result: 5 items found")
|
||||
|
||||
all_responses: list[StreamBaseResponse] = []
|
||||
|
||||
# Tool use
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(id="ws-1", name="WebSearch", input={"query": "test"})
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
)
|
||||
# ResultMessage triggers flush
|
||||
all_responses.extend(
|
||||
adapter.convert_message(
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
output_events = [
|
||||
r for r in all_responses if isinstance(r, StreamToolOutputAvailable)
|
||||
]
|
||||
assert len(output_events) == 1
|
||||
assert output_events[0].output == "Search result: 5 items found"
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# -- wait_for_stash synchronisation tests --
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_stash_signaled():
|
||||
"""wait_for_stash returns True when stash_pending_tool_output signals."""
|
||||
_pto.set({})
|
||||
event = asyncio.Event()
|
||||
_stash_event.set(event)
|
||||
|
||||
# Simulate a PostToolUse hook that stashes output after a short delay
|
||||
async def delayed_stash():
|
||||
await asyncio.sleep(0.01)
|
||||
_stash("WebSearch", "result data")
|
||||
|
||||
asyncio.create_task(delayed_stash())
|
||||
result = await wait_for_stash(timeout=1.0)
|
||||
|
||||
assert result is True
|
||||
assert _pto.get({}).get("WebSearch") == ["result data"]
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_stash_timeout():
|
||||
"""wait_for_stash returns False on timeout when no stash occurs."""
|
||||
_pto.set({})
|
||||
event = asyncio.Event()
|
||||
_stash_event.set(event)
|
||||
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is False
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_stash_already_stashed():
|
||||
"""wait_for_stash picks up a stash that happened just before the wait."""
|
||||
_pto.set({})
|
||||
event = asyncio.Event()
|
||||
_stash_event.set(event)
|
||||
|
||||
# Stash before waiting — simulates hook completing before message arrives
|
||||
_stash("Read", "file contents")
|
||||
# Event is now set; wait_for_stash detects the fast path and returns
|
||||
# immediately without timing out.
|
||||
result = await wait_for_stash(timeout=0.05)
|
||||
assert result is True
|
||||
|
||||
# But the stash itself is populated
|
||||
assert _pto.get({}).get("Read") == ["file contents"]
|
||||
|
||||
# Cleanup
|
||||
_pto.set({}) # type: ignore[arg-type]
|
||||
_stash_event.set(None)
|
||||
|
||||
|
||||
# -- Parallel tool call tests --
|
||||
|
||||
|
||||
def test_parallel_tool_calls_not_flushed_prematurely():
|
||||
"""Parallel tool calls should NOT be flushed when the next AssistantMessage
|
||||
only contains ToolUseBlocks (parallel continuation)."""
|
||||
adapter = SDKResponseAdapter()
|
||||
|
||||
# Init
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
|
||||
# First AssistantMessage: tool call #1
|
||||
msg1 = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
|
||||
model="test",
|
||||
)
|
||||
r1 = adapter.convert_message(msg1)
|
||||
assert any(isinstance(r, StreamToolInputAvailable) for r in r1)
|
||||
assert adapter.has_unresolved_tool_calls
|
||||
|
||||
# Second AssistantMessage: tool call #2 (parallel continuation)
|
||||
msg2 = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t2", name="WebSearch", input={"q": "bar"})],
|
||||
model="test",
|
||||
)
|
||||
r2 = adapter.convert_message(msg2)
|
||||
|
||||
# No flush should have happened — t1 should NOT have StreamToolOutputAvailable
|
||||
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
|
||||
assert len(output_events) == 0, (
|
||||
f"Tool-only AssistantMessage should not flush prior tools, "
|
||||
f"but got {len(output_events)} output events"
|
||||
)
|
||||
|
||||
# Both t1 and t2 should still be unresolved
|
||||
assert "t1" not in adapter.resolved_tool_calls
|
||||
assert "t2" not in adapter.resolved_tool_calls
|
||||
|
||||
|
||||
def test_text_assistant_message_flushes_prior_tools():
|
||||
"""An AssistantMessage with text (new turn) should flush unresolved tools."""
|
||||
adapter = SDKResponseAdapter()
|
||||
|
||||
# Init
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
|
||||
# Tool call
|
||||
msg1 = AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="WebSearch", input={"q": "foo"})],
|
||||
model="test",
|
||||
)
|
||||
adapter.convert_message(msg1)
|
||||
assert adapter.has_unresolved_tool_calls
|
||||
|
||||
# Text AssistantMessage (new turn after tools completed)
|
||||
msg2 = AssistantMessage(
|
||||
content=[TextBlock(text="Here are the results")],
|
||||
model="test",
|
||||
)
|
||||
r2 = adapter.convert_message(msg2)
|
||||
|
||||
# Flush SHOULD have happened — t1 gets empty output
|
||||
output_events = [r for r in r2 if isinstance(r, StreamToolOutputAvailable)]
|
||||
assert len(output_events) == 1
|
||||
assert output_events[0].toolCallId == "t1"
|
||||
assert "t1" in adapter.resolved_tool_calls
|
||||
|
||||
|
||||
def test_already_resolved_tool_skipped_in_user_message():
|
||||
"""A tool result in UserMessage should be skipped if already resolved by flush."""
|
||||
adapter = SDKResponseAdapter()
|
||||
|
||||
adapter.convert_message(SystemMessage(subtype="init", data={}))
|
||||
|
||||
# Tool call + flush via text message
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ToolUseBlock(id="t1", name="WebSearch", input={})],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[TextBlock(text="Done")],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
assert "t1" in adapter.resolved_tool_calls
|
||||
|
||||
# Now UserMessage arrives with the real result — should be skipped
|
||||
user_msg = UserMessage(content=[ToolResultBlock(tool_use_id="t1", content="real")])
|
||||
r = adapter.convert_message(user_msg)
|
||||
output_events = [r_ for r_ in r if isinstance(r_, StreamToolOutputAvailable)]
|
||||
assert (
|
||||
len(output_events) == 0
|
||||
), "Already-resolved tool should not emit duplicate output"
|
||||
@@ -11,12 +11,11 @@ import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from .tool_adapter import (
|
||||
from backend.api.features.chat.sdk.tool_adapter import (
|
||||
BLOCKED_TOOLS,
|
||||
DANGEROUS_PATTERNS,
|
||||
MCP_TOOL_PREFIX,
|
||||
WORKSPACE_SCOPED_TOOLS,
|
||||
stash_pending_tool_output,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -124,20 +123,20 @@ def _validate_user_isolation(
|
||||
"""Validate that tool calls respect user isolation."""
|
||||
# For workspace file tools, ensure path doesn't escape
|
||||
if "workspace" in tool_name.lower():
|
||||
# The "path" param is a cloud storage key (e.g. "/ASEAN/report.md")
|
||||
# where a leading "/" is normal. Only check for ".." traversal.
|
||||
# Filesystem paths (source_path, save_to_path) are validated inside
|
||||
# the tool itself via _validate_ephemeral_path.
|
||||
path = tool_input.get("path", "") or tool_input.get("file_path", "")
|
||||
if path and ".." in path:
|
||||
logger.warning(f"Blocked path traversal attempt: {path} by user {user_id}")
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": "Path traversal not allowed",
|
||||
if path:
|
||||
# Check for path traversal
|
||||
if ".." in path or path.startswith("/"):
|
||||
logger.warning(
|
||||
f"Blocked path traversal attempt: {path} by user {user_id}"
|
||||
)
|
||||
return {
|
||||
"hookSpecificOutput": {
|
||||
"hookEventName": "PreToolUse",
|
||||
"permissionDecision": "deny",
|
||||
"permissionDecisionReason": "Path traversal not allowed",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
@@ -188,19 +187,8 @@ def create_security_hooks(
|
||||
|
||||
# Rate-limit Task (sub-agent) spawns per session
|
||||
if tool_name == "Task":
|
||||
# Block background task execution first — denied calls
|
||||
# should not consume a subtask slot.
|
||||
if tool_input.get("run_in_background"):
|
||||
logger.info(f"[SDK] Blocked background Task, user={user_id}")
|
||||
return cast(
|
||||
SyncHookJSONOutput,
|
||||
_deny(
|
||||
"Background task execution is not supported. "
|
||||
"Run tasks in the foreground instead "
|
||||
"(remove the run_in_background parameter)."
|
||||
),
|
||||
)
|
||||
if task_spawn_count >= max_subtasks:
|
||||
task_spawn_count += 1
|
||||
if task_spawn_count > max_subtasks:
|
||||
logger.warning(
|
||||
f"[SDK] Task limit reached ({max_subtasks}), user={user_id}"
|
||||
)
|
||||
@@ -211,7 +199,6 @@ def create_security_hooks(
|
||||
"Please continue in the main conversation."
|
||||
),
|
||||
)
|
||||
task_spawn_count += 1
|
||||
|
||||
# Strip MCP prefix for consistent validation
|
||||
is_copilot_tool = tool_name.startswith(MCP_TOOL_PREFIX)
|
||||
@@ -237,43 +224,10 @@ def create_security_hooks(
|
||||
tool_use_id: str | None,
|
||||
context: HookContext,
|
||||
) -> SyncHookJSONOutput:
|
||||
"""Log successful tool executions and stash SDK built-in tool outputs.
|
||||
|
||||
MCP tools stash their output in ``_execute_tool_sync`` before the
|
||||
SDK can truncate it. SDK built-in tools (WebSearch, Read, etc.)
|
||||
are executed by the CLI internally — this hook captures their
|
||||
output so the response adapter can forward it to the frontend.
|
||||
"""
|
||||
"""Log successful tool executions for observability."""
|
||||
_ = context
|
||||
tool_name = cast(str, input_data.get("tool_name", ""))
|
||||
is_builtin = not tool_name.startswith(MCP_TOOL_PREFIX)
|
||||
logger.info(
|
||||
"[SDK] PostToolUse: %s (builtin=%s, tool_use_id=%s)",
|
||||
tool_name,
|
||||
is_builtin,
|
||||
(tool_use_id or "")[:12],
|
||||
)
|
||||
|
||||
# Stash output for SDK built-in tools so the response adapter can
|
||||
# emit StreamToolOutputAvailable even when the CLI doesn't surface
|
||||
# a separate UserMessage with ToolResultBlock content.
|
||||
if is_builtin:
|
||||
tool_response = input_data.get("tool_response")
|
||||
if tool_response is not None:
|
||||
resp_preview = str(tool_response)[:100]
|
||||
logger.info(
|
||||
"[SDK] Stashing builtin output for %s (%d chars): %s...",
|
||||
tool_name,
|
||||
len(str(tool_response)),
|
||||
resp_preview,
|
||||
)
|
||||
stash_pending_tool_output(tool_name, tool_response)
|
||||
else:
|
||||
logger.warning(
|
||||
"[SDK] PostToolUse for builtin %s but tool_response is None",
|
||||
tool_name,
|
||||
)
|
||||
|
||||
logger.debug(f"[SDK] Tool success: {tool_name}, tool_use_id={tool_use_id}")
|
||||
return cast(SyncHookJSONOutput, {})
|
||||
|
||||
async def post_tool_failure_hook(
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Unit tests for SDK security hooks."""
|
||||
|
||||
import os
|
||||
|
||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||
|
||||
SDK_CWD = "/tmp/copilot-abc123"
|
||||
|
||||
|
||||
def _is_denied(result: dict) -> bool:
|
||||
hook = result.get("hookSpecificOutput", {})
|
||||
return hook.get("permissionDecision") == "deny"
|
||||
|
||||
|
||||
# -- Blocked tools -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_blocked_tools_denied():
|
||||
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
||||
result = _validate_tool_access(tool, {})
|
||||
assert _is_denied(result), f"{tool} should be blocked"
|
||||
|
||||
|
||||
def test_unknown_tool_allowed():
|
||||
result = _validate_tool_access("SomeCustomTool", {})
|
||||
assert result == {}
|
||||
|
||||
|
||||
# -- Workspace-scoped tools --------------------------------------------------
|
||||
|
||||
|
||||
def test_read_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_write_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_edit_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_glob_within_workspace_allowed():
|
||||
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_grep_within_workspace_allowed():
|
||||
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_outside_workspace_denied():
|
||||
result = _validate_tool_access(
|
||||
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_write_outside_workspace_denied():
|
||||
result = _validate_tool_access(
|
||||
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_traversal_attack_denied():
|
||||
result = _validate_tool_access(
|
||||
"Read",
|
||||
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
|
||||
sdk_cwd=SDK_CWD,
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_no_path_allowed():
|
||||
"""Glob/Grep without a path argument defaults to cwd — should pass."""
|
||||
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_no_cwd_denies_absolute():
|
||||
"""If no sdk_cwd is set, absolute paths are denied."""
|
||||
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Tool-results directory --------------------------------------------------
|
||||
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_claude_projects_without_tool_results_denied():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
|
||||
|
||||
|
||||
def test_bash_builtin_always_blocked():
|
||||
"""SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead."""
|
||||
result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Dangerous patterns ------------------------------------------------------
|
||||
|
||||
|
||||
def test_dangerous_pattern_blocked():
|
||||
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_subprocess_pattern_blocked():
|
||||
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- User isolation ----------------------------------------------------------
|
||||
|
||||
|
||||
def test_workspace_path_traversal_blocked():
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_workspace_absolute_path_blocked():
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "/etc/passwd"}, user_id="user-1"
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_workspace_normal_path_allowed():
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_non_workspace_tool_passes_isolation():
|
||||
result = _validate_user_isolation(
|
||||
"find_agent", {"query": "email"}, user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
@@ -0,0 +1,752 @@
|
||||
"""Claude Agent SDK service layer for CoPilot chat completions."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from .. import stream_registry
|
||||
from ..config import ChatConfig
|
||||
from ..model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
get_chat_session,
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from ..service import (
|
||||
_build_system_prompt,
|
||||
_execute_long_running_tool_with_streaming,
|
||||
_generate_session_title,
|
||||
)
|
||||
from ..tools.models import OperationPendingResponse, OperationStartedResponse
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tracking import track_user_message
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .tool_adapter import (
|
||||
COPILOT_TOOL_NAMES,
|
||||
SDK_DISALLOWED_TOOLS,
|
||||
LongRunningCallback,
|
||||
create_copilot_mcp_server,
|
||||
set_execution_context,
|
||||
)
|
||||
from .transcript import (
|
||||
download_transcript,
|
||||
read_transcript_file,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
write_transcript_to_tempfile,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Set to hold background tasks to prevent garbage collection
|
||||
_background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CapturedTranscript:
|
||||
"""Info captured by the SDK Stop hook for stateless --resume."""
|
||||
|
||||
path: str = ""
|
||||
sdk_session_id: str = ""
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
return bool(self.path)
|
||||
|
||||
|
||||
_SDK_CWD_PREFIX = WORKSPACE_PREFIX
|
||||
|
||||
# Appended to the system prompt to inform the agent about available tools.
|
||||
# The SDK built-in Bash is NOT available — use mcp__copilot__bash_exec instead,
|
||||
# which has kernel-level network isolation (unshare --net).
|
||||
_SDK_TOOL_SUPPLEMENT = """
|
||||
|
||||
## Tool notes
|
||||
|
||||
- The SDK built-in Bash tool is NOT available. Use the `bash_exec` MCP tool
|
||||
for shell commands — it runs in a network-isolated sandbox.
|
||||
- **Shared workspace**: The SDK Read/Write tools and `bash_exec` share the
|
||||
same working directory. Files created by one are readable by the other.
|
||||
These files are **ephemeral** — they exist only for the current session.
|
||||
- **Persistent storage**: Use `write_workspace_file` / `read_workspace_file`
|
||||
for files that should persist across sessions (stored in cloud storage).
|
||||
- Long-running tools (create_agent, edit_agent, etc.) are handled
|
||||
asynchronously. You will receive an immediate response; the actual result
|
||||
is delivered to the user via a background stream.
|
||||
"""
|
||||
|
||||
|
||||
def _build_long_running_callback(user_id: str | None) -> LongRunningCallback:
|
||||
"""Build a callback that delegates long-running tools to the non-SDK infrastructure.
|
||||
|
||||
Long-running tools (create_agent, edit_agent, etc.) are delegated to the
|
||||
existing background infrastructure: stream_registry (Redis Streams),
|
||||
database persistence, and SSE reconnection. This means results survive
|
||||
page refreshes / pod restarts, and the frontend shows the proper loading
|
||||
widget with progress updates.
|
||||
|
||||
The returned callback matches the ``LongRunningCallback`` signature:
|
||||
``(tool_name, args, session) -> MCP response dict``.
|
||||
"""
|
||||
|
||||
async def _callback(
|
||||
tool_name: str, args: dict[str, Any], session: ChatSession
|
||||
) -> dict[str, Any]:
|
||||
operation_id = str(uuid.uuid4())
|
||||
task_id = str(uuid.uuid4())
|
||||
tool_call_id = f"sdk-{uuid.uuid4().hex[:12]}"
|
||||
session_id = session.session_id
|
||||
|
||||
# --- Build user-friendly messages (matches non-SDK service) ---
|
||||
if tool_name == "create_agent":
|
||||
desc = args.get("description", "")
|
||||
desc_preview = (desc[:100] + "...") if len(desc) > 100 else desc
|
||||
pending_msg = (
|
||||
f"Creating your agent: {desc_preview}"
|
||||
if desc_preview
|
||||
else "Creating agent... This may take a few minutes."
|
||||
)
|
||||
started_msg = (
|
||||
"Agent creation started. You can close this tab - "
|
||||
"check your library in a few minutes."
|
||||
)
|
||||
elif tool_name == "edit_agent":
|
||||
changes = args.get("changes", "")
|
||||
changes_preview = (changes[:100] + "...") if len(changes) > 100 else changes
|
||||
pending_msg = (
|
||||
f"Editing agent: {changes_preview}"
|
||||
if changes_preview
|
||||
else "Editing agent... This may take a few minutes."
|
||||
)
|
||||
started_msg = (
|
||||
"Agent edit started. You can close this tab - "
|
||||
"check your library in a few minutes."
|
||||
)
|
||||
else:
|
||||
pending_msg = f"Running {tool_name}... This may take a few minutes."
|
||||
started_msg = (
|
||||
f"{tool_name} started. You can close this tab - "
|
||||
"check back in a few minutes."
|
||||
)
|
||||
|
||||
# --- Register task in Redis for SSE reconnection ---
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# --- Save OperationPendingResponse to chat history ---
|
||||
pending_message = ChatMessage(
|
||||
role="tool",
|
||||
content=OperationPendingResponse(
|
||||
message=pending_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
).model_dump_json(),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
session.messages.append(pending_message)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
# --- Spawn background task (reuses non-SDK infrastructure) ---
|
||||
bg_task = asyncio.create_task(
|
||||
_execute_long_running_tool_with_streaming(
|
||||
tool_name=tool_name,
|
||||
parameters=args,
|
||||
tool_call_id=tool_call_id,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
_background_tasks.add(bg_task)
|
||||
bg_task.add_done_callback(_background_tasks.discard)
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
|
||||
logger.info(
|
||||
f"[SDK] Long-running tool {tool_name} delegated to background "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
|
||||
# --- Return OperationStartedResponse as MCP tool result ---
|
||||
# This flows through SDK → response adapter → frontend, triggering
|
||||
# the loading widget with SSE reconnection support.
|
||||
started_json = OperationStartedResponse(
|
||||
message=started_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
task_id=task_id,
|
||||
).model_dump_json()
|
||||
|
||||
return {
|
||||
"content": [{"type": "text", "text": started_json}],
|
||||
"isError": False,
|
||||
}
|
||||
|
||||
return _callback
|
||||
|
||||
|
||||
def _resolve_sdk_model() -> str | None:
|
||||
"""Resolve the model name for the Claude Agent SDK CLI.
|
||||
|
||||
Uses ``config.claude_agent_model`` if set, otherwise derives from
|
||||
``config.model`` by stripping the OpenRouter provider prefix (e.g.,
|
||||
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``).
|
||||
"""
|
||||
if config.claude_agent_model:
|
||||
return config.claude_agent_model
|
||||
model = config.model
|
||||
if "/" in model:
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
|
||||
def _build_sdk_env() -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI process.
|
||||
|
||||
Routes API calls through OpenRouter (or a custom base_url) using
|
||||
the same ``config.api_key`` / ``config.base_url`` as the non-SDK path.
|
||||
This gives per-call token and cost tracking on the OpenRouter dashboard.
|
||||
|
||||
Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth
|
||||
token are both present — otherwise returns an empty dict so the SDK
|
||||
falls back to its default credentials.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
if config.api_key and config.base_url:
|
||||
# Strip /v1 suffix — SDK expects the base URL without a version path
|
||||
base = config.base_url.rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
if not base or not base.startswith("http"):
|
||||
# Invalid base_url — don't override SDK defaults
|
||||
return env
|
||||
env["ANTHROPIC_BASE_URL"] = base
|
||||
env["ANTHROPIC_AUTH_TOKEN"] = config.api_key
|
||||
# Must be explicitly empty so the CLI uses AUTH_TOKEN instead
|
||||
env["ANTHROPIC_API_KEY"] = ""
|
||||
return env
|
||||
|
||||
|
||||
def _make_sdk_cwd(session_id: str) -> str:
|
||||
"""Create a safe, session-specific working directory path.
|
||||
|
||||
Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path`
|
||||
(single source of truth for path sanitization) and adds a defence-in-depth
|
||||
assertion.
|
||||
"""
|
||||
cwd = make_session_path(session_id)
|
||||
# Defence-in-depth: normpath + startswith is a CodeQL-recognised sanitizer
|
||||
cwd = os.path.normpath(cwd)
|
||||
if not cwd.startswith(_SDK_CWD_PREFIX):
|
||||
raise ValueError(f"SDK cwd escaped prefix: {cwd}")
|
||||
return cwd
|
||||
|
||||
|
||||
def _cleanup_sdk_tool_results(cwd: str) -> None:
|
||||
"""Remove SDK tool-result files for a specific session working directory.
|
||||
|
||||
The SDK creates tool-result files under ~/.claude/projects/<encoded-cwd>/tool-results/.
|
||||
We clean only the specific cwd's results to avoid race conditions between
|
||||
concurrent sessions.
|
||||
|
||||
Security: cwd MUST be created by _make_sdk_cwd() which sanitizes session_id.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Validate cwd is under the expected prefix
|
||||
normalized = os.path.normpath(cwd)
|
||||
if not normalized.startswith(_SDK_CWD_PREFIX):
|
||||
logger.warning(f"[SDK] Rejecting cleanup for path outside workspace: {cwd}")
|
||||
return
|
||||
|
||||
# SDK encodes the cwd path by replacing '/' with '-'
|
||||
encoded_cwd = normalized.replace("/", "-")
|
||||
|
||||
# Construct the project directory path (known-safe home expansion)
|
||||
claude_projects = os.path.expanduser("~/.claude/projects")
|
||||
project_dir = os.path.join(claude_projects, encoded_cwd)
|
||||
|
||||
# Security check 3: Validate project_dir is under ~/.claude/projects
|
||||
project_dir = os.path.normpath(project_dir)
|
||||
if not project_dir.startswith(claude_projects):
|
||||
logger.warning(
|
||||
f"[SDK] Rejecting cleanup for escaped project path: {project_dir}"
|
||||
)
|
||||
return
|
||||
|
||||
results_dir = os.path.join(project_dir, "tool-results")
|
||||
if os.path.isdir(results_dir):
|
||||
for filename in os.listdir(results_dir):
|
||||
file_path = os.path.join(results_dir, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Also clean up the temp cwd directory itself
|
||||
try:
|
||||
shutil.rmtree(normalized, ignore_errors=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
async def _compress_conversation_history(
|
||||
session: ChatSession,
|
||||
) -> list[ChatMessage]:
|
||||
"""Compress prior conversation messages if they exceed the token threshold.
|
||||
|
||||
Uses the shared compress_context() from prompt.py which supports:
|
||||
- LLM summarization of old messages (keeps recent ones intact)
|
||||
- Progressive content truncation as fallback
|
||||
- Middle-out deletion as last resort
|
||||
|
||||
Returns the compressed prior messages (everything except the current message).
|
||||
"""
|
||||
prior = session.messages[:-1]
|
||||
if len(prior) < 2:
|
||||
return prior
|
||||
|
||||
from backend.util.prompt import compress_context
|
||||
|
||||
# Convert ChatMessages to dicts for compress_context
|
||||
messages_dict = []
|
||||
for msg in prior:
|
||||
msg_dict: dict[str, Any] = {"role": msg.role}
|
||||
if msg.content:
|
||||
msg_dict["content"] = msg.content
|
||||
if msg.tool_calls:
|
||||
msg_dict["tool_calls"] = msg.tool_calls
|
||||
if msg.tool_call_id:
|
||||
msg_dict["tool_call_id"] = msg.tool_call_id
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
async with openai.AsyncOpenAI(
|
||||
api_key=config.api_key, base_url=config.base_url, timeout=30.0
|
||||
) as client:
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Context compression with LLM failed: {e}")
|
||||
# Fall back to truncation-only (no LLM summarization)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.model,
|
||||
client=None,
|
||||
)
|
||||
|
||||
if result.was_compacted:
|
||||
logger.info(
|
||||
f"[SDK] Context compacted: {result.original_token_count} -> "
|
||||
f"{result.token_count} tokens "
|
||||
f"({result.messages_summarized} summarized, "
|
||||
f"{result.messages_dropped} dropped)"
|
||||
)
|
||||
# Convert compressed dicts back to ChatMessages
|
||||
return [
|
||||
ChatMessage(
|
||||
role=m["role"],
|
||||
content=m.get("content"),
|
||||
tool_calls=m.get("tool_calls"),
|
||||
tool_call_id=m.get("tool_call_id"),
|
||||
)
|
||||
for m in result.messages
|
||||
]
|
||||
|
||||
return prior
|
||||
|
||||
|
||||
def _format_conversation_context(messages: list[ChatMessage]) -> str | None:
|
||||
"""Format conversation messages into a context prefix for the user message.
|
||||
|
||||
Returns a string like:
|
||||
<conversation_history>
|
||||
User: hello
|
||||
You responded: Hi! How can I help?
|
||||
</conversation_history>
|
||||
|
||||
Returns None if there are no messages to format.
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
lines: list[str] = []
|
||||
for msg in messages:
|
||||
if not msg.content:
|
||||
continue
|
||||
if msg.role == "user":
|
||||
lines.append(f"User: {msg.content}")
|
||||
elif msg.role == "assistant":
|
||||
lines.append(f"You responded: {msg.content}")
|
||||
# Skip tool messages — they're internal details
|
||||
|
||||
if not lines:
|
||||
return None
|
||||
|
||||
return "<conversation_history>\n" + "\n".join(lines) + "\n</conversation_history>"
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
tool_call_response: str | None = None, # noqa: ARG001
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
retry_count: int = 0, # noqa: ARG001
|
||||
session: ChatSession | None = None,
|
||||
context: dict[str, str] | None = None, # noqa: ARG001
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Stream chat completion using Claude Agent SDK.
|
||||
|
||||
Drop-in replacement for stream_chat_completion with improved reliability.
|
||||
"""
|
||||
|
||||
if session is None:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
|
||||
if not session:
|
||||
raise NotFoundError(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
if message:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if is_user_message else "assistant", content=message
|
||||
)
|
||||
)
|
||||
if is_user_message:
|
||||
track_user_message(
|
||||
user_id=user_id, session_id=session_id, message_length=len(message)
|
||||
)
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Generate title for new sessions (first user message)
|
||||
if is_user_message and not session.title:
|
||||
user_messages = [m for m in session.messages if m.role == "user"]
|
||||
if len(user_messages) == 1:
|
||||
first_message = user_messages[0].content or message or ""
|
||||
if first_message:
|
||||
task = asyncio.create_task(
|
||||
_update_title_async(session_id, first_message, user_id)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# Build system prompt (reuses non-SDK path with Langfuse support)
|
||||
has_history = len(session.messages) > 1
|
||||
system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=has_history
|
||||
)
|
||||
system_prompt += _SDK_TOOL_SUPPLEMENT
|
||||
message_id = str(uuid.uuid4())
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
yield StreamStart(messageId=message_id, taskId=task_id)
|
||||
|
||||
stream_completed = False
|
||||
# Initialise sdk_cwd before the try so the finally can reference it
|
||||
# even if _make_sdk_cwd raises (in that case it stays as "").
|
||||
sdk_cwd = ""
|
||||
use_resume = False
|
||||
|
||||
try:
|
||||
# Use a session-specific temp dir to avoid cleanup race conditions
|
||||
# between concurrent sessions.
|
||||
sdk_cwd = _make_sdk_cwd(session_id)
|
||||
os.makedirs(sdk_cwd, exist_ok=True)
|
||||
|
||||
set_execution_context(
|
||||
user_id,
|
||||
session,
|
||||
long_running_callback=_build_long_running_callback(user_id),
|
||||
)
|
||||
try:
|
||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||
|
||||
# Fail fast when no API credentials are available at all
|
||||
sdk_env = _build_sdk_env()
|
||||
if not sdk_env and not os.environ.get("ANTHROPIC_API_KEY"):
|
||||
raise RuntimeError(
|
||||
"No API key configured. Set OPEN_ROUTER_API_KEY "
|
||||
"(or CHAT_API_KEY) for OpenRouter routing, "
|
||||
"or ANTHROPIC_API_KEY for direct Anthropic access."
|
||||
)
|
||||
|
||||
mcp_server = create_copilot_mcp_server()
|
||||
|
||||
sdk_model = _resolve_sdk_model()
|
||||
|
||||
# --- Transcript capture via Stop hook ---
|
||||
captured_transcript = CapturedTranscript()
|
||||
|
||||
def _on_stop(transcript_path: str, sdk_session_id: str) -> None:
|
||||
captured_transcript.path = transcript_path
|
||||
captured_transcript.sdk_session_id = sdk_session_id
|
||||
|
||||
security_hooks = create_security_hooks(
|
||||
user_id,
|
||||
sdk_cwd=sdk_cwd,
|
||||
max_subtasks=config.claude_agent_max_subtasks,
|
||||
on_stop=_on_stop if config.claude_agent_use_resume else None,
|
||||
)
|
||||
|
||||
# --- Resume strategy: download transcript from bucket ---
|
||||
resume_file: str | None = None
|
||||
use_resume = False
|
||||
|
||||
if config.claude_agent_use_resume and user_id and len(session.messages) > 1:
|
||||
transcript_content = await download_transcript(user_id, session_id)
|
||||
if transcript_content and validate_transcript(transcript_content):
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
transcript_content, session_id, sdk_cwd
|
||||
)
|
||||
if resume_file:
|
||||
use_resume = True
|
||||
logger.info(
|
||||
f"[SDK] Using --resume with transcript "
|
||||
f"({len(transcript_content)} bytes)"
|
||||
)
|
||||
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
"system_prompt": system_prompt,
|
||||
"mcp_servers": {"copilot": mcp_server},
|
||||
"allowed_tools": COPILOT_TOOL_NAMES,
|
||||
"disallowed_tools": SDK_DISALLOWED_TOOLS,
|
||||
"hooks": security_hooks,
|
||||
"cwd": sdk_cwd,
|
||||
"max_buffer_size": config.claude_agent_max_buffer_size,
|
||||
}
|
||||
if sdk_env:
|
||||
sdk_options_kwargs["model"] = sdk_model
|
||||
sdk_options_kwargs["env"] = sdk_env
|
||||
if use_resume and resume_file:
|
||||
sdk_options_kwargs["resume"] = resume_file
|
||||
|
||||
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type]
|
||||
|
||||
adapter = SDKResponseAdapter(message_id=message_id)
|
||||
adapter.set_task_id(task_id)
|
||||
|
||||
async with ClaudeSDKClient(options=options) as client:
|
||||
current_message = message or ""
|
||||
if not current_message and session.messages:
|
||||
last_user = [m for m in session.messages if m.role == "user"]
|
||||
if last_user:
|
||||
current_message = last_user[-1].content or ""
|
||||
|
||||
if not current_message.strip():
|
||||
yield StreamError(
|
||||
errorText="Message cannot be empty.",
|
||||
code="empty_prompt",
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
# Build query: with --resume the CLI already has full
|
||||
# context, so we only send the new message. Without
|
||||
# resume, compress history into a context prefix.
|
||||
query_message = current_message
|
||||
if not use_resume and len(session.messages) > 1:
|
||||
logger.warning(
|
||||
f"[SDK] Using compression fallback for session "
|
||||
f"{session_id} ({len(session.messages)} messages) — "
|
||||
f"no transcript available for --resume"
|
||||
)
|
||||
compressed = await _compress_conversation_history(session)
|
||||
history_context = _format_conversation_context(compressed)
|
||||
if history_context:
|
||||
query_message = (
|
||||
f"{history_context}\n\n"
|
||||
f"Now, the user says:\n{current_message}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[SDK] Sending query ({len(session.messages)} msgs in session)"
|
||||
)
|
||||
logger.debug(f"[SDK] Query preview: {current_message[:80]!r}")
|
||||
await client.query(query_message, session_id=session_id)
|
||||
|
||||
assistant_response = ChatMessage(role="assistant", content="")
|
||||
accumulated_tool_calls: list[dict[str, Any]] = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
|
||||
async for sdk_msg in client.receive_messages():
|
||||
logger.debug(
|
||||
f"[SDK] Received: {type(sdk_msg).__name__} "
|
||||
f"{getattr(sdk_msg, 'subtype', '')}"
|
||||
)
|
||||
for response in adapter.convert_message(sdk_msg):
|
||||
if isinstance(response, StreamStart):
|
||||
continue
|
||||
|
||||
yield response
|
||||
|
||||
if isinstance(response, StreamTextDelta):
|
||||
delta = response.delta or ""
|
||||
# After tool results, start a new assistant
|
||||
# message for the post-tool text.
|
||||
if has_tool_results and has_appended_assistant:
|
||||
assistant_response = ChatMessage(
|
||||
role="assistant", content=delta
|
||||
)
|
||||
accumulated_tool_calls = []
|
||||
has_appended_assistant = False
|
||||
has_tool_results = False
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
else:
|
||||
assistant_response.content = (
|
||||
assistant_response.content or ""
|
||||
) + delta
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolInputAvailable):
|
||||
accumulated_tool_calls.append(
|
||||
{
|
||||
"id": response.toolCallId,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": response.toolName,
|
||||
"arguments": json.dumps(response.input or {}),
|
||||
},
|
||||
}
|
||||
)
|
||||
assistant_response.tool_calls = accumulated_tool_calls
|
||||
if not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
has_appended_assistant = True
|
||||
|
||||
elif isinstance(response, StreamToolOutputAvailable):
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
content=(
|
||||
response.output
|
||||
if isinstance(response.output, str)
|
||||
else str(response.output)
|
||||
),
|
||||
tool_call_id=response.toolCallId,
|
||||
)
|
||||
)
|
||||
has_tool_results = True
|
||||
|
||||
elif isinstance(response, StreamFinish):
|
||||
stream_completed = True
|
||||
|
||||
if stream_completed:
|
||||
break
|
||||
|
||||
if (
|
||||
assistant_response.content or assistant_response.tool_calls
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# --- Capture transcript while CLI is still alive ---
|
||||
# Must happen INSIDE async with: close() sends SIGTERM
|
||||
# which kills the CLI before it can flush the JSONL.
|
||||
if (
|
||||
config.claude_agent_use_resume
|
||||
and user_id
|
||||
and captured_transcript.available
|
||||
):
|
||||
# Give CLI time to flush JSONL writes before we read
|
||||
await asyncio.sleep(0.5)
|
||||
raw_transcript = read_transcript_file(captured_transcript.path)
|
||||
if raw_transcript:
|
||||
task = asyncio.create_task(
|
||||
_upload_transcript_bg(user_id, session_id, raw_transcript)
|
||||
)
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
else:
|
||||
logger.debug("[SDK] Stop hook fired but transcript not usable")
|
||||
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"claude-agent-sdk is not installed. "
|
||||
"Disable SDK mode (CHAT_USE_CLAUDE_AGENT_SDK=false) "
|
||||
"to use the OpenAI-compatible fallback."
|
||||
)
|
||||
|
||||
await upsert_chat_session(session)
|
||||
logger.debug(
|
||||
f"[SDK] Session {session_id} saved with {len(session.messages)} messages"
|
||||
)
|
||||
if not stream_completed:
|
||||
yield StreamFinish()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SDK] Error: {e}", exc_info=True)
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception as save_err:
|
||||
logger.error(f"[SDK] Failed to save session on error: {save_err}")
|
||||
yield StreamError(
|
||||
errorText="An error occurred. Please try again.",
|
||||
code="sdk_error",
|
||||
)
|
||||
yield StreamFinish()
|
||||
finally:
|
||||
if sdk_cwd:
|
||||
_cleanup_sdk_tool_results(sdk_cwd)
|
||||
|
||||
|
||||
async def _upload_transcript_bg(
|
||||
user_id: str, session_id: str, raw_content: str
|
||||
) -> None:
|
||||
"""Background task to strip progress entries and upload transcript."""
|
||||
try:
|
||||
await upload_transcript(user_id, session_id, raw_content)
|
||||
except Exception as e:
|
||||
logger.error(f"[SDK] Failed to upload transcript for {session_id}: {e}")
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
session_id: str, message: str, user_id: str | None = None
|
||||
) -> None:
|
||||
"""Background task to update session title."""
|
||||
try:
|
||||
title = await _generate_session_title(
|
||||
message, user_id=user_id, session_id=session_id
|
||||
)
|
||||
if title:
|
||||
await update_session_title(session_id, title)
|
||||
logger.debug(f"[SDK] Generated title for {session_id}: {title}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[SDK] Failed to update session title: {e}")
|
||||
@@ -9,7 +9,6 @@ via a callback provided by the service layer. This avoids wasteful SDK polling
|
||||
and makes results survive page refreshes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
@@ -19,9 +18,9 @@ from collections.abc import Awaitable, Callable
|
||||
from contextvars import ContextVar
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools import TOOL_REGISTRY
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,17 +41,9 @@ _current_session: ContextVar[ChatSession | None] = ContextVar(
|
||||
# Stash for MCP tool outputs before the SDK potentially truncates them.
|
||||
# Keyed by tool_name → full output string. Consumed (popped) by the
|
||||
# response adapter when it builds StreamToolOutputAvailable.
|
||||
_pending_tool_outputs: ContextVar[dict[str, list[str]]] = ContextVar(
|
||||
_pending_tool_outputs: ContextVar[dict[str, str]] = ContextVar(
|
||||
"pending_tool_outputs", default=None # type: ignore[arg-type]
|
||||
)
|
||||
# Event signaled whenever stash_pending_tool_output() adds a new entry.
|
||||
# Used by the streaming loop to wait for PostToolUse hooks to complete
|
||||
# instead of sleeping an arbitrary duration. The SDK fires hooks via
|
||||
# start_soon (fire-and-forget) so the next message can arrive before
|
||||
# the hook stashes its output — this event bridges that gap.
|
||||
_stash_event: ContextVar[asyncio.Event | None] = ContextVar(
|
||||
"_stash_event", default=None
|
||||
)
|
||||
|
||||
# Callback type for delegating long-running tools to the non-SDK infrastructure.
|
||||
# Args: (tool_name, arguments, session) → MCP-formatted response dict.
|
||||
@@ -85,7 +76,6 @@ def set_execution_context(
|
||||
_current_user_id.set(user_id)
|
||||
_current_session.set(session)
|
||||
_pending_tool_outputs.set({})
|
||||
_stash_event.set(asyncio.Event())
|
||||
_long_running_callback.set(long_running_callback)
|
||||
|
||||
|
||||
@@ -98,89 +88,19 @@ def get_execution_context() -> tuple[str | None, ChatSession | None]:
|
||||
|
||||
|
||||
def pop_pending_tool_output(tool_name: str) -> str | None:
|
||||
"""Pop and return the oldest stashed output for *tool_name*.
|
||||
"""Pop and return the stashed full output for *tool_name*.
|
||||
|
||||
The SDK CLI may truncate large tool results (writing them to disk and
|
||||
replacing the content with a file reference). This stash keeps the
|
||||
original MCP output so the response adapter can forward it to the
|
||||
frontend for proper widget rendering.
|
||||
|
||||
Uses a FIFO queue per tool name so duplicate calls to the same tool
|
||||
in one turn each get their own output.
|
||||
|
||||
Returns ``None`` if nothing was stashed for *tool_name*.
|
||||
"""
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is None:
|
||||
return None
|
||||
queue = pending.get(tool_name)
|
||||
if not queue:
|
||||
pending.pop(tool_name, None)
|
||||
return None
|
||||
value = queue.pop(0)
|
||||
if not queue:
|
||||
del pending[tool_name]
|
||||
return value
|
||||
|
||||
|
||||
def stash_pending_tool_output(tool_name: str, output: Any) -> None:
|
||||
"""Stash tool output for later retrieval by the response adapter.
|
||||
|
||||
Used by the PostToolUse hook to capture SDK built-in tool outputs
|
||||
(WebSearch, Read, etc.) that aren't available through the MCP stash
|
||||
mechanism in ``_execute_tool_sync``.
|
||||
|
||||
Appends to a FIFO queue per tool name so multiple calls to the same
|
||||
tool in one turn are all preserved.
|
||||
"""
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is None:
|
||||
return
|
||||
if isinstance(output, str):
|
||||
text = output
|
||||
else:
|
||||
try:
|
||||
text = json.dumps(output)
|
||||
except (TypeError, ValueError):
|
||||
text = str(output)
|
||||
pending.setdefault(tool_name, []).append(text)
|
||||
# Signal any waiters that new output is available.
|
||||
event = _stash_event.get(None)
|
||||
if event is not None:
|
||||
event.set()
|
||||
|
||||
|
||||
async def wait_for_stash(timeout: float = 0.5) -> bool:
|
||||
"""Wait for a PostToolUse hook to stash tool output.
|
||||
|
||||
The SDK fires PostToolUse hooks asynchronously via ``start_soon()`` —
|
||||
the next message (AssistantMessage/ResultMessage) can arrive before the
|
||||
hook completes and stashes its output. This function bridges that gap
|
||||
by waiting on the ``_stash_event``, which is signaled by
|
||||
:func:`stash_pending_tool_output`.
|
||||
|
||||
After the event fires, callers should ``await asyncio.sleep(0)`` to
|
||||
give any remaining concurrent hooks a chance to complete.
|
||||
|
||||
Returns ``True`` if a stash signal was received, ``False`` on timeout.
|
||||
The timeout is a safety net — normally the stash happens within
|
||||
microseconds of yielding to the event loop.
|
||||
"""
|
||||
event = _stash_event.get(None)
|
||||
if event is None:
|
||||
return False
|
||||
# Fast path: hook already completed before we got here.
|
||||
if event.is_set():
|
||||
event.clear()
|
||||
return True
|
||||
# Slow path: wait for the hook to signal.
|
||||
try:
|
||||
async with asyncio.timeout(timeout):
|
||||
await event.wait()
|
||||
event.clear()
|
||||
return True
|
||||
except TimeoutError:
|
||||
return False
|
||||
return pending.pop(tool_name, None)
|
||||
|
||||
|
||||
async def _execute_tool_sync(
|
||||
@@ -205,63 +125,14 @@ async def _execute_tool_sync(
|
||||
# Stash the full output before the SDK potentially truncates it.
|
||||
pending = _pending_tool_outputs.get(None)
|
||||
if pending is not None:
|
||||
pending.setdefault(base_tool.name, []).append(text)
|
||||
|
||||
content_blocks: list[dict[str, str]] = [{"type": "text", "text": text}]
|
||||
|
||||
# If the tool result contains inline image data, add an MCP image block
|
||||
# so Claude can "see" the image (e.g. read_workspace_file on a small PNG).
|
||||
image_block = _extract_image_block(text)
|
||||
if image_block:
|
||||
content_blocks.append(image_block)
|
||||
pending[base_tool.name] = text
|
||||
|
||||
return {
|
||||
"content": content_blocks,
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"isError": not result.success,
|
||||
}
|
||||
|
||||
|
||||
# MIME types that Claude can process as image content blocks.
|
||||
_SUPPORTED_IMAGE_TYPES = frozenset(
|
||||
{"image/png", "image/jpeg", "image/gif", "image/webp"}
|
||||
)
|
||||
|
||||
|
||||
def _extract_image_block(text: str) -> dict[str, str] | None:
|
||||
"""Extract an MCP image content block from a tool result JSON string.
|
||||
|
||||
Detects workspace file responses with ``content_base64`` and an image
|
||||
MIME type, returning an MCP-format image block that allows Claude to
|
||||
"see" the image. Returns ``None`` if the result is not an inline image.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
mime_type = data.get("mime_type", "")
|
||||
base64_content = data.get("content_base64", "")
|
||||
|
||||
# Only inline small images — large ones would exceed Claude's limits.
|
||||
# 32 KB raw ≈ ~43 KB base64.
|
||||
_MAX_IMAGE_BASE64_BYTES = 43_000
|
||||
if (
|
||||
mime_type in _SUPPORTED_IMAGE_TYPES
|
||||
and base64_content
|
||||
and len(base64_content) <= _MAX_IMAGE_BASE64_BYTES
|
||||
):
|
||||
return {
|
||||
"type": "image",
|
||||
"data": base64_content,
|
||||
"mimeType": mime_type,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _mcp_error(message: str) -> dict[str, Any]:
|
||||
return {
|
||||
"content": [
|
||||
@@ -440,29 +311,14 @@ def create_copilot_mcp_server():
|
||||
# which provides kernel-level network isolation via unshare --net.
|
||||
# Task allows spawning sub-agents (rate-limited by security hooks).
|
||||
# WebSearch uses Brave Search via Anthropic's API — safe, no SSRF risk.
|
||||
# TodoWrite manages the task checklist shown in the UI — no security concern.
|
||||
_SDK_BUILTIN_TOOLS = [
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Glob",
|
||||
"Grep",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
"TodoWrite",
|
||||
]
|
||||
_SDK_BUILTIN_TOOLS = ["Read", "Write", "Edit", "Glob", "Grep", "Task", "WebSearch"]
|
||||
|
||||
# SDK built-in tools that must be explicitly blocked.
|
||||
# Bash: dangerous — agent uses mcp__copilot__bash_exec with kernel-level
|
||||
# network isolation (unshare --net) instead.
|
||||
# WebFetch: SSRF risk — can reach internal network (localhost, 10.x, etc.).
|
||||
# Agent uses the SSRF-protected mcp__copilot__web_fetch tool instead.
|
||||
# AskUserQuestion: interactive CLI tool — no terminal in copilot context.
|
||||
SDK_DISALLOWED_TOOLS = [
|
||||
"Bash",
|
||||
"WebFetch",
|
||||
"AskUserQuestion",
|
||||
]
|
||||
SDK_DISALLOWED_TOOLS = ["Bash", "WebFetch"]
|
||||
|
||||
# Tools that are blocked entirely in security hooks (defence-in-depth).
|
||||
# Includes SDK_DISALLOWED_TOOLS plus common aliases/synonyms.
|
||||
@@ -14,8 +14,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,16 +31,6 @@ STRIPPABLE_TYPES = frozenset(
|
||||
{"progress", "file-history-snapshot", "queue-operation", "summary", "pr-link"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptDownload:
|
||||
"""Result of downloading a transcript with its metadata."""
|
||||
|
||||
content: str
|
||||
message_count: int = 0 # session.messages length when uploaded
|
||||
uploaded_at: float = 0.0 # epoch timestamp of upload
|
||||
|
||||
|
||||
# Workspace storage constants — deterministic path from session_id.
|
||||
TRANSCRIPT_STORAGE_PREFIX = "chat-transcripts"
|
||||
|
||||
@@ -131,21 +119,22 @@ def read_transcript_file(transcript_path: str) -> str | None:
|
||||
content = f.read()
|
||||
|
||||
if not content.strip():
|
||||
logger.debug("[Transcript] File is empty: %s", transcript_path)
|
||||
logger.debug(f"[Transcript] Empty file: {transcript_path}")
|
||||
return None
|
||||
|
||||
lines = content.strip().split("\n")
|
||||
|
||||
# Validate that the transcript has real conversation content
|
||||
# (not just metadata like queue-operation entries).
|
||||
if not validate_transcript(content):
|
||||
if len(lines) < 3:
|
||||
# Raw files with ≤2 lines are metadata-only
|
||||
# (queue-operation + file-history-snapshot, no conversation).
|
||||
logger.debug(
|
||||
"[Transcript] No conversation content (%d lines) in %s",
|
||||
len(lines),
|
||||
transcript_path,
|
||||
f"[Transcript] Too few lines ({len(lines)}): {transcript_path}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Quick structural validation — parse first and last lines.
|
||||
json.loads(lines[0])
|
||||
json.loads(lines[-1])
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Read {len(lines)} lines, "
|
||||
f"{len(content)} bytes from {transcript_path}"
|
||||
@@ -171,41 +160,6 @@ def _sanitize_id(raw_id: str, max_len: int = 36) -> str:
|
||||
_SAFE_CWD_PREFIX = os.path.realpath("/tmp/copilot-")
|
||||
|
||||
|
||||
def _encode_cwd_for_cli(cwd: str) -> str:
|
||||
"""Encode a working directory path the same way the Claude CLI does.
|
||||
|
||||
The CLI replaces all non-alphanumeric characters with ``-``.
|
||||
"""
|
||||
return re.sub(r"[^a-zA-Z0-9]", "-", os.path.realpath(cwd))
|
||||
|
||||
|
||||
def cleanup_cli_project_dir(sdk_cwd: str) -> None:
|
||||
"""Remove the CLI's project directory for a specific working directory.
|
||||
|
||||
The CLI stores session data under ``~/.claude/projects/<encoded_cwd>/``.
|
||||
Each SDK turn uses a unique ``sdk_cwd``, so the project directory is
|
||||
safe to remove entirely after the transcript has been uploaded.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
cwd_encoded = _encode_cwd_for_cli(sdk_cwd)
|
||||
config_dir = os.environ.get("CLAUDE_CONFIG_DIR") or os.path.expanduser("~/.claude")
|
||||
projects_base = os.path.realpath(os.path.join(config_dir, "projects"))
|
||||
project_dir = os.path.realpath(os.path.join(projects_base, cwd_encoded))
|
||||
|
||||
if not project_dir.startswith(projects_base + os.sep):
|
||||
logger.warning(
|
||||
f"[Transcript] Cleanup path escaped projects base: {project_dir}"
|
||||
)
|
||||
return
|
||||
|
||||
if os.path.isdir(project_dir):
|
||||
shutil.rmtree(project_dir, ignore_errors=True)
|
||||
logger.debug(f"[Transcript] Cleaned up CLI project dir: {project_dir}")
|
||||
else:
|
||||
logger.debug(f"[Transcript] Project dir not found: {project_dir}")
|
||||
|
||||
|
||||
def write_transcript_to_tempfile(
|
||||
transcript_content: str,
|
||||
session_id: str,
|
||||
@@ -294,15 +248,6 @@ def _storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
)
|
||||
|
||||
|
||||
def _meta_storage_path_parts(user_id: str, session_id: str) -> tuple[str, str, str]:
|
||||
"""Return (workspace_id, file_id, filename) for a session's transcript metadata."""
|
||||
return (
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
_sanitize_id(user_id),
|
||||
f"{_sanitize_id(session_id)}.meta.json",
|
||||
)
|
||||
|
||||
|
||||
def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
"""Build the full storage path string that ``retrieve()`` expects.
|
||||
|
||||
@@ -323,30 +268,21 @@ def _build_storage_path(user_id: str, session_id: str, backend: object) -> str:
|
||||
return f"local://{wid}/{fid}/{fname}"
|
||||
|
||||
|
||||
async def upload_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
content: str,
|
||||
message_count: int = 0,
|
||||
) -> None:
|
||||
async def upload_transcript(user_id: str, session_id: str, content: str) -> None:
|
||||
"""Strip progress entries and upload transcript to bucket storage.
|
||||
|
||||
Safety: only overwrites when the new (stripped) transcript is larger than
|
||||
what is already stored. Since JSONL is append-only, the latest transcript
|
||||
is always the longest. This prevents a slow/stale background task from
|
||||
clobbering a newer upload from a concurrent turn.
|
||||
|
||||
Args:
|
||||
message_count: ``len(session.messages)`` at upload time — used by
|
||||
the next turn to detect staleness and compress only the gap.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
stripped = strip_progress_entries(content)
|
||||
if not validate_transcript(stripped):
|
||||
logger.warning(
|
||||
f"[Transcript] Skipping upload — stripped content not valid "
|
||||
f"for session {session_id}"
|
||||
f"[Transcript] Skipping upload — stripped content is not a valid "
|
||||
f"transcript for session {session_id}"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -361,8 +297,9 @@ async def upload_transcript(
|
||||
existing = await storage.retrieve(path)
|
||||
if len(existing) >= new_size:
|
||||
logger.info(
|
||||
f"[Transcript] Skipping upload — existing ({len(existing)}B) "
|
||||
f">= new ({new_size}B) for session {session_id}"
|
||||
f"[Transcript] Skipping upload — existing transcript "
|
||||
f"({len(existing)}B) >= new ({new_size}B) for session "
|
||||
f"{session_id}"
|
||||
)
|
||||
return
|
||||
except (FileNotFoundError, Exception):
|
||||
@@ -374,38 +311,16 @@ async def upload_transcript(
|
||||
filename=fname,
|
||||
content=encoded,
|
||||
)
|
||||
|
||||
# Store metadata alongside the transcript so the next turn can detect
|
||||
# staleness and only compress the gap instead of the full history.
|
||||
# Wrapped in try/except so a metadata write failure doesn't orphan
|
||||
# the already-uploaded transcript — the next turn will just fall back
|
||||
# to full gap fill (msg_count=0).
|
||||
try:
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
await storage.store(
|
||||
workspace_id=mwid,
|
||||
file_id=mfid,
|
||||
filename=mfname,
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transcript] Failed to write metadata for {session_id}: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Uploaded {new_size}B "
|
||||
f"(stripped from {len(content)}B, msg_count={message_count}) "
|
||||
f"for session {session_id}"
|
||||
f"[Transcript] Uploaded {new_size} bytes "
|
||||
f"(stripped from {len(content)}) for session {session_id}"
|
||||
)
|
||||
|
||||
|
||||
async def download_transcript(
|
||||
user_id: str, session_id: str
|
||||
) -> TranscriptDownload | None:
|
||||
"""Download transcript and metadata from bucket storage.
|
||||
async def download_transcript(user_id: str, session_id: str) -> str | None:
|
||||
"""Download transcript from bucket storage.
|
||||
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
Returns the JSONL content string, or ``None`` if not found.
|
||||
"""
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
@@ -415,6 +330,10 @@ async def download_transcript(
|
||||
try:
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
logger.info(
|
||||
f"[Transcript] Downloaded {len(content)} bytes for session {session_id}"
|
||||
)
|
||||
return content
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"[Transcript] No transcript in storage for {session_id}")
|
||||
return None
|
||||
@@ -422,36 +341,6 @@ async def download_transcript(
|
||||
logger.warning(f"[Transcript] Failed to download transcript: {e}")
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
message_count = 0
|
||||
uploaded_at = 0.0
|
||||
try:
|
||||
from backend.util.workspace_storage import GCSWorkspaceStorage
|
||||
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
if isinstance(storage, GCSWorkspaceStorage):
|
||||
blob = f"workspaces/{mwid}/{mfid}/{mfname}"
|
||||
meta_path = f"gcs://{storage.bucket_name}/{blob}"
|
||||
else:
|
||||
meta_path = f"local://{mwid}/{mfid}/{mfname}"
|
||||
|
||||
meta_data = await storage.retrieve(meta_path)
|
||||
meta = json.loads(meta_data.decode("utf-8"))
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except (FileNotFoundError, json.JSONDecodeError, Exception):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
|
||||
logger.info(
|
||||
f"[Transcript] Downloaded {len(content)}B "
|
||||
f"(msg_count={message_count}) for session {session_id}"
|
||||
)
|
||||
return TranscriptDownload(
|
||||
content=content,
|
||||
message_count=message_count,
|
||||
uploaded_at=uploaded_at,
|
||||
)
|
||||
|
||||
|
||||
async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
"""Delete transcript from bucket storage (e.g. after resume failure)."""
|
||||
@@ -27,18 +27,20 @@ from openai.types.chat import (
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
|
||||
from backend.data.db_accessors import chat_db, understanding_db
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import format_understanding_for_prompt
|
||||
from backend.data.understanding import (
|
||||
format_understanding_for_prompt,
|
||||
get_business_understanding,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.settings import AppEnvironment, Settings
|
||||
|
||||
from . import db as chat_db
|
||||
from . import stream_registry
|
||||
from .config import ChatConfig
|
||||
from .model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
ChatSessionInfo,
|
||||
Usage,
|
||||
cache_chat_session,
|
||||
get_chat_session,
|
||||
@@ -118,8 +120,6 @@ Adapt flexibly to the conversation context. Not every interaction requires all s
|
||||
- Find reusable components with `find_block`
|
||||
- Create custom solutions with `create_agent` if nothing suitable exists
|
||||
- Modify existing library agents with `edit_agent`
|
||||
- **When `create_agent` returns `suggested_goal`**: Present the suggestion to the user and ask "Would you like me to proceed with this refined goal?" If they accept, call `create_agent` again with the suggested goal.
|
||||
- **When `create_agent` returns `clarifying_questions`**: After the user answers, call `create_agent` again with the original description AND the answers in the `context` parameter.
|
||||
|
||||
5. **Execute**: Run automations immediately, schedule them, or set up webhooks using `run_agent`. Test specific components with `run_block`.
|
||||
|
||||
@@ -166,11 +166,6 @@ Adapt flexibly to the conversation context. Not every interaction requires all s
|
||||
- Use `add_understanding` to capture valuable business context
|
||||
- When tool calls fail, try alternative approaches
|
||||
|
||||
**Handle Feedback Loops:**
|
||||
- When a tool returns a suggested alternative (like a refined goal), present it clearly and ask the user for confirmation before proceeding
|
||||
- When clarifying questions are answered, immediately re-call the tool with the accumulated context
|
||||
- Don't ask redundant questions if the user has already provided context in the conversation
|
||||
|
||||
## CRITICAL REMINDER
|
||||
|
||||
You are NOT a chatbot. You are NOT documentation. You are a partner who helps busy business owners get value quickly by showing proof through working automations. Bias toward action over explanation."""
|
||||
@@ -268,7 +263,7 @@ async def _build_system_prompt(
|
||||
understanding = None
|
||||
if user_id:
|
||||
try:
|
||||
understanding = await understanding_db().get_business_understanding(user_id)
|
||||
understanding = await get_business_understanding(user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch business understanding: {e}")
|
||||
understanding = None
|
||||
@@ -344,7 +339,7 @@ async def _generate_session_title(
|
||||
async def assign_user_to_session(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
) -> ChatSessionInfo:
|
||||
) -> ChatSession:
|
||||
"""
|
||||
Assign a user to a chat session.
|
||||
"""
|
||||
@@ -352,8 +347,7 @@ async def assign_user_to_session(
|
||||
if not session:
|
||||
raise NotFoundError(f"Session {session_id} not found")
|
||||
session.user_id = user_id
|
||||
session, _ = await upsert_chat_session(session)
|
||||
return session
|
||||
return await upsert_chat_session(session)
|
||||
|
||||
|
||||
async def stream_chat_completion(
|
||||
@@ -434,16 +428,12 @@ async def stream_chat_completion(
|
||||
f"Session {session_id} not found. Please create a new session first."
|
||||
)
|
||||
|
||||
# Append the new message to the session if it's not already there
|
||||
new_message_role = "user" if is_user_message else "assistant"
|
||||
if message and (
|
||||
len(session.messages) == 0
|
||||
or not (
|
||||
session.messages[-1].role == new_message_role
|
||||
and session.messages[-1].content == message
|
||||
if message:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="user" if is_user_message else "assistant", content=message
|
||||
)
|
||||
)
|
||||
):
|
||||
session.messages.append(ChatMessage(role=new_message_role, content=message))
|
||||
logger.info(
|
||||
f"Appended message (role={'user' if is_user_message else 'assistant'}), "
|
||||
f"new message_count={len(session.messages)}"
|
||||
@@ -464,7 +454,7 @@ async def stream_chat_completion(
|
||||
)
|
||||
|
||||
upsert_start = time.monotonic()
|
||||
session, _ = await upsert_chat_session(session)
|
||||
session = await upsert_chat_session(session)
|
||||
upsert_time = (time.monotonic() - upsert_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] upsert_chat_session took {upsert_time:.1f}ms",
|
||||
@@ -690,7 +680,7 @@ async def stream_chat_completion(
|
||||
f"tool_responses={len(tool_response_messages)}"
|
||||
)
|
||||
if messages_to_save_early or has_appended_streaming_message:
|
||||
_ = await upsert_chat_session(session)
|
||||
await upsert_chat_session(session)
|
||||
has_saved_assistant_message = True
|
||||
|
||||
has_yielded_end = True
|
||||
@@ -729,7 +719,7 @@ async def stream_chat_completion(
|
||||
if tool_response_messages:
|
||||
session.messages.extend(tool_response_messages)
|
||||
try:
|
||||
_ = await upsert_chat_session(session)
|
||||
await upsert_chat_session(session)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to save interrupted session {session.session_id}: {e}"
|
||||
@@ -770,7 +760,7 @@ async def stream_chat_completion(
|
||||
if messages_to_save:
|
||||
session.messages.extend(messages_to_save)
|
||||
if messages_to_save or has_appended_streaming_message:
|
||||
_ = await upsert_chat_session(session)
|
||||
await upsert_chat_session(session)
|
||||
|
||||
if not has_yielded_error:
|
||||
error_message = str(e)
|
||||
@@ -854,7 +844,7 @@ async def stream_chat_completion(
|
||||
not has_long_running_tool_call
|
||||
and (messages_to_save or has_appended_streaming_message)
|
||||
):
|
||||
_ = await upsert_chat_session(session)
|
||||
await upsert_chat_session(session)
|
||||
else:
|
||||
logger.info(
|
||||
"Assistant message already saved when StreamFinish was received, "
|
||||
@@ -1233,10 +1223,23 @@ async def _stream_chat_chunks(
|
||||
},
|
||||
)
|
||||
|
||||
# Execute all accumulated tool calls in parallel
|
||||
# Events are yielded as they arrive from each concurrent tool
|
||||
async for event in _execute_tool_calls_parallel(tool_calls, session):
|
||||
yield event
|
||||
# Yield all accumulated tool calls after the stream is complete
|
||||
# This ensures all tool call arguments have been fully received
|
||||
for idx, tool_call in enumerate(tool_calls):
|
||||
try:
|
||||
async for tc in _yield_tool_call(tool_calls, idx, session):
|
||||
yield tc
|
||||
except (orjson.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.error(
|
||||
f"Failed to parse tool call {idx}: {e}",
|
||||
exc_info=True,
|
||||
extra={"tool_call": tool_call},
|
||||
)
|
||||
yield StreamError(
|
||||
errorText=f"Invalid tool call arguments for tool {tool_call.get('function', {}).get('name', 'unknown')}: {e}",
|
||||
)
|
||||
# Re-raise to trigger retry logic in the parent function
|
||||
raise
|
||||
|
||||
total_time = (time_module.perf_counter() - stream_chunks_start) * 1000
|
||||
logger.info(
|
||||
@@ -1314,91 +1317,10 @@ async def _stream_chat_chunks(
|
||||
return
|
||||
|
||||
|
||||
async def _with_optional_lock(
|
||||
lock: asyncio.Lock | None,
|
||||
coro_fn: Any,
|
||||
) -> Any:
|
||||
"""Run *coro_fn()* under *lock* when provided, otherwise run directly."""
|
||||
if lock:
|
||||
async with lock:
|
||||
return await coro_fn()
|
||||
return await coro_fn()
|
||||
|
||||
|
||||
async def _execute_tool_calls_parallel(
|
||||
tool_calls: list[dict[str, Any]],
|
||||
session: ChatSession,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Execute all tool calls concurrently, yielding stream events as they arrive.
|
||||
|
||||
Each tool runs as an ``asyncio.Task``, pushing events into a shared queue.
|
||||
A ``session_lock`` serialises session-state mutations (long-running tool
|
||||
bookkeeping, ``run_agent`` counters).
|
||||
"""
|
||||
queue: asyncio.Queue[StreamBaseResponse | None] = asyncio.Queue()
|
||||
session_lock = asyncio.Lock()
|
||||
n_tools = len(tool_calls)
|
||||
retryable_errors: list[Exception] = []
|
||||
|
||||
async def _run_tool(idx: int) -> None:
|
||||
tool_name = tool_calls[idx].get("function", {}).get("name", "unknown")
|
||||
tool_call_id = tool_calls[idx].get("id", f"unknown_{idx}")
|
||||
try:
|
||||
async for event in _yield_tool_call(tool_calls, idx, session, session_lock):
|
||||
await queue.put(event)
|
||||
except (orjson.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.error(
|
||||
f"Failed to parse tool call {idx} ({tool_name}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
retryable_errors.append(e)
|
||||
except Exception as e:
|
||||
# Infrastructure / setup errors — emit an error output so the
|
||||
# client always sees a terminal event and doesn't hang.
|
||||
logger.error(f"Tool call {idx} ({tool_name}) failed: {e}", exc_info=True)
|
||||
await queue.put(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=ErrorResponse(
|
||||
message=f"Tool execution failed: {e!s}",
|
||||
error=type(e).__name__,
|
||||
session_id=session.session_id,
|
||||
).model_dump_json(),
|
||||
success=False,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await queue.put(None) # sentinel
|
||||
|
||||
tasks = [asyncio.create_task(_run_tool(idx)) for idx in range(n_tools)]
|
||||
try:
|
||||
finished = 0
|
||||
while finished < n_tools:
|
||||
event = await queue.get()
|
||||
if event is None:
|
||||
finished += 1
|
||||
else:
|
||||
yield event
|
||||
if retryable_errors:
|
||||
if len(retryable_errors) > 1:
|
||||
logger.warning(
|
||||
f"{len(retryable_errors)} tool calls had retryable errors; "
|
||||
f"re-raising first to trigger retry"
|
||||
)
|
||||
raise retryable_errors[0]
|
||||
finally:
|
||||
for t in tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
async def _yield_tool_call(
|
||||
tool_calls: list[dict[str, Any]],
|
||||
yield_idx: int,
|
||||
session: ChatSession,
|
||||
session_lock: asyncio.Lock | None = None,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""
|
||||
Yield a tool call and its execution result.
|
||||
@@ -1496,7 +1418,8 @@ async def _yield_tool_call(
|
||||
"check back in a few minutes."
|
||||
)
|
||||
|
||||
# Track appended message for rollback on failure
|
||||
# Track appended messages for rollback on failure
|
||||
assistant_message: ChatMessage | None = None
|
||||
pending_message: ChatMessage | None = None
|
||||
|
||||
# Wrap session save and task creation in try-except to release lock on failure
|
||||
@@ -1511,24 +1434,22 @@ async def _yield_tool_call(
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Attach tool_call and save pending result — lock serialises
|
||||
# concurrent session mutations during parallel execution.
|
||||
async def _save_pending() -> None:
|
||||
nonlocal pending_message
|
||||
session.add_tool_call_to_current_turn(tool_calls[yield_idx])
|
||||
pending_message = ChatMessage(
|
||||
role="tool",
|
||||
content=OperationPendingResponse(
|
||||
message=pending_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
).model_dump_json(),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
session.messages.append(pending_message)
|
||||
_ = await upsert_chat_session(session)
|
||||
# Attach the tool_call to the current turn's assistant message
|
||||
# (or create one if this is a tool-only response with no text).
|
||||
session.add_tool_call_to_current_turn(tool_calls[yield_idx])
|
||||
|
||||
await _with_optional_lock(session_lock, _save_pending)
|
||||
# Then save pending tool result
|
||||
pending_message = ChatMessage(
|
||||
role="tool",
|
||||
content=OperationPendingResponse(
|
||||
message=pending_msg,
|
||||
operation_id=operation_id,
|
||||
tool_name=tool_name,
|
||||
).model_dump_json(),
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
session.messages.append(pending_message)
|
||||
await upsert_chat_session(session)
|
||||
logger.info(
|
||||
f"Saved pending operation {operation_id} (task_id={task_id}) "
|
||||
f"for tool {tool_name} in session {session.session_id}"
|
||||
@@ -1552,25 +1473,27 @@ async def _yield_tool_call(
|
||||
# Associate the asyncio task with the stream registry task
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
except Exception as e:
|
||||
# Roll back appended messages — use identity-based removal so
|
||||
# it works even when other parallel tools have appended after us.
|
||||
async def _rollback() -> None:
|
||||
if pending_message and pending_message in session.messages:
|
||||
session.messages.remove(pending_message)
|
||||
|
||||
await _with_optional_lock(session_lock, _rollback)
|
||||
# Roll back appended messages to prevent data corruption on subsequent saves
|
||||
if (
|
||||
pending_message
|
||||
and session.messages
|
||||
and session.messages[-1] == pending_message
|
||||
):
|
||||
session.messages.pop()
|
||||
if (
|
||||
assistant_message
|
||||
and session.messages
|
||||
and session.messages[-1] == assistant_message
|
||||
):
|
||||
session.messages.pop()
|
||||
|
||||
# Release the Redis lock since the background task won't be spawned
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
# Mark stream registry task as failed if it was created
|
||||
try:
|
||||
await stream_registry.mark_task_completed(
|
||||
task_id,
|
||||
status="failed",
|
||||
error_message=f"Failed to setup tool {tool_name}: {e}",
|
||||
)
|
||||
except Exception as mark_err:
|
||||
logger.warning(f"Failed to mark task {task_id} as failed: {mark_err}")
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(
|
||||
f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True
|
||||
)
|
||||
@@ -1736,11 +1659,7 @@ async def _execute_long_running_tool_with_streaming(
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
if not session:
|
||||
logger.error(f"Session {session_id} not found for background tool")
|
||||
await stream_registry.mark_task_completed(
|
||||
task_id,
|
||||
status="failed",
|
||||
error_message=f"Session {session_id} not found",
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
return
|
||||
|
||||
# Pass operation_id and task_id to the tool for async processing
|
||||
@@ -1851,7 +1770,7 @@ async def _update_pending_operation(
|
||||
This is called by background tasks when long-running operations complete.
|
||||
"""
|
||||
# Update the message in database
|
||||
updated = await chat_db().update_tool_message_content(
|
||||
updated = await chat_db.update_tool_message_content(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
new_content=result,
|
||||
@@ -2020,7 +1939,7 @@ async def _generate_llm_continuation(
|
||||
fresh_session.messages.append(assistant_message)
|
||||
|
||||
# Save to database (not cache) to persist the response
|
||||
_ = await upsert_chat_session(fresh_session)
|
||||
await upsert_chat_session(fresh_session)
|
||||
|
||||
# Invalidate cache so next poll/refresh gets fresh data
|
||||
await invalidate_session_cache(session_id)
|
||||
@@ -2226,7 +2145,7 @@ async def _generate_llm_continuation_with_streaming(
|
||||
fresh_session.messages.append(assistant_message)
|
||||
|
||||
# Save to database (not cache) to persist the response
|
||||
_ = await upsert_chat_session(fresh_session)
|
||||
await upsert_chat_session(fresh_session)
|
||||
|
||||
# Invalidate cache so next poll/refresh gets fresh data
|
||||
await invalidate_session_cache(session_id)
|
||||
@@ -58,7 +58,7 @@ async def test_stream_chat_completion_with_tool_calls(setup_test_user, test_user
|
||||
return pytest.skip("OPEN_ROUTER_API_KEY is not set, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session, _ = await upsert_chat_session(session)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
has_errors = False
|
||||
has_ended = False
|
||||
@@ -104,7 +104,7 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
return pytest.skip("CLAUDE_AGENT_USE_RESUME is not enabled, skipping test")
|
||||
|
||||
session = await create_chat_session(test_user_id)
|
||||
session, _ = await upsert_chat_session(session)
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# --- Turn 1: send a message with a unique keyword ---
|
||||
keyword = "ZEPHYR42"
|
||||
@@ -132,24 +132,18 @@ async def test_sdk_resume_multi_turn(setup_test_user, test_user_id):
|
||||
assert not turn1_errors, f"Turn 1 errors: {turn1_errors}"
|
||||
assert turn1_text, "Turn 1 produced no text"
|
||||
|
||||
# Wait for background upload task to complete (retry up to 5s).
|
||||
# The CLI may not produce a usable transcript for very short
|
||||
# conversations (only metadata entries) — this is environment-dependent
|
||||
# (CLI version, platform). When that happens, multi-turn still works
|
||||
# via conversation compression (non-resume path), but we can't test
|
||||
# the --resume round-trip.
|
||||
# Wait for background upload task to complete (retry up to 5s)
|
||||
transcript = None
|
||||
for _ in range(10):
|
||||
await asyncio.sleep(0.5)
|
||||
transcript = await download_transcript(test_user_id, session.session_id)
|
||||
if transcript:
|
||||
break
|
||||
if not transcript:
|
||||
return pytest.skip(
|
||||
"CLI did not produce a usable transcript — "
|
||||
"cannot test --resume round-trip in this environment"
|
||||
)
|
||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript.content)} bytes")
|
||||
assert transcript, (
|
||||
"Transcript was not uploaded to bucket after turn 1 — "
|
||||
"Stop hook may not have fired or transcript was too small"
|
||||
)
|
||||
logger.info(f"Turn 1 transcript uploaded: {len(transcript)} bytes")
|
||||
|
||||
# Reload session for turn 2
|
||||
session = await get_chat_session(session.session_id, test_user_id)
|
||||
@@ -227,14 +227,7 @@ async def publish_chunk(
|
||||
# Only log timing for significant chunks or slow operations
|
||||
if (
|
||||
chunk_type
|
||||
in (
|
||||
"StreamStart",
|
||||
"StreamFinish",
|
||||
"StreamTextStart",
|
||||
"StreamTextEnd",
|
||||
"StreamToolInputAvailable",
|
||||
"StreamToolOutputAvailable",
|
||||
)
|
||||
in ("StreamStart", "StreamFinish", "StreamTextStart", "StreamTextEnd")
|
||||
or total_time > 50
|
||||
):
|
||||
logger.info(
|
||||
@@ -644,8 +637,6 @@ async def _stream_listener(
|
||||
async def mark_task_completed(
|
||||
task_id: str,
|
||||
status: Literal["completed", "failed"] = "completed",
|
||||
*,
|
||||
error_message: str | None = None,
|
||||
) -> bool:
|
||||
"""Mark a task as completed and publish finish event.
|
||||
|
||||
@@ -656,10 +647,6 @@ async def mark_task_completed(
|
||||
Args:
|
||||
task_id: Task ID to mark as completed
|
||||
status: Final status ("completed" or "failed")
|
||||
error_message: If provided and status="failed", publish a StreamError
|
||||
before StreamFinish so connected clients see why the task ended.
|
||||
If not provided, no StreamError is published (caller should publish
|
||||
manually if needed to avoid duplicates).
|
||||
|
||||
Returns:
|
||||
True if task was newly marked completed, False if already completed/failed
|
||||
@@ -675,17 +662,6 @@ async def mark_task_completed(
|
||||
logger.debug(f"Task {task_id} already completed/failed, skipping")
|
||||
return False
|
||||
|
||||
# Publish error event before finish so connected clients know WHY the
|
||||
# task ended. Only publish if caller provided an explicit error message
|
||||
# to avoid duplicates with code paths that manually publish StreamError.
|
||||
# This is best-effort — if it fails, the StreamFinish still ensures
|
||||
# listeners clean up.
|
||||
if status == "failed" and error_message:
|
||||
try:
|
||||
await publish_chunk(task_id, StreamError(errorText=error_message))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish error event for task {task_id}: {e}")
|
||||
|
||||
# THEN publish finish event (best-effort - listeners can detect via status polling)
|
||||
try:
|
||||
await publish_chunk(task_id, StreamFinish())
|
||||
@@ -838,6 +814,24 @@ async def get_active_task_for_session(
|
||||
if task_user_id and user_id != task_user_id:
|
||||
continue
|
||||
|
||||
# Auto-expire stale tasks that exceeded stream_timeout
|
||||
created_at_str = meta.get("created_at", "")
|
||||
if created_at_str:
|
||||
try:
|
||||
created_at = datetime.fromisoformat(created_at_str)
|
||||
age_seconds = (
|
||||
datetime.now(timezone.utc) - created_at
|
||||
).total_seconds()
|
||||
if age_seconds > config.stream_timeout:
|
||||
logger.warning(
|
||||
f"[TASK_LOOKUP] Auto-expiring stale task {task_id[:8]}... "
|
||||
f"(age={age_seconds:.0f}s > timeout={config.stream_timeout}s)"
|
||||
)
|
||||
await mark_task_completed(task_id, "failed")
|
||||
continue
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
f"[TASK_LOOKUP] Found running task {task_id[:8]}... for session {session_id[:8]}..."
|
||||
)
|
||||
@@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_tool_called
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tracking import track_tool_called
|
||||
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
@@ -31,7 +31,7 @@ from .workspace_files import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -6,11 +6,11 @@ import pytest
|
||||
from prisma.types import ProfileCreateInput
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.llm import AITextGeneratorBlock
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.model import APIKeyCredentials
|
||||
@@ -3,9 +3,11 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import understanding_db
|
||||
from backend.data.understanding import BusinessUnderstandingInput
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.understanding import (
|
||||
BusinessUnderstandingInput,
|
||||
upsert_business_understanding,
|
||||
)
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse
|
||||
@@ -97,9 +99,7 @@ and automations for the user's specific needs."""
|
||||
]
|
||||
|
||||
# Upsert with merge
|
||||
understanding = await understanding_db().upsert_business_understanding(
|
||||
user_id, input_data
|
||||
)
|
||||
understanding = await upsert_business_understanding(user_id, input_data)
|
||||
|
||||
# Build current understanding summary (filter out empty values)
|
||||
current_understanding = {
|
||||
@@ -5,8 +5,9 @@ import re
|
||||
import uuid
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from backend.data.db_accessors import graph_db, library_db, store_db
|
||||
from backend.data.graph import Graph, Link, Node
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .service import (
|
||||
@@ -144,9 +145,8 @@ async def get_library_agent_by_id(
|
||||
Returns:
|
||||
LibraryAgentSummary if found, None otherwise
|
||||
"""
|
||||
db = library_db()
|
||||
try:
|
||||
agent = await db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
@@ -163,7 +163,7 @@ async def get_library_agent_by_id(
|
||||
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||
|
||||
try:
|
||||
agent = await db.get_library_agent(agent_id, user_id)
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
@@ -215,7 +215,7 @@ async def get_library_agents_for_generation(
|
||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||
"""
|
||||
try:
|
||||
response = await library_db().list_library_agents(
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
page=1,
|
||||
@@ -272,7 +272,7 @@ async def search_marketplace_agents_for_generation(
|
||||
List of LibraryAgentSummary with full input/output schemas
|
||||
"""
|
||||
try:
|
||||
response = await store_db().get_store_agents(
|
||||
response = await store_db.get_store_agents(
|
||||
search_query=search_query,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
@@ -286,7 +286,7 @@ async def search_marketplace_agents_for_generation(
|
||||
return []
|
||||
|
||||
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
||||
graphs = await graph_db().get_store_listed_graphs(graph_ids)
|
||||
graphs = await get_store_listed_graphs(*graph_ids)
|
||||
|
||||
results: list[LibraryAgentSummary] = []
|
||||
for agent in agents_with_graphs:
|
||||
@@ -673,10 +673,9 @@ async def save_agent_to_library(
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
graph = json_to_graph(agent_json)
|
||||
db = library_db()
|
||||
if is_update:
|
||||
return await db.update_graph_in_library(graph, user_id)
|
||||
return await db.create_graph_in_library(graph, user_id)
|
||||
return await library_db.update_graph_in_library(graph, user_id)
|
||||
return await library_db.create_graph_in_library(graph, user_id)
|
||||
|
||||
|
||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||
@@ -736,14 +735,12 @@ async def get_agent_as_json(
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
db = graph_db()
|
||||
|
||||
graph = await db.get_graph(agent_id, version=None, user_id=user_id)
|
||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||
|
||||
if not graph and user_id:
|
||||
try:
|
||||
library_agent = await library_db().get_library_agent(agent_id, user_id)
|
||||
graph = await db.get_graph(
|
||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id, version=None, user_id=user_id
|
||||
)
|
||||
except NotFoundError:
|
||||
@@ -7,9 +7,10 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library.model import LibraryAgent
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import execution_db, library_db
|
||||
from backend.data import execution as execution_db
|
||||
from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta
|
||||
|
||||
from .base import BaseTool
|
||||
@@ -164,12 +165,10 @@ class AgentOutputTool(BaseTool):
|
||||
Resolve agent from provided identifiers.
|
||||
Returns (library_agent, error_message).
|
||||
"""
|
||||
lib_db = library_db()
|
||||
|
||||
# Priority 1: Exact library agent ID
|
||||
if library_agent_id:
|
||||
try:
|
||||
agent = await lib_db.get_library_agent(library_agent_id, user_id)
|
||||
agent = await library_db.get_library_agent(library_agent_id, user_id)
|
||||
return agent, None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get library agent by ID: {e}")
|
||||
@@ -183,7 +182,7 @@ class AgentOutputTool(BaseTool):
|
||||
return None, f"Agent '{store_slug}' not found in marketplace"
|
||||
|
||||
# Find in user's library by graph_id
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id)
|
||||
if not agent:
|
||||
return (
|
||||
None,
|
||||
@@ -195,7 +194,7 @@ class AgentOutputTool(BaseTool):
|
||||
# Priority 3: Fuzzy name search in library
|
||||
if agent_name:
|
||||
try:
|
||||
response = await lib_db.list_library_agents(
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=agent_name,
|
||||
page_size=5,
|
||||
@@ -229,11 +228,9 @@ class AgentOutputTool(BaseTool):
|
||||
Fetch execution(s) based on filters.
|
||||
Returns (single_execution, available_executions_meta, error_message).
|
||||
"""
|
||||
exec_db = execution_db()
|
||||
|
||||
# If specific execution_id provided, fetch it directly
|
||||
if execution_id:
|
||||
execution = await exec_db.get_graph_execution(
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=execution_id,
|
||||
include_node_executions=False,
|
||||
@@ -243,7 +240,7 @@ class AgentOutputTool(BaseTool):
|
||||
return execution, [], None
|
||||
|
||||
# Get completed executions with time filters
|
||||
executions = await exec_db.get_graph_executions(
|
||||
executions = await execution_db.get_graph_executions(
|
||||
graph_id=graph_id,
|
||||
user_id=user_id,
|
||||
statuses=[ExecutionStatus.COMPLETED],
|
||||
@@ -257,7 +254,7 @@ class AgentOutputTool(BaseTool):
|
||||
|
||||
# If only one execution, fetch full details
|
||||
if len(executions) == 1:
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
@@ -265,7 +262,7 @@ class AgentOutputTool(BaseTool):
|
||||
return full_execution, [], None
|
||||
|
||||
# Multiple executions - return latest with full details, plus list of available
|
||||
full_execution = await exec_db.get_graph_execution(
|
||||
full_execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=executions[0].id,
|
||||
include_node_executions=False,
|
||||
@@ -383,7 +380,7 @@ class AgentOutputTool(BaseTool):
|
||||
and not input_data.store_slug
|
||||
):
|
||||
# Fetch execution directly to get graph_id
|
||||
execution = await execution_db().get_graph_execution(
|
||||
execution = await execution_db.get_graph_execution(
|
||||
user_id=user_id,
|
||||
execution_id=input_data.execution_id,
|
||||
include_node_executions=False,
|
||||
@@ -395,7 +392,7 @@ class AgentOutputTool(BaseTool):
|
||||
)
|
||||
|
||||
# Find library agent by graph_id
|
||||
agent = await library_db().get_library_agent_by_graph_id(
|
||||
agent = await library_db.get_library_agent_by_graph_id(
|
||||
user_id, execution.graph_id
|
||||
)
|
||||
if not agent:
|
||||
@@ -4,7 +4,8 @@ import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
|
||||
from .models import (
|
||||
@@ -44,10 +45,8 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
Returns:
|
||||
AgentInfo if found, None otherwise
|
||||
"""
|
||||
lib_db = library_db()
|
||||
|
||||
try:
|
||||
agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
@@ -72,7 +71,7 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N
|
||||
)
|
||||
|
||||
try:
|
||||
agent = await lib_db.get_library_agent(agent_id, user_id)
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
@@ -134,7 +133,7 @@ async def search_agents(
|
||||
try:
|
||||
if source == "marketplace":
|
||||
logger.info(f"Searching marketplace for: {query}")
|
||||
results = await store_db().get_store_agents(search_query=query, page_size=5)
|
||||
results = await store_db.get_store_agents(search_query=query, page_size=5)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
@@ -160,7 +159,7 @@ async def search_agents(
|
||||
|
||||
if not agents:
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db().list_library_agents(
|
||||
results = await library_db.list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.response_model import StreamToolOutputAvailable
|
||||
|
||||
from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase
|
||||
|
||||
@@ -11,11 +11,18 @@ available (e.g. macOS development).
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import BashExecResponse, ErrorResponse, ToolResponseBase
|
||||
from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
BashExecResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.api.features.chat.tools.sandbox import (
|
||||
get_workspace_dir,
|
||||
has_full_sandbox,
|
||||
run_sandboxed,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -3,10 +3,13 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
ErrorResponse,
|
||||
ResponseType,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -75,7 +78,7 @@ class CheckOperationStatusTool(BaseTool):
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
from backend.copilot import stream_registry
|
||||
from backend.api.features.chat import stream_registry
|
||||
|
||||
operation_id = (kwargs.get("operation_id") or "").strip()
|
||||
task_id = (kwargs.get("task_id") or "").strip()
|
||||
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -22,7 +22,6 @@ from .models import (
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
SuggestedGoalResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
@@ -187,28 +186,26 @@ class CreateAgentTool(BaseTool):
|
||||
if decomposition_result.get("type") == "unachievable_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get("reason", "")
|
||||
return SuggestedGoalResponse(
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"This goal cannot be accomplished with the available blocks. {reason}"
|
||||
f"This goal cannot be accomplished with the available blocks. "
|
||||
f"{reason} "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
suggested_goal=suggested,
|
||||
reason=reason,
|
||||
original_goal=description,
|
||||
goal_type="unachievable",
|
||||
error="unachievable_goal",
|
||||
details={"suggested_goal": suggested, "reason": reason},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if decomposition_result.get("type") == "vague_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get(
|
||||
"reason", "The goal needs more specific details"
|
||||
)
|
||||
return SuggestedGoalResponse(
|
||||
message="The goal is too vague to create a specific workflow.",
|
||||
suggested_goal=suggested,
|
||||
reason=reason,
|
||||
original_goal=description,
|
||||
goal_type="vague",
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The goal is too vague to create a specific workflow. "
|
||||
f"Suggestion: {suggested}"
|
||||
),
|
||||
error="vague_goal",
|
||||
details={"suggested_goal": suggested},
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import store_db as get_store_db
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -137,8 +137,6 @@ class CustomizeAgentTool(BaseTool):
|
||||
|
||||
creator_username, agent_slug = parts
|
||||
|
||||
store_db = get_store_db()
|
||||
|
||||
# Fetch the marketplace agent details
|
||||
try:
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
@@ -3,7 +3,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
@@ -5,14 +5,9 @@ from typing import Any
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.blocks.linear._api import LinearClient
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import user_db
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
ErrorResponse,
|
||||
FeatureRequestCreatedResponse,
|
||||
FeatureRequestInfo,
|
||||
@@ -20,6 +15,10 @@ from .models import (
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.blocks.linear._api import LinearClient
|
||||
from backend.data.model import APIKeyCredentials
|
||||
from backend.data.user import get_user_email_by_id
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,6 +32,7 @@ query SearchFeatureRequests($term: String!, $filter: IssueFilter, $first: Int) {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
description
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -104,8 +104,8 @@ def _get_linear_config() -> tuple[LinearClient, str, str]:
|
||||
Raises RuntimeError if any required setting is missing.
|
||||
"""
|
||||
secrets = _get_settings().secrets
|
||||
if not secrets.copilot_linear_api_key:
|
||||
raise RuntimeError("COPILOT_LINEAR_API_KEY is not configured")
|
||||
if not secrets.linear_api_key:
|
||||
raise RuntimeError("LINEAR_API_KEY is not configured")
|
||||
if not secrets.linear_feature_request_project_id:
|
||||
raise RuntimeError("LINEAR_FEATURE_REQUEST_PROJECT_ID is not configured")
|
||||
if not secrets.linear_feature_request_team_id:
|
||||
@@ -114,7 +114,7 @@ def _get_linear_config() -> tuple[LinearClient, str, str]:
|
||||
credentials = APIKeyCredentials(
|
||||
id="system-linear",
|
||||
provider="linear",
|
||||
api_key=SecretStr(secrets.copilot_linear_api_key),
|
||||
api_key=SecretStr(secrets.linear_api_key),
|
||||
title="System Linear API Key",
|
||||
)
|
||||
client = LinearClient(credentials=credentials)
|
||||
@@ -204,6 +204,7 @@ class SearchFeatureRequestsTool(BaseTool):
|
||||
id=node["id"],
|
||||
identifier=node["identifier"],
|
||||
title=node["title"],
|
||||
description=node.get("description"),
|
||||
)
|
||||
for node in nodes
|
||||
]
|
||||
@@ -237,11 +238,7 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
"Create a new feature request or add a customer need to an existing one. "
|
||||
"Always search first with search_feature_requests to avoid duplicates. "
|
||||
"If a matching request exists, pass its ID as existing_issue_id to add "
|
||||
"the user's need to it instead of creating a duplicate. "
|
||||
"IMPORTANT: Never include personally identifiable information (PII) in "
|
||||
"the title or description — no names, emails, phone numbers, company "
|
||||
"names, or other identifying details. Write titles and descriptions in "
|
||||
"generic, feature-focused language."
|
||||
"the user's need to it instead of creating a duplicate."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -251,20 +248,11 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Title for the feature request. Must be generic and "
|
||||
"feature-focused — do not include any user names, emails, "
|
||||
"company names, or other PII."
|
||||
),
|
||||
"description": "Title for the feature request.",
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Detailed description of what the user wants and why. "
|
||||
"Must not contain any personally identifiable information "
|
||||
"(PII) — describe the feature need generically without "
|
||||
"referencing specific users, companies, or contact details."
|
||||
),
|
||||
"description": "Detailed description of what the user wants and why.",
|
||||
},
|
||||
"existing_issue_id": {
|
||||
"type": "string",
|
||||
@@ -344,9 +332,7 @@ class CreateFeatureRequestTool(BaseTool):
|
||||
# Resolve a human-readable name (email) for the Linear customer record.
|
||||
# Fall back to user_id if the lookup fails or returns None.
|
||||
try:
|
||||
customer_display_name = (
|
||||
await user_db().get_user_email_by_id(user_id) or user_id
|
||||
)
|
||||
customer_display_name = await get_user_email_by_id(user_id) or user_id
|
||||
except Exception:
|
||||
customer_display_name = user_id
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
"""Tests for SearchFeatureRequestsTool and CreateFeatureRequestTool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ._test_data import make_session
|
||||
from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool
|
||||
from .models import (
|
||||
from backend.api.features.chat.tools.feature_requests import (
|
||||
CreateFeatureRequestTool,
|
||||
SearchFeatureRequestsTool,
|
||||
)
|
||||
from backend.api.features.chat.tools.models import (
|
||||
ErrorResponse,
|
||||
FeatureRequestCreatedResponse,
|
||||
FeatureRequestSearchResponse,
|
||||
NoResultsResponse,
|
||||
)
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-feature-requests"
|
||||
_TEST_USER_EMAIL = "testuser@example.com"
|
||||
|
||||
@@ -35,7 +39,7 @@ def _mock_linear_config(*, query_return=None, mutate_return=None):
|
||||
client.mutate.return_value = mutate_return
|
||||
return (
|
||||
patch(
|
||||
"backend.copilot.tools.feature_requests._get_linear_config",
|
||||
"backend.api.features.chat.tools.feature_requests._get_linear_config",
|
||||
return_value=(client, _FAKE_PROJECT_ID, _FAKE_TEAM_ID),
|
||||
),
|
||||
client,
|
||||
@@ -117,11 +121,13 @@ class TestSearchFeatureRequestsTool:
|
||||
"id": "id-1",
|
||||
"identifier": "FR-1",
|
||||
"title": "Dark mode",
|
||||
"description": "Add dark mode support",
|
||||
},
|
||||
{
|
||||
"id": "id-2",
|
||||
"identifier": "FR-2",
|
||||
"title": "Dark theme",
|
||||
"description": None,
|
||||
},
|
||||
]
|
||||
patcher, _ = _mock_linear_config(query_return=_search_response(nodes))
|
||||
@@ -202,7 +208,7 @@ class TestSearchFeatureRequestsTool:
|
||||
async def test_linear_client_init_failure(self):
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
with patch(
|
||||
"backend.copilot.tools.feature_requests._get_linear_config",
|
||||
"backend.api.features.chat.tools.feature_requests._get_linear_config",
|
||||
side_effect=RuntimeError("No API key"),
|
||||
):
|
||||
tool = SearchFeatureRequestsTool()
|
||||
@@ -225,11 +231,10 @@ class TestCreateFeatureRequestTool:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_email_lookup(self):
|
||||
mock_user_db = MagicMock()
|
||||
mock_user_db.get_user_email_by_id = AsyncMock(return_value=_TEST_USER_EMAIL)
|
||||
with patch(
|
||||
"backend.copilot.tools.feature_requests.user_db",
|
||||
return_value=mock_user_db,
|
||||
"backend.api.features.chat.tools.feature_requests.get_user_email_by_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_TEST_USER_EMAIL,
|
||||
):
|
||||
yield
|
||||
|
||||
@@ -342,7 +347,7 @@ class TestCreateFeatureRequestTool:
|
||||
async def test_linear_client_init_failure(self):
|
||||
session = make_session(user_id=_TEST_USER_ID)
|
||||
with patch(
|
||||
"backend.copilot.tools.feature_requests._get_linear_config",
|
||||
"backend.api.features.chat.tools.feature_requests._get_linear_config",
|
||||
side_effect=RuntimeError("No API key"),
|
||||
):
|
||||
tool = CreateFeatureRequestTool()
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
@@ -3,18 +3,17 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import search
|
||||
|
||||
from .base import BaseTool, ToolResponseBase
|
||||
from .models import (
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase
|
||||
from backend.api.features.chat.tools.models import (
|
||||
BlockInfoSummary,
|
||||
BlockListResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -108,7 +107,7 @@ class FindBlockTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Search for blocks using hybrid search
|
||||
results, total = await search().unified_hybrid_search(
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
@@ -4,15 +4,15 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
from .find_block import (
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
FindBlockTool,
|
||||
)
|
||||
from .models import BlockListResponse
|
||||
from backend.api.features.chat.tools.models import BlockListResponse
|
||||
from backend.blocks._base import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-find-block"
|
||||
|
||||
@@ -84,17 +84,13 @@ class TestFindBlockFiltering:
|
||||
"standard-block-id": standard_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
@@ -132,17 +128,13 @@ class TestFindBlockFiltering:
|
||||
"normal-block-id": normal_block,
|
||||
}.get(block_id)
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, 2)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, 2),
|
||||
):
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=mock_get_block,
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
@@ -361,16 +353,12 @@ class TestFindBlockFiltering:
|
||||
for d in block_defs
|
||||
}
|
||||
|
||||
mock_search_db = MagicMock()
|
||||
mock_search_db.unified_hybrid_search = AsyncMock(
|
||||
return_value=(search_results, len(search_results))
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.find_block.search",
|
||||
return_value=mock_search_db,
|
||||
"backend.api.features.chat.tools.find_block.unified_hybrid_search",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(search_results, len(search_results)),
|
||||
), patch(
|
||||
"backend.copilot.tools.find_block.get_block",
|
||||
"backend.api.features.chat.tools.find_block.get_block",
|
||||
side_effect=lambda bid: mock_blocks.get(bid),
|
||||
):
|
||||
tool = FindBlockTool()
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
|
||||
from .agent_search import search_agents
|
||||
from .base import BaseTool
|
||||
@@ -4,10 +4,13 @@ import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import DocPageResponse, ErrorResponse, ToolResponseBase
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocPageResponse,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -50,8 +50,6 @@ class ResponseType(str, Enum):
|
||||
# Feature request types
|
||||
FEATURE_REQUEST_SEARCH = "feature_request_search"
|
||||
FEATURE_REQUEST_CREATED = "feature_request_created"
|
||||
# Goal refinement
|
||||
SUGGESTED_GOAL = "suggested_goal"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -298,22 +296,6 @@ class ClarificationNeededResponse(ToolResponseBase):
|
||||
questions: list[ClarifyingQuestion] = Field(default_factory=list)
|
||||
|
||||
|
||||
class SuggestedGoalResponse(ToolResponseBase):
|
||||
"""Response when the goal needs refinement with a suggested alternative."""
|
||||
|
||||
type: ResponseType = ResponseType.SUGGESTED_GOAL
|
||||
suggested_goal: str = Field(description="The suggested alternative goal")
|
||||
reason: str = Field(
|
||||
default="", description="Why the original goal needs refinement"
|
||||
)
|
||||
original_goal: str = Field(
|
||||
default="", description="The user's original goal for context"
|
||||
)
|
||||
goal_type: Literal["vague", "unachievable"] = Field(
|
||||
default="vague", description="Type: 'vague' or 'unachievable'"
|
||||
)
|
||||
|
||||
|
||||
# Documentation search models
|
||||
class DocSearchResult(BaseModel):
|
||||
"""A single documentation search result."""
|
||||
@@ -504,6 +486,7 @@ class FeatureRequestInfo(BaseModel):
|
||||
id: str
|
||||
identifier: str
|
||||
title: str
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class FeatureRequestSearchResponse(ToolResponseBase):
|
||||
@@ -5,12 +5,16 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled
|
||||
from backend.data.db_accessors import graph_db, library_db, user_db
|
||||
from backend.api.features.chat.config import ChatConfig
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tracking import (
|
||||
track_agent_run_success,
|
||||
track_agent_scheduled,
|
||||
)
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.util.clients import get_scheduler_client
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
@@ -196,7 +200,7 @@ class RunAgentTool(BaseTool):
|
||||
|
||||
# Priority: library_agent_id if provided
|
||||
if has_library_id:
|
||||
library_agent = await library_db().get_library_agent(
|
||||
library_agent = await library_db.get_library_agent(
|
||||
params.library_agent_id, user_id
|
||||
)
|
||||
if not library_agent:
|
||||
@@ -205,7 +209,9 @@ class RunAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
# Get the graph from the library agent
|
||||
graph = await graph_db().get_graph(
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id,
|
||||
library_agent.graph_version,
|
||||
user_id=user_id,
|
||||
@@ -516,7 +522,7 @@ class RunAgentTool(BaseTool):
|
||||
library_agent = await get_or_create_library_agent(graph, user_id)
|
||||
|
||||
# Get user timezone
|
||||
user = await user_db().get_user_by_id(user_id)
|
||||
user = await get_user_by_id(user_id)
|
||||
user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone)
|
||||
|
||||
# Create schedule
|
||||
@@ -7,17 +7,20 @@ from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.find_block import (
|
||||
COPILOT_EXCLUDED_BLOCK_IDS,
|
||||
COPILOT_EXCLUDED_BLOCK_TYPES,
|
||||
)
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import AnyBlockSchema
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import BlockError
|
||||
|
||||
from .base import BaseTool
|
||||
from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES
|
||||
from .helpers import get_inputs_from_schema
|
||||
from .models import (
|
||||
BlockDetails,
|
||||
@@ -273,7 +276,7 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Get or create user's workspace for CoPilot file operations
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
|
||||
# Generate synthetic IDs for CoPilot context
|
||||
# Each chat session is treated as its own agent with one continuous run
|
||||
@@ -4,16 +4,16 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
from .models import (
|
||||
from backend.api.features.chat.tools.models import (
|
||||
BlockDetailsResponse,
|
||||
BlockOutputResponse,
|
||||
ErrorResponse,
|
||||
InputValidationErrorResponse,
|
||||
)
|
||||
from .run_block import RunBlockTool
|
||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||
from backend.blocks._base import BlockType
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-run-block"
|
||||
|
||||
@@ -77,7 +77,7 @@ class TestRunBlockFiltering:
|
||||
input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=input_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -103,7 +103,7 @@ class TestRunBlockFiltering:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=smart_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -127,7 +127,7 @@ class TestRunBlockFiltering:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=standard_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -183,7 +183,7 @@ class TestRunBlockInputValidation:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -222,7 +222,7 @@ class TestRunBlockInputValidation:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -263,7 +263,7 @@ class TestRunBlockInputValidation:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -302,19 +302,15 @@ class TestRunBlockInputValidation:
|
||||
|
||||
mock_block.execute = mock_execute
|
||||
|
||||
mock_workspace_db = MagicMock()
|
||||
mock_workspace_db.get_or_create_workspace = AsyncMock(
|
||||
return_value=MagicMock(id="test-workspace-id")
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.run_block.workspace_db",
|
||||
return_value=mock_workspace_db,
|
||||
"backend.api.features.chat.tools.run_block.get_or_create_workspace",
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(id="test-workspace-id"),
|
||||
),
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -348,7 +344,7 @@ class TestRunBlockInputValidation:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=mock_block,
|
||||
):
|
||||
tool = RunBlockTool()
|
||||
@@ -13,7 +13,6 @@ import logging
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import signal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -246,7 +245,6 @@ async def run_sandboxed(
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=safe_env,
|
||||
start_new_session=True, # Own process group for clean kill
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -257,18 +255,7 @@ async def run_sandboxed(
|
||||
stderr = stderr_bytes.decode("utf-8", errors="replace")
|
||||
return stdout, stderr, proc.returncode or 0, False
|
||||
except asyncio.TimeoutError:
|
||||
# Kill entire process group (bwrap + all children).
|
||||
# proc.kill() alone only kills the bwrap parent, leaving
|
||||
# children running until they finish naturally.
|
||||
try:
|
||||
os.killpg(proc.pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass # Already exited
|
||||
except OSError as kill_err:
|
||||
logger.warning(
|
||||
"Failed to kill process group %d: %s", proc.pid, kill_err
|
||||
)
|
||||
# Always reap the subprocess regardless of killpg outcome.
|
||||
proc.kill()
|
||||
await proc.communicate()
|
||||
return "", f"Execution timed out after {timeout}s", -1, True
|
||||
|
||||
@@ -5,17 +5,16 @@ from typing import Any
|
||||
|
||||
from prisma.enums import ContentType
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.data.db_accessors import search
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
DocSearchResult,
|
||||
DocSearchResultsResponse,
|
||||
ErrorResponse,
|
||||
NoResultsResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
from backend.api.features.store.hybrid_search import unified_hybrid_search
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -118,7 +117,7 @@ class SearchDocsTool(BaseTool):
|
||||
|
||||
try:
|
||||
# Search using hybrid search for DOCUMENTATION content type only
|
||||
results, total = await search().unified_hybrid_search(
|
||||
results, total = await unified_hybrid_search(
|
||||
query=query,
|
||||
content_types=[ContentType.DOCUMENTATION],
|
||||
page=1,
|
||||
@@ -4,13 +4,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.models import BlockDetailsResponse
|
||||
from backend.api.features.chat.tools.run_block import RunBlockTool
|
||||
from backend.blocks._base import BlockType
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._test_data import make_session
|
||||
from .models import BlockDetailsResponse
|
||||
from .run_block import RunBlockTool
|
||||
|
||||
_TEST_USER_ID = "test-user-run-block-details"
|
||||
|
||||
@@ -61,7 +61,7 @@ async def test_run_block_returns_details_when_no_input_provided():
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=http_block,
|
||||
):
|
||||
# Mock credentials check to return no missing credentials
|
||||
@@ -120,7 +120,7 @@ async def test_run_block_returns_details_when_only_credentials_provided():
|
||||
}
|
||||
|
||||
with patch(
|
||||
"backend.copilot.tools.run_block.get_block",
|
||||
"backend.api.features.chat.tools.run_block.get_block",
|
||||
return_value=mock,
|
||||
):
|
||||
with patch.object(
|
||||
@@ -3,8 +3,9 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.library import model as library_model
|
||||
from backend.data.db_accessors import library_db, store_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import (
|
||||
Credentials,
|
||||
@@ -38,14 +39,13 @@ async def fetch_graph_from_store_slug(
|
||||
Raises:
|
||||
DatabaseError: If there's a database error during lookup.
|
||||
"""
|
||||
sdb = store_db()
|
||||
try:
|
||||
store_agent = await sdb.get_store_agent_details(username, agent_name)
|
||||
store_agent = await store_db.get_store_agent_details(username, agent_name)
|
||||
except NotFoundError:
|
||||
return None, None
|
||||
|
||||
# Get the graph from store listing version
|
||||
graph = await sdb.get_available_graph(
|
||||
graph = await store_db.get_available_graph(
|
||||
store_agent.store_listing_version_id, hide_nodes=False
|
||||
)
|
||||
return graph, store_agent
|
||||
@@ -210,13 +210,13 @@ async def get_or_create_library_agent(
|
||||
Returns:
|
||||
LibraryAgent instance
|
||||
"""
|
||||
existing = await library_db().get_library_agent_by_graph_id(
|
||||
existing = await library_db.get_library_agent_by_graph_id(
|
||||
graph_id=graph.id, user_id=user_id
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
library_agents = await library_db().create_library_agent(
|
||||
library_agents = await library_db.create_library_agent(
|
||||
graph=graph,
|
||||
user_id=user_id,
|
||||
create_library_agents_for_sub_graphs=False,
|
||||
@@ -6,12 +6,15 @@ from typing import Any
|
||||
import aiohttp
|
||||
import html2text
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.chat.tools.base import BaseTool
|
||||
from backend.api.features.chat.tools.models import (
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
WebFetchResponse,
|
||||
)
|
||||
from backend.util.request import Requests
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, WebFetchResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Limits
|
||||
@@ -0,0 +1,626 @@
|
||||
"""CoPilot tools for workspace file operations."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.workspace import get_or_create_workspace
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkspaceFileInfoData(BaseModel):
|
||||
"""Data model for workspace file information (not a response itself)."""
|
||||
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class WorkspaceFileListResponse(ToolResponseBase):
|
||||
"""Response containing list of workspace files."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
|
||||
files: list[WorkspaceFileInfoData]
|
||||
total_count: int
|
||||
|
||||
|
||||
class WorkspaceFileContentResponse(ToolResponseBase):
|
||||
"""Response containing workspace file content (legacy, for small text files)."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
content_base64: str
|
||||
|
||||
|
||||
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
download_url: str
|
||||
preview: str | None = None # First 500 chars for text files
|
||||
|
||||
|
||||
class WorkspaceWriteResponse(ToolResponseBase):
|
||||
"""Response after writing a file to workspace."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class WorkspaceDeleteResponse(ToolResponseBase):
|
||||
"""Response after deleting a file from workspace."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
|
||||
file_id: str
|
||||
success: bool
|
||||
|
||||
|
||||
class ListWorkspaceFilesTool(BaseTool):
|
||||
"""Tool for listing files in user's workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_workspace_files"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List files in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Read/Glob tools instead. "
|
||||
"Returns file names, paths, sizes, and metadata. "
|
||||
"Optionally filter by path prefix."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path_prefix": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional path prefix to filter files "
|
||||
"(e.g., '/documents/' to list only files in documents folder). "
|
||||
"By default, only files from the current session are listed."
|
||||
),
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of files to return (default 50, max 100)",
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
},
|
||||
"include_all_sessions": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, list files from all sessions. "
|
||||
"Default is false (only current session's files)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||
limit = min(kwargs.get("limit", 50), 100)
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
files = await manager.list_files(
|
||||
path=path_prefix,
|
||||
limit=limit,
|
||||
include_all_sessions=include_all_sessions,
|
||||
)
|
||||
total = await manager.get_file_count(
|
||||
path=path_prefix,
|
||||
include_all_sessions=include_all_sessions,
|
||||
)
|
||||
|
||||
file_infos = [
|
||||
WorkspaceFileInfoData(
|
||||
file_id=f.id,
|
||||
name=f.name,
|
||||
path=f.path,
|
||||
mime_type=f.mimeType,
|
||||
size_bytes=f.sizeBytes,
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
|
||||
scope_msg = "all sessions" if include_all_sessions else "current session"
|
||||
return WorkspaceFileListResponse(
|
||||
files=file_infos,
|
||||
total_count=total,
|
||||
message=f"Found {len(files)} files in workspace ({scope_msg})",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to list workspace files: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class ReadWorkspaceFileTool(BaseTool):
|
||||
"""Tool for reading file content from workspace."""
|
||||
|
||||
# Size threshold for returning full content vs metadata+URL
|
||||
# Files larger than this return metadata with download URL to prevent context bloat
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||
# Preview size for text files
|
||||
PREVIEW_SIZE = 500
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a file from the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Read tool instead. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
"force_download_url": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, always return metadata+URL instead of inline content. "
|
||||
"Default is false (auto-selects based on file size/type)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
def _is_text_mime_type(self, mime_type: str) -> bool:
|
||||
"""Check if the MIME type is a text-based type."""
|
||||
text_types = [
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-python",
|
||||
"application/x-sh",
|
||||
]
|
||||
return any(mime_type.startswith(t) for t in text_types)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
# Get file info
|
||||
if file_id:
|
||||
file_info = await manager.get_file_info(file_id)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
# Decide whether to return inline content or metadata+URL
|
||||
is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text_file = self._is_text_mime_type(file_info.mimeType)
|
||||
|
||||
# Return inline content for small text files (unless force_download_url)
|
||||
if is_small_file and is_text_file and not force_download_url:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
content_b64 = base64.b64encode(content).decode("utf-8")
|
||||
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mimeType,
|
||||
content_base64=content_b64,
|
||||
message=f"Successfully read file: {file_info.name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Return metadata + workspace:// reference for large or binary files
|
||||
# This prevents context bloat (100KB file = ~133KB as base64)
|
||||
# Use workspace:// format so frontend urlTransform can add proxy prefix
|
||||
download_url = f"workspace://{target_file_id}"
|
||||
|
||||
# Generate preview for text files
|
||||
preview: str | None = None
|
||||
if is_text_file:
|
||||
try:
|
||||
content = await manager.read_file_by_id(target_file_id)
|
||||
preview_text = content[: self.PREVIEW_SIZE].decode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
if len(content) > self.PREVIEW_SIZE:
|
||||
preview_text += "..."
|
||||
preview = preview_text
|
||||
except Exception:
|
||||
pass # Preview is optional
|
||||
|
||||
return WorkspaceFileMetadataResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mimeType,
|
||||
size_bytes=file_info.sizeBytes,
|
||||
download_url=download_url,
|
||||
preview=preview,
|
||||
message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class WriteWorkspaceFileTool(BaseTool):
|
||||
"""Tool for writing files to workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write or create a file in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Write tool instead. "
|
||||
"Provide the content as a base64-encoded string. "
|
||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||
"Files are saved to the current session's folder by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name for the file (e.g., 'report.pdf')",
|
||||
},
|
||||
"content_base64": {
|
||||
"type": "string",
|
||||
"description": "Base64-encoded file content",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional virtual path where to save the file "
|
||||
"(e.g., '/documents/report.pdf'). "
|
||||
"Defaults to '/{filename}'. Scoped to current session."
|
||||
),
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional MIME type of the file. "
|
||||
"Auto-detected from filename if not provided."
|
||||
),
|
||||
},
|
||||
"overwrite": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||
},
|
||||
},
|
||||
"required": ["filename", "content_base64"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
filename: str = kwargs.get("filename", "")
|
||||
content_b64: str = kwargs.get("content_base64", "")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
mime_type: Optional[str] = kwargs.get("mime_type")
|
||||
overwrite: bool = kwargs.get("overwrite", False)
|
||||
|
||||
if not filename:
|
||||
return ErrorResponse(
|
||||
message="Please provide a filename",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not content_b64:
|
||||
return ErrorResponse(
|
||||
message="Please provide content_base64",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Decode content
|
||||
try:
|
||||
content = base64.b64decode(content_b64)
|
||||
except Exception:
|
||||
return ErrorResponse(
|
||||
message="Invalid base64-encoded content",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check size
|
||||
max_file_size = Config().max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_file_size:
|
||||
return ErrorResponse(
|
||||
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# Virus scan
|
||||
await scan_content_safe(content, filename=filename)
|
||||
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
file_record = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
path=path,
|
||||
mime_type=mime_type,
|
||||
overwrite=overwrite,
|
||||
)
|
||||
|
||||
return WorkspaceWriteResponse(
|
||||
file_id=file_record.id,
|
||||
name=file_record.name,
|
||||
path=file_record.path,
|
||||
size_bytes=file_record.sizeBytes,
|
||||
message=f"Successfully wrote file: {file_record.name}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
return ErrorResponse(
|
||||
message=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class DeleteWorkspaceFileTool(BaseTool):
|
||||
"""Tool for deleting files from workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "delete_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delete a file from the user's persistent workspace (cloud storage). "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
workspace = await get_or_create_workspace(user_id)
|
||||
# Pass session_id for session-scoped file access
|
||||
manager = WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
# Determine the file_id to delete
|
||||
target_file_id: str
|
||||
if file_id:
|
||||
target_file_id = file_id
|
||||
else:
|
||||
# path is guaranteed to be non-None here due to the check above
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
target_file_id = file_info.id
|
||||
|
||||
success = await manager.delete_file(target_file_id)
|
||||
|
||||
if not success:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {target_file_id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return WorkspaceDeleteResponse(
|
||||
file_id=target_file_id,
|
||||
success=True,
|
||||
message="File deleted successfully",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to delete workspace file: {str(e)}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -393,6 +393,7 @@ async def get_creators(
|
||||
@router.get(
|
||||
"/creator/{username}",
|
||||
summary="Get creator details",
|
||||
operation_id="getV2GetCreatorDetails",
|
||||
tags=["store", "public"],
|
||||
response_model=store_model.CreatorDetails,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ import fastapi
|
||||
from autogpt_libs.auth.dependencies import get_user_id, requires_user
|
||||
from fastapi.responses import Response
|
||||
|
||||
from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file
|
||||
from backend.data.workspace import get_workspace, get_workspace_file
|
||||
from backend.util.workspace_storage import get_workspace_storage
|
||||
|
||||
|
||||
@@ -44,11 +44,11 @@ router = fastapi.APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
|
||||
def _create_streaming_response(content: bytes, file) -> Response:
|
||||
"""Create a streaming response for file content."""
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=file.mime_type,
|
||||
media_type=file.mimeType,
|
||||
headers={
|
||||
"Content-Disposition": _sanitize_filename_for_header(file.name),
|
||||
"Content-Length": str(len(content)),
|
||||
@@ -56,7 +56,7 @@ def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response:
|
||||
)
|
||||
|
||||
|
||||
async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
async def _create_file_download_response(file) -> Response:
|
||||
"""
|
||||
Create a download response for a workspace file.
|
||||
|
||||
@@ -66,33 +66,33 @@ async def _create_file_download_response(file: WorkspaceFile) -> Response:
|
||||
storage = await get_workspace_storage()
|
||||
|
||||
# For local storage, stream the file directly
|
||||
if file.storage_path.startswith("local://"):
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
if file.storagePath.startswith("local://"):
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
return _create_streaming_response(content, file)
|
||||
|
||||
# For GCS, try to redirect to signed URL, fall back to streaming
|
||||
try:
|
||||
url = await storage.get_download_url(file.storage_path, expires_in=300)
|
||||
url = await storage.get_download_url(file.storagePath, expires_in=300)
|
||||
# If we got back an API path (fallback), stream directly instead
|
||||
if url.startswith("/api/"):
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
return _create_streaming_response(content, file)
|
||||
return fastapi.responses.RedirectResponse(url=url, status_code=302)
|
||||
except Exception as e:
|
||||
# Log the signed URL failure with context
|
||||
logger.error(
|
||||
f"Failed to get signed URL for file {file.id} "
|
||||
f"(storagePath={file.storage_path}): {e}",
|
||||
f"(storagePath={file.storagePath}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Fall back to streaming directly from GCS
|
||||
try:
|
||||
content = await storage.retrieve(file.storage_path)
|
||||
content = await storage.retrieve(file.storagePath)
|
||||
return _create_streaming_response(content, file)
|
||||
except Exception as fallback_error:
|
||||
logger.error(
|
||||
f"Fallback streaming also failed for file {file.id} "
|
||||
f"(storagePath={file.storage_path}): {fallback_error}",
|
||||
f"(storagePath={file.storagePath}): {fallback_error}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -18,6 +18,7 @@ from prisma.errors import PrismaError
|
||||
|
||||
import backend.api.features.admin.credit_admin_routes
|
||||
import backend.api.features.admin.execution_analytics_routes
|
||||
import backend.api.features.admin.llm_routes
|
||||
import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
@@ -39,13 +40,15 @@ import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.server.v2.llm.routes as public_llm_routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.copilot.completion_consumer import (
|
||||
from backend.api.features.chat.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.monitoring.instrumentation import instrument_fastapi
|
||||
@@ -116,11 +119,27 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
|
||||
AutoRegistry.patch_integrations()
|
||||
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
await llm_registry.refresh_llm_registry()
|
||||
await refresh_llm_costs()
|
||||
|
||||
# Clear block schema caches so they're regenerated with updated discriminator_mapping
|
||||
from backend.blocks._base import BlockSchema
|
||||
|
||||
BlockSchema.clear_all_schema_caches()
|
||||
|
||||
await backend.data.block.initialize_blocks()
|
||||
|
||||
await backend.data.user.migrate_and_encrypt_user_integrations()
|
||||
await backend.data.graph.fix_llm_provider_credentials()
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
# migrate_llm_models uses registry default model
|
||||
from backend.blocks.llm import LlmModel
|
||||
|
||||
default_model_slug = llm_registry.get_default_model_slug()
|
||||
if default_model_slug:
|
||||
await backend.data.graph.migrate_llm_models(LlmModel(default_model_slug))
|
||||
else:
|
||||
logger.warning("Skipping LLM model migration: no default model available")
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
# Start chat completion consumer for Redis Streams notifications
|
||||
@@ -322,6 +341,16 @@ app.include_router(
|
||||
tags=["v2", "executions", "review"],
|
||||
prefix="/api/review",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.admin.llm_routes.router,
|
||||
tags=["v2", "admin", "llm"],
|
||||
prefix="/api/llm/admin",
|
||||
)
|
||||
app.include_router(
|
||||
public_llm_routes.router,
|
||||
tags=["v2", "llm"],
|
||||
prefix="/api",
|
||||
)
|
||||
app.include_router(
|
||||
backend.api.features.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
|
||||
@@ -79,11 +79,49 @@ async def event_broadcaster(manager: ConnectionManager):
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
# Track registry pubsub for cleanup
|
||||
registry_pubsub = None
|
||||
|
||||
async def registry_refresh_worker():
|
||||
"""Listen for LLM registry refresh notifications and broadcast to all clients."""
|
||||
nonlocal registry_pubsub
|
||||
from backend.data.llm_registry import REGISTRY_REFRESH_CHANNEL
|
||||
from backend.data.redis_client import connect_async
|
||||
|
||||
redis = await connect_async()
|
||||
registry_pubsub = redis.pubsub()
|
||||
await registry_pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info(
|
||||
"Subscribed to LLM registry refresh notifications for WebSocket broadcast"
|
||||
)
|
||||
|
||||
async for message in registry_pubsub.listen():
|
||||
if (
|
||||
message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.info(
|
||||
"Broadcasting LLM registry refresh to all WebSocket clients"
|
||||
)
|
||||
await manager.broadcast_to_all(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data={
|
||||
"type": "LLM_REGISTRY_REFRESH",
|
||||
"event": "registry_updated",
|
||||
},
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
execution_worker(),
|
||||
notification_worker(),
|
||||
registry_refresh_worker(),
|
||||
)
|
||||
finally:
|
||||
# Ensure PubSub connections are closed on any exit to prevent leaks
|
||||
await execution_bus.close()
|
||||
await notification_bus.close()
|
||||
if registry_pubsub:
|
||||
await registry_pubsub.close()
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
@@ -38,9 +38,7 @@ def main(**kwargs):
|
||||
|
||||
from backend.api.rest_api import AgentServer
|
||||
from backend.api.ws_api import WebsocketServer
|
||||
from backend.copilot.executor.manager import CoPilotExecutor
|
||||
from backend.data.db_manager import DatabaseManager
|
||||
from backend.executor import ExecutionManager, Scheduler
|
||||
from backend.executor import DatabaseManager, ExecutionManager, Scheduler
|
||||
from backend.notifications import NotificationManager
|
||||
|
||||
run_processes(
|
||||
@@ -50,7 +48,6 @@ def main(**kwargs):
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
ExecutionManager(),
|
||||
CoPilotExecutor(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -134,7 +134,26 @@ class BlockInfo(BaseModel):
|
||||
|
||||
|
||||
class BlockSchema(BaseModel):
|
||||
cached_jsonschema: ClassVar[dict[str, Any]]
|
||||
cached_jsonschema: ClassVar[dict[str, Any] | None] = None
|
||||
|
||||
@classmethod
|
||||
def clear_schema_cache(cls) -> None:
|
||||
"""Clear the cached JSON schema for this class."""
|
||||
# Use None instead of {} because {} is truthy and would prevent regeneration
|
||||
cls.cached_jsonschema = None # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def clear_all_schema_caches() -> None:
|
||||
"""Clear cached JSON schemas for all BlockSchema subclasses."""
|
||||
|
||||
def clear_recursive(cls: type) -> None:
|
||||
"""Recursively clear cache for class and all subclasses."""
|
||||
if hasattr(cls, "clear_schema_cache"):
|
||||
cls.clear_schema_cache()
|
||||
for subclass in cls.__subclasses__():
|
||||
clear_recursive(subclass)
|
||||
|
||||
clear_recursive(BlockSchema)
|
||||
|
||||
@classmethod
|
||||
def jsonschema(cls) -> dict[str, Any]:
|
||||
@@ -225,7 +244,8 @@ class BlockSchema(BaseModel):
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
|
||||
# Reset cached JSON schema to prevent inheriting it from parent class
|
||||
cls.cached_jsonschema = {}
|
||||
# Use None instead of {} because {} is truthy and would prevent regeneration
|
||||
cls.cached_jsonschema = None
|
||||
|
||||
credentials_fields = cls.get_credentials_fields()
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from backend.blocks._base import (
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.blocks.llm import (
|
||||
DEFAULT_LLM_MODEL,
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
AIBlockBase,
|
||||
@@ -16,6 +15,7 @@ from backend.blocks.llm import (
|
||||
LlmModel,
|
||||
LLMResponse,
|
||||
llm_call,
|
||||
llm_model_schema_extra,
|
||||
)
|
||||
from backend.data.model import APIKeyCredentials, NodeExecutionStats, SchemaField
|
||||
|
||||
@@ -50,9 +50,10 @@ class AIConditionBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default_factory=LlmModel.default,
|
||||
description="The language model to use for evaluating the condition.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
|
||||
@@ -82,7 +83,7 @@ class AIConditionBlock(AIBlockBase):
|
||||
"condition": "the input is an email address",
|
||||
"yes_value": "Valid email",
|
||||
"no_value": "Not an email",
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": LlmModel.default(),
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
|
||||
@@ -4,16 +4,18 @@ import logging
|
||||
import re
|
||||
import secrets
|
||||
from abc import ABC
|
||||
from enum import Enum, EnumMeta
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Iterable, List, Literal, NamedTuple, Optional
|
||||
from typing import Any, Iterable, List, Literal, Optional
|
||||
|
||||
import anthropic
|
||||
import ollama
|
||||
import openai
|
||||
from anthropic.types import ToolParam
|
||||
from groq import AsyncGroq
|
||||
from pydantic import BaseModel, SecretStr
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, SecretStr
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from backend.blocks._base import (
|
||||
Block,
|
||||
@@ -22,6 +24,8 @@ from backend.blocks._base import (
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data import llm_registry
|
||||
from backend.data.llm_registry import ModelMetadata
|
||||
from backend.data.model import (
|
||||
APIKeyCredentials,
|
||||
CredentialsField,
|
||||
@@ -66,114 +70,123 @@ TEST_CREDENTIALS_INPUT = {
|
||||
|
||||
|
||||
def AICredentialsField() -> AICredentials:
|
||||
"""
|
||||
Returns a CredentialsField for LLM providers.
|
||||
The discriminator_mapping will be refreshed when the schema is generated
|
||||
if it's empty, ensuring the LLM registry is loaded.
|
||||
"""
|
||||
# Get the mapping now - it may be empty initially, but will be refreshed
|
||||
# when the schema is generated via CredentialsMetaInput._add_json_schema_extra
|
||||
mapping = llm_registry.get_llm_discriminator_mapping()
|
||||
|
||||
return CredentialsField(
|
||||
description="API key for the LLM provider.",
|
||||
discriminator="model",
|
||||
discriminator_mapping={
|
||||
model.value: model.metadata.provider for model in LlmModel
|
||||
},
|
||||
discriminator_mapping=mapping, # May be empty initially, refreshed later
|
||||
)
|
||||
|
||||
|
||||
class ModelMetadata(NamedTuple):
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int | None
|
||||
display_name: str
|
||||
provider_name: str
|
||||
creator_name: str
|
||||
price_tier: Literal[1, 2, 3]
|
||||
def llm_model_schema_extra() -> dict[str, Any]:
|
||||
return {"options": llm_registry.get_llm_model_schema_options()}
|
||||
|
||||
|
||||
class LlmModelMeta(EnumMeta):
|
||||
pass
|
||||
class LlmModelMeta(type):
|
||||
"""
|
||||
Metaclass for LlmModel that enables attribute-style access to dynamic models.
|
||||
|
||||
This allows code like `LlmModel.GPT4O` to work by converting the attribute
|
||||
name to a slug format:
|
||||
- GPT4O -> gpt-4o
|
||||
- GPT4O_MINI -> gpt-4o-mini
|
||||
- CLAUDE_3_5_SONNET -> claude-3-5-sonnet
|
||||
"""
|
||||
|
||||
def __getattr__(cls, name: str):
|
||||
# Don't intercept private/dunder attributes
|
||||
if name.startswith("_"):
|
||||
raise AttributeError(f"type object 'LlmModel' has no attribute '{name}'")
|
||||
|
||||
# Convert attribute name to slug format:
|
||||
# 1. Lowercase: GPT4O -> gpt4o
|
||||
# 2. Underscores to hyphens: GPT4O_MINI -> gpt4o-mini
|
||||
slug = name.lower().replace("_", "-")
|
||||
|
||||
# Check for exact match in registry first (e.g., "o1" stays "o1")
|
||||
registry_slugs = llm_registry.get_dynamic_model_slugs()
|
||||
if slug in registry_slugs:
|
||||
return cls(slug)
|
||||
|
||||
# If no exact match, try inserting hyphen between letter and digit
|
||||
# e.g., gpt4o -> gpt-4o
|
||||
transformed_slug = re.sub(r"([a-z])(\d)", r"\1-\2", slug)
|
||||
return cls(transformed_slug)
|
||||
|
||||
def __iter__(cls):
|
||||
"""Iterate over all models from the registry.
|
||||
|
||||
Yields LlmModel instances for each model in the dynamic registry.
|
||||
Used by __get_pydantic_json_schema__ to build model metadata.
|
||||
"""
|
||||
for model in llm_registry.iter_dynamic_models():
|
||||
yield cls(model.slug)
|
||||
|
||||
|
||||
class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
# OpenAI models
|
||||
O3_MINI = "o3-mini"
|
||||
O3 = "o3-2025-04-16"
|
||||
O1 = "o1"
|
||||
O1_MINI = "o1-mini"
|
||||
# GPT-5 models
|
||||
GPT5_2 = "gpt-5.2-2025-12-11"
|
||||
GPT5_1 = "gpt-5.1-2025-11-13"
|
||||
GPT5 = "gpt-5-2025-08-07"
|
||||
GPT5_MINI = "gpt-5-mini-2025-08-07"
|
||||
GPT5_NANO = "gpt-5-nano-2025-08-07"
|
||||
GPT5_CHAT = "gpt-5-chat-latest"
|
||||
GPT41 = "gpt-4.1-2025-04-14"
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
GPT4O_MINI = "gpt-4o-mini"
|
||||
GPT4O = "gpt-4o"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT3_5_TURBO = "gpt-3.5-turbo"
|
||||
# Anthropic models
|
||||
CLAUDE_4_1_OPUS = "claude-opus-4-1-20250805"
|
||||
CLAUDE_4_OPUS = "claude-opus-4-20250514"
|
||||
CLAUDE_4_SONNET = "claude-sonnet-4-20250514"
|
||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_4_6_OPUS = "claude-opus-4-6"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
AIML_API_LLAMA3_1_70B = "nvidia/llama-3.1-nemotron-70b-instruct"
|
||||
AIML_API_LLAMA3_3_70B = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
|
||||
AIML_API_META_LLAMA_3_1_70B = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
|
||||
AIML_API_LLAMA_3_2_3B = "meta-llama/Llama-3.2-3B-Instruct-Turbo"
|
||||
# Groq models
|
||||
LLAMA3_3_70B = "llama-3.3-70b-versatile"
|
||||
LLAMA3_1_8B = "llama-3.1-8b-instant"
|
||||
# Ollama models
|
||||
OLLAMA_LLAMA3_3 = "llama3.3"
|
||||
OLLAMA_LLAMA3_2 = "llama3.2"
|
||||
OLLAMA_LLAMA3_8B = "llama3"
|
||||
OLLAMA_LLAMA3_405B = "llama3.1:405b"
|
||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||
# OpenRouter models
|
||||
OPENAI_GPT_OSS_120B = "openai/gpt-oss-120b"
|
||||
OPENAI_GPT_OSS_20B = "openai/gpt-oss-20b"
|
||||
GEMINI_2_5_PRO = "google/gemini-2.5-pro-preview-03-25"
|
||||
GEMINI_3_PRO_PREVIEW = "google/gemini-3-pro-preview"
|
||||
GEMINI_2_5_FLASH = "google/gemini-2.5-flash"
|
||||
GEMINI_2_0_FLASH = "google/gemini-2.0-flash-001"
|
||||
GEMINI_2_5_FLASH_LITE_PREVIEW = "google/gemini-2.5-flash-lite-preview-06-17"
|
||||
GEMINI_2_0_FLASH_LITE = "google/gemini-2.0-flash-lite-001"
|
||||
MISTRAL_NEMO = "mistralai/mistral-nemo"
|
||||
COHERE_COMMAND_R_08_2024 = "cohere/command-r-08-2024"
|
||||
COHERE_COMMAND_R_PLUS_08_2024 = "cohere/command-r-plus-08-2024"
|
||||
DEEPSEEK_CHAT = "deepseek/deepseek-chat" # Actually: DeepSeek V3
|
||||
DEEPSEEK_R1_0528 = "deepseek/deepseek-r1-0528"
|
||||
PERPLEXITY_SONAR = "perplexity/sonar"
|
||||
PERPLEXITY_SONAR_PRO = "perplexity/sonar-pro"
|
||||
PERPLEXITY_SONAR_DEEP_RESEARCH = "perplexity/sonar-deep-research"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B = "nousresearch/hermes-3-llama-3.1-405b"
|
||||
NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B = "nousresearch/hermes-3-llama-3.1-70b"
|
||||
AMAZON_NOVA_LITE_V1 = "amazon/nova-lite-v1"
|
||||
AMAZON_NOVA_MICRO_V1 = "amazon/nova-micro-v1"
|
||||
AMAZON_NOVA_PRO_V1 = "amazon/nova-pro-v1"
|
||||
MICROSOFT_WIZARDLM_2_8X22B = "microsoft/wizardlm-2-8x22b"
|
||||
GRYPHE_MYTHOMAX_L2_13B = "gryphe/mythomax-l2-13b"
|
||||
META_LLAMA_4_SCOUT = "meta-llama/llama-4-scout"
|
||||
META_LLAMA_4_MAVERICK = "meta-llama/llama-4-maverick"
|
||||
GROK_4 = "x-ai/grok-4"
|
||||
GROK_4_FAST = "x-ai/grok-4-fast"
|
||||
GROK_4_1_FAST = "x-ai/grok-4.1-fast"
|
||||
GROK_CODE_FAST_1 = "x-ai/grok-code-fast-1"
|
||||
KIMI_K2 = "moonshotai/kimi-k2"
|
||||
QWEN3_235B_A22B_THINKING = "qwen/qwen3-235b-a22b-thinking-2507"
|
||||
QWEN3_CODER = "qwen/qwen3-coder"
|
||||
# Llama API models
|
||||
LLAMA_API_LLAMA_4_SCOUT = "Llama-4-Scout-17B-16E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA4_MAVERICK = "Llama-4-Maverick-17B-128E-Instruct-FP8"
|
||||
LLAMA_API_LLAMA3_3_8B = "Llama-3.3-8B-Instruct"
|
||||
LLAMA_API_LLAMA3_3_70B = "Llama-3.3-70B-Instruct"
|
||||
# v0 by Vercel models
|
||||
V0_1_5_MD = "v0-1.5-md"
|
||||
V0_1_5_LG = "v0-1.5-lg"
|
||||
V0_1_0_MD = "v0-1.0-md"
|
||||
class LlmModel(str, metaclass=LlmModelMeta):
|
||||
"""
|
||||
Dynamic LLM model type that accepts any model slug from the registry.
|
||||
|
||||
This is a string subclass (not an Enum) that allows any model slug value.
|
||||
All models are managed via the LLM Registry in the database.
|
||||
|
||||
Usage:
|
||||
model = LlmModel("gpt-4o") # Direct construction
|
||||
model = LlmModel.GPT4O # Attribute access (converted to "gpt-4o")
|
||||
model.value # Returns the slug string
|
||||
model.provider # Returns the provider from registry
|
||||
"""
|
||||
|
||||
def __new__(cls, value: str):
|
||||
if isinstance(value, LlmModel):
|
||||
return value
|
||||
return str.__new__(cls, value)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""
|
||||
Tell Pydantic how to validate LlmModel.
|
||||
|
||||
Accepts strings and converts them to LlmModel instances.
|
||||
"""
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls, # The validator function (LlmModel constructor)
|
||||
core_schema.str_schema(), # Accept string input
|
||||
serialization=core_schema.to_string_ser_schema(), # Serialize as string
|
||||
)
|
||||
|
||||
@property
|
||||
def value(self) -> str:
|
||||
"""Return the model slug (for compatibility with enum-style access)."""
|
||||
return str(self)
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> "LlmModel":
|
||||
"""
|
||||
Get the default model from the registry.
|
||||
|
||||
Returns the recommended model if set, otherwise gpt-4o if available
|
||||
and enabled, otherwise the first enabled model from the registry.
|
||||
Falls back to "gpt-4o" if registry is empty (e.g., at module import time).
|
||||
"""
|
||||
from backend.data.llm_registry import get_default_model_slug
|
||||
|
||||
slug = get_default_model_slug()
|
||||
if slug is None:
|
||||
# Registry is empty (e.g., at module import time before DB connection).
|
||||
# Fall back to gpt-4o for backward compatibility.
|
||||
slug = "gpt-4o"
|
||||
return cls(slug)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_json_schema__(cls, schema, handler):
|
||||
@@ -181,7 +194,15 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
llm_model_metadata = {}
|
||||
for model in cls:
|
||||
model_name = model.value
|
||||
metadata = model.metadata
|
||||
# Skip disabled models - only show enabled models in the picker
|
||||
if not llm_registry.is_model_enabled(model_name):
|
||||
continue
|
||||
# Use registry directly with None check to gracefully handle
|
||||
# missing metadata during startup/import before registry is populated
|
||||
metadata = llm_registry.get_llm_model_metadata(model_name)
|
||||
if metadata is None:
|
||||
# Skip models without metadata (registry not yet populated)
|
||||
continue
|
||||
llm_model_metadata[model_name] = {
|
||||
"creator": metadata.creator_name,
|
||||
"creator_name": metadata.creator_name,
|
||||
@@ -197,7 +218,12 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
return MODEL_METADATA[self]
|
||||
metadata = llm_registry.get_llm_model_metadata(self.value)
|
||||
if metadata:
|
||||
return metadata
|
||||
raise ValueError(
|
||||
f"Missing metadata for model: {self.value}. Model not found in LLM registry."
|
||||
)
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
@@ -212,300 +238,125 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
return self.metadata.max_output_tokens
|
||||
|
||||
|
||||
MODEL_METADATA = {
|
||||
# https://platform.openai.com/docs/models
|
||||
LlmModel.O3: ModelMetadata("openai", 200000, 100000, "O3", "OpenAI", "OpenAI", 2),
|
||||
LlmModel.O3_MINI: ModelMetadata(
|
||||
"openai", 200000, 100000, "O3 Mini", "OpenAI", "OpenAI", 1
|
||||
), # o3-mini-2025-01-31
|
||||
LlmModel.O1: ModelMetadata(
|
||||
"openai", 200000, 100000, "O1", "OpenAI", "OpenAI", 3
|
||||
), # o1-2024-12-17
|
||||
LlmModel.O1_MINI: ModelMetadata(
|
||||
"openai", 128000, 65536, "O1 Mini", "OpenAI", "OpenAI", 2
|
||||
), # o1-mini-2024-09-12
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5.2", "OpenAI", "OpenAI", 3
|
||||
),
|
||||
LlmModel.GPT5_1: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5.1", "OpenAI", "OpenAI", 2
|
||||
),
|
||||
LlmModel.GPT5: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT5_MINI: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5 Mini", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT5_NANO: ModelMetadata(
|
||||
"openai", 400000, 128000, "GPT-5 Nano", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT5_CHAT: ModelMetadata(
|
||||
"openai", 400000, 16384, "GPT-5 Chat Latest", "OpenAI", "OpenAI", 2
|
||||
),
|
||||
LlmModel.GPT41: ModelMetadata(
|
||||
"openai", 1047576, 32768, "GPT-4.1", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT41_MINI: ModelMetadata(
|
||||
"openai", 1047576, 32768, "GPT-4.1 Mini", "OpenAI", "OpenAI", 1
|
||||
),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata(
|
||||
"openai", 128000, 16384, "GPT-4o Mini", "OpenAI", "OpenAI", 1
|
||||
), # gpt-4o-mini-2024-07-18
|
||||
LlmModel.GPT4O: ModelMetadata(
|
||||
"openai", 128000, 16384, "GPT-4o", "OpenAI", "OpenAI", 2
|
||||
), # gpt-4o-2024-08-06
|
||||
LlmModel.GPT4_TURBO: ModelMetadata(
|
||||
"openai", 128000, 4096, "GPT-4 Turbo", "OpenAI", "OpenAI", 3
|
||||
), # gpt-4-turbo-2024-04-09
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata(
|
||||
"openai", 16385, 4096, "GPT-3.5 Turbo", "OpenAI", "OpenAI", 1
|
||||
), # gpt-3.5-turbo-0125
|
||||
# https://docs.anthropic.com/en/docs/about-claude/models
|
||||
LlmModel.CLAUDE_4_1_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000, "Claude Opus 4.1", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-1-20250805
|
||||
LlmModel.CLAUDE_4_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 32000, "Claude Opus 4", "Anthropic", "Anthropic", 3
|
||||
), # claude-4-opus-20250514
|
||||
LlmModel.CLAUDE_4_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2
|
||||
), # claude-4-sonnet-20250514
|
||||
LlmModel.CLAUDE_4_6_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-6
|
||||
LlmModel.CLAUDE_4_5_OPUS: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3
|
||||
), # claude-opus-4-5-20251101
|
||||
LlmModel.CLAUDE_4_5_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Sonnet 4.5", "Anthropic", "Anthropic", 3
|
||||
), # claude-sonnet-4-5-20250929
|
||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
||||
), # claude-haiku-4-5-20251001
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
||||
), # claude-3-haiku-20240307
|
||||
# https://docs.aimlapi.com/api-overview/model-database/text-models
|
||||
LlmModel.AIML_API_QWEN2_5_72B: ModelMetadata(
|
||||
"aiml_api", 32000, 8000, "Qwen 2.5 72B Instruct Turbo", "AI/ML", "Qwen", 1
|
||||
),
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: ModelMetadata(
|
||||
"aiml_api",
|
||||
128000,
|
||||
40000,
|
||||
"Llama 3.1 Nemotron 70B Instruct",
|
||||
"AI/ML",
|
||||
"Nvidia",
|
||||
1,
|
||||
),
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: ModelMetadata(
|
||||
"aiml_api", 128000, None, "Llama 3.3 70B Instruct Turbo", "AI/ML", "Meta", 1
|
||||
),
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: ModelMetadata(
|
||||
"aiml_api", 131000, 2000, "Llama 3.1 70B Instruct Turbo", "AI/ML", "Meta", 1
|
||||
),
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: ModelMetadata(
|
||||
"aiml_api", 128000, None, "Llama 3.2 3B Instruct Turbo", "AI/ML", "Meta", 1
|
||||
),
|
||||
# https://console.groq.com/docs/models
|
||||
LlmModel.LLAMA3_3_70B: ModelMetadata(
|
||||
"groq", 128000, 32768, "Llama 3.3 70B Versatile", "Groq", "Meta", 1
|
||||
),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata(
|
||||
"groq", 128000, 8192, "Llama 3.1 8B Instant", "Groq", "Meta", 1
|
||||
),
|
||||
# https://ollama.com/library
|
||||
LlmModel.OLLAMA_LLAMA3_3: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3.3", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3.2", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata(
|
||||
"ollama", 8192, None, "Llama 3.1 405B", "Ollama", "Meta", 1
|
||||
),
|
||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata(
|
||||
"ollama", 32768, None, "Dolphin Mistral Latest", "Ollama", "Mistral AI", 1
|
||||
),
|
||||
# https://openrouter.ai/models
|
||||
LlmModel.GEMINI_2_5_PRO: ModelMetadata(
|
||||
"open_router",
|
||||
1050000,
|
||||
8192,
|
||||
"Gemini 2.5 Pro Preview 03.25",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
2,
|
||||
),
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 3 Pro Preview", "OpenRouter", "Google", 2
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 65535, "Gemini 2.5 Flash", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_2_0_FLASH: ModelMetadata(
|
||||
"open_router", 1048576, 8192, "Gemini 2.0 Flash 001", "OpenRouter", "Google", 1
|
||||
),
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
65535,
|
||||
"Gemini 2.5 Flash Lite Preview 06.17",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: ModelMetadata(
|
||||
"open_router",
|
||||
1048576,
|
||||
8192,
|
||||
"Gemini 2.0 Flash Lite 001",
|
||||
"OpenRouter",
|
||||
"Google",
|
||||
1,
|
||||
),
|
||||
LlmModel.MISTRAL_NEMO: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Mistral Nemo", "OpenRouter", "Mistral AI", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R 08.2024", "OpenRouter", "Cohere", 1
|
||||
),
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: ModelMetadata(
|
||||
"open_router", 128000, 4096, "Command R Plus 08.2024", "OpenRouter", "Cohere", 2
|
||||
),
|
||||
LlmModel.DEEPSEEK_CHAT: ModelMetadata(
|
||||
"open_router", 64000, 2048, "DeepSeek Chat", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
LlmModel.DEEPSEEK_R1_0528: ModelMetadata(
|
||||
"open_router", 163840, 163840, "DeepSeek R1 0528", "OpenRouter", "DeepSeek", 1
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR: ModelMetadata(
|
||||
"open_router", 127000, 8000, "Sonar", "OpenRouter", "Perplexity", 1
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: ModelMetadata(
|
||||
"open_router", 200000, 8000, "Sonar Pro", "OpenRouter", "Perplexity", 2
|
||||
),
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: ModelMetadata(
|
||||
"open_router",
|
||||
128000,
|
||||
16000,
|
||||
"Sonar Deep Research",
|
||||
"OpenRouter",
|
||||
"Perplexity",
|
||||
3,
|
||||
),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: ModelMetadata(
|
||||
"open_router",
|
||||
131000,
|
||||
4096,
|
||||
"Hermes 3 Llama 3.1 405B",
|
||||
"OpenRouter",
|
||||
"Nous Research",
|
||||
1,
|
||||
),
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: ModelMetadata(
|
||||
"open_router",
|
||||
12288,
|
||||
12288,
|
||||
"Hermes 3 Llama 3.1 70B",
|
||||
"OpenRouter",
|
||||
"Nous Research",
|
||||
1,
|
||||
),
|
||||
LlmModel.OPENAI_GPT_OSS_120B: ModelMetadata(
|
||||
"open_router", 131072, 131072, "GPT-OSS 120B", "OpenRouter", "OpenAI", 1
|
||||
),
|
||||
LlmModel.OPENAI_GPT_OSS_20B: ModelMetadata(
|
||||
"open_router", 131072, 32768, "GPT-OSS 20B", "OpenRouter", "OpenAI", 1
|
||||
),
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: ModelMetadata(
|
||||
"open_router", 300000, 5120, "Nova Lite V1", "OpenRouter", "Amazon", 1
|
||||
),
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: ModelMetadata(
|
||||
"open_router", 128000, 5120, "Nova Micro V1", "OpenRouter", "Amazon", 1
|
||||
),
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: ModelMetadata(
|
||||
"open_router", 300000, 5120, "Nova Pro V1", "OpenRouter", "Amazon", 1
|
||||
),
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: ModelMetadata(
|
||||
"open_router", 65536, 4096, "WizardLM 2 8x22B", "OpenRouter", "Microsoft", 1
|
||||
),
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: ModelMetadata(
|
||||
"open_router", 4096, 4096, "MythoMax L2 13B", "OpenRouter", "Gryphe", 1
|
||||
),
|
||||
LlmModel.META_LLAMA_4_SCOUT: ModelMetadata(
|
||||
"open_router", 131072, 131072, "Llama 4 Scout", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.META_LLAMA_4_MAVERICK: ModelMetadata(
|
||||
"open_router", 1048576, 1000000, "Llama 4 Maverick", "OpenRouter", "Meta", 1
|
||||
),
|
||||
LlmModel.GROK_4: ModelMetadata(
|
||||
"open_router", 256000, 256000, "Grok 4", "OpenRouter", "xAI", 3
|
||||
),
|
||||
LlmModel.GROK_4_FAST: ModelMetadata(
|
||||
"open_router", 2000000, 30000, "Grok 4 Fast", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.GROK_4_1_FAST: ModelMetadata(
|
||||
"open_router", 2000000, 30000, "Grok 4.1 Fast", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.GROK_CODE_FAST_1: ModelMetadata(
|
||||
"open_router", 256000, 10000, "Grok Code Fast 1", "OpenRouter", "xAI", 1
|
||||
),
|
||||
LlmModel.KIMI_K2: ModelMetadata(
|
||||
"open_router", 131000, 131000, "Kimi K2", "OpenRouter", "Moonshot AI", 1
|
||||
),
|
||||
LlmModel.QWEN3_235B_A22B_THINKING: ModelMetadata(
|
||||
"open_router",
|
||||
262144,
|
||||
262144,
|
||||
"Qwen 3 235B A22B Thinking 2507",
|
||||
"OpenRouter",
|
||||
"Qwen",
|
||||
1,
|
||||
),
|
||||
LlmModel.QWEN3_CODER: ModelMetadata(
|
||||
"open_router", 262144, 262144, "Qwen 3 Coder", "OpenRouter", "Qwen", 3
|
||||
),
|
||||
# Llama API models
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: ModelMetadata(
|
||||
"llama_api",
|
||||
128000,
|
||||
4028,
|
||||
"Llama 4 Scout 17B 16E Instruct FP8",
|
||||
"Llama API",
|
||||
"Meta",
|
||||
1,
|
||||
),
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: ModelMetadata(
|
||||
"llama_api",
|
||||
128000,
|
||||
4028,
|
||||
"Llama 4 Maverick 17B 128E Instruct FP8",
|
||||
"Llama API",
|
||||
"Meta",
|
||||
1,
|
||||
),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: ModelMetadata(
|
||||
"llama_api", 128000, 4028, "Llama 3.3 8B Instruct", "Llama API", "Meta", 1
|
||||
),
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: ModelMetadata(
|
||||
"llama_api", 128000, 4028, "Llama 3.3 70B Instruct", "Llama API", "Meta", 1
|
||||
),
|
||||
# v0 by Vercel models
|
||||
LlmModel.V0_1_5_MD: ModelMetadata("v0", 128000, 64000, "v0 1.5 MD", "V0", "V0", 1),
|
||||
LlmModel.V0_1_5_LG: ModelMetadata("v0", 512000, 64000, "v0 1.5 LG", "V0", "V0", 1),
|
||||
LlmModel.V0_1_0_MD: ModelMetadata("v0", 128000, 64000, "v0 1.0 MD", "V0", "V0", 1),
|
||||
}
|
||||
# Default model constant for backward compatibility
|
||||
# Uses the dynamic registry to get the default model
|
||||
DEFAULT_LLM_MODEL = LlmModel.default()
|
||||
|
||||
DEFAULT_LLM_MODEL = LlmModel.GPT5_2
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_METADATA:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
class ModelUnavailableError(ValueError):
|
||||
"""Raised when a requested LLM model cannot be resolved for use."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedModel:
|
||||
"""Result of resolving a model for an LLM call."""
|
||||
|
||||
slug: str # The actual model slug to use (may differ from requested if fallback)
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int
|
||||
used_fallback: bool = False
|
||||
original_slug: str | None = None # Set if fallback was used
|
||||
|
||||
|
||||
async def resolve_model_for_call(llm_model: LlmModel) -> ResolvedModel:
|
||||
"""
|
||||
Resolve a model for use in an LLM call.
|
||||
|
||||
Handles:
|
||||
- Checking if the model exists in the registry
|
||||
- Falling back to an enabled model from the same provider if disabled
|
||||
- Refreshing the registry cache if model not found (with DB access)
|
||||
|
||||
Args:
|
||||
llm_model: The requested LlmModel
|
||||
|
||||
Returns:
|
||||
ResolvedModel with all necessary metadata for the call
|
||||
|
||||
Raises:
|
||||
ModelUnavailableError: If model cannot be resolved (not found, disabled with no fallback)
|
||||
"""
|
||||
from backend.data.llm_registry import (
|
||||
get_fallback_model_for_disabled,
|
||||
get_model_info,
|
||||
)
|
||||
|
||||
model_info = get_model_info(llm_model.value)
|
||||
|
||||
# Case 1: Model found and disabled - try fallback
|
||||
if model_info and not model_info.is_enabled:
|
||||
fallback = get_fallback_model_for_disabled(llm_model.value)
|
||||
if fallback:
|
||||
logger.warning(
|
||||
f"Model '{llm_model.value}' is disabled. Using fallback "
|
||||
f"'{fallback.slug}' from same provider ({fallback.metadata.provider})."
|
||||
)
|
||||
return ResolvedModel(
|
||||
slug=fallback.slug,
|
||||
provider=fallback.metadata.provider,
|
||||
context_window=fallback.metadata.context_window,
|
||||
max_output_tokens=fallback.metadata.max_output_tokens or 2**15,
|
||||
used_fallback=True,
|
||||
original_slug=llm_model.value,
|
||||
)
|
||||
raise ModelUnavailableError(
|
||||
f"Model '{llm_model.value}' is disabled and no fallback from the same "
|
||||
f"provider is available. Enable the model or select a different one."
|
||||
)
|
||||
|
||||
# Case 2: Model found and enabled - use it directly
|
||||
if model_info:
|
||||
return ResolvedModel(
|
||||
slug=llm_model.value,
|
||||
provider=model_info.metadata.provider,
|
||||
context_window=model_info.metadata.context_window,
|
||||
max_output_tokens=model_info.metadata.max_output_tokens or 2**15,
|
||||
)
|
||||
|
||||
# Case 3: Model not in registry - try refresh if DB available
|
||||
logger.warning(f"Model '{llm_model.value}' not found in registry cache")
|
||||
|
||||
from backend.data.db import is_connected
|
||||
|
||||
if not is_connected():
|
||||
raise ModelUnavailableError(
|
||||
f"Model '{llm_model.value}' not found in registry. "
|
||||
f"The registry may need to be refreshed via the admin UI."
|
||||
)
|
||||
|
||||
# Try refreshing the registry
|
||||
try:
|
||||
logger.info(f"Refreshing LLM registry for model '{llm_model.value}'")
|
||||
await llm_registry.refresh_llm_registry()
|
||||
except Exception as e:
|
||||
raise ModelUnavailableError(
|
||||
f"Model '{llm_model.value}' not found and registry refresh failed: {e}"
|
||||
) from e
|
||||
|
||||
# Check again after refresh
|
||||
model_info = get_model_info(llm_model.value)
|
||||
if not model_info:
|
||||
raise ModelUnavailableError(
|
||||
f"Model '{llm_model.value}' not found in registry. "
|
||||
f"Add it via the admin UI at /admin/llms."
|
||||
)
|
||||
|
||||
if not model_info.is_enabled:
|
||||
raise ModelUnavailableError(
|
||||
f"Model '{llm_model.value}' exists but is disabled. "
|
||||
f"Enable it via the admin UI at /admin/llms."
|
||||
)
|
||||
|
||||
logger.info(f"Model '{llm_model.value}' loaded after registry refresh")
|
||||
return ResolvedModel(
|
||||
slug=llm_model.value,
|
||||
provider=model_info.metadata.provider,
|
||||
context_window=model_info.metadata.context_window,
|
||||
max_output_tokens=model_info.metadata.max_output_tokens or 2**15,
|
||||
)
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
@@ -531,12 +382,12 @@ class LLMResponse(BaseModel):
|
||||
|
||||
def convert_openai_tool_fmt_to_anthropic(
|
||||
openai_tools: list[dict] | None = None,
|
||||
) -> Iterable[ToolParam] | anthropic.Omit:
|
||||
) -> Iterable[ToolParam] | anthropic.NotGiven:
|
||||
"""
|
||||
Convert OpenAI tool format to Anthropic tool format.
|
||||
"""
|
||||
if not openai_tools or len(openai_tools) == 0:
|
||||
return anthropic.omit
|
||||
return anthropic.NOT_GIVEN
|
||||
|
||||
anthropic_tools = []
|
||||
for tool in openai_tools:
|
||||
@@ -598,7 +449,12 @@ def get_parallel_tool_calls_param(
|
||||
llm_model: LlmModel, parallel_tool_calls: bool | None
|
||||
) -> bool | openai.Omit:
|
||||
"""Get the appropriate parallel_tool_calls parameter for OpenAI-compatible APIs."""
|
||||
if llm_model.startswith("o") or parallel_tool_calls is None:
|
||||
# Check for o-series models (o1, o1-mini, o3-mini, etc.) which don't support
|
||||
# parallel tool calls. Handle both bare slugs ("o1-mini") and provider-prefixed
|
||||
# slugs ("openai/o1-mini"). The pattern matches "o" followed by a digit at the
|
||||
# start of the string or after a "/" separator.
|
||||
is_o_series = re.search(r"(^|/)o\d", llm_model) is not None
|
||||
if is_o_series or parallel_tool_calls is None:
|
||||
return openai.omit
|
||||
return parallel_tool_calls
|
||||
|
||||
@@ -634,15 +490,22 @@ async def llm_call(
|
||||
- prompt_tokens: The number of tokens used in the prompt.
|
||||
- completion_tokens: The number of tokens used in the completion.
|
||||
"""
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
# Resolve the model - handles disabled models, fallbacks, and cache misses
|
||||
resolved = await resolve_model_for_call(llm_model)
|
||||
|
||||
model_to_use = resolved.slug
|
||||
provider = resolved.provider
|
||||
context_window = resolved.context_window
|
||||
model_max_output = resolved.max_output_tokens
|
||||
|
||||
# Create effective model for model-specific parameter resolution (e.g., o-series check)
|
||||
effective_model = LlmModel(model_to_use)
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
result = await compress_context(
|
||||
messages=prompt,
|
||||
target_tokens=llm_model.context_window // 2,
|
||||
target_tokens=context_window // 2,
|
||||
client=None, # Truncation-only, no LLM summarization
|
||||
reserve=0, # Caller handles response token budget separately
|
||||
)
|
||||
if result.error:
|
||||
logger.warning(
|
||||
@@ -653,7 +516,7 @@ async def llm_call(
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
model_max_output = llm_model.max_output_tokens or int(2**15)
|
||||
# model_max_output already set above
|
||||
user_max = max_tokens or model_max_output
|
||||
available_tokens = max(context_window - estimated_input_tokens, 0)
|
||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||
@@ -664,14 +527,14 @@ async def llm_call(
|
||||
response_format = None
|
||||
|
||||
parallel_tool_calls = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
effective_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
if force_json_output:
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
response = await oai_client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_completion_tokens=max_tokens,
|
||||
@@ -718,7 +581,7 @@ async def llm_call(
|
||||
)
|
||||
try:
|
||||
resp = await client.messages.create(
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
@@ -782,7 +645,7 @@ async def llm_call(
|
||||
client = AsyncGroq(api_key=credentials.api_key.get_secret_value())
|
||||
response_format = {"type": "json_object"} if force_json_output else None
|
||||
response = await client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
@@ -804,7 +667,7 @@ async def llm_call(
|
||||
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
|
||||
response = await client.generate(
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
prompt=f"{sys_messages}\n\n{usr_messages}",
|
||||
stream=False,
|
||||
options={"num_ctx": max_tokens},
|
||||
@@ -826,7 +689,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
effective_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
@@ -834,7 +697,7 @@ async def llm_call(
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
@@ -868,7 +731,7 @@ async def llm_call(
|
||||
)
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
effective_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
@@ -876,7 +739,7 @@ async def llm_call(
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
@@ -903,7 +766,7 @@ async def llm_call(
|
||||
reasoning=reasoning,
|
||||
)
|
||||
elif provider == "aiml_api":
|
||||
client = openai.OpenAI(
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url="https://api.aimlapi.com/v2",
|
||||
api_key=credentials.api_key.get_secret_value(),
|
||||
default_headers={
|
||||
@@ -913,8 +776,8 @@ async def llm_call(
|
||||
},
|
||||
)
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
completion = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
@@ -942,11 +805,11 @@ async def llm_call(
|
||||
response_format = {"type": "json_object"}
|
||||
|
||||
parallel_tool_calls_param = get_parallel_tool_calls_param(
|
||||
llm_model, parallel_tool_calls
|
||||
effective_model, parallel_tool_calls
|
||||
)
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=llm_model.value,
|
||||
model=model_to_use,
|
||||
messages=prompt, # type: ignore
|
||||
response_format=response_format, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
@@ -997,9 +860,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default_factory=LlmModel.default,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
force_json_output: bool = SchemaField(
|
||||
title="Restrict LLM to pure JSON output",
|
||||
@@ -1062,7 +926,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
input_schema=AIStructuredResponseGeneratorBlock.Input,
|
||||
output_schema=AIStructuredResponseGeneratorBlock.Output,
|
||||
test_input={
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expected_format": {
|
||||
"key1": "value1",
|
||||
@@ -1428,9 +1292,10 @@ class AITextGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default_factory=LlmModel.default,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
sys_prompt: str = SchemaField(
|
||||
@@ -1524,8 +1389,9 @@ class AITextSummarizerBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default_factory=LlmModel.default,
|
||||
description="The language model to use for summarizing the text.",
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
focus: str = SchemaField(
|
||||
title="Focus",
|
||||
@@ -1741,8 +1607,9 @@ class AIConversationBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default_factory=LlmModel.default,
|
||||
description="The language model to use for the conversation.",
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
max_tokens: int | None = SchemaField(
|
||||
@@ -1779,7 +1646,7 @@ class AIConversationBlock(AIBlockBase):
|
||||
},
|
||||
{"role": "user", "content": "Where was it played?"},
|
||||
],
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
@@ -1842,9 +1709,10 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
model: LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=DEFAULT_LLM_MODEL,
|
||||
default_factory=LlmModel.default,
|
||||
description="The language model to use for generating the list.",
|
||||
advanced=True,
|
||||
json_schema_extra=llm_model_schema_extra(),
|
||||
)
|
||||
credentials: AICredentials = AICredentialsField()
|
||||
max_retries: int = SchemaField(
|
||||
@@ -1899,7 +1767,7 @@ class AIListGeneratorBlock(AIBlockBase):
|
||||
"drawing explorers to uncover its mysteries. Each planet showcases the limitless possibilities of "
|
||||
"fictional worlds."
|
||||
),
|
||||
"model": DEFAULT_LLM_MODEL,
|
||||
"model": "gpt-4o", # Using string value - enum accepts any model slug dynamically
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_retries": 3,
|
||||
"force_json_output": False,
|
||||
|
||||
@@ -226,9 +226,10 @@ class SmartDecisionMakerBlock(Block):
|
||||
)
|
||||
model: llm.LlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
default=llm.DEFAULT_LLM_MODEL,
|
||||
default_factory=llm.LlmModel.default,
|
||||
description="The language model to use for answering the prompt.",
|
||||
advanced=False,
|
||||
json_schema_extra=llm.llm_model_schema_extra(),
|
||||
)
|
||||
credentials: llm.AICredentials = llm.AICredentialsField()
|
||||
multiple_tool_calls: bool = SchemaField(
|
||||
|
||||
@@ -10,13 +10,13 @@ import stagehand.main
|
||||
from stagehand import Stagehand
|
||||
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AICredentials,
|
||||
AICredentialsField,
|
||||
LlmModel,
|
||||
ModelMetadata,
|
||||
)
|
||||
from backend.blocks.stagehand._config import stagehand as stagehand_provider
|
||||
from backend.data import llm_registry
|
||||
from backend.sdk import (
|
||||
APIKeyCredentials,
|
||||
Block,
|
||||
@@ -91,7 +91,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
Returns the provider name for the model in the required format for Stagehand:
|
||||
provider/model_name
|
||||
"""
|
||||
model_metadata = MODEL_METADATA[LlmModel(self.value)]
|
||||
model_metadata = self.metadata
|
||||
model_name = self.value
|
||||
|
||||
if len(model_name.split("/")) == 1 and not self.value.startswith(
|
||||
@@ -102,24 +102,28 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
), "Logic failed and open_router provider attempted to be prepended to model name! in stagehand/_config.py"
|
||||
model_name = f"{model_metadata.provider}/{model_name}"
|
||||
|
||||
logger.error(f"Model name: {model_name}")
|
||||
logger.debug(f"Model name: {model_name}")
|
||||
return model_name
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
return MODEL_METADATA[LlmModel(self.value)].provider
|
||||
return self.metadata.provider
|
||||
|
||||
@property
|
||||
def metadata(self) -> ModelMetadata:
|
||||
return MODEL_METADATA[LlmModel(self.value)]
|
||||
metadata = llm_registry.get_llm_model_metadata(self.value)
|
||||
if metadata:
|
||||
return metadata
|
||||
# Fallback to LlmModel enum if registry lookup fails
|
||||
return LlmModel(self.value).metadata
|
||||
|
||||
@property
|
||||
def context_window(self) -> int:
|
||||
return MODEL_METADATA[LlmModel(self.value)].context_window
|
||||
return self.metadata.context_window
|
||||
|
||||
@property
|
||||
def max_output_tokens(self) -> int | None:
|
||||
return MODEL_METADATA[LlmModel(self.value)].max_output_tokens
|
||||
return self.metadata.max_output_tokens
|
||||
|
||||
|
||||
class StagehandObserveBlock(Block):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -28,54 +27,6 @@ async def server():
|
||||
yield server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_id() -> str:
|
||||
"""Test user ID fixture."""
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_id() -> str:
|
||||
"""Admin user ID fixture."""
|
||||
return "4e53486c-cf57-477e-ba2a-cb02dc828e1b"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def target_user_id() -> str:
|
||||
"""Target user ID fixture."""
|
||||
return "5e53486c-cf57-477e-ba2a-cb02dc828e1c"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_test_user(test_user_id):
|
||||
"""Create test user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the test user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": test_user_id,
|
||||
"email": "test@example.com",
|
||||
"user_metadata": {"name": "Test User"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return test_user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_admin_user(admin_user_id):
|
||||
"""Create admin user in database before tests."""
|
||||
from backend.data.user import get_or_create_user
|
||||
|
||||
# Create the admin user in the database using JWT token format
|
||||
user_data = {
|
||||
"sub": admin_user_id,
|
||||
"email": "test-admin@example.com",
|
||||
"user_metadata": {"name": "Test Admin"},
|
||||
}
|
||||
await get_or_create_user(user_data)
|
||||
return admin_user_id
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True)
|
||||
async def graph_cleanup(server):
|
||||
created_graph_ids = []
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
"""Entry point for running the CoPilot Executor service.
|
||||
|
||||
Usage:
|
||||
python -m backend.copilot.executor
|
||||
"""
|
||||
|
||||
from backend.app import run_processes
|
||||
|
||||
from .manager import CoPilotExecutor
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the CoPilot Executor service."""
|
||||
run_processes(CoPilotExecutor())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,520 +0,0 @@
|
||||
"""CoPilot Executor Manager - main service for CoPilot task execution.
|
||||
|
||||
This module contains the CoPilotExecutor class that consumes chat tasks from
|
||||
RabbitMQ and processes them using a thread pool, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
from pika.adapters.blocking_connection import BlockingChannel
|
||||
from pika.exceptions import AMQPChannelError, AMQPConnectionError
|
||||
from pika.spec import Basic, BasicProperties
|
||||
from prometheus_client import Gauge, start_http_server
|
||||
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.retry import continuous_retry
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .processor import execute_copilot_task, init_worker
|
||||
from .utils import (
|
||||
COPILOT_CANCEL_QUEUE_NAME,
|
||||
COPILOT_EXECUTION_QUEUE_NAME,
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS,
|
||||
CancelCoPilotEvent,
|
||||
CoPilotExecutionEntry,
|
||||
create_copilot_queue_config,
|
||||
)
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
settings = Settings()
|
||||
|
||||
# Prometheus metrics
|
||||
active_tasks_gauge = Gauge(
|
||||
"copilot_executor_active_tasks",
|
||||
"Number of active CoPilot tasks",
|
||||
)
|
||||
pool_size_gauge = Gauge(
|
||||
"copilot_executor_pool_size",
|
||||
"Maximum number of CoPilot executor workers",
|
||||
)
|
||||
utilization_gauge = Gauge(
|
||||
"copilot_executor_utilization_ratio",
|
||||
"Ratio of active tasks to pool size",
|
||||
)
|
||||
|
||||
|
||||
class CoPilotExecutor(AppProcess):
|
||||
"""CoPilot Executor service for processing chat generation tasks.
|
||||
|
||||
This service consumes tasks from RabbitMQ, processes them using a thread pool,
|
||||
and publishes results to Redis Streams. It follows the graph executor pattern
|
||||
for reliable message handling and graceful shutdown.
|
||||
|
||||
Key features:
|
||||
- RabbitMQ-based task distribution with manual acknowledgment
|
||||
- Thread pool executor for concurrent task processing
|
||||
- Cluster lock for duplicate prevention across pods
|
||||
- Graceful shutdown with timeout for in-flight tasks
|
||||
- FANOUT exchange for cancellation broadcast
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pool_size = settings.config.num_copilot_workers
|
||||
self.active_tasks: dict[str, tuple[Future, threading.Event]] = {}
|
||||
self.executor_id = str(uuid.uuid4())
|
||||
|
||||
self._executor = None
|
||||
self._stop_consuming = None
|
||||
|
||||
self._cancel_thread = None
|
||||
self._cancel_client = None
|
||||
self._run_thread = None
|
||||
self._run_client = None
|
||||
|
||||
self._task_locks: dict[str, ClusterLock] = {}
|
||||
self._active_tasks_lock = threading.Lock()
|
||||
|
||||
# ============ Main Entry Points (AppProcess interface) ============ #
|
||||
|
||||
def run(self):
|
||||
"""Main service loop - consume from RabbitMQ."""
|
||||
logger.info(f"Pod assigned executor_id: {self.executor_id}")
|
||||
logger.info(f"Spawn max-{self.pool_size} workers...")
|
||||
|
||||
pool_size_gauge.set(self.pool_size)
|
||||
self._update_metrics()
|
||||
start_http_server(settings.config.copilot_executor_port)
|
||||
|
||||
self.cancel_thread.start()
|
||||
self.run_thread.start()
|
||||
|
||||
while True:
|
||||
time.sleep(1e5)
|
||||
|
||||
def cleanup(self):
|
||||
"""Graceful shutdown with active execution waiting."""
|
||||
pid = os.getpid()
|
||||
logger.info(f"[cleanup {pid}] Starting graceful shutdown...")
|
||||
|
||||
# Signal the consumer thread to stop
|
||||
try:
|
||||
self.stop_consuming.set()
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.connection.add_callback_threadsafe(
|
||||
lambda: run_channel.stop_consuming()
|
||||
)
|
||||
logger.info(f"[cleanup {pid}] Consumer has been signaled to stop")
|
||||
except Exception as e:
|
||||
logger.error(f"[cleanup {pid}] Error stopping consumer: {e}")
|
||||
|
||||
# Wait for active executions to complete
|
||||
if self.active_tasks:
|
||||
logger.info(
|
||||
f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
|
||||
)
|
||||
|
||||
start_time = time.monotonic()
|
||||
last_refresh = start_time
|
||||
lock_refresh_interval = settings.config.cluster_lock_timeout / 10
|
||||
|
||||
while (
|
||||
self.active_tasks
|
||||
and (time.monotonic() - start_time) < GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS
|
||||
):
|
||||
self._cleanup_completed_tasks()
|
||||
if not self.active_tasks:
|
||||
break
|
||||
|
||||
# Refresh cluster locks periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= lock_refresh_interval:
|
||||
for lock in list(self._task_locks.values()):
|
||||
try:
|
||||
lock.refresh()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[cleanup {pid}] Failed to refresh lock: {e}"
|
||||
)
|
||||
last_refresh = current_time
|
||||
|
||||
logger.info(
|
||||
f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..."
|
||||
)
|
||||
time.sleep(10.0)
|
||||
|
||||
# Stop message consumers
|
||||
if self._run_thread:
|
||||
self._stop_message_consumers(
|
||||
self._run_thread, self.run_client, "[cleanup][run]"
|
||||
)
|
||||
if self._cancel_thread:
|
||||
self._stop_message_consumers(
|
||||
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
|
||||
)
|
||||
|
||||
# Clean up worker threads (closes per-loop workspace storage sessions)
|
||||
if self._executor:
|
||||
from .processor import cleanup_worker
|
||||
|
||||
logger.info(f"[cleanup {pid}] Cleaning up workers...")
|
||||
futures = []
|
||||
for _ in range(self._executor._max_workers):
|
||||
futures.append(self._executor.submit(cleanup_worker))
|
||||
for f in futures:
|
||||
try:
|
||||
f.result(timeout=10)
|
||||
except Exception as e:
|
||||
logger.warning(f"[cleanup {pid}] Worker cleanup error: {e}")
|
||||
|
||||
logger.info(f"[cleanup {pid}] Shutting down executor...")
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
# Release any remaining locks
|
||||
for task_id, lock in list(self._task_locks.items()):
|
||||
try:
|
||||
lock.release()
|
||||
logger.info(f"[cleanup {pid}] Released lock for {task_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[cleanup {pid}] Failed to release lock for {task_id}: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
|
||||
|
||||
# ============ RabbitMQ Consumer Methods ============ #
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_cancel(self):
|
||||
"""Consume cancellation messages from FANOUT exchange."""
|
||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||
logger.info("Stop reconnecting cancel consumer - service cleaned up")
|
||||
return
|
||||
|
||||
if not self.cancel_client.is_ready:
|
||||
self.cancel_client.disconnect()
|
||||
self.cancel_client.connect()
|
||||
|
||||
# Check again after connect - shutdown may have been requested
|
||||
if self.stop_consuming.is_set() and not self.active_tasks:
|
||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||
self.cancel_client.disconnect()
|
||||
return
|
||||
|
||||
cancel_channel = self.cancel_client.get_channel()
|
||||
cancel_channel.basic_consume(
|
||||
queue=COPILOT_CANCEL_QUEUE_NAME,
|
||||
on_message_callback=self._handle_cancel_message,
|
||||
auto_ack=True,
|
||||
)
|
||||
logger.info("Starting to consume cancel messages...")
|
||||
cancel_channel.start_consuming()
|
||||
if not self.stop_consuming.is_set() or self.active_tasks:
|
||||
raise RuntimeError("Cancel message consumer stopped unexpectedly")
|
||||
logger.info("Cancel message consumer stopped gracefully")
|
||||
|
||||
@continuous_retry()
|
||||
def _consume_run(self):
|
||||
"""Consume run messages from DIRECT exchange."""
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Stop reconnecting run consumer - service cleaned up")
|
||||
return
|
||||
|
||||
if not self.run_client.is_ready:
|
||||
self.run_client.disconnect()
|
||||
self.run_client.connect()
|
||||
|
||||
# Check again after connect - shutdown may have been requested
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Stop consuming requested during reconnect - disconnecting")
|
||||
self.run_client.disconnect()
|
||||
return
|
||||
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.basic_qos(prefetch_count=self.pool_size)
|
||||
|
||||
run_channel.basic_consume(
|
||||
queue=COPILOT_EXECUTION_QUEUE_NAME,
|
||||
on_message_callback=self._handle_run_message,
|
||||
auto_ack=False,
|
||||
consumer_tag="copilot_execution_consumer",
|
||||
)
|
||||
logger.info("Starting to consume run messages...")
|
||||
run_channel.start_consuming()
|
||||
if not self.stop_consuming.is_set():
|
||||
raise RuntimeError("Run message consumer stopped unexpectedly")
|
||||
logger.info("Run message consumer stopped gracefully")
|
||||
|
||||
# ============ Message Handlers ============ #
|
||||
|
||||
@error_logged(swallow=True)
|
||||
def _handle_cancel_message(
|
||||
self,
|
||||
_channel: BlockingChannel,
|
||||
_method: Basic.Deliver,
|
||||
_properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""Handle cancel message from FANOUT exchange."""
|
||||
request = CancelCoPilotEvent.model_validate_json(body)
|
||||
task_id = request.task_id
|
||||
if not task_id:
|
||||
logger.warning("Cancel message missing 'task_id'")
|
||||
return
|
||||
if task_id not in self.active_tasks:
|
||||
logger.debug(f"Cancel received for {task_id} but not active")
|
||||
return
|
||||
|
||||
_, cancel_event = self.active_tasks[task_id]
|
||||
logger.info(f"Received cancel for {task_id}")
|
||||
if not cancel_event.is_set():
|
||||
cancel_event.set()
|
||||
else:
|
||||
logger.debug(f"Cancel already set for {task_id}")
|
||||
|
||||
def _handle_run_message(
|
||||
self,
|
||||
_channel: BlockingChannel,
|
||||
method: Basic.Deliver,
|
||||
_properties: BasicProperties,
|
||||
body: bytes,
|
||||
):
|
||||
"""Handle run message from DIRECT exchange."""
|
||||
delivery_tag = method.delivery_tag
|
||||
# Capture the channel used at message delivery time to ensure we ack
|
||||
# on the correct channel. Delivery tags are channel-scoped and become
|
||||
# invalid if the channel is recreated after reconnection.
|
||||
delivery_channel = _channel
|
||||
|
||||
def ack_message(reject: bool, requeue: bool):
|
||||
"""Acknowledge or reject the message.
|
||||
|
||||
Uses the channel from the original message delivery. If the channel
|
||||
is no longer open (e.g., after reconnection), logs a warning and
|
||||
skips the ack - RabbitMQ will redeliver the message automatically.
|
||||
"""
|
||||
try:
|
||||
if not delivery_channel.is_open:
|
||||
logger.warning(
|
||||
f"Channel closed, cannot ack delivery_tag={delivery_tag}. "
|
||||
"Message will be redelivered by RabbitMQ."
|
||||
)
|
||||
return
|
||||
|
||||
if reject:
|
||||
delivery_channel.connection.add_callback_threadsafe(
|
||||
lambda: delivery_channel.basic_nack(
|
||||
delivery_tag, requeue=requeue
|
||||
)
|
||||
)
|
||||
else:
|
||||
delivery_channel.connection.add_callback_threadsafe(
|
||||
lambda: delivery_channel.basic_ack(delivery_tag)
|
||||
)
|
||||
except (AMQPChannelError, AMQPConnectionError) as e:
|
||||
# Channel/connection errors indicate stale delivery tag - don't retry
|
||||
logger.warning(
|
||||
f"Cannot ack delivery_tag={delivery_tag} due to channel/connection "
|
||||
f"error: {e}. Message will be redelivered by RabbitMQ."
|
||||
)
|
||||
except Exception as e:
|
||||
# Other errors might be transient, but log and skip to avoid blocking
|
||||
logger.error(
|
||||
f"Unexpected error acking delivery_tag={delivery_tag}: {e}"
|
||||
)
|
||||
|
||||
# Check if we're shutting down
|
||||
if self.stop_consuming.is_set():
|
||||
logger.info("Rejecting new task during shutdown")
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Check if we can accept more tasks
|
||||
self._cleanup_completed_tasks()
|
||||
if len(self.active_tasks) >= self.pool_size:
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
try:
|
||||
entry = CoPilotExecutionEntry.model_validate_json(body)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not parse run message: {e}, body={body}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
task_id = entry.task_id
|
||||
|
||||
# Check for local duplicate - task is already running on this executor
|
||||
if task_id in self.active_tasks:
|
||||
logger.warning(
|
||||
f"Task {task_id} already running locally, rejecting duplicate"
|
||||
)
|
||||
ack_message(reject=True, requeue=False)
|
||||
return
|
||||
|
||||
# Try to acquire cluster-wide lock
|
||||
cluster_lock = ClusterLock(
|
||||
redis=redis.get_redis(),
|
||||
key=f"copilot:task:{task_id}:lock",
|
||||
owner_id=self.executor_id,
|
||||
timeout=settings.config.cluster_lock_timeout,
|
||||
)
|
||||
current_owner = cluster_lock.try_acquire()
|
||||
if current_owner != self.executor_id:
|
||||
if current_owner is not None:
|
||||
logger.warning(f"Task {task_id} already running on pod {current_owner}")
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not acquire lock for {task_id} - Redis unavailable"
|
||||
)
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
# Execute the task
|
||||
try:
|
||||
self._task_locks[task_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}"
|
||||
)
|
||||
|
||||
cancel_event = threading.Event()
|
||||
future = self.executor.submit(
|
||||
execute_copilot_task, entry, cancel_event, cluster_lock
|
||||
)
|
||||
self.active_tasks[task_id] = (future, cancel_event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to setup execution for {task_id}: {e}")
|
||||
cluster_lock.release()
|
||||
if task_id in self._task_locks:
|
||||
del self._task_locks[task_id]
|
||||
ack_message(reject=True, requeue=True)
|
||||
return
|
||||
|
||||
self._update_metrics()
|
||||
|
||||
def on_run_done(f: Future):
|
||||
logger.info(f"Run completed for {task_id}")
|
||||
try:
|
||||
if exec_error := f.exception():
|
||||
logger.error(f"Execution for {task_id} failed: {exec_error}")
|
||||
# Don't requeue failed tasks - they've been marked as failed
|
||||
# in the stream registry. Requeuing would cause infinite retries
|
||||
# for deterministic failures.
|
||||
ack_message(reject=True, requeue=False)
|
||||
else:
|
||||
ack_message(reject=False, requeue=False)
|
||||
except BaseException as e:
|
||||
logger.exception(f"Error in run completion callback: {e}")
|
||||
finally:
|
||||
# Release the cluster lock
|
||||
if task_id in self._task_locks:
|
||||
logger.info(f"Releasing cluster lock for {task_id}")
|
||||
self._task_locks[task_id].release()
|
||||
del self._task_locks[task_id]
|
||||
self._cleanup_completed_tasks()
|
||||
|
||||
future.add_done_callback(on_run_done)
|
||||
|
||||
# ============ Helper Methods ============ #
|
||||
|
||||
def _cleanup_completed_tasks(self) -> list[str]:
|
||||
"""Remove completed futures from active_tasks and update metrics."""
|
||||
completed_tasks = []
|
||||
with self._active_tasks_lock:
|
||||
for task_id, (future, _) in list(self.active_tasks.items()):
|
||||
if future.done():
|
||||
completed_tasks.append(task_id)
|
||||
self.active_tasks.pop(task_id, None)
|
||||
logger.info(f"Cleaned up completed task {task_id}")
|
||||
|
||||
self._update_metrics()
|
||||
return completed_tasks
|
||||
|
||||
def _update_metrics(self):
|
||||
"""Update Prometheus metrics."""
|
||||
active_count = len(self.active_tasks)
|
||||
active_tasks_gauge.set(active_count)
|
||||
if self.stop_consuming.is_set():
|
||||
utilization_gauge.set(1.0)
|
||||
else:
|
||||
utilization_gauge.set(
|
||||
active_count / self.pool_size if self.pool_size > 0 else 0
|
||||
)
|
||||
|
||||
def _stop_message_consumers(
|
||||
self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str
|
||||
):
|
||||
"""Stop a message consumer thread."""
|
||||
try:
|
||||
channel = client.get_channel()
|
||||
channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming())
|
||||
|
||||
thread.join(timeout=300)
|
||||
if thread.is_alive():
|
||||
logger.error(
|
||||
f"{prefix} Thread did not finish in time, forcing disconnect"
|
||||
)
|
||||
|
||||
client.disconnect()
|
||||
logger.info(f"{prefix} Client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"{prefix} Error disconnecting client: {e}")
|
||||
|
||||
# ============ Lazy-initialized Properties ============ #
|
||||
|
||||
@property
|
||||
def cancel_thread(self) -> threading.Thread:
|
||||
if self._cancel_thread is None:
|
||||
self._cancel_thread = threading.Thread(
|
||||
target=lambda: self._consume_cancel(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._cancel_thread
|
||||
|
||||
@property
|
||||
def run_thread(self) -> threading.Thread:
|
||||
if self._run_thread is None:
|
||||
self._run_thread = threading.Thread(
|
||||
target=lambda: self._consume_run(),
|
||||
daemon=True,
|
||||
)
|
||||
return self._run_thread
|
||||
|
||||
@property
|
||||
def stop_consuming(self) -> threading.Event:
|
||||
if self._stop_consuming is None:
|
||||
self._stop_consuming = threading.Event()
|
||||
return self._stop_consuming
|
||||
|
||||
@property
|
||||
def executor(self) -> ThreadPoolExecutor:
|
||||
if self._executor is None:
|
||||
self._executor = ThreadPoolExecutor(
|
||||
max_workers=self.pool_size,
|
||||
initializer=init_worker,
|
||||
)
|
||||
return self._executor
|
||||
|
||||
@property
|
||||
def cancel_client(self) -> SyncRabbitMQ:
|
||||
if self._cancel_client is None:
|
||||
self._cancel_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||
return self._cancel_client
|
||||
|
||||
@property
|
||||
def run_client(self) -> SyncRabbitMQ:
|
||||
if self._run_client is None:
|
||||
self._run_client = SyncRabbitMQ(create_copilot_queue_config())
|
||||
return self._run_client
|
||||
@@ -1,291 +0,0 @@
|
||||
"""CoPilot execution processor - per-worker execution logic.
|
||||
|
||||
This module contains the processor class that handles CoPilot task execution
|
||||
in a thread-local context, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from backend.copilot import service as copilot_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.process import set_service_name
|
||||
from backend.util.retry import func_retry
|
||||
|
||||
from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
|
||||
|
||||
# ============ Module Entry Points ============ #
|
||||
|
||||
# Thread-local storage for processor instances
|
||||
_tls = threading.local()
|
||||
|
||||
|
||||
def execute_copilot_task(
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot task using the thread-local processor.
|
||||
|
||||
This function is the entry point called by the thread pool executor.
|
||||
|
||||
Args:
|
||||
entry: The task payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for this execution
|
||||
"""
|
||||
processor: CoPilotProcessor = _tls.processor
|
||||
return processor.execute(entry, cancel, cluster_lock)
|
||||
|
||||
|
||||
def init_worker():
|
||||
"""Initialize the processor for the current worker thread.
|
||||
|
||||
This function is called by the thread pool executor when a new worker
|
||||
thread is created. It ensures each worker has its own processor instance.
|
||||
"""
|
||||
_tls.processor = CoPilotProcessor()
|
||||
_tls.processor.on_executor_start()
|
||||
|
||||
|
||||
def cleanup_worker():
|
||||
"""Clean up the processor for the current worker thread.
|
||||
|
||||
Should be called before the worker thread's event loop is destroyed so
|
||||
that event-loop-bound resources (e.g. ``aiohttp.ClientSession``) are
|
||||
closed on the correct loop.
|
||||
"""
|
||||
processor: CoPilotProcessor | None = getattr(_tls, "processor", None)
|
||||
if processor is not None:
|
||||
processor.cleanup()
|
||||
|
||||
|
||||
# ============ Processor Class ============ #
|
||||
|
||||
|
||||
class CoPilotProcessor:
|
||||
"""Per-worker execution logic for CoPilot tasks.
|
||||
|
||||
This class is instantiated once per worker thread and handles the execution
|
||||
of CoPilot chat generation tasks. It maintains an async event loop for
|
||||
running the async service code.
|
||||
|
||||
The execution flow:
|
||||
1. CoPilot task is picked from RabbitMQ queue
|
||||
2. Manager submits task to thread pool
|
||||
3. Processor executes the task in its event loop
|
||||
4. Results are published to Redis Streams
|
||||
"""
|
||||
|
||||
@func_retry
|
||||
def on_executor_start(self):
|
||||
"""Initialize the processor when the worker thread starts.
|
||||
|
||||
This method is called once per worker thread to set up the async event
|
||||
loop and initialize any required resources.
|
||||
|
||||
Database is accessed only through DatabaseManager, so we don't need to connect
|
||||
to Prisma directly.
|
||||
"""
|
||||
configure_logging()
|
||||
set_service_name("CoPilotExecutor")
|
||||
self.tid = threading.get_ident()
|
||||
self.execution_loop = asyncio.new_event_loop()
|
||||
self.execution_thread = threading.Thread(
|
||||
target=self.execution_loop.run_forever, daemon=True
|
||||
)
|
||||
self.execution_thread.start()
|
||||
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} started")
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up event-loop-bound resources before the loop is destroyed.
|
||||
|
||||
Shuts down the workspace storage instance that belongs to this
|
||||
worker's event loop, ensuring ``aiohttp.ClientSession.close()``
|
||||
runs on the same loop that created the session.
|
||||
"""
|
||||
from backend.util.workspace_storage import shutdown_workspace_storage
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
shutdown_workspace_storage(), self.execution_loop
|
||||
)
|
||||
future.result(timeout=5)
|
||||
except Exception as e:
|
||||
logger.warning(f"[CoPilotExecutor] Worker {self.tid} cleanup error: {e}")
|
||||
|
||||
# Stop the event loop
|
||||
self.execution_loop.call_soon_threadsafe(self.execution_loop.stop)
|
||||
self.execution_thread.join(timeout=5)
|
||||
logger.info(f"[CoPilotExecutor] Worker {self.tid} cleaned up")
|
||||
|
||||
@error_logged(swallow=False)
|
||||
def execute(
|
||||
self,
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
):
|
||||
"""Execute a CoPilot task.
|
||||
|
||||
This is the main entry point for task execution. It runs the async
|
||||
execution logic in the worker's event loop and handles errors.
|
||||
|
||||
Args:
|
||||
entry: The task payload containing session and message info
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock to prevent duplicate execution
|
||||
"""
|
||||
log = CoPilotLogMetadata(
|
||||
logging.getLogger(__name__),
|
||||
task_id=entry.task_id,
|
||||
session_id=entry.session_id,
|
||||
user_id=entry.user_id,
|
||||
)
|
||||
log.info("Starting execution")
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
try:
|
||||
# Run the async execution in our event loop
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._execute_async(entry, cancel, cluster_lock, log),
|
||||
self.execution_loop,
|
||||
)
|
||||
|
||||
# Wait for completion, checking cancel periodically
|
||||
while not future.done():
|
||||
try:
|
||||
future.result(timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
if cancel.is_set():
|
||||
log.info("Cancellation requested")
|
||||
future.cancel()
|
||||
break
|
||||
# Refresh cluster lock to maintain ownership
|
||||
cluster_lock.refresh()
|
||||
|
||||
if not future.cancelled():
|
||||
# Get result to propagate any exceptions
|
||||
future.result()
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.info(f"Execution completed in {elapsed:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.error(f"Execution failed after {elapsed:.2f}s: {e}")
|
||||
# Note: _execute_async already marks the task as failed before re-raising,
|
||||
# so we don't call _mark_task_failed here to avoid duplicate error events.
|
||||
raise
|
||||
|
||||
async def _execute_async(
|
||||
self,
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
log: CoPilotLogMetadata,
|
||||
):
|
||||
"""Async execution logic for CoPilot task.
|
||||
|
||||
This method calls the existing stream_chat_completion service function
|
||||
and publishes results to the stream registry.
|
||||
|
||||
Args:
|
||||
entry: The task payload
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock for refresh
|
||||
log: Structured logger for this task
|
||||
"""
|
||||
last_refresh = time.monotonic()
|
||||
refresh_interval = 30.0 # Refresh lock every 30 seconds
|
||||
|
||||
try:
|
||||
# Choose service based on LaunchDarkly flag
|
||||
config = ChatConfig()
|
||||
use_sdk = await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else copilot_service.stream_chat_completion
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'standard'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis
|
||||
async for chunk in stream_fn(
|
||||
session_id=entry.session_id,
|
||||
message=entry.message if entry.message else None,
|
||||
is_user_message=entry.is_user_message,
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
):
|
||||
# Check for cancellation
|
||||
if cancel.is_set():
|
||||
log.info("Cancelled during streaming")
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamError(errorText="Operation cancelled")
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
entry.task_id, StreamFinishStep()
|
||||
)
|
||||
await stream_registry.publish_chunk(entry.task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(
|
||||
entry.task_id, status="failed"
|
||||
)
|
||||
return
|
||||
|
||||
# Refresh cluster lock periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= refresh_interval:
|
||||
cluster_lock.refresh()
|
||||
last_refresh = current_time
|
||||
|
||||
# Publish chunk to stream registry
|
||||
await stream_registry.publish_chunk(entry.task_id, chunk)
|
||||
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(entry.task_id, status="completed")
|
||||
log.info("Task completed successfully")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.info("Task cancelled")
|
||||
await stream_registry.mark_task_completed(
|
||||
entry.task_id,
|
||||
status="failed",
|
||||
error_message="Task was cancelled",
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
log.error(f"Task failed: {e}")
|
||||
await self._mark_task_failed(entry.task_id, str(e))
|
||||
raise
|
||||
|
||||
async def _mark_task_failed(self, task_id: str, error_message: str):
|
||||
"""Mark a task as failed and publish error to stream registry."""
|
||||
try:
|
||||
await stream_registry.publish_chunk(
|
||||
task_id, StreamError(errorText=error_message)
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinishStep())
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to mark task {task_id} as failed: {e}")
|
||||
@@ -1,224 +0,0 @@
|
||||
"""RabbitMQ queue configuration for CoPilot executor.
|
||||
|
||||
Defines two exchanges and queues following the graph executor pattern:
|
||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============ Logging Helper ============ #
|
||||
|
||||
|
||||
class CoPilotLogMetadata(TruncatedLogger):
|
||||
"""Structured logging helper for CoPilot executor.
|
||||
|
||||
In cloud environments (structured logging enabled), uses a simple prefix
|
||||
and passes metadata via json_fields. In local environments, uses a detailed
|
||||
prefix with all metadata key-value pairs for easier debugging.
|
||||
|
||||
Args:
|
||||
logger: The underlying logger instance
|
||||
max_length: Maximum log message length before truncation
|
||||
**kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz")
|
||||
These are added to json_fields in cloud mode, or to the prefix in local mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger: logging.Logger,
|
||||
max_length: int = 1000,
|
||||
**kwargs: str | None,
|
||||
):
|
||||
# Filter out None values
|
||||
metadata = {k: v for k, v in kwargs.items() if v is not None}
|
||||
metadata["component"] = "CoPilotExecutor"
|
||||
|
||||
if is_structured_logging_enabled():
|
||||
prefix = "[CoPilotExecutor]"
|
||||
else:
|
||||
# Build prefix from metadata key-value pairs
|
||||
meta_parts = "|".join(
|
||||
f"{k}:{v}" for k, v in metadata.items() if k != "component"
|
||||
)
|
||||
prefix = (
|
||||
f"[CoPilotExecutor|{meta_parts}]" if meta_parts else "[CoPilotExecutor]"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
logger,
|
||||
max_length=max_length,
|
||||
prefix=prefix,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
# ============ Exchange and Queue Configuration ============ #
|
||||
|
||||
COPILOT_EXECUTION_EXCHANGE = Exchange(
|
||||
name="copilot_execution",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue"
|
||||
COPILOT_EXECUTION_ROUTING_KEY = "copilot.run"
|
||||
|
||||
COPILOT_CANCEL_EXCHANGE = Exchange(
|
||||
name="copilot_cancel",
|
||||
type=ExchangeType.FANOUT,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue"
|
||||
|
||||
# CoPilot operations can include extended thinking and agent generation
|
||||
# which may take 30+ minutes to complete
|
||||
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
|
||||
|
||||
# Graceful shutdown timeout - allow in-flight operations to complete
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def create_copilot_queue_config() -> RabbitMQConfig:
|
||||
"""Create RabbitMQ configuration for CoPilot executor.
|
||||
|
||||
Defines two exchanges and queues:
|
||||
- 'copilot_execution' (DIRECT) for chat generation tasks
|
||||
- 'copilot_cancel' (FANOUT) for cancellation requests
|
||||
|
||||
Returns:
|
||||
RabbitMQConfig with exchanges and queues defined
|
||||
"""
|
||||
run_queue = Queue(
|
||||
name=COPILOT_EXECUTION_QUEUE_NAME,
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
arguments={
|
||||
# Extended consumer timeout for long-running LLM operations
|
||||
# Default 30-minute timeout is insufficient for extended thinking
|
||||
# and agent generation which can take 30+ minutes
|
||||
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||
* 1000,
|
||||
},
|
||||
)
|
||||
cancel_queue = Queue(
|
||||
name=COPILOT_CANCEL_QUEUE_NAME,
|
||||
exchange=COPILOT_CANCEL_EXCHANGE,
|
||||
routing_key="", # not used for FANOUT
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
)
|
||||
return RabbitMQConfig(
|
||||
vhost="/",
|
||||
exchanges=[COPILOT_EXECUTION_EXCHANGE, COPILOT_CANCEL_EXCHANGE],
|
||||
queues=[run_queue, cancel_queue],
|
||||
)
|
||||
|
||||
|
||||
# ============ Message Models ============ #
|
||||
|
||||
|
||||
class CoPilotExecutionEntry(BaseModel):
|
||||
"""Task payload for CoPilot AI generation.
|
||||
|
||||
This model represents a chat generation task to be processed by the executor.
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
"""Unique identifier for this task (used for stream registry)"""
|
||||
|
||||
session_id: str
|
||||
"""Chat session ID"""
|
||||
|
||||
user_id: str | None
|
||||
"""User ID (may be None for anonymous users)"""
|
||||
|
||||
operation_id: str
|
||||
"""Operation ID for webhook callbacks and completion tracking"""
|
||||
|
||||
message: str
|
||||
"""User's message to process"""
|
||||
|
||||
is_user_message: bool = True
|
||||
"""Whether the message is from the user (vs system/assistant)"""
|
||||
|
||||
context: dict[str, str] | None = None
|
||||
"""Optional context for the message (e.g., {url: str, content: str})"""
|
||||
|
||||
|
||||
class CancelCoPilotEvent(BaseModel):
|
||||
"""Event to cancel a CoPilot operation."""
|
||||
|
||||
task_id: str
|
||||
"""Task ID to cancel"""
|
||||
|
||||
|
||||
# ============ Queue Publishing Helpers ============ #
|
||||
|
||||
|
||||
async def enqueue_copilot_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
operation_id: str,
|
||||
message: str,
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for this task (used for stream registry)
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (may be None for anonymous users)
|
||||
operation_id: Operation ID for webhook callbacks and completion tracking
|
||||
message: User's message to process
|
||||
is_user_message: Whether the message is from the user (vs system/assistant)
|
||||
context: Optional context for the message (e.g., {url: str, content: str})
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
entry = CoPilotExecutionEntry(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
operation_id=operation_id,
|
||||
message=message,
|
||||
is_user_message=is_user_message,
|
||||
context=context,
|
||||
)
|
||||
|
||||
queue_client = await get_async_copilot_queue()
|
||||
await queue_client.publish_message(
|
||||
routing_key=COPILOT_EXECUTION_ROUTING_KEY,
|
||||
message=entry.model_dump_json(),
|
||||
exchange=COPILOT_EXECUTION_EXCHANGE,
|
||||
)
|
||||
|
||||
|
||||
async def enqueue_cancel_task(task_id: str) -> None:
|
||||
"""Publish a cancel request for a running CoPilot task.
|
||||
|
||||
Sends a ``CancelCoPilotEvent`` to the FANOUT exchange so all executor
|
||||
pods receive the cancellation signal.
|
||||
"""
|
||||
from backend.util.clients import get_async_copilot_queue
|
||||
|
||||
event = CancelCoPilotEvent(task_id=task_id)
|
||||
queue_client = await get_async_copilot_queue()
|
||||
await queue_client.publish_message(
|
||||
routing_key="", # FANOUT ignores routing key
|
||||
message=event.model_dump_json(),
|
||||
exchange=COPILOT_CANCEL_EXCHANGE,
|
||||
)
|
||||
@@ -1,272 +0,0 @@
|
||||
"""Tests for parallel tool call execution in CoPilot.
|
||||
|
||||
These tests mock _yield_tool_call to avoid importing the full copilot stack
|
||||
which requires Prisma, DB connections, etc.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_tool_calls_run_concurrently():
|
||||
"""Multiple tool calls should complete in ~max(delays), not sum(delays)."""
|
||||
# Import here to allow module-level mocking if needed
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
n_tools = 3
|
||||
delay_per_tool = 0.2
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"tool_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(n_tools)
|
||||
]
|
||||
|
||||
# Minimal session mock
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
original_yield = None
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
input={},
|
||||
)
|
||||
await asyncio.sleep(delay_per_tool)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"],
|
||||
toolName=tc_list[idx]["function"]["name"],
|
||||
output="{}",
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
original_yield = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
start = time.monotonic()
|
||||
events = []
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
elapsed = time.monotonic() - start
|
||||
finally:
|
||||
svc._yield_tool_call = original_yield
|
||||
|
||||
assert len(events) == n_tools * 2
|
||||
# Parallel: should take ~delay, not ~n*delay
|
||||
assert elapsed < delay_per_tool * (
|
||||
n_tools - 0.5
|
||||
), f"Took {elapsed:.2f}s, expected parallel (~{delay_per_tool}s)"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_tool_call_works():
|
||||
"""Single tool call should work identically."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": "call_0",
|
||||
"type": "function",
|
||||
"function": {"name": "t", "arguments": "{}"},
|
||||
}
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
yield StreamToolInputAvailable(toolCallId="call_0", toolName="t", input={})
|
||||
yield StreamToolOutputAvailable(toolCallId="call_0", toolName="t", output="{}")
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = [
|
||||
e
|
||||
async for e in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
)
|
||||
]
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(events) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_error_propagates():
|
||||
"""Retryable errors should be raised after all tools finish."""
|
||||
from backend.copilot.response_model import StreamToolOutputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
if idx == 1:
|
||||
raise KeyError("bad")
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", input={}
|
||||
)
|
||||
await asyncio.sleep(0.05)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName="t_0", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
events = []
|
||||
with pytest.raises(KeyError):
|
||||
async for event in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
events.append(event)
|
||||
# First tool's events should still be yielded
|
||||
assert any(isinstance(e, StreamToolOutputAvailable) for e in events)
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_lock_shared():
|
||||
"""All parallel tools should receive the same lock instance."""
|
||||
from backend.copilot.response_model import (
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
observed_locks = []
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
observed_locks.append(lock)
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
yield StreamToolOutputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", output="{}"
|
||||
)
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
async for _ in _execute_tool_calls_parallel(
|
||||
tool_calls, cast(Any, FakeSession())
|
||||
):
|
||||
pass
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
|
||||
assert len(observed_locks) == 3
|
||||
assert observed_locks[0] is observed_locks[1] is observed_locks[2]
|
||||
assert isinstance(observed_locks[0], asyncio.Lock)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancellation_cleans_up():
|
||||
"""Generator close should cancel in-flight tasks."""
|
||||
from backend.copilot.response_model import StreamToolInputAvailable
|
||||
from backend.copilot.service import _execute_tool_calls_parallel
|
||||
|
||||
tool_calls = [
|
||||
{
|
||||
"id": f"call_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": f"t_{i}", "arguments": "{}"},
|
||||
}
|
||||
for i in range(2)
|
||||
]
|
||||
|
||||
class FakeSession:
|
||||
session_id = "test"
|
||||
user_id = "test"
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def fake_yield(tc_list, idx, sess, lock=None):
|
||||
yield StreamToolInputAvailable(
|
||||
toolCallId=tc_list[idx]["id"], toolName=f"t_{idx}", input={}
|
||||
)
|
||||
started.set()
|
||||
await asyncio.sleep(10) # simulate long-running
|
||||
|
||||
import backend.copilot.service as svc
|
||||
|
||||
orig = svc._yield_tool_call
|
||||
svc._yield_tool_call = fake_yield
|
||||
try:
|
||||
gen = _execute_tool_calls_parallel(tool_calls, cast(Any, FakeSession()))
|
||||
await gen.__anext__() # get first event
|
||||
await started.wait()
|
||||
await gen.aclose() # close generator
|
||||
finally:
|
||||
svc._yield_tool_call = orig
|
||||
# If we get here without hanging, cleanup worked
|
||||
@@ -1,221 +0,0 @@
|
||||
"""Tests for _format_conversation_context and _build_query_message."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.sdk.service import (
|
||||
_build_query_message,
|
||||
_format_conversation_context,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _format_conversation_context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_format_empty_list():
|
||||
assert _format_conversation_context([]) is None
|
||||
|
||||
|
||||
def test_format_none_content_messages():
|
||||
msgs = [ChatMessage(role="user", content=None)]
|
||||
assert _format_conversation_context(msgs) is None
|
||||
|
||||
|
||||
def test_format_user_message():
|
||||
msgs = [ChatMessage(role="user", content="hello")]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "User: hello" in result
|
||||
assert result.startswith("<conversation_history>")
|
||||
assert result.endswith("</conversation_history>")
|
||||
|
||||
|
||||
def test_format_assistant_text():
|
||||
msgs = [ChatMessage(role="assistant", content="hi there")]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "You responded: hi there" in result
|
||||
|
||||
|
||||
def test_format_assistant_tool_calls():
|
||||
msgs = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[{"function": {"name": "search", "arguments": '{"q": "test"}'}}],
|
||||
)
|
||||
]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert 'You called tool: search({"q": "test"})' in result
|
||||
|
||||
|
||||
def test_format_tool_result():
|
||||
msgs = [ChatMessage(role="tool", content='{"result": "ok"}')]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert 'Tool result: {"result": "ok"}' in result
|
||||
|
||||
|
||||
def test_format_tool_result_none_content():
|
||||
msgs = [ChatMessage(role="tool", content=None)]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "Tool result: " in result
|
||||
|
||||
|
||||
def test_format_full_conversation():
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="find agents"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="I'll search for agents.",
|
||||
tool_calls=[
|
||||
{"function": {"name": "find_agents", "arguments": '{"q": "test"}'}}
|
||||
],
|
||||
),
|
||||
ChatMessage(role="tool", content='[{"id": "1", "name": "Agent1"}]'),
|
||||
ChatMessage(role="assistant", content="Found Agent1."),
|
||||
]
|
||||
result = _format_conversation_context(msgs)
|
||||
assert result is not None
|
||||
assert "User: find agents" in result
|
||||
assert "You responded: I'll search for agents." in result
|
||||
assert "You called tool: find_agents" in result
|
||||
assert "Tool result:" in result
|
||||
assert "You responded: Found Agent1." in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_query_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(messages: list[ChatMessage]) -> ChatSession:
|
||||
"""Build a minimal ChatSession with the given messages."""
|
||||
now = datetime.now(UTC)
|
||||
return ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=messages,
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_up_to_date():
|
||||
"""With --resume and transcript covers all messages, return raw message."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
ChatMessage(role="user", content="what's new?"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
"what's new?",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
)
|
||||
# transcript_msg_count == msg_count - 1, so no gap
|
||||
assert result == "what's new?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_stale_transcript():
|
||||
"""With --resume and stale transcript, gap context is prepended."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="turn 1"),
|
||||
ChatMessage(role="assistant", content="reply 1"),
|
||||
ChatMessage(role="user", content="turn 2"),
|
||||
ChatMessage(role="assistant", content="reply 2"),
|
||||
ChatMessage(role="user", content="turn 3"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
"turn 3",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
assert "turn 2" in result
|
||||
assert "reply 2" in result
|
||||
assert "Now, the user says:\nturn 3" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_zero_msg_count():
|
||||
"""With --resume but transcript_msg_count=0, return raw message."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="hello"),
|
||||
ChatMessage(role="assistant", content="hi"),
|
||||
ChatMessage(role="user", content="new msg"),
|
||||
]
|
||||
)
|
||||
result = await _build_query_message(
|
||||
"new msg",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert result == "new msg"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_single_message():
|
||||
"""Without --resume and only 1 message, return raw message."""
|
||||
session = _make_session([ChatMessage(role="user", content="first")])
|
||||
result = await _build_query_message(
|
||||
"first",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert result == "first"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_multi_message(monkeypatch):
|
||||
"""Without --resume and multiple messages, compress and prepend."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="older question"),
|
||||
ChatMessage(role="assistant", content="older answer"),
|
||||
ChatMessage(role="user", content="new question"),
|
||||
]
|
||||
)
|
||||
|
||||
# Mock _compress_conversation_history to return the messages as-is
|
||||
async def _mock_compress(sess):
|
||||
return sess.messages[:-1]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_conversation_history",
|
||||
_mock_compress,
|
||||
)
|
||||
|
||||
result = await _build_query_message(
|
||||
"new question",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
assert "older question" in result
|
||||
assert "older answer" in result
|
||||
assert "Now, the user says:\nnew question" in result
|
||||
@@ -1,376 +0,0 @@
|
||||
"""Response adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This module provides the adapter layer that converts streaming messages from
|
||||
the Claude Agent SDK into the Vercel AI SDK UI Stream Protocol format that
|
||||
the frontend expects.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
Message,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
)
|
||||
|
||||
from .tool_adapter import MCP_TOOL_PREFIX, pop_pending_tool_output
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SDKResponseAdapter:
|
||||
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
This class maintains state during a streaming session to properly track
|
||||
text blocks, tool calls, and message lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, message_id: str | None = None, session_id: str | None = None):
|
||||
self.message_id = message_id or str(uuid.uuid4())
|
||||
self.session_id = session_id
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_started_text = False
|
||||
self.has_ended_text = False
|
||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||
self.resolved_tool_calls: set[str] = set()
|
||||
self.task_id: str | None = None
|
||||
self.step_open = False
|
||||
|
||||
def set_task_id(self, task_id: str) -> None:
|
||||
"""Set the task ID for reconnection support."""
|
||||
self.task_id = task_id
|
||||
|
||||
@property
|
||||
def has_unresolved_tool_calls(self) -> bool:
|
||||
"""True when there are tool calls that haven't received output yet."""
|
||||
return bool(self.current_tool_calls.keys() - self.resolved_tool_calls)
|
||||
|
||||
def convert_message(self, sdk_message: Message) -> list[StreamBaseResponse]:
|
||||
"""Convert a single SDK message to Vercel AI SDK format."""
|
||||
responses: list[StreamBaseResponse] = []
|
||||
|
||||
if isinstance(sdk_message, SystemMessage):
|
||||
if sdk_message.subtype == "init":
|
||||
responses.append(
|
||||
StreamStart(messageId=self.message_id, taskId=self.task_id)
|
||||
)
|
||||
# Open the first step (matches non-SDK: StreamStart then StreamStartStep)
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
# Flush any SDK built-in tool calls that didn't get a UserMessage
|
||||
# result (e.g. WebSearch, Read handled internally by the CLI).
|
||||
# BUT skip flush when this AssistantMessage is a parallel tool
|
||||
# continuation (contains only ToolUseBlocks) — the prior tools
|
||||
# are still executing concurrently and haven't finished yet.
|
||||
is_tool_only = all(isinstance(b, ToolUseBlock) for b in sdk_message.content)
|
||||
if not is_tool_only:
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
|
||||
# After tool results, the SDK sends a new AssistantMessage for the
|
||||
# next LLM turn. Open a new step if the previous one was closed.
|
||||
if not self.step_open:
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
|
||||
for block in sdk_message.content:
|
||||
if isinstance(block, TextBlock):
|
||||
if block.text:
|
||||
self._ensure_text_started(responses)
|
||||
responses.append(
|
||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
||||
)
|
||||
|
||||
elif isinstance(block, ToolUseBlock):
|
||||
self._end_text_if_open(responses)
|
||||
|
||||
# Strip MCP prefix so frontend sees "find_block"
|
||||
# instead of "mcp__copilot__find_block".
|
||||
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
responses.append(
|
||||
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
|
||||
)
|
||||
responses.append(
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=block.id,
|
||||
toolName=tool_name,
|
||||
input=block.input,
|
||||
)
|
||||
)
|
||||
self.current_tool_calls[block.id] = {"name": tool_name}
|
||||
|
||||
elif isinstance(sdk_message, UserMessage):
|
||||
# UserMessage carries tool results back from tool execution.
|
||||
content = sdk_message.content
|
||||
blocks = content if isinstance(content, list) else []
|
||||
resolved_in_blocks: set[str] = set()
|
||||
|
||||
sid = (self.session_id or "?")[:12]
|
||||
parent_id_preview = getattr(sdk_message, "parent_tool_use_id", None)
|
||||
logger.info(
|
||||
"[SDK] [%s] UserMessage: %d blocks, content_type=%s, "
|
||||
"parent_tool_use_id=%s",
|
||||
sid,
|
||||
len(blocks),
|
||||
type(content).__name__,
|
||||
parent_id_preview[:12] if parent_id_preview else "None",
|
||||
)
|
||||
|
||||
for block in blocks:
|
||||
if isinstance(block, ToolResultBlock) and block.tool_use_id:
|
||||
# Skip if already resolved (e.g. by flush) — the real
|
||||
# result supersedes the empty flush, but re-emitting
|
||||
# would confuse the frontend's state machine.
|
||||
if block.tool_use_id in self.resolved_tool_calls:
|
||||
continue
|
||||
tool_info = self.current_tool_calls.get(block.tool_use_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
# Prefer the stashed full output over the SDK's
|
||||
# (potentially truncated) ToolResultBlock content.
|
||||
# The SDK truncates large results, writing them to disk,
|
||||
# which breaks frontend widget parsing.
|
||||
output = pop_pending_tool_output(tool_name) or (
|
||||
_extract_tool_output(block.content)
|
||||
)
|
||||
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=block.tool_use_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=not (block.is_error or False),
|
||||
)
|
||||
)
|
||||
resolved_in_blocks.add(block.tool_use_id)
|
||||
|
||||
# Handle SDK built-in tool results carried via parent_tool_use_id
|
||||
# instead of (or in addition to) ToolResultBlock content.
|
||||
parent_id = sdk_message.parent_tool_use_id
|
||||
if (
|
||||
parent_id
|
||||
and parent_id not in resolved_in_blocks
|
||||
and parent_id not in self.resolved_tool_calls
|
||||
):
|
||||
tool_info = self.current_tool_calls.get(parent_id, {})
|
||||
tool_name = tool_info.get("name", "unknown")
|
||||
|
||||
# Try stashed output first (from PostToolUse hook),
|
||||
# then tool_use_result dict, then string content.
|
||||
output = pop_pending_tool_output(tool_name)
|
||||
if not output:
|
||||
tur = sdk_message.tool_use_result
|
||||
if tur is not None:
|
||||
output = _extract_tool_use_result(tur)
|
||||
if not output and isinstance(content, str) and content.strip():
|
||||
output = content.strip()
|
||||
|
||||
if output:
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=parent_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
resolved_in_blocks.add(parent_id)
|
||||
|
||||
self.resolved_tool_calls.update(resolved_in_blocks)
|
||||
|
||||
# Close the current step after tool results — the next
|
||||
# AssistantMessage will open a new step for the continuation.
|
||||
if self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
elif isinstance(sdk_message, ResultMessage):
|
||||
self._flush_unresolved_tool_calls(responses)
|
||||
self._end_text_if_open(responses)
|
||||
# Close the step before finishing.
|
||||
if self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
if sdk_message.subtype == "success":
|
||||
responses.append(StreamFinish())
|
||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||
error_msg = sdk_message.result or "Unknown error"
|
||||
responses.append(
|
||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unexpected ResultMessage subtype: {sdk_message.subtype}"
|
||||
)
|
||||
responses.append(StreamFinish())
|
||||
|
||||
else:
|
||||
logger.debug(f"Unhandled SDK message type: {type(sdk_message).__name__}")
|
||||
|
||||
return responses
|
||||
|
||||
def _ensure_text_started(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Start (or restart) a text block if needed."""
|
||||
if not self.has_started_text or self.has_ended_text:
|
||||
if self.has_ended_text:
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
self.has_ended_text = False
|
||||
responses.append(StreamTextStart(id=self.text_block_id))
|
||||
self.has_started_text = True
|
||||
|
||||
def _end_text_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""End the current text block if one is open."""
|
||||
if self.has_started_text and not self.has_ended_text:
|
||||
responses.append(StreamTextEnd(id=self.text_block_id))
|
||||
self.has_ended_text = True
|
||||
|
||||
def _flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Emit outputs for tool calls that didn't receive a UserMessage result.
|
||||
|
||||
SDK built-in tools (WebSearch, Read, etc.) may be executed by the CLI
|
||||
internally without surfacing a separate ``UserMessage`` with
|
||||
``ToolResultBlock`` content. The ``PostToolUse`` hook stashes their
|
||||
output, which we pop and emit here before the next ``AssistantMessage``
|
||||
starts.
|
||||
"""
|
||||
unresolved = [
|
||||
(tid, info.get("name", "unknown"))
|
||||
for tid, info in self.current_tool_calls.items()
|
||||
if tid not in self.resolved_tool_calls
|
||||
]
|
||||
sid = (self.session_id or "?")[:12]
|
||||
if not unresolved:
|
||||
logger.info(
|
||||
"[SDK] [%s] Flush called but all %d tool(s) already resolved",
|
||||
sid,
|
||||
len(self.current_tool_calls),
|
||||
)
|
||||
return
|
||||
logger.info(
|
||||
"[SDK] [%s] Flushing %d unresolved tool call(s): %s",
|
||||
sid,
|
||||
len(unresolved),
|
||||
", ".join(f"{name}({tid[:12]})" for tid, name in unresolved),
|
||||
)
|
||||
|
||||
flushed = False
|
||||
for tool_id, tool_name in unresolved:
|
||||
output = pop_pending_tool_output(tool_name)
|
||||
if output is not None:
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_id,
|
||||
toolName=tool_name,
|
||||
output=output,
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
self.resolved_tool_calls.add(tool_id)
|
||||
flushed = True
|
||||
logger.info(
|
||||
"[SDK] [%s] Flushed stashed output for %s " "(call %s, %d chars)",
|
||||
sid,
|
||||
tool_name,
|
||||
tool_id[:12],
|
||||
len(output),
|
||||
)
|
||||
else:
|
||||
# No output available — emit an empty output so the frontend
|
||||
# transitions the tool from input-available to output-available
|
||||
# (stops the spinner).
|
||||
responses.append(
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_id,
|
||||
toolName=tool_name,
|
||||
output="",
|
||||
success=True,
|
||||
)
|
||||
)
|
||||
self.resolved_tool_calls.add(tool_id)
|
||||
flushed = True
|
||||
logger.warning(
|
||||
"[SDK] [%s] Flushed EMPTY output for unresolved tool %s "
|
||||
"(call %s) — stash was empty (likely SDK hook race "
|
||||
"condition: PostToolUse hook hadn't completed before "
|
||||
"flush was triggered)",
|
||||
sid,
|
||||
tool_name,
|
||||
tool_id[:12],
|
||||
)
|
||||
|
||||
if flushed and self.step_open:
|
||||
responses.append(StreamFinishStep())
|
||||
self.step_open = False
|
||||
|
||||
|
||||
def _extract_tool_output(content: str | list[dict[str, str]] | None) -> str:
|
||||
"""Extract a string output from a ToolResultBlock's content field."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = [item.get("text", "") for item in content if item.get("type") == "text"]
|
||||
if parts:
|
||||
return "".join(parts)
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
if content is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(content)
|
||||
except (TypeError, ValueError):
|
||||
return str(content)
|
||||
|
||||
|
||||
def _extract_tool_use_result(result: object) -> str:
|
||||
"""Extract a string from a UserMessage's ``tool_use_result`` dict.
|
||||
|
||||
SDK built-in tools may store their result in ``tool_use_result``
|
||||
instead of (or in addition to) ``ToolResultBlock`` content blocks.
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if isinstance(result, dict):
|
||||
# Try common result keys
|
||||
for key in ("content", "text", "output", "stdout", "result"):
|
||||
val = result.get(key)
|
||||
if isinstance(val, str) and val:
|
||||
return val
|
||||
# Fall back to JSON serialization of the whole dict
|
||||
try:
|
||||
return json.dumps(result)
|
||||
except (TypeError, ValueError):
|
||||
return str(result)
|
||||
if result is None:
|
||||
return ""
|
||||
try:
|
||||
return json.dumps(result)
|
||||
except (TypeError, ValueError):
|
||||
return str(result)
|
||||
@@ -1,194 +0,0 @@
|
||||
"""SDK compatibility tests — verify the claude-agent-sdk public API surface we depend on.
|
||||
|
||||
Instead of pinning to a narrow version range, these tests verify that the
|
||||
installed SDK exposes every class, function, attribute, and method the copilot
|
||||
integration relies on. If an SDK upgrade removes or renames something these
|
||||
tests will catch it immediately.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public types & factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sdk_exports_client_and_options():
|
||||
from claude_agent_sdk import ClaudeAgentOptions, ClaudeSDKClient
|
||||
|
||||
assert inspect.isclass(ClaudeSDKClient)
|
||||
assert inspect.isclass(ClaudeAgentOptions)
|
||||
|
||||
|
||||
def test_sdk_exports_message_types():
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
Message,
|
||||
ResultMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
for cls in (AssistantMessage, ResultMessage, SystemMessage, UserMessage):
|
||||
assert inspect.isclass(cls), f"{cls.__name__} is not a class"
|
||||
# Message is a Union type alias, just verify it's importable
|
||||
assert Message is not None
|
||||
|
||||
|
||||
def test_sdk_exports_content_block_types():
|
||||
from claude_agent_sdk import TextBlock, ToolResultBlock, ToolUseBlock
|
||||
|
||||
for cls in (TextBlock, ToolResultBlock, ToolUseBlock):
|
||||
assert inspect.isclass(cls), f"{cls.__name__} is not a class"
|
||||
|
||||
|
||||
def test_sdk_exports_mcp_helpers():
|
||||
from claude_agent_sdk import create_sdk_mcp_server, tool
|
||||
|
||||
assert callable(create_sdk_mcp_server)
|
||||
assert callable(tool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ClaudeSDKClient interface
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_client_has_required_methods():
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
required = ["connect", "disconnect", "query", "receive_messages"]
|
||||
for name in required:
|
||||
attr = getattr(ClaudeSDKClient, name, None)
|
||||
assert attr is not None, f"ClaudeSDKClient.{name} missing"
|
||||
assert callable(attr), f"ClaudeSDKClient.{name} is not callable"
|
||||
|
||||
|
||||
def test_client_supports_async_context_manager():
|
||||
from claude_agent_sdk import ClaudeSDKClient
|
||||
|
||||
assert hasattr(ClaudeSDKClient, "__aenter__")
|
||||
assert hasattr(ClaudeSDKClient, "__aexit__")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ClaudeAgentOptions fields
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_agent_options_accepts_required_fields():
|
||||
"""Verify ClaudeAgentOptions accepts all kwargs our code passes."""
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
|
||||
opts = ClaudeAgentOptions(
|
||||
system_prompt="test",
|
||||
cwd="/tmp",
|
||||
)
|
||||
assert opts.system_prompt == "test"
|
||||
assert opts.cwd == "/tmp"
|
||||
|
||||
|
||||
def test_agent_options_accepts_all_our_fields():
|
||||
"""Comprehensive check of every field we use in service.py."""
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
|
||||
fields_we_use = [
|
||||
"system_prompt",
|
||||
"mcp_servers",
|
||||
"allowed_tools",
|
||||
"disallowed_tools",
|
||||
"hooks",
|
||||
"cwd",
|
||||
"model",
|
||||
"env",
|
||||
"resume",
|
||||
"max_buffer_size",
|
||||
]
|
||||
sig = inspect.signature(ClaudeAgentOptions)
|
||||
for field in fields_we_use:
|
||||
assert field in sig.parameters, (
|
||||
f"ClaudeAgentOptions no longer accepts '{field}' — "
|
||||
f"available params: {list(sig.parameters.keys())}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Message attributes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_assistant_message_has_content_and_model():
|
||||
from claude_agent_sdk import AssistantMessage, TextBlock
|
||||
|
||||
msg = AssistantMessage(content=[TextBlock(text="hi")], model="test")
|
||||
assert hasattr(msg, "content")
|
||||
assert hasattr(msg, "model")
|
||||
|
||||
|
||||
def test_result_message_has_required_attrs():
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
msg = ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="s1",
|
||||
)
|
||||
assert msg.subtype == "success"
|
||||
assert hasattr(msg, "result")
|
||||
|
||||
|
||||
def test_system_message_has_subtype_and_data():
|
||||
from claude_agent_sdk import SystemMessage
|
||||
|
||||
msg = SystemMessage(subtype="init", data={})
|
||||
assert msg.subtype == "init"
|
||||
assert msg.data == {}
|
||||
|
||||
|
||||
def test_user_message_has_parent_tool_use_id():
|
||||
from claude_agent_sdk import UserMessage
|
||||
|
||||
msg = UserMessage(content="test")
|
||||
assert hasattr(msg, "parent_tool_use_id")
|
||||
assert hasattr(msg, "tool_use_result")
|
||||
|
||||
|
||||
def test_tool_use_block_has_id_name_input():
|
||||
from claude_agent_sdk import ToolUseBlock
|
||||
|
||||
block = ToolUseBlock(id="t1", name="test", input={"key": "val"})
|
||||
assert block.id == "t1"
|
||||
assert block.name == "test"
|
||||
assert block.input == {"key": "val"}
|
||||
|
||||
|
||||
def test_tool_result_block_has_required_attrs():
|
||||
from claude_agent_sdk import ToolResultBlock
|
||||
|
||||
block = ToolResultBlock(tool_use_id="t1", content="result")
|
||||
assert block.tool_use_id == "t1"
|
||||
assert block.content == "result"
|
||||
assert hasattr(block, "is_error")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hook types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"hook_event",
|
||||
["PreToolUse", "PostToolUse", "Stop"],
|
||||
)
|
||||
def test_sdk_exports_hook_event_type(hook_event: str):
|
||||
"""Verify HookEvent literal includes the events our security_hooks use."""
|
||||
from claude_agent_sdk.types import HookEvent
|
||||
|
||||
# HookEvent is a Literal type — check that our events are valid values.
|
||||
# We can't easily inspect Literal at runtime, so just verify the type exists.
|
||||
assert HookEvent is not None
|
||||
@@ -1,335 +0,0 @@
|
||||
"""Tests for SDK security hooks — workspace paths, tool access, and deny messages.
|
||||
|
||||
These are pure unit tests with no external dependencies (no SDK, no DB, no server).
|
||||
They validate that the security hooks correctly block unauthorized paths,
|
||||
tool access, and dangerous input patterns.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from .security_hooks import _validate_tool_access, _validate_user_isolation
|
||||
from .service import _is_tool_error_or_denial
|
||||
|
||||
SDK_CWD = "/tmp/copilot-abc123"
|
||||
|
||||
|
||||
def _sdk_available() -> bool:
|
||||
try:
|
||||
import claude_agent_sdk # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def _is_denied(result: dict) -> bool:
|
||||
hook = result.get("hookSpecificOutput", {})
|
||||
return hook.get("permissionDecision") == "deny"
|
||||
|
||||
|
||||
def _reason(result: dict) -> str:
|
||||
return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "")
|
||||
|
||||
|
||||
# -- Blocked tools -----------------------------------------------------------
|
||||
|
||||
|
||||
def test_blocked_tools_denied():
|
||||
for tool in ("bash", "shell", "exec", "terminal", "command"):
|
||||
result = _validate_tool_access(tool, {})
|
||||
assert _is_denied(result), f"{tool} should be blocked"
|
||||
|
||||
|
||||
def test_unknown_tool_allowed():
|
||||
result = _validate_tool_access("SomeCustomTool", {})
|
||||
assert result == {}
|
||||
|
||||
|
||||
# -- Workspace-scoped tools --------------------------------------------------
|
||||
|
||||
|
||||
def test_read_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Read", {"file_path": f"{SDK_CWD}/file.txt"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_write_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Write", {"file_path": f"{SDK_CWD}/output.json"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_edit_within_workspace_allowed():
|
||||
result = _validate_tool_access(
|
||||
"Edit", {"file_path": f"{SDK_CWD}/src/main.py"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_glob_within_workspace_allowed():
|
||||
result = _validate_tool_access("Glob", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_grep_within_workspace_allowed():
|
||||
result = _validate_tool_access("Grep", {"path": f"{SDK_CWD}/src"}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_outside_workspace_denied():
|
||||
result = _validate_tool_access(
|
||||
"Read", {"file_path": "/etc/passwd"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_write_outside_workspace_denied():
|
||||
result = _validate_tool_access(
|
||||
"Write", {"file_path": "/home/user/secrets.txt"}, sdk_cwd=SDK_CWD
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_traversal_attack_denied():
|
||||
result = _validate_tool_access(
|
||||
"Read",
|
||||
{"file_path": f"{SDK_CWD}/../../etc/passwd"},
|
||||
sdk_cwd=SDK_CWD,
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_no_path_allowed():
|
||||
"""Glob/Grep without a path argument defaults to cwd — should pass."""
|
||||
result = _validate_tool_access("Glob", {}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_no_cwd_denies_absolute():
|
||||
"""If no sdk_cwd is set, absolute paths are denied."""
|
||||
result = _validate_tool_access("Read", {"file_path": "/tmp/anything"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Tool-results directory --------------------------------------------------
|
||||
|
||||
|
||||
def test_read_tool_results_allowed():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/tool-results/12345.txt"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_read_claude_projects_without_tool_results_denied():
|
||||
home = os.path.expanduser("~")
|
||||
path = f"{home}/.claude/projects/-tmp-copilot-abc123/settings.json"
|
||||
result = _validate_tool_access("Read", {"file_path": path}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Built-in Bash is blocked (use bash_exec MCP tool instead) ---------------
|
||||
|
||||
|
||||
def test_bash_builtin_always_blocked():
|
||||
"""SDK built-in Bash is blocked — bash_exec MCP tool with bubblewrap is used instead."""
|
||||
result = _validate_tool_access("Bash", {"command": "echo hello"}, sdk_cwd=SDK_CWD)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- Dangerous patterns ------------------------------------------------------
|
||||
|
||||
|
||||
def test_dangerous_pattern_blocked():
|
||||
result = _validate_tool_access("SomeTool", {"cmd": "sudo rm -rf /"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_subprocess_pattern_blocked():
|
||||
result = _validate_tool_access("SomeTool", {"code": "subprocess.run(...)"})
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
# -- User isolation ----------------------------------------------------------
|
||||
|
||||
|
||||
def test_workspace_path_traversal_blocked():
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "../../../etc/shadow"}, user_id="user-1"
|
||||
)
|
||||
assert _is_denied(result)
|
||||
|
||||
|
||||
def test_workspace_absolute_path_allowed():
|
||||
"""Workspace 'path' is a cloud storage key — leading '/' is normal."""
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "/ASEAN/report.md"}, user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_workspace_normal_path_allowed():
|
||||
result = _validate_user_isolation(
|
||||
"workspace_read", {"path": "src/main.py"}, user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_non_workspace_tool_passes_isolation():
|
||||
result = _validate_user_isolation(
|
||||
"find_agent", {"query": "email"}, user_id="user-1"
|
||||
)
|
||||
assert result == {}
|
||||
|
||||
|
||||
# -- Deny message quality ----------------------------------------------------
|
||||
|
||||
|
||||
def test_blocked_tool_message_clarity():
|
||||
"""Deny messages must include [SECURITY] and 'cannot be bypassed'."""
|
||||
reason = _reason(_validate_tool_access("bash", {}))
|
||||
assert "[SECURITY]" in reason
|
||||
assert "cannot be bypassed" in reason
|
||||
|
||||
|
||||
def test_bash_builtin_blocked_message_clarity():
|
||||
reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"}))
|
||||
assert "[SECURITY]" in reason
|
||||
assert "cannot be bypassed" in reason
|
||||
|
||||
|
||||
# -- Task sub-agent hooks (require SDK) --------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _hooks():
|
||||
"""Create security hooks and return the PreToolUse handler."""
|
||||
from .security_hooks import create_security_hooks
|
||||
|
||||
hooks = create_security_hooks(user_id="u1", sdk_cwd=SDK_CWD, max_subtasks=2)
|
||||
pre = hooks["PreToolUse"][0].hooks[0]
|
||||
return pre
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_background_blocked(_hooks):
|
||||
"""Task with run_in_background=true must be denied."""
|
||||
result = await _hooks(
|
||||
{"tool_name": "Task", "tool_input": {"run_in_background": True, "prompt": "x"}},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "foreground" in _reason(result).lower()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_foreground_allowed(_hooks):
|
||||
"""Task without run_in_background should be allowed."""
|
||||
result = await _hooks(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "do stuff"}},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _sdk_available(), reason="claude_agent_sdk not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_limit_enforced(_hooks):
|
||||
"""Task spawns beyond max_subtasks should be denied."""
|
||||
# First two should pass
|
||||
for _ in range(2):
|
||||
result = await _hooks(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "ok"}},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert not _is_denied(result)
|
||||
|
||||
# Third should be denied (limit=2)
|
||||
result = await _hooks(
|
||||
{"tool_name": "Task", "tool_input": {"prompt": "over limit"}},
|
||||
tool_use_id=None,
|
||||
context={},
|
||||
)
|
||||
assert _is_denied(result)
|
||||
assert "Maximum" in _reason(result)
|
||||
|
||||
|
||||
# -- _is_tool_error_or_denial ------------------------------------------------
|
||||
|
||||
|
||||
class TestIsToolErrorOrDenial:
|
||||
def test_none_content(self):
|
||||
assert _is_tool_error_or_denial(None) is False
|
||||
|
||||
def test_empty_content(self):
|
||||
assert _is_tool_error_or_denial("") is False
|
||||
|
||||
def test_benign_output(self):
|
||||
assert _is_tool_error_or_denial("All good, no issues.") is False
|
||||
|
||||
def test_security_marker(self):
|
||||
assert _is_tool_error_or_denial("[SECURITY] Tool access blocked") is True
|
||||
|
||||
def test_cannot_be_bypassed(self):
|
||||
assert _is_tool_error_or_denial("This restriction cannot be bypassed.") is True
|
||||
|
||||
def test_not_allowed(self):
|
||||
assert _is_tool_error_or_denial("Operation not allowed in sandbox") is True
|
||||
|
||||
def test_background_task_denial(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
"Background task execution is not supported. "
|
||||
"Run tasks in the foreground instead."
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_subtask_limit_denial(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
"Maximum 2 sub-tasks per session. Please continue in the main conversation."
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_denied_marker(self):
|
||||
assert (
|
||||
_is_tool_error_or_denial("Access denied: insufficient privileges") is True
|
||||
)
|
||||
|
||||
def test_blocked_marker(self):
|
||||
assert _is_tool_error_or_denial("Request blocked by security policy") is True
|
||||
|
||||
def test_failed_marker(self):
|
||||
assert _is_tool_error_or_denial("Failed to execute tool: timeout") is True
|
||||
|
||||
def test_mcp_iserror(self):
|
||||
assert _is_tool_error_or_denial('{"isError": true, "content": []}') is True
|
||||
|
||||
def test_benign_error_in_value(self):
|
||||
"""Content like '0 errors found' should not trigger — 'error' was removed."""
|
||||
assert _is_tool_error_or_denial("0 errors found") is False
|
||||
|
||||
def test_benign_permission_field(self):
|
||||
"""Schema descriptions mentioning 'permission' should not trigger."""
|
||||
assert (
|
||||
_is_tool_error_or_denial(
|
||||
'{"fields": [{"name": "permission_level", "type": "int"}]}'
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_benign_not_found_in_listing(self):
|
||||
"""File listing containing 'not found' in filenames should not trigger."""
|
||||
assert _is_tool_error_or_denial("readme.md\nfile-not-found-handler.py") is False
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,142 +0,0 @@
|
||||
"""Tests for CreateAgentTool response types."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools.create_agent import CreateAgentTool
|
||||
from backend.copilot.tools.models import (
|
||||
ClarificationNeededResponse,
|
||||
ErrorResponse,
|
||||
SuggestedGoalResponse,
|
||||
)
|
||||
|
||||
from ._test_data import make_session
|
||||
|
||||
_TEST_USER_ID = "test-user-create-agent"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool():
|
||||
return CreateAgentTool()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
return make_session(_TEST_USER_ID)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_description_returns_error(tool, session):
|
||||
"""Missing description returns ErrorResponse."""
|
||||
result = await tool._execute(user_id=_TEST_USER_ID, session=session, description="")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "Missing description parameter"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vague_goal_returns_suggested_goal_response(tool, session):
|
||||
"""vague_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
|
||||
vague_result = {
|
||||
"type": "vague_goal",
|
||||
"suggested_goal": "Monitor Twitter mentions for a specific keyword and send a daily digest email",
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
return_value=vague_result,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
description="monitor social media",
|
||||
)
|
||||
|
||||
assert isinstance(result, SuggestedGoalResponse)
|
||||
assert result.goal_type == "vague"
|
||||
assert result.suggested_goal == vague_result["suggested_goal"]
|
||||
assert result.original_goal == "monitor social media"
|
||||
assert result.reason == "The goal needs more specific details"
|
||||
assert not isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unachievable_goal_returns_suggested_goal_response(tool, session):
|
||||
"""unachievable_goal decomposition result returns SuggestedGoalResponse, not ErrorResponse."""
|
||||
unachievable_result = {
|
||||
"type": "unachievable_goal",
|
||||
"suggested_goal": "Summarize the latest news articles on a topic and send them by email",
|
||||
"reason": "There are no blocks for mind-reading.",
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
return_value=unachievable_result,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
description="read my mind",
|
||||
)
|
||||
|
||||
assert isinstance(result, SuggestedGoalResponse)
|
||||
assert result.goal_type == "unachievable"
|
||||
assert result.suggested_goal == unachievable_result["suggested_goal"]
|
||||
assert result.original_goal == "read my mind"
|
||||
assert result.reason == unachievable_result["reason"]
|
||||
assert not isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clarifying_questions_returns_clarification_needed_response(
|
||||
tool, session
|
||||
):
|
||||
"""clarifying_questions decomposition result returns ClarificationNeededResponse."""
|
||||
clarifying_result = {
|
||||
"type": "clarifying_questions",
|
||||
"questions": [
|
||||
{
|
||||
"question": "What platform should be monitored?",
|
||||
"keyword": "platform",
|
||||
"example": "Twitter, Reddit",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.create_agent.decompose_goal",
|
||||
new_callable=AsyncMock,
|
||||
return_value=clarifying_result,
|
||||
),
|
||||
):
|
||||
result = await tool._execute(
|
||||
user_id=_TEST_USER_ID,
|
||||
session=session,
|
||||
description="monitor social media and alert me",
|
||||
)
|
||||
|
||||
assert isinstance(result, ClarificationNeededResponse)
|
||||
assert len(result.questions) == 1
|
||||
assert result.questions[0].keyword == "platform"
|
||||
@@ -1,789 +0,0 @@
|
||||
"""CoPilot tools for workspace file operations."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
from backend.data.db_accessors import workspace_db
|
||||
from backend.util.settings import Config
|
||||
from backend.util.virus_scanner import scan_content_safe
|
||||
from backend.util.workspace import WorkspaceManager
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ResponseType, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_write_content(
|
||||
content_text: str | None,
|
||||
content_b64: str | None,
|
||||
source_path: str | None,
|
||||
session_id: str,
|
||||
) -> bytes | ErrorResponse:
|
||||
"""Resolve file content from exactly one of three input sources.
|
||||
|
||||
Returns the raw bytes on success, or an ``ErrorResponse`` on validation
|
||||
failure (wrong number of sources, invalid path, file not found, etc.).
|
||||
"""
|
||||
# Normalise empty strings to None so counting and dispatch stay in sync.
|
||||
if content_text is not None and content_text == "":
|
||||
content_text = None
|
||||
if content_b64 is not None and content_b64 == "":
|
||||
content_b64 = None
|
||||
if source_path is not None and source_path == "":
|
||||
source_path = None
|
||||
|
||||
sources_provided = sum(
|
||||
x is not None for x in [content_text, content_b64, source_path]
|
||||
)
|
||||
if sources_provided == 0:
|
||||
return ErrorResponse(
|
||||
message="Please provide one of: content, content_base64, or source_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
if sources_provided > 1:
|
||||
return ErrorResponse(
|
||||
message="Provide only one of: content, content_base64, or source_path",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if source_path is not None:
|
||||
validated = _validate_ephemeral_path(
|
||||
source_path, param_name="source_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated, ErrorResponse):
|
||||
return validated
|
||||
try:
|
||||
with open(validated, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=f"Source file not found: {source_path}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read source file: {e}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if content_b64 is not None:
|
||||
try:
|
||||
return base64.b64decode(content_b64)
|
||||
except Exception:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Invalid base64 encoding in content_base64. "
|
||||
"Please encode the file content with standard base64, "
|
||||
"or use the 'content' parameter for plain text, "
|
||||
"or 'source_path' to copy from the working directory."
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
assert content_text is not None
|
||||
return content_text.encode("utf-8")
|
||||
|
||||
|
||||
def _validate_ephemeral_path(
|
||||
path: str, *, param_name: str, session_id: str
|
||||
) -> ErrorResponse | str:
|
||||
"""Validate that *path* is inside the session's ephemeral directory.
|
||||
|
||||
Uses the session-specific directory (``make_session_path(session_id)``)
|
||||
rather than the bare prefix, so ``/tmp/copilot-evil/...`` is rejected.
|
||||
|
||||
Returns the resolved real path on success, or an ``ErrorResponse`` when the
|
||||
path escapes the session directory.
|
||||
"""
|
||||
session_dir = os.path.realpath(make_session_path(session_id)) + os.sep
|
||||
real = os.path.realpath(path)
|
||||
if not real.startswith(session_dir):
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"{param_name} must be within the ephemeral working "
|
||||
f"directory ({make_session_path(session_id)})"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
return real
|
||||
|
||||
|
||||
_TEXT_MIME_PREFIXES = (
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-python",
|
||||
"application/x-sh",
|
||||
)
|
||||
|
||||
_IMAGE_MIME_TYPES = {"image/png", "image/jpeg", "image/gif", "image/webp"}
|
||||
|
||||
|
||||
def _is_text_mime(mime_type: str) -> bool:
|
||||
return any(mime_type.startswith(t) for t in _TEXT_MIME_PREFIXES)
|
||||
|
||||
|
||||
async def _get_manager(user_id: str, session_id: str) -> WorkspaceManager:
|
||||
"""Create a session-scoped WorkspaceManager."""
|
||||
workspace = await workspace_db().get_or_create_workspace(user_id)
|
||||
return WorkspaceManager(user_id, workspace.id, session_id)
|
||||
|
||||
|
||||
async def _resolve_file(
|
||||
manager: WorkspaceManager,
|
||||
file_id: str | None,
|
||||
path: str | None,
|
||||
session_id: str,
|
||||
) -> tuple[str, Any] | ErrorResponse:
|
||||
"""Resolve a file by file_id or path.
|
||||
|
||||
Returns ``(target_file_id, file_info)`` on success, or an
|
||||
``ErrorResponse`` if the file was not found.
|
||||
"""
|
||||
if file_id:
|
||||
file_info = await manager.get_file_info(file_id)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {file_id}", session_id=session_id
|
||||
)
|
||||
return file_id, file_info
|
||||
|
||||
assert path is not None
|
||||
file_info = await manager.get_file_info_by_path(path)
|
||||
if file_info is None:
|
||||
return ErrorResponse(
|
||||
message=f"File not found at path: {path}", session_id=session_id
|
||||
)
|
||||
return file_info.id, file_info
|
||||
|
||||
|
||||
class WorkspaceFileInfoData(BaseModel):
|
||||
"""Data model for workspace file information (not a response itself)."""
|
||||
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
|
||||
|
||||
class WorkspaceFileListResponse(ToolResponseBase):
|
||||
"""Response containing list of workspace files."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_LIST
|
||||
files: list[WorkspaceFileInfoData]
|
||||
total_count: int
|
||||
|
||||
|
||||
class WorkspaceFileContentResponse(ToolResponseBase):
|
||||
"""Response containing workspace file content (legacy, for small text files)."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_CONTENT
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
content_base64: str
|
||||
|
||||
|
||||
class WorkspaceFileMetadataResponse(ToolResponseBase):
|
||||
"""Response containing workspace file metadata and download URL (prevents context bloat)."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_METADATA
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
mime_type: str
|
||||
size_bytes: int
|
||||
download_url: str
|
||||
preview: str | None = None # First 500 chars for text files
|
||||
|
||||
|
||||
class WorkspaceWriteResponse(ToolResponseBase):
|
||||
"""Response after writing a file to workspace."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_WRITTEN
|
||||
file_id: str
|
||||
name: str
|
||||
path: str
|
||||
size_bytes: int
|
||||
source: str | None = None # "content", "base64", or "copied from <path>"
|
||||
content_preview: str | None = None # First 200 chars for text files
|
||||
|
||||
|
||||
class WorkspaceDeleteResponse(ToolResponseBase):
|
||||
"""Response after deleting a file from workspace."""
|
||||
|
||||
type: ResponseType = ResponseType.WORKSPACE_FILE_DELETED
|
||||
file_id: str
|
||||
success: bool
|
||||
|
||||
|
||||
class ListWorkspaceFilesTool(BaseTool):
|
||||
"""Tool for listing files in user's workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_workspace_files"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List files in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Read/Glob tools instead. "
|
||||
"Returns file names, paths, sizes, and metadata. "
|
||||
"Optionally filter by path prefix."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path_prefix": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional path prefix to filter files "
|
||||
"(e.g., '/documents/' to list only files in documents folder). "
|
||||
"By default, only files from the current session are listed."
|
||||
),
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of files to return (default 50, max 100)",
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
},
|
||||
"include_all_sessions": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, list files from all sessions. "
|
||||
"Default is false (only current session's files)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
path_prefix: Optional[str] = kwargs.get("path_prefix")
|
||||
limit = min(kwargs.get("limit", 50), 100)
|
||||
include_all_sessions: bool = kwargs.get("include_all_sessions", False)
|
||||
|
||||
try:
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
files = await manager.list_files(
|
||||
path=path_prefix, limit=limit, include_all_sessions=include_all_sessions
|
||||
)
|
||||
total = await manager.get_file_count(
|
||||
path=path_prefix, include_all_sessions=include_all_sessions
|
||||
)
|
||||
file_infos = [
|
||||
WorkspaceFileInfoData(
|
||||
file_id=f.id,
|
||||
name=f.name,
|
||||
path=f.path,
|
||||
mime_type=f.mime_type,
|
||||
size_bytes=f.size_bytes,
|
||||
)
|
||||
for f in files
|
||||
]
|
||||
scope = "all sessions" if include_all_sessions else "current session"
|
||||
total_size = sum(f.size_bytes for f in file_infos)
|
||||
|
||||
# Build a human-readable summary so the agent can relay details.
|
||||
lines = [f"Found {len(files)} file(s) in workspace ({scope}):"]
|
||||
for f in file_infos:
|
||||
lines.append(f" - {f.path} ({f.size_bytes:,} bytes, {f.mime_type})")
|
||||
if total > len(files):
|
||||
lines.append(f" ... and {total - len(files)} more")
|
||||
lines.append(f"Total size: {total_size:,} bytes")
|
||||
|
||||
return WorkspaceFileListResponse(
|
||||
files=file_infos,
|
||||
total_count=total,
|
||||
message="\n".join(lines),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing workspace files: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to list workspace files: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class ReadWorkspaceFileTool(BaseTool):
|
||||
"""Tool for reading file content from workspace."""
|
||||
|
||||
MAX_INLINE_SIZE_BYTES = 32 * 1024 # 32KB
|
||||
PREVIEW_SIZE = 500
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "read_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a file from the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Read tool instead. "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"For small text files, returns content directly. "
|
||||
"For large or binary files, returns metadata and a download URL. "
|
||||
"Optionally use 'save_to_path' to copy the file to the ephemeral "
|
||||
"working directory for processing with bash_exec or SDK tools. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
"save_to_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"If provided, save the file to this path in the ephemeral "
|
||||
"working directory (e.g., '/tmp/copilot-.../data.csv') "
|
||||
"so it can be processed with bash_exec or SDK tools. "
|
||||
"The file content is still returned in the response."
|
||||
),
|
||||
},
|
||||
"force_download_url": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"If true, always return metadata+URL instead of inline content. "
|
||||
"Default is false (auto-selects based on file size/type)."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
save_to_path: Optional[str] = kwargs.get("save_to_path")
|
||||
force_download_url: bool = kwargs.get("force_download_url", False)
|
||||
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path", session_id=session_id
|
||||
)
|
||||
|
||||
# Validate and resolve save_to_path (use sanitized real path).
|
||||
if save_to_path:
|
||||
validated_save = _validate_ephemeral_path(
|
||||
save_to_path, param_name="save_to_path", session_id=session_id
|
||||
)
|
||||
if isinstance(validated_save, ErrorResponse):
|
||||
return validated_save
|
||||
save_to_path = validated_save
|
||||
|
||||
try:
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
target_file_id, file_info = resolved
|
||||
|
||||
# If save_to_path, read + save; cache bytes for possible inline reuse.
|
||||
cached_content: bytes | None = None
|
||||
if save_to_path:
|
||||
cached_content = await manager.read_file_by_id(target_file_id)
|
||||
dir_path = os.path.dirname(save_to_path)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
with open(save_to_path, "wb") as f:
|
||||
f.write(cached_content)
|
||||
|
||||
is_small = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES
|
||||
is_text = _is_text_mime(file_info.mime_type)
|
||||
is_image = file_info.mime_type in _IMAGE_MIME_TYPES
|
||||
|
||||
# Inline content for small text/image files
|
||||
if is_small and (is_text or is_image) and not force_download_url:
|
||||
content = cached_content or await manager.read_file_by_id(
|
||||
target_file_id
|
||||
)
|
||||
msg = (
|
||||
f"Read {file_info.name} from workspace:{file_info.path} "
|
||||
f"({file_info.size_bytes:,} bytes, {file_info.mime_type})"
|
||||
)
|
||||
if save_to_path:
|
||||
msg += f" — also saved to {save_to_path}"
|
||||
return WorkspaceFileContentResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mime_type,
|
||||
content_base64=base64.b64encode(content).decode("utf-8"),
|
||||
message=msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Metadata + download URL for large/binary files
|
||||
preview: str | None = None
|
||||
if is_text:
|
||||
try:
|
||||
raw = cached_content or await manager.read_file_by_id(
|
||||
target_file_id
|
||||
)
|
||||
preview = raw[: self.PREVIEW_SIZE].decode("utf-8", errors="replace")
|
||||
if len(raw) > self.PREVIEW_SIZE:
|
||||
preview += "..."
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
msg = (
|
||||
f"File: {file_info.name} at workspace:{file_info.path} "
|
||||
f"({file_info.size_bytes:,} bytes, {file_info.mime_type})"
|
||||
)
|
||||
if save_to_path:
|
||||
msg += f" — saved to {save_to_path}"
|
||||
else:
|
||||
msg += (
|
||||
" — use read_workspace_file with this file_id to retrieve content"
|
||||
)
|
||||
return WorkspaceFileMetadataResponse(
|
||||
file_id=file_info.id,
|
||||
name=file_info.name,
|
||||
path=file_info.path,
|
||||
mime_type=file_info.mime_type,
|
||||
size_bytes=file_info.size_bytes,
|
||||
download_url=f"workspace://{target_file_id}",
|
||||
preview=preview,
|
||||
message=msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
return ErrorResponse(message=str(e), session_id=session_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to read workspace file: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class WriteWorkspaceFileTool(BaseTool):
|
||||
"""Tool for writing files to workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write or create a file in the user's persistent workspace (cloud storage). "
|
||||
"These files survive across sessions. "
|
||||
"For ephemeral session files, use the SDK Write tool instead. "
|
||||
"Provide content as plain text via 'content', OR base64-encoded via "
|
||||
"'content_base64', OR copy a file from the ephemeral working directory "
|
||||
"via 'source_path'. Exactly one of these three is required. "
|
||||
f"Maximum file size is {Config().max_file_size_mb}MB. "
|
||||
"Files are saved to the current session's folder by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Name for the file (e.g., 'report.pdf')",
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Plain text content to write. Use this for text files "
|
||||
"(code, configs, documents, etc.). "
|
||||
"Mutually exclusive with content_base64 and source_path."
|
||||
),
|
||||
},
|
||||
"content_base64": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Base64-encoded file content. Use this for binary files "
|
||||
"(images, PDFs, etc.). "
|
||||
"Mutually exclusive with content and source_path."
|
||||
),
|
||||
},
|
||||
"source_path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Path to a file in the ephemeral working directory to "
|
||||
"copy to workspace (e.g., '/tmp/copilot-.../output.csv'). "
|
||||
"Use this to persist files created by bash_exec or SDK Write. "
|
||||
"Mutually exclusive with content and content_base64."
|
||||
),
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional virtual path where to save the file "
|
||||
"(e.g., '/documents/report.pdf'). "
|
||||
"Defaults to '/{filename}'. Scoped to current session."
|
||||
),
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional MIME type of the file. "
|
||||
"Auto-detected from filename if not provided."
|
||||
),
|
||||
},
|
||||
"overwrite": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to overwrite if file exists at path (default: false)",
|
||||
},
|
||||
},
|
||||
"required": ["filename"],
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
filename: str = kwargs.get("filename", "")
|
||||
if not filename:
|
||||
return ErrorResponse(
|
||||
message="Please provide a filename", session_id=session_id
|
||||
)
|
||||
|
||||
source_path_arg: str | None = kwargs.get("source_path")
|
||||
content_text: str | None = kwargs.get("content")
|
||||
content_b64: str | None = kwargs.get("content_base64")
|
||||
|
||||
resolved = _resolve_write_content(
|
||||
content_text,
|
||||
content_b64,
|
||||
source_path_arg,
|
||||
session_id,
|
||||
)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
content: bytes = resolved
|
||||
|
||||
max_size = Config().max_file_size_mb * 1024 * 1024
|
||||
if len(content) > max_size:
|
||||
return ErrorResponse(
|
||||
message=f"File too large. Maximum size is {Config().max_file_size_mb}MB",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
await scan_content_safe(content, filename=filename)
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
rec = await manager.write_file(
|
||||
content=content,
|
||||
filename=filename,
|
||||
path=kwargs.get("path"),
|
||||
mime_type=kwargs.get("mime_type"),
|
||||
overwrite=kwargs.get("overwrite", False),
|
||||
)
|
||||
|
||||
# Build informative source label and message.
|
||||
if source_path_arg:
|
||||
source = f"copied from {source_path_arg}"
|
||||
msg = (
|
||||
f"Copied {source_path_arg} → workspace:{rec.path} "
|
||||
f"({rec.size_bytes:,} bytes)"
|
||||
)
|
||||
elif content_b64:
|
||||
source = "base64"
|
||||
msg = (
|
||||
f"Wrote {rec.name} to workspace ({rec.size_bytes:,} bytes, "
|
||||
f"decoded from base64)"
|
||||
)
|
||||
else:
|
||||
source = "content"
|
||||
msg = f"Wrote {rec.name} to workspace ({rec.size_bytes:,} bytes)"
|
||||
|
||||
# Include a short preview for text content.
|
||||
preview: str | None = None
|
||||
if _is_text_mime(rec.mime_type):
|
||||
try:
|
||||
preview = content[:200].decode("utf-8", errors="replace")
|
||||
if len(content) > 200:
|
||||
preview += "..."
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return WorkspaceWriteResponse(
|
||||
file_id=rec.id,
|
||||
name=rec.name,
|
||||
path=rec.path,
|
||||
size_bytes=rec.size_bytes,
|
||||
source=source,
|
||||
content_preview=preview,
|
||||
message=msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
return ErrorResponse(message=str(e), session_id=session_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to write workspace file: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
class DeleteWorkspaceFileTool(BaseTool):
|
||||
"""Tool for deleting files from workspace."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "delete_workspace_file"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Delete a file from the user's persistent workspace (cloud storage). "
|
||||
"Specify either file_id or path to identify the file. "
|
||||
"Paths are scoped to the current session by default. "
|
||||
"Use /sessions/<session_id>/... for cross-session access."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_id": {
|
||||
"type": "string",
|
||||
"description": "The file's unique ID (from list_workspace_files)",
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The virtual file path (e.g., '/documents/report.pdf'). "
|
||||
"Scoped to current session by default."
|
||||
),
|
||||
},
|
||||
},
|
||||
"required": [], # At least one must be provided
|
||||
}
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
session_id = session.session_id
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="Authentication required", session_id=session_id
|
||||
)
|
||||
|
||||
file_id: Optional[str] = kwargs.get("file_id")
|
||||
path: Optional[str] = kwargs.get("path")
|
||||
if not file_id and not path:
|
||||
return ErrorResponse(
|
||||
message="Please provide either file_id or path", session_id=session_id
|
||||
)
|
||||
|
||||
try:
|
||||
manager = await _get_manager(user_id, session_id)
|
||||
resolved = await _resolve_file(manager, file_id, path, session_id)
|
||||
if isinstance(resolved, ErrorResponse):
|
||||
return resolved
|
||||
target_file_id, file_info = resolved
|
||||
|
||||
if not await manager.delete_file(target_file_id):
|
||||
return ErrorResponse(
|
||||
message=f"File not found: {target_file_id}", session_id=session_id
|
||||
)
|
||||
return WorkspaceDeleteResponse(
|
||||
file_id=target_file_id,
|
||||
success=True,
|
||||
message=(
|
||||
f"Deleted {file_info.name} from workspace:{file_info.path} "
|
||||
f"({file_info.size_bytes:,} bytes)"
|
||||
),
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting workspace file: {e}", exc_info=True)
|
||||
return ErrorResponse(
|
||||
message=f"Failed to delete workspace file: {e}",
|
||||
error=str(e),
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -1,267 +0,0 @@
|
||||
"""Tests for workspace file tool helpers and path validation."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.tools._test_data import make_session, setup_test_data
|
||||
from backend.copilot.tools.workspace_files import (
|
||||
DeleteWorkspaceFileTool,
|
||||
ListWorkspaceFilesTool,
|
||||
ReadWorkspaceFileTool,
|
||||
WorkspaceDeleteResponse,
|
||||
WorkspaceFileListResponse,
|
||||
WorkspaceWriteResponse,
|
||||
WriteWorkspaceFileTool,
|
||||
_resolve_write_content,
|
||||
_validate_ephemeral_path,
|
||||
)
|
||||
|
||||
# Re-export so pytest discovers the session-scoped fixture
|
||||
setup_test_data = setup_test_data
|
||||
|
||||
# We need to mock make_session_path to return a known temp dir for tests.
|
||||
# The real one uses WORKSPACE_PREFIX = "/tmp/copilot-"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ephemeral_dir(tmp_path, monkeypatch):
|
||||
"""Create a temp dir that acts as the ephemeral session directory."""
|
||||
session_dir = tmp_path / "copilot-test-session"
|
||||
session_dir.mkdir()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.workspace_files.make_session_path",
|
||||
lambda session_id: str(session_dir),
|
||||
)
|
||||
return session_dir
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_ephemeral_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateEphemeralPath:
|
||||
def test_valid_path(self, ephemeral_dir):
|
||||
target = ephemeral_dir / "file.txt"
|
||||
target.touch()
|
||||
result = _validate_ephemeral_path(
|
||||
str(target), param_name="test", session_id="s1"
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
assert result == os.path.realpath(str(target))
|
||||
|
||||
def test_path_traversal_rejected(self, ephemeral_dir):
|
||||
evil_path = str(ephemeral_dir / ".." / "etc" / "passwd")
|
||||
result = _validate_ephemeral_path(evil_path, param_name="test", session_id="s1")
|
||||
# Should return ErrorResponse
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_different_session_rejected(self, ephemeral_dir, tmp_path):
|
||||
other_dir = tmp_path / "copilot-evil-session"
|
||||
other_dir.mkdir()
|
||||
target = other_dir / "steal.txt"
|
||||
target.touch()
|
||||
result = _validate_ephemeral_path(
|
||||
str(target), param_name="test", session_id="s1"
|
||||
)
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_symlink_escape_rejected(self, ephemeral_dir, tmp_path):
|
||||
"""Symlink inside session dir pointing outside should be rejected."""
|
||||
outside_file = tmp_path / "secret.txt"
|
||||
outside_file.write_text("secret")
|
||||
symlink = ephemeral_dir / "link.txt"
|
||||
symlink.symlink_to(outside_file)
|
||||
result = _validate_ephemeral_path(
|
||||
str(symlink), param_name="test", session_id="s1"
|
||||
)
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_nested_path_valid(self, ephemeral_dir):
|
||||
nested = ephemeral_dir / "subdir" / "deep"
|
||||
nested.mkdir(parents=True)
|
||||
target = nested / "data.csv"
|
||||
target.touch()
|
||||
result = _validate_ephemeral_path(
|
||||
str(target), param_name="test", session_id="s1"
|
||||
)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_write_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveWriteContent:
|
||||
def test_no_sources_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content(None, None, None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_multiple_sources_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content("text", "b64data", None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_plain_text_content(self):
|
||||
result = _resolve_write_content("hello world", None, None, "s1")
|
||||
assert result == b"hello world"
|
||||
|
||||
def test_base64_content(self):
|
||||
raw = b"binary data"
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
result = _resolve_write_content(None, b64, None, "s1")
|
||||
assert result == raw
|
||||
|
||||
def test_invalid_base64_returns_error(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
result = _resolve_write_content(None, "not-valid-b64!!!", None, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "base64" in result.message.lower()
|
||||
|
||||
def test_source_path(self, ephemeral_dir):
|
||||
target = ephemeral_dir / "input.txt"
|
||||
target.write_bytes(b"file content")
|
||||
result = _resolve_write_content(None, None, str(target), "s1")
|
||||
assert result == b"file content"
|
||||
|
||||
def test_source_path_not_found(self, ephemeral_dir):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
missing = str(ephemeral_dir / "nope.txt")
|
||||
result = _resolve_write_content(None, None, missing, "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_source_path_outside_ephemeral(self, ephemeral_dir, tmp_path):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
outside = tmp_path / "outside.txt"
|
||||
outside.write_text("nope")
|
||||
result = _resolve_write_content(None, None, str(outside), "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_empty_string_sources_treated_as_none(self):
|
||||
from backend.copilot.tools.models import ErrorResponse
|
||||
|
||||
# All empty strings → same as no sources
|
||||
result = _resolve_write_content("", "", "", "s1")
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
def test_empty_string_source_path_with_text(self):
|
||||
# source_path="" should be normalised to None, so only content counts
|
||||
result = _resolve_write_content("hello", "", "", "s1")
|
||||
assert result == b"hello"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# E2E: workspace file tool round-trip (write → list → read → delete)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_workspace_file_round_trip(setup_test_data):
|
||||
"""E2E: write a file, list it, read it back (with save_to_path), then delete it."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
session_id = session.session_id
|
||||
|
||||
# ---- Write ----
|
||||
write_tool = WriteWorkspaceFileTool()
|
||||
write_resp = await write_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
filename="test_round_trip.txt",
|
||||
content="Hello from e2e test!",
|
||||
)
|
||||
assert isinstance(write_resp, WorkspaceWriteResponse), write_resp.message
|
||||
file_id = write_resp.file_id
|
||||
|
||||
# ---- List ----
|
||||
list_tool = ListWorkspaceFilesTool()
|
||||
list_resp = await list_tool._execute(user_id=user.id, session=session)
|
||||
assert isinstance(list_resp, WorkspaceFileListResponse), list_resp.message
|
||||
assert any(f.file_id == file_id for f in list_resp.files)
|
||||
|
||||
# ---- Read (inline) ----
|
||||
read_tool = ReadWorkspaceFileTool()
|
||||
read_resp = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id
|
||||
)
|
||||
from backend.copilot.tools.workspace_files import WorkspaceFileContentResponse
|
||||
|
||||
assert isinstance(read_resp, WorkspaceFileContentResponse), read_resp.message
|
||||
decoded = base64.b64decode(read_resp.content_base64).decode()
|
||||
assert decoded == "Hello from e2e test!"
|
||||
|
||||
# ---- Read with save_to_path ----
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
|
||||
ephemeral_dir = make_session_path(session_id)
|
||||
os.makedirs(ephemeral_dir, exist_ok=True)
|
||||
save_path = os.path.join(ephemeral_dir, "saved_copy.txt")
|
||||
|
||||
read_resp2 = await read_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id, save_to_path=save_path
|
||||
)
|
||||
assert not isinstance(read_resp2, type(None))
|
||||
assert os.path.exists(save_path)
|
||||
with open(save_path) as f:
|
||||
assert f.read() == "Hello from e2e test!"
|
||||
|
||||
# ---- Delete ----
|
||||
delete_tool = DeleteWorkspaceFileTool()
|
||||
del_resp = await delete_tool._execute(
|
||||
user_id=user.id, session=session, file_id=file_id
|
||||
)
|
||||
assert isinstance(del_resp, WorkspaceDeleteResponse), del_resp.message
|
||||
assert del_resp.success is True
|
||||
|
||||
# Verify file is gone
|
||||
list_resp2 = await list_tool._execute(user_id=user.id, session=session)
|
||||
assert isinstance(list_resp2, WorkspaceFileListResponse)
|
||||
assert not any(f.file_id == file_id for f in list_resp2.files)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_write_workspace_file_source_path(setup_test_data):
|
||||
"""E2E: write a file from ephemeral source_path to workspace."""
|
||||
user = setup_test_data["user"]
|
||||
session = make_session(user.id)
|
||||
session_id = session.session_id
|
||||
|
||||
# Create a file in the ephemeral dir
|
||||
from backend.copilot.tools.sandbox import make_session_path
|
||||
|
||||
ephemeral_dir = make_session_path(session_id)
|
||||
os.makedirs(ephemeral_dir, exist_ok=True)
|
||||
source = os.path.join(ephemeral_dir, "generated_output.csv")
|
||||
with open(source, "w") as f:
|
||||
f.write("col1,col2\n1,2\n")
|
||||
|
||||
write_tool = WriteWorkspaceFileTool()
|
||||
write_resp = await write_tool._execute(
|
||||
user_id=user.id,
|
||||
session=session,
|
||||
filename="output.csv",
|
||||
source_path=source,
|
||||
)
|
||||
assert isinstance(write_resp, WorkspaceWriteResponse), write_resp.message
|
||||
|
||||
# Clean up
|
||||
delete_tool = DeleteWorkspaceFileTool()
|
||||
await delete_tool._execute(
|
||||
user_id=user.id, session=session, file_id=write_resp.file_id
|
||||
)
|
||||
@@ -19,6 +19,30 @@ CompletedBlockOutput = dict[str, list[Any]] # Completed stream, collected as a
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
# Refresh LLM registry before initializing blocks so blocks can use registry data
|
||||
# This ensures the registry cache is populated even in executor context
|
||||
try:
|
||||
from backend.data import llm_registry
|
||||
from backend.data.block_cost_config import refresh_llm_costs
|
||||
|
||||
# Only refresh if we have DB access (check if Prisma is connected)
|
||||
from backend.data.db import is_connected
|
||||
|
||||
if is_connected():
|
||||
await llm_registry.refresh_llm_registry()
|
||||
await refresh_llm_costs()
|
||||
logger.info("LLM registry refreshed during block initialization")
|
||||
else:
|
||||
logger.warning(
|
||||
"Prisma not connected, skipping LLM registry refresh during block initialization"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh LLM registry during block initialization: %s", exc
|
||||
)
|
||||
|
||||
# First, sync all provider costs to blocks
|
||||
# Imported here to avoid circular import
|
||||
from backend.blocks import get_blocks
|
||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||
from backend.util.retry import func_retry
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import logging
|
||||
from typing import Type
|
||||
|
||||
import prisma.models
|
||||
|
||||
from backend.blocks._base import Block, BlockCost, BlockCostType
|
||||
from backend.blocks.ai_image_customizer import AIImageCustomizerBlock, GeminiImageModel
|
||||
from backend.blocks.ai_image_generator_block import AIImageGeneratorBlock, ImageGenModel
|
||||
@@ -24,13 +27,11 @@ from backend.blocks.ideogram import IdeogramModelBlock
|
||||
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
|
||||
from backend.blocks.jina.search import ExtractWebsiteContentBlock, SearchTheWebBlock
|
||||
from backend.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AIConversationBlock,
|
||||
AIListGeneratorBlock,
|
||||
AIStructuredResponseGeneratorBlock,
|
||||
AITextGeneratorBlock,
|
||||
AITextSummarizerBlock,
|
||||
LlmModel,
|
||||
)
|
||||
from backend.blocks.replicate.flux_advanced import ReplicateFluxAdvancedModelBlock
|
||||
from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
@@ -38,6 +39,7 @@ from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
from backend.data import llm_registry
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
anthropic_credentials,
|
||||
@@ -57,210 +59,116 @@ from backend.integrations.credentials_store import (
|
||||
v0_credentials,
|
||||
)
|
||||
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.O3: 4,
|
||||
LlmModel.O3_MINI: 2,
|
||||
LlmModel.O1: 16,
|
||||
LlmModel.O1_MINI: 4,
|
||||
# GPT-5 models
|
||||
LlmModel.GPT5_2: 6,
|
||||
LlmModel.GPT5_1: 5,
|
||||
LlmModel.GPT5: 2,
|
||||
LlmModel.GPT5_MINI: 1,
|
||||
LlmModel.GPT5_NANO: 1,
|
||||
LlmModel.GPT5_CHAT: 5,
|
||||
LlmModel.GPT41: 2,
|
||||
LlmModel.GPT41_MINI: 1,
|
||||
LlmModel.GPT4O_MINI: 1,
|
||||
LlmModel.GPT4O: 3,
|
||||
LlmModel.GPT4_TURBO: 10,
|
||||
LlmModel.GPT3_5_TURBO: 1,
|
||||
LlmModel.CLAUDE_4_1_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_OPUS: 21,
|
||||
LlmModel.CLAUDE_4_SONNET: 5,
|
||||
LlmModel.CLAUDE_4_6_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.AIML_API_META_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AIML_API_LLAMA_3_2_3B: 1,
|
||||
LlmModel.LLAMA3_3_70B: 1,
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_3: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_405B: 1,
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_120B: 1,
|
||||
LlmModel.OPENAI_GPT_OSS_20B: 1,
|
||||
LlmModel.GEMINI_2_5_PRO: 4,
|
||||
LlmModel.GEMINI_3_PRO_PREVIEW: 5,
|
||||
LlmModel.GEMINI_2_5_FLASH: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH: 1,
|
||||
LlmModel.GEMINI_2_5_FLASH_LITE_PREVIEW: 1,
|
||||
LlmModel.GEMINI_2_0_FLASH_LITE: 1,
|
||||
LlmModel.MISTRAL_NEMO: 1,
|
||||
LlmModel.COHERE_COMMAND_R_08_2024: 1,
|
||||
LlmModel.COHERE_COMMAND_R_PLUS_08_2024: 3,
|
||||
LlmModel.DEEPSEEK_CHAT: 2,
|
||||
LlmModel.DEEPSEEK_R1_0528: 1,
|
||||
LlmModel.PERPLEXITY_SONAR: 1,
|
||||
LlmModel.PERPLEXITY_SONAR_PRO: 5,
|
||||
LlmModel.PERPLEXITY_SONAR_DEEP_RESEARCH: 10,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_405B: 1,
|
||||
LlmModel.NOUSRESEARCH_HERMES_3_LLAMA_3_1_70B: 1,
|
||||
LlmModel.AMAZON_NOVA_LITE_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_MICRO_V1: 1,
|
||||
LlmModel.AMAZON_NOVA_PRO_V1: 1,
|
||||
LlmModel.MICROSOFT_WIZARDLM_2_8X22B: 1,
|
||||
LlmModel.GRYPHE_MYTHOMAX_L2_13B: 1,
|
||||
LlmModel.META_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.META_LLAMA_4_MAVERICK: 1,
|
||||
LlmModel.LLAMA_API_LLAMA_4_SCOUT: 1,
|
||||
LlmModel.LLAMA_API_LLAMA4_MAVERICK: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_8B: 1,
|
||||
LlmModel.LLAMA_API_LLAMA3_3_70B: 1,
|
||||
LlmModel.GROK_4: 9,
|
||||
LlmModel.GROK_4_FAST: 1,
|
||||
LlmModel.GROK_4_1_FAST: 1,
|
||||
LlmModel.GROK_CODE_FAST_1: 1,
|
||||
LlmModel.KIMI_K2: 1,
|
||||
LlmModel.QWEN3_235B_A22B_THINKING: 1,
|
||||
LlmModel.QWEN3_CODER: 9,
|
||||
# v0 by Vercel models
|
||||
LlmModel.V0_1_5_MD: 1,
|
||||
LlmModel.V0_1_5_LG: 2,
|
||||
LlmModel.V0_1_0_MD: 1,
|
||||
PROVIDER_CREDENTIALS = {
|
||||
"openai": openai_credentials,
|
||||
"anthropic": anthropic_credentials,
|
||||
"groq": groq_credentials,
|
||||
"open_router": open_router_credentials,
|
||||
"llama_api": llama_api_credentials,
|
||||
"aiml_api": aiml_api_credentials,
|
||||
"v0": v0_credentials,
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_COST:
|
||||
raise ValueError(f"Missing MODEL_COST for model: {model}")
|
||||
# =============== Configure the cost for each LLM Model call =============== #
|
||||
# All LLM costs now come from the database via llm_registry
|
||||
|
||||
LLM_COST: list[BlockCost] = []
|
||||
|
||||
|
||||
LLM_COST = (
|
||||
# Anthropic Models
|
||||
[
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
async def _build_llm_costs_from_registry() -> list[BlockCost]:
|
||||
"""
|
||||
Build BlockCost list from all models in the LLM registry.
|
||||
|
||||
This function checks for active model migrations with customCreditCost overrides.
|
||||
When a model has been migrated with a custom price, that price is used instead
|
||||
of the target model's default cost.
|
||||
"""
|
||||
# Query active migrations with custom pricing overrides.
|
||||
# Note: LlmModelMigration is system-level data (no userId field) and this function
|
||||
# is only called during app startup and admin operations, so no user ID filter needed.
|
||||
migration_overrides: dict[str, int] = {}
|
||||
try:
|
||||
active_migrations = await prisma.models.LlmModelMigration.prisma().find_many(
|
||||
where={
|
||||
"isReverted": False,
|
||||
"customCreditCost": {"not": None},
|
||||
}
|
||||
)
|
||||
# Key by targetModelSlug since that's the model nodes are now using
|
||||
# after migration. The custom cost applies to the target model.
|
||||
migration_overrides = {
|
||||
migration.targetModelSlug: migration.customCreditCost
|
||||
for migration in active_migrations
|
||||
if migration.customCreditCost is not None
|
||||
}
|
||||
if migration_overrides:
|
||||
logger.info(
|
||||
"Found %d active model migrations with custom pricing overrides",
|
||||
len(migration_overrides),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to query model migration overrides: %s. Proceeding with default costs.",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
costs: list[BlockCost] = []
|
||||
for model in llm_registry.iter_dynamic_models():
|
||||
for cost in model.costs:
|
||||
credentials = PROVIDER_CREDENTIALS.get(cost.credential_provider)
|
||||
if not credentials:
|
||||
logger.warning(
|
||||
"Skipping cost entry for %s due to unknown credentials provider %s",
|
||||
model.slug,
|
||||
cost.credential_provider,
|
||||
)
|
||||
continue
|
||||
|
||||
# Check if this model has a custom cost override from migration
|
||||
cost_amount = migration_overrides.get(model.slug, cost.credit_cost)
|
||||
|
||||
if model.slug in migration_overrides:
|
||||
logger.debug(
|
||||
"Applying custom cost override for model %s: %d credits (default: %d)",
|
||||
model.slug,
|
||||
cost_amount,
|
||||
cost.credit_cost,
|
||||
)
|
||||
|
||||
cost_filter = {
|
||||
"model": model.slug,
|
||||
"credentials": {
|
||||
"id": anthropic_credentials.id,
|
||||
"provider": anthropic_credentials.provider,
|
||||
"type": anthropic_credentials.type,
|
||||
"id": credentials.id,
|
||||
"provider": credentials.provider,
|
||||
"type": credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "anthropic"
|
||||
]
|
||||
# OpenAI Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": openai_credentials.id,
|
||||
"provider": openai_credentials.provider,
|
||||
"type": openai_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "openai"
|
||||
]
|
||||
# Groq Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {"id": groq_credentials.id},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "groq"
|
||||
]
|
||||
# Open Router Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": open_router_credentials.id,
|
||||
"provider": open_router_credentials.provider,
|
||||
"type": open_router_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "open_router"
|
||||
]
|
||||
# Llama API Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": llama_api_credentials.id,
|
||||
"provider": llama_api_credentials.provider,
|
||||
"type": llama_api_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "llama_api"
|
||||
]
|
||||
# v0 by Vercel Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": v0_credentials.id,
|
||||
"provider": v0_credentials.provider,
|
||||
"type": v0_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "v0"
|
||||
]
|
||||
# AI/ML Api Models
|
||||
+ [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"credentials": {
|
||||
"id": aiml_api_credentials.id,
|
||||
"provider": aiml_api_credentials.provider,
|
||||
"type": aiml_api_credentials.type,
|
||||
},
|
||||
},
|
||||
cost_amount=cost,
|
||||
)
|
||||
for model, cost in MODEL_COST.items()
|
||||
if MODEL_METADATA[model].provider == "aiml_api"
|
||||
]
|
||||
)
|
||||
}
|
||||
costs.append(
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter=cost_filter,
|
||||
cost_amount=cost_amount,
|
||||
)
|
||||
)
|
||||
return costs
|
||||
|
||||
|
||||
async def refresh_llm_costs() -> None:
|
||||
"""
|
||||
Refresh LLM costs from the registry. All costs now come from the database.
|
||||
|
||||
This function also checks for active model migrations with custom pricing overrides
|
||||
and applies them to ensure accurate billing.
|
||||
"""
|
||||
LLM_COST.clear()
|
||||
LLM_COST.extend(await _build_llm_costs_from_registry())
|
||||
|
||||
|
||||
# Initial load will happen after registry is refreshed at startup
|
||||
# Don't call refresh_llm_costs() here - it will be called after registry refresh
|
||||
|
||||
# =============== This is the exhaustive list of cost for each Block =============== #
|
||||
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
from backend.data import db
|
||||
|
||||
|
||||
def chat_db():
|
||||
if db.is_connected():
|
||||
from backend.copilot import db as _chat_db
|
||||
|
||||
chat_db = _chat_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
chat_db = get_database_manager_async_client()
|
||||
|
||||
return chat_db
|
||||
|
||||
|
||||
def graph_db():
|
||||
if db.is_connected():
|
||||
from backend.data import graph as _graph_db
|
||||
|
||||
graph_db = _graph_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
graph_db = get_database_manager_async_client()
|
||||
|
||||
return graph_db
|
||||
|
||||
|
||||
def library_db():
|
||||
if db.is_connected():
|
||||
from backend.api.features.library import db as _library_db
|
||||
|
||||
library_db = _library_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
library_db = get_database_manager_async_client()
|
||||
|
||||
return library_db
|
||||
|
||||
|
||||
def store_db():
|
||||
if db.is_connected():
|
||||
from backend.api.features.store import db as _store_db
|
||||
|
||||
store_db = _store_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
store_db = get_database_manager_async_client()
|
||||
|
||||
return store_db
|
||||
|
||||
|
||||
def search():
|
||||
if db.is_connected():
|
||||
from backend.api.features.store import hybrid_search as _search
|
||||
|
||||
search = _search
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
search = get_database_manager_async_client()
|
||||
|
||||
return search
|
||||
|
||||
|
||||
def execution_db():
|
||||
if db.is_connected():
|
||||
from backend.data import execution as _execution_db
|
||||
|
||||
execution_db = _execution_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
execution_db = get_database_manager_async_client()
|
||||
|
||||
return execution_db
|
||||
|
||||
|
||||
def user_db():
|
||||
if db.is_connected():
|
||||
from backend.data import user as _user_db
|
||||
|
||||
user_db = _user_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
user_db = get_database_manager_async_client()
|
||||
|
||||
return user_db
|
||||
|
||||
|
||||
def understanding_db():
|
||||
if db.is_connected():
|
||||
from backend.data import understanding as _understanding_db
|
||||
|
||||
understanding_db = _understanding_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
understanding_db = get_database_manager_async_client()
|
||||
|
||||
return understanding_db
|
||||
|
||||
|
||||
def workspace_db():
|
||||
if db.is_connected():
|
||||
from backend.data import workspace as _workspace_db
|
||||
|
||||
workspace_db = _workspace_db
|
||||
else:
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
|
||||
workspace_db = get_database_manager_async_client()
|
||||
|
||||
return workspace_db
|
||||
@@ -1147,14 +1147,14 @@ async def get_graph(
|
||||
return GraphModel.from_db(graph, for_export)
|
||||
|
||||
|
||||
async def get_store_listed_graphs(graph_ids: list[str]) -> dict[str, GraphModel]:
|
||||
async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]:
|
||||
"""Batch-fetch multiple store-listed graphs by their IDs.
|
||||
|
||||
Only returns graphs that have approved store listings (publicly available).
|
||||
Does not require permission checks since store-listed graphs are public.
|
||||
|
||||
Args:
|
||||
graph_ids: List of graph IDs to fetch
|
||||
*graph_ids: Variable number of graph IDs to fetch
|
||||
|
||||
Returns:
|
||||
Dict mapping graph_id to GraphModel for graphs with approved store listings
|
||||
@@ -1663,8 +1663,10 @@ async def migrate_llm_models(migrate_to: LlmModel):
|
||||
if field.annotation == LlmModel:
|
||||
llm_model_fields[block.id] = field_name
|
||||
|
||||
# Convert enum values to a list of strings for the SQL query
|
||||
enum_values = [v.value for v in LlmModel]
|
||||
# Get all model slugs from the registry (dynamic, not hardcoded enum)
|
||||
from backend.data import llm_registry
|
||||
|
||||
enum_values = list(llm_registry.get_all_model_slugs_for_validation())
|
||||
escaped_enum_values = repr(tuple(enum_values)) # hack but works
|
||||
|
||||
# Update each block
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
LLM Registry module for managing LLM models, providers, and costs dynamically.
|
||||
|
||||
This module provides a database-driven registry system for LLM models,
|
||||
replacing hardcoded model configurations with a flexible admin-managed system.
|
||||
"""
|
||||
|
||||
from backend.data.llm_registry.model import ModelMetadata
|
||||
|
||||
# Re-export for backwards compatibility
|
||||
from backend.data.llm_registry.notifications import (
|
||||
REGISTRY_REFRESH_CHANNEL,
|
||||
publish_registry_refresh_notification,
|
||||
subscribe_to_registry_refresh,
|
||||
)
|
||||
from backend.data.llm_registry.registry import (
|
||||
RegistryModel,
|
||||
RegistryModelCost,
|
||||
RegistryModelCreator,
|
||||
get_all_model_slugs_for_validation,
|
||||
get_default_model_slug,
|
||||
get_dynamic_model_slugs,
|
||||
get_fallback_model_for_disabled,
|
||||
get_llm_discriminator_mapping,
|
||||
get_llm_model_cost,
|
||||
get_llm_model_metadata,
|
||||
get_llm_model_schema_options,
|
||||
get_model_info,
|
||||
is_model_enabled,
|
||||
iter_dynamic_models,
|
||||
refresh_llm_registry,
|
||||
register_static_costs,
|
||||
register_static_metadata,
|
||||
)
|
||||
from backend.data.llm_registry.schema_utils import (
|
||||
is_llm_model_field,
|
||||
refresh_llm_discriminator_mapping,
|
||||
refresh_llm_model_options,
|
||||
update_schema_with_llm_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Types
|
||||
"ModelMetadata",
|
||||
"RegistryModel",
|
||||
"RegistryModelCost",
|
||||
"RegistryModelCreator",
|
||||
# Registry functions
|
||||
"get_all_model_slugs_for_validation",
|
||||
"get_default_model_slug",
|
||||
"get_dynamic_model_slugs",
|
||||
"get_fallback_model_for_disabled",
|
||||
"get_llm_discriminator_mapping",
|
||||
"get_llm_model_cost",
|
||||
"get_llm_model_metadata",
|
||||
"get_llm_model_schema_options",
|
||||
"get_model_info",
|
||||
"is_model_enabled",
|
||||
"iter_dynamic_models",
|
||||
"refresh_llm_registry",
|
||||
"register_static_costs",
|
||||
"register_static_metadata",
|
||||
# Notifications
|
||||
"REGISTRY_REFRESH_CHANNEL",
|
||||
"publish_registry_refresh_notification",
|
||||
"subscribe_to_registry_refresh",
|
||||
# Schema utilities
|
||||
"is_llm_model_field",
|
||||
"refresh_llm_discriminator_mapping",
|
||||
"refresh_llm_model_options",
|
||||
"update_schema_with_llm_registry",
|
||||
]
|
||||
25
autogpt_platform/backend/backend/data/llm_registry/model.py
Normal file
25
autogpt_platform/backend/backend/data/llm_registry/model.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Type definitions for LLM model metadata."""
|
||||
|
||||
from typing import Literal, NamedTuple
|
||||
|
||||
|
||||
class ModelMetadata(NamedTuple):
|
||||
"""Metadata for an LLM model.
|
||||
|
||||
Attributes:
|
||||
provider: The provider identifier (e.g., "openai", "anthropic")
|
||||
context_window: Maximum context window size in tokens
|
||||
max_output_tokens: Maximum output tokens (None if unlimited)
|
||||
display_name: Human-readable name for the model
|
||||
provider_name: Human-readable provider name (e.g., "OpenAI", "Anthropic")
|
||||
creator_name: Name of the organization that created the model
|
||||
price_tier: Relative cost tier (1=cheapest, 2=medium, 3=expensive)
|
||||
"""
|
||||
|
||||
provider: str
|
||||
context_window: int
|
||||
max_output_tokens: int | None
|
||||
display_name: str
|
||||
provider_name: str
|
||||
creator_name: str
|
||||
price_tier: Literal[1, 2, 3]
|
||||
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Redis pub/sub notifications for LLM registry updates.
|
||||
|
||||
When models are added/updated/removed via the admin UI, this module
|
||||
publishes notifications to Redis that all executor services subscribe to,
|
||||
ensuring they refresh their registry cache in real-time.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.redis_client import connect_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis channel name for LLM registry refresh notifications
|
||||
REGISTRY_REFRESH_CHANNEL = "llm_registry:refresh"
|
||||
|
||||
|
||||
async def publish_registry_refresh_notification() -> None:
|
||||
"""
|
||||
Publish a notification to Redis that the LLM registry has been updated.
|
||||
All executor services subscribed to this channel will refresh their registry.
|
||||
"""
|
||||
try:
|
||||
redis = await connect_async()
|
||||
await redis.publish(REGISTRY_REFRESH_CHANNEL, "refresh")
|
||||
logger.info("Published LLM registry refresh notification to Redis")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to publish LLM registry refresh notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
async def subscribe_to_registry_refresh(
|
||||
on_refresh: Any, # Async callable that takes no args
|
||||
) -> None:
|
||||
"""
|
||||
Subscribe to Redis notifications for LLM registry updates.
|
||||
This runs in a loop and processes messages as they arrive.
|
||||
|
||||
Args:
|
||||
on_refresh: Async callable to execute when a refresh notification is received
|
||||
"""
|
||||
try:
|
||||
redis = await connect_async()
|
||||
pubsub = redis.pubsub()
|
||||
await pubsub.subscribe(REGISTRY_REFRESH_CHANNEL)
|
||||
logger.info(
|
||||
"Subscribed to LLM registry refresh notifications on channel: %s",
|
||||
REGISTRY_REFRESH_CHANNEL,
|
||||
)
|
||||
|
||||
# Process messages in a loop
|
||||
while True:
|
||||
try:
|
||||
message = await pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=1.0
|
||||
)
|
||||
if (
|
||||
message
|
||||
and message["type"] == "message"
|
||||
and message["channel"] == REGISTRY_REFRESH_CHANNEL
|
||||
):
|
||||
logger.info("Received LLM registry refresh notification")
|
||||
try:
|
||||
await on_refresh()
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Error refreshing LLM registry from notification: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Error processing registry refresh message: %s", exc, exc_info=True
|
||||
)
|
||||
# Continue listening even if one message fails
|
||||
await asyncio.sleep(1)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to subscribe to LLM registry refresh notifications: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
388
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
388
autogpt_platform/backend/backend/data/llm_registry/registry.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""Core LLM registry implementation for managing models dynamically."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Iterable
|
||||
|
||||
import prisma.models
|
||||
|
||||
from backend.data.llm_registry.model import ModelMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _json_to_dict(value: Any) -> dict[str, Any]:
|
||||
"""Convert Prisma Json type to dict, with fallback to empty dict."""
|
||||
if value is None:
|
||||
return {}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
# Prisma Json type should always be a dict at runtime
|
||||
return dict(value) if value else {}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCost:
|
||||
"""Cost configuration for an LLM model."""
|
||||
|
||||
credit_cost: int
|
||||
credential_provider: str
|
||||
credential_id: str | None
|
||||
credential_type: str | None
|
||||
currency: str | None
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModelCreator:
|
||||
"""Creator information for an LLM model."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
website_url: str | None
|
||||
logo_url: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistryModel:
|
||||
"""Represents a model in the LLM registry."""
|
||||
|
||||
slug: str
|
||||
display_name: str
|
||||
description: str | None
|
||||
metadata: ModelMetadata
|
||||
capabilities: dict[str, Any]
|
||||
extra_metadata: dict[str, Any]
|
||||
provider_display_name: str
|
||||
is_enabled: bool
|
||||
is_recommended: bool = False
|
||||
costs: tuple[RegistryModelCost, ...] = field(default_factory=tuple)
|
||||
creator: RegistryModelCreator | None = None
|
||||
|
||||
|
||||
_static_metadata: dict[str, ModelMetadata] = {}
|
||||
_static_costs: dict[str, int] = {}
|
||||
_dynamic_models: dict[str, RegistryModel] = {}
|
||||
_schema_options: list[dict[str, str]] = []
|
||||
_discriminator_mapping: dict[str, str] = {}
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def register_static_metadata(metadata: dict[Any, ModelMetadata]) -> None:
|
||||
"""Register static metadata for legacy models (deprecated)."""
|
||||
_static_metadata.update({str(key): value for key, value in metadata.items()})
|
||||
_refresh_cached_schema()
|
||||
|
||||
|
||||
def register_static_costs(costs: dict[Any, int]) -> None:
|
||||
"""Register static costs for legacy models (deprecated)."""
|
||||
_static_costs.update({str(key): value for key, value in costs.items()})
|
||||
|
||||
|
||||
def _build_schema_options() -> list[dict[str, str]]:
|
||||
"""Build schema options for model selection dropdown. Only includes enabled models."""
|
||||
options: list[dict[str, str]] = []
|
||||
# Only include enabled models in the dropdown options
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
|
||||
if model.is_enabled:
|
||||
options.append(
|
||||
{
|
||||
"label": model.display_name,
|
||||
"value": model.slug,
|
||||
"group": model.metadata.provider,
|
||||
"description": model.description or "",
|
||||
}
|
||||
)
|
||||
|
||||
for slug, metadata in _static_metadata.items():
|
||||
if slug in _dynamic_models:
|
||||
continue
|
||||
options.append(
|
||||
{
|
||||
"label": slug,
|
||||
"value": slug,
|
||||
"group": metadata.provider,
|
||||
"description": "",
|
||||
}
|
||||
)
|
||||
return options
|
||||
|
||||
|
||||
async def refresh_llm_registry() -> None:
|
||||
"""Refresh the LLM registry from the database. Loads all models (enabled and disabled)."""
|
||||
async with _lock:
|
||||
try:
|
||||
records = await prisma.models.LlmModel.prisma().find_many(
|
||||
include={
|
||||
"Provider": True,
|
||||
"Costs": True,
|
||||
"Creator": True,
|
||||
}
|
||||
)
|
||||
logger.debug("Found %d LLM model records in database", len(records))
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Failed to refresh LLM registry from DB: %s", exc, exc_info=True
|
||||
)
|
||||
return
|
||||
|
||||
dynamic: dict[str, RegistryModel] = {}
|
||||
for record in records:
|
||||
provider_name = (
|
||||
record.Provider.name if record.Provider else record.providerId
|
||||
)
|
||||
provider_display_name = (
|
||||
record.Provider.displayName if record.Provider else record.providerId
|
||||
)
|
||||
# Creator name: prefer Creator.name, fallback to provider display name
|
||||
creator_name = (
|
||||
record.Creator.name if record.Creator else provider_display_name
|
||||
)
|
||||
# Price tier: default to 1 (cheapest) if not set
|
||||
price_tier = getattr(record, "priceTier", 1) or 1
|
||||
# Clamp to valid range 1-3
|
||||
price_tier = max(1, min(3, price_tier))
|
||||
|
||||
metadata = ModelMetadata(
|
||||
provider=provider_name,
|
||||
context_window=record.contextWindow,
|
||||
max_output_tokens=record.maxOutputTokens,
|
||||
display_name=record.displayName,
|
||||
provider_name=provider_display_name,
|
||||
creator_name=creator_name,
|
||||
price_tier=price_tier, # type: ignore[arg-type]
|
||||
)
|
||||
costs = tuple(
|
||||
RegistryModelCost(
|
||||
credit_cost=cost.creditCost,
|
||||
credential_provider=cost.credentialProvider,
|
||||
credential_id=cost.credentialId,
|
||||
credential_type=cost.credentialType,
|
||||
currency=cost.currency,
|
||||
metadata=_json_to_dict(cost.metadata),
|
||||
)
|
||||
for cost in (record.Costs or [])
|
||||
)
|
||||
|
||||
# Map creator if present
|
||||
creator = None
|
||||
if record.Creator:
|
||||
creator = RegistryModelCreator(
|
||||
id=record.Creator.id,
|
||||
name=record.Creator.name,
|
||||
display_name=record.Creator.displayName,
|
||||
description=record.Creator.description,
|
||||
website_url=record.Creator.websiteUrl,
|
||||
logo_url=record.Creator.logoUrl,
|
||||
)
|
||||
|
||||
dynamic[record.slug] = RegistryModel(
|
||||
slug=record.slug,
|
||||
display_name=record.displayName,
|
||||
description=record.description,
|
||||
metadata=metadata,
|
||||
capabilities=_json_to_dict(record.capabilities),
|
||||
extra_metadata=_json_to_dict(record.metadata),
|
||||
provider_display_name=(
|
||||
record.Provider.displayName
|
||||
if record.Provider
|
||||
else record.providerId
|
||||
),
|
||||
is_enabled=record.isEnabled,
|
||||
is_recommended=record.isRecommended,
|
||||
costs=costs,
|
||||
creator=creator,
|
||||
)
|
||||
|
||||
# Atomic swap - build new structures then replace references
|
||||
# This ensures readers never see partially updated state
|
||||
global _dynamic_models
|
||||
_dynamic_models = dynamic
|
||||
_refresh_cached_schema()
|
||||
logger.info(
|
||||
"LLM registry refreshed with %s dynamic models (enabled: %s, disabled: %s)",
|
||||
len(dynamic),
|
||||
sum(1 for m in dynamic.values() if m.is_enabled),
|
||||
sum(1 for m in dynamic.values() if not m.is_enabled),
|
||||
)
|
||||
|
||||
|
||||
def _refresh_cached_schema() -> None:
|
||||
"""Refresh cached schema options and discriminator mapping."""
|
||||
global _schema_options, _discriminator_mapping
|
||||
|
||||
# Build new structures
|
||||
new_options = _build_schema_options()
|
||||
new_mapping = {
|
||||
slug: entry.metadata.provider for slug, entry in _dynamic_models.items()
|
||||
}
|
||||
for slug, metadata in _static_metadata.items():
|
||||
new_mapping.setdefault(slug, metadata.provider)
|
||||
|
||||
# Atomic swap - replace references to ensure readers see consistent state
|
||||
_schema_options = new_options
|
||||
_discriminator_mapping = new_mapping
|
||||
|
||||
|
||||
def get_llm_model_metadata(slug: str) -> ModelMetadata | None:
|
||||
"""Get model metadata by slug. Checks dynamic models first, then static metadata."""
|
||||
if slug in _dynamic_models:
|
||||
return _dynamic_models[slug].metadata
|
||||
return _static_metadata.get(slug)
|
||||
|
||||
|
||||
def get_llm_model_cost(slug: str) -> tuple[RegistryModelCost, ...]:
|
||||
"""Get model cost configuration by slug."""
|
||||
if slug in _dynamic_models:
|
||||
return _dynamic_models[slug].costs
|
||||
cost_value = _static_costs.get(slug)
|
||||
if cost_value is None:
|
||||
return tuple()
|
||||
return (
|
||||
RegistryModelCost(
|
||||
credit_cost=cost_value,
|
||||
credential_provider="static",
|
||||
credential_id=None,
|
||||
credential_type=None,
|
||||
currency=None,
|
||||
metadata={},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_llm_model_schema_options() -> list[dict[str, str]]:
|
||||
"""
|
||||
Get schema options for LLM model selection dropdown.
|
||||
|
||||
Returns a copy of cached schema options that are refreshed when the registry is
|
||||
updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
|
||||
"""
|
||||
# Return a copy to prevent external mutation
|
||||
return list(_schema_options)
|
||||
|
||||
|
||||
def get_llm_discriminator_mapping() -> dict[str, str]:
|
||||
"""
|
||||
Get discriminator mapping for LLM models.
|
||||
|
||||
Returns a copy of cached discriminator mapping that is refreshed when the registry
|
||||
is updated via refresh_llm_registry() (called on startup and via Redis pub/sub).
|
||||
"""
|
||||
# Return a copy to prevent external mutation
|
||||
return dict(_discriminator_mapping)
|
||||
|
||||
|
||||
def get_dynamic_model_slugs() -> set[str]:
|
||||
"""Get all dynamic model slugs from the registry."""
|
||||
return set(_dynamic_models.keys())
|
||||
|
||||
|
||||
def get_all_model_slugs_for_validation() -> set[str]:
|
||||
"""
|
||||
Get ALL model slugs (both enabled and disabled) for validation purposes.
|
||||
|
||||
This is used for JSON schema enum validation - we need to accept any known
|
||||
model value (even disabled ones) so that existing graphs don't fail validation.
|
||||
The actual fallback/enforcement happens at runtime in llm_call().
|
||||
"""
|
||||
all_slugs = set(_dynamic_models.keys())
|
||||
all_slugs.update(_static_metadata.keys())
|
||||
return all_slugs
|
||||
|
||||
|
||||
def iter_dynamic_models() -> Iterable[RegistryModel]:
|
||||
"""Iterate over all dynamic models in the registry."""
|
||||
return tuple(_dynamic_models.values())
|
||||
|
||||
|
||||
def get_fallback_model_for_disabled(disabled_model_slug: str) -> RegistryModel | None:
|
||||
"""
|
||||
Find a fallback model when the requested model is disabled.
|
||||
|
||||
Looks for an enabled model from the same provider. Prefers models with
|
||||
similar names or capabilities if possible.
|
||||
|
||||
Args:
|
||||
disabled_model_slug: The slug of the disabled model
|
||||
|
||||
Returns:
|
||||
An enabled RegistryModel from the same provider, or None if no fallback found
|
||||
"""
|
||||
disabled_model = _dynamic_models.get(disabled_model_slug)
|
||||
if not disabled_model:
|
||||
return None
|
||||
|
||||
provider = disabled_model.metadata.provider
|
||||
|
||||
# Find all enabled models from the same provider
|
||||
candidates = [
|
||||
model
|
||||
for model in _dynamic_models.values()
|
||||
if model.is_enabled and model.metadata.provider == provider
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Sort by: prefer models with similar context window, then by name
|
||||
candidates.sort(
|
||||
key=lambda m: (
|
||||
abs(m.metadata.context_window - disabled_model.metadata.context_window),
|
||||
m.display_name.lower(),
|
||||
)
|
||||
)
|
||||
|
||||
return candidates[0]
|
||||
|
||||
|
||||
def is_model_enabled(model_slug: str) -> bool:
|
||||
"""Check if a model is enabled in the registry."""
|
||||
model = _dynamic_models.get(model_slug)
|
||||
if not model:
|
||||
# Model not in registry - assume it's a static/legacy model and allow it
|
||||
return True
|
||||
return model.is_enabled
|
||||
|
||||
|
||||
def get_model_info(model_slug: str) -> RegistryModel | None:
|
||||
"""Get model info from the registry."""
|
||||
return _dynamic_models.get(model_slug)
|
||||
|
||||
|
||||
def get_default_model_slug() -> str | None:
|
||||
"""
|
||||
Get the default model slug to use for block defaults.
|
||||
|
||||
Returns the recommended model if set (configured via admin UI),
|
||||
otherwise returns the first enabled model alphabetically.
|
||||
Returns None if no models are available or enabled.
|
||||
"""
|
||||
# Return the recommended model if one is set and enabled
|
||||
for model in _dynamic_models.values():
|
||||
if model.is_recommended and model.is_enabled:
|
||||
return model.slug
|
||||
|
||||
# No recommended model set - find first enabled model alphabetically
|
||||
for model in sorted(_dynamic_models.values(), key=lambda m: m.display_name.lower()):
|
||||
if model.is_enabled:
|
||||
logger.warning(
|
||||
"No recommended model set, using '%s' as default",
|
||||
model.slug,
|
||||
)
|
||||
return model.slug
|
||||
|
||||
# No enabled models available
|
||||
if _dynamic_models:
|
||||
logger.error(
|
||||
"No enabled models found in registry (%d models registered but all disabled)",
|
||||
len(_dynamic_models),
|
||||
)
|
||||
else:
|
||||
logger.error("No models registered in LLM registry")
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Helper utilities for LLM registry integration with block schemas.
|
||||
|
||||
This module handles the dynamic injection of discriminator mappings
|
||||
and model options from the LLM registry into block schemas.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.llm_registry.registry import (
|
||||
get_all_model_slugs_for_validation,
|
||||
get_default_model_slug,
|
||||
get_llm_discriminator_mapping,
|
||||
get_llm_model_schema_options,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_llm_model_field(field_name: str, field_info: Any) -> bool:
|
||||
"""
|
||||
Check if a field is an LLM model selection field.
|
||||
|
||||
Returns True if the field has 'options' in json_schema_extra
|
||||
(set by llm_model_schema_extra() in blocks/llm.py).
|
||||
"""
|
||||
if not hasattr(field_info, "json_schema_extra"):
|
||||
return False
|
||||
|
||||
extra = field_info.json_schema_extra
|
||||
if isinstance(extra, dict):
|
||||
return "options" in extra
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def refresh_llm_model_options(field_schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Refresh LLM model options from the registry.
|
||||
|
||||
Updates 'options' (for frontend dropdown) to show only enabled models,
|
||||
but keeps the 'enum' (for validation) inclusive of ALL known models.
|
||||
|
||||
This is important because:
|
||||
- Options: What users see in the dropdown (enabled models only)
|
||||
- Enum: What values pass validation (all known models, including disabled)
|
||||
|
||||
Existing graphs may have disabled models selected - they should pass validation
|
||||
and the fallback logic in llm_call() will handle using an alternative model.
|
||||
"""
|
||||
fresh_options = get_llm_model_schema_options()
|
||||
if not fresh_options:
|
||||
return
|
||||
|
||||
# Update options array (UI dropdown) - only enabled models
|
||||
if "options" in field_schema:
|
||||
field_schema["options"] = fresh_options
|
||||
|
||||
all_known_slugs = get_all_model_slugs_for_validation()
|
||||
if all_known_slugs and "enum" in field_schema:
|
||||
existing_enum = set(field_schema.get("enum", []))
|
||||
combined_enum = existing_enum | all_known_slugs
|
||||
field_schema["enum"] = sorted(combined_enum)
|
||||
|
||||
# Set the default value from the registry (gpt-4o if available, else first enabled)
|
||||
# This ensures new blocks have a sensible default pre-selected
|
||||
default_slug = get_default_model_slug()
|
||||
if default_slug:
|
||||
field_schema["default"] = default_slug
|
||||
|
||||
|
||||
def refresh_llm_discriminator_mapping(field_schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Refresh discriminator_mapping for fields that use model-based discrimination.
|
||||
|
||||
The discriminator is already set when AICredentialsField() creates the field.
|
||||
We only need to refresh the mapping when models are added/removed.
|
||||
"""
|
||||
if field_schema.get("discriminator") != "model":
|
||||
return
|
||||
|
||||
# Always refresh the mapping to get latest models
|
||||
fresh_mapping = get_llm_discriminator_mapping()
|
||||
if fresh_mapping is not None:
|
||||
field_schema["discriminator_mapping"] = fresh_mapping
|
||||
|
||||
|
||||
def update_schema_with_llm_registry(
|
||||
schema: dict[str, Any], model_class: type | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Update a JSON schema with current LLM registry data.
|
||||
|
||||
Refreshes:
|
||||
1. Model options for LLM model selection fields (dropdown choices)
|
||||
2. Discriminator mappings for credentials fields (model → provider)
|
||||
|
||||
Args:
|
||||
schema: The JSON schema to update (mutated in-place)
|
||||
model_class: The Pydantic model class (optional, for field introspection)
|
||||
"""
|
||||
properties = schema.get("properties", {})
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
if not isinstance(field_schema, dict):
|
||||
continue
|
||||
|
||||
# Refresh model options for LLM model fields
|
||||
if model_class and hasattr(model_class, "model_fields"):
|
||||
field_info = model_class.model_fields.get(field_name)
|
||||
if field_info and is_llm_model_field(field_name, field_info):
|
||||
try:
|
||||
refresh_llm_model_options(field_schema)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh LLM options for field %s: %s",
|
||||
field_name,
|
||||
exc,
|
||||
)
|
||||
|
||||
# Refresh discriminator mapping for fields that use model discrimination
|
||||
try:
|
||||
refresh_llm_discriminator_mapping(field_schema)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to refresh discriminator mapping for field %s: %s",
|
||||
field_name,
|
||||
exc,
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user